系统架构层面流量控制功能完善

This commit is contained in:
flswld
2023-02-05 07:18:43 +08:00
parent cfb001c18a
commit 94c8db402a
51 changed files with 1049 additions and 2408 deletions

View File

@@ -6,6 +6,7 @@ import (
"os"
"os/signal"
"strings"
"sync/atomic"
"syscall"
"time"
@@ -50,6 +51,7 @@ func Run(ctx context.Context, configFile string) error {
_, err := client.Discovery.KeepaliveServer(context.TODO(), &api.KeepaliveServerReq{
ServerType: api.GATE,
AppId: APPID,
LoadCount: uint32(atomic.LoadInt32(&net.CLIENT_CONN_NUM)),
})
if err != nil {
logger.Error("keepalive error: %v", err)
@@ -75,7 +77,8 @@ func Run(ctx context.Context, configFile string) error {
go func() {
outputChan := connectManager.GetKcpEventOutputChan()
for {
<-outputChan
kcpEvent := <-outputChan
logger.Info("kcpEvent: %v", kcpEvent)
}
}()

View File

@@ -1,64 +0,0 @@
package kcp
const maxAutoTuneSamples = 258
// pulse represents a 0/1 signal with time sequence
type pulse struct {
bit bool // 0 or 1
seq uint32 // sequence of the signal
}
// autoTune object
type autoTune struct {
pulses [maxAutoTuneSamples]pulse
}
// Sample adds a signal sample to the pulse buffer
func (tune *autoTune) Sample(bit bool, seq uint32) {
tune.pulses[seq%maxAutoTuneSamples] = pulse{bit, seq}
}
// Find a period for a given signal
// returns -1 if not found
//
// --- ------
// | |
// |______________|
// Period
// Falling Edge Rising Edge
func (tune *autoTune) FindPeriod(bit bool) int {
// last pulse and initial index setup
lastPulse := tune.pulses[0]
idx := 1
// left edge
var leftEdge int
for ; idx < len(tune.pulses); idx++ {
if lastPulse.bit != bit && tune.pulses[idx].bit == bit { // edge found
if lastPulse.seq+1 == tune.pulses[idx].seq { // ensure edge continuity
leftEdge = idx
break
}
}
lastPulse = tune.pulses[idx]
}
// right edge
var rightEdge int
lastPulse = tune.pulses[leftEdge]
idx = leftEdge + 1
for ; idx < len(tune.pulses); idx++ {
if lastPulse.seq+1 == tune.pulses[idx].seq { // ensure pulses in this level monotonic
if lastPulse.bit == bit && tune.pulses[idx].bit != bit { // edge found
rightEdge = idx
break
}
} else {
return -1
}
lastPulse = tune.pulses[idx]
}
return rightEdge - leftEdge
}

View File

@@ -1,47 +0,0 @@
package kcp
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAutoTune(t *testing.T) {
signals := []uint32{0, 0, 0, 0, 0, 0}
tune := autoTune{}
for i := 0; i < len(signals); i++ {
if signals[i] == 0 {
tune.Sample(false, uint32(i))
} else {
tune.Sample(true, uint32(i))
}
}
assert.Equal(t, -1, tune.FindPeriod(false))
assert.Equal(t, -1, tune.FindPeriod(true))
signals = []uint32{1, 0, 1, 0, 0, 1}
tune = autoTune{}
for i := 0; i < len(signals); i++ {
if signals[i] == 0 {
tune.Sample(false, uint32(i))
} else {
tune.Sample(true, uint32(i))
}
}
assert.Equal(t, 1, tune.FindPeriod(false))
assert.Equal(t, 1, tune.FindPeriod(true))
signals = []uint32{1, 0, 0, 0, 0, 1}
tune = autoTune{}
for i := 0; i < len(signals); i++ {
if signals[i] == 0 {
tune.Sample(false, uint32(i))
} else {
tune.Sample(true, uint32(i))
}
}
assert.Equal(t, -1, tune.FindPeriod(true))
assert.Equal(t, 4, tune.FindPeriod(false))
}

View File

@@ -1,617 +0,0 @@
package kcp
import (
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/sha1"
"unsafe"
xor "github.com/templexxx/xorsimd"
"github.com/tjfoc/gmsm/sm4"
"golang.org/x/crypto/blowfish"
"golang.org/x/crypto/cast5"
"golang.org/x/crypto/pbkdf2"
"golang.org/x/crypto/salsa20"
"golang.org/x/crypto/tea"
"golang.org/x/crypto/twofish"
"golang.org/x/crypto/xtea"
)
var (
initialVector = []byte{167, 115, 79, 156, 18, 172, 27, 1, 164, 21, 242, 193, 252, 120, 230, 107}
saltxor = `sH3CIVoF#rWLtJo6`
)
// BlockCrypt defines encryption/decryption methods for a given byte slice.
// Notes on implementing: the data to be encrypted contains a builtin
// nonce at the first 16 bytes
type BlockCrypt interface {
// Encrypt encrypts the whole block in src into dst.
// Dst and src may point at the same memory.
Encrypt(dst, src []byte)
// Decrypt decrypts the whole block in src into dst.
// Dst and src may point at the same memory.
Decrypt(dst, src []byte)
}
type salsa20BlockCrypt struct {
key [32]byte
}
// NewSalsa20BlockCrypt https://en.wikipedia.org/wiki/Salsa20
func NewSalsa20BlockCrypt(key []byte) (BlockCrypt, error) {
c := new(salsa20BlockCrypt)
copy(c.key[:], key)
return c, nil
}
func (c *salsa20BlockCrypt) Encrypt(dst, src []byte) {
salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key)
copy(dst[:8], src[:8])
}
func (c *salsa20BlockCrypt) Decrypt(dst, src []byte) {
salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key)
copy(dst[:8], src[:8])
}
type sm4BlockCrypt struct {
encbuf [sm4.BlockSize]byte // 64bit alignment enc/dec buffer
decbuf [2 * sm4.BlockSize]byte
block cipher.Block
}
// NewSM4BlockCrypt https://github.com/tjfoc/gmsm/tree/master/sm4
func NewSM4BlockCrypt(key []byte) (BlockCrypt, error) {
c := new(sm4BlockCrypt)
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *sm4BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *sm4BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type twofishBlockCrypt struct {
encbuf [twofish.BlockSize]byte
decbuf [2 * twofish.BlockSize]byte
block cipher.Block
}
// NewTwofishBlockCrypt https://en.wikipedia.org/wiki/Twofish
func NewTwofishBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(twofishBlockCrypt)
block, err := twofish.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *twofishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *twofishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type tripleDESBlockCrypt struct {
encbuf [des.BlockSize]byte
decbuf [2 * des.BlockSize]byte
block cipher.Block
}
// NewTripleDESBlockCrypt https://en.wikipedia.org/wiki/Triple_DES
func NewTripleDESBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(tripleDESBlockCrypt)
block, err := des.NewTripleDESCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *tripleDESBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *tripleDESBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type cast5BlockCrypt struct {
encbuf [cast5.BlockSize]byte
decbuf [2 * cast5.BlockSize]byte
block cipher.Block
}
// NewCast5BlockCrypt https://en.wikipedia.org/wiki/CAST-128
func NewCast5BlockCrypt(key []byte) (BlockCrypt, error) {
c := new(cast5BlockCrypt)
block, err := cast5.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *cast5BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *cast5BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type blowfishBlockCrypt struct {
encbuf [blowfish.BlockSize]byte
decbuf [2 * blowfish.BlockSize]byte
block cipher.Block
}
// NewBlowfishBlockCrypt https://en.wikipedia.org/wiki/Blowfish_(cipher)
func NewBlowfishBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(blowfishBlockCrypt)
block, err := blowfish.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *blowfishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *blowfishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type aesBlockCrypt struct {
encbuf [aes.BlockSize]byte
decbuf [2 * aes.BlockSize]byte
block cipher.Block
}
// NewAESBlockCrypt https://en.wikipedia.org/wiki/Advanced_Encryption_Standard
func NewAESBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(aesBlockCrypt)
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *aesBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *aesBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type teaBlockCrypt struct {
encbuf [tea.BlockSize]byte
decbuf [2 * tea.BlockSize]byte
block cipher.Block
}
// NewTEABlockCrypt https://en.wikipedia.org/wiki/Tiny_Encryption_Algorithm
func NewTEABlockCrypt(key []byte) (BlockCrypt, error) {
c := new(teaBlockCrypt)
block, err := tea.NewCipherWithRounds(key, 16)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *teaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *teaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type xteaBlockCrypt struct {
encbuf [xtea.BlockSize]byte
decbuf [2 * xtea.BlockSize]byte
block cipher.Block
}
// NewXTEABlockCrypt https://en.wikipedia.org/wiki/XTEA
func NewXTEABlockCrypt(key []byte) (BlockCrypt, error) {
c := new(xteaBlockCrypt)
block, err := xtea.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *xteaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *xteaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type simpleXORBlockCrypt struct {
xortbl []byte
}
// NewSimpleXORBlockCrypt simple xor with key expanding
func NewSimpleXORBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(simpleXORBlockCrypt)
c.xortbl = pbkdf2.Key(key, []byte(saltxor), 32, mtuLimit, sha1.New)
return c, nil
}
func (c *simpleXORBlockCrypt) Encrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) }
func (c *simpleXORBlockCrypt) Decrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) }
type noneBlockCrypt struct{}
// NewNoneBlockCrypt does nothing but copying
func NewNoneBlockCrypt(key []byte) (BlockCrypt, error) {
return new(noneBlockCrypt), nil
}
func (c *noneBlockCrypt) Encrypt(dst, src []byte) { copy(dst, src) }
func (c *noneBlockCrypt) Decrypt(dst, src []byte) { copy(dst, src) }
// packet encryption with local CFB mode
func encrypt(block cipher.Block, dst, src, buf []byte) {
switch block.BlockSize() {
case 8:
encrypt8(block, dst, src, buf)
case 16:
encrypt16(block, dst, src, buf)
default:
panic("unsupported cipher block size")
}
}
// optimized encryption for the ciphers which works in 8-bytes
func encrypt8(block cipher.Block, dst, src, buf []byte) {
tbl := buf[:8]
block.Encrypt(tbl, initialVector)
n := len(src) / 8
base := 0
repeat := n / 8
left := n % 8
ptr_tbl := (*uint64)(unsafe.Pointer(&tbl[0]))
for i := 0; i < repeat; i++ {
s := src[base:][0:64]
d := dst[base:][0:64]
// 1
*(*uint64)(unsafe.Pointer(&d[0])) = *(*uint64)(unsafe.Pointer(&s[0])) ^ *ptr_tbl
block.Encrypt(tbl, d[0:8])
// 2
*(*uint64)(unsafe.Pointer(&d[8])) = *(*uint64)(unsafe.Pointer(&s[8])) ^ *ptr_tbl
block.Encrypt(tbl, d[8:16])
// 3
*(*uint64)(unsafe.Pointer(&d[16])) = *(*uint64)(unsafe.Pointer(&s[16])) ^ *ptr_tbl
block.Encrypt(tbl, d[16:24])
// 4
*(*uint64)(unsafe.Pointer(&d[24])) = *(*uint64)(unsafe.Pointer(&s[24])) ^ *ptr_tbl
block.Encrypt(tbl, d[24:32])
// 5
*(*uint64)(unsafe.Pointer(&d[32])) = *(*uint64)(unsafe.Pointer(&s[32])) ^ *ptr_tbl
block.Encrypt(tbl, d[32:40])
// 6
*(*uint64)(unsafe.Pointer(&d[40])) = *(*uint64)(unsafe.Pointer(&s[40])) ^ *ptr_tbl
block.Encrypt(tbl, d[40:48])
// 7
*(*uint64)(unsafe.Pointer(&d[48])) = *(*uint64)(unsafe.Pointer(&s[48])) ^ *ptr_tbl
block.Encrypt(tbl, d[48:56])
// 8
*(*uint64)(unsafe.Pointer(&d[56])) = *(*uint64)(unsafe.Pointer(&s[56])) ^ *ptr_tbl
block.Encrypt(tbl, d[56:64])
base += 64
}
switch left {
case 7:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 6:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 5:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 4:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 3:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 2:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 1:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 0:
xorBytes(dst[base:], src[base:], tbl)
}
}
// optimized encryption for the ciphers which works in 16-bytes
func encrypt16(block cipher.Block, dst, src, buf []byte) {
tbl := buf[:16]
block.Encrypt(tbl, initialVector)
n := len(src) / 16
base := 0
repeat := n / 8
left := n % 8
for i := 0; i < repeat; i++ {
s := src[base:][0:128]
d := dst[base:][0:128]
// 1
xor.Bytes16Align(d[0:16], s[0:16], tbl)
block.Encrypt(tbl, d[0:16])
// 2
xor.Bytes16Align(d[16:32], s[16:32], tbl)
block.Encrypt(tbl, d[16:32])
// 3
xor.Bytes16Align(d[32:48], s[32:48], tbl)
block.Encrypt(tbl, d[32:48])
// 4
xor.Bytes16Align(d[48:64], s[48:64], tbl)
block.Encrypt(tbl, d[48:64])
// 5
xor.Bytes16Align(d[64:80], s[64:80], tbl)
block.Encrypt(tbl, d[64:80])
// 6
xor.Bytes16Align(d[80:96], s[80:96], tbl)
block.Encrypt(tbl, d[80:96])
// 7
xor.Bytes16Align(d[96:112], s[96:112], tbl)
block.Encrypt(tbl, d[96:112])
// 8
xor.Bytes16Align(d[112:128], s[112:128], tbl)
block.Encrypt(tbl, d[112:128])
base += 128
}
switch left {
case 7:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 6:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 5:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 4:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 3:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 2:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 1:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 0:
xorBytes(dst[base:], src[base:], tbl)
}
}
// decryption
func decrypt(block cipher.Block, dst, src, buf []byte) {
switch block.BlockSize() {
case 8:
decrypt8(block, dst, src, buf)
case 16:
decrypt16(block, dst, src, buf)
default:
panic("unsupported cipher block size")
}
}
// decrypt 8 bytes block, all byte slices are supposed to be 64bit aligned
func decrypt8(block cipher.Block, dst, src, buf []byte) {
tbl := buf[0:8]
next := buf[8:16]
block.Encrypt(tbl, initialVector)
n := len(src) / 8
base := 0
repeat := n / 8
left := n % 8
ptr_tbl := (*uint64)(unsafe.Pointer(&tbl[0]))
ptr_next := (*uint64)(unsafe.Pointer(&next[0]))
for i := 0; i < repeat; i++ {
s := src[base:][0:64]
d := dst[base:][0:64]
// 1
block.Encrypt(next, s[0:8])
*(*uint64)(unsafe.Pointer(&d[0])) = *(*uint64)(unsafe.Pointer(&s[0])) ^ *ptr_tbl
// 2
block.Encrypt(tbl, s[8:16])
*(*uint64)(unsafe.Pointer(&d[8])) = *(*uint64)(unsafe.Pointer(&s[8])) ^ *ptr_next
// 3
block.Encrypt(next, s[16:24])
*(*uint64)(unsafe.Pointer(&d[16])) = *(*uint64)(unsafe.Pointer(&s[16])) ^ *ptr_tbl
// 4
block.Encrypt(tbl, s[24:32])
*(*uint64)(unsafe.Pointer(&d[24])) = *(*uint64)(unsafe.Pointer(&s[24])) ^ *ptr_next
// 5
block.Encrypt(next, s[32:40])
*(*uint64)(unsafe.Pointer(&d[32])) = *(*uint64)(unsafe.Pointer(&s[32])) ^ *ptr_tbl
// 6
block.Encrypt(tbl, s[40:48])
*(*uint64)(unsafe.Pointer(&d[40])) = *(*uint64)(unsafe.Pointer(&s[40])) ^ *ptr_next
// 7
block.Encrypt(next, s[48:56])
*(*uint64)(unsafe.Pointer(&d[48])) = *(*uint64)(unsafe.Pointer(&s[48])) ^ *ptr_tbl
// 8
block.Encrypt(tbl, s[56:64])
*(*uint64)(unsafe.Pointer(&d[56])) = *(*uint64)(unsafe.Pointer(&s[56])) ^ *ptr_next
base += 64
}
switch left {
case 7:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 6:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 5:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 4:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 3:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 2:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 1:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 0:
xorBytes(dst[base:], src[base:], tbl)
}
}
func decrypt16(block cipher.Block, dst, src, buf []byte) {
tbl := buf[0:16]
next := buf[16:32]
block.Encrypt(tbl, initialVector)
n := len(src) / 16
base := 0
repeat := n / 8
left := n % 8
for i := 0; i < repeat; i++ {
s := src[base:][0:128]
d := dst[base:][0:128]
// 1
block.Encrypt(next, s[0:16])
xor.Bytes16Align(d[0:16], s[0:16], tbl)
// 2
block.Encrypt(tbl, s[16:32])
xor.Bytes16Align(d[16:32], s[16:32], next)
// 3
block.Encrypt(next, s[32:48])
xor.Bytes16Align(d[32:48], s[32:48], tbl)
// 4
block.Encrypt(tbl, s[48:64])
xor.Bytes16Align(d[48:64], s[48:64], next)
// 5
block.Encrypt(next, s[64:80])
xor.Bytes16Align(d[64:80], s[64:80], tbl)
// 6
block.Encrypt(tbl, s[80:96])
xor.Bytes16Align(d[80:96], s[80:96], next)
// 7
block.Encrypt(next, s[96:112])
xor.Bytes16Align(d[96:112], s[96:112], tbl)
// 8
block.Encrypt(tbl, s[112:128])
xor.Bytes16Align(d[112:128], s[112:128], next)
base += 128
}
switch left {
case 7:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 6:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 5:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 4:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 3:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 2:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 1:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 0:
xorBytes(dst[base:], src[base:], tbl)
}
}
// per bytes xors
func xorBytes(dst, a, b []byte) int {
n := len(a)
if len(b) < n {
n = len(b)
}
if n == 0 {
return 0
}
for i := 0; i < n; i++ {
dst[i] = a[i] ^ b[i]
}
return n
}

View File

@@ -1,289 +0,0 @@
package kcp
import (
"bytes"
"crypto/aes"
"crypto/md5"
"crypto/rand"
"crypto/sha1"
"hash/crc32"
"io"
"testing"
)
func TestSM4(t *testing.T) {
bc, err := NewSM4BlockCrypt(pass[:16])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestAES(t *testing.T) {
bc, err := NewAESBlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestTEA(t *testing.T) {
bc, err := NewTEABlockCrypt(pass[:16])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestXOR(t *testing.T) {
bc, err := NewSimpleXORBlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestBlowfish(t *testing.T) {
bc, err := NewBlowfishBlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestNone(t *testing.T) {
bc, err := NewNoneBlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestCast5(t *testing.T) {
bc, err := NewCast5BlockCrypt(pass[:16])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func Test3DES(t *testing.T) {
bc, err := NewTripleDESBlockCrypt(pass[:24])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestTwofish(t *testing.T) {
bc, err := NewTwofishBlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestXTEA(t *testing.T) {
bc, err := NewXTEABlockCrypt(pass[:16])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestSalsa20(t *testing.T) {
bc, err := NewSalsa20BlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func cryptTest(t *testing.T, bc BlockCrypt) {
data := make([]byte, mtuLimit)
io.ReadFull(rand.Reader, data)
dec := make([]byte, mtuLimit)
enc := make([]byte, mtuLimit)
bc.Encrypt(enc, data)
bc.Decrypt(dec, enc)
if !bytes.Equal(data, dec) {
t.Fail()
}
}
func BenchmarkSM4(b *testing.B) {
bc, err := NewSM4BlockCrypt(pass[:16])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkAES128(b *testing.B) {
bc, err := NewAESBlockCrypt(pass[:16])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkAES192(b *testing.B) {
bc, err := NewAESBlockCrypt(pass[:24])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkAES256(b *testing.B) {
bc, err := NewAESBlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkTEA(b *testing.B) {
bc, err := NewTEABlockCrypt(pass[:16])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkXOR(b *testing.B) {
bc, err := NewSimpleXORBlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkBlowfish(b *testing.B) {
bc, err := NewBlowfishBlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkNone(b *testing.B) {
bc, err := NewNoneBlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkCast5(b *testing.B) {
bc, err := NewCast5BlockCrypt(pass[:16])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func Benchmark3DES(b *testing.B) {
bc, err := NewTripleDESBlockCrypt(pass[:24])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkTwofish(b *testing.B) {
bc, err := NewTwofishBlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkXTEA(b *testing.B) {
bc, err := NewXTEABlockCrypt(pass[:16])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkSalsa20(b *testing.B) {
bc, err := NewSalsa20BlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func benchCrypt(b *testing.B, bc BlockCrypt) {
data := make([]byte, mtuLimit)
io.ReadFull(rand.Reader, data)
dec := make([]byte, mtuLimit)
enc := make([]byte, mtuLimit)
b.ReportAllocs()
b.SetBytes(int64(len(enc) * 2))
b.ResetTimer()
for i := 0; i < b.N; i++ {
bc.Encrypt(enc, data)
bc.Decrypt(dec, enc)
}
}
func BenchmarkCRC32(b *testing.B) {
content := make([]byte, 1024)
b.SetBytes(int64(len(content)))
for i := 0; i < b.N; i++ {
crc32.ChecksumIEEE(content)
}
}
func BenchmarkCsprngSystem(b *testing.B) {
data := make([]byte, md5.Size)
b.SetBytes(int64(len(data)))
for i := 0; i < b.N; i++ {
io.ReadFull(rand.Reader, data)
}
}
func BenchmarkCsprngMD5(b *testing.B) {
var data [md5.Size]byte
b.SetBytes(md5.Size)
for i := 0; i < b.N; i++ {
data = md5.Sum(data[:])
}
}
func BenchmarkCsprngSHA1(b *testing.B) {
var data [sha1.Size]byte
b.SetBytes(sha1.Size)
for i := 0; i < b.N; i++ {
data = sha1.Sum(data[:])
}
}
func BenchmarkCsprngNonceMD5(b *testing.B) {
var ng nonceMD5
ng.Init()
b.SetBytes(md5.Size)
data := make([]byte, md5.Size)
for i := 0; i < b.N; i++ {
ng.Fill(data)
}
}
func BenchmarkCsprngNonceAES128(b *testing.B) {
var ng nonceAES128
ng.Init()
b.SetBytes(aes.BlockSize)
data := make([]byte, aes.BlockSize)
for i := 0; i < b.N; i++ {
ng.Fill(data)
}
}

View File

@@ -1,52 +0,0 @@
package kcp
import (
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"crypto/rand"
"io"
)
// Entropy defines a entropy source
type Entropy interface {
Init()
Fill(nonce []byte)
}
// nonceMD5 nonce generator for packet header
type nonceMD5 struct {
seed [md5.Size]byte
}
func (n *nonceMD5) Init() { /*nothing required*/ }
func (n *nonceMD5) Fill(nonce []byte) {
if n.seed[0] == 0 { // entropy update
io.ReadFull(rand.Reader, n.seed[:])
}
n.seed = md5.Sum(n.seed[:])
copy(nonce, n.seed[:])
}
// nonceAES128 nonce generator for packet headers
type nonceAES128 struct {
seed [aes.BlockSize]byte
block cipher.Block
}
func (n *nonceAES128) Init() {
var key [16]byte // aes-128
io.ReadFull(rand.Reader, key[:])
io.ReadFull(rand.Reader, n.seed[:])
block, _ := aes.NewCipher(key[:])
n.block = block
}
func (n *nonceAES128) Fill(nonce []byte) {
if n.seed[0] == 0 { // entropy update
io.ReadFull(rand.Reader, n.seed[:])
}
n.block.Encrypt(n.seed[:], n.seed[:])
copy(nonce, n.seed[:])
}

View File

@@ -1,381 +0,0 @@
package kcp
import (
"encoding/binary"
"sync/atomic"
"github.com/klauspost/reedsolomon"
)
const (
fecHeaderSize = 6
fecHeaderSizePlus2 = fecHeaderSize + 2 // plus 2B data size
typeData = 0xf1
typeParity = 0xf2
fecExpire = 60000
rxFECMulti = 3 // FEC keeps rxFECMulti* (dataShard+parityShard) ordered packets in memory
)
// fecPacket is a decoded FEC packet
type fecPacket []byte
func (bts fecPacket) seqid() uint32 { return binary.LittleEndian.Uint32(bts) }
func (bts fecPacket) flag() uint16 { return binary.LittleEndian.Uint16(bts[4:]) }
func (bts fecPacket) data() []byte { return bts[6:] }
// fecElement has auxcilliary time field
type fecElement struct {
fecPacket
ts uint32
}
// fecDecoder for decoding incoming packets
type fecDecoder struct {
rxlimit int // queue size limit
dataShards int
parityShards int
shardSize int
rx []fecElement // ordered receive queue
// caches
decodeCache [][]byte
flagCache []bool
// zeros
zeros []byte
// RS decoder
codec reedsolomon.Encoder
// auto tune fec parameter
autoTune autoTune
}
func newFECDecoder(dataShards, parityShards int) *fecDecoder {
if dataShards <= 0 || parityShards <= 0 {
return nil
}
dec := new(fecDecoder)
dec.dataShards = dataShards
dec.parityShards = parityShards
dec.shardSize = dataShards + parityShards
dec.rxlimit = rxFECMulti * dec.shardSize
codec, err := reedsolomon.New(dataShards, parityShards)
if err != nil {
return nil
}
dec.codec = codec
dec.decodeCache = make([][]byte, dec.shardSize)
dec.flagCache = make([]bool, dec.shardSize)
dec.zeros = make([]byte, mtuLimit)
return dec
}
// decode a fec packet
func (dec *fecDecoder) decode(in fecPacket) (recovered [][]byte) {
// sample to auto FEC tuner
if in.flag() == typeData {
dec.autoTune.Sample(true, in.seqid())
} else {
dec.autoTune.Sample(false, in.seqid())
}
// check if FEC parameters is out of sync
var shouldTune bool
if int(in.seqid())%dec.shardSize < dec.dataShards {
if in.flag() != typeData { // expect typeData
shouldTune = true
}
} else {
if in.flag() != typeParity {
shouldTune = true
}
}
if shouldTune {
autoDS := dec.autoTune.FindPeriod(true)
autoPS := dec.autoTune.FindPeriod(false)
// edges found, we can tune parameters now
if autoDS > 0 && autoPS > 0 && autoDS < 256 && autoPS < 256 {
// and make sure it's different
if autoDS != dec.dataShards || autoPS != dec.parityShards {
dec.dataShards = autoDS
dec.parityShards = autoPS
dec.shardSize = autoDS + autoPS
dec.rxlimit = rxFECMulti * dec.shardSize
codec, err := reedsolomon.New(autoDS, autoPS)
if err != nil {
return nil
}
dec.codec = codec
dec.decodeCache = make([][]byte, dec.shardSize)
dec.flagCache = make([]bool, dec.shardSize)
// log.Println("autotune to :", dec.dataShards, dec.parityShards)
}
}
}
// insertion
n := len(dec.rx) - 1
insertIdx := 0
for i := n; i >= 0; i-- {
if in.seqid() == dec.rx[i].seqid() { // de-duplicate
return nil
} else if _itimediff(in.seqid(), dec.rx[i].seqid()) > 0 { // insertion
insertIdx = i + 1
break
}
}
// make a copy
pkt := fecPacket(xmitBuf.Get().([]byte)[:len(in)])
copy(pkt, in)
elem := fecElement{pkt, currentMs()}
// insert into ordered rx queue
if insertIdx == n+1 {
dec.rx = append(dec.rx, elem)
} else {
dec.rx = append(dec.rx, fecElement{})
copy(dec.rx[insertIdx+1:], dec.rx[insertIdx:]) // shift right
dec.rx[insertIdx] = elem
}
// shard range for current packet
shardBegin := pkt.seqid() - pkt.seqid()%uint32(dec.shardSize)
shardEnd := shardBegin + uint32(dec.shardSize) - 1
// max search range in ordered queue for current shard
searchBegin := insertIdx - int(pkt.seqid()%uint32(dec.shardSize))
if searchBegin < 0 {
searchBegin = 0
}
searchEnd := searchBegin + dec.shardSize - 1
if searchEnd >= len(dec.rx) {
searchEnd = len(dec.rx) - 1
}
// re-construct datashards
if searchEnd-searchBegin+1 >= dec.dataShards {
var numshard, numDataShard, first, maxlen int
// zero caches
shards := dec.decodeCache
shardsflag := dec.flagCache
for k := range dec.decodeCache {
shards[k] = nil
shardsflag[k] = false
}
// shard assembly
for i := searchBegin; i <= searchEnd; i++ {
seqid := dec.rx[i].seqid()
if _itimediff(seqid, shardEnd) > 0 {
break
} else if _itimediff(seqid, shardBegin) >= 0 {
shards[seqid%uint32(dec.shardSize)] = dec.rx[i].data()
shardsflag[seqid%uint32(dec.shardSize)] = true
numshard++
if dec.rx[i].flag() == typeData {
numDataShard++
}
if numshard == 1 {
first = i
}
if len(dec.rx[i].data()) > maxlen {
maxlen = len(dec.rx[i].data())
}
}
}
if numDataShard == dec.dataShards {
// case 1: no loss on data shards
dec.rx = dec.freeRange(first, numshard, dec.rx)
} else if numshard >= dec.dataShards {
// case 2: loss on data shards, but it's recoverable from parity shards
for k := range shards {
if shards[k] != nil {
dlen := len(shards[k])
shards[k] = shards[k][:maxlen]
copy(shards[k][dlen:], dec.zeros)
} else if k < dec.dataShards {
shards[k] = xmitBuf.Get().([]byte)[:0]
}
}
if err := dec.codec.ReconstructData(shards); err == nil {
for k := range shards[:dec.dataShards] {
if !shardsflag[k] {
// recovered data should be recycled
recovered = append(recovered, shards[k])
}
}
}
dec.rx = dec.freeRange(first, numshard, dec.rx)
}
}
// keep rxlimit
if len(dec.rx) > dec.rxlimit {
if dec.rx[0].flag() == typeData { // track the unrecoverable data
atomic.AddUint64(&DefaultSnmp.FECShortShards, 1)
}
dec.rx = dec.freeRange(0, 1, dec.rx)
}
// timeout policy
current := currentMs()
numExpired := 0
for k := range dec.rx {
if _itimediff(current, dec.rx[k].ts) > fecExpire {
numExpired++
continue
}
break
}
if numExpired > 0 {
dec.rx = dec.freeRange(0, numExpired, dec.rx)
}
return
}
// free a range of fecPacket
func (dec *fecDecoder) freeRange(first, n int, q []fecElement) []fecElement {
for i := first; i < first+n; i++ { // recycle buffer
xmitBuf.Put([]byte(q[i].fecPacket))
}
if first == 0 && n < cap(q)/2 {
return q[n:]
}
copy(q[first:], q[first+n:])
return q[:len(q)-n]
}
// release all segments back to xmitBuf
func (dec *fecDecoder) release() {
if n := len(dec.rx); n > 0 {
dec.rx = dec.freeRange(0, n, dec.rx)
}
}
type (
// fecEncoder for encoding outgoing packets
fecEncoder struct {
dataShards int
parityShards int
shardSize int
paws uint32 // Protect Against Wrapped Sequence numbers
next uint32 // next seqid
shardCount int // count the number of datashards collected
maxSize int // track maximum data length in datashard
headerOffset int // FEC header offset
payloadOffset int // FEC payload offset
// caches
shardCache [][]byte
encodeCache [][]byte
// zeros
zeros []byte
// RS encoder
codec reedsolomon.Encoder
}
)
func newFECEncoder(dataShards, parityShards, offset int) *fecEncoder {
if dataShards <= 0 || parityShards <= 0 {
return nil
}
enc := new(fecEncoder)
enc.dataShards = dataShards
enc.parityShards = parityShards
enc.shardSize = dataShards + parityShards
enc.paws = 0xffffffff / uint32(enc.shardSize) * uint32(enc.shardSize)
enc.headerOffset = offset
enc.payloadOffset = enc.headerOffset + fecHeaderSize
codec, err := reedsolomon.New(dataShards, parityShards)
if err != nil {
return nil
}
enc.codec = codec
// caches
enc.encodeCache = make([][]byte, enc.shardSize)
enc.shardCache = make([][]byte, enc.shardSize)
for k := range enc.shardCache {
enc.shardCache[k] = make([]byte, mtuLimit)
}
enc.zeros = make([]byte, mtuLimit)
return enc
}
// encodes the packet, outputs parity shards if we have collected quorum datashards
// notice: the contents of 'ps' will be re-written in successive calling
func (enc *fecEncoder) encode(b []byte) (ps [][]byte) {
// The header format:
// | FEC SEQID(4B) | FEC TYPE(2B) | SIZE (2B) | PAYLOAD(SIZE-2) |
// |<-headerOffset |<-payloadOffset
enc.markData(b[enc.headerOffset:])
binary.LittleEndian.PutUint16(b[enc.payloadOffset:], uint16(len(b[enc.payloadOffset:])))
// copy data from payloadOffset to fec shard cache
sz := len(b)
enc.shardCache[enc.shardCount] = enc.shardCache[enc.shardCount][:sz]
copy(enc.shardCache[enc.shardCount][enc.payloadOffset:], b[enc.payloadOffset:])
enc.shardCount++
// track max datashard length
if sz > enc.maxSize {
enc.maxSize = sz
}
// Generation of Reed-Solomon Erasure Code
if enc.shardCount == enc.dataShards {
// fill '0' into the tail of each datashard
for i := 0; i < enc.dataShards; i++ {
shard := enc.shardCache[i]
slen := len(shard)
copy(shard[slen:enc.maxSize], enc.zeros)
}
// construct equal-sized slice with stripped header
cache := enc.encodeCache
for k := range cache {
cache[k] = enc.shardCache[k][enc.payloadOffset:enc.maxSize]
}
// encoding
if err := enc.codec.Encode(cache); err == nil {
ps = enc.shardCache[enc.dataShards:]
for k := range ps {
enc.markParity(ps[k][enc.headerOffset:])
ps[k] = ps[k][:enc.maxSize]
}
}
// counters resetting
enc.shardCount = 0
enc.maxSize = 0
}
return
}
func (enc *fecEncoder) markData(data []byte) {
binary.LittleEndian.PutUint32(data, enc.next)
binary.LittleEndian.PutUint16(data[4:], typeData)
enc.next++
}
func (enc *fecEncoder) markParity(data []byte) {
binary.LittleEndian.PutUint32(data, enc.next)
binary.LittleEndian.PutUint16(data[4:], typeParity)
// sequence wrap will only happen at parity shard
enc.next = (enc.next + 1) % enc.paws
}

View File

@@ -1,43 +0,0 @@
package kcp
import (
"encoding/binary"
"math/rand"
"testing"
)
func BenchmarkFECDecode(b *testing.B) {
const dataSize = 10
const paritySize = 3
const payLoad = 1500
decoder := newFECDecoder(dataSize, paritySize)
b.ReportAllocs()
b.SetBytes(payLoad)
for i := 0; i < b.N; i++ {
if rand.Int()%(dataSize+paritySize) == 0 { // random loss
continue
}
pkt := make([]byte, payLoad)
binary.LittleEndian.PutUint32(pkt, uint32(i))
if i%(dataSize+paritySize) >= dataSize {
binary.LittleEndian.PutUint16(pkt[4:], typeParity)
} else {
binary.LittleEndian.PutUint16(pkt[4:], typeData)
}
decoder.decode(pkt)
}
}
func BenchmarkFECEncode(b *testing.B) {
const dataSize = 10
const paritySize = 3
const payLoad = 1500
b.ReportAllocs()
b.SetBytes(payLoad)
encoder := newFECEncoder(dataSize, paritySize, 0)
for i := 0; i < b.N; i++ {
data := make([]byte, payLoad)
encoder.encode(data)
}
}

View File

@@ -74,8 +74,8 @@ func TestLossyConn4(t *testing.T) {
func testlink(t *testing.T, client *lossyconn.LossyConn, server *lossyconn.LossyConn, nodelay, interval, resend, nc int) {
t.Log("testing with nodelay parameters:", nodelay, interval, resend, nc)
sess, _ := NewConn2(server.LocalAddr(), nil, 0, 0, client)
listener, _ := ServeConn(nil, 0, 0, server)
sess, _ := NewConn2(server.LocalAddr(), client)
listener, _ := ServeConn(server)
echoServer := func(l *Listener) {
for {
conn, err := l.AcceptKCP()

View File

@@ -1,12 +0,0 @@
//go:build !linux
// +build !linux
package kcp
func (s *UDPSession) readLoop() {
s.defaultReadLoop()
}
func (l *Listener) monitor() {
l.defaultMonitor()
}

View File

@@ -11,7 +11,6 @@ import (
"crypto/rand"
"encoding/binary"
"encoding/hex"
"hash/crc32"
"io"
"net"
"sync"
@@ -65,16 +64,16 @@ type (
ownConn bool // true if we created conn internally, false if provided by caller
kcp *KCP // KCP ARQ protocol
l *Listener // pointing to the Listener object if it's been accepted by a Listener
block BlockCrypt // block encryption object
// block BlockCrypt // block encryption object
// kcp receiving is based on packets
// recvbuf turns packets into stream
recvbuf []byte
bufptr []byte
// FEC codec
fecDecoder *fecDecoder
fecEncoder *fecEncoder
// // FEC codec
// fecDecoder *fecDecoder
// fecEncoder *fecEncoder
// settings
remote net.Addr // remote peer address
@@ -99,8 +98,8 @@ type (
socketReadErrorOnce sync.Once
socketWriteErrorOnce sync.Once
// nonce generator
nonce Entropy
// // nonce generator
// nonce Entropy
// packets waiting to be sent on wire
txqueue []ipv4.Message
@@ -124,11 +123,11 @@ type (
)
// newUDPSession create a new udp session for client or server
func newUDPSession(conv uint64, dataShards, parityShards int, l *Listener, conn net.PacketConn, ownConn bool, remote net.Addr, block BlockCrypt) *UDPSession {
func newUDPSession(conv uint64, l *Listener, conn net.PacketConn, ownConn bool, remote net.Addr) *UDPSession {
sess := new(UDPSession)
sess.die = make(chan struct{})
sess.nonce = new(nonceAES128)
sess.nonce.Init()
// sess.nonce = new(nonceAES128)
// sess.nonce.Init()
sess.chReadEvent = make(chan struct{}, 1)
sess.chWriteEvent = make(chan struct{}, 1)
sess.chSocketReadError = make(chan struct{})
@@ -137,7 +136,7 @@ func newUDPSession(conv uint64, dataShards, parityShards int, l *Listener, conn
sess.conn = conn
sess.ownConn = ownConn
sess.l = l
sess.block = block
// sess.block = block
sess.recvbuf = make([]byte, mtuLimit)
// cast to writebatch conn
@@ -152,21 +151,21 @@ func newUDPSession(conv uint64, dataShards, parityShards int, l *Listener, conn
}
}
// FEC codec initialization
sess.fecDecoder = newFECDecoder(dataShards, parityShards)
if sess.block != nil {
sess.fecEncoder = newFECEncoder(dataShards, parityShards, cryptHeaderSize)
} else {
sess.fecEncoder = newFECEncoder(dataShards, parityShards, 0)
}
// // FEC codec initialization
// sess.fecDecoder = newFECDecoder(dataShards, parityShards)
// if sess.block != nil {
// sess.fecEncoder = newFECEncoder(dataShards, parityShards, cryptHeaderSize)
// } else {
// sess.fecEncoder = newFECEncoder(dataShards, parityShards, 0)
// }
// calculate additional header size introduced by FEC and encryption
if sess.block != nil {
sess.headerSize += cryptHeaderSize
}
if sess.fecEncoder != nil {
sess.headerSize += fecHeaderSizePlus2
}
// // calculate additional header size introduced by FEC and encryption
// if sess.block != nil {
// sess.headerSize += cryptHeaderSize
// }
// if sess.fecEncoder != nil {
// sess.headerSize += fecHeaderSizePlus2
// }
sess.kcp = NewKCP(conv, func(buf []byte, size int) {
if size >= IKCP_OVERHEAD+sess.headerSize {
@@ -371,9 +370,9 @@ func (s *UDPSession) Close() error {
s.uncork()
// release pending segments
s.kcp.ReleaseTX()
if s.fecDecoder != nil {
s.fecDecoder.release()
}
// if s.fecDecoder != nil {
// s.fecDecoder.release()
// }
s.mu.Unlock()
if s.l != nil { // belongs to listener
@@ -552,25 +551,25 @@ func (s *UDPSession) SetWriteBuffer(bytes int) error {
func (s *UDPSession) output(buf []byte) {
var ecc [][]byte
// 1. FEC encoding
if s.fecEncoder != nil {
ecc = s.fecEncoder.encode(buf)
}
// // 1. FEC encoding
// if s.fecEncoder != nil {
// ecc = s.fecEncoder.encode(buf)
// }
// 2&3. crc32 & encryption
if s.block != nil {
s.nonce.Fill(buf[:nonceSize])
checksum := crc32.ChecksumIEEE(buf[cryptHeaderSize:])
binary.LittleEndian.PutUint32(buf[nonceSize:], checksum)
s.block.Encrypt(buf, buf)
for k := range ecc {
s.nonce.Fill(ecc[k][:nonceSize])
checksum := crc32.ChecksumIEEE(ecc[k][cryptHeaderSize:])
binary.LittleEndian.PutUint32(ecc[k][nonceSize:], checksum)
s.block.Encrypt(ecc[k], ecc[k])
}
}
// // 2&3. crc32 & encryption
// if s.block != nil {
// s.nonce.Fill(buf[:nonceSize])
// checksum := crc32.ChecksumIEEE(buf[cryptHeaderSize:])
// binary.LittleEndian.PutUint32(buf[nonceSize:], checksum)
// s.block.Encrypt(buf, buf)
//
// for k := range ecc {
// s.nonce.Fill(ecc[k][:nonceSize])
// checksum := crc32.ChecksumIEEE(ecc[k][cryptHeaderSize:])
// binary.LittleEndian.PutUint32(ecc[k][nonceSize:], checksum)
// s.block.Encrypt(ecc[k], ecc[k])
// }
// }
// 4. TxQueue
var msg ipv4.Message
@@ -663,114 +662,130 @@ func (s *UDPSession) notifyWriteError(err error) {
// packet input stage
func (s *UDPSession) packetInput(data []byte) {
decrypted := false
if s.block != nil && len(data) >= cryptHeaderSize {
s.block.Decrypt(data, data)
data = data[nonceSize:]
checksum := crc32.ChecksumIEEE(data[crcSize:])
if checksum == binary.LittleEndian.Uint32(data) {
data = data[crcSize:]
decrypted = true
} else {
atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1)
}
} else if s.block == nil {
decrypted = true
}
// decrypted := false
// if s.block != nil && len(data) >= cryptHeaderSize {
// s.block.Decrypt(data, data)
// data = data[nonceSize:]
// checksum := crc32.ChecksumIEEE(data[crcSize:])
// if checksum == binary.LittleEndian.Uint32(data) {
// data = data[crcSize:]
// decrypted = true
// } else {
// atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1)
// }
// } else if s.block == nil {
// decrypted = true
// }
decrypted := true
if decrypted && len(data) >= IKCP_OVERHEAD {
s.kcpInput(data)
}
}
func (s *UDPSession) kcpInput(data []byte) {
var kcpInErrors, fecErrs, fecRecovered, fecParityShards uint64
// var kcpInErrors, fecErrs, fecRecovered, fecParityShards uint64
//
// fecFlag := binary.LittleEndian.Uint16(data[8:])
// if fecFlag == typeData || fecFlag == typeParity { // 16bit kcp cmd [81-84] and frg [0-255] will not overlap with FEC type 0x00f1 0x00f2
// if len(data) >= fecHeaderSizePlus2 {
// f := fecPacket(data)
// if f.flag() == typeParity {
// fecParityShards++
// }
//
// // lock
// s.mu.Lock()
// // if fecDecoder is not initialized, create one with default parameter
// if s.fecDecoder == nil {
// s.fecDecoder = newFECDecoder(1, 1)
// }
// recovers := s.fecDecoder.decode(f)
// if f.flag() == typeData {
// if ret := s.kcp.Input(data[fecHeaderSizePlus2:], true, s.ackNoDelay); ret != 0 {
// kcpInErrors++
// }
// }
//
// for _, r := range recovers {
// if len(r) >= 2 { // must be larger than 2bytes
// sz := binary.LittleEndian.Uint16(r)
// if int(sz) <= len(r) && sz >= 2 {
// if ret := s.kcp.Input(r[2:sz], false, s.ackNoDelay); ret == 0 {
// fecRecovered++
// } else {
// kcpInErrors++
// }
// } else {
// fecErrs++
// }
// } else {
// fecErrs++
// }
// // recycle the recovers
// xmitBuf.Put(r)
// }
//
// // to notify the readers to receive the data
// if n := s.kcp.PeekSize(); n > 0 {
// s.notifyReadEvent()
// }
// // to notify the writers
// waitsnd := s.kcp.WaitSnd()
// if waitsnd < int(s.kcp.snd_wnd) && waitsnd < int(s.kcp.rmt_wnd) {
// s.notifyWriteEvent()
// }
//
// s.uncork()
// s.mu.Unlock()
// } else {
// atomic.AddUint64(&DefaultSnmp.InErrs, 1)
// }
// } else {
// s.mu.Lock()
// if ret := s.kcp.Input(data, true, s.ackNoDelay); ret != 0 {
// kcpInErrors++
// }
// if n := s.kcp.PeekSize(); n > 0 {
// s.notifyReadEvent()
// }
// waitsnd := s.kcp.WaitSnd()
// if waitsnd < int(s.kcp.snd_wnd) && waitsnd < int(s.kcp.rmt_wnd) {
// s.notifyWriteEvent()
// }
// s.uncork()
// s.mu.Unlock()
// }
fecFlag := binary.LittleEndian.Uint16(data[8:])
if fecFlag == typeData || fecFlag == typeParity { // 16bit kcp cmd [81-84] and frg [0-255] will not overlap with FEC type 0x00f1 0x00f2
if len(data) >= fecHeaderSizePlus2 {
f := fecPacket(data)
if f.flag() == typeParity {
fecParityShards++
}
// lock
s.mu.Lock()
// if fecDecoder is not initialized, create one with default parameter
if s.fecDecoder == nil {
s.fecDecoder = newFECDecoder(1, 1)
}
recovers := s.fecDecoder.decode(f)
if f.flag() == typeData {
if ret := s.kcp.Input(data[fecHeaderSizePlus2:], true, s.ackNoDelay); ret != 0 {
kcpInErrors++
}
}
for _, r := range recovers {
if len(r) >= 2 { // must be larger than 2bytes
sz := binary.LittleEndian.Uint16(r)
if int(sz) <= len(r) && sz >= 2 {
if ret := s.kcp.Input(r[2:sz], false, s.ackNoDelay); ret == 0 {
fecRecovered++
} else {
kcpInErrors++
}
} else {
fecErrs++
}
} else {
fecErrs++
}
// recycle the recovers
xmitBuf.Put(r)
}
// to notify the readers to receive the data
if n := s.kcp.PeekSize(); n > 0 {
s.notifyReadEvent()
}
// to notify the writers
waitsnd := s.kcp.WaitSnd()
if waitsnd < int(s.kcp.snd_wnd) && waitsnd < int(s.kcp.rmt_wnd) {
s.notifyWriteEvent()
}
s.uncork()
s.mu.Unlock()
} else {
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
}
} else {
s.mu.Lock()
if ret := s.kcp.Input(data, true, s.ackNoDelay); ret != 0 {
kcpInErrors++
}
if n := s.kcp.PeekSize(); n > 0 {
s.notifyReadEvent()
}
waitsnd := s.kcp.WaitSnd()
if waitsnd < int(s.kcp.snd_wnd) && waitsnd < int(s.kcp.rmt_wnd) {
s.notifyWriteEvent()
}
s.uncork()
s.mu.Unlock()
var kcpInErrors uint64
s.mu.Lock()
if ret := s.kcp.Input(data, true, s.ackNoDelay); ret != 0 {
kcpInErrors++
}
if n := s.kcp.PeekSize(); n > 0 {
s.notifyReadEvent()
}
waitsnd := s.kcp.WaitSnd()
if waitsnd < int(s.kcp.snd_wnd) && waitsnd < int(s.kcp.rmt_wnd) {
s.notifyWriteEvent()
}
s.uncork()
s.mu.Unlock()
atomic.AddUint64(&DefaultSnmp.InPkts, 1)
atomic.AddUint64(&DefaultSnmp.InBytes, uint64(len(data)))
if fecParityShards > 0 {
atomic.AddUint64(&DefaultSnmp.FECParityShards, fecParityShards)
}
if kcpInErrors > 0 {
atomic.AddUint64(&DefaultSnmp.KCPInErrors, kcpInErrors)
}
if fecErrs > 0 {
atomic.AddUint64(&DefaultSnmp.FECErrs, fecErrs)
}
if fecRecovered > 0 {
atomic.AddUint64(&DefaultSnmp.FECRecovered, fecRecovered)
}
// if fecParityShards > 0 {
// atomic.AddUint64(&DefaultSnmp.FECParityShards, fecParityShards)
// }
// if fecErrs > 0 {
// atomic.AddUint64(&DefaultSnmp.FECErrs, fecErrs)
// }
// if fecRecovered > 0 {
// atomic.AddUint64(&DefaultSnmp.FECRecovered, fecRecovered)
// }
}
@@ -828,11 +843,11 @@ const (
type (
// Listener defines a server which will be waiting to accept incoming connections
Listener struct {
block BlockCrypt // block encryption
dataShards int // FEC data shard
parityShards int // FEC parity shard
conn net.PacketConn // the underlying packet connection
ownConn bool // true if we created conn internally, false if provided by caller
// block BlockCrypt // block encryption
// dataShards int // FEC data shard
// parityShards int // FEC parity shard
conn net.PacketConn // the underlying packet connection
ownConn bool // true if we created conn internally, false if provided by caller
// 网络切换会话保持改造 将convId作为会话的唯一标识 不再校验源地址
sessions map[uint64]*UDPSession // all sessions accepted by this Listener
@@ -856,21 +871,22 @@ type (
// packet input stage
func (l *Listener) packetInput(data []byte, addr net.Addr, convId uint64) {
decrypted := false
if l.block != nil && len(data) >= cryptHeaderSize {
l.block.Decrypt(data, data)
data = data[nonceSize:]
checksum := crc32.ChecksumIEEE(data[crcSize:])
if checksum == binary.LittleEndian.Uint32(data) {
data = data[crcSize:]
decrypted = true
} else {
atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1)
}
} else if l.block == nil {
decrypted = true
}
// decrypted := false
// if l.block != nil && len(data) >= cryptHeaderSize {
// l.block.Decrypt(data, data)
// data = data[nonceSize:]
// checksum := crc32.ChecksumIEEE(data[crcSize:])
// if checksum == binary.LittleEndian.Uint32(data) {
// data = data[crcSize:]
// decrypted = true
// } else {
// atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1)
// }
// } else if l.block == nil {
// decrypted = true
// }
decrypted := true
if decrypted && len(data) >= IKCP_OVERHEAD {
l.sessionLock.RLock()
s, ok := l.sessions[convId]
@@ -879,20 +895,25 @@ func (l *Listener) packetInput(data []byte, addr net.Addr, convId uint64) {
var conv uint64
var sn uint32
convRecovered := false
fecFlag := binary.LittleEndian.Uint16(data[8:])
if fecFlag == typeData || fecFlag == typeParity { // 16bit kcp cmd [81-84] and frg [0-255] will not overlap with FEC type 0x00f1 0x00f2
// packet with FEC
if fecFlag == typeData && len(data) >= fecHeaderSizePlus2+IKCP_OVERHEAD {
conv = binary.LittleEndian.Uint64(data[fecHeaderSizePlus2:])
sn = binary.LittleEndian.Uint32(data[fecHeaderSizePlus2+IKCP_SN_OFFSET:])
convRecovered = true
}
} else {
// packet without FEC
conv = binary.LittleEndian.Uint64(data)
sn = binary.LittleEndian.Uint32(data[IKCP_SN_OFFSET:])
convRecovered = true
}
// fecFlag := binary.LittleEndian.Uint16(data[8:])
// if fecFlag == typeData || fecFlag == typeParity { // 16bit kcp cmd [81-84] and frg [0-255] will not overlap with FEC type 0x00f1 0x00f2
// // packet with FEC
// if fecFlag == typeData && len(data) >= fecHeaderSizePlus2+IKCP_OVERHEAD {
// conv = binary.LittleEndian.Uint64(data[fecHeaderSizePlus2:])
// sn = binary.LittleEndian.Uint32(data[fecHeaderSizePlus2+IKCP_SN_OFFSET:])
// convRecovered = true
// }
// } else {
// // packet without FEC
// conv = binary.LittleEndian.Uint64(data)
// sn = binary.LittleEndian.Uint32(data[IKCP_SN_OFFSET:])
// convRecovered = true
// }
// packet without FEC
conv = binary.LittleEndian.Uint64(data)
sn = binary.LittleEndian.Uint32(data[IKCP_SN_OFFSET:])
convRecovered = true
if ok { // existing connection
if !convRecovered || conv == s.kcp.conv { // parity data or valid conversation
@@ -906,7 +927,7 @@ func (l *Listener) packetInput(data []byte, addr net.Addr, convId uint64) {
if s == nil && convRecovered { // new session
if len(l.chAccepts) < cap(l.chAccepts) { // do not let the new sessions overwhelm accept queue
s := newUDPSession(conv, l.dataShards, l.parityShards, l, l.conn, false, addr, l.block)
s := newUDPSession(conv, l, l.conn, false, addr)
s.kcpInput(data)
l.sessionLock.Lock()
l.sessions[convId] = s
@@ -1047,7 +1068,7 @@ func (l *Listener) closeSession(convId uint64) (ret bool) {
func (l *Listener) Addr() net.Addr { return l.conn.LocalAddr() }
// Listen listens for incoming KCP packets addressed to the local address laddr on the network "udp",
func Listen(laddr string) (net.Listener, error) { return ListenWithOptions(laddr, nil, 0, 0) }
func Listen(laddr string) (net.Listener, error) { return ListenWithOptions(laddr) }
// ListenWithOptions listens for incoming KCP packets addressed to the local address laddr on the network "udp" with packet encryption.
//
@@ -1056,7 +1077,7 @@ func Listen(laddr string) (net.Listener, error) { return ListenWithOptions(laddr
// 'dataShards', 'parityShards' specify how many parity packets will be generated following the data packets.
//
// Check https://github.com/klauspost/reedsolomon for details
func ListenWithOptions(laddr string, block BlockCrypt, dataShards, parityShards int) (*Listener, error) {
func ListenWithOptions(laddr string) (*Listener, error) {
udpaddr, err := net.ResolveUDPAddr("udp", laddr)
if err != nil {
return nil, errors.WithStack(err)
@@ -1066,15 +1087,15 @@ func ListenWithOptions(laddr string, block BlockCrypt, dataShards, parityShards
return nil, errors.WithStack(err)
}
return serveConn(block, dataShards, parityShards, conn, true)
return serveConn(conn, true)
}
// ServeConn serves KCP protocol for a single packet connection.
func ServeConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*Listener, error) {
return serveConn(block, dataShards, parityShards, conn, false)
func ServeConn(conn net.PacketConn) (*Listener, error) {
return serveConn(conn, false)
}
func serveConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketConn, ownConn bool) (*Listener, error) {
func serveConn(conn net.PacketConn, ownConn bool) (*Listener, error) {
l := new(Listener)
l.conn = conn
l.ownConn = ownConn
@@ -1082,9 +1103,9 @@ func serveConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketCo
l.chAccepts = make(chan *UDPSession, acceptBacklog)
l.chSessionClosed = make(chan net.Addr)
l.die = make(chan struct{})
l.dataShards = dataShards
l.parityShards = parityShards
l.block = block
// l.dataShards = dataShards
// l.parityShards = parityShards
// l.block = block
l.chSocketReadError = make(chan struct{})
l.EnetNotify = make(chan *Enet, 1000)
go l.monitor()
@@ -1092,7 +1113,7 @@ func serveConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketCo
}
// Dial connects to the remote address "raddr" on the network "udp" without encryption and FEC
func Dial(raddr string) (net.Conn, error) { return DialWithOptions(raddr, nil, 0, 0) }
func Dial(raddr string) (net.Conn, error) { return DialWithOptions(raddr) }
// DialWithOptions connects to the remote address "raddr" on the network "udp" with packet encryption
//
@@ -1101,7 +1122,7 @@ func Dial(raddr string) (net.Conn, error) { return DialWithOptions(raddr, nil, 0
// 'dataShards', 'parityShards' specify how many parity packets will be generated following the data packets.
//
// Check https://github.com/klauspost/reedsolomon for details
func DialWithOptions(raddr string, block BlockCrypt, dataShards, parityShards int) (*UDPSession, error) {
func DialWithOptions(raddr string) (*UDPSession, error) {
// network type detection
udpaddr, err := net.ResolveUDPAddr("udp", raddr)
if err != nil {
@@ -1119,26 +1140,26 @@ func DialWithOptions(raddr string, block BlockCrypt, dataShards, parityShards in
var convid uint64
binary.Read(rand.Reader, binary.LittleEndian, &convid)
return newUDPSession(convid, dataShards, parityShards, nil, conn, true, udpaddr, block), nil
return newUDPSession(convid, nil, conn, true, udpaddr), nil
}
// NewConn3 establishes a session and talks KCP protocol over a packet connection.
func NewConn3(convid uint64, raddr net.Addr, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) {
return newUDPSession(convid, dataShards, parityShards, nil, conn, false, raddr, block), nil
func NewConn3(convid uint64, raddr net.Addr, conn net.PacketConn) (*UDPSession, error) {
return newUDPSession(convid, nil, conn, false, raddr), nil
}
// NewConn2 establishes a session and talks KCP protocol over a packet connection.
func NewConn2(raddr net.Addr, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) {
func NewConn2(raddr net.Addr, conn net.PacketConn) (*UDPSession, error) {
var convid uint64
binary.Read(rand.Reader, binary.LittleEndian, &convid)
return NewConn3(convid, raddr, block, dataShards, parityShards, conn)
return NewConn3(convid, raddr, conn)
}
// NewConn establishes a session and talks KCP protocol over a packet connection.
func NewConn(raddr string, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) {
func NewConn(raddr string, conn net.PacketConn) (*UDPSession, error) {
udpaddr, err := net.ResolveUDPAddr("udp", raddr)
if err != nil {
return nil, errors.WithStack(err)
}
return NewConn2(udpaddr, block, dataShards, parityShards, conn)
return NewConn2(udpaddr, conn)
}

View File

@@ -33,8 +33,8 @@ func dialEcho(port int) (*UDPSession, error) {
// block, _ := NewSimpleXORBlockCrypt(pass)
// block, _ := NewTEABlockCrypt(pass[:16])
// block, _ := NewAESBlockCrypt(pass)
block, _ := NewSalsa20BlockCrypt(pass)
sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port), block, 10, 3)
// block, _ := NewSalsa20BlockCrypt(pass)
sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port))
if err != nil {
panic(err)
}
@@ -57,7 +57,7 @@ func dialEcho(port int) (*UDPSession, error) {
}
func dialSink(port int) (*UDPSession, error) {
sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port), nil, 0, 0)
sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port))
if err != nil {
panic(err)
}
@@ -79,8 +79,8 @@ func dialTinyBufferEcho(port int) (*UDPSession, error) {
// block, _ := NewSimpleXORBlockCrypt(pass)
// block, _ := NewTEABlockCrypt(pass[:16])
// block, _ := NewAESBlockCrypt(pass)
block, _ := NewSalsa20BlockCrypt(pass)
sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port), block, 10, 3)
// block, _ := NewSalsa20BlockCrypt(pass)
sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port))
if err != nil {
panic(err)
}
@@ -93,20 +93,20 @@ func listenEcho(port int) (net.Listener, error) {
// block, _ := NewSimpleXORBlockCrypt(pass)
// block, _ := NewTEABlockCrypt(pass[:16])
// block, _ := NewAESBlockCrypt(pass)
block, _ := NewSalsa20BlockCrypt(pass)
return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), block, 10, 0)
// block, _ := NewSalsa20BlockCrypt(pass)
return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port))
}
func listenTinyBufferEcho(port int) (net.Listener, error) {
// block, _ := NewNoneBlockCrypt(pass)
// block, _ := NewSimpleXORBlockCrypt(pass)
// block, _ := NewTEABlockCrypt(pass[:16])
// block, _ := NewAESBlockCrypt(pass)
block, _ := NewSalsa20BlockCrypt(pass)
return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), block, 10, 3)
// block, _ := NewSalsa20BlockCrypt(pass)
return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port))
}
func listenSink(port int) (net.Listener, error) {
return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), nil, 0, 0)
return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port))
}
func echoServer(port int) net.Listener {
@@ -541,7 +541,7 @@ func TestSNMP(t *testing.T) {
func TestListenerClose(t *testing.T) {
port := int(atomic.AddUint32(&baseport, 1))
l, err := ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), nil, 10, 3)
l, err := ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port))
if err != nil {
t.Fail()
}
@@ -580,7 +580,7 @@ func newClosedFlagPacketConn(c net.PacketConn) *closedFlagPacketConn {
// https://github.com/xtaci/kcp-go/issues/165
func TestListenerOwnedPacketConn(t *testing.T) {
// ListenWithOptions creates its own net.PacketConn.
l, err := ListenWithOptions("127.0.0.1:0", nil, 0, 0)
l, err := ListenWithOptions("127.0.0.1:0")
if err != nil {
panic(err)
}
@@ -616,7 +616,7 @@ func TestListenerNonOwnedPacketConn(t *testing.T) {
// Make it remember when it has been closed.
pconn := newClosedFlagPacketConn(c)
l, err := ServeConn(nil, 0, 0, pconn)
l, err := ServeConn(pconn)
if err != nil {
panic(err)
}
@@ -643,7 +643,7 @@ func TestUDPSessionOwnedPacketConn(t *testing.T) {
defer l.Close()
// DialWithOptions creates its own net.PacketConn.
client, err := DialWithOptions(l.Addr().String(), nil, 0, 0)
client, err := DialWithOptions(l.Addr().String())
if err != nil {
panic(err)
}
@@ -682,7 +682,7 @@ func TestUDPSessionNonOwnedPacketConn(t *testing.T) {
// Make it remember when it has been closed.
pconn := newClosedFlagPacketConn(c)
client, err := NewConn2(l.Addr(), nil, 0, 0, pconn)
client, err := NewConn2(l.Addr(), pconn)
if err != nil {
panic(err)
}

View File

@@ -7,14 +7,14 @@ import (
// Snmp defines network statistics indicator
type Snmp struct {
BytesSent uint64 // bytes sent from upper level
BytesReceived uint64 // bytes received to upper level
MaxConn uint64 // max number of connections ever reached
ActiveOpens uint64 // accumulated active open connections
PassiveOpens uint64 // accumulated passive open connections
CurrEstab uint64 // current number of established connections
InErrs uint64 // UDP read errors reported from net.PacketConn
InCsumErrors uint64 // checksum errors from CRC32
BytesSent uint64 // bytes sent from upper level
BytesReceived uint64 // bytes received to upper level
MaxConn uint64 // max number of connections ever reached
ActiveOpens uint64 // accumulated active open connections
PassiveOpens uint64 // accumulated passive open connections
CurrEstab uint64 // current number of established connections
// InErrs uint64 // UDP read errors reported from net.PacketConn
// InCsumErrors uint64 // checksum errors from CRC32
KCPInErrors uint64 // packet iput errors reported from KCP
InPkts uint64 // incoming packets count
OutPkts uint64 // outgoing packets count
@@ -27,10 +27,10 @@ type Snmp struct {
EarlyRetransSegs uint64 // accmulated early retransmitted segments
LostSegs uint64 // number of segs inferred as lost
RepeatSegs uint64 // number of segs duplicated
FECRecovered uint64 // correct packets recovered from FEC
FECErrs uint64 // incorrect packets recovered from FEC
FECParityShards uint64 // FEC segments received
FECShortShards uint64 // number of data shards that's not enough for recovery
// FECRecovered uint64 // correct packets recovered from FEC
// FECErrs uint64 // incorrect packets recovered from FEC
// FECParityShards uint64 // FEC segments received
// FECShortShards uint64 // number of data shards that's not enough for recovery
}
func newSnmp() *Snmp {
@@ -46,8 +46,8 @@ func (s *Snmp) Header() []string {
"ActiveOpens",
"PassiveOpens",
"CurrEstab",
"InErrs",
"InCsumErrors",
// "InErrs",
// "InCsumErrors",
"KCPInErrors",
"InPkts",
"OutPkts",
@@ -60,10 +60,10 @@ func (s *Snmp) Header() []string {
"EarlyRetransSegs",
"LostSegs",
"RepeatSegs",
"FECParityShards",
"FECErrs",
"FECRecovered",
"FECShortShards",
// "FECParityShards",
// "FECErrs",
// "FECRecovered",
// "FECShortShards",
}
}
@@ -77,8 +77,8 @@ func (s *Snmp) ToSlice() []string {
fmt.Sprint(snmp.ActiveOpens),
fmt.Sprint(snmp.PassiveOpens),
fmt.Sprint(snmp.CurrEstab),
fmt.Sprint(snmp.InErrs),
fmt.Sprint(snmp.InCsumErrors),
// fmt.Sprint(snmp.InErrs),
// fmt.Sprint(snmp.InCsumErrors),
fmt.Sprint(snmp.KCPInErrors),
fmt.Sprint(snmp.InPkts),
fmt.Sprint(snmp.OutPkts),
@@ -91,10 +91,10 @@ func (s *Snmp) ToSlice() []string {
fmt.Sprint(snmp.EarlyRetransSegs),
fmt.Sprint(snmp.LostSegs),
fmt.Sprint(snmp.RepeatSegs),
fmt.Sprint(snmp.FECParityShards),
fmt.Sprint(snmp.FECErrs),
fmt.Sprint(snmp.FECRecovered),
fmt.Sprint(snmp.FECShortShards),
// fmt.Sprint(snmp.FECParityShards),
// fmt.Sprint(snmp.FECErrs),
// fmt.Sprint(snmp.FECRecovered),
// fmt.Sprint(snmp.FECShortShards),
}
}
@@ -107,8 +107,8 @@ func (s *Snmp) Copy() *Snmp {
d.ActiveOpens = atomic.LoadUint64(&s.ActiveOpens)
d.PassiveOpens = atomic.LoadUint64(&s.PassiveOpens)
d.CurrEstab = atomic.LoadUint64(&s.CurrEstab)
d.InErrs = atomic.LoadUint64(&s.InErrs)
d.InCsumErrors = atomic.LoadUint64(&s.InCsumErrors)
// d.InErrs = atomic.LoadUint64(&s.InErrs)
// d.InCsumErrors = atomic.LoadUint64(&s.InCsumErrors)
d.KCPInErrors = atomic.LoadUint64(&s.KCPInErrors)
d.InPkts = atomic.LoadUint64(&s.InPkts)
d.OutPkts = atomic.LoadUint64(&s.OutPkts)
@@ -121,10 +121,10 @@ func (s *Snmp) Copy() *Snmp {
d.EarlyRetransSegs = atomic.LoadUint64(&s.EarlyRetransSegs)
d.LostSegs = atomic.LoadUint64(&s.LostSegs)
d.RepeatSegs = atomic.LoadUint64(&s.RepeatSegs)
d.FECParityShards = atomic.LoadUint64(&s.FECParityShards)
d.FECErrs = atomic.LoadUint64(&s.FECErrs)
d.FECRecovered = atomic.LoadUint64(&s.FECRecovered)
d.FECShortShards = atomic.LoadUint64(&s.FECShortShards)
// d.FECParityShards = atomic.LoadUint64(&s.FECParityShards)
// d.FECErrs = atomic.LoadUint64(&s.FECErrs)
// d.FECRecovered = atomic.LoadUint64(&s.FECRecovered)
// d.FECShortShards = atomic.LoadUint64(&s.FECShortShards)
return d
}
@@ -136,8 +136,8 @@ func (s *Snmp) Reset() {
atomic.StoreUint64(&s.ActiveOpens, 0)
atomic.StoreUint64(&s.PassiveOpens, 0)
atomic.StoreUint64(&s.CurrEstab, 0)
atomic.StoreUint64(&s.InErrs, 0)
atomic.StoreUint64(&s.InCsumErrors, 0)
// atomic.StoreUint64(&s.InErrs, 0)
// atomic.StoreUint64(&s.InCsumErrors, 0)
atomic.StoreUint64(&s.KCPInErrors, 0)
atomic.StoreUint64(&s.InPkts, 0)
atomic.StoreUint64(&s.OutPkts, 0)
@@ -150,10 +150,10 @@ func (s *Snmp) Reset() {
atomic.StoreUint64(&s.EarlyRetransSegs, 0)
atomic.StoreUint64(&s.LostSegs, 0)
atomic.StoreUint64(&s.RepeatSegs, 0)
atomic.StoreUint64(&s.FECParityShards, 0)
atomic.StoreUint64(&s.FECErrs, 0)
atomic.StoreUint64(&s.FECRecovered, 0)
atomic.StoreUint64(&s.FECShortShards, 0)
// atomic.StoreUint64(&s.FECParityShards, 0)
// atomic.StoreUint64(&s.FECErrs, 0)
// atomic.StoreUint64(&s.FECRecovered, 0)
// atomic.StoreUint64(&s.FECShortShards, 0)
}
// DefaultSnmp is the global KCP connection statistics collector

View File

@@ -1,80 +0,0 @@
package kcp
import (
"net"
"sync/atomic"
"github.com/pkg/errors"
"golang.org/x/net/ipv4"
)
func buildEnet(connType uint8, enetType uint32, conv uint64) []byte {
data := make([]byte, 20)
if connType == ConnEnetSyn {
copy(data[0:4], MagicEnetSynHead)
copy(data[16:20], MagicEnetSynTail)
} else if connType == ConnEnetEst {
copy(data[0:4], MagicEnetEstHead)
copy(data[16:20], MagicEnetEstTail)
} else if connType == ConnEnetFin {
copy(data[0:4], MagicEnetFinHead)
copy(data[16:20], MagicEnetFinTail)
} else {
return nil
}
// conv的高四个字节和低四个字节分开
// 例如 00 00 01 45 | LL LL LL LL | HH HH HH HH | 49 96 02 d2 | 14 51 45 45
data[4] = uint8(conv >> 24)
data[5] = uint8(conv >> 16)
data[6] = uint8(conv >> 8)
data[7] = uint8(conv >> 0)
data[8] = uint8(conv >> 56)
data[9] = uint8(conv >> 48)
data[10] = uint8(conv >> 40)
data[11] = uint8(conv >> 32)
// Enet
data[12] = uint8(enetType >> 24)
data[13] = uint8(enetType >> 16)
data[14] = uint8(enetType >> 8)
data[15] = uint8(enetType >> 0)
return data
}
func (l *Listener) defaultSendEnetNotifyToClient(enet *Enet) {
remoteAddr, err := net.ResolveUDPAddr("udp", enet.Addr)
if err != nil {
return
}
data := buildEnet(enet.ConnType, enet.EnetType, enet.ConvId)
if data == nil {
return
}
_, _ = l.conn.WriteTo(data, remoteAddr)
}
func (s *UDPSession) defaultSendEnetNotify(enet *Enet) {
data := buildEnet(enet.ConnType, enet.EnetType, s.GetConv())
if data == nil {
return
}
s.defaultTx([]ipv4.Message{{
Buffers: [][]byte{data},
Addr: s.remote,
}})
}
func (s *UDPSession) defaultTx(txqueue []ipv4.Message) {
nbytes := 0
npkts := 0
for k := range txqueue {
if n, err := s.conn.WriteTo(txqueue[k].Buffers[0], txqueue[k].Addr); err == nil {
nbytes += n
npkts++
} else {
s.notifyWriteError(errors.WithStack(err))
break
}
}
atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts))
atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes))
}

View File

@@ -1,102 +0,0 @@
//go:build linux
// +build linux
package kcp
import (
"net"
"os"
"sync/atomic"
"github.com/pkg/errors"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
func (l *Listener) SendEnetNotifyToClient(enet *Enet) {
var xconn batchConn
_, ok := l.conn.(*net.UDPConn)
if !ok {
return
}
localAddr, err := net.ResolveUDPAddr("udp", l.conn.LocalAddr().String())
if err != nil {
return
}
if localAddr.IP.To4() != nil {
xconn = ipv4.NewPacketConn(l.conn)
} else {
xconn = ipv6.NewPacketConn(l.conn)
}
// default version
if xconn == nil {
l.defaultSendEnetNotifyToClient(enet)
return
}
remoteAddr, err := net.ResolveUDPAddr("udp", enet.Addr)
if err != nil {
return
}
data := buildEnet(enet.ConnType, enet.EnetType, enet.ConvId)
if data == nil {
return
}
_, _ = xconn.WriteBatch([]ipv4.Message{{
Buffers: [][]byte{data},
Addr: remoteAddr,
}}, 0)
}
func (s *UDPSession) SendEnetNotify(enet *Enet) {
data := buildEnet(enet.ConnType, enet.EnetType, s.GetConv())
if data == nil {
return
}
s.tx([]ipv4.Message{{
Buffers: [][]byte{data},
Addr: s.remote,
}})
}
func (s *UDPSession) tx(txqueue []ipv4.Message) {
// default version
if s.xconn == nil || s.xconnWriteError != nil {
s.defaultTx(txqueue)
return
}
// x/net version
nbytes := 0
npkts := 0
for len(txqueue) > 0 {
if n, err := s.xconn.WriteBatch(txqueue, 0); err == nil {
for k := range txqueue[:n] {
nbytes += len(txqueue[k].Buffers[0])
}
npkts += n
txqueue = txqueue[n:]
} else {
// compatibility issue:
// for linux kernel<=2.6.32, support for sendmmsg is not available
// an error of type os.SyscallError will be returned
if operr, ok := err.(*net.OpError); ok {
if se, ok := operr.Err.(*os.SyscallError); ok {
if se.Syscall == "sendmmsg" {
s.xconnWriteError = se
s.defaultTx(txqueue)
return
}
}
}
s.notifyWriteError(errors.WithStack(err))
break
}
}
atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts))
atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes))
}

View File

@@ -3,8 +3,11 @@ package kcp
import (
"bytes"
"encoding/binary"
"net"
"sync/atomic"
"github.com/pkg/errors"
"golang.org/x/net/ipv4"
)
func (s *UDPSession) defaultReadLoop() {
@@ -125,3 +128,74 @@ func (l *Listener) defaultMonitor() {
}
}
}
func buildEnet(connType uint8, enetType uint32, conv uint64) []byte {
data := make([]byte, 20)
if connType == ConnEnetSyn {
copy(data[0:4], MagicEnetSynHead)
copy(data[16:20], MagicEnetSynTail)
} else if connType == ConnEnetEst {
copy(data[0:4], MagicEnetEstHead)
copy(data[16:20], MagicEnetEstTail)
} else if connType == ConnEnetFin {
copy(data[0:4], MagicEnetFinHead)
copy(data[16:20], MagicEnetFinTail)
} else {
return nil
}
// conv的高四个字节和低四个字节分开
// 例如 00 00 01 45 | LL LL LL LL | HH HH HH HH | 49 96 02 d2 | 14 51 45 45
data[4] = uint8(conv >> 24)
data[5] = uint8(conv >> 16)
data[6] = uint8(conv >> 8)
data[7] = uint8(conv >> 0)
data[8] = uint8(conv >> 56)
data[9] = uint8(conv >> 48)
data[10] = uint8(conv >> 40)
data[11] = uint8(conv >> 32)
// Enet
data[12] = uint8(enetType >> 24)
data[13] = uint8(enetType >> 16)
data[14] = uint8(enetType >> 8)
data[15] = uint8(enetType >> 0)
return data
}
func (l *Listener) defaultSendEnetNotifyToClient(enet *Enet) {
remoteAddr, err := net.ResolveUDPAddr("udp", enet.Addr)
if err != nil {
return
}
data := buildEnet(enet.ConnType, enet.EnetType, enet.ConvId)
if data == nil {
return
}
_, _ = l.conn.WriteTo(data, remoteAddr)
}
func (s *UDPSession) defaultSendEnetNotify(enet *Enet) {
data := buildEnet(enet.ConnType, enet.EnetType, s.GetConv())
if data == nil {
return
}
s.defaultTx([]ipv4.Message{{
Buffers: [][]byte{data},
Addr: s.remote,
}})
}
func (s *UDPSession) defaultTx(txqueue []ipv4.Message) {
nbytes := 0
npkts := 0
for k := range txqueue {
if n, err := s.conn.WriteTo(txqueue[k].Buffers[0], txqueue[k].Addr); err == nil {
nbytes += n
npkts++
} else {
s.notifyWriteError(errors.WithStack(err))
break
}
}
atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts))
atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes))
}

View File

@@ -8,6 +8,7 @@ import (
"encoding/binary"
"net"
"os"
"sync/atomic"
"github.com/pkg/errors"
"golang.org/x/net/ipv4"
@@ -198,3 +199,91 @@ func (l *Listener) monitor() {
}
}
}
func (l *Listener) SendEnetNotifyToClient(enet *Enet) {
var xconn batchConn
_, ok := l.conn.(*net.UDPConn)
if !ok {
return
}
localAddr, err := net.ResolveUDPAddr("udp", l.conn.LocalAddr().String())
if err != nil {
return
}
if localAddr.IP.To4() != nil {
xconn = ipv4.NewPacketConn(l.conn)
} else {
xconn = ipv6.NewPacketConn(l.conn)
}
// default version
if xconn == nil {
l.defaultSendEnetNotifyToClient(enet)
return
}
remoteAddr, err := net.ResolveUDPAddr("udp", enet.Addr)
if err != nil {
return
}
data := buildEnet(enet.ConnType, enet.EnetType, enet.ConvId)
if data == nil {
return
}
_, _ = xconn.WriteBatch([]ipv4.Message{{
Buffers: [][]byte{data},
Addr: remoteAddr,
}}, 0)
}
func (s *UDPSession) SendEnetNotify(enet *Enet) {
data := buildEnet(enet.ConnType, enet.EnetType, s.GetConv())
if data == nil {
return
}
s.tx([]ipv4.Message{{
Buffers: [][]byte{data},
Addr: s.remote,
}})
}
func (s *UDPSession) tx(txqueue []ipv4.Message) {
// default version
if s.xconn == nil || s.xconnWriteError != nil {
s.defaultTx(txqueue)
return
}
// x/net version
nbytes := 0
npkts := 0
for len(txqueue) > 0 {
if n, err := s.xconn.WriteBatch(txqueue, 0); err == nil {
for k := range txqueue[:n] {
nbytes += len(txqueue[k].Buffers[0])
}
npkts += n
txqueue = txqueue[n:]
} else {
// compatibility issue:
// for linux kernel<=2.6.32, support for sendmmsg is not available
// an error of type os.SyscallError will be returned
if operr, ok := err.(*net.OpError); ok {
if se, ok := operr.Err.(*os.SyscallError); ok {
if se.Syscall == "sendmmsg" {
s.xconnWriteError = se
s.defaultTx(txqueue)
return
}
}
}
s.notifyWriteError(errors.WithStack(err))
break
}
}
atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts))
atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes))
}

View File

@@ -7,6 +7,14 @@ import (
"golang.org/x/net/ipv4"
)
func (s *UDPSession) readLoop() {
s.defaultReadLoop()
}
func (l *Listener) monitor() {
l.defaultMonitor()
}
func (l *Listener) SendEnetNotifyToClient(enet *Enet) {
l.defaultSendEnetNotifyToClient(enet)
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/binary"
"strconv"
"sync"
"sync/atomic"
"time"
"hk4e/common/config"
@@ -21,27 +22,36 @@ import (
)
const (
PacketFreqLimit = 1000
PacketMaxLen = 343 * 1024
ConnRecvTimeout = 30
ConnSendTimeout = 10
ConnSynPacketFreqLimit = 100 // 连接建立握手包每秒发包频率限制
RecvPacketFreqLimit = 100 // 客户端上行每秒发包频率限制
SendPacketFreqLimit = 1000 // 服务器下行每秒发包频率限制
PacketMaxLen = 343 * 1024 // 最大应用层包长度
ConnRecvTimeout = 30 // 收包超时时间 秒
ConnSendTimeout = 10 // 发包超时时间 秒
MaxClientConnNumLimit = 100 // 最大客户端连接数限制
)
var CLIENT_CONN_NUM int32 = 0 // 当前客户端连接数
type KcpConnectManager struct {
discovery *rpc.DiscoveryClient
openState bool
discovery *rpc.DiscoveryClient // node服务器客户端
openState bool // 网关开放状态
// 会话
sessionConvIdMap map[uint64]*Session
sessionUserIdMap map[uint32]*Session
sessionMapLock sync.RWMutex
kcpEventInput chan *KcpEvent
kcpEventOutput chan *KcpEvent
serverCmdProtoMap *cmd.CmdProtoMap
clientCmdProtoMap *client_proto.ClientCmdProtoMap
messageQueue *mq.MessageQueue
localMsgOutput chan *ProtoMsg
createSessionChan chan *Session
destroySessionChan chan *Session
// 密钥相关
// 连接事件
kcpEventInput chan *KcpEvent
kcpEventOutput chan *KcpEvent
// 协议
serverCmdProtoMap *cmd.CmdProtoMap
clientCmdProtoMap *client_proto.ClientCmdProtoMap
// 输入输出管道
messageQueue *mq.MessageQueue
localMsgOutput chan *ProtoMsg
// 密钥
dispatchKey []byte
signRsaKey []byte
encRsaKeyMap map[string][]byte
@@ -53,6 +63,8 @@ func NewKcpConnectManager(messageQueue *mq.MessageQueue, discovery *rpc.Discover
r.openState = true
r.sessionConvIdMap = make(map[uint64]*Session)
r.sessionUserIdMap = make(map[uint32]*Session)
r.createSessionChan = make(chan *Session, 1000)
r.destroySessionChan = make(chan *Session, 1000)
r.kcpEventInput = make(chan *KcpEvent, 1000)
r.kcpEventOutput = make(chan *KcpEvent, 1000)
r.serverCmdProtoMap = cmd.NewCmdProtoMap()
@@ -61,8 +73,6 @@ func NewKcpConnectManager(messageQueue *mq.MessageQueue, discovery *rpc.Discover
}
r.messageQueue = messageQueue
r.localMsgOutput = make(chan *ProtoMsg, 1000)
r.createSessionChan = make(chan *Session, 1000)
r.destroySessionChan = make(chan *Session, 1000)
r.run()
return r
}
@@ -86,7 +96,7 @@ func (k *KcpConnectManager) run() {
k.dispatchKey = regionEc2b.XorKey()
// kcp
port := strconv.Itoa(int(config.CONF.Hk4e.KcpPort))
listener, err := kcp.ListenWithOptions("0.0.0.0:"+port, nil, 0, 0)
listener, err := kcp.ListenWithOptions("0.0.0.0:" + port)
if err != nil {
logger.Error("listen kcp err: %v", err)
return
@@ -95,29 +105,56 @@ func (k *KcpConnectManager) run() {
go k.eventHandle()
go k.sendMsgHandle()
go k.acceptHandle(listener)
go k.gateNetInfo()
}
func (k *KcpConnectManager) Close() {
k.closeAllKcpConn()
// 等待所有连接关闭时需要发送的消息发送完毕
time.Sleep(time.Second * 3)
}
func (k *KcpConnectManager) gateNetInfo() {
ticker := time.NewTicker(time.Second * 60)
kcpErrorCount := uint64(0)
for {
<-ticker.C
snmp := kcp.DefaultSnmp.Copy()
kcpErrorCount += snmp.KCPInErrors
logger.Info("kcp send: %v B/s, kcp recv: %v B/s", snmp.BytesSent/60, snmp.BytesReceived/60)
logger.Info("udp send: %v B/s, udp recv: %v B/s", snmp.OutBytes/60, snmp.InBytes/60)
logger.Info("udp send: %v pps, udp recv: %v pps", snmp.OutPkts/60, snmp.InPkts/60)
clientConnNum := atomic.LoadInt32(&CLIENT_CONN_NUM)
logger.Info("conn num: %v, new conn num: %v, kcp error num: %v", clientConnNum, snmp.CurrEstab, kcpErrorCount)
kcp.DefaultSnmp.Reset()
}
}
// 接收并创建新连接处理函数
func (k *KcpConnectManager) acceptHandle(listener *kcp.Listener) {
logger.Debug("accept handle start")
logger.Info("accept handle start")
for {
conn, err := listener.AcceptKCP()
if err != nil {
logger.Error("accept kcp err: %v", err)
return
}
convId := conn.GetConv()
if k.openState == false {
logger.Error("gate not open, convId: %v", convId)
_ = conn.Close()
continue
}
clientConnNum := atomic.AddInt32(&CLIENT_CONN_NUM, 1)
if clientConnNum > MaxClientConnNumLimit {
logger.Error("gate conn num limit, convId: %v", convId)
_ = conn.Close()
atomic.AddInt32(&CLIENT_CONN_NUM, -1)
continue
}
conn.SetACKNoDelay(true)
conn.SetWriteDelay(false)
convId := conn.GetConv()
logger.Debug("client connect, convId: %v", convId)
logger.Info("client connect, convId: %v", convId)
kcpRawSendChan := make(chan *ProtoMsg, 1000)
session := &Session{
conn: conn,
@@ -145,61 +182,76 @@ func (k *KcpConnectManager) acceptHandle(listener *kcp.Listener) {
}
}
// 连接事件处理函数
func (k *KcpConnectManager) enetHandle(listener *kcp.Listener) {
logger.Debug("enet handle start")
logger.Info("enet handle start")
// conv短时间内唯一生成
convGenMap := make(map[uint64]int64)
pktFreqLimitCounter := 0
pktFreqLimitTimer := time.Now().UnixNano()
for {
enetNotify := <-listener.EnetNotify
logger.Info("[Enet Notify], addr: %v, conv: %v, conn: %v, enet: %v", enetNotify.Addr, enetNotify.ConvId, enetNotify.ConnType, enetNotify.EnetType)
switch enetNotify.ConnType {
case kcp.ConnEnetSyn:
if enetNotify.EnetType == kcp.EnetClientConnectKey {
// 清理老旧的conv
// 连接建立握手包频率限制
pktFreqLimitCounter++
if pktFreqLimitCounter > ConnSynPacketFreqLimit {
now := time.Now().UnixNano()
oldConvList := make([]uint64, 0)
for conv, timestamp := range convGenMap {
if now-timestamp > int64(time.Hour) {
oldConvList = append(oldConvList, conv)
}
if now-pktFreqLimitTimer > int64(time.Second) {
pktFreqLimitCounter = 0
pktFreqLimitTimer = now
} else {
continue
}
delConvList := make([]uint64, 0)
k.sessionMapLock.RLock()
for _, conv := range oldConvList {
_, exist := k.sessionConvIdMap[conv]
if !exist {
delConvList = append(delConvList, conv)
delete(convGenMap, conv)
}
}
k.sessionMapLock.RUnlock()
logger.Info("clean dead conv list: %v", delConvList)
// 生成没用过的conv
var conv uint64
for {
convData := random.GetRandomByte(8)
convDataBuffer := bytes.NewBuffer(convData)
_ = binary.Read(convDataBuffer, binary.LittleEndian, &conv)
_, exist := convGenMap[conv]
if exist {
continue
} else {
convGenMap[conv] = time.Now().UnixNano()
break
}
}
listener.SendEnetNotifyToClient(&kcp.Enet{
Addr: enetNotify.Addr,
ConvId: conv,
ConnType: kcp.ConnEnetEst,
EnetType: enetNotify.EnetType,
})
}
if enetNotify.EnetType != kcp.EnetClientConnectKey {
continue
}
// 清理老旧的conv
now := time.Now().UnixNano()
oldConvList := make([]uint64, 0)
for conv, timestamp := range convGenMap {
if now-timestamp > int64(time.Hour) {
oldConvList = append(oldConvList, conv)
}
}
delConvList := make([]uint64, 0)
k.sessionMapLock.RLock()
for _, conv := range oldConvList {
_, exist := k.sessionConvIdMap[conv]
if !exist {
delConvList = append(delConvList, conv)
delete(convGenMap, conv)
}
}
k.sessionMapLock.RUnlock()
logger.Info("clean dead conv list: %v", delConvList)
// 生成没用过的conv
var conv uint64
for {
convData := random.GetRandomByte(8)
convDataBuffer := bytes.NewBuffer(convData)
_ = binary.Read(convDataBuffer, binary.LittleEndian, &conv)
_, exist := convGenMap[conv]
if exist {
continue
} else {
convGenMap[conv] = time.Now().UnixNano()
break
}
}
listener.SendEnetNotifyToClient(&kcp.Enet{
Addr: enetNotify.Addr,
ConvId: conv,
ConnType: kcp.ConnEnetEst,
EnetType: enetNotify.EnetType,
})
case kcp.ConnEnetEst:
case kcp.ConnEnetFin:
session := k.GetSessionByConvId(enetNotify.ConvId)
if session == nil {
logger.Error("session not exist, convId: %v", enetNotify.ConvId)
logger.Error("session not exist, conv: %v", enetNotify.ConvId)
continue
}
session.conn.SendEnetNotify(&kcp.Enet{
@@ -219,6 +271,7 @@ func (k *KcpConnectManager) enetHandle(listener *kcp.Listener) {
}
}
// Session 连接会话结构
type Session struct {
conn *kcp.UDPSession
connState uint8
@@ -237,13 +290,13 @@ type Session struct {
// 接收
func (k *KcpConnectManager) recvHandle(session *Session) {
logger.Debug("recv handle start")
logger.Info("recv handle start")
conn := session.conn
convId := conn.GetConv()
pktFreqLimitCounter := 0
pktFreqLimitTimer := time.Now().UnixNano()
recvBuf := make([]byte, PacketMaxLen)
dataBuf := make([]byte, 0, 1500)
pktFreqLimitCounter := 0
pktFreqLimitTimer := time.Now().UnixNano()
for {
_ = conn.SetReadDeadline(time.Now().Add(time.Second * ConnRecvTimeout))
recvLen, err := conn.Read(recvBuf)
@@ -254,16 +307,16 @@ func (k *KcpConnectManager) recvHandle(session *Session) {
}
// 收包频率限制
pktFreqLimitCounter++
now := time.Now().UnixNano()
if now-pktFreqLimitTimer > int64(time.Second) {
if pktFreqLimitCounter > PacketFreqLimit {
if pktFreqLimitCounter > RecvPacketFreqLimit {
now := time.Now().UnixNano()
if now-pktFreqLimitTimer > int64(time.Second) {
pktFreqLimitCounter = 0
pktFreqLimitTimer = now
} else {
logger.Error("exit recv loop, client packet send freq too high, convId: %v, pps: %v", convId, pktFreqLimitCounter)
k.closeKcpConn(session, kcp.EnetPacketFreqTooHigh)
break
} else {
pktFreqLimitCounter = 0
}
pktFreqLimitTimer = now
}
recvData := recvBuf[:recvLen]
kcpMsgList := make([]*KcpMsg, 0)
@@ -279,9 +332,11 @@ func (k *KcpConnectManager) recvHandle(session *Session) {
// 发送
func (k *KcpConnectManager) sendHandle(session *Session) {
logger.Debug("send handle start")
logger.Info("send handle start")
conn := session.conn
convId := conn.GetConv()
pktFreqLimitCounter := 0
pktFreqLimitTimer := time.Now().UnixNano()
for {
protoMsg, ok := <-session.kcpRawSendChan
if !ok {
@@ -302,9 +357,22 @@ func (k *KcpConnectManager) sendHandle(session *Session) {
k.closeKcpConn(session, kcp.EnetServerKick)
break
}
// 发包频率限制
pktFreqLimitCounter++
if pktFreqLimitCounter > SendPacketFreqLimit {
now := time.Now().UnixNano()
if now-pktFreqLimitTimer > int64(time.Second) {
pktFreqLimitCounter = 0
pktFreqLimitTimer = now
} else {
logger.Error("exit send loop, server packet send freq too high, convId: %v, pps: %v", convId, pktFreqLimitCounter)
k.closeKcpConn(session, kcp.EnetPacketFreqTooHigh)
break
}
}
if session.changeXorKeyFin == false && protoMsg.CmdId == cmd.GetPlayerTokenRsp {
// XOR密钥切换
logger.Debug("change session xor key, convId: %v", convId)
logger.Info("change session xor key, convId: %v", convId)
session.changeXorKeyFin = true
keyBlock := random.NewKeyBlock(session.seed, session.useMagicSeed)
xorKey := keyBlock.XorKey()
@@ -315,8 +383,8 @@ func (k *KcpConnectManager) sendHandle(session *Session) {
}
}
// 强制关闭指定连接
func (k *KcpConnectManager) forceCloseKcpConn(convId uint64, reason uint32) {
// 强制关闭某个连接
session := k.GetSessionByConvId(convId)
if session == nil {
logger.Error("session not exist, convId: %v", convId)
@@ -326,6 +394,7 @@ func (k *KcpConnectManager) forceCloseKcpConn(convId uint64, reason uint32) {
logger.Info("conn has been force close, convId: %v", convId)
}
// 关闭指定连接
func (k *KcpConnectManager) closeKcpConn(session *Session, enetType uint32) {
if session.connState == ConnClose {
return
@@ -358,8 +427,10 @@ func (k *KcpConnectManager) closeKcpConn(session *Session, enetType uint32) {
})
logger.Info("send to gs user offline, ConvId: %v, UserId: %v", convId, connCtrlMsg.UserId)
k.destroySessionChan <- session
atomic.AddInt32(&CLIENT_CONN_NUM, -1)
}
// 关闭所有连接
func (k *KcpConnectManager) closeAllKcpConn() {
sessionList := make([]*Session, 0)
k.sessionMapLock.RLock()