diff --git a/gate/kcp/kcp_test.go b/gate/kcp/kcp_test.go deleted file mode 100644 index d23ecabe..00000000 --- a/gate/kcp/kcp_test.go +++ /dev/null @@ -1,135 +0,0 @@ -package kcp - -import ( - "io" - "net" - "sync" - "testing" - "time" - - "github.com/xtaci/lossyconn" -) - -const repeat = 16 - -func TestLossyConn1(t *testing.T) { - t.Log("testing loss rate 10%, rtt 200ms") - t.Log("testing link with nodelay parameters:1 10 2 1") - client, err := lossyconn.NewLossyConn(0.1, 100) - if err != nil { - t.Fatal(err) - } - - server, err := lossyconn.NewLossyConn(0.1, 100) - if err != nil { - t.Fatal(err) - } - testlink(t, client, server, 1, 10, 2, 1) -} - -func TestLossyConn2(t *testing.T) { - t.Log("testing loss rate 20%, rtt 200ms") - t.Log("testing link with nodelay parameters:1 10 2 1") - client, err := lossyconn.NewLossyConn(0.2, 100) - if err != nil { - t.Fatal(err) - } - - server, err := lossyconn.NewLossyConn(0.2, 100) - if err != nil { - t.Fatal(err) - } - testlink(t, client, server, 1, 10, 2, 1) -} - -func TestLossyConn3(t *testing.T) { - t.Log("testing loss rate 30%, rtt 200ms") - t.Log("testing link with nodelay parameters:1 10 2 1") - client, err := lossyconn.NewLossyConn(0.3, 100) - if err != nil { - t.Fatal(err) - } - - server, err := lossyconn.NewLossyConn(0.3, 100) - if err != nil { - t.Fatal(err) - } - testlink(t, client, server, 1, 10, 2, 1) -} - -func TestLossyConn4(t *testing.T) { - t.Log("testing loss rate 10%, rtt 200ms") - t.Log("testing link with nodelay parameters:1 10 2 0") - client, err := lossyconn.NewLossyConn(0.1, 100) - if err != nil { - t.Fatal(err) - } - - server, err := lossyconn.NewLossyConn(0.1, 100) - if err != nil { - t.Fatal(err) - } - testlink(t, client, server, 1, 10, 2, 0) -} - -func testlink(t *testing.T, client *lossyconn.LossyConn, server *lossyconn.LossyConn, nodelay, interval, resend, nc int) { - t.Log("testing with nodelay parameters:", nodelay, interval, resend, nc) - sess, _ := NewConn2(server.LocalAddr(), client) - listener, _ := ServeConn(server) - echoServer := func(l *Listener) { - for { - conn, err := l.AcceptKCP() - if err != nil { - return - } - go func() { - conn.SetNoDelay(nodelay, interval, resend, nc) - buf := make([]byte, 65536) - for { - n, err := conn.Read(buf) - if err != nil { - return - } - conn.Write(buf[:n]) - } - }() - } - } - - echoTester := func(s *UDPSession, raddr net.Addr) { - s.SetNoDelay(nodelay, interval, resend, nc) - buf := make([]byte, 64) - var rtt time.Duration - for i := 0; i < repeat; i++ { - start := time.Now() - s.Write(buf) - io.ReadFull(s, buf) - rtt += time.Since(start) - } - - t.Log("client:", client) - t.Log("server:", server) - t.Log("avg rtt:", rtt/repeat) - t.Logf("total time: %v for %v round trip:", rtt, repeat) - } - - go echoServer(listener) - echoTester(sess, server.LocalAddr()) -} - -func BenchmarkFlush(b *testing.B) { - kcp := NewKCP(1, func(buf []byte, size int) {}) - kcp.snd_buf = make([]segment, 1024) - for k := range kcp.snd_buf { - kcp.snd_buf[k].xmit = 1 - kcp.snd_buf[k].resendts = currentMs() + 10000 - } - b.ResetTimer() - b.ReportAllocs() - var mu sync.Mutex - for i := 0; i < b.N; i++ { - mu.Lock() - kcp.flush(false) - mu.Unlock() - } -} diff --git a/gate/kcp/sess_test.go b/gate/kcp/sess_test.go deleted file mode 100644 index 82d82090..00000000 --- a/gate/kcp/sess_test.go +++ /dev/null @@ -1,703 +0,0 @@ -package kcp - -import ( - "crypto/sha1" - "fmt" - "io" - "log" - "net" - "net/http" - _ "net/http/pprof" - "sync" - "sync/atomic" - "testing" - "time" - - "golang.org/x/crypto/pbkdf2" -) - -var baseport = uint32(10000) -var key = []byte("testkey") -var pass = pbkdf2.Key(key, []byte("testsalt"), 4096, 32, sha1.New) - -func init() { - go func() { - log.Println(http.ListenAndServe("0.0.0.0:6060", nil)) - }() - - log.Println("beginning tests, encryption:salsa20, fec:10/3") -} - -func dialEcho(port int) (*UDPSession, error) { - // block, _ := NewNoneBlockCrypt(pass) - // block, _ := NewSimpleXORBlockCrypt(pass) - // block, _ := NewTEABlockCrypt(pass[:16]) - // block, _ := NewAESBlockCrypt(pass) - // block, _ := NewSalsa20BlockCrypt(pass) - sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port)) - if err != nil { - panic(err) - } - - sess.SetStreamMode(true) - sess.SetStreamMode(false) - sess.SetStreamMode(true) - sess.SetWindowSize(1024, 1024) - sess.SetReadBuffer(16 * 1024 * 1024) - sess.SetWriteBuffer(16 * 1024 * 1024) - sess.SetStreamMode(true) - sess.SetNoDelay(1, 10, 2, 1) - sess.SetMtu(1400) - sess.SetMtu(1600) - sess.SetMtu(1400) - sess.SetACKNoDelay(true) - sess.SetACKNoDelay(false) - sess.SetDeadline(time.Now().Add(time.Minute)) - return sess, err -} - -func dialSink(port int) (*UDPSession, error) { - sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port)) - if err != nil { - panic(err) - } - - sess.SetStreamMode(true) - sess.SetWindowSize(1024, 1024) - sess.SetReadBuffer(16 * 1024 * 1024) - sess.SetWriteBuffer(16 * 1024 * 1024) - sess.SetStreamMode(true) - sess.SetNoDelay(1, 10, 2, 1) - sess.SetMtu(1400) - sess.SetACKNoDelay(false) - sess.SetDeadline(time.Now().Add(time.Minute)) - return sess, err -} - -func dialTinyBufferEcho(port int) (*UDPSession, error) { - // block, _ := NewNoneBlockCrypt(pass) - // block, _ := NewSimpleXORBlockCrypt(pass) - // block, _ := NewTEABlockCrypt(pass[:16]) - // block, _ := NewAESBlockCrypt(pass) - // block, _ := NewSalsa20BlockCrypt(pass) - sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port)) - if err != nil { - panic(err) - } - return sess, err -} - -// //////////////////////// -func listenEcho(port int) (net.Listener, error) { - // block, _ := NewNoneBlockCrypt(pass) - // block, _ := NewSimpleXORBlockCrypt(pass) - // block, _ := NewTEABlockCrypt(pass[:16]) - // block, _ := NewAESBlockCrypt(pass) - // block, _ := NewSalsa20BlockCrypt(pass) - return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port)) -} -func listenTinyBufferEcho(port int) (net.Listener, error) { - // block, _ := NewNoneBlockCrypt(pass) - // block, _ := NewSimpleXORBlockCrypt(pass) - // block, _ := NewTEABlockCrypt(pass[:16]) - // block, _ := NewAESBlockCrypt(pass) - // block, _ := NewSalsa20BlockCrypt(pass) - return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port)) -} - -func listenSink(port int) (net.Listener, error) { - return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port)) -} - -func echoServer(port int) net.Listener { - l, err := listenEcho(port) - if err != nil { - panic(err) - } - - go func() { - kcplistener := l.(*Listener) - kcplistener.SetReadBuffer(4 * 1024 * 1024) - kcplistener.SetWriteBuffer(4 * 1024 * 1024) - kcplistener.SetDSCP(46) - for { - s, err := l.Accept() - if err != nil { - return - } - - // coverage test - s.(*UDPSession).SetReadBuffer(4 * 1024 * 1024) - s.(*UDPSession).SetWriteBuffer(4 * 1024 * 1024) - go handleEcho(s.(*UDPSession)) - } - }() - - return l -} - -func sinkServer(port int) net.Listener { - l, err := listenSink(port) - if err != nil { - panic(err) - } - - go func() { - kcplistener := l.(*Listener) - kcplistener.SetReadBuffer(4 * 1024 * 1024) - kcplistener.SetWriteBuffer(4 * 1024 * 1024) - kcplistener.SetDSCP(46) - for { - s, err := l.Accept() - if err != nil { - return - } - - go handleSink(s.(*UDPSession)) - } - }() - - return l -} - -func tinyBufferEchoServer(port int) net.Listener { - l, err := listenTinyBufferEcho(port) - if err != nil { - panic(err) - } - - go func() { - for { - s, err := l.Accept() - if err != nil { - return - } - go handleTinyBufferEcho(s.(*UDPSession)) - } - }() - return l -} - -// ///////////////////////// - -func handleEcho(conn *UDPSession) { - conn.SetStreamMode(true) - conn.SetWindowSize(4096, 4096) - conn.SetNoDelay(1, 10, 2, 1) - conn.SetDSCP(46) - conn.SetMtu(1400) - conn.SetACKNoDelay(false) - conn.SetReadDeadline(time.Now().Add(time.Hour)) - conn.SetWriteDeadline(time.Now().Add(time.Hour)) - buf := make([]byte, 65536) - for { - n, err := conn.Read(buf) - if err != nil { - return - } - conn.Write(buf[:n]) - } -} - -func handleSink(conn *UDPSession) { - conn.SetStreamMode(true) - conn.SetWindowSize(4096, 4096) - conn.SetNoDelay(1, 10, 2, 1) - conn.SetDSCP(46) - conn.SetMtu(1400) - conn.SetACKNoDelay(false) - conn.SetReadDeadline(time.Now().Add(time.Hour)) - conn.SetWriteDeadline(time.Now().Add(time.Hour)) - buf := make([]byte, 65536) - for { - _, err := conn.Read(buf) - if err != nil { - return - } - } -} - -func handleTinyBufferEcho(conn *UDPSession) { - conn.SetStreamMode(true) - buf := make([]byte, 2) - for { - n, err := conn.Read(buf) - if err != nil { - return - } - conn.Write(buf[:n]) - } -} - -// ///////////////////////// - -func TestTimeout(t *testing.T) { - port := int(atomic.AddUint32(&baseport, 1)) - l := echoServer(port) - defer l.Close() - - cli, err := dialEcho(port) - if err != nil { - panic(err) - } - buf := make([]byte, 10) - - // timeout - cli.SetDeadline(time.Now().Add(time.Second)) - <-time.After(2 * time.Second) - n, err := cli.Read(buf) - if n != 0 || err == nil { - t.Fail() - } - cli.Close() -} - -func TestSendRecv(t *testing.T) { - port := int(atomic.AddUint32(&baseport, 1)) - l := echoServer(port) - defer l.Close() - - cli, err := dialEcho(port) - if err != nil { - panic(err) - } - cli.SetWriteDelay(true) - cli.SetDUP(1) - const N = 100 - buf := make([]byte, 10) - for i := 0; i < N; i++ { - msg := fmt.Sprintf("hello%v", i) - cli.Write([]byte(msg)) - if n, err := cli.Read(buf); err == nil { - if string(buf[:n]) != msg { - t.Fail() - } - } else { - panic(err) - } - } - cli.Close() -} - -func TestSendVector(t *testing.T) { - port := int(atomic.AddUint32(&baseport, 1)) - l := echoServer(port) - defer l.Close() - - cli, err := dialEcho(port) - if err != nil { - panic(err) - } - cli.SetWriteDelay(false) - const N = 100 - buf := make([]byte, 20) - v := make([][]byte, 2) - for i := 0; i < N; i++ { - v[0] = []byte(fmt.Sprintf("hello%v", i)) - v[1] = []byte(fmt.Sprintf("world%v", i)) - msg := fmt.Sprintf("hello%vworld%v", i, i) - cli.WriteBuffers(v) - if n, err := cli.Read(buf); err == nil { - if string(buf[:n]) != msg { - t.Error(string(buf[:n]), msg) - } - } else { - panic(err) - } - } - cli.Close() -} - -func TestTinyBufferReceiver(t *testing.T) { - port := int(atomic.AddUint32(&baseport, 1)) - l := tinyBufferEchoServer(port) - defer l.Close() - - cli, err := dialTinyBufferEcho(port) - if err != nil { - panic(err) - } - const N = 100 - snd := byte(0) - fillBuffer := func(buf []byte) { - for i := 0; i < len(buf); i++ { - buf[i] = snd - snd++ - } - } - - rcv := byte(0) - check := func(buf []byte) bool { - for i := 0; i < len(buf); i++ { - if buf[i] != rcv { - return false - } - rcv++ - } - return true - } - sndbuf := make([]byte, 7) - rcvbuf := make([]byte, 7) - for i := 0; i < N; i++ { - fillBuffer(sndbuf) - cli.Write(sndbuf) - if n, err := io.ReadFull(cli, rcvbuf); err == nil { - if !check(rcvbuf[:n]) { - t.Fail() - } - } else { - panic(err) - } - } - cli.Close() -} - -func TestClose(t *testing.T) { - var n int - var err error - - port := int(atomic.AddUint32(&baseport, 1)) - l := echoServer(port) - defer l.Close() - - cli, err := dialEcho(port) - if err != nil { - panic(err) - } - - // double close - cli.Close() - if cli.Close() == nil { - t.Fatal("double close misbehavior") - } - - // write after close - buf := make([]byte, 10) - n, err = cli.Write(buf) - if n != 0 || err == nil { - t.Fatal("write after close misbehavior") - } - - // write, close, read, read - cli, err = dialEcho(port) - if err != nil { - panic(err) - } - if n, err = cli.Write(buf); err != nil { - t.Fatal("write misbehavior") - } - - // wait until data arrival - time.Sleep(2 * time.Second) - // drain - cli.Close() - n, err = io.ReadFull(cli, buf) - if err != nil { - t.Fatal("closed conn drain bytes failed", err, n) - } - - // after drain, read should return error - n, err = cli.Read(buf) - if n != 0 || err == nil { - t.Fatal("write->close->drain->read misbehavior", err, n) - } - cli.Close() -} - -func TestParallel1024CLIENT_64BMSG_64CNT(t *testing.T) { - port := int(atomic.AddUint32(&baseport, 1)) - l := echoServer(port) - defer l.Close() - - var wg sync.WaitGroup - wg.Add(1024) - for i := 0; i < 1024; i++ { - go parallel_client(&wg, port) - } - wg.Wait() -} - -func parallel_client(wg *sync.WaitGroup, port int) (err error) { - cli, err := dialEcho(port) - if err != nil { - panic(err) - } - - err = echo_tester(cli, 64, 64) - cli.Close() - wg.Done() - return -} - -func BenchmarkEchoSpeed4K(b *testing.B) { - speedclient(b, 4096) -} - -func BenchmarkEchoSpeed64K(b *testing.B) { - speedclient(b, 65536) -} - -func BenchmarkEchoSpeed512K(b *testing.B) { - speedclient(b, 524288) -} - -func BenchmarkEchoSpeed1M(b *testing.B) { - speedclient(b, 1048576) -} - -func speedclient(b *testing.B, nbytes int) { - port := int(atomic.AddUint32(&baseport, 1)) - l := echoServer(port) - defer l.Close() - - b.ReportAllocs() - cli, err := dialEcho(port) - if err != nil { - panic(err) - } - - if err := echo_tester(cli, nbytes, b.N); err != nil { - b.Fail() - } - b.SetBytes(int64(nbytes)) - cli.Close() -} - -func BenchmarkSinkSpeed4K(b *testing.B) { - sinkclient(b, 4096) -} - -func BenchmarkSinkSpeed64K(b *testing.B) { - sinkclient(b, 65536) -} - -func BenchmarkSinkSpeed256K(b *testing.B) { - sinkclient(b, 524288) -} - -func BenchmarkSinkSpeed1M(b *testing.B) { - sinkclient(b, 1048576) -} - -func sinkclient(b *testing.B, nbytes int) { - port := int(atomic.AddUint32(&baseport, 1)) - l := sinkServer(port) - defer l.Close() - - b.ReportAllocs() - cli, err := dialSink(port) - if err != nil { - panic(err) - } - - sink_tester(cli, nbytes, b.N) - b.SetBytes(int64(nbytes)) - cli.Close() -} - -func echo_tester(cli net.Conn, msglen, msgcount int) error { - buf := make([]byte, msglen) - for i := 0; i < msgcount; i++ { - // send packet - if _, err := cli.Write(buf); err != nil { - return err - } - - // receive packet - nrecv := 0 - for { - n, err := cli.Read(buf) - if err != nil { - return err - } else { - nrecv += n - if nrecv == msglen { - break - } - } - } - } - return nil -} - -func sink_tester(cli *UDPSession, msglen, msgcount int) error { - // sender - buf := make([]byte, msglen) - for i := 0; i < msgcount; i++ { - if _, err := cli.Write(buf); err != nil { - return err - } - } - return nil -} - -func TestSNMP(t *testing.T) { - t.Log(DefaultSnmp.Copy()) - t.Log(DefaultSnmp.Header()) - t.Log(DefaultSnmp.ToSlice()) - DefaultSnmp.Reset() - t.Log(DefaultSnmp.ToSlice()) -} - -func TestListenerClose(t *testing.T) { - port := int(atomic.AddUint32(&baseport, 1)) - l, err := ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port)) - if err != nil { - t.Fail() - } - l.SetReadDeadline(time.Now().Add(time.Second)) - l.SetWriteDeadline(time.Now().Add(time.Second)) - l.SetDeadline(time.Now().Add(time.Second)) - time.Sleep(2 * time.Second) - if _, err := l.Accept(); err == nil { - t.Fail() - } - - l.Close() - // fakeaddr, _ := net.ResolveUDPAddr("udp6", "127.0.0.1:1111") - fakeConvId := uint64(0) - if l.closeSession(fakeConvId) { - t.Fail() - } -} - -// A wrapper for net.PacketConn that remembers when Close has been called. -type closedFlagPacketConn struct { - net.PacketConn - Closed bool -} - -func (c *closedFlagPacketConn) Close() error { - c.Closed = true - return c.PacketConn.Close() -} - -func newClosedFlagPacketConn(c net.PacketConn) *closedFlagPacketConn { - return &closedFlagPacketConn{c, false} -} - -// Listener should close a net.PacketConn that it created. -// https://github.com/xtaci/kcp-go/issues/165 -func TestListenerOwnedPacketConn(t *testing.T) { - // ListenWithOptions creates its own net.PacketConn. - l, err := ListenWithOptions("127.0.0.1:0") - if err != nil { - panic(err) - } - defer l.Close() - // Replace the internal net.PacketConn with one that remembers when it - // has been closed. - pconn := newClosedFlagPacketConn(l.conn) - l.conn = pconn - - if pconn.Closed { - t.Fatal("owned PacketConn closed before Listener.Close()") - } - - err = l.Close() - if err != nil { - panic(err) - } - - if !pconn.Closed { - t.Fatal("owned PacketConn not closed after Listener.Close()") - } -} - -// Listener should not close a net.PacketConn that it did not create. -// https://github.com/xtaci/kcp-go/issues/165 -func TestListenerNonOwnedPacketConn(t *testing.T) { - // Create a net.PacketConn not owned by the Listener. - c, err := net.ListenPacket("udp", "127.0.0.1:0") - if err != nil { - panic(err) - } - defer c.Close() - // Make it remember when it has been closed. - pconn := newClosedFlagPacketConn(c) - - l, err := ServeConn(pconn) - if err != nil { - panic(err) - } - defer l.Close() - - if pconn.Closed { - t.Fatal("non-owned PacketConn closed before Listener.Close()") - } - - err = l.Close() - if err != nil { - panic(err) - } - - if pconn.Closed { - t.Fatal("non-owned PacketConn closed after Listener.Close()") - } -} - -// UDPSession should close a net.PacketConn that it created. -// https://github.com/xtaci/kcp-go/issues/165 -func TestUDPSessionOwnedPacketConn(t *testing.T) { - l := sinkServer(0) - defer l.Close() - - // DialWithOptions creates its own net.PacketConn. - client, err := DialWithOptions(l.Addr().String()) - if err != nil { - panic(err) - } - defer client.Close() - // Replace the internal net.PacketConn with one that remembers when it - // has been closed. - pconn := newClosedFlagPacketConn(client.conn) - client.conn = pconn - - if pconn.Closed { - t.Fatal("owned PacketConn closed before UDPSession.Close()") - } - - err = client.Close() - if err != nil { - panic(err) - } - - if !pconn.Closed { - t.Fatal("owned PacketConn not closed after UDPSession.Close()") - } -} - -// UDPSession should not close a net.PacketConn that it did not create. -// https://github.com/xtaci/kcp-go/issues/165 -func TestUDPSessionNonOwnedPacketConn(t *testing.T) { - l := sinkServer(0) - defer l.Close() - - // Create a net.PacketConn not owned by the UDPSession. - c, err := net.ListenPacket("udp", "127.0.0.1:0") - if err != nil { - panic(err) - } - defer c.Close() - // Make it remember when it has been closed. - pconn := newClosedFlagPacketConn(c) - - client, err := NewConn2(l.Addr(), pconn) - if err != nil { - panic(err) - } - defer client.Close() - - if pconn.Closed { - t.Fatal("non-owned PacketConn closed before UDPSession.Close()") - } - - err = client.Close() - if err != nil { - panic(err) - } - - if pconn.Closed { - t.Fatal("non-owned PacketConn closed after UDPSession.Close()") - } -} diff --git a/gate/net/kcp_connect_manager.go b/gate/net/kcp_connect_manager.go index 527be454..61d9b175 100644 --- a/gate/net/kcp_connect_manager.go +++ b/gate/net/kcp_connect_manager.go @@ -317,9 +317,9 @@ func (k *KcpConnectManager) recvHandle(session *Session) { } recvData := recvBuf[:recvLen] kcpMsgList := make([]*KcpMsg, 0) - k.decodeBinToPayload(recvData, &dataBuf, convId, &kcpMsgList, session.xorKey) + DecodeBinToPayload(recvData, &dataBuf, convId, &kcpMsgList, session.xorKey) for _, v := range kcpMsgList { - protoMsgList := k.protoDecode(v) + protoMsgList := ProtoDecode(v, k.serverCmdProtoMap, k.clientCmdProtoMap) for _, vv := range protoMsgList { k.recvMsgHandle(vv, session) } @@ -341,12 +341,12 @@ func (k *KcpConnectManager) sendHandle(session *Session) { k.closeKcpConn(session, kcp.EnetServerKick) break } - kcpMsg := k.protoEncode(protoMsg) + kcpMsg := ProtoEncode(protoMsg, k.serverCmdProtoMap, k.clientCmdProtoMap) if kcpMsg == nil { logger.Error("decode kcp msg is nil, convId: %v", convId) continue } - bin := k.encodePayloadToBin(kcpMsg, session.xorKey) + bin := EncodePayloadToBin(kcpMsg, session.xorKey) _ = conn.SetWriteDeadline(time.Now().Add(time.Second * ConnSendTimeout)) _, err := conn.Write(bin) if err != nil { diff --git a/gate/net/kcp_endecode.go b/gate/net/kcp_endecode.go index 4259cd6f..5c769c02 100644 --- a/gate/net/kcp_endecode.go +++ b/gate/net/kcp_endecode.go @@ -32,13 +32,13 @@ type KcpMsg struct { ProtoData []byte } -func (k *KcpConnectManager) decodeBinToPayload(data []byte, dataBuf *[]byte, convId uint64, kcpMsgList *[]*KcpMsg, xorKey []byte) { +func DecodeBinToPayload(data []byte, dataBuf *[]byte, convId uint64, kcpMsgList *[]*KcpMsg, xorKey []byte) { // xor解密 endec.Xor(data, xorKey) - k.decodeLoop(data, dataBuf, convId, kcpMsgList) + DecodeLoop(data, dataBuf, convId, kcpMsgList) } -func (k *KcpConnectManager) decodeLoop(data []byte, dataBuf *[]byte, convId uint64, kcpMsgList *[]*KcpMsg) { +func DecodeLoop(data []byte, dataBuf *[]byte, convId uint64, kcpMsgList *[]*KcpMsg) { if len(*dataBuf) != 0 { // 取出之前的缓冲区数据 data = append(*dataBuf, data...) @@ -93,11 +93,11 @@ func (k *KcpConnectManager) decodeLoop(data []byte, dataBuf *[]byte, convId uint *kcpMsgList = append(*kcpMsgList, kcpMsg) // 递归解析 if haveMorePacket { - k.decodeLoop(data[packetLen:], dataBuf, convId, kcpMsgList) + DecodeLoop(data[packetLen:], dataBuf, convId, kcpMsgList) } } -func (k *KcpConnectManager) encodePayloadToBin(kcpMsg *KcpMsg, xorKey []byte) (bin []byte) { +func EncodePayloadToBin(kcpMsg *KcpMsg, xorKey []byte) (bin []byte) { if kcpMsg.HeadData == nil { kcpMsg.HeadData = make([]byte, 0) } diff --git a/gate/net/proto_endecode.go b/gate/net/proto_endecode.go index 000c37dd..7808a70c 100644 --- a/gate/net/proto_endecode.go +++ b/gate/net/proto_endecode.go @@ -4,6 +4,7 @@ import ( "reflect" "hk4e/common/config" + "hk4e/gate/client_proto" "hk4e/pkg/logger" "hk4e/pkg/object" "hk4e/protocol/cmd" @@ -24,17 +25,18 @@ type ProtoMessage struct { message pb.Message } -func (k *KcpConnectManager) protoDecode(kcpMsg *KcpMsg) (protoMsgList []*ProtoMsg) { +func ProtoDecode(kcpMsg *KcpMsg, + serverCmdProtoMap *cmd.CmdProtoMap, clientCmdProtoMap *client_proto.ClientCmdProtoMap) (protoMsgList []*ProtoMsg) { protoMsgList = make([]*ProtoMsg, 0) if config.CONF.Hk4e.ClientProtoProxyEnable { clientCmdId := kcpMsg.CmdId clientProtoData := kcpMsg.ProtoData - cmdName := k.clientCmdProtoMap.GetClientCmdNameByCmdId(clientCmdId) + cmdName := clientCmdProtoMap.GetClientCmdNameByCmdId(clientCmdId) if cmdName == "" { logger.Error("get cmdName is nil, clientCmdId: %v", clientCmdId) return protoMsgList } - clientProtoObj := k.getClientProtoObjByName(cmdName) + clientProtoObj := GetClientProtoObjByName(cmdName, clientCmdProtoMap) if clientProtoObj == nil { logger.Error("get client proto obj is nil, cmdName: %v", cmdName) return protoMsgList @@ -44,12 +46,12 @@ func (k *KcpConnectManager) protoDecode(kcpMsg *KcpMsg) (protoMsgList []*ProtoMs logger.Error("unmarshal client proto error: %v", err) return protoMsgList } - serverCmdId := k.serverCmdProtoMap.GetCmdIdByCmdName(cmdName) + serverCmdId := serverCmdProtoMap.GetCmdIdByCmdName(cmdName) if serverCmdId == 0 { logger.Error("get server cmdId is nil, cmdName: %v", cmdName) return protoMsgList } - serverProtoObj := k.serverCmdProtoMap.GetProtoObjByCmdId(serverCmdId) + serverProtoObj := serverCmdProtoMap.GetProtoObjByCmdId(serverCmdId) if serverProtoObj == nil { logger.Error("get server proto obj is nil, serverCmdId: %v", serverCmdId) return protoMsgList @@ -87,7 +89,7 @@ func (k *KcpConnectManager) protoDecode(kcpMsg *KcpMsg) (protoMsgList []*ProtoMs } // payload msg protoMessageList := make([]*ProtoMessage, 0) - k.protoDecodePayloadLoop(kcpMsg.CmdId, kcpMsg.ProtoData, &protoMessageList) + ProtoDecodePayloadLoop(kcpMsg.CmdId, kcpMsg.ProtoData, &protoMessageList, serverCmdProtoMap, clientCmdProtoMap) if len(protoMessageList) == 0 { logger.Error("decode proto object is nil") return protoMsgList @@ -106,7 +108,8 @@ func (k *KcpConnectManager) protoDecode(kcpMsg *KcpMsg) (protoMsgList []*ProtoMs if msg.PayloadMessage != nil { cmdName = string(msg.PayloadMessage.ProtoReflect().Descriptor().FullName()) } - logger.Debug("[RECV UNION CMD], cmdId: %v, cmdName: %v, convId: %v, headMsg: %v", msg.CmdId, cmdName, msg.ConvId, msg.HeadMessage) + logger.Debug("[RECV UNION CMD], cmdId: %v, cmdName: %v, convId: %v, headMsg: %v", + msg.CmdId, cmdName, msg.ConvId, msg.HeadMessage) } } else { protoMsg.PayloadMessage = protoMessageList[0].message @@ -115,13 +118,15 @@ func (k *KcpConnectManager) protoDecode(kcpMsg *KcpMsg) (protoMsgList []*ProtoMs if protoMsg.PayloadMessage != nil { cmdName = string(protoMsg.PayloadMessage.ProtoReflect().Descriptor().FullName()) } - logger.Debug("[RECV], cmdId: %v, cmdName: %v, convId: %v, headMsg: %v", protoMsg.CmdId, cmdName, protoMsg.ConvId, protoMsg.HeadMessage) + logger.Debug("[RECV], cmdId: %v, cmdName: %v, convId: %v, headMsg: %v", + protoMsg.CmdId, cmdName, protoMsg.ConvId, protoMsg.HeadMessage) } return protoMsgList } -func (k *KcpConnectManager) protoDecodePayloadLoop(cmdId uint16, protoData []byte, protoMessageList *[]*ProtoMessage) { - protoObj := k.decodePayloadToProto(cmdId, protoData) +func ProtoDecodePayloadLoop(cmdId uint16, protoData []byte, protoMessageList *[]*ProtoMessage, + serverCmdProtoMap *cmd.CmdProtoMap, clientCmdProtoMap *client_proto.ClientCmdProtoMap) { + protoObj := DecodePayloadToProto(cmdId, protoData, serverCmdProtoMap) if protoObj == nil { logger.Error("decode proto object is nil") return @@ -137,12 +142,12 @@ func (k *KcpConnectManager) protoDecodePayloadLoop(cmdId uint16, protoData []byt if config.CONF.Hk4e.ClientProtoProxyEnable { clientCmdId := uint16(unionCmd.MessageId) clientProtoData := unionCmd.Body - cmdName := k.clientCmdProtoMap.GetClientCmdNameByCmdId(clientCmdId) + cmdName := clientCmdProtoMap.GetClientCmdNameByCmdId(clientCmdId) if cmdName == "" { logger.Error("get cmdName is nil, clientCmdId: %v", clientCmdId) continue } - clientProtoObj := k.getClientProtoObjByName(cmdName) + clientProtoObj := GetClientProtoObjByName(cmdName, clientCmdProtoMap) if clientProtoObj == nil { logger.Error("get client proto obj is nil, cmdName: %v", cmdName) continue @@ -152,12 +157,12 @@ func (k *KcpConnectManager) protoDecodePayloadLoop(cmdId uint16, protoData []byt logger.Error("unmarshal client proto error: %v", err) continue } - serverCmdId := k.serverCmdProtoMap.GetCmdIdByCmdName(cmdName) + serverCmdId := serverCmdProtoMap.GetCmdIdByCmdName(cmdName) if serverCmdId == 0 { logger.Error("get server cmdId is nil, cmdName: %v", cmdName) continue } - serverProtoObj := k.serverCmdProtoMap.GetProtoObjByCmdId(serverCmdId) + serverProtoObj := serverCmdProtoMap.GetProtoObjByCmdId(serverCmdId) if serverProtoObj == nil { logger.Error("get server proto obj is nil, serverCmdId: %v", serverCmdId) continue @@ -178,7 +183,8 @@ func (k *KcpConnectManager) protoDecodePayloadLoop(cmdId uint16, protoData []byt unionCmd.MessageId = uint32(serverCmdId) unionCmd.Body = serverProtoData } - k.protoDecodePayloadLoop(uint16(unionCmd.MessageId), unionCmd.Body, protoMessageList) + ProtoDecodePayloadLoop(uint16(unionCmd.MessageId), unionCmd.Body, protoMessageList, + serverCmdProtoMap, clientCmdProtoMap) } } *protoMessageList = append(*protoMessageList, &ProtoMessage{ @@ -187,12 +193,14 @@ func (k *KcpConnectManager) protoDecodePayloadLoop(cmdId uint16, protoData []byt }) } -func (k *KcpConnectManager) protoEncode(protoMsg *ProtoMsg) (kcpMsg *KcpMsg) { +func ProtoEncode(protoMsg *ProtoMsg, + serverCmdProtoMap *cmd.CmdProtoMap, clientCmdProtoMap *client_proto.ClientCmdProtoMap) (kcpMsg *KcpMsg) { cmdName := "" if protoMsg.PayloadMessage != nil { cmdName = string(protoMsg.PayloadMessage.ProtoReflect().Descriptor().FullName()) } - logger.Debug("[SEND], cmdId: %v, cmdName: %v, convId: %v, headMsg: %v", protoMsg.CmdId, cmdName, protoMsg.ConvId, protoMsg.HeadMessage) + logger.Debug("[SEND], cmdId: %v, cmdName: %v, convId: %v, headMsg: %v", + protoMsg.CmdId, cmdName, protoMsg.ConvId, protoMsg.HeadMessage) kcpMsg = new(KcpMsg) kcpMsg.ConvId = protoMsg.ConvId kcpMsg.CmdId = protoMsg.CmdId @@ -209,13 +217,14 @@ func (k *KcpConnectManager) protoEncode(protoMsg *ProtoMsg) (kcpMsg *KcpMsg) { } // payload msg if protoMsg.PayloadMessage != nil { - cmdId, protoData := k.encodeProtoToPayload(protoMsg.PayloadMessage) + cmdId, protoData := EncodeProtoToPayload(protoMsg.PayloadMessage, serverCmdProtoMap) if cmdId == 0 || protoData == nil { logger.Error("encode proto data is nil") return nil } if cmdId != 65535 && cmdId != protoMsg.CmdId { - logger.Error("cmd id is not match with proto obj, src cmd id: %v, found cmd id: %v", protoMsg.CmdId, cmdId) + logger.Error("cmd id is not match with proto obj, src cmd id: %v, found cmd id: %v", + protoMsg.CmdId, cmdId) return nil } kcpMsg.ProtoData = protoData @@ -225,7 +234,7 @@ func (k *KcpConnectManager) protoEncode(protoMsg *ProtoMsg) (kcpMsg *KcpMsg) { if config.CONF.Hk4e.ClientProtoProxyEnable { serverCmdId := kcpMsg.CmdId serverProtoData := kcpMsg.ProtoData - serverProtoObj := k.serverCmdProtoMap.GetProtoObjByCmdId(serverCmdId) + serverProtoObj := serverCmdProtoMap.GetProtoObjByCmdId(serverCmdId) if serverProtoObj == nil { logger.Error("get server proto obj is nil, serverCmdId: %v", serverCmdId) return nil @@ -235,12 +244,12 @@ func (k *KcpConnectManager) protoEncode(protoMsg *ProtoMsg) (kcpMsg *KcpMsg) { logger.Error("unmarshal server proto error: %v", err) return nil } - cmdName := k.serverCmdProtoMap.GetCmdNameByCmdId(serverCmdId) + cmdName := serverCmdProtoMap.GetCmdNameByCmdId(serverCmdId) if cmdName == "" { logger.Error("get cmdName is nil, serverCmdId: %v", serverCmdId) return nil } - clientProtoObj := k.getClientProtoObjByName(cmdName) + clientProtoObj := GetClientProtoObjByName(cmdName, clientCmdProtoMap) if clientProtoObj == nil { logger.Error("get client proto obj is nil, cmdName: %v", cmdName) return nil @@ -258,7 +267,7 @@ func (k *KcpConnectManager) protoEncode(protoMsg *ProtoMsg) (kcpMsg *KcpMsg) { logger.Error("marshal client proto error: %v", err) return nil } - clientCmdId := k.clientCmdProtoMap.GetClientCmdIdByCmdName(cmdName) + clientCmdId := clientCmdProtoMap.GetClientCmdIdByCmdName(cmdName) if clientCmdId == 0 { logger.Error("get client cmdId is nil, cmdName: %v", cmdName) return nil @@ -269,8 +278,8 @@ func (k *KcpConnectManager) protoEncode(protoMsg *ProtoMsg) (kcpMsg *KcpMsg) { return kcpMsg } -func (k *KcpConnectManager) decodePayloadToProto(cmdId uint16, protoData []byte) (protoObj pb.Message) { - protoObj = k.serverCmdProtoMap.GetProtoObjByCmdId(cmdId) +func DecodePayloadToProto(cmdId uint16, protoData []byte, serverCmdProtoMap *cmd.CmdProtoMap) (protoObj pb.Message) { + protoObj = serverCmdProtoMap.GetProtoObjByCmdId(cmdId) if protoObj == nil { logger.Error("get new proto object is nil") return nil @@ -283,8 +292,8 @@ func (k *KcpConnectManager) decodePayloadToProto(cmdId uint16, protoData []byte) return protoObj } -func (k *KcpConnectManager) encodeProtoToPayload(protoObj pb.Message) (cmdId uint16, protoData []byte) { - cmdId = k.serverCmdProtoMap.GetCmdIdByProtoObj(protoObj) +func EncodeProtoToPayload(protoObj pb.Message, serverCmdProtoMap *cmd.CmdProtoMap) (cmdId uint16, protoData []byte) { + cmdId = serverCmdProtoMap.GetCmdIdByProtoObj(protoObj) var err error = nil protoData, err = pb.Marshal(protoObj) if err != nil { @@ -294,8 +303,8 @@ func (k *KcpConnectManager) encodeProtoToPayload(protoObj pb.Message) (cmdId uin return cmdId, protoData } -func (k *KcpConnectManager) getClientProtoObjByName(protoObjName string) pb.Message { - fn := k.clientCmdProtoMap.RefValue.MethodByName("GetClientProtoObjByName") +func GetClientProtoObjByName(protoObjName string, clientCmdProtoMap *client_proto.ClientCmdProtoMap) pb.Message { + fn := clientCmdProtoMap.RefValue.MethodByName("GetClientProtoObjByName") ret := fn.Call([]reflect.Value{reflect.ValueOf(protoObjName)}) obj := ret[0].Interface() if obj == nil { diff --git a/gate/net/session.go b/gate/net/session.go index c2e39fef..60a1e267 100644 --- a/gate/net/session.go +++ b/gate/net/session.go @@ -482,10 +482,10 @@ func (k *KcpConnectManager) getPlayerToken(req *proto.GetPlayerTokenReq, session timeRand := random.GetTimeRand() serverSeedUint64 := timeRand.Uint64() session.seed = serverSeedUint64 - if req.GetKeyId() != 0 { + if req.KeyId != 0 { logger.Debug("do hk4e 2.8 rsa logic, uid: %v", uid) session.useMagicSeed = true - keyId := strconv.Itoa(int(req.GetKeyId())) + keyId := strconv.Itoa(int(req.KeyId)) encPubPrivKey, exist := k.encRsaKeyMap[keyId] if !exist { logger.Error("can not found key id: %v, uid: %v", keyId, uid) @@ -504,7 +504,7 @@ func (k *KcpConnectManager) getPlayerToken(req *proto.GetPlayerTokenReq, session loginFailClose() return nil } - clientSeedBase64 := req.GetClientRandKey() + clientSeedBase64 := req.ClientRandKey clientSeedEnc, err := base64.StdEncoding.DecodeString(clientSeedBase64) if err != nil { logger.Error("parse client seed base64 error: %v, uid: %v", err, uid) diff --git a/go.mod b/go.mod index b2a4c2a1..875246a5 100644 --- a/go.mod +++ b/go.mod @@ -5,14 +5,6 @@ go 1.18 // toml require github.com/BurntSushi/toml v0.3.1 -// kcp -require ( - github.com/pkg/errors v0.9.1 - github.com/xtaci/lossyconn v0.0.0-20200209145036-adba10fffc37 - golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be - golang.org/x/net v0.0.0-20211123203042-d83791d6bcd9 -) - // protobuf require google.golang.org/protobuf v1.28.0 @@ -58,7 +50,14 @@ require github.com/yuin/gopher-lua v1.0.0 // lz4 require github.com/pierrec/lz4/v4 v4.1.17 -require golang.org/x/sys v0.0.0-20220928140112-f11e5e49a4ec +// dpdk-go +require github.com/FlourishingWorld/dpdk-go v1.0.1 + +require ( + github.com/pkg/errors v0.9.1 + golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2 + golang.org/x/sys v0.0.0-20220928140112-f11e5e49a4ec +) require ( github.com/cespare/xxhash/v2 v2.1.2 // indirect @@ -88,6 +87,7 @@ require ( github.com/xdg-go/scram v1.0.2 // indirect github.com/xdg-go/stringprep v1.0.2 // indirect github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect + golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be // indirect golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e // indirect golang.org/x/text v0.3.6 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index cc188245..7fe8a13f 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/FlourishingWorld/dpdk-go v1.0.1 h1:6PGyJ5dE8y7J/GlculPEGLrWlMVCfeThIUzfgaHse0c= +github.com/FlourishingWorld/dpdk-go v1.0.1/go.mod h1:dmizWx3DB1EK9Tu1/x4o6VpeGdp9YSpRhFPqYoLlNE0= github.com/arl/statsviz v0.5.1 h1:3HY0ZEB738JtguWsD1Tf1pFJZiCcWUmYRq/3OTYKaSI= github.com/arl/statsviz v0.5.1/go.mod h1:zDnjgRblGm1Dyd7J5YlbH7gM1/+HRC+SfkhZhQb5AnM= github.com/byebyebruce/natsrpc v0.5.5-0.20221125150611-56cd29a4e335 h1:V5qahA5kDL/TBnlwvYjemR5du/uQ7q75qkBBlTc4rXI= @@ -112,8 +114,6 @@ github.com/xdg-go/scram v1.0.2 h1:akYIkZ28e6A96dkWNJQu3nmCzH3YfwMPQExUYDaRv7w= github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= github.com/xdg-go/stringprep v1.0.2 h1:6iq84/ryjjeRmMJwxutI51F2GIPlP5BfTvXHeYjyhBc= github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= -github.com/xtaci/lossyconn v0.0.0-20200209145036-adba10fffc37 h1:EWU6Pktpas0n8lLQwDsRyZfmkPeRbdgPtW609es+/9E= -github.com/xtaci/lossyconn v0.0.0-20200209145036-adba10fffc37/go.mod h1:HpMP7DB2CyokmAh4lp0EQnnWhmycP/TvwBGzvuie+H0= github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= github.com/yuin/gopher-lua v1.0.0 h1:pQCf0LN67Kf7M5u7vRd40A8M1I8IMLrxlqngUJgZ0Ow= @@ -130,8 +130,8 @@ golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be/go.mod h1:IxCIyHEi3zRg3s0 golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211123203042-d83791d6bcd9 h1:0qxwC5n+ttVOINCBeRHO0nq9X7uy8SDsPoi5OaCdIEI= -golang.org/x/net v0.0.0-20211123203042-d83791d6bcd9/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2 h1:CIJ76btIcR3eFI5EgSo6k1qKw9KJexJuRLI9G7Hp5wE= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/robot/cmd/application.toml b/robot/cmd/application.toml new file mode 100644 index 00000000..12346ebe --- /dev/null +++ b/robot/cmd/application.toml @@ -0,0 +1,5 @@ +[logger] +level = "DEBUG" +mode = "CONSOLE" +track = true +max_size = 10485760 diff --git a/robot/cmd/main.go b/robot/cmd/main.go new file mode 100644 index 00000000..3c53577c --- /dev/null +++ b/robot/cmd/main.go @@ -0,0 +1,67 @@ +package main + +import ( + "os" + "os/signal" + "syscall" + "time" + + "hk4e/common/config" + hk4egatenet "hk4e/gate/net" + "hk4e/pkg/logger" + "hk4e/protocol/cmd" + "hk4e/protocol/proto" + "hk4e/robot/net" + + "github.com/FlourishingWorld/dpdk-go/engine" +) + +func main() { + config.InitConfig("application.toml") + logger.InitLogger("robot") + + err := engine.InitEngine("00:0C:29:3E:3E:DF", "192.168.199.199", "255.255.255.0", "192.168.199.1") + if err != nil { + panic(err) + } + engine.RunEngine([]int{0, 1, 2, 3}, 1, "0.0.0.0") + + time.Sleep(time.Second * 30) + + session := net.NewSession("192.168.199.233:22222", []byte{0x00}) + go func() { + protoMsg := <-session.RecvChan + logger.Debug("%v", protoMsg) + }() + go func() { + session.SendChan <- &hk4egatenet.ProtoMsg{ + ConvId: 0, + CmdId: cmd.GetPlayerTokenReq, + HeadMessage: &proto.PacketHead{ + ClientSequenceId: 1, + SentMs: uint64(time.Now().UnixMilli()), + }, + PayloadMessage: &proto.GetPlayerTokenReq{ + AccountToken: "xxxxxx", + AccountUid: "10001", + KeyId: 0, + ClientRandKey: "", + }, + } + }() + + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT) + for { + s := <-c + switch s { + case syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT: + engine.StopEngine() + time.Sleep(time.Second) + return + case syscall.SIGHUP: + default: + return + } + } +} diff --git a/robot/net/session.go b/robot/net/session.go new file mode 100644 index 00000000..c7320981 --- /dev/null +++ b/robot/net/session.go @@ -0,0 +1,109 @@ +package net + +import ( + "time" + + hk4egatenet "hk4e/gate/net" + "hk4e/pkg/logger" + "hk4e/pkg/random" + "hk4e/protocol/cmd" + + "github.com/FlourishingWorld/dpdk-go/protocol/kcp" +) + +type Session struct { + SendChan chan *hk4egatenet.ProtoMsg + RecvChan chan *hk4egatenet.ProtoMsg + conn *kcp.UDPSession + seed uint64 // TODO 密钥交换后收到的服务器生成的seed + xorKey []byte + changeXorKeyFin bool + useMagicSeed bool +} + +func NewSession(gateAddr string, dispatchKey []byte) (r *Session) { + conn, err := kcp.DialWithOptions(gateAddr, "0.0.0.0:30000") + if err != nil { + logger.Error("kcp client conn to server error: %v", err) + return + } + conn.SetACKNoDelay(true) + conn.SetWriteDelay(false) + sendChan := make(chan *hk4egatenet.ProtoMsg, 1000) + recvChan := make(chan *hk4egatenet.ProtoMsg, 1000) + r = &Session{ + SendChan: sendChan, + RecvChan: recvChan, + conn: conn, + seed: 0, + xorKey: dispatchKey, + changeXorKeyFin: false, + useMagicSeed: true, + } + go r.recvHandle() + go r.sendHandle() + return r +} + +func (s *Session) recvHandle() { + logger.Info("recv handle start") + conn := s.conn + convId := conn.GetConv() + recvBuf := make([]byte, hk4egatenet.PacketMaxLen) + dataBuf := make([]byte, 0, 1500) + for { + _ = conn.SetReadDeadline(time.Now().Add(time.Second * hk4egatenet.ConnRecvTimeout)) + recvLen, err := conn.Read(recvBuf) + if err != nil { + logger.Error("exit recv loop, conn read err: %v, convId: %v", err, convId) + _ = conn.Close() + break + } + recvData := recvBuf[:recvLen] + kcpMsgList := make([]*hk4egatenet.KcpMsg, 0) + hk4egatenet.DecodeBinToPayload(recvData, &dataBuf, convId, &kcpMsgList, s.xorKey) + for _, v := range kcpMsgList { + protoMsgList := hk4egatenet.ProtoDecode(v, nil, nil) + for _, vv := range protoMsgList { + s.RecvChan <- vv + if s.changeXorKeyFin == false && vv.CmdId == cmd.GetPlayerTokenRsp { + // XOR密钥切换 + logger.Info("change session xor key, convId: %v", convId) + s.changeXorKeyFin = true + keyBlock := random.NewKeyBlock(s.seed, s.useMagicSeed) + xorKey := keyBlock.XorKey() + key := make([]byte, 4096) + copy(key, xorKey[:]) + s.xorKey = key + } + } + } + } +} + +func (s *Session) sendHandle() { + logger.Info("send handle start") + conn := s.conn + convId := conn.GetConv() + for { + protoMsg, ok := <-s.SendChan + if !ok { + logger.Error("exit send loop, send chan close, convId: %v", convId) + _ = conn.Close() + break + } + kcpMsg := hk4egatenet.ProtoEncode(protoMsg, nil, nil) + if kcpMsg == nil { + logger.Error("decode kcp msg is nil, convId: %v", convId) + continue + } + bin := hk4egatenet.EncodePayloadToBin(kcpMsg, s.xorKey) + _ = conn.SetWriteDeadline(time.Now().Add(time.Second * hk4egatenet.ConnSendTimeout)) + _, err := conn.Write(bin) + if err != nil { + logger.Error("exit send loop, conn write err: %v, convId: %v", err, convId) + _ = conn.Close() + break + } + } +}