diff --git a/common/mq/net_msg.go b/common/mq/net_msg.go index d4a45ee9..61c75bf8 100644 --- a/common/mq/net_msg.go +++ b/common/mq/net_msg.go @@ -111,6 +111,7 @@ type UserMpInfo struct { ApplyPlayerOnlineInfo *UserBaseInfo `msgpack:"ApplyPlayerOnlineInfo"` ApplyOk bool `msgpack:"ApplyOk"` Agreed bool `msgpack:"Agreed"` + Reason int32 `msgpack:"Reason"` HostNickname string `msgpack:"HostNickname"` } diff --git a/dispatch/controller/fixed_controller.go b/dispatch/controller/constant_controller.go similarity index 100% rename from dispatch/controller/fixed_controller.go rename to dispatch/controller/constant_controller.go diff --git a/dispatch/controller/gate_controller.go b/dispatch/controller/gate_controller.go index 62b1ee30..a9cde7cf 100644 --- a/dispatch/controller/gate_controller.go +++ b/dispatch/controller/gate_controller.go @@ -22,24 +22,42 @@ type TokenVerifyRsp struct { } func (c *Controller) gateTokenVerify(context *gin.Context) { - tokenVerifyReq := new(TokenVerifyReq) - err := context.ShouldBindJSON(tokenVerifyReq) - if err != nil { - return - } - logger.Debug("gate token verify, req: %v", tokenVerifyReq) - accountId, err := strconv.ParseUint(tokenVerifyReq.AccountId, 10, 64) - if err != nil { - return - } - account, err := c.dao.QueryAccountByField("accountID", accountId) - if err != nil || account == nil { + VerifyFail := func() { context.JSON(http.StatusOK, &TokenVerifyRsp{ Valid: false, Forbid: false, ForbidEndTime: 0, PlayerID: 0, }) + } + tokenVerifyReq := new(TokenVerifyReq) + err := context.ShouldBindJSON(tokenVerifyReq) + if err != nil { + VerifyFail() + return + } + logger.Info("gate token verify, req: %v", tokenVerifyReq) + accountId, err := strconv.ParseUint(tokenVerifyReq.AccountId, 10, 64) + if err != nil { + VerifyFail() + return + } + account, err := c.dao.QueryAccountByField("accountID", accountId) + if err != nil || account == nil { + VerifyFail() + return + } + if tokenVerifyReq.AccountToken != account.ComboToken { + VerifyFail() + return + } + if account.ComboTokenUsed { + VerifyFail() + return + } + _, err = c.dao.UpdateAccountFieldByFieldName("accountID", account.AccountID, "comboTokenUsed", true) + if err != nil { + VerifyFail() return } context.JSON(http.StatusOK, &TokenVerifyRsp{ diff --git a/dispatch/controller/log_controller.go b/dispatch/controller/log_controller.go index 1c631c7c..2a65deef 100644 --- a/dispatch/controller/log_controller.go +++ b/dispatch/controller/log_controller.go @@ -1,6 +1,11 @@ package controller -import "github.com/gin-gonic/gin" +import ( + "hk4e/dispatch/model" + "hk4e/pkg/logger" + + "github.com/gin-gonic/gin" +) // POST https://log-upload-os.mihoyo.com/sdk/dataUpload HTTP/1.1 func (c *Controller) sdkDataUpload(context *gin.Context) { @@ -22,6 +27,17 @@ func (c *Controller) perfDataUpload(context *gin.Context) { // POST http://overseauspider.yuanshen.com:8888/log HTTP/1.1 func (c *Controller) log8888(context *gin.Context) { + clientLog := new(model.ClientLog) + err := context.ShouldBindJSON(clientLog) + if err != nil { + logger.Error("parse client log error: %v", err) + return + } + _, err = c.dao.InsertClientLog(clientLog) + if err != nil { + logger.Error("insert client log error: %v", err) + return + } context.Header("Content-type", "application/json") _, _ = context.Writer.WriteString("{\"code\":0}") } diff --git a/dispatch/controller/login_controller.go b/dispatch/controller/login_controller.go index b58a635b..c2307eb0 100644 --- a/dispatch/controller/login_controller.go +++ b/dispatch/controller/login_controller.go @@ -7,6 +7,7 @@ import ( "regexp" "strconv" "strings" + "time" "hk4e/dispatch/api" "hk4e/dispatch/model" @@ -98,7 +99,7 @@ func (c *Controller) apiLogin(context *gin.Context) { return } if account == nil { - // 注册一个原神account + // 自动注册 accountId, err := c.dao.GetNextAccountId() if err != nil { responseData.Retcode = -201 @@ -138,7 +139,7 @@ func (c *Controller) apiLogin(context *gin.Context) { context.JSON(http.StatusOK, responseData) return } - // 生产新的token + // 生成新的token account.Token = base64.StdEncoding.EncodeToString(random.GetRandomByte(24)) _, err = c.dao.UpdateAccountFieldByFieldName("accountID", account.AccountID, "token", account.Token) if err != nil { @@ -147,6 +148,13 @@ func (c *Controller) apiLogin(context *gin.Context) { context.JSON(http.StatusOK, responseData) return } + _, err = c.dao.UpdateAccountFieldByFieldName("accountID", account.AccountID, "tokenCreateTime", time.Now().UnixMilli()) + if err != nil { + responseData.Retcode = -201 + responseData.Message = "服务器内部错误:-5" + context.JSON(http.StatusOK, responseData) + return + } responseData.Message = "OK" responseData.Data.Account.Uid = strconv.FormatInt(int64(account.AccountID), 10) responseData.Data.Account.Token = account.Token @@ -178,6 +186,12 @@ func (c *Controller) apiVerify(context *gin.Context) { context.JSON(http.StatusOK, responseData) return } + if uint64(time.Now().UnixMilli())-account.TokenCreateTime > uint64(time.Hour.Milliseconds()*24*7) { + responseData.Retcode = -111 + responseData.Message = "登录已失效" + context.JSON(http.StatusOK, responseData) + return + } responseData.Message = "OK" responseData.Data.Account.Uid = requestData.Uid responseData.Data.Account.Token = requestData.Token @@ -225,6 +239,13 @@ func (c *Controller) v2Login(context *gin.Context) { context.JSON(http.StatusOK, responseData) return } + _, err = c.dao.UpdateAccountFieldByFieldName("accountID", account.AccountID, "comboTokenUsed", false) + if err != nil { + responseData.Retcode = -201 + responseData.Message = "服务器内部错误:-2" + context.JSON(http.StatusOK, responseData) + return + } responseData.Message = "OK" responseData.Data.OpenID = loginData.Uid responseData.Data.ComboID = "0" diff --git a/dispatch/dao/account_mongo.go b/dispatch/dao/account_mongo.go index 234a1ee5..bfe3ce5f 100644 --- a/dispatch/dao/account_mongo.go +++ b/dispatch/dao/account_mongo.go @@ -6,88 +6,10 @@ import ( "hk4e/dispatch/model" "hk4e/pkg/logger" - "github.com/pkg/errors" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" - "go.mongodb.org/mongo-driver/mongo" ) -func (d *Dao) _GetNextAccountId() (uint64, error) { - db := d.db.Collection("account_id_counter") - find := db.FindOne(context.TODO(), bson.D{{"_id", "default"}}) - item := new(model.AccountIDCounter) - err := find.Decode(item) - if err != nil { - if err == mongo.ErrNoDocuments { - item := &model.AccountIDCounter{ - ID: "default", - AccountID: 1, - } - _, err := db.InsertOne(context.TODO(), item) - if err != nil { - return 0, errors.New("insert new AccountID error") - } - return item.AccountID, nil - } else { - return 0, err - } - } - item.AccountID++ - _, err = db.UpdateOne( - context.TODO(), - bson.D{ - {"_id", "default"}, - }, - bson.D{ - {"$set", bson.D{ - {"AccountID", item.AccountID}, - }}, - }, - ) - if err != nil { - return 0, err - } - return item.AccountID, nil -} - -func (d *Dao) _GetNextYuanShenUid() (uint64, error) { - db := d.db.Collection("player_id_counter") - find := db.FindOne(context.TODO(), bson.D{{"_id", "default"}}) - item := new(model.PlayerIDCounter) - err := find.Decode(item) - if err != nil { - if err == mongo.ErrNoDocuments { - item := &model.PlayerIDCounter{ - ID: "default", - PlayerID: 100000001, - } - _, err := db.InsertOne(context.TODO(), item) - if err != nil { - return 0, errors.New("insert new PlayerID error") - } - return item.PlayerID, nil - } else { - return 0, err - } - } - item.PlayerID++ - _, err = db.UpdateOne( - context.TODO(), - bson.D{ - {"_id", "default"}, - }, - bson.D{ - {"$set", bson.D{ - {"PlayerID", item.PlayerID}, - }}, - }, - ) - if err != nil { - return 0, err - } - return item.PlayerID, nil -} - func (d *Dao) InsertAccount(account *model.Account) (primitive.ObjectID, error) { db := d.db.Collection("account") id, err := db.InsertOne(context.TODO(), account) @@ -103,21 +25,6 @@ func (d *Dao) InsertAccount(account *model.Account) (primitive.ObjectID, error) } } -func (d *Dao) DeleteAccountByUsername(username string) (int64, error) { - db := d.db.Collection("account") - deleteCount, err := db.DeleteOne( - context.TODO(), - bson.D{ - {"username", username}, - }, - ) - if err != nil { - return 0, err - } else { - return deleteCount.DeletedCount, nil - } -} - func (d *Dao) UpdateAccountFieldByFieldName(fieldName string, fieldValue any, fieldUpdateName string, fieldUpdateValue any) (int64, error) { db := d.db.Collection("account") updateCount, err := db.UpdateOne( diff --git a/dispatch/dao/client_log_mongo.go b/dispatch/dao/client_log_mongo.go new file mode 100644 index 00000000..3068eb3d --- /dev/null +++ b/dispatch/dao/client_log_mongo.go @@ -0,0 +1,25 @@ +package dao + +import ( + "context" + + "hk4e/dispatch/model" + "hk4e/pkg/logger" + + "go.mongodb.org/mongo-driver/bson/primitive" +) + +func (d *Dao) InsertClientLog(clientLog *model.ClientLog) (primitive.ObjectID, error) { + db := d.db.Collection("client_log") + id, err := db.InsertOne(context.TODO(), clientLog) + if err != nil { + return primitive.ObjectID{}, err + } else { + _id, ok := id.InsertedID.(primitive.ObjectID) + if !ok { + logger.Error("get insert id error") + return primitive.ObjectID{}, nil + } + return _id, nil + } +} diff --git a/dispatch/model/account.go b/dispatch/model/account.go index 6ecfe5ab..229c5ac6 100644 --- a/dispatch/model/account.go +++ b/dispatch/model/account.go @@ -3,13 +3,15 @@ package model import "go.mongodb.org/mongo-driver/bson/primitive" type Account struct { - ID primitive.ObjectID `bson:"_id,omitempty"` - AccountID uint64 `bson:"accountID"` - Username string `bson:"username"` - Password string `bson:"password"` - PlayerID uint64 `bson:"playerID"` - Token string `bson:"token"` - ComboToken string `bson:"comboToken"` - Forbid bool `bson:"forbid"` - ForbidEndTime uint64 `bson:"forbidEndTime"` + ID primitive.ObjectID `bson:"_id,omitempty"` + AccountID uint64 `bson:"accountID"` + Username string `bson:"username"` + Password string `bson:"password"` + PlayerID uint64 `bson:"playerID"` + Token string `bson:"token"` + TokenCreateTime uint64 `bson:"tokenCreateTime"` // 毫秒时间戳 + ComboToken string `bson:"comboToken"` + ComboTokenUsed bool `bson:"comboTokenUsed"` + Forbid bool `bson:"forbid"` + ForbidEndTime uint64 `bson:"forbidEndTime"` // 秒时间戳 } diff --git a/dispatch/model/account_id_counter.go b/dispatch/model/account_id_counter.go deleted file mode 100644 index b1a1b5f5..00000000 --- a/dispatch/model/account_id_counter.go +++ /dev/null @@ -1,6 +0,0 @@ -package model - -type AccountIDCounter struct { - ID string `bson:"_id"` - AccountID uint64 `bson:"AccountID"` -} diff --git a/dispatch/model/client_log.go b/dispatch/model/client_log.go new file mode 100644 index 00000000..7b5fa637 --- /dev/null +++ b/dispatch/model/client_log.go @@ -0,0 +1,24 @@ +package model + +import ( + "go.mongodb.org/mongo-driver/bson/primitive" +) + +type ClientLog struct { + ID primitive.ObjectID `json:"-" bson:"_id,omitempty"` + Auid string `json:"auid" bson:"auid"` + ClientIp string `json:"clientIp" bson:"clientIp"` + CpuInfo string `json:"cpuInfo" bson:"cpuInfo"` + DeviceModel string `json:"deviceModel" bson:"deviceModel"` + DeviceName string `json:"deviceName" bson:"deviceName"` + GpuInfo string `json:"gpuInfo" bson:"gpuInfo"` + Guid string `json:"guid" bson:"guid"` + LogStr string `json:"logStr" bson:"logStr"` + LogType string `json:"logType" bson:"logType"` + OperatingSystem string `json:"operatingSystem" bson:"operatingSystem"` + StackTrace string `json:"stackTrace" bson:"stackTrace"` + Time string `json:"time" bson:"time"` + Uid uint64 `json:"uid" bson:"uid"` + UserName string `json:"userName" bson:"userName"` + Version string `json:"version" bson:"version"` +} diff --git a/dispatch/model/player_id_counter.go b/dispatch/model/player_id_counter.go deleted file mode 100644 index 7309263b..00000000 --- a/dispatch/model/player_id_counter.go +++ /dev/null @@ -1,6 +0,0 @@ -package model - -type PlayerIDCounter struct { - ID string `bson:"_id"` - PlayerID uint64 `bson:"PlayerID"` -} diff --git a/dispatch/service/forbid_user_info.go b/dispatch/service/forbid_user_info.go deleted file mode 100644 index b0ecf106..00000000 --- a/dispatch/service/forbid_user_info.go +++ /dev/null @@ -1,6 +0,0 @@ -package service - -type ForbidUserInfo struct { - UserId uint32 - ForbidEndTime uint64 -} diff --git a/dispatch/service/service.go b/dispatch/service/service.go index b5f18fa3..401e3c4c 100644 --- a/dispatch/service/service.go +++ b/dispatch/service/service.go @@ -8,71 +8,36 @@ type Service struct { dao *dao.Dao } -// 用户密码改变 -func (f *Service) UserPasswordChange(uid uint32) bool { - // dispatch登录态失效 - _, err := f.dao.UpdateAccountFieldByFieldName("accountID", uid, "token", "") +// UserPasswordChange 用户密码改变 +func (s *Service) UserPasswordChange(uid uint32) bool { + // http登录态失效 + _, err := s.dao.UpdateAccountFieldByFieldName("playerID", uid, "tokenCreateTime", 0) if err != nil { return false } - // 游戏内登录态失效 - account, err := f.dao.QueryAccountByField("accountID", uid) - if err != nil { - return false - } - if account == nil { - return false - } - // convId, exist := f.getConvIdByUserId(uint32(account.PlayerID)) - // if !exist { - // return true - // } - // f.kcpEventInput <- &net.KcpEvent{ - // ConvId: convId, - // EventId: net.KcpConnForceClose, - // EventMessage: uint32(kcp.EnetAccountPasswordChange), - // } + // TODO 游戏内登录态失效 return true } -// 封号 -func (f *Service) ForbidUser(info *ForbidUserInfo) bool { - if info == nil { - return false - } +// ForbidUser 封号 +func (s *Service) ForbidUser(uid uint32, forbidEndTime uint64) bool { // 写入账号封禁信息 - _, err := f.dao.UpdateAccountFieldByFieldName("accountID", info.UserId, "forbid", true) + _, err := s.dao.UpdateAccountFieldByFieldName("playerID", uid, "forbid", true) if err != nil { return false } - _, err = f.dao.UpdateAccountFieldByFieldName("accountID", info.UserId, "forbidEndTime", info.ForbidEndTime) + _, err = s.dao.UpdateAccountFieldByFieldName("playerID", uid, "forbidEndTime", forbidEndTime) if err != nil { return false } - // 游戏强制下线 - account, err := f.dao.QueryAccountByField("accountID", info.UserId) - if err != nil { - return false - } - if account == nil { - return false - } - // convId, exist := f.getConvIdByUserId(uint32(account.PlayerID)) - // if !exist { - // return true - // } - // f.kcpEventInput <- &net.KcpEvent{ - // ConvId: convId, - // EventId: net.KcpConnForceClose, - // EventMessage: uint32(kcp.EnetServerKillClient), - // } + // TODO 游戏强制下线 return true } -// 解封 +// UnForbidUser 解封 func (s *Service) UnForbidUser(uid uint32) bool { // 解除账号封禁 - _, err := s.dao.UpdateAccountFieldByFieldName("accountID", uid, "forbid", false) + _, err := s.dao.UpdateAccountFieldByFieldName("playerID", uid, "forbid", false) if err != nil { return false } diff --git a/gate/app/app.go b/gate/app/app.go index 24126c8a..9b2e54fd 100644 --- a/gate/app/app.go +++ b/gate/app/app.go @@ -6,6 +6,7 @@ import ( "os" "os/signal" "strings" + "sync/atomic" "syscall" "time" @@ -50,6 +51,7 @@ func Run(ctx context.Context, configFile string) error { _, err := client.Discovery.KeepaliveServer(context.TODO(), &api.KeepaliveServerReq{ ServerType: api.GATE, AppId: APPID, + LoadCount: uint32(atomic.LoadInt32(&net.CLIENT_CONN_NUM)), }) if err != nil { logger.Error("keepalive error: %v", err) @@ -75,7 +77,8 @@ func Run(ctx context.Context, configFile string) error { go func() { outputChan := connectManager.GetKcpEventOutputChan() for { - <-outputChan + kcpEvent := <-outputChan + logger.Info("kcpEvent: %v", kcpEvent) } }() diff --git a/gate/kcp/autotune.go b/gate/kcp/autotune.go deleted file mode 100644 index 7c189951..00000000 --- a/gate/kcp/autotune.go +++ /dev/null @@ -1,64 +0,0 @@ -package kcp - -const maxAutoTuneSamples = 258 - -// pulse represents a 0/1 signal with time sequence -type pulse struct { - bit bool // 0 or 1 - seq uint32 // sequence of the signal -} - -// autoTune object -type autoTune struct { - pulses [maxAutoTuneSamples]pulse -} - -// Sample adds a signal sample to the pulse buffer -func (tune *autoTune) Sample(bit bool, seq uint32) { - tune.pulses[seq%maxAutoTuneSamples] = pulse{bit, seq} -} - -// Find a period for a given signal -// returns -1 if not found -// -// --- ------ -// | | -// |______________| -// Period -// Falling Edge Rising Edge -func (tune *autoTune) FindPeriod(bit bool) int { - // last pulse and initial index setup - lastPulse := tune.pulses[0] - idx := 1 - - // left edge - var leftEdge int - for ; idx < len(tune.pulses); idx++ { - if lastPulse.bit != bit && tune.pulses[idx].bit == bit { // edge found - if lastPulse.seq+1 == tune.pulses[idx].seq { // ensure edge continuity - leftEdge = idx - break - } - } - lastPulse = tune.pulses[idx] - } - - // right edge - var rightEdge int - lastPulse = tune.pulses[leftEdge] - idx = leftEdge + 1 - - for ; idx < len(tune.pulses); idx++ { - if lastPulse.seq+1 == tune.pulses[idx].seq { // ensure pulses in this level monotonic - if lastPulse.bit == bit && tune.pulses[idx].bit != bit { // edge found - rightEdge = idx - break - } - } else { - return -1 - } - lastPulse = tune.pulses[idx] - } - - return rightEdge - leftEdge -} diff --git a/gate/kcp/autotune_test.go b/gate/kcp/autotune_test.go deleted file mode 100644 index 3dc1ecc6..00000000 --- a/gate/kcp/autotune_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package kcp - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestAutoTune(t *testing.T) { - signals := []uint32{0, 0, 0, 0, 0, 0} - - tune := autoTune{} - for i := 0; i < len(signals); i++ { - if signals[i] == 0 { - tune.Sample(false, uint32(i)) - } else { - tune.Sample(true, uint32(i)) - } - } - - assert.Equal(t, -1, tune.FindPeriod(false)) - assert.Equal(t, -1, tune.FindPeriod(true)) - - signals = []uint32{1, 0, 1, 0, 0, 1} - tune = autoTune{} - for i := 0; i < len(signals); i++ { - if signals[i] == 0 { - tune.Sample(false, uint32(i)) - } else { - tune.Sample(true, uint32(i)) - } - } - assert.Equal(t, 1, tune.FindPeriod(false)) - assert.Equal(t, 1, tune.FindPeriod(true)) - - signals = []uint32{1, 0, 0, 0, 0, 1} - tune = autoTune{} - for i := 0; i < len(signals); i++ { - if signals[i] == 0 { - tune.Sample(false, uint32(i)) - } else { - tune.Sample(true, uint32(i)) - } - } - assert.Equal(t, -1, tune.FindPeriod(true)) - assert.Equal(t, 4, tune.FindPeriod(false)) -} diff --git a/gate/kcp/crypt.go b/gate/kcp/crypt.go deleted file mode 100644 index 89d82e49..00000000 --- a/gate/kcp/crypt.go +++ /dev/null @@ -1,617 +0,0 @@ -package kcp - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/des" - "crypto/sha1" - "unsafe" - - xor "github.com/templexxx/xorsimd" - "github.com/tjfoc/gmsm/sm4" - "golang.org/x/crypto/blowfish" - "golang.org/x/crypto/cast5" - "golang.org/x/crypto/pbkdf2" - "golang.org/x/crypto/salsa20" - "golang.org/x/crypto/tea" - "golang.org/x/crypto/twofish" - "golang.org/x/crypto/xtea" -) - -var ( - initialVector = []byte{167, 115, 79, 156, 18, 172, 27, 1, 164, 21, 242, 193, 252, 120, 230, 107} - saltxor = `sH3CIVoF#rWLtJo6` -) - -// BlockCrypt defines encryption/decryption methods for a given byte slice. -// Notes on implementing: the data to be encrypted contains a builtin -// nonce at the first 16 bytes -type BlockCrypt interface { - // Encrypt encrypts the whole block in src into dst. - // Dst and src may point at the same memory. - Encrypt(dst, src []byte) - - // Decrypt decrypts the whole block in src into dst. - // Dst and src may point at the same memory. - Decrypt(dst, src []byte) -} - -type salsa20BlockCrypt struct { - key [32]byte -} - -// NewSalsa20BlockCrypt https://en.wikipedia.org/wiki/Salsa20 -func NewSalsa20BlockCrypt(key []byte) (BlockCrypt, error) { - c := new(salsa20BlockCrypt) - copy(c.key[:], key) - return c, nil -} - -func (c *salsa20BlockCrypt) Encrypt(dst, src []byte) { - salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key) - copy(dst[:8], src[:8]) -} -func (c *salsa20BlockCrypt) Decrypt(dst, src []byte) { - salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key) - copy(dst[:8], src[:8]) -} - -type sm4BlockCrypt struct { - encbuf [sm4.BlockSize]byte // 64bit alignment enc/dec buffer - decbuf [2 * sm4.BlockSize]byte - block cipher.Block -} - -// NewSM4BlockCrypt https://github.com/tjfoc/gmsm/tree/master/sm4 -func NewSM4BlockCrypt(key []byte) (BlockCrypt, error) { - c := new(sm4BlockCrypt) - block, err := sm4.NewCipher(key) - if err != nil { - return nil, err - } - c.block = block - return c, nil -} - -func (c *sm4BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } -func (c *sm4BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } - -type twofishBlockCrypt struct { - encbuf [twofish.BlockSize]byte - decbuf [2 * twofish.BlockSize]byte - block cipher.Block -} - -// NewTwofishBlockCrypt https://en.wikipedia.org/wiki/Twofish -func NewTwofishBlockCrypt(key []byte) (BlockCrypt, error) { - c := new(twofishBlockCrypt) - block, err := twofish.NewCipher(key) - if err != nil { - return nil, err - } - c.block = block - return c, nil -} - -func (c *twofishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } -func (c *twofishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } - -type tripleDESBlockCrypt struct { - encbuf [des.BlockSize]byte - decbuf [2 * des.BlockSize]byte - block cipher.Block -} - -// NewTripleDESBlockCrypt https://en.wikipedia.org/wiki/Triple_DES -func NewTripleDESBlockCrypt(key []byte) (BlockCrypt, error) { - c := new(tripleDESBlockCrypt) - block, err := des.NewTripleDESCipher(key) - if err != nil { - return nil, err - } - c.block = block - return c, nil -} - -func (c *tripleDESBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } -func (c *tripleDESBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } - -type cast5BlockCrypt struct { - encbuf [cast5.BlockSize]byte - decbuf [2 * cast5.BlockSize]byte - block cipher.Block -} - -// NewCast5BlockCrypt https://en.wikipedia.org/wiki/CAST-128 -func NewCast5BlockCrypt(key []byte) (BlockCrypt, error) { - c := new(cast5BlockCrypt) - block, err := cast5.NewCipher(key) - if err != nil { - return nil, err - } - c.block = block - return c, nil -} - -func (c *cast5BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } -func (c *cast5BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } - -type blowfishBlockCrypt struct { - encbuf [blowfish.BlockSize]byte - decbuf [2 * blowfish.BlockSize]byte - block cipher.Block -} - -// NewBlowfishBlockCrypt https://en.wikipedia.org/wiki/Blowfish_(cipher) -func NewBlowfishBlockCrypt(key []byte) (BlockCrypt, error) { - c := new(blowfishBlockCrypt) - block, err := blowfish.NewCipher(key) - if err != nil { - return nil, err - } - c.block = block - return c, nil -} - -func (c *blowfishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } -func (c *blowfishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } - -type aesBlockCrypt struct { - encbuf [aes.BlockSize]byte - decbuf [2 * aes.BlockSize]byte - block cipher.Block -} - -// NewAESBlockCrypt https://en.wikipedia.org/wiki/Advanced_Encryption_Standard -func NewAESBlockCrypt(key []byte) (BlockCrypt, error) { - c := new(aesBlockCrypt) - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - c.block = block - return c, nil -} - -func (c *aesBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } -func (c *aesBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } - -type teaBlockCrypt struct { - encbuf [tea.BlockSize]byte - decbuf [2 * tea.BlockSize]byte - block cipher.Block -} - -// NewTEABlockCrypt https://en.wikipedia.org/wiki/Tiny_Encryption_Algorithm -func NewTEABlockCrypt(key []byte) (BlockCrypt, error) { - c := new(teaBlockCrypt) - block, err := tea.NewCipherWithRounds(key, 16) - if err != nil { - return nil, err - } - c.block = block - return c, nil -} - -func (c *teaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } -func (c *teaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } - -type xteaBlockCrypt struct { - encbuf [xtea.BlockSize]byte - decbuf [2 * xtea.BlockSize]byte - block cipher.Block -} - -// NewXTEABlockCrypt https://en.wikipedia.org/wiki/XTEA -func NewXTEABlockCrypt(key []byte) (BlockCrypt, error) { - c := new(xteaBlockCrypt) - block, err := xtea.NewCipher(key) - if err != nil { - return nil, err - } - c.block = block - return c, nil -} - -func (c *xteaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) } -func (c *xteaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) } - -type simpleXORBlockCrypt struct { - xortbl []byte -} - -// NewSimpleXORBlockCrypt simple xor with key expanding -func NewSimpleXORBlockCrypt(key []byte) (BlockCrypt, error) { - c := new(simpleXORBlockCrypt) - c.xortbl = pbkdf2.Key(key, []byte(saltxor), 32, mtuLimit, sha1.New) - return c, nil -} - -func (c *simpleXORBlockCrypt) Encrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) } -func (c *simpleXORBlockCrypt) Decrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) } - -type noneBlockCrypt struct{} - -// NewNoneBlockCrypt does nothing but copying -func NewNoneBlockCrypt(key []byte) (BlockCrypt, error) { - return new(noneBlockCrypt), nil -} - -func (c *noneBlockCrypt) Encrypt(dst, src []byte) { copy(dst, src) } -func (c *noneBlockCrypt) Decrypt(dst, src []byte) { copy(dst, src) } - -// packet encryption with local CFB mode -func encrypt(block cipher.Block, dst, src, buf []byte) { - switch block.BlockSize() { - case 8: - encrypt8(block, dst, src, buf) - case 16: - encrypt16(block, dst, src, buf) - default: - panic("unsupported cipher block size") - } -} - -// optimized encryption for the ciphers which works in 8-bytes -func encrypt8(block cipher.Block, dst, src, buf []byte) { - tbl := buf[:8] - block.Encrypt(tbl, initialVector) - n := len(src) / 8 - base := 0 - repeat := n / 8 - left := n % 8 - ptr_tbl := (*uint64)(unsafe.Pointer(&tbl[0])) - - for i := 0; i < repeat; i++ { - s := src[base:][0:64] - d := dst[base:][0:64] - // 1 - *(*uint64)(unsafe.Pointer(&d[0])) = *(*uint64)(unsafe.Pointer(&s[0])) ^ *ptr_tbl - block.Encrypt(tbl, d[0:8]) - // 2 - *(*uint64)(unsafe.Pointer(&d[8])) = *(*uint64)(unsafe.Pointer(&s[8])) ^ *ptr_tbl - block.Encrypt(tbl, d[8:16]) - // 3 - *(*uint64)(unsafe.Pointer(&d[16])) = *(*uint64)(unsafe.Pointer(&s[16])) ^ *ptr_tbl - block.Encrypt(tbl, d[16:24]) - // 4 - *(*uint64)(unsafe.Pointer(&d[24])) = *(*uint64)(unsafe.Pointer(&s[24])) ^ *ptr_tbl - block.Encrypt(tbl, d[24:32]) - // 5 - *(*uint64)(unsafe.Pointer(&d[32])) = *(*uint64)(unsafe.Pointer(&s[32])) ^ *ptr_tbl - block.Encrypt(tbl, d[32:40]) - // 6 - *(*uint64)(unsafe.Pointer(&d[40])) = *(*uint64)(unsafe.Pointer(&s[40])) ^ *ptr_tbl - block.Encrypt(tbl, d[40:48]) - // 7 - *(*uint64)(unsafe.Pointer(&d[48])) = *(*uint64)(unsafe.Pointer(&s[48])) ^ *ptr_tbl - block.Encrypt(tbl, d[48:56]) - // 8 - *(*uint64)(unsafe.Pointer(&d[56])) = *(*uint64)(unsafe.Pointer(&s[56])) ^ *ptr_tbl - block.Encrypt(tbl, d[56:64]) - base += 64 - } - - switch left { - case 7: - *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl - block.Encrypt(tbl, dst[base:]) - base += 8 - fallthrough - case 6: - *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl - block.Encrypt(tbl, dst[base:]) - base += 8 - fallthrough - case 5: - *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl - block.Encrypt(tbl, dst[base:]) - base += 8 - fallthrough - case 4: - *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl - block.Encrypt(tbl, dst[base:]) - base += 8 - fallthrough - case 3: - *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl - block.Encrypt(tbl, dst[base:]) - base += 8 - fallthrough - case 2: - *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl - block.Encrypt(tbl, dst[base:]) - base += 8 - fallthrough - case 1: - *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl - block.Encrypt(tbl, dst[base:]) - base += 8 - fallthrough - case 0: - xorBytes(dst[base:], src[base:], tbl) - } -} - -// optimized encryption for the ciphers which works in 16-bytes -func encrypt16(block cipher.Block, dst, src, buf []byte) { - tbl := buf[:16] - block.Encrypt(tbl, initialVector) - n := len(src) / 16 - base := 0 - repeat := n / 8 - left := n % 8 - for i := 0; i < repeat; i++ { - s := src[base:][0:128] - d := dst[base:][0:128] - // 1 - xor.Bytes16Align(d[0:16], s[0:16], tbl) - block.Encrypt(tbl, d[0:16]) - // 2 - xor.Bytes16Align(d[16:32], s[16:32], tbl) - block.Encrypt(tbl, d[16:32]) - // 3 - xor.Bytes16Align(d[32:48], s[32:48], tbl) - block.Encrypt(tbl, d[32:48]) - // 4 - xor.Bytes16Align(d[48:64], s[48:64], tbl) - block.Encrypt(tbl, d[48:64]) - // 5 - xor.Bytes16Align(d[64:80], s[64:80], tbl) - block.Encrypt(tbl, d[64:80]) - // 6 - xor.Bytes16Align(d[80:96], s[80:96], tbl) - block.Encrypt(tbl, d[80:96]) - // 7 - xor.Bytes16Align(d[96:112], s[96:112], tbl) - block.Encrypt(tbl, d[96:112]) - // 8 - xor.Bytes16Align(d[112:128], s[112:128], tbl) - block.Encrypt(tbl, d[112:128]) - base += 128 - } - - switch left { - case 7: - xor.Bytes16Align(dst[base:], src[base:], tbl) - block.Encrypt(tbl, dst[base:]) - base += 16 - fallthrough - case 6: - xor.Bytes16Align(dst[base:], src[base:], tbl) - block.Encrypt(tbl, dst[base:]) - base += 16 - fallthrough - case 5: - xor.Bytes16Align(dst[base:], src[base:], tbl) - block.Encrypt(tbl, dst[base:]) - base += 16 - fallthrough - case 4: - xor.Bytes16Align(dst[base:], src[base:], tbl) - block.Encrypt(tbl, dst[base:]) - base += 16 - fallthrough - case 3: - xor.Bytes16Align(dst[base:], src[base:], tbl) - block.Encrypt(tbl, dst[base:]) - base += 16 - fallthrough - case 2: - xor.Bytes16Align(dst[base:], src[base:], tbl) - block.Encrypt(tbl, dst[base:]) - base += 16 - fallthrough - case 1: - xor.Bytes16Align(dst[base:], src[base:], tbl) - block.Encrypt(tbl, dst[base:]) - base += 16 - fallthrough - case 0: - xorBytes(dst[base:], src[base:], tbl) - } -} - -// decryption -func decrypt(block cipher.Block, dst, src, buf []byte) { - switch block.BlockSize() { - case 8: - decrypt8(block, dst, src, buf) - case 16: - decrypt16(block, dst, src, buf) - default: - panic("unsupported cipher block size") - } -} - -// decrypt 8 bytes block, all byte slices are supposed to be 64bit aligned -func decrypt8(block cipher.Block, dst, src, buf []byte) { - tbl := buf[0:8] - next := buf[8:16] - block.Encrypt(tbl, initialVector) - n := len(src) / 8 - base := 0 - repeat := n / 8 - left := n % 8 - ptr_tbl := (*uint64)(unsafe.Pointer(&tbl[0])) - ptr_next := (*uint64)(unsafe.Pointer(&next[0])) - - for i := 0; i < repeat; i++ { - s := src[base:][0:64] - d := dst[base:][0:64] - // 1 - block.Encrypt(next, s[0:8]) - *(*uint64)(unsafe.Pointer(&d[0])) = *(*uint64)(unsafe.Pointer(&s[0])) ^ *ptr_tbl - // 2 - block.Encrypt(tbl, s[8:16]) - *(*uint64)(unsafe.Pointer(&d[8])) = *(*uint64)(unsafe.Pointer(&s[8])) ^ *ptr_next - // 3 - block.Encrypt(next, s[16:24]) - *(*uint64)(unsafe.Pointer(&d[16])) = *(*uint64)(unsafe.Pointer(&s[16])) ^ *ptr_tbl - // 4 - block.Encrypt(tbl, s[24:32]) - *(*uint64)(unsafe.Pointer(&d[24])) = *(*uint64)(unsafe.Pointer(&s[24])) ^ *ptr_next - // 5 - block.Encrypt(next, s[32:40]) - *(*uint64)(unsafe.Pointer(&d[32])) = *(*uint64)(unsafe.Pointer(&s[32])) ^ *ptr_tbl - // 6 - block.Encrypt(tbl, s[40:48]) - *(*uint64)(unsafe.Pointer(&d[40])) = *(*uint64)(unsafe.Pointer(&s[40])) ^ *ptr_next - // 7 - block.Encrypt(next, s[48:56]) - *(*uint64)(unsafe.Pointer(&d[48])) = *(*uint64)(unsafe.Pointer(&s[48])) ^ *ptr_tbl - // 8 - block.Encrypt(tbl, s[56:64]) - *(*uint64)(unsafe.Pointer(&d[56])) = *(*uint64)(unsafe.Pointer(&s[56])) ^ *ptr_next - base += 64 - } - - switch left { - case 7: - block.Encrypt(next, src[base:]) - *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) - tbl, next = next, tbl - base += 8 - fallthrough - case 6: - block.Encrypt(next, src[base:]) - *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) - tbl, next = next, tbl - base += 8 - fallthrough - case 5: - block.Encrypt(next, src[base:]) - *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) - tbl, next = next, tbl - base += 8 - fallthrough - case 4: - block.Encrypt(next, src[base:]) - *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) - tbl, next = next, tbl - base += 8 - fallthrough - case 3: - block.Encrypt(next, src[base:]) - *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) - tbl, next = next, tbl - base += 8 - fallthrough - case 2: - block.Encrypt(next, src[base:]) - *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) - tbl, next = next, tbl - base += 8 - fallthrough - case 1: - block.Encrypt(next, src[base:]) - *(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0])) - tbl, next = next, tbl - base += 8 - fallthrough - case 0: - xorBytes(dst[base:], src[base:], tbl) - } -} - -func decrypt16(block cipher.Block, dst, src, buf []byte) { - tbl := buf[0:16] - next := buf[16:32] - block.Encrypt(tbl, initialVector) - n := len(src) / 16 - base := 0 - repeat := n / 8 - left := n % 8 - for i := 0; i < repeat; i++ { - s := src[base:][0:128] - d := dst[base:][0:128] - // 1 - block.Encrypt(next, s[0:16]) - xor.Bytes16Align(d[0:16], s[0:16], tbl) - // 2 - block.Encrypt(tbl, s[16:32]) - xor.Bytes16Align(d[16:32], s[16:32], next) - // 3 - block.Encrypt(next, s[32:48]) - xor.Bytes16Align(d[32:48], s[32:48], tbl) - // 4 - block.Encrypt(tbl, s[48:64]) - xor.Bytes16Align(d[48:64], s[48:64], next) - // 5 - block.Encrypt(next, s[64:80]) - xor.Bytes16Align(d[64:80], s[64:80], tbl) - // 6 - block.Encrypt(tbl, s[80:96]) - xor.Bytes16Align(d[80:96], s[80:96], next) - // 7 - block.Encrypt(next, s[96:112]) - xor.Bytes16Align(d[96:112], s[96:112], tbl) - // 8 - block.Encrypt(tbl, s[112:128]) - xor.Bytes16Align(d[112:128], s[112:128], next) - base += 128 - } - - switch left { - case 7: - block.Encrypt(next, src[base:]) - xor.Bytes16Align(dst[base:], src[base:], tbl) - tbl, next = next, tbl - base += 16 - fallthrough - case 6: - block.Encrypt(next, src[base:]) - xor.Bytes16Align(dst[base:], src[base:], tbl) - tbl, next = next, tbl - base += 16 - fallthrough - case 5: - block.Encrypt(next, src[base:]) - xor.Bytes16Align(dst[base:], src[base:], tbl) - tbl, next = next, tbl - base += 16 - fallthrough - case 4: - block.Encrypt(next, src[base:]) - xor.Bytes16Align(dst[base:], src[base:], tbl) - tbl, next = next, tbl - base += 16 - fallthrough - case 3: - block.Encrypt(next, src[base:]) - xor.Bytes16Align(dst[base:], src[base:], tbl) - tbl, next = next, tbl - base += 16 - fallthrough - case 2: - block.Encrypt(next, src[base:]) - xor.Bytes16Align(dst[base:], src[base:], tbl) - tbl, next = next, tbl - base += 16 - fallthrough - case 1: - block.Encrypt(next, src[base:]) - xor.Bytes16Align(dst[base:], src[base:], tbl) - tbl, next = next, tbl - base += 16 - fallthrough - case 0: - xorBytes(dst[base:], src[base:], tbl) - } -} - -// per bytes xors -func xorBytes(dst, a, b []byte) int { - n := len(a) - if len(b) < n { - n = len(b) - } - if n == 0 { - return 0 - } - - for i := 0; i < n; i++ { - dst[i] = a[i] ^ b[i] - } - return n -} diff --git a/gate/kcp/crypt_test.go b/gate/kcp/crypt_test.go deleted file mode 100644 index 2ef4dc8a..00000000 --- a/gate/kcp/crypt_test.go +++ /dev/null @@ -1,289 +0,0 @@ -package kcp - -import ( - "bytes" - "crypto/aes" - "crypto/md5" - "crypto/rand" - "crypto/sha1" - "hash/crc32" - "io" - "testing" -) - -func TestSM4(t *testing.T) { - bc, err := NewSM4BlockCrypt(pass[:16]) - if err != nil { - t.Fatal(err) - } - cryptTest(t, bc) -} - -func TestAES(t *testing.T) { - bc, err := NewAESBlockCrypt(pass[:32]) - if err != nil { - t.Fatal(err) - } - cryptTest(t, bc) -} - -func TestTEA(t *testing.T) { - bc, err := NewTEABlockCrypt(pass[:16]) - if err != nil { - t.Fatal(err) - } - cryptTest(t, bc) -} - -func TestXOR(t *testing.T) { - bc, err := NewSimpleXORBlockCrypt(pass[:32]) - if err != nil { - t.Fatal(err) - } - cryptTest(t, bc) -} - -func TestBlowfish(t *testing.T) { - bc, err := NewBlowfishBlockCrypt(pass[:32]) - if err != nil { - t.Fatal(err) - } - cryptTest(t, bc) -} - -func TestNone(t *testing.T) { - bc, err := NewNoneBlockCrypt(pass[:32]) - if err != nil { - t.Fatal(err) - } - cryptTest(t, bc) -} - -func TestCast5(t *testing.T) { - bc, err := NewCast5BlockCrypt(pass[:16]) - if err != nil { - t.Fatal(err) - } - cryptTest(t, bc) -} - -func Test3DES(t *testing.T) { - bc, err := NewTripleDESBlockCrypt(pass[:24]) - if err != nil { - t.Fatal(err) - } - cryptTest(t, bc) -} - -func TestTwofish(t *testing.T) { - bc, err := NewTwofishBlockCrypt(pass[:32]) - if err != nil { - t.Fatal(err) - } - cryptTest(t, bc) -} - -func TestXTEA(t *testing.T) { - bc, err := NewXTEABlockCrypt(pass[:16]) - if err != nil { - t.Fatal(err) - } - cryptTest(t, bc) -} - -func TestSalsa20(t *testing.T) { - bc, err := NewSalsa20BlockCrypt(pass[:32]) - if err != nil { - t.Fatal(err) - } - cryptTest(t, bc) -} - -func cryptTest(t *testing.T, bc BlockCrypt) { - data := make([]byte, mtuLimit) - io.ReadFull(rand.Reader, data) - dec := make([]byte, mtuLimit) - enc := make([]byte, mtuLimit) - bc.Encrypt(enc, data) - bc.Decrypt(dec, enc) - if !bytes.Equal(data, dec) { - t.Fail() - } -} - -func BenchmarkSM4(b *testing.B) { - bc, err := NewSM4BlockCrypt(pass[:16]) - if err != nil { - b.Fatal(err) - } - benchCrypt(b, bc) -} - -func BenchmarkAES128(b *testing.B) { - bc, err := NewAESBlockCrypt(pass[:16]) - if err != nil { - b.Fatal(err) - } - - benchCrypt(b, bc) -} - -func BenchmarkAES192(b *testing.B) { - bc, err := NewAESBlockCrypt(pass[:24]) - if err != nil { - b.Fatal(err) - } - - benchCrypt(b, bc) -} - -func BenchmarkAES256(b *testing.B) { - bc, err := NewAESBlockCrypt(pass[:32]) - if err != nil { - b.Fatal(err) - } - - benchCrypt(b, bc) -} - -func BenchmarkTEA(b *testing.B) { - bc, err := NewTEABlockCrypt(pass[:16]) - if err != nil { - b.Fatal(err) - } - benchCrypt(b, bc) -} - -func BenchmarkXOR(b *testing.B) { - bc, err := NewSimpleXORBlockCrypt(pass[:32]) - if err != nil { - b.Fatal(err) - } - benchCrypt(b, bc) -} - -func BenchmarkBlowfish(b *testing.B) { - bc, err := NewBlowfishBlockCrypt(pass[:32]) - if err != nil { - b.Fatal(err) - } - benchCrypt(b, bc) -} - -func BenchmarkNone(b *testing.B) { - bc, err := NewNoneBlockCrypt(pass[:32]) - if err != nil { - b.Fatal(err) - } - benchCrypt(b, bc) -} - -func BenchmarkCast5(b *testing.B) { - bc, err := NewCast5BlockCrypt(pass[:16]) - if err != nil { - b.Fatal(err) - } - benchCrypt(b, bc) -} - -func Benchmark3DES(b *testing.B) { - bc, err := NewTripleDESBlockCrypt(pass[:24]) - if err != nil { - b.Fatal(err) - } - benchCrypt(b, bc) -} - -func BenchmarkTwofish(b *testing.B) { - bc, err := NewTwofishBlockCrypt(pass[:32]) - if err != nil { - b.Fatal(err) - } - benchCrypt(b, bc) -} - -func BenchmarkXTEA(b *testing.B) { - bc, err := NewXTEABlockCrypt(pass[:16]) - if err != nil { - b.Fatal(err) - } - benchCrypt(b, bc) -} - -func BenchmarkSalsa20(b *testing.B) { - bc, err := NewSalsa20BlockCrypt(pass[:32]) - if err != nil { - b.Fatal(err) - } - benchCrypt(b, bc) -} - -func benchCrypt(b *testing.B, bc BlockCrypt) { - data := make([]byte, mtuLimit) - io.ReadFull(rand.Reader, data) - dec := make([]byte, mtuLimit) - enc := make([]byte, mtuLimit) - - b.ReportAllocs() - b.SetBytes(int64(len(enc) * 2)) - b.ResetTimer() - for i := 0; i < b.N; i++ { - bc.Encrypt(enc, data) - bc.Decrypt(dec, enc) - } -} - -func BenchmarkCRC32(b *testing.B) { - content := make([]byte, 1024) - b.SetBytes(int64(len(content))) - for i := 0; i < b.N; i++ { - crc32.ChecksumIEEE(content) - } -} - -func BenchmarkCsprngSystem(b *testing.B) { - data := make([]byte, md5.Size) - b.SetBytes(int64(len(data))) - - for i := 0; i < b.N; i++ { - io.ReadFull(rand.Reader, data) - } -} - -func BenchmarkCsprngMD5(b *testing.B) { - var data [md5.Size]byte - b.SetBytes(md5.Size) - - for i := 0; i < b.N; i++ { - data = md5.Sum(data[:]) - } -} -func BenchmarkCsprngSHA1(b *testing.B) { - var data [sha1.Size]byte - b.SetBytes(sha1.Size) - - for i := 0; i < b.N; i++ { - data = sha1.Sum(data[:]) - } -} - -func BenchmarkCsprngNonceMD5(b *testing.B) { - var ng nonceMD5 - ng.Init() - b.SetBytes(md5.Size) - data := make([]byte, md5.Size) - for i := 0; i < b.N; i++ { - ng.Fill(data) - } -} - -func BenchmarkCsprngNonceAES128(b *testing.B) { - var ng nonceAES128 - ng.Init() - - b.SetBytes(aes.BlockSize) - data := make([]byte, aes.BlockSize) - for i := 0; i < b.N; i++ { - ng.Fill(data) - } -} diff --git a/gate/kcp/entropy.go b/gate/kcp/entropy.go deleted file mode 100644 index a86cf4f4..00000000 --- a/gate/kcp/entropy.go +++ /dev/null @@ -1,52 +0,0 @@ -package kcp - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/md5" - "crypto/rand" - "io" -) - -// Entropy defines a entropy source -type Entropy interface { - Init() - Fill(nonce []byte) -} - -// nonceMD5 nonce generator for packet header -type nonceMD5 struct { - seed [md5.Size]byte -} - -func (n *nonceMD5) Init() { /*nothing required*/ } - -func (n *nonceMD5) Fill(nonce []byte) { - if n.seed[0] == 0 { // entropy update - io.ReadFull(rand.Reader, n.seed[:]) - } - n.seed = md5.Sum(n.seed[:]) - copy(nonce, n.seed[:]) -} - -// nonceAES128 nonce generator for packet headers -type nonceAES128 struct { - seed [aes.BlockSize]byte - block cipher.Block -} - -func (n *nonceAES128) Init() { - var key [16]byte // aes-128 - io.ReadFull(rand.Reader, key[:]) - io.ReadFull(rand.Reader, n.seed[:]) - block, _ := aes.NewCipher(key[:]) - n.block = block -} - -func (n *nonceAES128) Fill(nonce []byte) { - if n.seed[0] == 0 { // entropy update - io.ReadFull(rand.Reader, n.seed[:]) - } - n.block.Encrypt(n.seed[:], n.seed[:]) - copy(nonce, n.seed[:]) -} diff --git a/gate/kcp/fec.go b/gate/kcp/fec.go deleted file mode 100644 index a6126c03..00000000 --- a/gate/kcp/fec.go +++ /dev/null @@ -1,381 +0,0 @@ -package kcp - -import ( - "encoding/binary" - "sync/atomic" - - "github.com/klauspost/reedsolomon" -) - -const ( - fecHeaderSize = 6 - fecHeaderSizePlus2 = fecHeaderSize + 2 // plus 2B data size - typeData = 0xf1 - typeParity = 0xf2 - fecExpire = 60000 - rxFECMulti = 3 // FEC keeps rxFECMulti* (dataShard+parityShard) ordered packets in memory -) - -// fecPacket is a decoded FEC packet -type fecPacket []byte - -func (bts fecPacket) seqid() uint32 { return binary.LittleEndian.Uint32(bts) } -func (bts fecPacket) flag() uint16 { return binary.LittleEndian.Uint16(bts[4:]) } -func (bts fecPacket) data() []byte { return bts[6:] } - -// fecElement has auxcilliary time field -type fecElement struct { - fecPacket - ts uint32 -} - -// fecDecoder for decoding incoming packets -type fecDecoder struct { - rxlimit int // queue size limit - dataShards int - parityShards int - shardSize int - rx []fecElement // ordered receive queue - - // caches - decodeCache [][]byte - flagCache []bool - - // zeros - zeros []byte - - // RS decoder - codec reedsolomon.Encoder - - // auto tune fec parameter - autoTune autoTune -} - -func newFECDecoder(dataShards, parityShards int) *fecDecoder { - if dataShards <= 0 || parityShards <= 0 { - return nil - } - - dec := new(fecDecoder) - dec.dataShards = dataShards - dec.parityShards = parityShards - dec.shardSize = dataShards + parityShards - dec.rxlimit = rxFECMulti * dec.shardSize - codec, err := reedsolomon.New(dataShards, parityShards) - if err != nil { - return nil - } - dec.codec = codec - dec.decodeCache = make([][]byte, dec.shardSize) - dec.flagCache = make([]bool, dec.shardSize) - dec.zeros = make([]byte, mtuLimit) - return dec -} - -// decode a fec packet -func (dec *fecDecoder) decode(in fecPacket) (recovered [][]byte) { - // sample to auto FEC tuner - if in.flag() == typeData { - dec.autoTune.Sample(true, in.seqid()) - } else { - dec.autoTune.Sample(false, in.seqid()) - } - - // check if FEC parameters is out of sync - var shouldTune bool - if int(in.seqid())%dec.shardSize < dec.dataShards { - if in.flag() != typeData { // expect typeData - shouldTune = true - } - } else { - if in.flag() != typeParity { - shouldTune = true - } - } - - if shouldTune { - autoDS := dec.autoTune.FindPeriod(true) - autoPS := dec.autoTune.FindPeriod(false) - - // edges found, we can tune parameters now - if autoDS > 0 && autoPS > 0 && autoDS < 256 && autoPS < 256 { - // and make sure it's different - if autoDS != dec.dataShards || autoPS != dec.parityShards { - dec.dataShards = autoDS - dec.parityShards = autoPS - dec.shardSize = autoDS + autoPS - dec.rxlimit = rxFECMulti * dec.shardSize - codec, err := reedsolomon.New(autoDS, autoPS) - if err != nil { - return nil - } - dec.codec = codec - dec.decodeCache = make([][]byte, dec.shardSize) - dec.flagCache = make([]bool, dec.shardSize) - // log.Println("autotune to :", dec.dataShards, dec.parityShards) - } - } - } - - // insertion - n := len(dec.rx) - 1 - insertIdx := 0 - for i := n; i >= 0; i-- { - if in.seqid() == dec.rx[i].seqid() { // de-duplicate - return nil - } else if _itimediff(in.seqid(), dec.rx[i].seqid()) > 0 { // insertion - insertIdx = i + 1 - break - } - } - - // make a copy - pkt := fecPacket(xmitBuf.Get().([]byte)[:len(in)]) - copy(pkt, in) - elem := fecElement{pkt, currentMs()} - - // insert into ordered rx queue - if insertIdx == n+1 { - dec.rx = append(dec.rx, elem) - } else { - dec.rx = append(dec.rx, fecElement{}) - copy(dec.rx[insertIdx+1:], dec.rx[insertIdx:]) // shift right - dec.rx[insertIdx] = elem - } - - // shard range for current packet - shardBegin := pkt.seqid() - pkt.seqid()%uint32(dec.shardSize) - shardEnd := shardBegin + uint32(dec.shardSize) - 1 - - // max search range in ordered queue for current shard - searchBegin := insertIdx - int(pkt.seqid()%uint32(dec.shardSize)) - if searchBegin < 0 { - searchBegin = 0 - } - searchEnd := searchBegin + dec.shardSize - 1 - if searchEnd >= len(dec.rx) { - searchEnd = len(dec.rx) - 1 - } - - // re-construct datashards - if searchEnd-searchBegin+1 >= dec.dataShards { - var numshard, numDataShard, first, maxlen int - - // zero caches - shards := dec.decodeCache - shardsflag := dec.flagCache - for k := range dec.decodeCache { - shards[k] = nil - shardsflag[k] = false - } - - // shard assembly - for i := searchBegin; i <= searchEnd; i++ { - seqid := dec.rx[i].seqid() - if _itimediff(seqid, shardEnd) > 0 { - break - } else if _itimediff(seqid, shardBegin) >= 0 { - shards[seqid%uint32(dec.shardSize)] = dec.rx[i].data() - shardsflag[seqid%uint32(dec.shardSize)] = true - numshard++ - if dec.rx[i].flag() == typeData { - numDataShard++ - } - if numshard == 1 { - first = i - } - if len(dec.rx[i].data()) > maxlen { - maxlen = len(dec.rx[i].data()) - } - } - } - - if numDataShard == dec.dataShards { - // case 1: no loss on data shards - dec.rx = dec.freeRange(first, numshard, dec.rx) - } else if numshard >= dec.dataShards { - // case 2: loss on data shards, but it's recoverable from parity shards - for k := range shards { - if shards[k] != nil { - dlen := len(shards[k]) - shards[k] = shards[k][:maxlen] - copy(shards[k][dlen:], dec.zeros) - } else if k < dec.dataShards { - shards[k] = xmitBuf.Get().([]byte)[:0] - } - } - if err := dec.codec.ReconstructData(shards); err == nil { - for k := range shards[:dec.dataShards] { - if !shardsflag[k] { - // recovered data should be recycled - recovered = append(recovered, shards[k]) - } - } - } - dec.rx = dec.freeRange(first, numshard, dec.rx) - } - } - - // keep rxlimit - if len(dec.rx) > dec.rxlimit { - if dec.rx[0].flag() == typeData { // track the unrecoverable data - atomic.AddUint64(&DefaultSnmp.FECShortShards, 1) - } - dec.rx = dec.freeRange(0, 1, dec.rx) - } - - // timeout policy - current := currentMs() - numExpired := 0 - for k := range dec.rx { - if _itimediff(current, dec.rx[k].ts) > fecExpire { - numExpired++ - continue - } - break - } - if numExpired > 0 { - dec.rx = dec.freeRange(0, numExpired, dec.rx) - } - return -} - -// free a range of fecPacket -func (dec *fecDecoder) freeRange(first, n int, q []fecElement) []fecElement { - for i := first; i < first+n; i++ { // recycle buffer - xmitBuf.Put([]byte(q[i].fecPacket)) - } - - if first == 0 && n < cap(q)/2 { - return q[n:] - } - copy(q[first:], q[first+n:]) - return q[:len(q)-n] -} - -// release all segments back to xmitBuf -func (dec *fecDecoder) release() { - if n := len(dec.rx); n > 0 { - dec.rx = dec.freeRange(0, n, dec.rx) - } -} - -type ( - // fecEncoder for encoding outgoing packets - fecEncoder struct { - dataShards int - parityShards int - shardSize int - paws uint32 // Protect Against Wrapped Sequence numbers - next uint32 // next seqid - - shardCount int // count the number of datashards collected - maxSize int // track maximum data length in datashard - - headerOffset int // FEC header offset - payloadOffset int // FEC payload offset - - // caches - shardCache [][]byte - encodeCache [][]byte - - // zeros - zeros []byte - - // RS encoder - codec reedsolomon.Encoder - } -) - -func newFECEncoder(dataShards, parityShards, offset int) *fecEncoder { - if dataShards <= 0 || parityShards <= 0 { - return nil - } - enc := new(fecEncoder) - enc.dataShards = dataShards - enc.parityShards = parityShards - enc.shardSize = dataShards + parityShards - enc.paws = 0xffffffff / uint32(enc.shardSize) * uint32(enc.shardSize) - enc.headerOffset = offset - enc.payloadOffset = enc.headerOffset + fecHeaderSize - - codec, err := reedsolomon.New(dataShards, parityShards) - if err != nil { - return nil - } - enc.codec = codec - - // caches - enc.encodeCache = make([][]byte, enc.shardSize) - enc.shardCache = make([][]byte, enc.shardSize) - for k := range enc.shardCache { - enc.shardCache[k] = make([]byte, mtuLimit) - } - enc.zeros = make([]byte, mtuLimit) - return enc -} - -// encodes the packet, outputs parity shards if we have collected quorum datashards -// notice: the contents of 'ps' will be re-written in successive calling -func (enc *fecEncoder) encode(b []byte) (ps [][]byte) { - // The header format: - // | FEC SEQID(4B) | FEC TYPE(2B) | SIZE (2B) | PAYLOAD(SIZE-2) | - // |<-headerOffset |<-payloadOffset - enc.markData(b[enc.headerOffset:]) - binary.LittleEndian.PutUint16(b[enc.payloadOffset:], uint16(len(b[enc.payloadOffset:]))) - - // copy data from payloadOffset to fec shard cache - sz := len(b) - enc.shardCache[enc.shardCount] = enc.shardCache[enc.shardCount][:sz] - copy(enc.shardCache[enc.shardCount][enc.payloadOffset:], b[enc.payloadOffset:]) - enc.shardCount++ - - // track max datashard length - if sz > enc.maxSize { - enc.maxSize = sz - } - - // Generation of Reed-Solomon Erasure Code - if enc.shardCount == enc.dataShards { - // fill '0' into the tail of each datashard - for i := 0; i < enc.dataShards; i++ { - shard := enc.shardCache[i] - slen := len(shard) - copy(shard[slen:enc.maxSize], enc.zeros) - } - - // construct equal-sized slice with stripped header - cache := enc.encodeCache - for k := range cache { - cache[k] = enc.shardCache[k][enc.payloadOffset:enc.maxSize] - } - - // encoding - if err := enc.codec.Encode(cache); err == nil { - ps = enc.shardCache[enc.dataShards:] - for k := range ps { - enc.markParity(ps[k][enc.headerOffset:]) - ps[k] = ps[k][:enc.maxSize] - } - } - - // counters resetting - enc.shardCount = 0 - enc.maxSize = 0 - } - - return -} - -func (enc *fecEncoder) markData(data []byte) { - binary.LittleEndian.PutUint32(data, enc.next) - binary.LittleEndian.PutUint16(data[4:], typeData) - enc.next++ -} - -func (enc *fecEncoder) markParity(data []byte) { - binary.LittleEndian.PutUint32(data, enc.next) - binary.LittleEndian.PutUint16(data[4:], typeParity) - // sequence wrap will only happen at parity shard - enc.next = (enc.next + 1) % enc.paws -} diff --git a/gate/kcp/fec_test.go b/gate/kcp/fec_test.go deleted file mode 100644 index 59b64aca..00000000 --- a/gate/kcp/fec_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package kcp - -import ( - "encoding/binary" - "math/rand" - "testing" -) - -func BenchmarkFECDecode(b *testing.B) { - const dataSize = 10 - const paritySize = 3 - const payLoad = 1500 - decoder := newFECDecoder(dataSize, paritySize) - b.ReportAllocs() - b.SetBytes(payLoad) - for i := 0; i < b.N; i++ { - if rand.Int()%(dataSize+paritySize) == 0 { // random loss - continue - } - pkt := make([]byte, payLoad) - binary.LittleEndian.PutUint32(pkt, uint32(i)) - if i%(dataSize+paritySize) >= dataSize { - binary.LittleEndian.PutUint16(pkt[4:], typeParity) - } else { - binary.LittleEndian.PutUint16(pkt[4:], typeData) - } - decoder.decode(pkt) - } -} - -func BenchmarkFECEncode(b *testing.B) { - const dataSize = 10 - const paritySize = 3 - const payLoad = 1500 - - b.ReportAllocs() - b.SetBytes(payLoad) - encoder := newFECEncoder(dataSize, paritySize, 0) - for i := 0; i < b.N; i++ { - data := make([]byte, payLoad) - encoder.encode(data) - } -} diff --git a/gate/kcp/kcp_test.go b/gate/kcp/kcp_test.go index 49d55d5a..d23ecabe 100644 --- a/gate/kcp/kcp_test.go +++ b/gate/kcp/kcp_test.go @@ -74,8 +74,8 @@ func TestLossyConn4(t *testing.T) { func testlink(t *testing.T, client *lossyconn.LossyConn, server *lossyconn.LossyConn, nodelay, interval, resend, nc int) { t.Log("testing with nodelay parameters:", nodelay, interval, resend, nc) - sess, _ := NewConn2(server.LocalAddr(), nil, 0, 0, client) - listener, _ := ServeConn(nil, 0, 0, server) + sess, _ := NewConn2(server.LocalAddr(), client) + listener, _ := ServeConn(server) echoServer := func(l *Listener) { for { conn, err := l.AcceptKCP() diff --git a/gate/kcp/readloop_generic.go b/gate/kcp/readloop_generic.go deleted file mode 100644 index 08e72cb8..00000000 --- a/gate/kcp/readloop_generic.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build !linux -// +build !linux - -package kcp - -func (s *UDPSession) readLoop() { - s.defaultReadLoop() -} - -func (l *Listener) monitor() { - l.defaultMonitor() -} diff --git a/gate/kcp/sess.go b/gate/kcp/sess.go index 85ea6672..4eae1ca5 100644 --- a/gate/kcp/sess.go +++ b/gate/kcp/sess.go @@ -11,7 +11,6 @@ import ( "crypto/rand" "encoding/binary" "encoding/hex" - "hash/crc32" "io" "net" "sync" @@ -65,16 +64,16 @@ type ( ownConn bool // true if we created conn internally, false if provided by caller kcp *KCP // KCP ARQ protocol l *Listener // pointing to the Listener object if it's been accepted by a Listener - block BlockCrypt // block encryption object + // block BlockCrypt // block encryption object // kcp receiving is based on packets // recvbuf turns packets into stream recvbuf []byte bufptr []byte - // FEC codec - fecDecoder *fecDecoder - fecEncoder *fecEncoder + // // FEC codec + // fecDecoder *fecDecoder + // fecEncoder *fecEncoder // settings remote net.Addr // remote peer address @@ -99,8 +98,8 @@ type ( socketReadErrorOnce sync.Once socketWriteErrorOnce sync.Once - // nonce generator - nonce Entropy + // // nonce generator + // nonce Entropy // packets waiting to be sent on wire txqueue []ipv4.Message @@ -124,11 +123,11 @@ type ( ) // newUDPSession create a new udp session for client or server -func newUDPSession(conv uint64, dataShards, parityShards int, l *Listener, conn net.PacketConn, ownConn bool, remote net.Addr, block BlockCrypt) *UDPSession { +func newUDPSession(conv uint64, l *Listener, conn net.PacketConn, ownConn bool, remote net.Addr) *UDPSession { sess := new(UDPSession) sess.die = make(chan struct{}) - sess.nonce = new(nonceAES128) - sess.nonce.Init() + // sess.nonce = new(nonceAES128) + // sess.nonce.Init() sess.chReadEvent = make(chan struct{}, 1) sess.chWriteEvent = make(chan struct{}, 1) sess.chSocketReadError = make(chan struct{}) @@ -137,7 +136,7 @@ func newUDPSession(conv uint64, dataShards, parityShards int, l *Listener, conn sess.conn = conn sess.ownConn = ownConn sess.l = l - sess.block = block + // sess.block = block sess.recvbuf = make([]byte, mtuLimit) // cast to writebatch conn @@ -152,21 +151,21 @@ func newUDPSession(conv uint64, dataShards, parityShards int, l *Listener, conn } } - // FEC codec initialization - sess.fecDecoder = newFECDecoder(dataShards, parityShards) - if sess.block != nil { - sess.fecEncoder = newFECEncoder(dataShards, parityShards, cryptHeaderSize) - } else { - sess.fecEncoder = newFECEncoder(dataShards, parityShards, 0) - } + // // FEC codec initialization + // sess.fecDecoder = newFECDecoder(dataShards, parityShards) + // if sess.block != nil { + // sess.fecEncoder = newFECEncoder(dataShards, parityShards, cryptHeaderSize) + // } else { + // sess.fecEncoder = newFECEncoder(dataShards, parityShards, 0) + // } - // calculate additional header size introduced by FEC and encryption - if sess.block != nil { - sess.headerSize += cryptHeaderSize - } - if sess.fecEncoder != nil { - sess.headerSize += fecHeaderSizePlus2 - } + // // calculate additional header size introduced by FEC and encryption + // if sess.block != nil { + // sess.headerSize += cryptHeaderSize + // } + // if sess.fecEncoder != nil { + // sess.headerSize += fecHeaderSizePlus2 + // } sess.kcp = NewKCP(conv, func(buf []byte, size int) { if size >= IKCP_OVERHEAD+sess.headerSize { @@ -371,9 +370,9 @@ func (s *UDPSession) Close() error { s.uncork() // release pending segments s.kcp.ReleaseTX() - if s.fecDecoder != nil { - s.fecDecoder.release() - } + // if s.fecDecoder != nil { + // s.fecDecoder.release() + // } s.mu.Unlock() if s.l != nil { // belongs to listener @@ -552,25 +551,25 @@ func (s *UDPSession) SetWriteBuffer(bytes int) error { func (s *UDPSession) output(buf []byte) { var ecc [][]byte - // 1. FEC encoding - if s.fecEncoder != nil { - ecc = s.fecEncoder.encode(buf) - } + // // 1. FEC encoding + // if s.fecEncoder != nil { + // ecc = s.fecEncoder.encode(buf) + // } - // 2&3. crc32 & encryption - if s.block != nil { - s.nonce.Fill(buf[:nonceSize]) - checksum := crc32.ChecksumIEEE(buf[cryptHeaderSize:]) - binary.LittleEndian.PutUint32(buf[nonceSize:], checksum) - s.block.Encrypt(buf, buf) - - for k := range ecc { - s.nonce.Fill(ecc[k][:nonceSize]) - checksum := crc32.ChecksumIEEE(ecc[k][cryptHeaderSize:]) - binary.LittleEndian.PutUint32(ecc[k][nonceSize:], checksum) - s.block.Encrypt(ecc[k], ecc[k]) - } - } + // // 2&3. crc32 & encryption + // if s.block != nil { + // s.nonce.Fill(buf[:nonceSize]) + // checksum := crc32.ChecksumIEEE(buf[cryptHeaderSize:]) + // binary.LittleEndian.PutUint32(buf[nonceSize:], checksum) + // s.block.Encrypt(buf, buf) + // + // for k := range ecc { + // s.nonce.Fill(ecc[k][:nonceSize]) + // checksum := crc32.ChecksumIEEE(ecc[k][cryptHeaderSize:]) + // binary.LittleEndian.PutUint32(ecc[k][nonceSize:], checksum) + // s.block.Encrypt(ecc[k], ecc[k]) + // } + // } // 4. TxQueue var msg ipv4.Message @@ -663,114 +662,130 @@ func (s *UDPSession) notifyWriteError(err error) { // packet input stage func (s *UDPSession) packetInput(data []byte) { - decrypted := false - if s.block != nil && len(data) >= cryptHeaderSize { - s.block.Decrypt(data, data) - data = data[nonceSize:] - checksum := crc32.ChecksumIEEE(data[crcSize:]) - if checksum == binary.LittleEndian.Uint32(data) { - data = data[crcSize:] - decrypted = true - } else { - atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1) - } - } else if s.block == nil { - decrypted = true - } + // decrypted := false + // if s.block != nil && len(data) >= cryptHeaderSize { + // s.block.Decrypt(data, data) + // data = data[nonceSize:] + // checksum := crc32.ChecksumIEEE(data[crcSize:]) + // if checksum == binary.LittleEndian.Uint32(data) { + // data = data[crcSize:] + // decrypted = true + // } else { + // atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1) + // } + // } else if s.block == nil { + // decrypted = true + // } + decrypted := true if decrypted && len(data) >= IKCP_OVERHEAD { s.kcpInput(data) } } func (s *UDPSession) kcpInput(data []byte) { - var kcpInErrors, fecErrs, fecRecovered, fecParityShards uint64 + // var kcpInErrors, fecErrs, fecRecovered, fecParityShards uint64 + // + // fecFlag := binary.LittleEndian.Uint16(data[8:]) + // if fecFlag == typeData || fecFlag == typeParity { // 16bit kcp cmd [81-84] and frg [0-255] will not overlap with FEC type 0x00f1 0x00f2 + // if len(data) >= fecHeaderSizePlus2 { + // f := fecPacket(data) + // if f.flag() == typeParity { + // fecParityShards++ + // } + // + // // lock + // s.mu.Lock() + // // if fecDecoder is not initialized, create one with default parameter + // if s.fecDecoder == nil { + // s.fecDecoder = newFECDecoder(1, 1) + // } + // recovers := s.fecDecoder.decode(f) + // if f.flag() == typeData { + // if ret := s.kcp.Input(data[fecHeaderSizePlus2:], true, s.ackNoDelay); ret != 0 { + // kcpInErrors++ + // } + // } + // + // for _, r := range recovers { + // if len(r) >= 2 { // must be larger than 2bytes + // sz := binary.LittleEndian.Uint16(r) + // if int(sz) <= len(r) && sz >= 2 { + // if ret := s.kcp.Input(r[2:sz], false, s.ackNoDelay); ret == 0 { + // fecRecovered++ + // } else { + // kcpInErrors++ + // } + // } else { + // fecErrs++ + // } + // } else { + // fecErrs++ + // } + // // recycle the recovers + // xmitBuf.Put(r) + // } + // + // // to notify the readers to receive the data + // if n := s.kcp.PeekSize(); n > 0 { + // s.notifyReadEvent() + // } + // // to notify the writers + // waitsnd := s.kcp.WaitSnd() + // if waitsnd < int(s.kcp.snd_wnd) && waitsnd < int(s.kcp.rmt_wnd) { + // s.notifyWriteEvent() + // } + // + // s.uncork() + // s.mu.Unlock() + // } else { + // atomic.AddUint64(&DefaultSnmp.InErrs, 1) + // } + // } else { + // s.mu.Lock() + // if ret := s.kcp.Input(data, true, s.ackNoDelay); ret != 0 { + // kcpInErrors++ + // } + // if n := s.kcp.PeekSize(); n > 0 { + // s.notifyReadEvent() + // } + // waitsnd := s.kcp.WaitSnd() + // if waitsnd < int(s.kcp.snd_wnd) && waitsnd < int(s.kcp.rmt_wnd) { + // s.notifyWriteEvent() + // } + // s.uncork() + // s.mu.Unlock() + // } - fecFlag := binary.LittleEndian.Uint16(data[8:]) - if fecFlag == typeData || fecFlag == typeParity { // 16bit kcp cmd [81-84] and frg [0-255] will not overlap with FEC type 0x00f1 0x00f2 - if len(data) >= fecHeaderSizePlus2 { - f := fecPacket(data) - if f.flag() == typeParity { - fecParityShards++ - } - - // lock - s.mu.Lock() - // if fecDecoder is not initialized, create one with default parameter - if s.fecDecoder == nil { - s.fecDecoder = newFECDecoder(1, 1) - } - recovers := s.fecDecoder.decode(f) - if f.flag() == typeData { - if ret := s.kcp.Input(data[fecHeaderSizePlus2:], true, s.ackNoDelay); ret != 0 { - kcpInErrors++ - } - } - - for _, r := range recovers { - if len(r) >= 2 { // must be larger than 2bytes - sz := binary.LittleEndian.Uint16(r) - if int(sz) <= len(r) && sz >= 2 { - if ret := s.kcp.Input(r[2:sz], false, s.ackNoDelay); ret == 0 { - fecRecovered++ - } else { - kcpInErrors++ - } - } else { - fecErrs++ - } - } else { - fecErrs++ - } - // recycle the recovers - xmitBuf.Put(r) - } - - // to notify the readers to receive the data - if n := s.kcp.PeekSize(); n > 0 { - s.notifyReadEvent() - } - // to notify the writers - waitsnd := s.kcp.WaitSnd() - if waitsnd < int(s.kcp.snd_wnd) && waitsnd < int(s.kcp.rmt_wnd) { - s.notifyWriteEvent() - } - - s.uncork() - s.mu.Unlock() - } else { - atomic.AddUint64(&DefaultSnmp.InErrs, 1) - } - } else { - s.mu.Lock() - if ret := s.kcp.Input(data, true, s.ackNoDelay); ret != 0 { - kcpInErrors++ - } - if n := s.kcp.PeekSize(); n > 0 { - s.notifyReadEvent() - } - waitsnd := s.kcp.WaitSnd() - if waitsnd < int(s.kcp.snd_wnd) && waitsnd < int(s.kcp.rmt_wnd) { - s.notifyWriteEvent() - } - s.uncork() - s.mu.Unlock() + var kcpInErrors uint64 + s.mu.Lock() + if ret := s.kcp.Input(data, true, s.ackNoDelay); ret != 0 { + kcpInErrors++ } + if n := s.kcp.PeekSize(); n > 0 { + s.notifyReadEvent() + } + waitsnd := s.kcp.WaitSnd() + if waitsnd < int(s.kcp.snd_wnd) && waitsnd < int(s.kcp.rmt_wnd) { + s.notifyWriteEvent() + } + s.uncork() + s.mu.Unlock() atomic.AddUint64(&DefaultSnmp.InPkts, 1) atomic.AddUint64(&DefaultSnmp.InBytes, uint64(len(data))) - if fecParityShards > 0 { - atomic.AddUint64(&DefaultSnmp.FECParityShards, fecParityShards) - } if kcpInErrors > 0 { atomic.AddUint64(&DefaultSnmp.KCPInErrors, kcpInErrors) } - if fecErrs > 0 { - atomic.AddUint64(&DefaultSnmp.FECErrs, fecErrs) - } - if fecRecovered > 0 { - atomic.AddUint64(&DefaultSnmp.FECRecovered, fecRecovered) - } + // if fecParityShards > 0 { + // atomic.AddUint64(&DefaultSnmp.FECParityShards, fecParityShards) + // } + // if fecErrs > 0 { + // atomic.AddUint64(&DefaultSnmp.FECErrs, fecErrs) + // } + // if fecRecovered > 0 { + // atomic.AddUint64(&DefaultSnmp.FECRecovered, fecRecovered) + // } } @@ -828,11 +843,11 @@ const ( type ( // Listener defines a server which will be waiting to accept incoming connections Listener struct { - block BlockCrypt // block encryption - dataShards int // FEC data shard - parityShards int // FEC parity shard - conn net.PacketConn // the underlying packet connection - ownConn bool // true if we created conn internally, false if provided by caller + // block BlockCrypt // block encryption + // dataShards int // FEC data shard + // parityShards int // FEC parity shard + conn net.PacketConn // the underlying packet connection + ownConn bool // true if we created conn internally, false if provided by caller // 网络切换会话保持改造 将convId作为会话的唯一标识 不再校验源地址 sessions map[uint64]*UDPSession // all sessions accepted by this Listener @@ -856,21 +871,22 @@ type ( // packet input stage func (l *Listener) packetInput(data []byte, addr net.Addr, convId uint64) { - decrypted := false - if l.block != nil && len(data) >= cryptHeaderSize { - l.block.Decrypt(data, data) - data = data[nonceSize:] - checksum := crc32.ChecksumIEEE(data[crcSize:]) - if checksum == binary.LittleEndian.Uint32(data) { - data = data[crcSize:] - decrypted = true - } else { - atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1) - } - } else if l.block == nil { - decrypted = true - } + // decrypted := false + // if l.block != nil && len(data) >= cryptHeaderSize { + // l.block.Decrypt(data, data) + // data = data[nonceSize:] + // checksum := crc32.ChecksumIEEE(data[crcSize:]) + // if checksum == binary.LittleEndian.Uint32(data) { + // data = data[crcSize:] + // decrypted = true + // } else { + // atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1) + // } + // } else if l.block == nil { + // decrypted = true + // } + decrypted := true if decrypted && len(data) >= IKCP_OVERHEAD { l.sessionLock.RLock() s, ok := l.sessions[convId] @@ -879,20 +895,25 @@ func (l *Listener) packetInput(data []byte, addr net.Addr, convId uint64) { var conv uint64 var sn uint32 convRecovered := false - fecFlag := binary.LittleEndian.Uint16(data[8:]) - if fecFlag == typeData || fecFlag == typeParity { // 16bit kcp cmd [81-84] and frg [0-255] will not overlap with FEC type 0x00f1 0x00f2 - // packet with FEC - if fecFlag == typeData && len(data) >= fecHeaderSizePlus2+IKCP_OVERHEAD { - conv = binary.LittleEndian.Uint64(data[fecHeaderSizePlus2:]) - sn = binary.LittleEndian.Uint32(data[fecHeaderSizePlus2+IKCP_SN_OFFSET:]) - convRecovered = true - } - } else { - // packet without FEC - conv = binary.LittleEndian.Uint64(data) - sn = binary.LittleEndian.Uint32(data[IKCP_SN_OFFSET:]) - convRecovered = true - } + // fecFlag := binary.LittleEndian.Uint16(data[8:]) + // if fecFlag == typeData || fecFlag == typeParity { // 16bit kcp cmd [81-84] and frg [0-255] will not overlap with FEC type 0x00f1 0x00f2 + // // packet with FEC + // if fecFlag == typeData && len(data) >= fecHeaderSizePlus2+IKCP_OVERHEAD { + // conv = binary.LittleEndian.Uint64(data[fecHeaderSizePlus2:]) + // sn = binary.LittleEndian.Uint32(data[fecHeaderSizePlus2+IKCP_SN_OFFSET:]) + // convRecovered = true + // } + // } else { + // // packet without FEC + // conv = binary.LittleEndian.Uint64(data) + // sn = binary.LittleEndian.Uint32(data[IKCP_SN_OFFSET:]) + // convRecovered = true + // } + + // packet without FEC + conv = binary.LittleEndian.Uint64(data) + sn = binary.LittleEndian.Uint32(data[IKCP_SN_OFFSET:]) + convRecovered = true if ok { // existing connection if !convRecovered || conv == s.kcp.conv { // parity data or valid conversation @@ -906,7 +927,7 @@ func (l *Listener) packetInput(data []byte, addr net.Addr, convId uint64) { if s == nil && convRecovered { // new session if len(l.chAccepts) < cap(l.chAccepts) { // do not let the new sessions overwhelm accept queue - s := newUDPSession(conv, l.dataShards, l.parityShards, l, l.conn, false, addr, l.block) + s := newUDPSession(conv, l, l.conn, false, addr) s.kcpInput(data) l.sessionLock.Lock() l.sessions[convId] = s @@ -1047,7 +1068,7 @@ func (l *Listener) closeSession(convId uint64) (ret bool) { func (l *Listener) Addr() net.Addr { return l.conn.LocalAddr() } // Listen listens for incoming KCP packets addressed to the local address laddr on the network "udp", -func Listen(laddr string) (net.Listener, error) { return ListenWithOptions(laddr, nil, 0, 0) } +func Listen(laddr string) (net.Listener, error) { return ListenWithOptions(laddr) } // ListenWithOptions listens for incoming KCP packets addressed to the local address laddr on the network "udp" with packet encryption. // @@ -1056,7 +1077,7 @@ func Listen(laddr string) (net.Listener, error) { return ListenWithOptions(laddr // 'dataShards', 'parityShards' specify how many parity packets will be generated following the data packets. // // Check https://github.com/klauspost/reedsolomon for details -func ListenWithOptions(laddr string, block BlockCrypt, dataShards, parityShards int) (*Listener, error) { +func ListenWithOptions(laddr string) (*Listener, error) { udpaddr, err := net.ResolveUDPAddr("udp", laddr) if err != nil { return nil, errors.WithStack(err) @@ -1066,15 +1087,15 @@ func ListenWithOptions(laddr string, block BlockCrypt, dataShards, parityShards return nil, errors.WithStack(err) } - return serveConn(block, dataShards, parityShards, conn, true) + return serveConn(conn, true) } // ServeConn serves KCP protocol for a single packet connection. -func ServeConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*Listener, error) { - return serveConn(block, dataShards, parityShards, conn, false) +func ServeConn(conn net.PacketConn) (*Listener, error) { + return serveConn(conn, false) } -func serveConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketConn, ownConn bool) (*Listener, error) { +func serveConn(conn net.PacketConn, ownConn bool) (*Listener, error) { l := new(Listener) l.conn = conn l.ownConn = ownConn @@ -1082,9 +1103,9 @@ func serveConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketCo l.chAccepts = make(chan *UDPSession, acceptBacklog) l.chSessionClosed = make(chan net.Addr) l.die = make(chan struct{}) - l.dataShards = dataShards - l.parityShards = parityShards - l.block = block + // l.dataShards = dataShards + // l.parityShards = parityShards + // l.block = block l.chSocketReadError = make(chan struct{}) l.EnetNotify = make(chan *Enet, 1000) go l.monitor() @@ -1092,7 +1113,7 @@ func serveConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketCo } // Dial connects to the remote address "raddr" on the network "udp" without encryption and FEC -func Dial(raddr string) (net.Conn, error) { return DialWithOptions(raddr, nil, 0, 0) } +func Dial(raddr string) (net.Conn, error) { return DialWithOptions(raddr) } // DialWithOptions connects to the remote address "raddr" on the network "udp" with packet encryption // @@ -1101,7 +1122,7 @@ func Dial(raddr string) (net.Conn, error) { return DialWithOptions(raddr, nil, 0 // 'dataShards', 'parityShards' specify how many parity packets will be generated following the data packets. // // Check https://github.com/klauspost/reedsolomon for details -func DialWithOptions(raddr string, block BlockCrypt, dataShards, parityShards int) (*UDPSession, error) { +func DialWithOptions(raddr string) (*UDPSession, error) { // network type detection udpaddr, err := net.ResolveUDPAddr("udp", raddr) if err != nil { @@ -1119,26 +1140,26 @@ func DialWithOptions(raddr string, block BlockCrypt, dataShards, parityShards in var convid uint64 binary.Read(rand.Reader, binary.LittleEndian, &convid) - return newUDPSession(convid, dataShards, parityShards, nil, conn, true, udpaddr, block), nil + return newUDPSession(convid, nil, conn, true, udpaddr), nil } // NewConn3 establishes a session and talks KCP protocol over a packet connection. -func NewConn3(convid uint64, raddr net.Addr, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) { - return newUDPSession(convid, dataShards, parityShards, nil, conn, false, raddr, block), nil +func NewConn3(convid uint64, raddr net.Addr, conn net.PacketConn) (*UDPSession, error) { + return newUDPSession(convid, nil, conn, false, raddr), nil } // NewConn2 establishes a session and talks KCP protocol over a packet connection. -func NewConn2(raddr net.Addr, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) { +func NewConn2(raddr net.Addr, conn net.PacketConn) (*UDPSession, error) { var convid uint64 binary.Read(rand.Reader, binary.LittleEndian, &convid) - return NewConn3(convid, raddr, block, dataShards, parityShards, conn) + return NewConn3(convid, raddr, conn) } // NewConn establishes a session and talks KCP protocol over a packet connection. -func NewConn(raddr string, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) { +func NewConn(raddr string, conn net.PacketConn) (*UDPSession, error) { udpaddr, err := net.ResolveUDPAddr("udp", raddr) if err != nil { return nil, errors.WithStack(err) } - return NewConn2(udpaddr, block, dataShards, parityShards, conn) + return NewConn2(udpaddr, conn) } diff --git a/gate/kcp/sess_test.go b/gate/kcp/sess_test.go index 053e6caa..82d82090 100644 --- a/gate/kcp/sess_test.go +++ b/gate/kcp/sess_test.go @@ -33,8 +33,8 @@ func dialEcho(port int) (*UDPSession, error) { // block, _ := NewSimpleXORBlockCrypt(pass) // block, _ := NewTEABlockCrypt(pass[:16]) // block, _ := NewAESBlockCrypt(pass) - block, _ := NewSalsa20BlockCrypt(pass) - sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port), block, 10, 3) + // block, _ := NewSalsa20BlockCrypt(pass) + sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port)) if err != nil { panic(err) } @@ -57,7 +57,7 @@ func dialEcho(port int) (*UDPSession, error) { } func dialSink(port int) (*UDPSession, error) { - sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port), nil, 0, 0) + sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port)) if err != nil { panic(err) } @@ -79,8 +79,8 @@ func dialTinyBufferEcho(port int) (*UDPSession, error) { // block, _ := NewSimpleXORBlockCrypt(pass) // block, _ := NewTEABlockCrypt(pass[:16]) // block, _ := NewAESBlockCrypt(pass) - block, _ := NewSalsa20BlockCrypt(pass) - sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port), block, 10, 3) + // block, _ := NewSalsa20BlockCrypt(pass) + sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port)) if err != nil { panic(err) } @@ -93,20 +93,20 @@ func listenEcho(port int) (net.Listener, error) { // block, _ := NewSimpleXORBlockCrypt(pass) // block, _ := NewTEABlockCrypt(pass[:16]) // block, _ := NewAESBlockCrypt(pass) - block, _ := NewSalsa20BlockCrypt(pass) - return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), block, 10, 0) + // block, _ := NewSalsa20BlockCrypt(pass) + return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port)) } func listenTinyBufferEcho(port int) (net.Listener, error) { // block, _ := NewNoneBlockCrypt(pass) // block, _ := NewSimpleXORBlockCrypt(pass) // block, _ := NewTEABlockCrypt(pass[:16]) // block, _ := NewAESBlockCrypt(pass) - block, _ := NewSalsa20BlockCrypt(pass) - return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), block, 10, 3) + // block, _ := NewSalsa20BlockCrypt(pass) + return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port)) } func listenSink(port int) (net.Listener, error) { - return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), nil, 0, 0) + return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port)) } func echoServer(port int) net.Listener { @@ -541,7 +541,7 @@ func TestSNMP(t *testing.T) { func TestListenerClose(t *testing.T) { port := int(atomic.AddUint32(&baseport, 1)) - l, err := ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), nil, 10, 3) + l, err := ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port)) if err != nil { t.Fail() } @@ -580,7 +580,7 @@ func newClosedFlagPacketConn(c net.PacketConn) *closedFlagPacketConn { // https://github.com/xtaci/kcp-go/issues/165 func TestListenerOwnedPacketConn(t *testing.T) { // ListenWithOptions creates its own net.PacketConn. - l, err := ListenWithOptions("127.0.0.1:0", nil, 0, 0) + l, err := ListenWithOptions("127.0.0.1:0") if err != nil { panic(err) } @@ -616,7 +616,7 @@ func TestListenerNonOwnedPacketConn(t *testing.T) { // Make it remember when it has been closed. pconn := newClosedFlagPacketConn(c) - l, err := ServeConn(nil, 0, 0, pconn) + l, err := ServeConn(pconn) if err != nil { panic(err) } @@ -643,7 +643,7 @@ func TestUDPSessionOwnedPacketConn(t *testing.T) { defer l.Close() // DialWithOptions creates its own net.PacketConn. - client, err := DialWithOptions(l.Addr().String(), nil, 0, 0) + client, err := DialWithOptions(l.Addr().String()) if err != nil { panic(err) } @@ -682,7 +682,7 @@ func TestUDPSessionNonOwnedPacketConn(t *testing.T) { // Make it remember when it has been closed. pconn := newClosedFlagPacketConn(c) - client, err := NewConn2(l.Addr(), nil, 0, 0, pconn) + client, err := NewConn2(l.Addr(), pconn) if err != nil { panic(err) } diff --git a/gate/kcp/snmp.go b/gate/kcp/snmp.go index f9618107..4101b1c8 100644 --- a/gate/kcp/snmp.go +++ b/gate/kcp/snmp.go @@ -7,14 +7,14 @@ import ( // Snmp defines network statistics indicator type Snmp struct { - BytesSent uint64 // bytes sent from upper level - BytesReceived uint64 // bytes received to upper level - MaxConn uint64 // max number of connections ever reached - ActiveOpens uint64 // accumulated active open connections - PassiveOpens uint64 // accumulated passive open connections - CurrEstab uint64 // current number of established connections - InErrs uint64 // UDP read errors reported from net.PacketConn - InCsumErrors uint64 // checksum errors from CRC32 + BytesSent uint64 // bytes sent from upper level + BytesReceived uint64 // bytes received to upper level + MaxConn uint64 // max number of connections ever reached + ActiveOpens uint64 // accumulated active open connections + PassiveOpens uint64 // accumulated passive open connections + CurrEstab uint64 // current number of established connections + // InErrs uint64 // UDP read errors reported from net.PacketConn + // InCsumErrors uint64 // checksum errors from CRC32 KCPInErrors uint64 // packet iput errors reported from KCP InPkts uint64 // incoming packets count OutPkts uint64 // outgoing packets count @@ -27,10 +27,10 @@ type Snmp struct { EarlyRetransSegs uint64 // accmulated early retransmitted segments LostSegs uint64 // number of segs inferred as lost RepeatSegs uint64 // number of segs duplicated - FECRecovered uint64 // correct packets recovered from FEC - FECErrs uint64 // incorrect packets recovered from FEC - FECParityShards uint64 // FEC segments received - FECShortShards uint64 // number of data shards that's not enough for recovery + // FECRecovered uint64 // correct packets recovered from FEC + // FECErrs uint64 // incorrect packets recovered from FEC + // FECParityShards uint64 // FEC segments received + // FECShortShards uint64 // number of data shards that's not enough for recovery } func newSnmp() *Snmp { @@ -46,8 +46,8 @@ func (s *Snmp) Header() []string { "ActiveOpens", "PassiveOpens", "CurrEstab", - "InErrs", - "InCsumErrors", + // "InErrs", + // "InCsumErrors", "KCPInErrors", "InPkts", "OutPkts", @@ -60,10 +60,10 @@ func (s *Snmp) Header() []string { "EarlyRetransSegs", "LostSegs", "RepeatSegs", - "FECParityShards", - "FECErrs", - "FECRecovered", - "FECShortShards", + // "FECParityShards", + // "FECErrs", + // "FECRecovered", + // "FECShortShards", } } @@ -77,8 +77,8 @@ func (s *Snmp) ToSlice() []string { fmt.Sprint(snmp.ActiveOpens), fmt.Sprint(snmp.PassiveOpens), fmt.Sprint(snmp.CurrEstab), - fmt.Sprint(snmp.InErrs), - fmt.Sprint(snmp.InCsumErrors), + // fmt.Sprint(snmp.InErrs), + // fmt.Sprint(snmp.InCsumErrors), fmt.Sprint(snmp.KCPInErrors), fmt.Sprint(snmp.InPkts), fmt.Sprint(snmp.OutPkts), @@ -91,10 +91,10 @@ func (s *Snmp) ToSlice() []string { fmt.Sprint(snmp.EarlyRetransSegs), fmt.Sprint(snmp.LostSegs), fmt.Sprint(snmp.RepeatSegs), - fmt.Sprint(snmp.FECParityShards), - fmt.Sprint(snmp.FECErrs), - fmt.Sprint(snmp.FECRecovered), - fmt.Sprint(snmp.FECShortShards), + // fmt.Sprint(snmp.FECParityShards), + // fmt.Sprint(snmp.FECErrs), + // fmt.Sprint(snmp.FECRecovered), + // fmt.Sprint(snmp.FECShortShards), } } @@ -107,8 +107,8 @@ func (s *Snmp) Copy() *Snmp { d.ActiveOpens = atomic.LoadUint64(&s.ActiveOpens) d.PassiveOpens = atomic.LoadUint64(&s.PassiveOpens) d.CurrEstab = atomic.LoadUint64(&s.CurrEstab) - d.InErrs = atomic.LoadUint64(&s.InErrs) - d.InCsumErrors = atomic.LoadUint64(&s.InCsumErrors) + // d.InErrs = atomic.LoadUint64(&s.InErrs) + // d.InCsumErrors = atomic.LoadUint64(&s.InCsumErrors) d.KCPInErrors = atomic.LoadUint64(&s.KCPInErrors) d.InPkts = atomic.LoadUint64(&s.InPkts) d.OutPkts = atomic.LoadUint64(&s.OutPkts) @@ -121,10 +121,10 @@ func (s *Snmp) Copy() *Snmp { d.EarlyRetransSegs = atomic.LoadUint64(&s.EarlyRetransSegs) d.LostSegs = atomic.LoadUint64(&s.LostSegs) d.RepeatSegs = atomic.LoadUint64(&s.RepeatSegs) - d.FECParityShards = atomic.LoadUint64(&s.FECParityShards) - d.FECErrs = atomic.LoadUint64(&s.FECErrs) - d.FECRecovered = atomic.LoadUint64(&s.FECRecovered) - d.FECShortShards = atomic.LoadUint64(&s.FECShortShards) + // d.FECParityShards = atomic.LoadUint64(&s.FECParityShards) + // d.FECErrs = atomic.LoadUint64(&s.FECErrs) + // d.FECRecovered = atomic.LoadUint64(&s.FECRecovered) + // d.FECShortShards = atomic.LoadUint64(&s.FECShortShards) return d } @@ -136,8 +136,8 @@ func (s *Snmp) Reset() { atomic.StoreUint64(&s.ActiveOpens, 0) atomic.StoreUint64(&s.PassiveOpens, 0) atomic.StoreUint64(&s.CurrEstab, 0) - atomic.StoreUint64(&s.InErrs, 0) - atomic.StoreUint64(&s.InCsumErrors, 0) + // atomic.StoreUint64(&s.InErrs, 0) + // atomic.StoreUint64(&s.InCsumErrors, 0) atomic.StoreUint64(&s.KCPInErrors, 0) atomic.StoreUint64(&s.InPkts, 0) atomic.StoreUint64(&s.OutPkts, 0) @@ -150,10 +150,10 @@ func (s *Snmp) Reset() { atomic.StoreUint64(&s.EarlyRetransSegs, 0) atomic.StoreUint64(&s.LostSegs, 0) atomic.StoreUint64(&s.RepeatSegs, 0) - atomic.StoreUint64(&s.FECParityShards, 0) - atomic.StoreUint64(&s.FECErrs, 0) - atomic.StoreUint64(&s.FECRecovered, 0) - atomic.StoreUint64(&s.FECShortShards, 0) + // atomic.StoreUint64(&s.FECParityShards, 0) + // atomic.StoreUint64(&s.FECErrs, 0) + // atomic.StoreUint64(&s.FECRecovered, 0) + // atomic.StoreUint64(&s.FECShortShards, 0) } // DefaultSnmp is the global KCP connection statistics collector diff --git a/gate/kcp/tx.go b/gate/kcp/tx.go deleted file mode 100644 index e39f6a0a..00000000 --- a/gate/kcp/tx.go +++ /dev/null @@ -1,80 +0,0 @@ -package kcp - -import ( - "net" - "sync/atomic" - - "github.com/pkg/errors" - "golang.org/x/net/ipv4" -) - -func buildEnet(connType uint8, enetType uint32, conv uint64) []byte { - data := make([]byte, 20) - if connType == ConnEnetSyn { - copy(data[0:4], MagicEnetSynHead) - copy(data[16:20], MagicEnetSynTail) - } else if connType == ConnEnetEst { - copy(data[0:4], MagicEnetEstHead) - copy(data[16:20], MagicEnetEstTail) - } else if connType == ConnEnetFin { - copy(data[0:4], MagicEnetFinHead) - copy(data[16:20], MagicEnetFinTail) - } else { - return nil - } - // conv的高四个字节和低四个字节分开 - // 例如 00 00 01 45 | LL LL LL LL | HH HH HH HH | 49 96 02 d2 | 14 51 45 45 - data[4] = uint8(conv >> 24) - data[5] = uint8(conv >> 16) - data[6] = uint8(conv >> 8) - data[7] = uint8(conv >> 0) - data[8] = uint8(conv >> 56) - data[9] = uint8(conv >> 48) - data[10] = uint8(conv >> 40) - data[11] = uint8(conv >> 32) - // Enet - data[12] = uint8(enetType >> 24) - data[13] = uint8(enetType >> 16) - data[14] = uint8(enetType >> 8) - data[15] = uint8(enetType >> 0) - return data -} - -func (l *Listener) defaultSendEnetNotifyToClient(enet *Enet) { - remoteAddr, err := net.ResolveUDPAddr("udp", enet.Addr) - if err != nil { - return - } - data := buildEnet(enet.ConnType, enet.EnetType, enet.ConvId) - if data == nil { - return - } - _, _ = l.conn.WriteTo(data, remoteAddr) -} - -func (s *UDPSession) defaultSendEnetNotify(enet *Enet) { - data := buildEnet(enet.ConnType, enet.EnetType, s.GetConv()) - if data == nil { - return - } - s.defaultTx([]ipv4.Message{{ - Buffers: [][]byte{data}, - Addr: s.remote, - }}) -} - -func (s *UDPSession) defaultTx(txqueue []ipv4.Message) { - nbytes := 0 - npkts := 0 - for k := range txqueue { - if n, err := s.conn.WriteTo(txqueue[k].Buffers[0], txqueue[k].Addr); err == nil { - nbytes += n - npkts++ - } else { - s.notifyWriteError(errors.WithStack(err)) - break - } - } - atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts)) - atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes)) -} diff --git a/gate/kcp/tx_linux.go b/gate/kcp/tx_linux.go deleted file mode 100644 index 5cb81557..00000000 --- a/gate/kcp/tx_linux.go +++ /dev/null @@ -1,102 +0,0 @@ -//go:build linux -// +build linux - -package kcp - -import ( - "net" - "os" - "sync/atomic" - - "github.com/pkg/errors" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" -) - -func (l *Listener) SendEnetNotifyToClient(enet *Enet) { - var xconn batchConn - _, ok := l.conn.(*net.UDPConn) - if !ok { - return - } - localAddr, err := net.ResolveUDPAddr("udp", l.conn.LocalAddr().String()) - if err != nil { - return - } - if localAddr.IP.To4() != nil { - xconn = ipv4.NewPacketConn(l.conn) - } else { - xconn = ipv6.NewPacketConn(l.conn) - } - - // default version - if xconn == nil { - l.defaultSendEnetNotifyToClient(enet) - return - } - - remoteAddr, err := net.ResolveUDPAddr("udp", enet.Addr) - if err != nil { - return - } - - data := buildEnet(enet.ConnType, enet.EnetType, enet.ConvId) - if data == nil { - return - } - - _, _ = xconn.WriteBatch([]ipv4.Message{{ - Buffers: [][]byte{data}, - Addr: remoteAddr, - }}, 0) -} - -func (s *UDPSession) SendEnetNotify(enet *Enet) { - data := buildEnet(enet.ConnType, enet.EnetType, s.GetConv()) - if data == nil { - return - } - s.tx([]ipv4.Message{{ - Buffers: [][]byte{data}, - Addr: s.remote, - }}) -} - -func (s *UDPSession) tx(txqueue []ipv4.Message) { - // default version - if s.xconn == nil || s.xconnWriteError != nil { - s.defaultTx(txqueue) - return - } - - // x/net version - nbytes := 0 - npkts := 0 - for len(txqueue) > 0 { - if n, err := s.xconn.WriteBatch(txqueue, 0); err == nil { - for k := range txqueue[:n] { - nbytes += len(txqueue[k].Buffers[0]) - } - npkts += n - txqueue = txqueue[n:] - } else { - // compatibility issue: - // for linux kernel<=2.6.32, support for sendmmsg is not available - // an error of type os.SyscallError will be returned - if operr, ok := err.(*net.OpError); ok { - if se, ok := operr.Err.(*os.SyscallError); ok { - if se.Syscall == "sendmmsg" { - s.xconnWriteError = se - s.defaultTx(txqueue) - return - } - } - } - s.notifyWriteError(errors.WithStack(err)) - break - } - } - - atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts)) - atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes)) -} diff --git a/gate/kcp/readloop.go b/gate/kcp/udp_socket.go similarity index 64% rename from gate/kcp/readloop.go rename to gate/kcp/udp_socket.go index 1fa87832..734b0429 100644 --- a/gate/kcp/readloop.go +++ b/gate/kcp/udp_socket.go @@ -3,8 +3,11 @@ package kcp import ( "bytes" "encoding/binary" + "net" + "sync/atomic" "github.com/pkg/errors" + "golang.org/x/net/ipv4" ) func (s *UDPSession) defaultReadLoop() { @@ -125,3 +128,74 @@ func (l *Listener) defaultMonitor() { } } } + +func buildEnet(connType uint8, enetType uint32, conv uint64) []byte { + data := make([]byte, 20) + if connType == ConnEnetSyn { + copy(data[0:4], MagicEnetSynHead) + copy(data[16:20], MagicEnetSynTail) + } else if connType == ConnEnetEst { + copy(data[0:4], MagicEnetEstHead) + copy(data[16:20], MagicEnetEstTail) + } else if connType == ConnEnetFin { + copy(data[0:4], MagicEnetFinHead) + copy(data[16:20], MagicEnetFinTail) + } else { + return nil + } + // conv的高四个字节和低四个字节分开 + // 例如 00 00 01 45 | LL LL LL LL | HH HH HH HH | 49 96 02 d2 | 14 51 45 45 + data[4] = uint8(conv >> 24) + data[5] = uint8(conv >> 16) + data[6] = uint8(conv >> 8) + data[7] = uint8(conv >> 0) + data[8] = uint8(conv >> 56) + data[9] = uint8(conv >> 48) + data[10] = uint8(conv >> 40) + data[11] = uint8(conv >> 32) + // Enet + data[12] = uint8(enetType >> 24) + data[13] = uint8(enetType >> 16) + data[14] = uint8(enetType >> 8) + data[15] = uint8(enetType >> 0) + return data +} + +func (l *Listener) defaultSendEnetNotifyToClient(enet *Enet) { + remoteAddr, err := net.ResolveUDPAddr("udp", enet.Addr) + if err != nil { + return + } + data := buildEnet(enet.ConnType, enet.EnetType, enet.ConvId) + if data == nil { + return + } + _, _ = l.conn.WriteTo(data, remoteAddr) +} + +func (s *UDPSession) defaultSendEnetNotify(enet *Enet) { + data := buildEnet(enet.ConnType, enet.EnetType, s.GetConv()) + if data == nil { + return + } + s.defaultTx([]ipv4.Message{{ + Buffers: [][]byte{data}, + Addr: s.remote, + }}) +} + +func (s *UDPSession) defaultTx(txqueue []ipv4.Message) { + nbytes := 0 + npkts := 0 + for k := range txqueue { + if n, err := s.conn.WriteTo(txqueue[k].Buffers[0], txqueue[k].Addr); err == nil { + nbytes += n + npkts++ + } else { + s.notifyWriteError(errors.WithStack(err)) + break + } + } + atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts)) + atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes)) +} diff --git a/gate/kcp/readloop_linux.go b/gate/kcp/udp_socket_linux.go similarity index 73% rename from gate/kcp/readloop_linux.go rename to gate/kcp/udp_socket_linux.go index 1e11d297..32309b52 100644 --- a/gate/kcp/readloop_linux.go +++ b/gate/kcp/udp_socket_linux.go @@ -8,6 +8,7 @@ import ( "encoding/binary" "net" "os" + "sync/atomic" "github.com/pkg/errors" "golang.org/x/net/ipv4" @@ -198,3 +199,91 @@ func (l *Listener) monitor() { } } } + +func (l *Listener) SendEnetNotifyToClient(enet *Enet) { + var xconn batchConn + _, ok := l.conn.(*net.UDPConn) + if !ok { + return + } + localAddr, err := net.ResolveUDPAddr("udp", l.conn.LocalAddr().String()) + if err != nil { + return + } + if localAddr.IP.To4() != nil { + xconn = ipv4.NewPacketConn(l.conn) + } else { + xconn = ipv6.NewPacketConn(l.conn) + } + + // default version + if xconn == nil { + l.defaultSendEnetNotifyToClient(enet) + return + } + + remoteAddr, err := net.ResolveUDPAddr("udp", enet.Addr) + if err != nil { + return + } + + data := buildEnet(enet.ConnType, enet.EnetType, enet.ConvId) + if data == nil { + return + } + + _, _ = xconn.WriteBatch([]ipv4.Message{{ + Buffers: [][]byte{data}, + Addr: remoteAddr, + }}, 0) +} + +func (s *UDPSession) SendEnetNotify(enet *Enet) { + data := buildEnet(enet.ConnType, enet.EnetType, s.GetConv()) + if data == nil { + return + } + s.tx([]ipv4.Message{{ + Buffers: [][]byte{data}, + Addr: s.remote, + }}) +} + +func (s *UDPSession) tx(txqueue []ipv4.Message) { + // default version + if s.xconn == nil || s.xconnWriteError != nil { + s.defaultTx(txqueue) + return + } + + // x/net version + nbytes := 0 + npkts := 0 + for len(txqueue) > 0 { + if n, err := s.xconn.WriteBatch(txqueue, 0); err == nil { + for k := range txqueue[:n] { + nbytes += len(txqueue[k].Buffers[0]) + } + npkts += n + txqueue = txqueue[n:] + } else { + // compatibility issue: + // for linux kernel<=2.6.32, support for sendmmsg is not available + // an error of type os.SyscallError will be returned + if operr, ok := err.(*net.OpError); ok { + if se, ok := operr.Err.(*os.SyscallError); ok { + if se.Syscall == "sendmmsg" { + s.xconnWriteError = se + s.defaultTx(txqueue) + return + } + } + } + s.notifyWriteError(errors.WithStack(err)) + break + } + } + + atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts)) + atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes)) +} diff --git a/gate/kcp/tx_generic.go b/gate/kcp/udp_socket_windows.go similarity index 75% rename from gate/kcp/tx_generic.go rename to gate/kcp/udp_socket_windows.go index 58f1f73c..aa76bd6d 100644 --- a/gate/kcp/tx_generic.go +++ b/gate/kcp/udp_socket_windows.go @@ -7,6 +7,14 @@ import ( "golang.org/x/net/ipv4" ) +func (s *UDPSession) readLoop() { + s.defaultReadLoop() +} + +func (l *Listener) monitor() { + l.defaultMonitor() +} + func (l *Listener) SendEnetNotifyToClient(enet *Enet) { l.defaultSendEnetNotifyToClient(enet) } diff --git a/gate/net/kcp_connect_manager.go b/gate/net/kcp_connect_manager.go index 960d9aab..e4c612fb 100644 --- a/gate/net/kcp_connect_manager.go +++ b/gate/net/kcp_connect_manager.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "strconv" "sync" + "sync/atomic" "time" "hk4e/common/config" @@ -21,27 +22,36 @@ import ( ) const ( - PacketFreqLimit = 1000 - PacketMaxLen = 343 * 1024 - ConnRecvTimeout = 30 - ConnSendTimeout = 10 + ConnSynPacketFreqLimit = 100 // 连接建立握手包每秒发包频率限制 + RecvPacketFreqLimit = 100 // 客户端上行每秒发包频率限制 + SendPacketFreqLimit = 1000 // 服务器下行每秒发包频率限制 + PacketMaxLen = 343 * 1024 // 最大应用层包长度 + ConnRecvTimeout = 30 // 收包超时时间 秒 + ConnSendTimeout = 10 // 发包超时时间 秒 + MaxClientConnNumLimit = 100 // 最大客户端连接数限制 ) +var CLIENT_CONN_NUM int32 = 0 // 当前客户端连接数 + type KcpConnectManager struct { - discovery *rpc.DiscoveryClient - openState bool + discovery *rpc.DiscoveryClient // node服务器客户端 + openState bool // 网关开放状态 + // 会话 sessionConvIdMap map[uint64]*Session sessionUserIdMap map[uint32]*Session sessionMapLock sync.RWMutex - kcpEventInput chan *KcpEvent - kcpEventOutput chan *KcpEvent - serverCmdProtoMap *cmd.CmdProtoMap - clientCmdProtoMap *client_proto.ClientCmdProtoMap - messageQueue *mq.MessageQueue - localMsgOutput chan *ProtoMsg createSessionChan chan *Session destroySessionChan chan *Session - // 密钥相关 + // 连接事件 + kcpEventInput chan *KcpEvent + kcpEventOutput chan *KcpEvent + // 协议 + serverCmdProtoMap *cmd.CmdProtoMap + clientCmdProtoMap *client_proto.ClientCmdProtoMap + // 输入输出管道 + messageQueue *mq.MessageQueue + localMsgOutput chan *ProtoMsg + // 密钥 dispatchKey []byte signRsaKey []byte encRsaKeyMap map[string][]byte @@ -53,6 +63,8 @@ func NewKcpConnectManager(messageQueue *mq.MessageQueue, discovery *rpc.Discover r.openState = true r.sessionConvIdMap = make(map[uint64]*Session) r.sessionUserIdMap = make(map[uint32]*Session) + r.createSessionChan = make(chan *Session, 1000) + r.destroySessionChan = make(chan *Session, 1000) r.kcpEventInput = make(chan *KcpEvent, 1000) r.kcpEventOutput = make(chan *KcpEvent, 1000) r.serverCmdProtoMap = cmd.NewCmdProtoMap() @@ -61,8 +73,6 @@ func NewKcpConnectManager(messageQueue *mq.MessageQueue, discovery *rpc.Discover } r.messageQueue = messageQueue r.localMsgOutput = make(chan *ProtoMsg, 1000) - r.createSessionChan = make(chan *Session, 1000) - r.destroySessionChan = make(chan *Session, 1000) r.run() return r } @@ -86,7 +96,7 @@ func (k *KcpConnectManager) run() { k.dispatchKey = regionEc2b.XorKey() // kcp port := strconv.Itoa(int(config.CONF.Hk4e.KcpPort)) - listener, err := kcp.ListenWithOptions("0.0.0.0:"+port, nil, 0, 0) + listener, err := kcp.ListenWithOptions("0.0.0.0:" + port) if err != nil { logger.Error("listen kcp err: %v", err) return @@ -95,29 +105,56 @@ func (k *KcpConnectManager) run() { go k.eventHandle() go k.sendMsgHandle() go k.acceptHandle(listener) + go k.gateNetInfo() } func (k *KcpConnectManager) Close() { k.closeAllKcpConn() + // 等待所有连接关闭时需要发送的消息发送完毕 time.Sleep(time.Second * 3) } +func (k *KcpConnectManager) gateNetInfo() { + ticker := time.NewTicker(time.Second * 60) + kcpErrorCount := uint64(0) + for { + <-ticker.C + snmp := kcp.DefaultSnmp.Copy() + kcpErrorCount += snmp.KCPInErrors + logger.Info("kcp send: %v B/s, kcp recv: %v B/s", snmp.BytesSent/60, snmp.BytesReceived/60) + logger.Info("udp send: %v B/s, udp recv: %v B/s", snmp.OutBytes/60, snmp.InBytes/60) + logger.Info("udp send: %v pps, udp recv: %v pps", snmp.OutPkts/60, snmp.InPkts/60) + clientConnNum := atomic.LoadInt32(&CLIENT_CONN_NUM) + logger.Info("conn num: %v, new conn num: %v, kcp error num: %v", clientConnNum, snmp.CurrEstab, kcpErrorCount) + kcp.DefaultSnmp.Reset() + } +} + +// 接收并创建新连接处理函数 func (k *KcpConnectManager) acceptHandle(listener *kcp.Listener) { - logger.Debug("accept handle start") + logger.Info("accept handle start") for { conn, err := listener.AcceptKCP() if err != nil { logger.Error("accept kcp err: %v", err) return } + convId := conn.GetConv() if k.openState == false { + logger.Error("gate not open, convId: %v", convId) _ = conn.Close() continue } + clientConnNum := atomic.AddInt32(&CLIENT_CONN_NUM, 1) + if clientConnNum > MaxClientConnNumLimit { + logger.Error("gate conn num limit, convId: %v", convId) + _ = conn.Close() + atomic.AddInt32(&CLIENT_CONN_NUM, -1) + continue + } conn.SetACKNoDelay(true) conn.SetWriteDelay(false) - convId := conn.GetConv() - logger.Debug("client connect, convId: %v", convId) + logger.Info("client connect, convId: %v", convId) kcpRawSendChan := make(chan *ProtoMsg, 1000) session := &Session{ conn: conn, @@ -145,61 +182,76 @@ func (k *KcpConnectManager) acceptHandle(listener *kcp.Listener) { } } +// 连接事件处理函数 func (k *KcpConnectManager) enetHandle(listener *kcp.Listener) { - logger.Debug("enet handle start") + logger.Info("enet handle start") // conv短时间内唯一生成 convGenMap := make(map[uint64]int64) + pktFreqLimitCounter := 0 + pktFreqLimitTimer := time.Now().UnixNano() for { enetNotify := <-listener.EnetNotify logger.Info("[Enet Notify], addr: %v, conv: %v, conn: %v, enet: %v", enetNotify.Addr, enetNotify.ConvId, enetNotify.ConnType, enetNotify.EnetType) switch enetNotify.ConnType { case kcp.ConnEnetSyn: - if enetNotify.EnetType == kcp.EnetClientConnectKey { - // 清理老旧的conv + // 连接建立握手包频率限制 + pktFreqLimitCounter++ + if pktFreqLimitCounter > ConnSynPacketFreqLimit { now := time.Now().UnixNano() - oldConvList := make([]uint64, 0) - for conv, timestamp := range convGenMap { - if now-timestamp > int64(time.Hour) { - oldConvList = append(oldConvList, conv) - } + if now-pktFreqLimitTimer > int64(time.Second) { + pktFreqLimitCounter = 0 + pktFreqLimitTimer = now + } else { + continue } - delConvList := make([]uint64, 0) - k.sessionMapLock.RLock() - for _, conv := range oldConvList { - _, exist := k.sessionConvIdMap[conv] - if !exist { - delConvList = append(delConvList, conv) - delete(convGenMap, conv) - } - } - k.sessionMapLock.RUnlock() - logger.Info("clean dead conv list: %v", delConvList) - // 生成没用过的conv - var conv uint64 - for { - convData := random.GetRandomByte(8) - convDataBuffer := bytes.NewBuffer(convData) - _ = binary.Read(convDataBuffer, binary.LittleEndian, &conv) - _, exist := convGenMap[conv] - if exist { - continue - } else { - convGenMap[conv] = time.Now().UnixNano() - break - } - } - listener.SendEnetNotifyToClient(&kcp.Enet{ - Addr: enetNotify.Addr, - ConvId: conv, - ConnType: kcp.ConnEnetEst, - EnetType: enetNotify.EnetType, - }) } + if enetNotify.EnetType != kcp.EnetClientConnectKey { + continue + } + // 清理老旧的conv + now := time.Now().UnixNano() + oldConvList := make([]uint64, 0) + for conv, timestamp := range convGenMap { + if now-timestamp > int64(time.Hour) { + oldConvList = append(oldConvList, conv) + } + } + delConvList := make([]uint64, 0) + k.sessionMapLock.RLock() + for _, conv := range oldConvList { + _, exist := k.sessionConvIdMap[conv] + if !exist { + delConvList = append(delConvList, conv) + delete(convGenMap, conv) + } + } + k.sessionMapLock.RUnlock() + logger.Info("clean dead conv list: %v", delConvList) + // 生成没用过的conv + var conv uint64 + for { + convData := random.GetRandomByte(8) + convDataBuffer := bytes.NewBuffer(convData) + _ = binary.Read(convDataBuffer, binary.LittleEndian, &conv) + _, exist := convGenMap[conv] + if exist { + continue + } else { + convGenMap[conv] = time.Now().UnixNano() + break + } + } + listener.SendEnetNotifyToClient(&kcp.Enet{ + Addr: enetNotify.Addr, + ConvId: conv, + ConnType: kcp.ConnEnetEst, + EnetType: enetNotify.EnetType, + }) case kcp.ConnEnetEst: case kcp.ConnEnetFin: session := k.GetSessionByConvId(enetNotify.ConvId) if session == nil { - logger.Error("session not exist, convId: %v", enetNotify.ConvId) + logger.Error("session not exist, conv: %v", enetNotify.ConvId) continue } session.conn.SendEnetNotify(&kcp.Enet{ @@ -219,6 +271,7 @@ func (k *KcpConnectManager) enetHandle(listener *kcp.Listener) { } } +// Session 连接会话结构 type Session struct { conn *kcp.UDPSession connState uint8 @@ -237,13 +290,13 @@ type Session struct { // 接收 func (k *KcpConnectManager) recvHandle(session *Session) { - logger.Debug("recv handle start") + logger.Info("recv handle start") conn := session.conn convId := conn.GetConv() - pktFreqLimitCounter := 0 - pktFreqLimitTimer := time.Now().UnixNano() recvBuf := make([]byte, PacketMaxLen) dataBuf := make([]byte, 0, 1500) + pktFreqLimitCounter := 0 + pktFreqLimitTimer := time.Now().UnixNano() for { _ = conn.SetReadDeadline(time.Now().Add(time.Second * ConnRecvTimeout)) recvLen, err := conn.Read(recvBuf) @@ -254,16 +307,16 @@ func (k *KcpConnectManager) recvHandle(session *Session) { } // 收包频率限制 pktFreqLimitCounter++ - now := time.Now().UnixNano() - if now-pktFreqLimitTimer > int64(time.Second) { - if pktFreqLimitCounter > PacketFreqLimit { + if pktFreqLimitCounter > RecvPacketFreqLimit { + now := time.Now().UnixNano() + if now-pktFreqLimitTimer > int64(time.Second) { + pktFreqLimitCounter = 0 + pktFreqLimitTimer = now + } else { logger.Error("exit recv loop, client packet send freq too high, convId: %v, pps: %v", convId, pktFreqLimitCounter) k.closeKcpConn(session, kcp.EnetPacketFreqTooHigh) break - } else { - pktFreqLimitCounter = 0 } - pktFreqLimitTimer = now } recvData := recvBuf[:recvLen] kcpMsgList := make([]*KcpMsg, 0) @@ -279,9 +332,11 @@ func (k *KcpConnectManager) recvHandle(session *Session) { // 发送 func (k *KcpConnectManager) sendHandle(session *Session) { - logger.Debug("send handle start") + logger.Info("send handle start") conn := session.conn convId := conn.GetConv() + pktFreqLimitCounter := 0 + pktFreqLimitTimer := time.Now().UnixNano() for { protoMsg, ok := <-session.kcpRawSendChan if !ok { @@ -302,9 +357,22 @@ func (k *KcpConnectManager) sendHandle(session *Session) { k.closeKcpConn(session, kcp.EnetServerKick) break } + // 发包频率限制 + pktFreqLimitCounter++ + if pktFreqLimitCounter > SendPacketFreqLimit { + now := time.Now().UnixNano() + if now-pktFreqLimitTimer > int64(time.Second) { + pktFreqLimitCounter = 0 + pktFreqLimitTimer = now + } else { + logger.Error("exit send loop, server packet send freq too high, convId: %v, pps: %v", convId, pktFreqLimitCounter) + k.closeKcpConn(session, kcp.EnetPacketFreqTooHigh) + break + } + } if session.changeXorKeyFin == false && protoMsg.CmdId == cmd.GetPlayerTokenRsp { // XOR密钥切换 - logger.Debug("change session xor key, convId: %v", convId) + logger.Info("change session xor key, convId: %v", convId) session.changeXorKeyFin = true keyBlock := random.NewKeyBlock(session.seed, session.useMagicSeed) xorKey := keyBlock.XorKey() @@ -315,8 +383,8 @@ func (k *KcpConnectManager) sendHandle(session *Session) { } } +// 强制关闭指定连接 func (k *KcpConnectManager) forceCloseKcpConn(convId uint64, reason uint32) { - // 强制关闭某个连接 session := k.GetSessionByConvId(convId) if session == nil { logger.Error("session not exist, convId: %v", convId) @@ -326,6 +394,7 @@ func (k *KcpConnectManager) forceCloseKcpConn(convId uint64, reason uint32) { logger.Info("conn has been force close, convId: %v", convId) } +// 关闭指定连接 func (k *KcpConnectManager) closeKcpConn(session *Session, enetType uint32) { if session.connState == ConnClose { return @@ -358,8 +427,10 @@ func (k *KcpConnectManager) closeKcpConn(session *Session, enetType uint32) { }) logger.Info("send to gs user offline, ConvId: %v, UserId: %v", convId, connCtrlMsg.UserId) k.destroySessionChan <- session + atomic.AddInt32(&CLIENT_CONN_NUM, -1) } +// 关闭所有连接 func (k *KcpConnectManager) closeAllKcpConn() { sessionList := make([]*Session, 0) k.sessionMapLock.RLock() diff --git a/go.mod b/go.mod index 77de7909..b2a4c2a1 100644 --- a/go.mod +++ b/go.mod @@ -7,11 +7,7 @@ require github.com/BurntSushi/toml v0.3.1 // kcp require ( - github.com/klauspost/reedsolomon v1.9.14 github.com/pkg/errors v0.9.1 - github.com/stretchr/testify v1.7.0 - github.com/templexxx/xorsimd v0.4.1 - github.com/tjfoc/gmsm v1.4.1 github.com/xtaci/lossyconn v0.0.0-20200209145036-adba10fffc37 golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be golang.org/x/net v0.0.0-20211123203042-d83791d6bcd9 @@ -59,9 +55,13 @@ require gitlab.com/gomidi/midi/v2 v2.0.25 // lua require github.com/yuin/gopher-lua v1.0.0 +// lz4 +require github.com/pierrec/lz4/v4 v4.1.17 + +require golang.org/x/sys v0.0.0-20220928140112-f11e5e49a4ec + require ( github.com/cespare/xxhash/v2 v2.1.2 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.13.0 // indirect @@ -74,7 +74,6 @@ require ( github.com/inconshreveable/mousetrap v1.0.1 // indirect github.com/json-iterator/go v1.1.9 // indirect github.com/klauspost/compress v1.15.11 // indirect - github.com/klauspost/cpuid/v2 v2.0.6 // indirect github.com/leodido/go-urn v1.2.0 // indirect github.com/mattn/go-isatty v0.0.12 // indirect github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect @@ -82,9 +81,7 @@ require ( github.com/nats-io/nats-server/v2 v2.9.7 // indirect github.com/nats-io/nkeys v0.3.0 // indirect github.com/nats-io/nuid v1.0.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/templexxx/cpu v0.0.1 // indirect github.com/ugorji/go/codec v1.1.7 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect @@ -92,8 +89,6 @@ require ( github.com/xdg-go/stringprep v1.0.2 // indirect github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e // indirect - golang.org/x/sys v0.0.0-20220928140112-f11e5e49a4ec // indirect golang.org/x/text v0.3.6 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f552564a..cc188245 100644 --- a/go.sum +++ b/go.sum @@ -1,24 +1,17 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/arl/statsviz v0.5.1 h1:3HY0ZEB738JtguWsD1Tf1pFJZiCcWUmYRq/3OTYKaSI= github.com/arl/statsviz v0.5.1/go.mod h1:zDnjgRblGm1Dyd7J5YlbH7gM1/+HRC+SfkhZhQb5AnM= github.com/byebyebruce/natsrpc v0.5.5-0.20221125150611-56cd29a4e335 h1:V5qahA5kDL/TBnlwvYjemR5du/uQ7q75qkBBlTc4rXI= github.com/byebyebruce/natsrpc v0.5.5-0.20221125150611-56cd29a4e335/go.mod h1:w61gLVOQWr/Tq/1wxSOMLxDPbH66rEo8jEHMh7j3qjo= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= -github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= @@ -38,26 +31,12 @@ github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/golang-jwt/jwt/v4 v4.4.0 h1:EmVIxB5jzbllGIjiCV5JG4VylbK3KE400tLGLI1cdfU= github.com/golang-jwt/jwt/v4 v4.4.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -75,10 +54,6 @@ github.com/jszwec/csvutil v1.7.1/go.mod h1:Rpu7Uu9giO9subDyMCIQfHVDuLrcaC36UA4Yc github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.15.11 h1:Lcadnb3RKGin4FYM/orgq0qde+nc15E5Cbqg4B9Sx9c= github.com/klauspost/compress v1.15.11/go.mod h1:QPwzmACJjUTFsnSHH934V6woptycfrDDJnH7hvFVbGM= -github.com/klauspost/cpuid/v2 v2.0.6 h1:dQ5ueTiftKxp0gyjKSx5+8BtPWkyQbd95m8Gys/RarI= -github.com/klauspost/cpuid/v2 v2.0.6/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/reedsolomon v1.9.14 h1:vkPCIhFMn2VdktLUcugqsU4vcLXN3dAhVd1uWA+TDD8= -github.com/klauspost/reedsolomon v1.9.14/go.mod h1:eqPAcE7xar5CIzcdfwydOEdcmchAKAP/qs14y4GCBOk= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -106,11 +81,12 @@ github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OS github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +github.com/pierrec/lz4/v4 v4.1.17 h1:kV4Ip+/hUBC+8T6+2EgburRtkE9ef4nbY3f4dFhGjMc= +github.com/pierrec/lz4/v4 v4.1.17/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.6.1 h1:o94oiPyS4KD1mPy2fmcYYHHfCxLqYjJOhGsCHFZtEzA= github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUqzrY= @@ -121,15 +97,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/templexxx/cpu v0.0.1 h1:hY4WdLOgKdc8y13EYklu9OUTXik80BkxHoWvTO6MQQY= -github.com/templexxx/cpu v0.0.1/go.mod h1:w7Tb+7qgcAlIyX4NhLuDKt78AHA5SzPmq0Wj6HiEnnk= -github.com/templexxx/xorsimd v0.4.1 h1:iUZcywbOYDRAZUasAs2eSCUW8eobuZDy0I9FJiORkVg= -github.com/templexxx/xorsimd v0.4.1/go.mod h1:W+ffZz8jJMH2SXwuKu9WhygqBMbFnp14G2fqEr8qaNo= github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= -github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho= -github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= @@ -154,37 +123,21 @@ gitlab.com/gomidi/midi/v2 v2.0.25/go.mod h1:quTyMKSQ4Klevxu6gY4gy2USbeZra0fV5Sal go.mongodb.org/mongo-driver v1.8.3 h1:TDKlTkGDKm9kkJVUOAXDK5/fkqKHJVwYQSpoRfB43R4= go.mongodb.org/mongo-driver v1.8.3/go.mod h1:0sQWfOeY63QTntERDJJ/0SuKK0T1uVSgKCuAROlKEPY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be h1:fmw3UbQh+nxngCAHrDCCztao/kbYFnWjoqop8dHx05A= golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20211123203042-d83791d6bcd9 h1:0qxwC5n+ttVOINCBeRHO0nq9X7uy8SDsPoi5OaCdIEI= golang.org/x/net v0.0.0-20211123203042-d83791d6bcd9/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220928140112-f11e5e49a4ec h1:BkDtF2Ih9xZ7le9ndzTA7KJow28VbQW3odyk/8drmuI= golang.org/x/sys v0.0.0-20220928140112-f11e5e49a4ec/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -198,27 +151,9 @@ golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20220922220347-f3bd1da661af h1:Yx9k8YCG3dvF87UAn2tu2HQLf2dt/eR1bXxpLMWeH+Y= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= @@ -234,5 +169,3 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/gs/app/app.go b/gs/app/app.go index ea4e9ce2..5fc342bc 100644 --- a/gs/app/app.go +++ b/gs/app/app.go @@ -5,6 +5,7 @@ import ( _ "net/http/pprof" "os" "os/signal" + "sync/atomic" "syscall" "time" @@ -49,6 +50,7 @@ func Run(ctx context.Context, configFile string) error { _, err := client.Discovery.KeepaliveServer(context.TODO(), &api.KeepaliveServerReq{ ServerType: api.GS, AppId: APPID, + LoadCount: uint32(atomic.LoadInt32(&game.ONLINE_PLAYER_NUM)), }) if err != nil { logger.Error("keepalive error: %v", err) diff --git a/gs/dao/player_redis.go b/gs/dao/player_redis.go index 966d0def..4660a6e9 100644 --- a/gs/dao/player_redis.go +++ b/gs/dao/player_redis.go @@ -1,13 +1,16 @@ package dao import ( + "bytes" "context" + "io" "strconv" "time" "hk4e/gs/model" "hk4e/pkg/logger" + "github.com/pierrec/lz4/v4" "github.com/vmihailenco/msgpack/v5" ) @@ -18,13 +21,28 @@ func (d *Dao) GetRedisPlayerKey(userId uint32) string { } func (d *Dao) GetRedisPlayer(userId uint32) *model.Player { - playerData, err := d.redis.Get(context.TODO(), d.GetRedisPlayerKey(userId)).Result() + playerDataLz4, err := d.redis.Get(context.TODO(), d.GetRedisPlayerKey(userId)).Result() if err != nil { logger.Error("get player from redis error: %v", err) return nil } + // 解压 + startTime := time.Now().UnixNano() + in := bytes.NewReader([]byte(playerDataLz4)) + out := new(bytes.Buffer) + lz4Reader := lz4.NewReader(in) + _, err = io.Copy(out, lz4Reader) + if err != nil { + logger.Error("lz4 decode player data error: %v", err) + return nil + } + playerData := out.Bytes() + endTime := time.Now().UnixNano() + costTime := endTime - startTime + logger.Debug("lz4 decode cost time: %v ns, before len: %v, after len: %v, ratio lz4/raw: %v", + costTime, len(playerDataLz4), len(playerData), float64(len(playerDataLz4))/float64(len(playerData))) player := new(model.Player) - err = msgpack.Unmarshal([]byte(playerData), player) + err = msgpack.Unmarshal(playerData, player) if err != nil { logger.Error("unmarshal player error: %v", err) return nil @@ -38,9 +56,29 @@ func (d *Dao) SetRedisPlayer(player *model.Player) { logger.Error("marshal player error: %v", err) return } - err = d.redis.Set(context.TODO(), d.GetRedisPlayerKey(player.PlayerID), playerData, time.Hour*24*30).Err() + // 压缩 + startTime := time.Now().UnixNano() + in := bytes.NewReader(playerData) + out := new(bytes.Buffer) + lz4Writer := lz4.NewWriter(out) + _, err = io.Copy(lz4Writer, in) if err != nil { - logger.Error("set player from redis error: %v", err) + logger.Error("lz4 encode player data error: %v", err) + return + } + err = lz4Writer.Close() + if err != nil { + logger.Error("lz4 encode player data error: %v", err) + return + } + playerDataLz4 := out.Bytes() + endTime := time.Now().UnixNano() + costTime := endTime - startTime + logger.Debug("lz4 encode cost time: %v ns, before len: %v, after len: %v, ratio lz4/raw: %v", + costTime, len(playerData), len(playerDataLz4), float64(len(playerDataLz4))/float64(len(playerData))) + err = d.redis.Set(context.TODO(), d.GetRedisPlayerKey(player.PlayerID), playerDataLz4, time.Hour*24*30).Err() + if err != nil { + logger.Error("set player to redis error: %v", err) return } } diff --git a/gs/game/command_gm.go b/gs/game/command_gm.go index 229e9346..336e07aa 100644 --- a/gs/game/command_gm.go +++ b/gs/game/command_gm.go @@ -54,9 +54,18 @@ func (c *CommandManager) GMAddUserAvatar(userId, avatarId uint32) { // GMAddUserAllItem 给予玩家所有物品 func (c *CommandManager) GMAddUserAllItem(userId, itemCount uint32) { + // 猜猜这样做为啥不行? + // for itemId := range GAME_MANAGER.GetAllItemDataConfig() { + // c.GMAddUserItem(userId, uint32(itemId), itemCount) + // } + itemList := make([]*UserItem, 0) for itemId := range GAME_MANAGER.GetAllItemDataConfig() { - c.GMAddUserItem(userId, uint32(itemId), itemCount) + itemList = append(itemList, &UserItem{ + ItemId: uint32(itemId), + ChangeCount: itemCount, + }) } + GAME_MANAGER.AddUserItem(userId, itemList, false, 0) } // GMAddUserAllWeapon 给予玩家所有武器 diff --git a/gs/game/game_manager.go b/gs/game/game_manager.go index b3010402..94eb5562 100644 --- a/gs/game/game_manager.go +++ b/gs/game/game_manager.go @@ -42,6 +42,8 @@ var COMMAND_MANAGER *CommandManager = nil var GCG_MANAGER *GCGManager = nil var MESSAGE_QUEUE *mq.MessageQueue +var ONLINE_PLAYER_NUM int32 = 0 // 当前在线玩家数 + var SELF *model.Player type GameManager struct { @@ -235,6 +237,7 @@ func (g *GameManager) gameMainLoop() { COMMAND_MANAGER.HandleCommand(command) end := time.Now().UnixNano() commandCost += end - start + logger.Info("run gm cmd cost: %v ns", commandCost) } } } diff --git a/gs/game/local_event_manager.go b/gs/game/local_event_manager.go index 05c5f0f3..947caf05 100644 --- a/gs/game/local_event_manager.go +++ b/gs/game/local_event_manager.go @@ -1,13 +1,15 @@ package game import ( + "sort" "time" "hk4e/common/mq" "hk4e/gdconf" "hk4e/gs/model" "hk4e/pkg/logger" - "hk4e/pkg/object" + + "github.com/vmihailenco/msgpack/v5" ) // 本地事件队列管理器 @@ -16,10 +18,10 @@ const ( LoadLoginUserFromDbFinish = iota // 玩家登录从数据库加载完成回调 CheckUserExistOnRegFromDbFinish // 玩家注册从数据库查询是否已存在完成回调 RunUserCopyAndSave // 执行一次在线玩家内存数据复制到数据库写入协程 - ExitRunUserCopyAndSave - UserOfflineSaveToDbFinish - ReloadGameDataConfig - ReloadGameDataConfigFinish + ExitRunUserCopyAndSave // 停服时执行全部玩家保存操作 + UserOfflineSaveToDbFinish // 玩家离线保存完成 + ReloadGameDataConfig // 执行热更表 + ReloadGameDataConfigFinish // 热更表完成 ) type LocalEvent struct { @@ -37,6 +39,20 @@ func NewLocalEventManager() (r *LocalEventManager) { return r } +type PlayerLastSaveTimeSortList []*model.Player + +func (p PlayerLastSaveTimeSortList) Len() int { + return len(p) +} + +func (p PlayerLastSaveTimeSortList) Less(i, j int) bool { + return p[i].LastSaveTime < p[j].LastSaveTime +} + +func (p PlayerLastSaveTimeSortList) Swap(i, j int) { + p[i], p[j] = p[j], p[i] +} + func (l *LocalEventManager) LocalEventHandle(localEvent *LocalEvent) { switch localEvent.EventId { case LoadLoginUserFromDbFinish: @@ -51,49 +67,45 @@ func (l *LocalEventManager) LocalEventHandle(localEvent *LocalEvent) { case ExitRunUserCopyAndSave: fallthrough case RunUserCopyAndSave: - saveUserIdList := localEvent.Msg.([]uint32) startTime := time.Now().UnixNano() - // 拷贝一份数据避免并发访问 - insertPlayerList := make([]*model.Player, 0) - updatePlayerList := make([]*model.Player, 0) - for _, uid := range saveUserIdList { - player := USER_MANAGER.GetOnlineUser(uid) - if player == nil { - logger.Error("try to save but user not exist or online, uid: %v", uid) + playerList := make(PlayerLastSaveTimeSortList, 0) + for _, player := range USER_MANAGER.playerMap { + if player.PlayerID < 100000000 { continue } - if uid < 100000000 { + playerList = append(playerList, player) + } + sort.Stable(playerList) + // 拷贝一份数据避免并发访问 + insertPlayerList := make([][]byte, 0) + updatePlayerList := make([][]byte, 0) + saveCount := 0 + for _, player := range playerList { + totalCostTime := time.Now().UnixNano() - startTime + if totalCostTime > time.Millisecond.Nanoseconds()*50 { + // 总耗时超过50ms就中止本轮保存 + logger.Debug("user copy loop overtime exit, total cost time: %v ns", totalCostTime) + break + } + playerData, err := msgpack.Marshal(player) + if err != nil { + logger.Error("marshal player data error: %v", err) continue } switch player.DbState { case model.DbNone: break case model.DbInsert: - playerCopy := new(model.Player) - err := object.FastDeepCopy(playerCopy, player) - if err != nil { - logger.Error("deep copy player error: %v", err) - continue - } - insertPlayerList = append(insertPlayerList, playerCopy) - USER_MANAGER.playerMap[uid].DbState = model.DbNormal + insertPlayerList = append(insertPlayerList, playerData) + USER_MANAGER.playerMap[player.PlayerID].DbState = model.DbNormal + player.LastSaveTime = uint32(time.Now().UnixMilli()) + saveCount++ case model.DbDelete: - playerCopy := new(model.Player) - err := object.FastDeepCopy(playerCopy, player) - if err != nil { - logger.Error("deep copy player error: %v", err) - continue - } - updatePlayerList = append(updatePlayerList, playerCopy) - delete(USER_MANAGER.playerMap, uid) + delete(USER_MANAGER.playerMap, player.PlayerID) case model.DbNormal: - playerCopy := new(model.Player) - err := object.FastDeepCopy(playerCopy, player) - if err != nil { - logger.Error("deep copy player error: %v", err) - continue - } - updatePlayerList = append(updatePlayerList, playerCopy) + updatePlayerList = append(updatePlayerList, playerData) + player.LastSaveTime = uint32(time.Now().UnixMilli()) + saveCount++ } } saveUserData := &SaveUserData{ @@ -107,7 +119,7 @@ func (l *LocalEventManager) LocalEventHandle(localEvent *LocalEvent) { USER_MANAGER.saveUserChan <- saveUserData endTime := time.Now().UnixNano() costTime := endTime - startTime - logger.Info("run save user copy cost time: %v ns", costTime) + logger.Debug("run save user copy cost time: %v ns, save user count: %v", costTime, saveCount) if localEvent.EventId == ExitRunUserCopyAndSave { // 在此阻塞掉主协程 不再进行任何消息和任务的处理 select {} diff --git a/gs/game/player_avatar.go b/gs/game/player_avatar.go index 7418a84f..e72b384c 100644 --- a/gs/game/player_avatar.go +++ b/gs/game/player_avatar.go @@ -1,6 +1,8 @@ package game import ( + "strconv" + "hk4e/common/constant" "hk4e/gdconf" "hk4e/gs/model" @@ -8,7 +10,6 @@ import ( "hk4e/pkg/object" "hk4e/protocol/cmd" "hk4e/protocol/proto" - "strconv" pb "google.golang.org/protobuf/proto" ) diff --git a/gs/game/player_login.go b/gs/game/player_login.go index 9839ad78..32e7a310 100644 --- a/gs/game/player_login.go +++ b/gs/game/player_login.go @@ -1,6 +1,7 @@ package game import ( + "sync/atomic" "time" "hk4e/common/constant" @@ -81,6 +82,8 @@ func (g *GameManager) OnLoginOk(userId uint32, player *model.Player, clientSeq u TICK_MANAGER.CreateUserGlobalTick(userId) TICK_MANAGER.CreateUserTimer(userId, UserTimerActionTest, 100) + + atomic.AddInt32(&ONLINE_PLAYER_NUM, 1) } func (g *GameManager) OnReg(userId uint32, clientSeq uint32, gateAppId string, payloadMsg pb.Message) { @@ -134,6 +137,8 @@ func (g *GameManager) OnUserOffline(userId uint32, changeGsInfo *ChangeGsInfo) { player.Online = false player.TotalOnlineTime += uint32(time.Now().UnixMilli()) - player.OnlineTime USER_MANAGER.OfflineUser(player, changeGsInfo) + + atomic.AddInt32(&ONLINE_PLAYER_NUM, -1) } func (g *GameManager) LoginNotify(userId uint32, player *model.Player, clientSeq uint32) { diff --git a/gs/game/player_multiplayer.go b/gs/game/player_multiplayer.go index a8c9c528..7a2aaf43 100644 --- a/gs/game/player_multiplayer.go +++ b/gs/game/player_multiplayer.go @@ -175,18 +175,18 @@ func (g *GameManager) SceneKickPlayerReq(player *model.Player, payloadMsg pb.Mes } func (g *GameManager) UserApplyEnterWorld(player *model.Player, targetUid uint32) { - applyFailNotify := func() { + applyFailNotify := func(reason proto.PlayerApplyEnterMpResultNotify_Reason) { playerApplyEnterMpResultNotify := &proto.PlayerApplyEnterMpResultNotify{ TargetUid: targetUid, TargetNickname: "", IsAgreed: false, - Reason: proto.PlayerApplyEnterMpResultNotify_PLAYER_CANNOT_ENTER_MP, + Reason: reason, } g.SendMsg(cmd.PlayerApplyEnterMpResultNotify, player.PlayerID, player.ClientSeq, playerApplyEnterMpResultNotify) } world := WORLD_MANAGER.GetWorldByID(player.WorldId) if world.multiplayer { - applyFailNotify() + applyFailNotify(proto.PlayerApplyEnterMpResultNotify_PLAYER_CANNOT_ENTER_MP) return } targetPlayer := USER_MANAGER.GetOnlineUser(targetUid) @@ -194,7 +194,7 @@ func (g *GameManager) UserApplyEnterWorld(player *model.Player, targetUid uint32 if !USER_MANAGER.GetRemoteUserOnlineState(targetUid) { // 全服不存在该在线玩家 logger.Error("target user not online in any game server, uid: %v", targetUid) - applyFailNotify() + applyFailNotify(proto.PlayerApplyEnterMpResultNotify_PLAYER_CANNOT_ENTER_MP) return } gsAppId := USER_MANAGER.GetRemoteUserGsAppId(targetUid) @@ -224,18 +224,32 @@ func (g *GameManager) UserApplyEnterWorld(player *model.Player, targetUid uint32 }) return } - applyTime, exist := targetPlayer.CoopApplyMap[player.PlayerID] - if exist && time.Now().UnixNano() < applyTime+int64(10*time.Second) { - applyFailNotify() + if WORLD_MANAGER.multiplayerWorldNum >= MAX_MULTIPLAYER_WORLD_NUM { + // 超过本服务器最大多人世界数量限制 + applyFailNotify(proto.PlayerApplyEnterMpResultNotify_MAX_PLAYER) return } - targetPlayer.CoopApplyMap[player.PlayerID] = time.Now().UnixNano() targetWorld := WORLD_MANAGER.GetWorldByID(targetPlayer.WorldId) if targetWorld.multiplayer && targetWorld.owner.PlayerID != targetPlayer.PlayerID { // 向同一世界内的非房主玩家申请时直接拒绝 - applyFailNotify() + applyFailNotify(proto.PlayerApplyEnterMpResultNotify_PLAYER_NOT_IN_PLAYER_WORLD) return } + mpSetting := targetPlayer.PropertiesMap[constant.PlayerPropertyConst.PROP_PLAYER_MP_SETTING_TYPE] + if mpSetting == 0 { + // 房主玩家没开权限 + applyFailNotify(proto.PlayerApplyEnterMpResultNotify_SCENE_CANNOT_ENTER) + return + } else if mpSetting == 1 { + g.UserDealEnterWorld(targetPlayer, player.PlayerID, true) + return + } + applyTime, exist := targetPlayer.CoopApplyMap[player.PlayerID] + if exist && time.Now().UnixNano() < applyTime+int64(10*time.Second) { + applyFailNotify(proto.PlayerApplyEnterMpResultNotify_PLAYER_CANNOT_ENTER_MP) + return + } + targetPlayer.CoopApplyMap[player.PlayerID] = time.Now().UnixNano() playerApplyEnterMpNotify := new(proto.PlayerApplyEnterMpNotify) playerApplyEnterMpNotify.SrcPlayerInfo = g.PacketOnlinePlayerInfo(player) @@ -520,7 +534,7 @@ func (g *GameManager) UpdateWorldPlayerInfo(hostWorld *World, excludePlayer *mod func (g *GameManager) ServerUserMpReq(userMpInfo *mq.UserMpInfo, gsAppId string) { switch userMpInfo.OriginInfo.CmdName { case "PlayerApplyEnterMpReq": - applyFailNotify := func() { + applyFailNotify := func(reason proto.PlayerApplyEnterMpResultNotify_Reason) { MESSAGE_QUEUE.SendToGs(gsAppId, &mq.NetMsg{ MsgType: mq.MsgTypeServer, EventId: mq.ServerUserMpRsp, @@ -529,6 +543,7 @@ func (g *GameManager) ServerUserMpReq(userMpInfo *mq.UserMpInfo, gsAppId string) OriginInfo: userMpInfo.OriginInfo, HostUserId: userMpInfo.HostUserId, ApplyOk: false, + Reason: int32(reason), }, }, }) @@ -536,21 +551,35 @@ func (g *GameManager) ServerUserMpReq(userMpInfo *mq.UserMpInfo, gsAppId string) hostPlayer := USER_MANAGER.GetOnlineUser(userMpInfo.HostUserId) if hostPlayer == nil { logger.Error("player is nil, uid: %v", userMpInfo.HostUserId) - applyFailNotify() + applyFailNotify(proto.PlayerApplyEnterMpResultNotify_PLAYER_CANNOT_ENTER_MP) + return + } + if WORLD_MANAGER.multiplayerWorldNum >= MAX_MULTIPLAYER_WORLD_NUM { + // 超过本服务器最大多人世界数量限制 + applyFailNotify(proto.PlayerApplyEnterMpResultNotify_MAX_PLAYER) + return + } + hostWorld := WORLD_MANAGER.GetWorldByID(hostPlayer.WorldId) + if hostWorld.multiplayer && hostWorld.owner.PlayerID != hostPlayer.PlayerID { + // 向同一世界内的非房主玩家申请时直接拒绝 + applyFailNotify(proto.PlayerApplyEnterMpResultNotify_PLAYER_NOT_IN_PLAYER_WORLD) + return + } + mpSetting := hostPlayer.PropertiesMap[constant.PlayerPropertyConst.PROP_PLAYER_MP_SETTING_TYPE] + if mpSetting == 0 { + // 房主玩家没开权限 + applyFailNotify(proto.PlayerApplyEnterMpResultNotify_SCENE_CANNOT_ENTER) + return + } else if mpSetting == 1 { + g.UserDealEnterWorld(hostPlayer, userMpInfo.ApplyUserId, true) return } applyTime, exist := hostPlayer.CoopApplyMap[userMpInfo.ApplyUserId] if exist && time.Now().UnixNano() < applyTime+int64(10*time.Second) { - applyFailNotify() + applyFailNotify(proto.PlayerApplyEnterMpResultNotify_PLAYER_CANNOT_ENTER_MP) return } hostPlayer.CoopApplyMap[userMpInfo.ApplyUserId] = time.Now().UnixNano() - hostWorld := WORLD_MANAGER.GetWorldByID(hostPlayer.WorldId) - if hostWorld.multiplayer && hostWorld.owner.PlayerID != hostPlayer.PlayerID { - // 向同一世界内的非房主玩家申请时直接拒绝 - applyFailNotify() - return - } playerApplyEnterMpNotify := new(proto.PlayerApplyEnterMpNotify) playerApplyEnterMpNotify.SrcPlayerInfo = &proto.OnlinePlayerInfo{ @@ -617,7 +646,7 @@ func (g *GameManager) ServerUserMpRsp(userMpInfo *mq.UserMpInfo) { TargetUid: userMpInfo.HostUserId, TargetNickname: "", IsAgreed: false, - Reason: proto.PlayerApplyEnterMpResultNotify_PLAYER_CANNOT_ENTER_MP, + Reason: proto.PlayerApplyEnterMpResultNotify_Reason(userMpInfo.Reason), } g.SendMsg(cmd.PlayerApplyEnterMpResultNotify, player.PlayerID, player.ClientSeq, playerApplyEnterMpResultNotify) } diff --git a/gs/game/player_scene.go b/gs/game/player_scene.go index 4e66e474..3ecded51 100644 --- a/gs/game/player_scene.go +++ b/gs/game/player_scene.go @@ -17,6 +17,10 @@ import ( pb "google.golang.org/protobuf/proto" ) +const ( + ENTITY_MAX_BATCH_SEND_NUM = 1000 // 单次同步的最大实体数量 +) + func (g *GameManager) EnterSceneReadyReq(player *model.Player, payloadMsg pb.Message) { logger.Debug("user enter scene ready, uid: %v", player.PlayerID) world := WORLD_MANAGER.GetWorldByID(player.WorldId) @@ -531,16 +535,14 @@ func (g *GameManager) RemoveSceneEntityNotifyBroadcast(scene *Scene, visionType } } -const ENTITY_BATCH_SIZE = 1000 - func (g *GameManager) AddSceneEntityNotify(player *model.Player, visionType proto.VisionType, entityIdList []uint32, broadcast bool, aec bool) { world := WORLD_MANAGER.GetWorldByID(player.WorldId) scene := world.GetSceneById(player.SceneId) // 如果总数量太多则分包发送 - times := int(math.Ceil(float64(len(entityIdList)) / float64(ENTITY_BATCH_SIZE))) + times := int(math.Ceil(float64(len(entityIdList)) / float64(ENTITY_MAX_BATCH_SEND_NUM))) for i := 0; i < times; i++ { - begin := ENTITY_BATCH_SIZE * i - end := ENTITY_BATCH_SIZE * (i + 1) + begin := ENTITY_MAX_BATCH_SEND_NUM * i + end := ENTITY_MAX_BATCH_SEND_NUM * (i + 1) if i == times-1 { end = len(entityIdList) } diff --git a/gs/game/player_weapon.go b/gs/game/player_weapon.go index f1400056..32e0d32a 100644 --- a/gs/game/player_weapon.go +++ b/gs/game/player_weapon.go @@ -1,14 +1,15 @@ package game import ( + "sort" + "strconv" + "hk4e/common/constant" "hk4e/gdconf" "hk4e/gs/model" "hk4e/pkg/logger" "hk4e/protocol/cmd" "hk4e/protocol/proto" - "sort" - "strconv" pb "google.golang.org/protobuf/proto" ) diff --git a/gs/game/tick_manager.go b/gs/game/tick_manager.go index 0d45613c..0dbe22d7 100644 --- a/gs/game/tick_manager.go +++ b/gs/game/tick_manager.go @@ -81,12 +81,6 @@ func (t *TickManager) onUserTickSecond(userId uint32, now int64) { } func (t *TickManager) onUserTickMinute(userId uint32, now int64) { - // 每分钟保存玩家数据 - saveUserIdList := []uint32{userId} - LOCAL_EVENT_MANAGER.localEventChan <- &LocalEvent{ - EventId: RunUserCopyAndSave, - Msg: saveUserIdList, - } } // 玩家定时任务常量 diff --git a/gs/game/user_manager.go b/gs/game/user_manager.go index b2da37d6..ede882e7 100644 --- a/gs/game/user_manager.go +++ b/gs/game/user_manager.go @@ -1,11 +1,14 @@ package game import ( + "time" + "hk4e/gs/dao" "hk4e/gs/model" "hk4e/pkg/logger" - "hk4e/pkg/object" "hk4e/protocol/proto" + + "github.com/vmihailenco/msgpack/v5" ) // 玩家管理器 @@ -167,14 +170,19 @@ type PlayerOfflineInfo struct { // OfflineUser 玩家离线 func (u *UserManager) OfflineUser(player *model.Player, changeGsInfo *ChangeGsInfo) { - playerCopy := new(model.Player) - err := object.FastDeepCopy(playerCopy, player) + playerData, err := msgpack.Marshal(player) if err != nil { - logger.Error("deep copy player error: %v", err) + logger.Error("marshal player data error: %v", err) return } - playerCopy.DbState = player.DbState go func() { + playerCopy := new(model.Player) + err := msgpack.Unmarshal(playerData, playerCopy) + if err != nil { + logger.Error("unmarshal player data error: %v", err) + return + } + playerCopy.DbState = player.DbState u.SaveUserToDbSync(playerCopy) u.SaveUserToRedisSync(playerCopy) LOCAL_EVENT_MANAGER.localEventChan <- &LocalEvent{ @@ -336,14 +344,19 @@ func (u *UserManager) SaveTempOfflineUser(player *model.Player) { // 主协程同步写入redis u.SaveUserToRedisSync(player) // 另一个协程异步的写回db - playerCopy := new(model.Player) - err := object.FastDeepCopy(playerCopy, player) + playerData, err := msgpack.Marshal(player) if err != nil { - logger.Error("deep copy player error: %v", err) + logger.Error("marshal player data error: %v", err) return } - playerCopy.DbState = player.DbState go func() { + playerCopy := new(model.Player) + err := msgpack.Unmarshal(playerData, playerCopy) + if err != nil { + logger.Error("unmarshal player data error: %v", err) + return + } + playerCopy.DbState = player.DbState u.SaveUserToDbSync(playerCopy) }() } @@ -351,21 +364,56 @@ func (u *UserManager) SaveTempOfflineUser(player *model.Player) { // db和redis相关操作 type SaveUserData struct { - insertPlayerList []*model.Player - updatePlayerList []*model.Player + insertPlayerList [][]byte + updatePlayerList [][]byte exitSave bool } func (u *UserManager) saveUserHandle() { - for { - saveUserData := <-u.saveUserChan - u.SaveUserListToDbSync(saveUserData) - u.SaveUserListToRedisSync(saveUserData) - if saveUserData.exitSave { - // 停服落地玩家数据完毕 通知APP主协程关闭程序 - EXIT_SAVE_FIN_CHAN <- true + go func() { + ticker := time.NewTicker(time.Minute) + for { + <-ticker.C + // 保存玩家数据 + LOCAL_EVENT_MANAGER.localEventChan <- &LocalEvent{ + EventId: RunUserCopyAndSave, + } } - } + }() + go func() { + for { + saveUserData := <-u.saveUserChan + insertPlayerList := make([]*model.Player, 0) + updatePlayerList := make([]*model.Player, 0) + setPlayerList := make([]*model.Player, 0) + for _, playerData := range saveUserData.insertPlayerList { + player := new(model.Player) + err := msgpack.Unmarshal(playerData, player) + if err != nil { + logger.Error("unmarshal player data error: %v", err) + continue + } + insertPlayerList = append(insertPlayerList, player) + setPlayerList = append(setPlayerList, player) + } + for _, playerData := range saveUserData.updatePlayerList { + player := new(model.Player) + err := msgpack.Unmarshal(playerData, player) + if err != nil { + logger.Error("unmarshal player data error: %v", err) + continue + } + updatePlayerList = append(updatePlayerList, player) + setPlayerList = append(setPlayerList, player) + } + u.SaveUserListToDbSync(insertPlayerList, updatePlayerList) + u.SaveUserListToRedisSync(setPlayerList) + if saveUserData.exitSave { + // 停服落地玩家数据完毕 通知APP主协程关闭程序 + EXIT_SAVE_FIN_CHAN <- true + } + } + }() } func (u *UserManager) LoadUserFromDbSync(userId uint32) *model.Player { @@ -395,18 +443,18 @@ func (u *UserManager) SaveUserToDbSync(player *model.Player) { } } -func (u *UserManager) SaveUserListToDbSync(saveUserData *SaveUserData) { - err := u.dao.InsertPlayerList(saveUserData.insertPlayerList) +func (u *UserManager) SaveUserListToDbSync(insertPlayerList []*model.Player, updatePlayerList []*model.Player) { + err := u.dao.InsertPlayerList(insertPlayerList) if err != nil { logger.Error("insert player list error: %v", err) return } - err = u.dao.UpdatePlayerList(saveUserData.updatePlayerList) + err = u.dao.UpdatePlayerList(updatePlayerList) if err != nil { logger.Error("update player list error: %v", err) return } - logger.Info("save user finish, insert user count: %v, update user count: %v", len(saveUserData.insertPlayerList), len(saveUserData.updatePlayerList)) + logger.Info("save user finish, insert user count: %v, update user count: %v", len(insertPlayerList), len(updatePlayerList)) } func (u *UserManager) LoadUserFromRedisSync(userId uint32) *model.Player { @@ -418,13 +466,6 @@ func (u *UserManager) SaveUserToRedisSync(player *model.Player) { u.dao.SetRedisPlayer(player) } -func (u *UserManager) SaveUserListToRedisSync(saveUserData *SaveUserData) { - setPlayerList := make([]*model.Player, 0, len(saveUserData.insertPlayerList)+len(saveUserData.updatePlayerList)) - for _, player := range saveUserData.insertPlayerList { - setPlayerList = append(setPlayerList, player) - } - for _, player := range saveUserData.updatePlayerList { - setPlayerList = append(setPlayerList, player) - } +func (u *UserManager) SaveUserListToRedisSync(setPlayerList []*model.Player) { u.dao.SetRedisPlayerList(setPlayerList) } diff --git a/gs/game/world_manager.go b/gs/game/world_manager.go index 2dced467..54ff56c2 100644 --- a/gs/game/world_manager.go +++ b/gs/game/world_manager.go @@ -16,11 +16,18 @@ import ( // 世界管理器 +const ( + ENTITY_NUM_UNLIMIT = false // 是否不限制场景内实体数量 + ENTITY_MAX_SEND_NUM = 300 // 场景内最大实体数量 + MAX_MULTIPLAYER_WORLD_NUM = 10 // 本服务器最大多人世界数量 +) + type WorldManager struct { - worldMap map[uint32]*World - snowflake *alg.SnowflakeWorker - aiWorld *World // 本服的Ai玩家世界 - sceneBlockAoiMap map[uint32]*alg.AoiManager // 全局各场景地图的aoi管理器 + worldMap map[uint32]*World + snowflake *alg.SnowflakeWorker + aiWorld *World // 本服的Ai玩家世界 + sceneBlockAoiMap map[uint32]*alg.AoiManager // 全局各场景地图的aoi管理器 + multiplayerWorldNum uint32 // 本服务器的多人世界数量 } func NewWorldManager(snowflake *alg.SnowflakeWorker) (r *WorldManager) { @@ -136,6 +143,7 @@ func NewWorldManager(snowflake *alg.SnowflakeWorker) (r *WorldManager) { } r.sceneBlockAoiMap[uint32(sceneConfig.Id)] = aoiManager } + r.multiplayerWorldNum = 0 return r } @@ -176,6 +184,9 @@ func (w *WorldManager) DestroyWorld(worldId uint32) { player.WorldId = 0 } delete(w.worldMap, worldId) + if world.multiplayer { + w.multiplayerWorldNum-- + } } // GetAiWorld 获取本服务器的Ai世界 @@ -618,6 +629,7 @@ func (w *World) GetChatList() []*proto.ChatInfo { // ChangeToMultiplayer 转换为多人世界 func (w *World) ChangeToMultiplayer() { + WORLD_MANAGER.multiplayerWorldNum++ w.multiplayer = true } @@ -665,10 +677,10 @@ type Scene struct { world *World playerMap map[uint32]*model.Player entityMap map[uint32]*Entity - objectIdEntityMap map[int64]*Entity - gameTime uint32 // 游戏内提瓦特大陆的时间 - createTime int64 - meeoIndex uint32 // 客户端风元素染色同步协议的计数器 + objectIdEntityMap map[int64]*Entity // 用于标识配置档里的唯一实体是否已被创建 + gameTime uint32 // 游戏内提瓦特大陆的时间 + createTime int64 // 场景创建时间 + meeoIndex uint32 // 客户端风元素染色同步协议的计数器 } func (s *Scene) GetAllPlayer() map[uint32]*model.Player { @@ -873,7 +885,7 @@ func (s *Scene) CreateEntityAvatar(player *model.Player, avatarId uint32) uint32 avatarId: avatarId, }, } - s.entityMap[entity.id] = entity + s.CreateEntity(entity, 0) MESSAGE_QUEUE.SendToFight(s.world.owner.FightAppId, &mq.NetMsg{ MsgType: mq.MsgTypeFight, EventId: mq.FightRoutineAddEntity, @@ -903,7 +915,7 @@ func (s *Scene) CreateEntityWeapon() uint32 { entityType: uint32(proto.ProtEntityType_PROT_ENTITY_WEAPON), level: 0, } - s.entityMap[entity.id] = entity + s.CreateEntity(entity, 0) return entity.id } @@ -931,8 +943,7 @@ func (s *Scene) CreateEntityMonster(pos, rot *model.Vector, monsterId uint32, le configId: configId, objectId: objectId, } - s.entityMap[entity.id] = entity - s.objectIdEntityMap[objectId] = entity + s.CreateEntity(entity, objectId) MESSAGE_QUEUE.SendToFight(s.world.owner.FightAppId, &mq.NetMsg{ MsgType: mq.MsgTypeFight, EventId: mq.FightRoutineAddEntity, @@ -976,8 +987,7 @@ func (s *Scene) CreateEntityNpc(pos, rot *model.Vector, npcId, roomId, parentQue configId: configId, objectId: objectId, } - s.entityMap[entity.id] = entity - s.objectIdEntityMap[objectId] = entity + s.CreateEntity(entity, objectId) return entity.id } @@ -1010,8 +1020,7 @@ func (s *Scene) CreateEntityGadgetNormal(pos, rot *model.Vector, gadgetId uint32 configId: configId, objectId: objectId, } - s.entityMap[entity.id] = entity - s.objectIdEntityMap[objectId] = entity + s.CreateEntity(entity, objectId) return entity.id } @@ -1047,8 +1056,7 @@ func (s *Scene) CreateEntityGadgetGather(pos, rot *model.Vector, gadgetId uint32 configId: configId, objectId: objectId, } - s.entityMap[entity.id] = entity - s.objectIdEntityMap[objectId] = entity + s.CreateEntity(entity, objectId) return entity.id } @@ -1081,7 +1089,7 @@ func (s *Scene) CreateEntityGadgetClient(pos, rot *model.Vector, entityId uint32 }, }, } - s.entityMap[entity.id] = entity + s.CreateEntity(entity, 0) } func (s *Scene) CreateEntityGadgetVehicle(uid uint32, pos, rot *model.Vector, vehicleId uint32) uint32 { @@ -1119,10 +1127,21 @@ func (s *Scene) CreateEntityGadgetVehicle(uid uint32, pos, rot *model.Vector, ve }, }, } - s.entityMap[entity.id] = entity + s.CreateEntity(entity, 0) return entity.id } +func (s *Scene) CreateEntity(entity *Entity, objectId int64) { + if len(s.entityMap) >= ENTITY_MAX_SEND_NUM && !ENTITY_NUM_UNLIMIT { + logger.Error("above max scene entity num limit: %v, id: %v, pos: %v", ENTITY_MAX_SEND_NUM, entity.id, entity.pos) + return + } + if objectId != 0 { + s.objectIdEntityMap[objectId] = entity + } + s.entityMap[entity.id] = entity +} + func (s *Scene) DestroyEntity(entityId uint32) { entity := s.GetEntity(entityId) if entity == nil { diff --git a/gs/model/player.go b/gs/model/player.go index ea011cbe..1caac8ec 100644 --- a/gs/model/player.go +++ b/gs/model/player.go @@ -57,6 +57,7 @@ type Player struct { GCGInfo *GCGInfo `bson:"gcgInfo"` // 七圣召唤信息 IsGM uint8 `bson:"isGM"` // 管理员权限等级 // 在线数据 请随意 记得加忽略字段的tag + LastSaveTime uint32 `bson:"-" msgpack:"-"` // 上一次保存时间 EnterSceneToken uint32 `bson:"-" msgpack:"-"` // 玩家的世界进入令牌 DbState int `bson:"-" msgpack:"-"` // 数据库存档状态 WorldId uint32 `bson:"-" msgpack:"-"` // 所在的世界id diff --git a/node/api/api.proto b/node/api/api.proto index 98d6be2c..5485ae6e 100644 --- a/node/api/api.proto +++ b/node/api/api.proto @@ -53,6 +53,7 @@ message CancelServerReq { message KeepaliveServerReq { string server_type = 1; string app_id = 2; + uint32 load_count = 3; } message GetGateServerAddrReq { diff --git a/node/service/discovery.go b/node/service/discovery.go index d38fd534..5f8ec0b0 100644 --- a/node/service/discovery.go +++ b/node/service/discovery.go @@ -2,6 +2,7 @@ package service import ( "context" + "math" "sort" "strings" "sync" @@ -46,6 +47,7 @@ type ServerInstance struct { version []string lastAliveTime int64 gsId uint32 + loadCount uint32 } type DiscoveryService struct { @@ -90,6 +92,7 @@ func (s *DiscoveryService) RegisterServer(ctx context.Context, req *api.Register serverType: req.ServerType, appId: appId, lastAliveTime: time.Now().Unix(), + loadCount: 0, } if req.ServerType == api.GATE { logger.Info("register new gate server, ip: %v, port: %v", req.GateServerAddr.KcpAddr, req.GateServerAddr.KcpPort) @@ -133,7 +136,7 @@ func (s *DiscoveryService) CancelServer(ctx context.Context, req *api.CancelServ // KeepaliveServer 服务器在线心跳保持 func (s *DiscoveryService) KeepaliveServer(ctx context.Context, req *api.KeepaliveServerReq) (*api.NullMsg, error) { - logger.Debug("server keepalive, server type: %v, appid: %v", req.ServerType, req.AppId) + logger.Debug("server keepalive, server type: %v, appid: %v, load: %v", req.ServerType, req.AppId, req.LoadCount) instMap, exist := s.serverInstanceMap[req.ServerType] if !exist { return nil, errors.New("server type not exist") @@ -145,6 +148,7 @@ func (s *DiscoveryService) KeepaliveServer(ctx context.Context, req *api.Keepali } serverInstance := inst.(*ServerInstance) serverInstance.lastAliveTime = time.Now().Unix() + serverInstance.loadCount = req.LoadCount return &api.NullMsg{}, nil } @@ -158,7 +162,12 @@ func (s *DiscoveryService) GetServerAppId(ctx context.Context, req *api.GetServe if s.getServerInstanceMapLen(instMap) == 0 { return nil, errors.New("no server found") } - inst := s.getRandomServerInstance(instMap) + var inst *ServerInstance = nil + if req.ServerType == api.GATE || req.ServerType == api.GS { + inst = s.getMinLoadServerInstance(instMap) + } else { + inst = s.getRandomServerInstance(instMap) + } logger.Debug("get server appid is: %v", inst.appId) return &api.GetServerAppIdRsp{ AppId: inst.appId, @@ -197,7 +206,7 @@ func (s *DiscoveryService) GetGateServerAddr(ctx context.Context, req *api.GetGa if s.getServerInstanceMapLen(&versionInstMap) == 0 { return nil, errors.New("no gate server found") } - inst := s.getRandomServerInstance(&versionInstMap) + inst := s.getMinLoadServerInstance(&versionInstMap) logger.Debug("get gate server addr is, ip: %v, port: %v", inst.gateServerKcpAddr, inst.gateServerKcpPort) return &api.GateServerAddr{ KcpAddr: inst.gateServerKcpAddr, @@ -270,6 +279,25 @@ func (s *DiscoveryService) getRandomServerInstance(instMap *sync.Map) *ServerIns return inst } +func (s *DiscoveryService) getMinLoadServerInstance(instMap *sync.Map) *ServerInstance { + instList := make(ServerInstanceSortList, 0) + instMap.Range(func(key, value any) bool { + instList = append(instList, value.(*ServerInstance)) + return true + }) + sort.Stable(instList) + minLoadInstIndex := 0 + minLoadInstCount := math.MaxUint32 + for index, inst := range instList { + if inst.loadCount < uint32(minLoadInstCount) { + minLoadInstCount = int(inst.loadCount) + minLoadInstIndex = index + } + } + inst := instList[minLoadInstIndex] + return inst +} + func (s *DiscoveryService) getServerInstanceMapLen(instMap *sync.Map) int { count := 0 instMap.Range(func(key, value any) bool { diff --git a/pkg/object/object.go b/pkg/object/object.go index 929f65f6..b1faaac2 100644 --- a/pkg/object/object.go +++ b/pkg/object/object.go @@ -8,12 +8,11 @@ import ( "strings" "github.com/pkg/errors" - "github.com/vmihailenco/msgpack/v5" "google.golang.org/protobuf/encoding/protojson" pb "google.golang.org/protobuf/proto" ) -func FullDeepCopy(dst, src any) error { +func DeepCopy(dst, src any) error { var buf bytes.Buffer err := gob.NewEncoder(&buf).Encode(src) if err != nil { @@ -26,18 +25,6 @@ func FullDeepCopy(dst, src any) error { return nil } -func FastDeepCopy(dst, src any) error { - data, err := msgpack.Marshal(src) - if err != nil { - return err - } - err = msgpack.Unmarshal(data, dst) - if err != nil { - return err - } - return nil -} - func CopyProtoBufSameField(dst, src pb.Message) ([]string, error) { data, err := protojson.Marshal(src) if err != nil {