优化架构

This commit is contained in:
huangxiaolei
2022-11-23 18:05:11 +08:00
parent 3efed3defe
commit 43403202b5
6760 changed files with 33748 additions and 554768 deletions
+11
View File
@@ -0,0 +1,11 @@
[hk4e]
kcp_addr = "hk4e.flswld.com"
kcp_port = 22103
[logger]
level = "DEBUG"
method = "CONSOLE"
track_line = true
[mq]
nats_url = "nats://nats1:4222,nats://nats2:4222,nats://nats3:4222"
+70
View File
@@ -0,0 +1,70 @@
package main
import (
"github.com/arl/statsviz"
"hk4e/common/config"
"hk4e/gate/forward"
"hk4e/gate/mq"
"hk4e/gate/net"
"hk4e/logger"
"hk4e/protocol/cmd"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"syscall"
"time"
)
func main() {
filePath := "./application.toml"
config.InitConfig(filePath)
logger.InitLogger("gate")
logger.LOG.Info("gate start")
go func() {
// 性能检测
err := statsviz.RegisterDefault()
if err != nil {
logger.LOG.Error("statsviz init error: %v", err)
}
err = http.ListenAndServe("0.0.0.0:2345", nil)
if err != nil {
logger.LOG.Error("perf debug http start error: %v", err)
}
}()
kcpEventInput := make(chan *net.KcpEvent)
kcpEventOutput := make(chan *net.KcpEvent)
protoMsgInput := make(chan *net.ProtoMsg, 10000)
protoMsgOutput := make(chan *net.ProtoMsg, 10000)
netMsgInput := make(chan *cmd.NetMsg, 10000)
netMsgOutput := make(chan *cmd.NetMsg, 10000)
connectManager := net.NewKcpConnectManager(protoMsgInput, protoMsgOutput, kcpEventInput, kcpEventOutput)
connectManager.Start()
forwardManager := forward.NewForwardManager(protoMsgInput, protoMsgOutput, kcpEventInput, kcpEventOutput, netMsgInput, netMsgOutput)
forwardManager.Start()
messageQueue := mq.NewMessageQueue(netMsgInput, netMsgOutput)
messageQueue.Start()
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT)
for {
s := <-c
logger.LOG.Info("get a signal %s", s.String())
switch s {
case syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT:
logger.LOG.Info("gate exit")
messageQueue.Close()
time.Sleep(time.Second)
return
case syscall.SIGHUP:
default:
return
}
}
}
+1
View File
@@ -0,0 +1 @@
* binary
+15
View File
@@ -0,0 +1,15 @@
-----BEGIN RSA PRIVATE KEY-----
MIICXQIBAAKBgQDHPnAvEbJfMUwHXmRLiNDH1qFeGm0U/D6n7BjzEmJl5VtMKBZF
hnz+aKsyMo9aAowi2Fe/6iWUuzcbnAJS+4iLUxaeqOdvPe5LuR3wQxHKGJ8XsDkH
kt3T1operE1rpw9wX3xuUi0CA5aHqpC0ho0zMsk7nvxWQogv1G8uqcXmfQIDAQAB
AoGBALElFmEC/vAbyFkU119A+T9z2GzuWeW6j4qFI3mZ8tpdnVqMmaCe/irDrNIo
mcORWD7y0rHS4C7odQqbHoXhFXgXfrGJXcMu977uIxBKGj0UBz6YIciznk/8DrMo
o3q6+SGsNj5zvlU8oY6cpfC663VoQb7VWveUGN4zshdnvyiRAkEA8nlq/LEuQPCj
lp5wbUizJ3Uwll5N51N6Kzm1wRQ0vUtIzRK940lGMxlhihnJfifTColAnnzmWj/X
dWIULqIc0wJBANJbrnq1iim9Jue0UOhQn6hV8vvWHgLjK7zuEsUPDqzxfhmpmBEh
BwAaH3li6bGCbIfSJazs+LmNLB4YtMo6nW8CQAMtmjxjqiKJxOslen3ENSzwOUnP
RKAilPhaEkrMlABjKzoc48ZF4Jis3X1s5xozNW3u7JznMDHAondUaMVPtKcCQG45
9lp8aBJo+ErvlHm3TYHiz7kgwIcYzKFqStGRi0oaHM6LrJBFMyrdhWKQ7w3B3ubo
ui872TU5gUWgApP5VOcCQQDDvU76TpLQ2v2LO8D0L/Ds+t6HdGcPpKvlAm/YQYHL
X6Q435tFNbeWo3JzpGElb25zAQfXU5cvzYvg37f36iM6
-----END RSA PRIVATE KEY-----
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
+1
View File
@@ -0,0 +1 @@
ElIKBm9zX3VzYRIHQW1lcmljYRoKREVWX1BVQkxJQyIzaHR0cHM6Ly9vc3VzYWRpc3BhdGNoLnl1YW5zaGVuLmNvbS9xdWVyeV9jdXJfcmVnaW9uElMKB29zX2V1cm8SBkV1cm9wZRoKREVWX1BVQkxJQyI0aHR0cHM6Ly9vc2V1cm9kaXNwYXRjaC55dWFuc2hlbi5jb20vcXVlcnlfY3VyX3JlZ2lvbhJRCgdvc19hc2lhEgRBc2lhGgpERVZfUFVCTElDIjRodHRwczovL29zYXNpYWRpc3BhdGNoLnl1YW5zaGVuLmNvbS9xdWVyeV9jdXJfcmVnaW9uElUKBm9zX2NodBIKVFcsIEhLLCBNTxoKREVWX1BVQkxJQyIzaHR0cHM6Ly9vc2NodGRpc3BhdGNoLnl1YW5zaGVuLmNvbS9xdWVyeV9jdXJfcmVnaW9uKpwQRWMyYhAAAABbrAvbhfIRHfaSCN24qQyVAAgAAMs68ZiMdPfEj41O2wBCYqGiC/WdovvJvaw4t3/m1zIYDrt3/ftK9GKFb7C+2E8FmaHqOnwjJYBg2wI1sXpGmuSxkeWw8Avr36wlNtQjhXNV9zoNKstuZYuheyLlpbPRbYZ3UA6/BzTVsjIhjR1lcqFrigQnpV6MgRR9KqxakCaffK6qIzMlodx4ZPKlqseQhCiyVAvLWQSRqCRcZipzotXsmgLQbpDFtRzhgukXPjfW5dAlzMwswPuu7ZQsf1AKipI34dVQLu6gtXthGgbjn89h/79VR5AokLCPGqIV7/2s+gHfykrjDtyp5rwCcmGQqwV3gHy5LGrHl8Zm12jNd7Qcng51ydqtX4xzet6J2iMF6Dw5nPd/hTyxn+i3Ttk6fop9rbCq3iNgEw3+0cSDal1I1ThYdVnMgPhZgQkZc5/SpTaR+8vfDzRIKbSSrrPSEgLnQvWZOOugXhNdyuiaBc8rJveno7vvktmnhDUF3xWi6osj75j2KghRrdHfDR3Zuh4COrGZDRBSKHft2AvfrxaMT9O8hPzzzYk0U2iicVCDlNP/8wqaT9Vqt1kHmruLxqh377iyp0mxKfNt0+SNRzLyRoyvOar/z3AT6TU9LRoCFrkcJpVsUN+2MVeT52PfMbv5O/Nw9sqsFDlofCJJ/EknY0wDc+tNarYOhDM67/ojn/p6W3ZPBJxb2wcF1TOh9dpAeZdCGJusqhMIj5lpoW8nENTFhkEgMUv2Lh5Z6WpeOAKAu9eDpBMhlRNCccDaNYUgo6TdVDtWxtPrS3NRYqtkvb2I2SEFP0apht954oKdG3ncxyOgHRUkwgtxbCMAngzWo9+VWV3H3OlqeEOv7DdO2o0y95EvlHYb/qtosXPI2jC+6FPa+yl4xmLqcENRTUrU23dsmX3SyBEmZvML4dNeyC53B+mh7DUFtPFJFndxj2tGO9mTSDgy8eCmKG90AiJOMoxaLB2HpnDXN1sTiIcd3WraiE6ZCt4E54hKXvXHPyN52CHkxq1y/TeXHEq4X4MyHyDSRLHmzVs9pnwHM0ZLthKFNyvGfTvjiYokAWtNEuh74syt+m6Wietb6JvgibnnDj6uFKI3BbH4GUT9blsnMgug323bJ6bFvV4iESvz1fNnnUSokWQy5+fWzxPDohULgFzhDCpwov78Bp0E3t6DXSWnrUdNqpLbYKmXO1Hdbn+QH4B90p85UB1V5eSZgxPpUvZbIO4GPScil8K+dkDLdsFa1zypWNmlUN0Ns5H/iuzMuJql2QFYz+SnV1R1T+qywwqCNP9oswcLiAR3XnSacs52vd3PI9+0PZuoF6tVMWlvutsQ34IFZaAwIkdKigZcHumLBt/0KyFASBfN674n8FnHrHOQHU6oCeXkQA9kC8MtkvMb7fOLdzbTsD6SVojzZ64i9mDXxF+iLR9o52OxjIFzwLGRy/ivT/aAnHLZ3AsbnvslDjlQl2ADBFvf7xjmvFu0xlfK58TUpfVEkScFFapWJyKVybB4CRz1wKKz6n/a9581LpCVOWRsJa5p+j0zYcS2PfhmRf3RzwsDHeBjEVlIARbhxNKvmjdZyIidSdMMcsJHDRLE3bvo9kKfag0vRVKmuPLPc9FrACsz3vlkApcVQvzieHWoiP+foEvfj9+7Ti2tLfKdzVkMUmugZiZ46+7PKvIciiiuBPlyld0CCPTtTFHUOMO5dUfrUblX8K3awWiaNQFBS0J3iK08t1bgWfLhsKzsS32fRWugaqecwO9Rji9oHn+UuN8Nz9SgNxodroq9q7y/KHFxbqjCl62g25HN9zUa/s5wnIRwVAiWgTuOe3qGqjwp5m/GR8YVSSK/8mV9EL4AaF8d1uifdVA6wWSH1e/1UB8vcdU83P8ne3u1ho+Y/57WB7KnQaGaiD/108+wiAxNqMb2ex8on01VxdLKV1makXV3gzsvWaRevW8t/K11ZwYfo9g+guWADsA0JO0jWooiaupq1kNWrEheBdSRXBO7Jnb+56cTjPGwLpp7ZOHe/bSCJ4MGzPF3lK66LXhVo+rxvNjhoKVRjhGYxN4T8+AiRo3r+1KwdIGSrtODp3ri3JWAy6Eajp1Ukp9GaCbHSJFnYml84nKew7zLLe//ExQpjd4QAjMTvnbm+Ff6a1jf69QEVo0I33gI7/buwqgjiuvjeL6EYaMolKrKlHZHf/HwWbFbdID8T9aoyZJuCUd6YHaMPRAS6n5nvTwkRLlJ/f6wgyypUGZ22Bb1qGIb9SoPgSgIJkifUoewQW2EexqfoAsHXJVABLy+jp/SC4xzHZOSh42zU1k80HIgrnSOmu6T56F6gqy4Y2cZuZU8LXbO/01u8ifEz8yaXfEFSFdxE0TWl92OLKFtJZr9nNOBQQQr5FDGf6zB1/0CziG/5+PrUDgG3irzho6+7wXkc2CpxlBKOLWdjs3V/Lab6cURz1QZY4HYgUkJtm4U5OKUeO2+murlhC7SrnwyUtGrsD8NbCmI4SRHKPoeLBJQO/m3dRze5Ltr8N9IS7/ukPeOYe1O2agrmhH/JjYfz/l8Gmq8PGY+oavYp8I+2yKvGLD9kCxEgKcTeRh9AW/xPTLGsacrGKQCY+M76DfyLKxCZDiDY9xkBIKchxsMsn7FqZvRMMyJBHbqa3AKQyAN73NCSuFF5f1qDjARU/xqJFhOaKoR64c78oqh1GqOqEFbfNQIRw6WeFCGyW6v6p10uLdR7KXnR7+wub9aG992MpIBk0+gru74yO/WcA0vLdDEQIBwc+M0lmLB53ylsPtde3nliaC5ROHR1IS4LO8Q+3o0BHMr0my0bqFwwCAvZVXOFBHxXyUgrrmUTnZYVSQXNV6+MALBmmRU5yOzhhyHoEdj9YHZeyPpZkYc6DkJWCRYbFfmczNIs133KB9rlfug40w/hHa8pXyRyLaKQUMIUYEvt3Y4AQ==
+1
View File
@@ -0,0 +1 @@
ElIKBm9zX3VzYRIHQW1lcmljYRoKREVWX1BVQkxJQyIzaHR0cHM6Ly9vc3VzYWRpc3BhdGNoLnl1YW5zaGVuLmNvbS9xdWVyeV9jdXJfcmVnaW9uElMKB29zX2V1cm8SBkV1cm9wZRoKREVWX1BVQkxJQyI0aHR0cHM6Ly9vc2V1cm9kaXNwYXRjaC55dWFuc2hlbi5jb20vcXVlcnlfY3VyX3JlZ2lvbhJRCgdvc19hc2lhEgRBc2lhGgpERVZfUFVCTElDIjRodHRwczovL29zYXNpYWRpc3BhdGNoLnl1YW5zaGVuLmNvbS9xdWVyeV9jdXJfcmVnaW9uElUKBm9zX2NodBIKVFcsIEhLLCBNTxoKREVWX1BVQkxJQyIzaHR0cHM6Ly9vc2NodGRpc3BhdGNoLnl1YW5zaGVuLmNvbS9xdWVyeV9jdXJfcmVnaW9uKpwQRWMyYhAAAABbrAvbhfIRHfaSCN24qQyVAAgAAMs68ZiMdPfEj41O2wBCYqGiC/WdovvJvaw4t3/m1zIYDrt3/ftK9GKFb7C+2E8FmaHqOnwjJYBg2wI1sXpGmuSxkeWw8Avr36wlNtQjhXNV9zoNKstuZYuheyLlpbPRbYZ3UA6/BzTVsjIhjR1lcqFrigQnpV6MgRR9KqxakCaffK6qIzMlodx4ZPKlqseQhCiyVAvLWQSRqCRcZipzotXsmgLQbpDFtRzhgukXPjfW5dAlzMwswPuu7ZQsf1AKipI34dVQLu6gtXthGgbjn89h/79VR5AokLCPGqIV7/2s+gHfykrjDtyp5rwCcmGQqwV3gHy5LGrHl8Zm12jNd7Qcng51ydqtX4xzet6J2iMF6Dw5nPd/hTyxn+i3Ttk6fop9rbCq3iNgEw3+0cSDal1I1ThYdVnMgPhZgQkZc5/SpTaR+8vfDzRIKbSSrrPSEgLnQvWZOOugXhNdyuiaBc8rJveno7vvktmnhDUF3xWi6osj75j2KghRrdHfDR3Zuh4COrGZDRBSKHft2AvfrxaMT9O8hPzzzYk0U2iicVCDlNP/8wqaT9Vqt1kHmruLxqh377iyp0mxKfNt0+SNRzLyRoyvOar/z3AT6TU9LRoCFrkcJpVsUN+2MVeT52PfMbv5O/Nw9sqsFDlofCJJ/EknY0wDc+tNarYOhDM67/ojn/p6W3ZPBJxb2wcF1TOh9dpAeZdCGJusqhMIj5lpoW8nENTFhkEgMUv2Lh5Z6WpeOAKAu9eDpBMhlRNCccDaNYUgo6TdVDtWxtPrS3NRYqtkvb2I2SEFP0apht954oKdG3ncxyOgHRUkwgtxbCMAngzWo9+VWV3H3OlqeEOv7DdO2o0y95EvlHYb/qtosXPI2jC+6FPa+yl4xmLqcENRTUrU23dsmX3SyBEmZvML4dNeyC53B+mh7DUFtPFJFndxj2tGO9mTSDgy8eCmKG90AiJOMoxaLB2HpnDXN1sTiIcd3WraiE6ZCt4E54hKXvXHPyN52CHkxq1y/TeXHEq4X4MyHyDSRLHmzVs9pnwHM0ZLthKFNyvGfTvjiYokAWtNEuh74syt+m6Wietb6JvgibnnDj6uFKI3BbH4GUT9blsnMgug323bJ6bFvV4iESvz1fNnnUSokWQy5+fWzxPDohULgFzhDCpwov78Bp0E3t6DXSWnrUdNqpLbYKmXO1Hdbn+QH4B90p85UB1V5eSZgxPpUvZbIO4GPScil8K+dkDLdsFa1zypWNmlUN0Ns5H/iuzMuJql2QFYz+SnV1R1T+qywwqCNP9oswcLiAR3XnSacs52vd3PI9+0PZuoF6tVMWlvutsQ34IFZaAwIkdKigZcHumLBt/0KyFASBfN674n8FnHrHOQHU6oCeXkQA9kC8MtkvMb7fOLdzbTsD6SVojzZ64i9mDXxF+iLR9o52OxjIFzwLGRy/ivT/aAnHLZ3AsbnvslDjlQl2ADBFvf7xjmvFu0xlfK58TUpfVEkScFFapWJyKVybB4CRz1wKKz6n/a9581LpCVOWRsJa5p+j0zYcS2PfhmRf3RzwsDHeBjEVlIARbhxNKvmjdZyIidSdMMcsJHDRLE3bvo9kKfag0vRVKmuPLPc9FrACsz3vlkApcVQvzieHWoiP+foEvfj9+7Ti2tLfKdzVkMUmugZiZ46+7PKvIciiiuBPlyld0CCPTtTFHUOMO5dUfrUblX8K3awWiaNQFBS0J3iK08t1bgWfLhsKzsS32fRWugaqecwO9Rji9oHn+UuN8Nz9SgNxodroq9q7y/KHFxbqjCl62g25HN9zUa/s5wnIRwVAiWgTuOe3qGqjwp5m/GR8YVSSK/8mV9EL4AaF8d1uifdVA6wWSH1e/1UB8vcdU83P8ne3u1ho+Y/57WB7KnQaGaiD/108+wiAxNqMb2ex8on01VxdLKV1makXV3gzsvWaRevW8t/K11ZwYfo9g+guWADsA0JO0jWooiaupq1kNWrEheBdSRXBO7Jnb+56cTjPGwLpp7ZOHe/bSCJ4MGzPF3lK66LXhVo+rxvNjhoKVRjhGYxN4T8+AiRo3r+1KwdIGSrtODp3ri3JWAy6Eajp1Ukp9GaCbHSJFnYml84nKew7zLLe//ExQpjd4QAjMTvnbm+Ff6a1jf69QEVo0I33gI7/buwqgjiuvjeL6EYaMolKrKlHZHf/HwWbFbdID8T9aoyZJuCUd6YHaMPRAS6n5nvTwkRLlJ/f6wgyypUGZ22Bb1qGIb9SoPgSgIJkifUoewQW2EexqfoAsHXJVABLy+jp/SC4xzHZOSh42zU1k80HIgrnSOmu6T56F6gqy4Y2cZuZU8LXbO/01u8ifEz8yaXfEFSFdxE0TWl92OLKFtJZr9nNOBQQQr5FDGf6zB1/0CziG/5+PrUDgG3irzho6+7wXkc2CpxlBKOLWdjs3V/Lab6cURz1QZY4HYgUkJtm4U5OKUeO2+murlhC7SrnwyUtGrsD8NbCmI4SRHKPoeLBJQO/m3dRze5Ltr8N9IS7/ukPeOYe1O2agrmhH/JjYfz/l8Gmq8PGY+oavYp8I+2yKvGLD9kCxEgKcTeRh9AW/xPTLGsacrGKQCY+M76DfyLKxCZDiDY9xkBIKchxsMsn7FqZvRMMyJBHbqa3AKQyAN73NCSuFF5f1qDjARU/xqJFhOaKoR64c78oqh1GqOqEFbfNQIRw6WeFCGyW6v6p10uLdR7KXnR7+wub9aG992MpIBk0+gru74yO/WcA0vLdDEQIBwc+M0lmLB53ylsPtde3nliaC5ROHR1IS4LO8Q+3o0BHMr0my0bqFwwCAvZVXOFBHxXyUgrrmUTnZYVSQXNV6+MALBmmRU5yOzhhyHoEdj9YHZeyPpZkYc6DkJWCRYbFfmczNIs133KB9rlfug40w/hHa8pXyRyLaKQUMIUYEvt3Y4AQ==
+27
View File
@@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAz/fyfozlDIDWG9e3Lb29+7j3c66wvUJBaBWP10rB9HTE6prj
fcGMqC9imr6zAdD9q+Gr1j7egvqgi3Da+VBAMFH92/5wD5PsD7dX8Z2f4o65Vk2n
VOY8Dl75Z/uRhg0Euwnfrved69z9LG6utmlyv6YUPAflXh/JFw7Dq6c4EGeR+Kej
FTwmVhEdzPGHjXhFmsVt9HdXRYSf4NxHPzOwj8tiSaOQA0jC4E4mM7rvGSH5GX6h
ma+7pJnl/5+rEVM0mSQvm0m1XefmuFy040bEZ/6O7ZenOGBsvvwuG3TT4FNDNzW8
Dw9ExH1l6NoRGaVkDdtrl/nFu5+a09Pm/E0ElwIDAQABAoIBAQCtH17Cck+KJQYX
j29xqG4qykNUDawbILiKCMkBE7553Wq/UcjmuuR4bVnML8ucS3mgR/BgHV3l8vUK
nxvqRx/oGZkWNazbiuwL+ThAblLWqrEmYuZVCoQcAnvkT8tIqDWz7fhDEuZnnkMz
ZcATIZzgZUSa5IfP3u3rP+MrVbyaCdzJEeI0Yrv1XT+M5ddkKQrYgqC5kRiYi/Lj
NcLJhqSVt8p37CdJx1PGHFjKKb4MZpANlNRgeTtWpGVfS0PJLzaiI1NyPSJv7xWZ
gVhbK9+wQxqSG6KmZ4vpEvRI1zKiov5BsAFN+GfuD5mpn1Xo9CpzTfj/sO13VpHH
+Mt80+yBAoGBAPYXVEcXug5zqkqXup4dp1S05saz1zWPhUhQm+CrbhgeTqpjngJJ
EB79qMrGmyki0P/cGtbTcrHf8+i7gDlIGW0OMb4/jn4f5ACVD00iyvkHSGPn0Aim
MoNOMbkGot7SkSnncwxXdawwDyTu2dofXuBr72+GYqgRAG52IuA0C0pRAoGBANhX
p/UyW/htB27frKch/rTKQKm12kBV20AkkRUQUibiiQyWueWKs+5bVaW5R5oDIhWx
qftJtnEFWUvWaTHpHsB/bpjS3CJ6WknqNbpa3QIScpV1uw8V+Etz/K2/ftjyZzFo
nqc+Jud5364xFdIlOsRj9gZnK83Wcui6EFxAer5nAoGBAJzTzzSjLUHqejqhKR98
nFeCFZPJpjuO5AxqunvaJAYgwlcZtueT8j8dvgTDvrvfYTu85CnFhNFQfFrzqspW
ZUW3hwHL9R3xatboJ2Er7Bf5iSuJ3my0pXpCSbO1Q/QmUrZWtl3GGsqJsg0CXjkA
RvFUN7ll9ddPRmwewykIYa2RAoGAcmKuWFNfE1O6WWIELH4p6LcDR3fyRI/gk+KB
nyx48zxVkAVllrsmdYFvIGd9Ny4u6F9+a3HG960HULS1/ACxFMCL3lumrsgYUvp1
m+mM7xqH4QRVeh14oZRa5hbY36YS76nMMMsI0Ny8aqJjUjADCXF81FfabkPTj79J
BS3GeEMCgYAXmFIT079QokHjJrWz/UaoEUbrNkXB/8vKiA4ZGSzMtl3jUPQdXrVf
e0ofeKiqCQs4f4S0dYEjCv7/OIijV5L24mj/Z1Q4q++S5OksKLPPAd3gX4AYbRcg
PS4rUKl1oDk/eYN0CNYC/DYV9sAv65lX8b35HYhvXISVYtwwQu/+Yg==
-----END RSA PRIVATE KEY-----
+27
View File
@@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEA02M1I1V/YvxANOvLFX8R7D8At40IlT7HDWpAW3t+tAgQ7sqj
CeYOxiXqOaaw2kJhM3HT5nZll48UmykVq45Q05J57nhdSsGXLJshtLcTg9liMEoW
61BjVZi9EPPRSnE05tBJc57iqZw+aEcaSU0awfzBc8IkRd6+pJ5iIgEVfuTluani
zhHWvRli3EkAF4VNhaTfP3EkYfr4NE899aUeScbbdLFI6u1XQudlJCPTxaISx5Zc
wM+nP3v242ABcjgUcfCbz0AY547WazK4bWP3qicyxo4MoLOoe9WBq6EuG4CuZQrz
Knq8ltSxud/6chdg8Mqp/IasEQ2TpvY78tEXDQIDAQABAoIBAQC4uPsYk4AsSe75
0Au6Dz7kSfIgdDhJ44AisvTmfLauMFZLtfxfjBDhCwTxuD7XnCZAxHm97Ty+AqSp
Km/raQQsvtWalMhBqYanzjDYMRv2niJ1vGjm3WrQxBaEF+yOtvrZsK5fQTslqInI
qknIQH7fgjazJ7Z28D18sYNj37qfFWSSymgFo+SoS/BKEr200lpRA/oaGXiHcyIO
jJidP6b7UGes7uhMXUvLrfozmCsSqslxXO5Uk5XN/fWl4LxCGX7mpNfPZIT5YBSj
HliFkNlxIjyJg8ORLGi82M2cuyxp39r93F6uaCjLtb+rdwlGur7npgXUkKfWQJf9
WE7uar6BAoGBAPXIuIuYFFUhqNz5CKU014jZu6Ql0z5ZA08V84cTJcfLIK4e2rqC
8DFTldA0FtVfOGt0V08H/x2pRChGOvUwGG5nn9Dqqh6BjByUrW4z2hnXzT3ZuSDh
6eapiCB1jl9meJ0snhF2Ps/hqWGL2b3SkCCe90qVTzOVOeLO6YUCIOq9AoGBANws
fQkAq/0xw8neRGNTrnXimvbS+VXPIF38widljubNN7DY5cIFTQJrnTBWKbuz/t9a
J8QX6TFL0ci/9vhPJoThfL12vL2kWGYgWkWRPmqaBW3yz7Hs5rt+xuH3/7A5w5vm
kEg1NZJgnsJ0rMUTu1Q6PM5CBg6OpyHY4ThBb8qRAoGAML8ciuMgtTm1yg3CPzHZ
xZSZeJbf7K+uzlKmOBX+GkAZPS91ZiRuCvpu7hpGpQ77m6Q5ZL1LRdC6adpz+wkM
72ix87d3AhHjfg+mzgKOsS1x0WCLLRBhWZQqIXXvRNCH/3RH7WKsVoKFG4mnJ9TJ
LQ8aMLqoOKzSDD/JZM3lRWkCgYA8hn5Y2zZshCGufMuQApETFxhCgfzI+geLztAQ
xHpkOEX296kxjQN+htbPUuBmGTUXcVE9NtWEF7Oz3BGocRnFrbb83odEGsmySXKH
bUYbR/v2Ham638UOBevmcqZ3a2m6kcdYEkiH1MfP7QMRqjr1DI1qpfvERLLtOxGu
xU5WAQKBgQCaVavyY6Slj3ZRQ7iKk9fHkge/bFl+zhANxRfWVOYMC8mD2gHDsq9C
IdCp1Mg0tJpWLaGgyDM1kgChZYsff4jRxHC4czvAtoPSlxWXF2nL31qJ3vk2Zzzc
a4GSHAInodXBrKstav5SIKosWKT2YysxgHlA9Sm2f4n09GjFbslEHg==
-----END RSA PRIVATE KEY-----
+27
View File
@@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAyaxqjPJP5+Innfv5IdfQqY/ftS++lnDRe3EczNkIjESWXhHS
OljEw9b9C+/BtF+fO9QZL7Z742y06eIdvsMPQKdGflB26+9OZ8AF4SpXDn3aVWGr
8+9qpB7BELRZI/Ph2FlFL4cobCzMHunncW8zTfMId48+fgHkAzCjRl5rC6XT0Yge
6+eKpXmF+hr0vGYWiTzqPzTABl44WZo3rw0yurZTzkrmRE4kR2VzkjY/rBnQAbFK
KFUKsUozjCXvSag4l461wDkhmmyivpNkK5cAxuDbsmC39iqagMt9438fajLVvYOv
pVs9ci5tiLcbBtfB4Rf/QVAkqtTm86Z0O3e7DwIDAQABAoIBAQCyma226vTW35LE
N5zXWuAg+hhcxk6bvofWMUMXKvGF/0vHPTMXlvuSkDeDNa4vBivneRthBNPMgb3q
DuTWxrogQMOOI8ZdhY3DFexfDvcQD2anDJuSqSmg9Nd36q+yxk3xIoXB5Ilo23dd
vTnJXHhsBNovv7zRLO134cAHFqDoKzt5EEHre0skUcn6HjHOek6c53jvpKr5LSrr
iwx5gMuY/7ZSIUDo9WGY70qbQFGY6bOlX9x8uNjcFF+7SztEVQ+vhJ/+7EvwqaJr
ysweo0l91TKM9WaMuwoucKeceVWuynEw6GGTw8UTLtltekLGe6bS8YxY8fVwnKkT
RwJYwAJRAoGBAP2rhcfOA+1Ja37hUHKebfp9rHsex4+pGyt3Kdu7WdqOn4sexmya
BuiHQcUchPDVla/ruQZ20+8LHgzBDo0m8sY7gpf715UV9NSVIRD0wu26SKRklOFz
J4HBOwU9hBGLSnRUJzyvVlt5O7E9hAv61SCrvWBEcow2YnKNQLwvjMVJAoGBAMuG
oSb3A/ulqtp2zpxVAclYe/bSItZZTOUWP6Vb4hOiHxIJ0n1H9ap6grOYkJ/Yn4gg
yYzKm/noF1wXP7Rj/xOahnvMkzhGdmOabvE9LH5HwQTWxBBWTkZzgBbYtbg+J5MT
cKqJaychSRjJj+xX+d90rtlSu/c27chlSRKAHXWXAoGAFTcIHDq9l1XBmL3tRXi8
h+uExlM/q2MgM5VmucrEbAPrke4D+Ec1drMBLCQDdkTWnPzg34qGlQJgA/8NYX61
ZSDK/j0AvaY1cKX8OvfNaaZftuf2j5ha4H4xmnGXnwQAORRkp62eUk4kUOFtLrdO
pcnXL7rpvZI6z4vCszpi0okCgYEAp3lZEl8g/+oK9UneKfYpSi1tlGTGFevVwozU
QpWhKta1CnraogycsnOtKWvZVi9C1xljwF7YioPY9QaMfTvroY3+K9DjM+OHd96U
fB4Chsc0pW60V10te/t+403f+oPqvLO6ehop+kEBjUwPCkQ6cQ3q8xmJYpvofoYZ
4wdZNnECgYBwG8Vrv7Z+kX9Zuh1FvcRoY57bYLU0cWW92SA3Nvi8pZOIEaLHrQyZ
pvvaLIicR1m9+KsOAmii7ru0zL7KsrGW+5migQsaDi4gzahKQpad/R7MLKi/L53r
Ymo0aZKARLHW82GbomQ0zxdRoo9vaqfGNpXkxyyt3k3GGDunmrskYw==
-----END RSA PRIVATE KEY-----
+27
View File
@@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAsJbFp3WcsiojjdQtVnTuvtawL2m4XxK93F6lCnFwcZqUP39t
xFGGlrogHMqreyawIUN7E5shtwGzigzjW8Ly5CryBJpXP3ehNTqJS7emb+9LlC19
Oxa1eQuUQnatgcsd16DPH7kJ5JzN3vXnhvUyk4Qficdmm0uk7FRaNYFi7EJs4xyq
FTrp3rDZ0dzBHumlIeK1om7FNt6Nyivgp+UybO7kl0NLFEeSlV4S+7ofitWQsO5x
YqKAzSzz+KIRQcxJidGBlZ1JN/g5DPDpx/ztvOWYUlM7TYk6xN3focZpU0kBzAw/
rn94yW9z8jpXfzk+MvWzVL/HAcPy4ySwkay0NwIDAQABAoIBADzKWpawbVYEHaM4
lLb7oCjAPXzE9zx7djLDvisfLCdfoINPedkoe52ty1o+BtRpWB7LXTY9pFic1FLE
5wvyy6zyf8hH3ZsysqNhWFxhh4FnLmx/UGokAir+anaK5mYVJ1vQtxzjlV1HAbQs
kRyrklKoHDdRFqiFXOwiib97oDNWhD+RxfyGwwJnynZZSXdLbLSiz/QHQNr/+Ufk
KRBaxt0CfU7mOLZxoy6fNAxHdBcBJPHCyh+aDvEbix7nSncSU8Ju/48YJ8DrglbZ
sXCYoA5Uz8NMDuaEMgoNWCFQVoEcRkEUoaH7BlWd3UUFRPnDZ1B4BmkrVoRE8a58
3OqSwakCgYEA19wQUISXtpnmCrEZfbyZ6IwOy8ZCVaVUtbTjVa8UyfNglzzJG3yz
cXU3X35v5/HNCHaXbG2qcbQLThnHBA+obW3RDo+Q49V84Zh1fUNH0ONHHuC09kB/
/gHqzn/4nLf1aJ2O0NrMyrZNsZ0ZKUKQuVCqWjBOmTNUitcc8RpXZ8sCgYEA0W09
POM/It7RoVGI+cfbbgSRmzFo9kzSp5lP7iZ81bnvUMabu2nv3OeGc3Pmdh1ZJFRw
6iDM6VVbG0uz8g+f8+JT32XdqM7MJAmgfcYfTVBMiVnh330WNkeRrGWqQzB2f2Wr
+0vJjU8CAAcOWDh0oNguJ1l1TSyKxqdL8FsA38UCgYEAudt1AJ7psgOYmqQZ+rUl
H6FYLAQsoWmVIk75XpE9KRUwmYdw8QXRy2LNpp9K4z7C9wKFJorWMsh+42Q2gzyo
HHBtjEf4zPLIb8XBg3UmpKjMV73Kkiy/B4nHDr4I5YdO+iCPEy0RH4kQJFnLjEcQ
LT9TLgxh4G7d4B2PgdjYYTkCgYEArdgiV2LETCvulBzcuYufqOn9/He9i4cl7p4j
bathQQFBmSnkqGQ+Cn/eagQxsKaYEsJNoOxtbNu/7x6eVzeFLawYt38Vy0UuzFN5
eC54WXNotTN5fk2VnKU4VYVnGrMmCobZhpbYzoZhQKiazby/g60wUtW9u7xXzqOd
M/428YkCgYBwbEOx1RboH8H+fP1CAiF+cqtq4Jrz9IRWPOgcDpt2Usk1rDweWrZx
bTRlwIaVc5csIEE2X02fut/TTXr1MoXHa6s2cQrnZYq44488NsO4TAC26hqs/x/H
bVOcX13gT26SYngAHHeh7xjWJr/KgIIwvcvgvoVs6lu7a8aLUvrOag==
-----END RSA PRIVATE KEY-----
+27
View File
@@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEpQIBAAKCAQEAxbbx2m1feHyrQ7jP+8mtDF/pyYLrJWKWAdEv3wZrOtjOZzeL
GPzsmkcgncgoRhX4dT+1itSMR9j9m0/OwsH2UoF6U32LxCOQWQD1AMgIZjAkJeJv
FTrtn8fMQ1701CkbaLTVIjRMlTw8kNXvNA/A9UatoiDmi4TFG6mrxTKZpIcTInvP
EpkK2A7Qsp1E4skFK8jmysy7uRhMaYHtPTsBvxP0zn3lhKB3W+HTqpneewXWHjCD
fL7Nbby91jbz5EKPZXWLuhXIvR1Cu4tiruorwXJxmXaP1HQZonytECNU/UOzP6GN
Ldq0eFDE4b04Wjp396551G99YiFP2nqHVJ5OMQIDAQABAoIBAQDEeYZhjyq+avUu
eSuFhOaIU4/ZhlXycsOqzpwJvzEz61tBSvrZPA5LSb9pzAvpic+7hDH94jX89+8d
NfO7qlADsVNEQJBxuv2o1MCjpCRkmBZz506IBGU60Kt1j5kwdCEergTW1q375z4w
l8f7LmSL2U6WvKcdojTVxohBkIUJ7shtmmukDi2YnMfe6T/2JuXDDL8rvIcnfr5E
MCgPQs+xLeLEGrIJdpUy1iIYZYrzvrpJwf9EJL3D0e7jkpbvAQZ8EF9YhEizJhOm
dzTqW4PgW2yUaHYd3q5QjiILy7AC+oOYoTZln3RfjPOxl+bYjeMOWlqkgtpPQkAE
4I64w8RZAoGBAPLR44pEkmTdfIIF8ZtzBiVfDZ29bT96J0CWXGVzp8x6bSu5J5jl
s7sP8DEcjGZ6vHsLGOvkcNxzcnR3l/5HOz6TIuvVuUm36b1jHltq1xZStjGeKZs1
ihhJSu2lIA+TrK8FCRnKARJ0ughXGNZFItgeM230Sgjp2RL4ISXJ724XAoGBANBy
S2RwNpUYvkCSZHSFnQM/jq1jldxw+0p4jAGpWLilEaA/8xWUnZrnCrPFF/t9llpb
dTR/dCI8ntIMAy2dH4IUHyYKUahyHSzCAUNKpS0s433kn5hy9tGvn7jyuOJ4dk9F
o1PIZM7qfzmkdCBbX3NF2TGpzOvbYGJHHC3ssVr3AoGBANHJDopN9iDYzpJTaktA
VEYDWnM2zmUyNylw/sDT7FwYRaup2xEZG2/5NC5qGM8NKTww+UYMZom/4FnJXXLd
vcyxOFGCpAORtoreUMLwioWJzkkN+apT1kxnPioVKJ7smhvYAOXcBZMZcAR2o0m0
D4eiiBJuJWyQBPCDmbfZQFffAoGBAKpcr4ewOrwS0/O8cgPV7CTqfjbyDFp1sLwF
2A/Hk66dotFBUvBRXZpruJCCxn4R/59r3lgAzy7oMrnjfXl7UHQk8+xIRMMSOQwK
p7OSv3szk96hy1pyo41vJ3CmWDsoTzGs7bcdMl72wvKemRaU92ckMEZpzAT8cEMC
cWKLb8yzAoGAMibG8IyHSo7CJz82+7UHm98jNOlg6s73CEjp0W/+FL45Ka7MF/lp
xtR3eSmxltvwvjQoti3V4Qboqtc2IPCt+EtapTM7Wo41wlLCWCNx4u25pZPH/c8g
1yQ+OvH+xOYG+SeO98Phw/8d3IRfR83aqisQHv5upo2Rozzo0Kh3OsE=
-----END RSA PRIVATE KEY-----
Binary file not shown.
+1
View File
@@ -0,0 +1 @@
lt1L ܟ.\pXP"ƀ(a
+6
View File
@@ -0,0 +1,6 @@
package gm
type KickPlayerInfo struct {
UserId uint32
Reason uint32
}
+11
View File
@@ -0,0 +1,11 @@
package gm
type OnlineUserList struct {
UserList []*OnlineUserInfo `json:"userList"`
}
type OnlineUserInfo struct {
Uid uint32 `json:"uid"`
ConvId uint64 `json:"convId"`
Addr string `json:"addr"`
}
+547
View File
@@ -0,0 +1,547 @@
package forward
import (
"hk4e/common/config"
"hk4e/common/region"
"hk4e/gate/entity/gm"
"hk4e/gate/kcp"
"hk4e/gate/net"
"hk4e/logger"
"hk4e/protocol/cmd"
"hk4e/protocol/proto"
"os"
"runtime"
"sync"
"time"
)
const (
ConnWaitToken = iota
ConnWaitLogin
ConnAlive
ConnClose
)
type ClientHeadMeta struct {
seq uint32
}
type ForwardManager struct {
dao string
protoMsgInput chan *net.ProtoMsg
protoMsgOutput chan *net.ProtoMsg
netMsgInput chan *cmd.NetMsg
netMsgOutput chan *cmd.NetMsg
// 玩家登录相关
connStateMap map[uint64]uint8
connStateMapLock sync.RWMutex
// kcpConv -> userID
convUserIdMap map[uint64]uint32
convUserIdMapLock sync.RWMutex
// userID -> kcpConv
userIdConvMap map[uint32]uint64
userIdConvMapLock sync.RWMutex
// kcpConv -> ipAddr
convAddrMap map[uint64]string
convAddrMapLock sync.RWMutex
// kcpConv -> headMeta
convHeadMetaMap map[uint64]*ClientHeadMeta
convHeadMetaMapLock sync.RWMutex
secretKeyBuffer []byte
kcpEventInput chan *net.KcpEvent
kcpEventOutput chan *net.KcpEvent
regionCurr *proto.QueryCurrRegionHttpRsp
signRsaKey []byte
encRsaKeyMap map[string][]byte
}
func NewForwardManager(
protoMsgInput chan *net.ProtoMsg, protoMsgOutput chan *net.ProtoMsg,
kcpEventInput chan *net.KcpEvent, kcpEventOutput chan *net.KcpEvent,
netMsgInput chan *cmd.NetMsg, netMsgOutput chan *cmd.NetMsg) (r *ForwardManager) {
r = new(ForwardManager)
r.protoMsgInput = protoMsgInput
r.protoMsgOutput = protoMsgOutput
r.netMsgInput = netMsgInput
r.netMsgOutput = netMsgOutput
r.connStateMap = make(map[uint64]uint8)
r.convUserIdMap = make(map[uint64]uint32)
r.userIdConvMap = make(map[uint32]uint64)
r.convAddrMap = make(map[uint64]string)
r.convHeadMetaMap = make(map[uint64]*ClientHeadMeta)
r.kcpEventInput = kcpEventInput
r.kcpEventOutput = kcpEventOutput
return r
}
func (f *ForwardManager) getHeadMsg(clientSeq uint32) (headMsg *proto.PacketHead) {
headMsg = new(proto.PacketHead)
if clientSeq != 0 {
headMsg.ClientSequenceId = clientSeq
headMsg.SentMs = uint64(time.Now().UnixMilli())
}
return headMsg
}
func (f *ForwardManager) kcpEventHandle() {
for {
event := <-f.kcpEventOutput
logger.LOG.Info("rpc manager recv event, ConvId: %v, EventId: %v", event.ConvId, event.EventId)
switch event.EventId {
case net.KcpPacketSendNotify:
// 发包通知
// 关闭发包监听
f.kcpEventInput <- &net.KcpEvent{
ConvId: event.ConvId,
EventId: net.KcpPacketSendListen,
EventMessage: "Disable",
}
// 登录成功 通知GS初始化相关数据
userId, exist := f.getUserIdByConvId(event.ConvId)
if !exist {
logger.LOG.Error("can not find userId by convId")
continue
}
headMeta, exist := f.getHeadMetaByConvId(event.ConvId)
if !exist {
logger.LOG.Error("can not find client head metadata by convId")
continue
}
netMsg := new(cmd.NetMsg)
netMsg.UserId = userId
netMsg.EventId = cmd.UserLoginNotify
netMsg.ClientSeq = headMeta.seq
f.netMsgInput <- netMsg
logger.LOG.Info("send to gs user login ok, ConvId: %v, UserId: %v", event.ConvId, netMsg.UserId)
case net.KcpConnCloseNotify:
// 连接断开通知
userId, exist := f.getUserIdByConvId(event.ConvId)
if !exist {
logger.LOG.Error("can not find userId by convId")
continue
}
if f.getConnState(event.ConvId) == ConnAlive {
// 通知GS玩家下线
netMsg := new(cmd.NetMsg)
netMsg.UserId = userId
netMsg.EventId = cmd.UserOfflineNotify
f.netMsgInput <- netMsg
logger.LOG.Info("send to gs user offline, ConvId: %v, UserId: %v", event.ConvId, netMsg.UserId)
}
// 删除各种map数据
f.deleteConnState(event.ConvId)
f.deleteUserIdByConvId(event.ConvId)
currConvId, currExist := f.getConvIdByUserId(userId)
if currExist && currConvId == event.ConvId {
// 防止误删顶号的新连接数据
f.deleteConvIdByUserId(userId)
}
f.deleteAddrByConvId(event.ConvId)
f.deleteHeadMetaByConvId(event.ConvId)
case net.KcpConnEstNotify:
// 连接建立通知
addr, ok := event.EventMessage.(string)
if !ok {
logger.LOG.Error("event KcpConnEstNotify msg type error")
continue
}
f.setAddrByConvId(event.ConvId, addr)
case net.KcpConnRttNotify:
// 客户端往返时延通知
rtt, ok := event.EventMessage.(int32)
if !ok {
logger.LOG.Error("event KcpConnRttNotify msg type error")
continue
}
// 通知GS玩家客户端往返时延
userId, exist := f.getUserIdByConvId(event.ConvId)
if !exist {
logger.LOG.Error("can not find userId by convId")
continue
}
netMsg := new(cmd.NetMsg)
netMsg.UserId = userId
netMsg.EventId = cmd.ClientRttNotify
netMsg.ClientRtt = uint32(rtt)
f.netMsgInput <- netMsg
case net.KcpConnAddrChangeNotify:
// 客户端网络地址改变通知
f.convAddrMapLock.Lock()
_, exist := f.convAddrMap[event.ConvId]
if !exist {
f.convAddrMapLock.Unlock()
logger.LOG.Error("conn addr change but conn can not be found")
continue
}
addr := event.EventMessage.(string)
f.convAddrMap[event.ConvId] = addr
f.convAddrMapLock.Unlock()
}
}
}
func (f *ForwardManager) Start() {
// 读取密钥相关文件
var err error = nil
f.secretKeyBuffer, err = os.ReadFile("static/secretKeyBuffer.bin")
if err != nil {
logger.LOG.Error("open secretKeyBuffer.bin error")
return
}
f.signRsaKey, f.encRsaKeyMap, _ = region.LoadRsaKey()
// region
regionCurr, _ := region.InitRegion(config.CONF.Hk4e.KcpAddr, config.CONF.Hk4e.KcpPort)
f.regionCurr = regionCurr
// kcp事件监听
go f.kcpEventHandle()
go f.recvNetMsgFromGameServer()
// 接收客户端消息
cpuCoreNum := runtime.NumCPU()
for i := 0; i < cpuCoreNum*10; i++ {
go f.sendNetMsgToGameServer()
}
}
// 发送消息到GS
func (f *ForwardManager) sendNetMsgToGameServer() {
for {
protoMsg := <-f.protoMsgOutput
if protoMsg.HeadMessage == nil {
logger.LOG.Error("recv null head msg: %v", protoMsg)
}
f.setHeadMetaByConvId(protoMsg.ConvId, &ClientHeadMeta{
seq: protoMsg.HeadMessage.ClientSequenceId,
})
connState := f.getConnState(protoMsg.ConvId)
// gate本地处理的请求
switch protoMsg.CmdId {
case cmd.GetPlayerTokenReq:
// 获取玩家token请求
if connState != ConnWaitToken {
continue
}
getPlayerTokenReq := protoMsg.PayloadMessage.(*proto.GetPlayerTokenReq)
getPlayerTokenRsp := f.getPlayerToken(protoMsg.ConvId, getPlayerTokenReq)
if getPlayerTokenRsp == nil {
continue
}
// 改变解密密钥
f.kcpEventInput <- &net.KcpEvent{
ConvId: protoMsg.ConvId,
EventId: net.KcpXorKeyChange,
EventMessage: "DEC",
}
// 返回数据到客户端
resp := new(net.ProtoMsg)
resp.ConvId = protoMsg.ConvId
resp.CmdId = cmd.GetPlayerTokenRsp
resp.HeadMessage = f.getHeadMsg(protoMsg.HeadMessage.ClientSequenceId)
resp.PayloadMessage = getPlayerTokenRsp
f.protoMsgInput <- resp
case cmd.PlayerLoginReq:
// 玩家登录请求
if connState != ConnWaitLogin {
continue
}
playerLoginReq := protoMsg.PayloadMessage.(*proto.PlayerLoginReq)
playerLoginRsp := f.playerLogin(protoMsg.ConvId, playerLoginReq)
if playerLoginRsp == nil {
continue
}
// 改变加密密钥
f.kcpEventInput <- &net.KcpEvent{
ConvId: protoMsg.ConvId,
EventId: net.KcpXorKeyChange,
EventMessage: "ENC",
}
// 开启发包监听
f.kcpEventInput <- &net.KcpEvent{
ConvId: protoMsg.ConvId,
EventId: net.KcpPacketSendListen,
EventMessage: "Enable",
}
go func() {
// 保证kcp事件已成功生效
time.Sleep(time.Millisecond * 50)
// 返回数据到客户端
resp := new(net.ProtoMsg)
resp.ConvId = protoMsg.ConvId
resp.CmdId = cmd.PlayerLoginRsp
resp.HeadMessage = f.getHeadMsg(protoMsg.HeadMessage.ClientSequenceId)
resp.PayloadMessage = playerLoginRsp
f.protoMsgInput <- resp
}()
case cmd.SetPlayerBornDataReq:
// 玩家注册请求
if connState != ConnAlive {
continue
}
userId, exist := f.getUserIdByConvId(protoMsg.ConvId)
if !exist {
logger.LOG.Error("can not find userId by convId")
continue
}
netMsg := new(cmd.NetMsg)
netMsg.UserId = userId
netMsg.EventId = cmd.UserRegNotify
netMsg.CmdId = cmd.SetPlayerBornDataReq
netMsg.ClientSeq = protoMsg.HeadMessage.ClientSequenceId
netMsg.PayloadMessage = protoMsg.PayloadMessage
f.netMsgInput <- netMsg
case cmd.PlayerForceExitRsp:
// 玩家退出游戏请求
if connState != ConnAlive {
continue
}
userId, exist := f.getUserIdByConvId(protoMsg.ConvId)
if !exist {
logger.LOG.Error("can not find userId by convId")
continue
}
f.setConnState(protoMsg.ConvId, ConnClose)
info := new(gm.KickPlayerInfo)
info.UserId = userId
info.Reason = uint32(kcp.EnetServerKick)
f.KickPlayer(info)
case cmd.PingReq:
// ping请求
if connState != ConnAlive {
continue
}
pingReq := protoMsg.PayloadMessage.(*proto.PingReq)
logger.LOG.Debug("user ping req, data: %v", pingReq.String())
// 返回数据到客户端
// TODO 记录客户端最后一次ping时间做超时下线处理
pingRsp := new(proto.PingRsp)
pingRsp.ClientTime = pingReq.ClientTime
resp := new(net.ProtoMsg)
resp.ConvId = protoMsg.ConvId
resp.CmdId = cmd.PingRsp
resp.HeadMessage = f.getHeadMsg(protoMsg.HeadMessage.ClientSequenceId)
resp.PayloadMessage = pingRsp
f.protoMsgInput <- resp
// 通知GS玩家客户端的本地时钟
userId, exist := f.getUserIdByConvId(protoMsg.ConvId)
if !exist {
logger.LOG.Error("can not find userId by convId")
continue
}
netMsg := new(cmd.NetMsg)
netMsg.UserId = userId
netMsg.EventId = cmd.ClientTimeNotify
netMsg.ClientTime = pingReq.ClientTime
f.netMsgInput <- netMsg
default:
// 转发到GS
// 未登录禁止访问GS
if connState != ConnAlive {
continue
}
netMsg := new(cmd.NetMsg)
userId, exist := f.getUserIdByConvId(protoMsg.ConvId)
if exist {
netMsg.UserId = userId
} else {
logger.LOG.Error("can not find userId by convId")
continue
}
netMsg.EventId = cmd.NormalMsg
netMsg.CmdId = protoMsg.CmdId
netMsg.ClientSeq = protoMsg.HeadMessage.ClientSequenceId
netMsg.PayloadMessage = protoMsg.PayloadMessage
f.netMsgInput <- netMsg
}
}
}
// 从GS接收消息
func (f *ForwardManager) recvNetMsgFromGameServer() {
for {
netMsg := <-f.netMsgOutput
convId, exist := f.getConvIdByUserId(netMsg.UserId)
if !exist {
logger.LOG.Error("can not find convId by userId")
continue
}
if netMsg.EventId == cmd.NormalMsg {
protoMsg := new(net.ProtoMsg)
protoMsg.ConvId = convId
protoMsg.CmdId = netMsg.CmdId
protoMsg.HeadMessage = f.getHeadMsg(netMsg.ClientSeq)
protoMsg.PayloadMessage = netMsg.PayloadMessage
f.protoMsgInput <- protoMsg
continue
} else {
logger.LOG.Error("recv unknown event from game server, event id: %v", netMsg.EventId)
continue
}
}
}
func (f *ForwardManager) getConnState(convId uint64) uint8 {
f.connStateMapLock.RLock()
connState, connStateExist := f.connStateMap[convId]
f.connStateMapLock.RUnlock()
if !connStateExist {
connState = ConnWaitToken
f.connStateMapLock.Lock()
f.connStateMap[convId] = ConnWaitToken
f.connStateMapLock.Unlock()
}
return connState
}
func (f *ForwardManager) setConnState(convId uint64, state uint8) {
f.connStateMapLock.Lock()
f.connStateMap[convId] = state
f.connStateMapLock.Unlock()
}
func (f *ForwardManager) deleteConnState(convId uint64) {
f.connStateMapLock.Lock()
delete(f.connStateMap, convId)
f.connStateMapLock.Unlock()
}
func (f *ForwardManager) getUserIdByConvId(convId uint64) (userId uint32, exist bool) {
f.convUserIdMapLock.RLock()
userId, exist = f.convUserIdMap[convId]
f.convUserIdMapLock.RUnlock()
return userId, exist
}
func (f *ForwardManager) setUserIdByConvId(convId uint64, userId uint32) {
f.convUserIdMapLock.Lock()
f.convUserIdMap[convId] = userId
f.convUserIdMapLock.Unlock()
}
func (f *ForwardManager) deleteUserIdByConvId(convId uint64) {
f.convUserIdMapLock.Lock()
delete(f.convUserIdMap, convId)
f.convUserIdMapLock.Unlock()
}
func (f *ForwardManager) getConvIdByUserId(userId uint32) (convId uint64, exist bool) {
f.userIdConvMapLock.RLock()
convId, exist = f.userIdConvMap[userId]
f.userIdConvMapLock.RUnlock()
return convId, exist
}
func (f *ForwardManager) setConvIdByUserId(userId uint32, convId uint64) {
f.userIdConvMapLock.Lock()
f.userIdConvMap[userId] = convId
f.userIdConvMapLock.Unlock()
}
func (f *ForwardManager) deleteConvIdByUserId(userId uint32) {
f.userIdConvMapLock.Lock()
delete(f.userIdConvMap, userId)
f.userIdConvMapLock.Unlock()
}
func (f *ForwardManager) getAddrByConvId(convId uint64) (addr string, exist bool) {
f.convAddrMapLock.RLock()
addr, exist = f.convAddrMap[convId]
f.convAddrMapLock.RUnlock()
return addr, exist
}
func (f *ForwardManager) setAddrByConvId(convId uint64, addr string) {
f.convAddrMapLock.Lock()
f.convAddrMap[convId] = addr
f.convAddrMapLock.Unlock()
}
func (f *ForwardManager) deleteAddrByConvId(convId uint64) {
f.convAddrMapLock.Lock()
delete(f.convAddrMap, convId)
f.convAddrMapLock.Unlock()
}
func (f *ForwardManager) getHeadMetaByConvId(convId uint64) (headMeta *ClientHeadMeta, exist bool) {
f.convHeadMetaMapLock.RLock()
headMeta, exist = f.convHeadMetaMap[convId]
f.convHeadMetaMapLock.RUnlock()
return headMeta, exist
}
func (f *ForwardManager) setHeadMetaByConvId(convId uint64, headMeta *ClientHeadMeta) {
f.convHeadMetaMapLock.Lock()
f.convHeadMetaMap[convId] = headMeta
f.convHeadMetaMapLock.Unlock()
}
func (f *ForwardManager) deleteHeadMetaByConvId(convId uint64) {
f.convHeadMetaMapLock.Lock()
delete(f.convHeadMetaMap, convId)
f.convHeadMetaMapLock.Unlock()
}
// 改变网关开放状态
func (f *ForwardManager) ChangeGateOpenState(isOpen bool) bool {
f.kcpEventInput <- &net.KcpEvent{
EventId: net.KcpGateOpenState,
EventMessage: isOpen,
}
logger.LOG.Info("change gate open state to: %v", isOpen)
return true
}
// 剔除玩家下线
func (f *ForwardManager) KickPlayer(info *gm.KickPlayerInfo) bool {
if info == nil {
return false
}
convId, exist := f.getConvIdByUserId(info.UserId)
if !exist {
return false
}
f.kcpEventInput <- &net.KcpEvent{
ConvId: convId,
EventId: net.KcpConnForceClose,
EventMessage: info.Reason,
}
return true
}
// 获取网关在线玩家信息
func (f *ForwardManager) GetOnlineUser(uid uint32) (list *gm.OnlineUserList) {
list = &gm.OnlineUserList{
UserList: make([]*gm.OnlineUserInfo, 0),
}
if uid == 0 {
// 获取全部玩家
f.convUserIdMapLock.RLock()
f.convAddrMapLock.RLock()
for convId, userId := range f.convUserIdMap {
addr := f.convAddrMap[convId]
info := &gm.OnlineUserInfo{
Uid: userId,
ConvId: convId,
Addr: addr,
}
list.UserList = append(list.UserList, info)
}
f.convAddrMapLock.RUnlock()
f.convUserIdMapLock.RUnlock()
} else {
// 获取指定uid玩家
convId, exist := f.getConvIdByUserId(uid)
if !exist {
return list
}
addr, exist := f.getAddrByConvId(convId)
if !exist {
return list
}
info := &gm.OnlineUserInfo{
Uid: uid,
ConvId: convId,
Addr: addr,
}
list.UserList = append(list.UserList, info)
}
return list
}
+182
View File
@@ -0,0 +1,182 @@
package forward
import (
"bytes"
"encoding/base64"
"encoding/binary"
"hk4e/common/utils/endec"
"hk4e/gate/kcp"
"hk4e/gate/net"
"hk4e/logger"
"hk4e/protocol/proto"
"strconv"
"strings"
)
func (f *ForwardManager) getPlayerToken(convId uint64, req *proto.GetPlayerTokenReq) (rsp *proto.GetPlayerTokenRsp) {
_ = req.AccountUid
_ = req.AccountToken
tokenValid := true
accountForbid := false
accountForbidEndTime := uint32(0)
accountPlayerID := uint32(100000001)
if !tokenValid {
logger.LOG.Error("token error")
return nil
}
// TODO 请求sdk验证token
// comboToken验证成功
if accountForbid {
// 封号通知
rsp = new(proto.GetPlayerTokenRsp)
rsp.Uid = accountPlayerID
rsp.IsProficientPlayer = true
rsp.Retcode = 21
rsp.Msg = "FORBID_CHEATING_PLUGINS"
//rsp.BlackUidEndTime = 2051193600 // 2035-01-01 00:00:00
rsp.BlackUidEndTime = accountForbidEndTime
rsp.RegPlatform = 3
rsp.CountryCode = "US"
addr, exist := f.getAddrByConvId(convId)
if !exist {
logger.LOG.Error("can not find addr by convId")
return nil
}
split := strings.Split(addr, ":")
rsp.ClientIpStr = split[0]
return rsp
}
oldConvId, oldExist := f.getConvIdByUserId(accountPlayerID)
if oldExist {
// 顶号
f.kcpEventInput <- &net.KcpEvent{
ConvId: oldConvId,
EventId: net.KcpConnForceClose,
EventMessage: uint32(kcp.EnetServerRelogin),
}
}
f.setUserIdByConvId(convId, accountPlayerID)
f.setConvIdByUserId(accountPlayerID, convId)
f.setConnState(convId, ConnWaitLogin)
// 返回响应
rsp = new(proto.GetPlayerTokenRsp)
rsp.Uid = accountPlayerID
// TODO 不同的token
rsp.Token = req.AccountToken
rsp.AccountType = 1
// TODO 要确定一下新注册的号这个值该返回什么
rsp.IsProficientPlayer = true
rsp.SecretKeySeed = 11468049314633205968
rsp.SecurityCmdBuffer = f.secretKeyBuffer
rsp.PlatformType = 3
rsp.ChannelId = 1
rsp.CountryCode = "US"
rsp.ClientVersionRandomKey = "c25-314dd05b0b5f"
rsp.RegPlatform = 3
addr, exist := f.getAddrByConvId(convId)
if !exist {
logger.LOG.Error("can not find addr by convId")
return nil
}
split := strings.Split(addr, ":")
rsp.ClientIpStr = split[0]
if req.GetKeyId() != 0 {
// pre check
logger.LOG.Debug("do hk4e 2.8 rsa logic")
keyId := strconv.Itoa(int(req.GetKeyId()))
encPubPrivKey, exist := f.encRsaKeyMap[keyId]
if !exist {
logger.LOG.Error("can not found key id: %v", keyId)
return
}
pubKey, err := endec.RsaParsePubKeyByPrivKey(encPubPrivKey)
if err != nil {
logger.LOG.Error("parse rsa pub key error: %v", err)
return nil
}
signPrivkey, err := endec.RsaParsePrivKey(f.signRsaKey)
if err != nil {
logger.LOG.Error("parse rsa priv key error: %v", err)
return nil
}
clientSeedBase64 := req.GetClientSeed()
clientSeedEnc, err := base64.StdEncoding.DecodeString(clientSeedBase64)
if err != nil {
logger.LOG.Error("parse client seed base64 error: %v", err)
return nil
}
// create error rsp info
clientSeedEncCopy := make([]byte, len(clientSeedEnc))
copy(clientSeedEncCopy, clientSeedEnc)
endec.Xor(clientSeedEncCopy, []byte{0x9f, 0x26, 0xb2, 0x17, 0x61, 0x5f, 0xc8, 0x00})
rsp.EncryptedSeed = base64.StdEncoding.EncodeToString(clientSeedEncCopy)
rsp.SeedSignature = "bm90aGluZyBoZXJl"
// do
clientSeed, err := endec.RsaDecrypt(clientSeedEnc, signPrivkey)
if err != nil {
logger.LOG.Error("rsa dec error: %v", err)
return rsp
}
clientSeedUint64 := uint64(0)
err = binary.Read(bytes.NewReader(clientSeed), binary.BigEndian, &clientSeedUint64)
if err != nil {
logger.LOG.Error("parse client seed to uint64 error: %v", err)
return rsp
}
seedUint64 := uint64(11468049314633205968) ^ clientSeedUint64
seedBuf := new(bytes.Buffer)
err = binary.Write(seedBuf, binary.BigEndian, seedUint64)
if err != nil {
logger.LOG.Error("conv seed uint64 to bytes error: %v", err)
return rsp
}
seed := seedBuf.Bytes()
seedEnc, err := endec.RsaEncrypt(seed, pubKey)
if err != nil {
logger.LOG.Error("rsa enc error: %v", err)
return rsp
}
seedSign, err := endec.RsaSign(seed, signPrivkey)
if err != nil {
logger.LOG.Error("rsa sign error: %v", err)
return rsp
}
rsp.EncryptedSeed = base64.StdEncoding.EncodeToString(seedEnc)
rsp.SeedSignature = base64.StdEncoding.EncodeToString(seedSign)
}
return rsp
}
func (f *ForwardManager) playerLogin(convId uint64, req *proto.PlayerLoginReq) (rsp *proto.PlayerLoginRsp) {
userId, exist := f.getUserIdByConvId(convId)
if !exist {
logger.LOG.Error("can not find userId by convId")
return nil
}
_ = userId
_ = req.Token
tokenValid := true
if !tokenValid {
logger.LOG.Error("token error")
return nil
}
// TODO 请求sdk验证token
// comboToken验证成功
f.setConnState(convId, ConnAlive)
// 返回响应
rsp = new(proto.PlayerLoginRsp)
rsp.IsUseAbilityHash = true
rsp.AbilityHashCode = 1844674
rsp.GameBiz = "hk4e_global"
rsp.ClientDataVersion = f.regionCurr.RegionInfo.ClientDataVersion
rsp.ClientSilenceDataVersion = f.regionCurr.RegionInfo.ClientSilenceDataVersion
rsp.ClientMd5 = f.regionCurr.RegionInfo.ClientDataMd5
rsp.ClientSilenceMd5 = f.regionCurr.RegionInfo.ClientSilenceDataMd5
rsp.ResVersionConfig = f.regionCurr.RegionInfo.ResVersionConfig
rsp.ClientVersionSuffix = f.regionCurr.RegionInfo.ClientVersionSuffix
rsp.ClientSilenceVersionSuffix = f.regionCurr.RegionInfo.ClientSilenceVersionSuffix
rsp.IsScOpen = false
rsp.RegisterCps = "mihoyo"
rsp.CountryCode = "US"
return rsp
}
+64
View File
@@ -0,0 +1,64 @@
package kcp
const maxAutoTuneSamples = 258
// pulse represents a 0/1 signal with time sequence
type pulse struct {
bit bool // 0 or 1
seq uint32 // sequence of the signal
}
// autoTune object
type autoTune struct {
pulses [maxAutoTuneSamples]pulse
}
// Sample adds a signal sample to the pulse buffer
func (tune *autoTune) Sample(bit bool, seq uint32) {
tune.pulses[seq%maxAutoTuneSamples] = pulse{bit, seq}
}
// Find a period for a given signal
// returns -1 if not found
//
// --- ------
// | |
// |______________|
// Period
// Falling Edge Rising Edge
func (tune *autoTune) FindPeriod(bit bool) int {
// last pulse and initial index setup
lastPulse := tune.pulses[0]
idx := 1
// left edge
var leftEdge int
for ; idx < len(tune.pulses); idx++ {
if lastPulse.bit != bit && tune.pulses[idx].bit == bit { // edge found
if lastPulse.seq+1 == tune.pulses[idx].seq { // ensure edge continuity
leftEdge = idx
break
}
}
lastPulse = tune.pulses[idx]
}
// right edge
var rightEdge int
lastPulse = tune.pulses[leftEdge]
idx = leftEdge + 1
for ; idx < len(tune.pulses); idx++ {
if lastPulse.seq+1 == tune.pulses[idx].seq { // ensure pulses in this level monotonic
if lastPulse.bit == bit && tune.pulses[idx].bit != bit { // edge found
rightEdge = idx
break
}
} else {
return -1
}
lastPulse = tune.pulses[idx]
}
return rightEdge - leftEdge
}
+47
View File
@@ -0,0 +1,47 @@
package kcp
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAutoTune(t *testing.T) {
signals := []uint32{0, 0, 0, 0, 0, 0}
tune := autoTune{}
for i := 0; i < len(signals); i++ {
if signals[i] == 0 {
tune.Sample(false, uint32(i))
} else {
tune.Sample(true, uint32(i))
}
}
assert.Equal(t, -1, tune.FindPeriod(false))
assert.Equal(t, -1, tune.FindPeriod(true))
signals = []uint32{1, 0, 1, 0, 0, 1}
tune = autoTune{}
for i := 0; i < len(signals); i++ {
if signals[i] == 0 {
tune.Sample(false, uint32(i))
} else {
tune.Sample(true, uint32(i))
}
}
assert.Equal(t, 1, tune.FindPeriod(false))
assert.Equal(t, 1, tune.FindPeriod(true))
signals = []uint32{1, 0, 0, 0, 0, 1}
tune = autoTune{}
for i := 0; i < len(signals); i++ {
if signals[i] == 0 {
tune.Sample(false, uint32(i))
} else {
tune.Sample(true, uint32(i))
}
}
assert.Equal(t, -1, tune.FindPeriod(true))
assert.Equal(t, 4, tune.FindPeriod(false))
}
+12
View File
@@ -0,0 +1,12 @@
package kcp
import "golang.org/x/net/ipv4"
const (
batchSize = 16
)
type batchConn interface {
WriteBatch(ms []ipv4.Message, flags int) (int, error)
ReadBatch(ms []ipv4.Message, flags int) (int, error)
}
+618
View File
@@ -0,0 +1,618 @@
package kcp
import (
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/sha1"
"unsafe"
xor "github.com/templexxx/xorsimd"
"github.com/tjfoc/gmsm/sm4"
"golang.org/x/crypto/blowfish"
"golang.org/x/crypto/cast5"
"golang.org/x/crypto/pbkdf2"
"golang.org/x/crypto/salsa20"
"golang.org/x/crypto/tea"
"golang.org/x/crypto/twofish"
"golang.org/x/crypto/xtea"
)
var (
initialVector = []byte{167, 115, 79, 156, 18, 172, 27, 1, 164, 21, 242, 193, 252, 120, 230, 107}
saltxor = `sH3CIVoF#rWLtJo6`
)
// BlockCrypt defines encryption/decryption methods for a given byte slice.
// Notes on implementing: the data to be encrypted contains a builtin
// nonce at the first 16 bytes
type BlockCrypt interface {
// Encrypt encrypts the whole block in src into dst.
// Dst and src may point at the same memory.
Encrypt(dst, src []byte)
// Decrypt decrypts the whole block in src into dst.
// Dst and src may point at the same memory.
Decrypt(dst, src []byte)
}
type salsa20BlockCrypt struct {
key [32]byte
}
// NewSalsa20BlockCrypt https://en.wikipedia.org/wiki/Salsa20
func NewSalsa20BlockCrypt(key []byte) (BlockCrypt, error) {
c := new(salsa20BlockCrypt)
copy(c.key[:], key)
return c, nil
}
func (c *salsa20BlockCrypt) Encrypt(dst, src []byte) {
salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key)
copy(dst[:8], src[:8])
}
func (c *salsa20BlockCrypt) Decrypt(dst, src []byte) {
salsa20.XORKeyStream(dst[8:], src[8:], src[:8], &c.key)
copy(dst[:8], src[:8])
}
type sm4BlockCrypt struct {
encbuf [sm4.BlockSize]byte // 64bit alignment enc/dec buffer
decbuf [2 * sm4.BlockSize]byte
block cipher.Block
}
// NewSM4BlockCrypt https://github.com/tjfoc/gmsm/tree/master/sm4
func NewSM4BlockCrypt(key []byte) (BlockCrypt, error) {
c := new(sm4BlockCrypt)
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *sm4BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *sm4BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type twofishBlockCrypt struct {
encbuf [twofish.BlockSize]byte
decbuf [2 * twofish.BlockSize]byte
block cipher.Block
}
// NewTwofishBlockCrypt https://en.wikipedia.org/wiki/Twofish
func NewTwofishBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(twofishBlockCrypt)
block, err := twofish.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *twofishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *twofishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type tripleDESBlockCrypt struct {
encbuf [des.BlockSize]byte
decbuf [2 * des.BlockSize]byte
block cipher.Block
}
// NewTripleDESBlockCrypt https://en.wikipedia.org/wiki/Triple_DES
func NewTripleDESBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(tripleDESBlockCrypt)
block, err := des.NewTripleDESCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *tripleDESBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *tripleDESBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type cast5BlockCrypt struct {
encbuf [cast5.BlockSize]byte
decbuf [2 * cast5.BlockSize]byte
block cipher.Block
}
// NewCast5BlockCrypt https://en.wikipedia.org/wiki/CAST-128
func NewCast5BlockCrypt(key []byte) (BlockCrypt, error) {
c := new(cast5BlockCrypt)
block, err := cast5.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *cast5BlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *cast5BlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type blowfishBlockCrypt struct {
encbuf [blowfish.BlockSize]byte
decbuf [2 * blowfish.BlockSize]byte
block cipher.Block
}
// NewBlowfishBlockCrypt https://en.wikipedia.org/wiki/Blowfish_(cipher)
func NewBlowfishBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(blowfishBlockCrypt)
block, err := blowfish.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *blowfishBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *blowfishBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type aesBlockCrypt struct {
encbuf [aes.BlockSize]byte
decbuf [2 * aes.BlockSize]byte
block cipher.Block
}
// NewAESBlockCrypt https://en.wikipedia.org/wiki/Advanced_Encryption_Standard
func NewAESBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(aesBlockCrypt)
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *aesBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *aesBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type teaBlockCrypt struct {
encbuf [tea.BlockSize]byte
decbuf [2 * tea.BlockSize]byte
block cipher.Block
}
// NewTEABlockCrypt https://en.wikipedia.org/wiki/Tiny_Encryption_Algorithm
func NewTEABlockCrypt(key []byte) (BlockCrypt, error) {
c := new(teaBlockCrypt)
block, err := tea.NewCipherWithRounds(key, 16)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *teaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *teaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type xteaBlockCrypt struct {
encbuf [xtea.BlockSize]byte
decbuf [2 * xtea.BlockSize]byte
block cipher.Block
}
// NewXTEABlockCrypt https://en.wikipedia.org/wiki/XTEA
func NewXTEABlockCrypt(key []byte) (BlockCrypt, error) {
c := new(xteaBlockCrypt)
block, err := xtea.NewCipher(key)
if err != nil {
return nil, err
}
c.block = block
return c, nil
}
func (c *xteaBlockCrypt) Encrypt(dst, src []byte) { encrypt(c.block, dst, src, c.encbuf[:]) }
func (c *xteaBlockCrypt) Decrypt(dst, src []byte) { decrypt(c.block, dst, src, c.decbuf[:]) }
type simpleXORBlockCrypt struct {
xortbl []byte
}
// NewSimpleXORBlockCrypt simple xor with key expanding
func NewSimpleXORBlockCrypt(key []byte) (BlockCrypt, error) {
c := new(simpleXORBlockCrypt)
c.xortbl = pbkdf2.Key(key, []byte(saltxor), 32, mtuLimit, sha1.New)
return c, nil
}
func (c *simpleXORBlockCrypt) Encrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) }
func (c *simpleXORBlockCrypt) Decrypt(dst, src []byte) { xor.Bytes(dst, src, c.xortbl) }
type noneBlockCrypt struct{}
// NewNoneBlockCrypt does nothing but copying
func NewNoneBlockCrypt(key []byte) (BlockCrypt, error) {
return new(noneBlockCrypt), nil
}
func (c *noneBlockCrypt) Encrypt(dst, src []byte) { copy(dst, src) }
func (c *noneBlockCrypt) Decrypt(dst, src []byte) { copy(dst, src) }
// packet encryption with local CFB mode
func encrypt(block cipher.Block, dst, src, buf []byte) {
switch block.BlockSize() {
case 8:
encrypt8(block, dst, src, buf)
case 16:
encrypt16(block, dst, src, buf)
default:
panic("unsupported cipher block size")
}
}
// optimized encryption for the ciphers which works in 8-bytes
func encrypt8(block cipher.Block, dst, src, buf []byte) {
tbl := buf[:8]
block.Encrypt(tbl, initialVector)
n := len(src) / 8
base := 0
repeat := n / 8
left := n % 8
ptr_tbl := (*uint64)(unsafe.Pointer(&tbl[0]))
for i := 0; i < repeat; i++ {
s := src[base:][0:64]
d := dst[base:][0:64]
// 1
*(*uint64)(unsafe.Pointer(&d[0])) = *(*uint64)(unsafe.Pointer(&s[0])) ^ *ptr_tbl
block.Encrypt(tbl, d[0:8])
// 2
*(*uint64)(unsafe.Pointer(&d[8])) = *(*uint64)(unsafe.Pointer(&s[8])) ^ *ptr_tbl
block.Encrypt(tbl, d[8:16])
// 3
*(*uint64)(unsafe.Pointer(&d[16])) = *(*uint64)(unsafe.Pointer(&s[16])) ^ *ptr_tbl
block.Encrypt(tbl, d[16:24])
// 4
*(*uint64)(unsafe.Pointer(&d[24])) = *(*uint64)(unsafe.Pointer(&s[24])) ^ *ptr_tbl
block.Encrypt(tbl, d[24:32])
// 5
*(*uint64)(unsafe.Pointer(&d[32])) = *(*uint64)(unsafe.Pointer(&s[32])) ^ *ptr_tbl
block.Encrypt(tbl, d[32:40])
// 6
*(*uint64)(unsafe.Pointer(&d[40])) = *(*uint64)(unsafe.Pointer(&s[40])) ^ *ptr_tbl
block.Encrypt(tbl, d[40:48])
// 7
*(*uint64)(unsafe.Pointer(&d[48])) = *(*uint64)(unsafe.Pointer(&s[48])) ^ *ptr_tbl
block.Encrypt(tbl, d[48:56])
// 8
*(*uint64)(unsafe.Pointer(&d[56])) = *(*uint64)(unsafe.Pointer(&s[56])) ^ *ptr_tbl
block.Encrypt(tbl, d[56:64])
base += 64
}
switch left {
case 7:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 6:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 5:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 4:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 3:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 2:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 1:
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *ptr_tbl
block.Encrypt(tbl, dst[base:])
base += 8
fallthrough
case 0:
xorBytes(dst[base:], src[base:], tbl)
}
}
// optimized encryption for the ciphers which works in 16-bytes
func encrypt16(block cipher.Block, dst, src, buf []byte) {
tbl := buf[:16]
block.Encrypt(tbl, initialVector)
n := len(src) / 16
base := 0
repeat := n / 8
left := n % 8
for i := 0; i < repeat; i++ {
s := src[base:][0:128]
d := dst[base:][0:128]
// 1
xor.Bytes16Align(d[0:16], s[0:16], tbl)
block.Encrypt(tbl, d[0:16])
// 2
xor.Bytes16Align(d[16:32], s[16:32], tbl)
block.Encrypt(tbl, d[16:32])
// 3
xor.Bytes16Align(d[32:48], s[32:48], tbl)
block.Encrypt(tbl, d[32:48])
// 4
xor.Bytes16Align(d[48:64], s[48:64], tbl)
block.Encrypt(tbl, d[48:64])
// 5
xor.Bytes16Align(d[64:80], s[64:80], tbl)
block.Encrypt(tbl, d[64:80])
// 6
xor.Bytes16Align(d[80:96], s[80:96], tbl)
block.Encrypt(tbl, d[80:96])
// 7
xor.Bytes16Align(d[96:112], s[96:112], tbl)
block.Encrypt(tbl, d[96:112])
// 8
xor.Bytes16Align(d[112:128], s[112:128], tbl)
block.Encrypt(tbl, d[112:128])
base += 128
}
switch left {
case 7:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 6:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 5:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 4:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 3:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 2:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 1:
xor.Bytes16Align(dst[base:], src[base:], tbl)
block.Encrypt(tbl, dst[base:])
base += 16
fallthrough
case 0:
xorBytes(dst[base:], src[base:], tbl)
}
}
// decryption
func decrypt(block cipher.Block, dst, src, buf []byte) {
switch block.BlockSize() {
case 8:
decrypt8(block, dst, src, buf)
case 16:
decrypt16(block, dst, src, buf)
default:
panic("unsupported cipher block size")
}
}
// decrypt 8 bytes block, all byte slices are supposed to be 64bit aligned
func decrypt8(block cipher.Block, dst, src, buf []byte) {
tbl := buf[0:8]
next := buf[8:16]
block.Encrypt(tbl, initialVector)
n := len(src) / 8
base := 0
repeat := n / 8
left := n % 8
ptr_tbl := (*uint64)(unsafe.Pointer(&tbl[0]))
ptr_next := (*uint64)(unsafe.Pointer(&next[0]))
for i := 0; i < repeat; i++ {
s := src[base:][0:64]
d := dst[base:][0:64]
// 1
block.Encrypt(next, s[0:8])
*(*uint64)(unsafe.Pointer(&d[0])) = *(*uint64)(unsafe.Pointer(&s[0])) ^ *ptr_tbl
// 2
block.Encrypt(tbl, s[8:16])
*(*uint64)(unsafe.Pointer(&d[8])) = *(*uint64)(unsafe.Pointer(&s[8])) ^ *ptr_next
// 3
block.Encrypt(next, s[16:24])
*(*uint64)(unsafe.Pointer(&d[16])) = *(*uint64)(unsafe.Pointer(&s[16])) ^ *ptr_tbl
// 4
block.Encrypt(tbl, s[24:32])
*(*uint64)(unsafe.Pointer(&d[24])) = *(*uint64)(unsafe.Pointer(&s[24])) ^ *ptr_next
// 5
block.Encrypt(next, s[32:40])
*(*uint64)(unsafe.Pointer(&d[32])) = *(*uint64)(unsafe.Pointer(&s[32])) ^ *ptr_tbl
// 6
block.Encrypt(tbl, s[40:48])
*(*uint64)(unsafe.Pointer(&d[40])) = *(*uint64)(unsafe.Pointer(&s[40])) ^ *ptr_next
// 7
block.Encrypt(next, s[48:56])
*(*uint64)(unsafe.Pointer(&d[48])) = *(*uint64)(unsafe.Pointer(&s[48])) ^ *ptr_tbl
// 8
block.Encrypt(tbl, s[56:64])
*(*uint64)(unsafe.Pointer(&d[56])) = *(*uint64)(unsafe.Pointer(&s[56])) ^ *ptr_next
base += 64
}
switch left {
case 7:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 6:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 5:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 4:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 3:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 2:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 1:
block.Encrypt(next, src[base:])
*(*uint64)(unsafe.Pointer(&dst[base])) = *(*uint64)(unsafe.Pointer(&src[base])) ^ *(*uint64)(unsafe.Pointer(&tbl[0]))
tbl, next = next, tbl
base += 8
fallthrough
case 0:
xorBytes(dst[base:], src[base:], tbl)
}
}
func decrypt16(block cipher.Block, dst, src, buf []byte) {
tbl := buf[0:16]
next := buf[16:32]
block.Encrypt(tbl, initialVector)
n := len(src) / 16
base := 0
repeat := n / 8
left := n % 8
for i := 0; i < repeat; i++ {
s := src[base:][0:128]
d := dst[base:][0:128]
// 1
block.Encrypt(next, s[0:16])
xor.Bytes16Align(d[0:16], s[0:16], tbl)
// 2
block.Encrypt(tbl, s[16:32])
xor.Bytes16Align(d[16:32], s[16:32], next)
// 3
block.Encrypt(next, s[32:48])
xor.Bytes16Align(d[32:48], s[32:48], tbl)
// 4
block.Encrypt(tbl, s[48:64])
xor.Bytes16Align(d[48:64], s[48:64], next)
// 5
block.Encrypt(next, s[64:80])
xor.Bytes16Align(d[64:80], s[64:80], tbl)
// 6
block.Encrypt(tbl, s[80:96])
xor.Bytes16Align(d[80:96], s[80:96], next)
// 7
block.Encrypt(next, s[96:112])
xor.Bytes16Align(d[96:112], s[96:112], tbl)
// 8
block.Encrypt(tbl, s[112:128])
xor.Bytes16Align(d[112:128], s[112:128], next)
base += 128
}
switch left {
case 7:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 6:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 5:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 4:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 3:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 2:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 1:
block.Encrypt(next, src[base:])
xor.Bytes16Align(dst[base:], src[base:], tbl)
tbl, next = next, tbl
base += 16
fallthrough
case 0:
xorBytes(dst[base:], src[base:], tbl)
}
}
// per bytes xors
func xorBytes(dst, a, b []byte) int {
n := len(a)
if len(b) < n {
n = len(b)
}
if n == 0 {
return 0
}
for i := 0; i < n; i++ {
dst[i] = a[i] ^ b[i]
}
return n
}
+289
View File
@@ -0,0 +1,289 @@
package kcp
import (
"bytes"
"crypto/aes"
"crypto/md5"
"crypto/rand"
"crypto/sha1"
"hash/crc32"
"io"
"testing"
)
func TestSM4(t *testing.T) {
bc, err := NewSM4BlockCrypt(pass[:16])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestAES(t *testing.T) {
bc, err := NewAESBlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestTEA(t *testing.T) {
bc, err := NewTEABlockCrypt(pass[:16])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestXOR(t *testing.T) {
bc, err := NewSimpleXORBlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestBlowfish(t *testing.T) {
bc, err := NewBlowfishBlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestNone(t *testing.T) {
bc, err := NewNoneBlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestCast5(t *testing.T) {
bc, err := NewCast5BlockCrypt(pass[:16])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func Test3DES(t *testing.T) {
bc, err := NewTripleDESBlockCrypt(pass[:24])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestTwofish(t *testing.T) {
bc, err := NewTwofishBlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestXTEA(t *testing.T) {
bc, err := NewXTEABlockCrypt(pass[:16])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func TestSalsa20(t *testing.T) {
bc, err := NewSalsa20BlockCrypt(pass[:32])
if err != nil {
t.Fatal(err)
}
cryptTest(t, bc)
}
func cryptTest(t *testing.T, bc BlockCrypt) {
data := make([]byte, mtuLimit)
io.ReadFull(rand.Reader, data)
dec := make([]byte, mtuLimit)
enc := make([]byte, mtuLimit)
bc.Encrypt(enc, data)
bc.Decrypt(dec, enc)
if !bytes.Equal(data, dec) {
t.Fail()
}
}
func BenchmarkSM4(b *testing.B) {
bc, err := NewSM4BlockCrypt(pass[:16])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkAES128(b *testing.B) {
bc, err := NewAESBlockCrypt(pass[:16])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkAES192(b *testing.B) {
bc, err := NewAESBlockCrypt(pass[:24])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkAES256(b *testing.B) {
bc, err := NewAESBlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkTEA(b *testing.B) {
bc, err := NewTEABlockCrypt(pass[:16])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkXOR(b *testing.B) {
bc, err := NewSimpleXORBlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkBlowfish(b *testing.B) {
bc, err := NewBlowfishBlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkNone(b *testing.B) {
bc, err := NewNoneBlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkCast5(b *testing.B) {
bc, err := NewCast5BlockCrypt(pass[:16])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func Benchmark3DES(b *testing.B) {
bc, err := NewTripleDESBlockCrypt(pass[:24])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkTwofish(b *testing.B) {
bc, err := NewTwofishBlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkXTEA(b *testing.B) {
bc, err := NewXTEABlockCrypt(pass[:16])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func BenchmarkSalsa20(b *testing.B) {
bc, err := NewSalsa20BlockCrypt(pass[:32])
if err != nil {
b.Fatal(err)
}
benchCrypt(b, bc)
}
func benchCrypt(b *testing.B, bc BlockCrypt) {
data := make([]byte, mtuLimit)
io.ReadFull(rand.Reader, data)
dec := make([]byte, mtuLimit)
enc := make([]byte, mtuLimit)
b.ReportAllocs()
b.SetBytes(int64(len(enc) * 2))
b.ResetTimer()
for i := 0; i < b.N; i++ {
bc.Encrypt(enc, data)
bc.Decrypt(dec, enc)
}
}
func BenchmarkCRC32(b *testing.B) {
content := make([]byte, 1024)
b.SetBytes(int64(len(content)))
for i := 0; i < b.N; i++ {
crc32.ChecksumIEEE(content)
}
}
func BenchmarkCsprngSystem(b *testing.B) {
data := make([]byte, md5.Size)
b.SetBytes(int64(len(data)))
for i := 0; i < b.N; i++ {
io.ReadFull(rand.Reader, data)
}
}
func BenchmarkCsprngMD5(b *testing.B) {
var data [md5.Size]byte
b.SetBytes(md5.Size)
for i := 0; i < b.N; i++ {
data = md5.Sum(data[:])
}
}
func BenchmarkCsprngSHA1(b *testing.B) {
var data [sha1.Size]byte
b.SetBytes(sha1.Size)
for i := 0; i < b.N; i++ {
data = sha1.Sum(data[:])
}
}
func BenchmarkCsprngNonceMD5(b *testing.B) {
var ng nonceMD5
ng.Init()
b.SetBytes(md5.Size)
data := make([]byte, md5.Size)
for i := 0; i < b.N; i++ {
ng.Fill(data)
}
}
func BenchmarkCsprngNonceAES128(b *testing.B) {
var ng nonceAES128
ng.Init()
b.SetBytes(aes.BlockSize)
data := make([]byte, aes.BlockSize)
for i := 0; i < b.N; i++ {
ng.Fill(data)
}
}
+52
View File
@@ -0,0 +1,52 @@
package kcp
import (
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"crypto/rand"
"io"
)
// Entropy defines a entropy source
type Entropy interface {
Init()
Fill(nonce []byte)
}
// nonceMD5 nonce generator for packet header
type nonceMD5 struct {
seed [md5.Size]byte
}
func (n *nonceMD5) Init() { /*nothing required*/ }
func (n *nonceMD5) Fill(nonce []byte) {
if n.seed[0] == 0 { // entropy update
io.ReadFull(rand.Reader, n.seed[:])
}
n.seed = md5.Sum(n.seed[:])
copy(nonce, n.seed[:])
}
// nonceAES128 nonce generator for packet headers
type nonceAES128 struct {
seed [aes.BlockSize]byte
block cipher.Block
}
func (n *nonceAES128) Init() {
var key [16]byte //aes-128
io.ReadFull(rand.Reader, key[:])
io.ReadFull(rand.Reader, n.seed[:])
block, _ := aes.NewCipher(key[:])
n.block = block
}
func (n *nonceAES128) Fill(nonce []byte) {
if n.seed[0] == 0 { // entropy update
io.ReadFull(rand.Reader, n.seed[:])
}
n.block.Encrypt(n.seed[:], n.seed[:])
copy(nonce, n.seed[:])
}
+381
View File
@@ -0,0 +1,381 @@
package kcp
import (
"encoding/binary"
"sync/atomic"
"github.com/klauspost/reedsolomon"
)
const (
fecHeaderSize = 6
fecHeaderSizePlus2 = fecHeaderSize + 2 // plus 2B data size
typeData = 0xf1
typeParity = 0xf2
fecExpire = 60000
rxFECMulti = 3 // FEC keeps rxFECMulti* (dataShard+parityShard) ordered packets in memory
)
// fecPacket is a decoded FEC packet
type fecPacket []byte
func (bts fecPacket) seqid() uint32 { return binary.LittleEndian.Uint32(bts) }
func (bts fecPacket) flag() uint16 { return binary.LittleEndian.Uint16(bts[4:]) }
func (bts fecPacket) data() []byte { return bts[6:] }
// fecElement has auxcilliary time field
type fecElement struct {
fecPacket
ts uint32
}
// fecDecoder for decoding incoming packets
type fecDecoder struct {
rxlimit int // queue size limit
dataShards int
parityShards int
shardSize int
rx []fecElement // ordered receive queue
// caches
decodeCache [][]byte
flagCache []bool
// zeros
zeros []byte
// RS decoder
codec reedsolomon.Encoder
// auto tune fec parameter
autoTune autoTune
}
func newFECDecoder(dataShards, parityShards int) *fecDecoder {
if dataShards <= 0 || parityShards <= 0 {
return nil
}
dec := new(fecDecoder)
dec.dataShards = dataShards
dec.parityShards = parityShards
dec.shardSize = dataShards + parityShards
dec.rxlimit = rxFECMulti * dec.shardSize
codec, err := reedsolomon.New(dataShards, parityShards)
if err != nil {
return nil
}
dec.codec = codec
dec.decodeCache = make([][]byte, dec.shardSize)
dec.flagCache = make([]bool, dec.shardSize)
dec.zeros = make([]byte, mtuLimit)
return dec
}
// decode a fec packet
func (dec *fecDecoder) decode(in fecPacket) (recovered [][]byte) {
// sample to auto FEC tuner
if in.flag() == typeData {
dec.autoTune.Sample(true, in.seqid())
} else {
dec.autoTune.Sample(false, in.seqid())
}
// check if FEC parameters is out of sync
var shouldTune bool
if int(in.seqid())%dec.shardSize < dec.dataShards {
if in.flag() != typeData { // expect typeData
shouldTune = true
}
} else {
if in.flag() != typeParity {
shouldTune = true
}
}
if shouldTune {
autoDS := dec.autoTune.FindPeriod(true)
autoPS := dec.autoTune.FindPeriod(false)
// edges found, we can tune parameters now
if autoDS > 0 && autoPS > 0 && autoDS < 256 && autoPS < 256 {
// and make sure it's different
if autoDS != dec.dataShards || autoPS != dec.parityShards {
dec.dataShards = autoDS
dec.parityShards = autoPS
dec.shardSize = autoDS + autoPS
dec.rxlimit = rxFECMulti * dec.shardSize
codec, err := reedsolomon.New(autoDS, autoPS)
if err != nil {
return nil
}
dec.codec = codec
dec.decodeCache = make([][]byte, dec.shardSize)
dec.flagCache = make([]bool, dec.shardSize)
//log.Println("autotune to :", dec.dataShards, dec.parityShards)
}
}
}
// insertion
n := len(dec.rx) - 1
insertIdx := 0
for i := n; i >= 0; i-- {
if in.seqid() == dec.rx[i].seqid() { // de-duplicate
return nil
} else if _itimediff(in.seqid(), dec.rx[i].seqid()) > 0 { // insertion
insertIdx = i + 1
break
}
}
// make a copy
pkt := fecPacket(xmitBuf.Get().([]byte)[:len(in)])
copy(pkt, in)
elem := fecElement{pkt, currentMs()}
// insert into ordered rx queue
if insertIdx == n+1 {
dec.rx = append(dec.rx, elem)
} else {
dec.rx = append(dec.rx, fecElement{})
copy(dec.rx[insertIdx+1:], dec.rx[insertIdx:]) // shift right
dec.rx[insertIdx] = elem
}
// shard range for current packet
shardBegin := pkt.seqid() - pkt.seqid()%uint32(dec.shardSize)
shardEnd := shardBegin + uint32(dec.shardSize) - 1
// max search range in ordered queue for current shard
searchBegin := insertIdx - int(pkt.seqid()%uint32(dec.shardSize))
if searchBegin < 0 {
searchBegin = 0
}
searchEnd := searchBegin + dec.shardSize - 1
if searchEnd >= len(dec.rx) {
searchEnd = len(dec.rx) - 1
}
// re-construct datashards
if searchEnd-searchBegin+1 >= dec.dataShards {
var numshard, numDataShard, first, maxlen int
// zero caches
shards := dec.decodeCache
shardsflag := dec.flagCache
for k := range dec.decodeCache {
shards[k] = nil
shardsflag[k] = false
}
// shard assembly
for i := searchBegin; i <= searchEnd; i++ {
seqid := dec.rx[i].seqid()
if _itimediff(seqid, shardEnd) > 0 {
break
} else if _itimediff(seqid, shardBegin) >= 0 {
shards[seqid%uint32(dec.shardSize)] = dec.rx[i].data()
shardsflag[seqid%uint32(dec.shardSize)] = true
numshard++
if dec.rx[i].flag() == typeData {
numDataShard++
}
if numshard == 1 {
first = i
}
if len(dec.rx[i].data()) > maxlen {
maxlen = len(dec.rx[i].data())
}
}
}
if numDataShard == dec.dataShards {
// case 1: no loss on data shards
dec.rx = dec.freeRange(first, numshard, dec.rx)
} else if numshard >= dec.dataShards {
// case 2: loss on data shards, but it's recoverable from parity shards
for k := range shards {
if shards[k] != nil {
dlen := len(shards[k])
shards[k] = shards[k][:maxlen]
copy(shards[k][dlen:], dec.zeros)
} else if k < dec.dataShards {
shards[k] = xmitBuf.Get().([]byte)[:0]
}
}
if err := dec.codec.ReconstructData(shards); err == nil {
for k := range shards[:dec.dataShards] {
if !shardsflag[k] {
// recovered data should be recycled
recovered = append(recovered, shards[k])
}
}
}
dec.rx = dec.freeRange(first, numshard, dec.rx)
}
}
// keep rxlimit
if len(dec.rx) > dec.rxlimit {
if dec.rx[0].flag() == typeData { // track the unrecoverable data
atomic.AddUint64(&DefaultSnmp.FECShortShards, 1)
}
dec.rx = dec.freeRange(0, 1, dec.rx)
}
// timeout policy
current := currentMs()
numExpired := 0
for k := range dec.rx {
if _itimediff(current, dec.rx[k].ts) > fecExpire {
numExpired++
continue
}
break
}
if numExpired > 0 {
dec.rx = dec.freeRange(0, numExpired, dec.rx)
}
return
}
// free a range of fecPacket
func (dec *fecDecoder) freeRange(first, n int, q []fecElement) []fecElement {
for i := first; i < first+n; i++ { // recycle buffer
xmitBuf.Put([]byte(q[i].fecPacket))
}
if first == 0 && n < cap(q)/2 {
return q[n:]
}
copy(q[first:], q[first+n:])
return q[:len(q)-n]
}
// release all segments back to xmitBuf
func (dec *fecDecoder) release() {
if n := len(dec.rx); n > 0 {
dec.rx = dec.freeRange(0, n, dec.rx)
}
}
type (
// fecEncoder for encoding outgoing packets
fecEncoder struct {
dataShards int
parityShards int
shardSize int
paws uint32 // Protect Against Wrapped Sequence numbers
next uint32 // next seqid
shardCount int // count the number of datashards collected
maxSize int // track maximum data length in datashard
headerOffset int // FEC header offset
payloadOffset int // FEC payload offset
// caches
shardCache [][]byte
encodeCache [][]byte
// zeros
zeros []byte
// RS encoder
codec reedsolomon.Encoder
}
)
func newFECEncoder(dataShards, parityShards, offset int) *fecEncoder {
if dataShards <= 0 || parityShards <= 0 {
return nil
}
enc := new(fecEncoder)
enc.dataShards = dataShards
enc.parityShards = parityShards
enc.shardSize = dataShards + parityShards
enc.paws = 0xffffffff / uint32(enc.shardSize) * uint32(enc.shardSize)
enc.headerOffset = offset
enc.payloadOffset = enc.headerOffset + fecHeaderSize
codec, err := reedsolomon.New(dataShards, parityShards)
if err != nil {
return nil
}
enc.codec = codec
// caches
enc.encodeCache = make([][]byte, enc.shardSize)
enc.shardCache = make([][]byte, enc.shardSize)
for k := range enc.shardCache {
enc.shardCache[k] = make([]byte, mtuLimit)
}
enc.zeros = make([]byte, mtuLimit)
return enc
}
// encodes the packet, outputs parity shards if we have collected quorum datashards
// notice: the contents of 'ps' will be re-written in successive calling
func (enc *fecEncoder) encode(b []byte) (ps [][]byte) {
// The header format:
// | FEC SEQID(4B) | FEC TYPE(2B) | SIZE (2B) | PAYLOAD(SIZE-2) |
// |<-headerOffset |<-payloadOffset
enc.markData(b[enc.headerOffset:])
binary.LittleEndian.PutUint16(b[enc.payloadOffset:], uint16(len(b[enc.payloadOffset:])))
// copy data from payloadOffset to fec shard cache
sz := len(b)
enc.shardCache[enc.shardCount] = enc.shardCache[enc.shardCount][:sz]
copy(enc.shardCache[enc.shardCount][enc.payloadOffset:], b[enc.payloadOffset:])
enc.shardCount++
// track max datashard length
if sz > enc.maxSize {
enc.maxSize = sz
}
// Generation of Reed-Solomon Erasure Code
if enc.shardCount == enc.dataShards {
// fill '0' into the tail of each datashard
for i := 0; i < enc.dataShards; i++ {
shard := enc.shardCache[i]
slen := len(shard)
copy(shard[slen:enc.maxSize], enc.zeros)
}
// construct equal-sized slice with stripped header
cache := enc.encodeCache
for k := range cache {
cache[k] = enc.shardCache[k][enc.payloadOffset:enc.maxSize]
}
// encoding
if err := enc.codec.Encode(cache); err == nil {
ps = enc.shardCache[enc.dataShards:]
for k := range ps {
enc.markParity(ps[k][enc.headerOffset:])
ps[k] = ps[k][:enc.maxSize]
}
}
// counters resetting
enc.shardCount = 0
enc.maxSize = 0
}
return
}
func (enc *fecEncoder) markData(data []byte) {
binary.LittleEndian.PutUint32(data, enc.next)
binary.LittleEndian.PutUint16(data[4:], typeData)
enc.next++
}
func (enc *fecEncoder) markParity(data []byte) {
binary.LittleEndian.PutUint32(data, enc.next)
binary.LittleEndian.PutUint16(data[4:], typeParity)
// sequence wrap will only happen at parity shard
enc.next = (enc.next + 1) % enc.paws
}
+43
View File
@@ -0,0 +1,43 @@
package kcp
import (
"encoding/binary"
"math/rand"
"testing"
)
func BenchmarkFECDecode(b *testing.B) {
const dataSize = 10
const paritySize = 3
const payLoad = 1500
decoder := newFECDecoder(dataSize, paritySize)
b.ReportAllocs()
b.SetBytes(payLoad)
for i := 0; i < b.N; i++ {
if rand.Int()%(dataSize+paritySize) == 0 { // random loss
continue
}
pkt := make([]byte, payLoad)
binary.LittleEndian.PutUint32(pkt, uint32(i))
if i%(dataSize+paritySize) >= dataSize {
binary.LittleEndian.PutUint16(pkt[4:], typeParity)
} else {
binary.LittleEndian.PutUint16(pkt[4:], typeData)
}
decoder.decode(pkt)
}
}
func BenchmarkFECEncode(b *testing.B) {
const dataSize = 10
const paritySize = 3
const payLoad = 1500
b.ReportAllocs()
b.SetBytes(payLoad)
encoder := newFECEncoder(dataSize, paritySize, 0)
for i := 0; i < b.N; i++ {
data := make([]byte, payLoad)
encoder.encode(data)
}
}
+1094
View File
File diff suppressed because it is too large Load Diff
+135
View File
@@ -0,0 +1,135 @@
package kcp
import (
"io"
"net"
"sync"
"testing"
"time"
"github.com/xtaci/lossyconn"
)
const repeat = 16
func TestLossyConn1(t *testing.T) {
t.Log("testing loss rate 10%, rtt 200ms")
t.Log("testing link with nodelay parameters:1 10 2 1")
client, err := lossyconn.NewLossyConn(0.1, 100)
if err != nil {
t.Fatal(err)
}
server, err := lossyconn.NewLossyConn(0.1, 100)
if err != nil {
t.Fatal(err)
}
testlink(t, client, server, 1, 10, 2, 1)
}
func TestLossyConn2(t *testing.T) {
t.Log("testing loss rate 20%, rtt 200ms")
t.Log("testing link with nodelay parameters:1 10 2 1")
client, err := lossyconn.NewLossyConn(0.2, 100)
if err != nil {
t.Fatal(err)
}
server, err := lossyconn.NewLossyConn(0.2, 100)
if err != nil {
t.Fatal(err)
}
testlink(t, client, server, 1, 10, 2, 1)
}
func TestLossyConn3(t *testing.T) {
t.Log("testing loss rate 30%, rtt 200ms")
t.Log("testing link with nodelay parameters:1 10 2 1")
client, err := lossyconn.NewLossyConn(0.3, 100)
if err != nil {
t.Fatal(err)
}
server, err := lossyconn.NewLossyConn(0.3, 100)
if err != nil {
t.Fatal(err)
}
testlink(t, client, server, 1, 10, 2, 1)
}
func TestLossyConn4(t *testing.T) {
t.Log("testing loss rate 10%, rtt 200ms")
t.Log("testing link with nodelay parameters:1 10 2 0")
client, err := lossyconn.NewLossyConn(0.1, 100)
if err != nil {
t.Fatal(err)
}
server, err := lossyconn.NewLossyConn(0.1, 100)
if err != nil {
t.Fatal(err)
}
testlink(t, client, server, 1, 10, 2, 0)
}
func testlink(t *testing.T, client *lossyconn.LossyConn, server *lossyconn.LossyConn, nodelay, interval, resend, nc int) {
t.Log("testing with nodelay parameters:", nodelay, interval, resend, nc)
sess, _ := NewConn2(server.LocalAddr(), nil, 0, 0, client)
listener, _ := ServeConn(nil, 0, 0, server)
echoServer := func(l *Listener) {
for {
conn, err := l.AcceptKCP()
if err != nil {
return
}
go func() {
conn.SetNoDelay(nodelay, interval, resend, nc)
buf := make([]byte, 65536)
for {
n, err := conn.Read(buf)
if err != nil {
return
}
conn.Write(buf[:n])
}
}()
}
}
echoTester := func(s *UDPSession, raddr net.Addr) {
s.SetNoDelay(nodelay, interval, resend, nc)
buf := make([]byte, 64)
var rtt time.Duration
for i := 0; i < repeat; i++ {
start := time.Now()
s.Write(buf)
io.ReadFull(s, buf)
rtt += time.Since(start)
}
t.Log("client:", client)
t.Log("server:", server)
t.Log("avg rtt:", rtt/repeat)
t.Logf("total time: %v for %v round trip:", rtt, repeat)
}
go echoServer(listener)
echoTester(sess, server.LocalAddr())
}
func BenchmarkFlush(b *testing.B) {
kcp := NewKCP(1, func(buf []byte, size int) {})
kcp.snd_buf = make([]segment, 1024)
for k := range kcp.snd_buf {
kcp.snd_buf[k].xmit = 1
kcp.snd_buf[k].resendts = currentMs() + 10000
}
b.ResetTimer()
b.ReportAllocs()
var mu sync.Mutex
for i := 0; i < b.N; i++ {
mu.Lock()
kcp.flush(false)
mu.Unlock()
}
}
+126
View File
@@ -0,0 +1,126 @@
package kcp
import (
"bytes"
"encoding/binary"
"github.com/pkg/errors"
)
func (s *UDPSession) defaultReadLoop() {
buf := make([]byte, mtuLimit)
var src string
for {
if n, addr, err := s.conn.ReadFrom(buf); err == nil {
udpPayload := buf[:n]
// make sure the packet is from the same source
if src == "" { // set source address
src = addr.String()
} else if addr.String() != src {
//atomic.AddUint64(&DefaultSnmp.InErrs, 1)
//continue
s.remote = addr
src = addr.String()
}
s.packetInput(udpPayload)
} else {
s.notifyReadError(errors.WithStack(err))
return
}
}
}
func (l *Listener) defaultMonitor() {
buf := make([]byte, mtuLimit)
for {
if n, from, err := l.conn.ReadFrom(buf); err == nil {
udpPayload := buf[:n]
var convId uint64 = 0
if n == 20 {
// 原神KCP的Enet协议
// 提取convId
convId += uint64(udpPayload[4]) << 24
convId += uint64(udpPayload[5]) << 16
convId += uint64(udpPayload[6]) << 8
convId += uint64(udpPayload[7]) << 0
convId += uint64(udpPayload[8]) << 56
convId += uint64(udpPayload[9]) << 48
convId += uint64(udpPayload[10]) << 40
convId += uint64(udpPayload[11]) << 32
// 提取Enet协议头部和尾部幻数
udpPayloadEnetHead := udpPayload[:4]
udpPayloadEnetTail := udpPayload[len(udpPayload)-4:]
// 提取Enet协议类型
enetTypeData := udpPayload[12:16]
enetTypeDataBuffer := bytes.NewBuffer(enetTypeData)
var enetType uint32
_ = binary.Read(enetTypeDataBuffer, binary.BigEndian, &enetType)
equalHead := bytes.Compare(udpPayloadEnetHead, MagicEnetSynHead)
equalTail := bytes.Compare(udpPayloadEnetTail, MagicEnetSynTail)
if equalHead == 0 && equalTail == 0 {
// 客户端前置握手获取conv
l.EnetNotify <- &Enet{
Addr: from.String(),
ConvId: convId,
ConnType: ConnEnetSyn,
EnetType: enetType,
}
continue
}
equalHead = bytes.Compare(udpPayloadEnetHead, MagicEnetEstHead)
equalTail = bytes.Compare(udpPayloadEnetTail, MagicEnetEstTail)
if equalHead == 0 && equalTail == 0 {
// 连接建立
l.EnetNotify <- &Enet{
Addr: from.String(),
ConvId: convId,
ConnType: ConnEnetEst,
EnetType: enetType,
}
continue
}
equalHead = bytes.Compare(udpPayloadEnetHead, MagicEnetFinHead)
equalTail = bytes.Compare(udpPayloadEnetTail, MagicEnetFinTail)
if equalHead == 0 && equalTail == 0 {
// 连接断开
l.EnetNotify <- &Enet{
Addr: from.String(),
ConvId: convId,
ConnType: ConnEnetFin,
EnetType: enetType,
}
continue
}
} else {
// 正常KCP包
convId += uint64(udpPayload[0]) << 0
convId += uint64(udpPayload[1]) << 8
convId += uint64(udpPayload[2]) << 16
convId += uint64(udpPayload[3]) << 24
convId += uint64(udpPayload[4]) << 32
convId += uint64(udpPayload[5]) << 40
convId += uint64(udpPayload[6]) << 48
convId += uint64(udpPayload[7]) << 56
}
l.sessionLock.RLock()
conn, exist := l.sessions[convId]
l.sessionLock.RUnlock()
if exist {
if conn.remote.String() != from.String() {
conn.remote = from
// 连接地址改变
l.EnetNotify <- &Enet{
Addr: conn.remote.String(),
ConvId: convId,
ConnType: ConnEnetAddrChange,
}
}
}
l.packetInput(udpPayload, from, convId)
} else {
l.notifyReadError(errors.WithStack(err))
return
}
}
}
+12
View File
@@ -0,0 +1,12 @@
//go:build !linux
// +build !linux
package kcp
func (s *UDPSession) readLoop() {
s.defaultReadLoop()
}
func (l *Listener) monitor() {
l.defaultMonitor()
}
+199
View File
@@ -0,0 +1,199 @@
//go:build linux
// +build linux
package kcp
import (
"bytes"
"encoding/binary"
"github.com/pkg/errors"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"net"
"os"
)
// the read loop for a client session
func (s *UDPSession) readLoop() {
// default version
if s.xconn == nil {
s.defaultReadLoop()
return
}
// x/net version
var src string
msgs := make([]ipv4.Message, batchSize)
for k := range msgs {
msgs[k].Buffers = [][]byte{make([]byte, mtuLimit)}
}
for {
if count, err := s.xconn.ReadBatch(msgs, 0); err == nil {
for i := 0; i < count; i++ {
msg := &msgs[i]
// make sure the packet is from the same source
if src == "" { // set source address if nil
src = msg.Addr.String()
} else if msg.Addr.String() != src {
//atomic.AddUint64(&DefaultSnmp.InErrs, 1)
//continue
s.remote = msg.Addr
src = msg.Addr.String()
}
udpPayload := msg.Buffers[0][:msg.N]
// source and size has validated
s.packetInput(udpPayload)
}
} else {
// compatibility issue:
// for linux kernel<=2.6.32, support for sendmmsg is not available
// an error of type os.SyscallError will be returned
if operr, ok := err.(*net.OpError); ok {
if se, ok := operr.Err.(*os.SyscallError); ok {
if se.Syscall == "recvmmsg" {
s.defaultReadLoop()
return
}
}
}
s.notifyReadError(errors.WithStack(err))
return
}
}
}
// monitor incoming data for all connections of server
func (l *Listener) monitor() {
var xconn batchConn
if _, ok := l.conn.(*net.UDPConn); ok {
addr, err := net.ResolveUDPAddr("udp", l.conn.LocalAddr().String())
if err == nil {
if addr.IP.To4() != nil {
xconn = ipv4.NewPacketConn(l.conn)
} else {
xconn = ipv6.NewPacketConn(l.conn)
}
}
}
// default version
if xconn == nil {
l.defaultMonitor()
return
}
// x/net version
msgs := make([]ipv4.Message, batchSize)
for k := range msgs {
msgs[k].Buffers = [][]byte{make([]byte, mtuLimit)}
}
for {
if count, err := xconn.ReadBatch(msgs, 0); err == nil {
for i := 0; i < count; i++ {
msg := &msgs[i]
udpPayload := msg.Buffers[0][:msg.N]
var convId uint64 = 0
if msg.N == 20 {
// 原神KCP的Enet协议
// 提取convId
convId += uint64(udpPayload[4]) << 24
convId += uint64(udpPayload[5]) << 16
convId += uint64(udpPayload[6]) << 8
convId += uint64(udpPayload[7]) << 0
convId += uint64(udpPayload[8]) << 56
convId += uint64(udpPayload[9]) << 48
convId += uint64(udpPayload[10]) << 40
convId += uint64(udpPayload[11]) << 32
// 提取Enet协议头部和尾部幻数
udpPayloadEnetHead := udpPayload[:4]
udpPayloadEnetTail := udpPayload[len(udpPayload)-4:]
// 提取Enet协议类型
enetTypeData := udpPayload[12:16]
enetTypeDataBuffer := bytes.NewBuffer(enetTypeData)
var enetType uint32
_ = binary.Read(enetTypeDataBuffer, binary.BigEndian, &enetType)
equalHead := bytes.Compare(udpPayloadEnetHead, MagicEnetSynHead)
equalTail := bytes.Compare(udpPayloadEnetTail, MagicEnetSynTail)
if equalHead == 0 && equalTail == 0 {
// 客户端前置握手获取conv
l.EnetNotify <- &Enet{
Addr: msg.Addr.String(),
ConvId: convId,
ConnType: ConnEnetSyn,
EnetType: enetType,
}
continue
}
equalHead = bytes.Compare(udpPayloadEnetHead, MagicEnetEstHead)
equalTail = bytes.Compare(udpPayloadEnetTail, MagicEnetEstTail)
if equalHead == 0 && equalTail == 0 {
// 连接建立
l.EnetNotify <- &Enet{
Addr: msg.Addr.String(),
ConvId: convId,
ConnType: ConnEnetEst,
EnetType: enetType,
}
continue
}
equalHead = bytes.Compare(udpPayloadEnetHead, MagicEnetFinHead)
equalTail = bytes.Compare(udpPayloadEnetTail, MagicEnetFinTail)
if equalHead == 0 && equalTail == 0 {
// 连接断开
l.EnetNotify <- &Enet{
Addr: msg.Addr.String(),
ConvId: convId,
ConnType: ConnEnetFin,
EnetType: enetType,
}
continue
}
} else {
// 正常KCP包
convId += uint64(udpPayload[0]) << 0
convId += uint64(udpPayload[1]) << 8
convId += uint64(udpPayload[2]) << 16
convId += uint64(udpPayload[3]) << 24
convId += uint64(udpPayload[4]) << 32
convId += uint64(udpPayload[5]) << 40
convId += uint64(udpPayload[6]) << 48
convId += uint64(udpPayload[7]) << 56
}
l.sessionLock.RLock()
conn, exist := l.sessions[convId]
l.sessionLock.RUnlock()
if exist {
if conn.remote.String() != msg.Addr.String() {
conn.remote = msg.Addr
// 连接地址改变
l.EnetNotify <- &Enet{
Addr: conn.remote.String(),
ConvId: convId,
ConnType: ConnEnetAddrChange,
}
}
}
l.packetInput(udpPayload, msg.Addr, convId)
}
} else {
// compatibility issue:
// for linux kernel<=2.6.32, support for sendmmsg is not available
// an error of type os.SyscallError will be returned
if operr, ok := err.(*net.OpError); ok {
if se, ok := operr.Err.(*os.SyscallError); ok {
if se.Syscall == "recvmmsg" {
l.defaultMonitor()
return
}
}
}
l.notifyReadError(errors.WithStack(err))
return
}
}
}
+1144
View File
File diff suppressed because it is too large Load Diff
+703
View File
@@ -0,0 +1,703 @@
package kcp
import (
"crypto/sha1"
"fmt"
"io"
"log"
"net"
"net/http"
_ "net/http/pprof"
"sync"
"sync/atomic"
"testing"
"time"
"golang.org/x/crypto/pbkdf2"
)
var baseport = uint32(10000)
var key = []byte("testkey")
var pass = pbkdf2.Key(key, []byte("testsalt"), 4096, 32, sha1.New)
func init() {
go func() {
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
}()
log.Println("beginning tests, encryption:salsa20, fec:10/3")
}
func dialEcho(port int) (*UDPSession, error) {
//block, _ := NewNoneBlockCrypt(pass)
//block, _ := NewSimpleXORBlockCrypt(pass)
//block, _ := NewTEABlockCrypt(pass[:16])
//block, _ := NewAESBlockCrypt(pass)
block, _ := NewSalsa20BlockCrypt(pass)
sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port), block, 10, 3)
if err != nil {
panic(err)
}
sess.SetStreamMode(true)
sess.SetStreamMode(false)
sess.SetStreamMode(true)
sess.SetWindowSize(1024, 1024)
sess.SetReadBuffer(16 * 1024 * 1024)
sess.SetWriteBuffer(16 * 1024 * 1024)
sess.SetStreamMode(true)
sess.SetNoDelay(1, 10, 2, 1)
sess.SetMtu(1400)
sess.SetMtu(1600)
sess.SetMtu(1400)
sess.SetACKNoDelay(true)
sess.SetACKNoDelay(false)
sess.SetDeadline(time.Now().Add(time.Minute))
return sess, err
}
func dialSink(port int) (*UDPSession, error) {
sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port), nil, 0, 0)
if err != nil {
panic(err)
}
sess.SetStreamMode(true)
sess.SetWindowSize(1024, 1024)
sess.SetReadBuffer(16 * 1024 * 1024)
sess.SetWriteBuffer(16 * 1024 * 1024)
sess.SetStreamMode(true)
sess.SetNoDelay(1, 10, 2, 1)
sess.SetMtu(1400)
sess.SetACKNoDelay(false)
sess.SetDeadline(time.Now().Add(time.Minute))
return sess, err
}
func dialTinyBufferEcho(port int) (*UDPSession, error) {
//block, _ := NewNoneBlockCrypt(pass)
//block, _ := NewSimpleXORBlockCrypt(pass)
//block, _ := NewTEABlockCrypt(pass[:16])
//block, _ := NewAESBlockCrypt(pass)
block, _ := NewSalsa20BlockCrypt(pass)
sess, err := DialWithOptions(fmt.Sprintf("127.0.0.1:%v", port), block, 10, 3)
if err != nil {
panic(err)
}
return sess, err
}
// ////////////////////////
func listenEcho(port int) (net.Listener, error) {
//block, _ := NewNoneBlockCrypt(pass)
//block, _ := NewSimpleXORBlockCrypt(pass)
//block, _ := NewTEABlockCrypt(pass[:16])
//block, _ := NewAESBlockCrypt(pass)
block, _ := NewSalsa20BlockCrypt(pass)
return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), block, 10, 0)
}
func listenTinyBufferEcho(port int) (net.Listener, error) {
//block, _ := NewNoneBlockCrypt(pass)
//block, _ := NewSimpleXORBlockCrypt(pass)
//block, _ := NewTEABlockCrypt(pass[:16])
//block, _ := NewAESBlockCrypt(pass)
block, _ := NewSalsa20BlockCrypt(pass)
return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), block, 10, 3)
}
func listenSink(port int) (net.Listener, error) {
return ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), nil, 0, 0)
}
func echoServer(port int) net.Listener {
l, err := listenEcho(port)
if err != nil {
panic(err)
}
go func() {
kcplistener := l.(*Listener)
kcplistener.SetReadBuffer(4 * 1024 * 1024)
kcplistener.SetWriteBuffer(4 * 1024 * 1024)
kcplistener.SetDSCP(46)
for {
s, err := l.Accept()
if err != nil {
return
}
// coverage test
s.(*UDPSession).SetReadBuffer(4 * 1024 * 1024)
s.(*UDPSession).SetWriteBuffer(4 * 1024 * 1024)
go handleEcho(s.(*UDPSession))
}
}()
return l
}
func sinkServer(port int) net.Listener {
l, err := listenSink(port)
if err != nil {
panic(err)
}
go func() {
kcplistener := l.(*Listener)
kcplistener.SetReadBuffer(4 * 1024 * 1024)
kcplistener.SetWriteBuffer(4 * 1024 * 1024)
kcplistener.SetDSCP(46)
for {
s, err := l.Accept()
if err != nil {
return
}
go handleSink(s.(*UDPSession))
}
}()
return l
}
func tinyBufferEchoServer(port int) net.Listener {
l, err := listenTinyBufferEcho(port)
if err != nil {
panic(err)
}
go func() {
for {
s, err := l.Accept()
if err != nil {
return
}
go handleTinyBufferEcho(s.(*UDPSession))
}
}()
return l
}
///////////////////////////
func handleEcho(conn *UDPSession) {
conn.SetStreamMode(true)
conn.SetWindowSize(4096, 4096)
conn.SetNoDelay(1, 10, 2, 1)
conn.SetDSCP(46)
conn.SetMtu(1400)
conn.SetACKNoDelay(false)
conn.SetReadDeadline(time.Now().Add(time.Hour))
conn.SetWriteDeadline(time.Now().Add(time.Hour))
buf := make([]byte, 65536)
for {
n, err := conn.Read(buf)
if err != nil {
return
}
conn.Write(buf[:n])
}
}
func handleSink(conn *UDPSession) {
conn.SetStreamMode(true)
conn.SetWindowSize(4096, 4096)
conn.SetNoDelay(1, 10, 2, 1)
conn.SetDSCP(46)
conn.SetMtu(1400)
conn.SetACKNoDelay(false)
conn.SetReadDeadline(time.Now().Add(time.Hour))
conn.SetWriteDeadline(time.Now().Add(time.Hour))
buf := make([]byte, 65536)
for {
_, err := conn.Read(buf)
if err != nil {
return
}
}
}
func handleTinyBufferEcho(conn *UDPSession) {
conn.SetStreamMode(true)
buf := make([]byte, 2)
for {
n, err := conn.Read(buf)
if err != nil {
return
}
conn.Write(buf[:n])
}
}
///////////////////////////
func TestTimeout(t *testing.T) {
port := int(atomic.AddUint32(&baseport, 1))
l := echoServer(port)
defer l.Close()
cli, err := dialEcho(port)
if err != nil {
panic(err)
}
buf := make([]byte, 10)
//timeout
cli.SetDeadline(time.Now().Add(time.Second))
<-time.After(2 * time.Second)
n, err := cli.Read(buf)
if n != 0 || err == nil {
t.Fail()
}
cli.Close()
}
func TestSendRecv(t *testing.T) {
port := int(atomic.AddUint32(&baseport, 1))
l := echoServer(port)
defer l.Close()
cli, err := dialEcho(port)
if err != nil {
panic(err)
}
cli.SetWriteDelay(true)
cli.SetDUP(1)
const N = 100
buf := make([]byte, 10)
for i := 0; i < N; i++ {
msg := fmt.Sprintf("hello%v", i)
cli.Write([]byte(msg))
if n, err := cli.Read(buf); err == nil {
if string(buf[:n]) != msg {
t.Fail()
}
} else {
panic(err)
}
}
cli.Close()
}
func TestSendVector(t *testing.T) {
port := int(atomic.AddUint32(&baseport, 1))
l := echoServer(port)
defer l.Close()
cli, err := dialEcho(port)
if err != nil {
panic(err)
}
cli.SetWriteDelay(false)
const N = 100
buf := make([]byte, 20)
v := make([][]byte, 2)
for i := 0; i < N; i++ {
v[0] = []byte(fmt.Sprintf("hello%v", i))
v[1] = []byte(fmt.Sprintf("world%v", i))
msg := fmt.Sprintf("hello%vworld%v", i, i)
cli.WriteBuffers(v)
if n, err := cli.Read(buf); err == nil {
if string(buf[:n]) != msg {
t.Error(string(buf[:n]), msg)
}
} else {
panic(err)
}
}
cli.Close()
}
func TestTinyBufferReceiver(t *testing.T) {
port := int(atomic.AddUint32(&baseport, 1))
l := tinyBufferEchoServer(port)
defer l.Close()
cli, err := dialTinyBufferEcho(port)
if err != nil {
panic(err)
}
const N = 100
snd := byte(0)
fillBuffer := func(buf []byte) {
for i := 0; i < len(buf); i++ {
buf[i] = snd
snd++
}
}
rcv := byte(0)
check := func(buf []byte) bool {
for i := 0; i < len(buf); i++ {
if buf[i] != rcv {
return false
}
rcv++
}
return true
}
sndbuf := make([]byte, 7)
rcvbuf := make([]byte, 7)
for i := 0; i < N; i++ {
fillBuffer(sndbuf)
cli.Write(sndbuf)
if n, err := io.ReadFull(cli, rcvbuf); err == nil {
if !check(rcvbuf[:n]) {
t.Fail()
}
} else {
panic(err)
}
}
cli.Close()
}
func TestClose(t *testing.T) {
var n int
var err error
port := int(atomic.AddUint32(&baseport, 1))
l := echoServer(port)
defer l.Close()
cli, err := dialEcho(port)
if err != nil {
panic(err)
}
// double close
cli.Close()
if cli.Close() == nil {
t.Fatal("double close misbehavior")
}
// write after close
buf := make([]byte, 10)
n, err = cli.Write(buf)
if n != 0 || err == nil {
t.Fatal("write after close misbehavior")
}
// write, close, read, read
cli, err = dialEcho(port)
if err != nil {
panic(err)
}
if n, err = cli.Write(buf); err != nil {
t.Fatal("write misbehavior")
}
// wait until data arrival
time.Sleep(2 * time.Second)
// drain
cli.Close()
n, err = io.ReadFull(cli, buf)
if err != nil {
t.Fatal("closed conn drain bytes failed", err, n)
}
// after drain, read should return error
n, err = cli.Read(buf)
if n != 0 || err == nil {
t.Fatal("write->close->drain->read misbehavior", err, n)
}
cli.Close()
}
func TestParallel1024CLIENT_64BMSG_64CNT(t *testing.T) {
port := int(atomic.AddUint32(&baseport, 1))
l := echoServer(port)
defer l.Close()
var wg sync.WaitGroup
wg.Add(1024)
for i := 0; i < 1024; i++ {
go parallel_client(&wg, port)
}
wg.Wait()
}
func parallel_client(wg *sync.WaitGroup, port int) (err error) {
cli, err := dialEcho(port)
if err != nil {
panic(err)
}
err = echo_tester(cli, 64, 64)
cli.Close()
wg.Done()
return
}
func BenchmarkEchoSpeed4K(b *testing.B) {
speedclient(b, 4096)
}
func BenchmarkEchoSpeed64K(b *testing.B) {
speedclient(b, 65536)
}
func BenchmarkEchoSpeed512K(b *testing.B) {
speedclient(b, 524288)
}
func BenchmarkEchoSpeed1M(b *testing.B) {
speedclient(b, 1048576)
}
func speedclient(b *testing.B, nbytes int) {
port := int(atomic.AddUint32(&baseport, 1))
l := echoServer(port)
defer l.Close()
b.ReportAllocs()
cli, err := dialEcho(port)
if err != nil {
panic(err)
}
if err := echo_tester(cli, nbytes, b.N); err != nil {
b.Fail()
}
b.SetBytes(int64(nbytes))
cli.Close()
}
func BenchmarkSinkSpeed4K(b *testing.B) {
sinkclient(b, 4096)
}
func BenchmarkSinkSpeed64K(b *testing.B) {
sinkclient(b, 65536)
}
func BenchmarkSinkSpeed256K(b *testing.B) {
sinkclient(b, 524288)
}
func BenchmarkSinkSpeed1M(b *testing.B) {
sinkclient(b, 1048576)
}
func sinkclient(b *testing.B, nbytes int) {
port := int(atomic.AddUint32(&baseport, 1))
l := sinkServer(port)
defer l.Close()
b.ReportAllocs()
cli, err := dialSink(port)
if err != nil {
panic(err)
}
sink_tester(cli, nbytes, b.N)
b.SetBytes(int64(nbytes))
cli.Close()
}
func echo_tester(cli net.Conn, msglen, msgcount int) error {
buf := make([]byte, msglen)
for i := 0; i < msgcount; i++ {
// send packet
if _, err := cli.Write(buf); err != nil {
return err
}
// receive packet
nrecv := 0
for {
n, err := cli.Read(buf)
if err != nil {
return err
} else {
nrecv += n
if nrecv == msglen {
break
}
}
}
}
return nil
}
func sink_tester(cli *UDPSession, msglen, msgcount int) error {
// sender
buf := make([]byte, msglen)
for i := 0; i < msgcount; i++ {
if _, err := cli.Write(buf); err != nil {
return err
}
}
return nil
}
func TestSNMP(t *testing.T) {
t.Log(DefaultSnmp.Copy())
t.Log(DefaultSnmp.Header())
t.Log(DefaultSnmp.ToSlice())
DefaultSnmp.Reset()
t.Log(DefaultSnmp.ToSlice())
}
func TestListenerClose(t *testing.T) {
port := int(atomic.AddUint32(&baseport, 1))
l, err := ListenWithOptions(fmt.Sprintf("127.0.0.1:%v", port), nil, 10, 3)
if err != nil {
t.Fail()
}
l.SetReadDeadline(time.Now().Add(time.Second))
l.SetWriteDeadline(time.Now().Add(time.Second))
l.SetDeadline(time.Now().Add(time.Second))
time.Sleep(2 * time.Second)
if _, err := l.Accept(); err == nil {
t.Fail()
}
l.Close()
//fakeaddr, _ := net.ResolveUDPAddr("udp6", "127.0.0.1:1111")
fakeConvId := uint64(0)
if l.closeSession(fakeConvId) {
t.Fail()
}
}
// A wrapper for net.PacketConn that remembers when Close has been called.
type closedFlagPacketConn struct {
net.PacketConn
Closed bool
}
func (c *closedFlagPacketConn) Close() error {
c.Closed = true
return c.PacketConn.Close()
}
func newClosedFlagPacketConn(c net.PacketConn) *closedFlagPacketConn {
return &closedFlagPacketConn{c, false}
}
// Listener should close a net.PacketConn that it created.
// https://github.com/xtaci/kcp-go/issues/165
func TestListenerOwnedPacketConn(t *testing.T) {
// ListenWithOptions creates its own net.PacketConn.
l, err := ListenWithOptions("127.0.0.1:0", nil, 0, 0)
if err != nil {
panic(err)
}
defer l.Close()
// Replace the internal net.PacketConn with one that remembers when it
// has been closed.
pconn := newClosedFlagPacketConn(l.conn)
l.conn = pconn
if pconn.Closed {
t.Fatal("owned PacketConn closed before Listener.Close()")
}
err = l.Close()
if err != nil {
panic(err)
}
if !pconn.Closed {
t.Fatal("owned PacketConn not closed after Listener.Close()")
}
}
// Listener should not close a net.PacketConn that it did not create.
// https://github.com/xtaci/kcp-go/issues/165
func TestListenerNonOwnedPacketConn(t *testing.T) {
// Create a net.PacketConn not owned by the Listener.
c, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
panic(err)
}
defer c.Close()
// Make it remember when it has been closed.
pconn := newClosedFlagPacketConn(c)
l, err := ServeConn(nil, 0, 0, pconn)
if err != nil {
panic(err)
}
defer l.Close()
if pconn.Closed {
t.Fatal("non-owned PacketConn closed before Listener.Close()")
}
err = l.Close()
if err != nil {
panic(err)
}
if pconn.Closed {
t.Fatal("non-owned PacketConn closed after Listener.Close()")
}
}
// UDPSession should close a net.PacketConn that it created.
// https://github.com/xtaci/kcp-go/issues/165
func TestUDPSessionOwnedPacketConn(t *testing.T) {
l := sinkServer(0)
defer l.Close()
// DialWithOptions creates its own net.PacketConn.
client, err := DialWithOptions(l.Addr().String(), nil, 0, 0)
if err != nil {
panic(err)
}
defer client.Close()
// Replace the internal net.PacketConn with one that remembers when it
// has been closed.
pconn := newClosedFlagPacketConn(client.conn)
client.conn = pconn
if pconn.Closed {
t.Fatal("owned PacketConn closed before UDPSession.Close()")
}
err = client.Close()
if err != nil {
panic(err)
}
if !pconn.Closed {
t.Fatal("owned PacketConn not closed after UDPSession.Close()")
}
}
// UDPSession should not close a net.PacketConn that it did not create.
// https://github.com/xtaci/kcp-go/issues/165
func TestUDPSessionNonOwnedPacketConn(t *testing.T) {
l := sinkServer(0)
defer l.Close()
// Create a net.PacketConn not owned by the UDPSession.
c, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
panic(err)
}
defer c.Close()
// Make it remember when it has been closed.
pconn := newClosedFlagPacketConn(c)
client, err := NewConn2(l.Addr(), nil, 0, 0, pconn)
if err != nil {
panic(err)
}
defer client.Close()
if pconn.Closed {
t.Fatal("non-owned PacketConn closed before UDPSession.Close()")
}
err = client.Close()
if err != nil {
panic(err)
}
if pconn.Closed {
t.Fatal("non-owned PacketConn closed after UDPSession.Close()")
}
}
+164
View File
@@ -0,0 +1,164 @@
package kcp
import (
"fmt"
"sync/atomic"
)
// Snmp defines network statistics indicator
type Snmp struct {
BytesSent uint64 // bytes sent from upper level
BytesReceived uint64 // bytes received to upper level
MaxConn uint64 // max number of connections ever reached
ActiveOpens uint64 // accumulated active open connections
PassiveOpens uint64 // accumulated passive open connections
CurrEstab uint64 // current number of established connections
InErrs uint64 // UDP read errors reported from net.PacketConn
InCsumErrors uint64 // checksum errors from CRC32
KCPInErrors uint64 // packet iput errors reported from KCP
InPkts uint64 // incoming packets count
OutPkts uint64 // outgoing packets count
InSegs uint64 // incoming KCP segments
OutSegs uint64 // outgoing KCP segments
InBytes uint64 // UDP bytes received
OutBytes uint64 // UDP bytes sent
RetransSegs uint64 // accmulated retransmited segments
FastRetransSegs uint64 // accmulated fast retransmitted segments
EarlyRetransSegs uint64 // accmulated early retransmitted segments
LostSegs uint64 // number of segs inferred as lost
RepeatSegs uint64 // number of segs duplicated
FECRecovered uint64 // correct packets recovered from FEC
FECErrs uint64 // incorrect packets recovered from FEC
FECParityShards uint64 // FEC segments received
FECShortShards uint64 // number of data shards that's not enough for recovery
}
func newSnmp() *Snmp {
return new(Snmp)
}
// Header returns all field names
func (s *Snmp) Header() []string {
return []string{
"BytesSent",
"BytesReceived",
"MaxConn",
"ActiveOpens",
"PassiveOpens",
"CurrEstab",
"InErrs",
"InCsumErrors",
"KCPInErrors",
"InPkts",
"OutPkts",
"InSegs",
"OutSegs",
"InBytes",
"OutBytes",
"RetransSegs",
"FastRetransSegs",
"EarlyRetransSegs",
"LostSegs",
"RepeatSegs",
"FECParityShards",
"FECErrs",
"FECRecovered",
"FECShortShards",
}
}
// ToSlice returns current snmp info as slice
func (s *Snmp) ToSlice() []string {
snmp := s.Copy()
return []string{
fmt.Sprint(snmp.BytesSent),
fmt.Sprint(snmp.BytesReceived),
fmt.Sprint(snmp.MaxConn),
fmt.Sprint(snmp.ActiveOpens),
fmt.Sprint(snmp.PassiveOpens),
fmt.Sprint(snmp.CurrEstab),
fmt.Sprint(snmp.InErrs),
fmt.Sprint(snmp.InCsumErrors),
fmt.Sprint(snmp.KCPInErrors),
fmt.Sprint(snmp.InPkts),
fmt.Sprint(snmp.OutPkts),
fmt.Sprint(snmp.InSegs),
fmt.Sprint(snmp.OutSegs),
fmt.Sprint(snmp.InBytes),
fmt.Sprint(snmp.OutBytes),
fmt.Sprint(snmp.RetransSegs),
fmt.Sprint(snmp.FastRetransSegs),
fmt.Sprint(snmp.EarlyRetransSegs),
fmt.Sprint(snmp.LostSegs),
fmt.Sprint(snmp.RepeatSegs),
fmt.Sprint(snmp.FECParityShards),
fmt.Sprint(snmp.FECErrs),
fmt.Sprint(snmp.FECRecovered),
fmt.Sprint(snmp.FECShortShards),
}
}
// Copy make a copy of current snmp snapshot
func (s *Snmp) Copy() *Snmp {
d := newSnmp()
d.BytesSent = atomic.LoadUint64(&s.BytesSent)
d.BytesReceived = atomic.LoadUint64(&s.BytesReceived)
d.MaxConn = atomic.LoadUint64(&s.MaxConn)
d.ActiveOpens = atomic.LoadUint64(&s.ActiveOpens)
d.PassiveOpens = atomic.LoadUint64(&s.PassiveOpens)
d.CurrEstab = atomic.LoadUint64(&s.CurrEstab)
d.InErrs = atomic.LoadUint64(&s.InErrs)
d.InCsumErrors = atomic.LoadUint64(&s.InCsumErrors)
d.KCPInErrors = atomic.LoadUint64(&s.KCPInErrors)
d.InPkts = atomic.LoadUint64(&s.InPkts)
d.OutPkts = atomic.LoadUint64(&s.OutPkts)
d.InSegs = atomic.LoadUint64(&s.InSegs)
d.OutSegs = atomic.LoadUint64(&s.OutSegs)
d.InBytes = atomic.LoadUint64(&s.InBytes)
d.OutBytes = atomic.LoadUint64(&s.OutBytes)
d.RetransSegs = atomic.LoadUint64(&s.RetransSegs)
d.FastRetransSegs = atomic.LoadUint64(&s.FastRetransSegs)
d.EarlyRetransSegs = atomic.LoadUint64(&s.EarlyRetransSegs)
d.LostSegs = atomic.LoadUint64(&s.LostSegs)
d.RepeatSegs = atomic.LoadUint64(&s.RepeatSegs)
d.FECParityShards = atomic.LoadUint64(&s.FECParityShards)
d.FECErrs = atomic.LoadUint64(&s.FECErrs)
d.FECRecovered = atomic.LoadUint64(&s.FECRecovered)
d.FECShortShards = atomic.LoadUint64(&s.FECShortShards)
return d
}
// Reset values to zero
func (s *Snmp) Reset() {
atomic.StoreUint64(&s.BytesSent, 0)
atomic.StoreUint64(&s.BytesReceived, 0)
atomic.StoreUint64(&s.MaxConn, 0)
atomic.StoreUint64(&s.ActiveOpens, 0)
atomic.StoreUint64(&s.PassiveOpens, 0)
atomic.StoreUint64(&s.CurrEstab, 0)
atomic.StoreUint64(&s.InErrs, 0)
atomic.StoreUint64(&s.InCsumErrors, 0)
atomic.StoreUint64(&s.KCPInErrors, 0)
atomic.StoreUint64(&s.InPkts, 0)
atomic.StoreUint64(&s.OutPkts, 0)
atomic.StoreUint64(&s.InSegs, 0)
atomic.StoreUint64(&s.OutSegs, 0)
atomic.StoreUint64(&s.InBytes, 0)
atomic.StoreUint64(&s.OutBytes, 0)
atomic.StoreUint64(&s.RetransSegs, 0)
atomic.StoreUint64(&s.FastRetransSegs, 0)
atomic.StoreUint64(&s.EarlyRetransSegs, 0)
atomic.StoreUint64(&s.LostSegs, 0)
atomic.StoreUint64(&s.RepeatSegs, 0)
atomic.StoreUint64(&s.FECParityShards, 0)
atomic.StoreUint64(&s.FECErrs, 0)
atomic.StoreUint64(&s.FECRecovered, 0)
atomic.StoreUint64(&s.FECShortShards, 0)
}
// DefaultSnmp is the global KCP connection statistics collector
var DefaultSnmp *Snmp
func init() {
DefaultSnmp = newSnmp()
}
+146
View File
@@ -0,0 +1,146 @@
package kcp
import (
"container/heap"
"runtime"
"sync"
"time"
)
// SystemTimedSched is the library level timed-scheduler
var SystemTimedSched = NewTimedSched(runtime.NumCPU())
type timedFunc struct {
execute func()
ts time.Time
}
// a heap for sorted timed function
type timedFuncHeap []timedFunc
func (h timedFuncHeap) Len() int { return len(h) }
func (h timedFuncHeap) Less(i, j int) bool { return h[i].ts.Before(h[j].ts) }
func (h timedFuncHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *timedFuncHeap) Push(x interface{}) { *h = append(*h, x.(timedFunc)) }
func (h *timedFuncHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
old[n-1].execute = nil // avoid memory leak
*h = old[0 : n-1]
return x
}
// TimedSched represents the control struct for timed parallel scheduler
type TimedSched struct {
// prepending tasks
prependTasks []timedFunc
prependLock sync.Mutex
chPrependNotify chan struct{}
// tasks will be distributed through chTask
chTask chan timedFunc
dieOnce sync.Once
die chan struct{}
}
// NewTimedSched creates a parallel-scheduler with given parallelization
func NewTimedSched(parallel int) *TimedSched {
ts := new(TimedSched)
ts.chTask = make(chan timedFunc)
ts.die = make(chan struct{})
ts.chPrependNotify = make(chan struct{}, 1)
for i := 0; i < parallel; i++ {
go ts.sched()
}
go ts.prepend()
return ts
}
func (ts *TimedSched) sched() {
var tasks timedFuncHeap
timer := time.NewTimer(0)
drained := false
for {
select {
case task := <-ts.chTask:
now := time.Now()
if now.After(task.ts) {
// already delayed! execute immediately
task.execute()
} else {
heap.Push(&tasks, task)
// properly reset timer to trigger based on the top element
stopped := timer.Stop()
if !stopped && !drained {
<-timer.C
}
timer.Reset(tasks[0].ts.Sub(now))
drained = false
}
case now := <-timer.C:
drained = true
for tasks.Len() > 0 {
if now.After(tasks[0].ts) {
heap.Pop(&tasks).(timedFunc).execute()
} else {
timer.Reset(tasks[0].ts.Sub(now))
drained = false
break
}
}
case <-ts.die:
return
}
}
}
func (ts *TimedSched) prepend() {
var tasks []timedFunc
for {
select {
case <-ts.chPrependNotify:
ts.prependLock.Lock()
// keep cap to reuse slice
if cap(tasks) < cap(ts.prependTasks) {
tasks = make([]timedFunc, 0, cap(ts.prependTasks))
}
tasks = tasks[:len(ts.prependTasks)]
copy(tasks, ts.prependTasks)
for k := range ts.prependTasks {
ts.prependTasks[k].execute = nil // avoid memory leak
}
ts.prependTasks = ts.prependTasks[:0]
ts.prependLock.Unlock()
for k := range tasks {
select {
case ts.chTask <- tasks[k]:
tasks[k].execute = nil // avoid memory leak
case <-ts.die:
return
}
}
tasks = tasks[:0]
case <-ts.die:
return
}
}
}
// Put a function 'f' awaiting to be executed at 'deadline'
func (ts *TimedSched) Put(f func(), deadline time.Time) {
ts.prependLock.Lock()
ts.prependTasks = append(ts.prependTasks, timedFunc{f, deadline})
ts.prependLock.Unlock()
select {
case ts.chPrependNotify <- struct{}{}:
default:
}
}
// Close terminates this scheduler
func (ts *TimedSched) Close() { ts.dieOnce.Do(func() { close(ts.die) }) }
+80
View File
@@ -0,0 +1,80 @@
package kcp
import (
"net"
"sync/atomic"
"github.com/pkg/errors"
"golang.org/x/net/ipv4"
)
func buildEnet(connType uint8, enetType uint32, conv uint64) []byte {
data := make([]byte, 20)
if connType == ConnEnetSyn {
copy(data[0:4], MagicEnetSynHead)
copy(data[16:20], MagicEnetSynTail)
} else if connType == ConnEnetEst {
copy(data[0:4], MagicEnetEstHead)
copy(data[16:20], MagicEnetEstTail)
} else if connType == ConnEnetFin {
copy(data[0:4], MagicEnetFinHead)
copy(data[16:20], MagicEnetFinTail)
} else {
return nil
}
// conv的高四个字节和低四个字节分开
// 例如 00 00 01 45 | LL LL LL LL | HH HH HH HH | 49 96 02 d2 | 14 51 45 45
data[4] = uint8(conv >> 24)
data[5] = uint8(conv >> 16)
data[6] = uint8(conv >> 8)
data[7] = uint8(conv >> 0)
data[8] = uint8(conv >> 56)
data[9] = uint8(conv >> 48)
data[10] = uint8(conv >> 40)
data[11] = uint8(conv >> 32)
// Enet
data[12] = uint8(enetType >> 24)
data[13] = uint8(enetType >> 16)
data[14] = uint8(enetType >> 8)
data[15] = uint8(enetType >> 0)
return data
}
func (l *Listener) defaultSendEnetNotifyToClient(enet *Enet) {
remoteAddr, err := net.ResolveUDPAddr("udp", enet.Addr)
if err != nil {
return
}
data := buildEnet(enet.ConnType, enet.EnetType, enet.ConvId)
if data == nil {
return
}
_, _ = l.conn.WriteTo(data, remoteAddr)
}
func (s *UDPSession) defaultSendEnetNotify(enet *Enet) {
data := buildEnet(enet.ConnType, enet.EnetType, s.GetConv())
if data == nil {
return
}
s.defaultTx([]ipv4.Message{{
Buffers: [][]byte{data},
Addr: s.remote,
}})
}
func (s *UDPSession) defaultTx(txqueue []ipv4.Message) {
nbytes := 0
npkts := 0
for k := range txqueue {
if n, err := s.conn.WriteTo(txqueue[k].Buffers[0], txqueue[k].Addr); err == nil {
nbytes += n
npkts++
} else {
s.notifyWriteError(errors.WithStack(err))
break
}
}
atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts))
atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes))
}
+20
View File
@@ -0,0 +1,20 @@
//go:build !linux
// +build !linux
package kcp
import (
"golang.org/x/net/ipv4"
)
func (l *Listener) SendEnetNotifyToClient(enet *Enet) {
l.defaultSendEnetNotifyToClient(enet)
}
func (s *UDPSession) SendEnetNotify(enet *Enet) {
s.defaultSendEnetNotify(enet)
}
func (s *UDPSession) tx(txqueue []ipv4.Message) {
s.defaultTx(txqueue)
}
+102
View File
@@ -0,0 +1,102 @@
//go:build linux
// +build linux
package kcp
import (
"golang.org/x/net/ipv6"
"net"
"os"
"sync/atomic"
"github.com/pkg/errors"
"golang.org/x/net/ipv4"
)
func (l *Listener) SendEnetNotifyToClient(enet *Enet) {
var xconn batchConn
_, ok := l.conn.(*net.UDPConn)
if !ok {
return
}
localAddr, err := net.ResolveUDPAddr("udp", l.conn.LocalAddr().String())
if err != nil {
return
}
if localAddr.IP.To4() != nil {
xconn = ipv4.NewPacketConn(l.conn)
} else {
xconn = ipv6.NewPacketConn(l.conn)
}
// default version
if xconn == nil {
l.defaultSendEnetNotifyToClient(enet)
return
}
remoteAddr, err := net.ResolveUDPAddr("udp", enet.Addr)
if err != nil {
return
}
data := buildEnet(enet.ConnType, enet.EnetType, enet.ConvId)
if data == nil {
return
}
_, _ = xconn.WriteBatch([]ipv4.Message{{
Buffers: [][]byte{data},
Addr: remoteAddr,
}}, 0)
}
func (s *UDPSession) SendEnetNotify(enet *Enet) {
data := buildEnet(enet.ConnType, enet.EnetType, s.GetConv())
if data == nil {
return
}
s.tx([]ipv4.Message{{
Buffers: [][]byte{data},
Addr: s.remote,
}})
}
func (s *UDPSession) tx(txqueue []ipv4.Message) {
// default version
if s.xconn == nil || s.xconnWriteError != nil {
s.defaultTx(txqueue)
return
}
// x/net version
nbytes := 0
npkts := 0
for len(txqueue) > 0 {
if n, err := s.xconn.WriteBatch(txqueue, 0); err == nil {
for k := range txqueue[:n] {
nbytes += len(txqueue[k].Buffers[0])
}
npkts += n
txqueue = txqueue[n:]
} else {
// compatibility issue:
// for linux kernel<=2.6.32, support for sendmmsg is not available
// an error of type os.SyscallError will be returned
if operr, ok := err.(*net.OpError); ok {
if se, ok := operr.Err.(*os.SyscallError); ok {
if se.Syscall == "sendmmsg" {
s.xconnWriteError = se
s.defaultTx(txqueue)
return
}
}
}
s.notifyWriteError(errors.WithStack(err))
break
}
}
atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts))
atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes))
}
+97
View File
@@ -0,0 +1,97 @@
package mq
import (
"github.com/nats-io/nats.go"
"github.com/vmihailenco/msgpack/v5"
pb "google.golang.org/protobuf/proto"
"hk4e/common/config"
"hk4e/logger"
"hk4e/protocol/cmd"
)
type MessageQueue struct {
natsConn *nats.Conn
natsMsgChan chan *nats.Msg
netMsgInput chan *cmd.NetMsg
netMsgOutput chan *cmd.NetMsg
cmdProtoMap *cmd.CmdProtoMap
}
func NewMessageQueue(netMsgInput chan *cmd.NetMsg, netMsgOutput chan *cmd.NetMsg) (r *MessageQueue) {
r = new(MessageQueue)
conn, err := nats.Connect(config.CONF.MQ.NatsUrl)
if err != nil {
logger.LOG.Error("connect nats error: %v", err)
return nil
}
r.natsConn = conn
r.natsMsgChan = make(chan *nats.Msg, 10000)
_, err = r.natsConn.ChanSubscribe("GATE_HK4E", r.natsMsgChan)
if err != nil {
logger.LOG.Error("nats subscribe error: %v", err)
return nil
}
r.netMsgInput = netMsgInput
r.netMsgOutput = netMsgOutput
r.cmdProtoMap = cmd.NewCmdProtoMap()
return r
}
func (m *MessageQueue) Start() {
go m.startRecvHandler()
go m.startSendHandler()
}
func (m *MessageQueue) Close() {
m.natsConn.Close()
}
func (m *MessageQueue) startRecvHandler() {
for {
natsMsg := <-m.natsMsgChan
// msgpack NetMsg
netMsg := new(cmd.NetMsg)
err := msgpack.Unmarshal(natsMsg.Data, netMsg)
if err != nil {
logger.LOG.Error("parse bin to net msg error: %v", err)
continue
}
if netMsg.EventId == cmd.NormalMsg {
// protobuf PayloadMessage
payloadMessage := m.cmdProtoMap.GetProtoObjByCmdId(netMsg.CmdId)
err = pb.Unmarshal(netMsg.PayloadMessageData, payloadMessage)
if err != nil {
logger.LOG.Error("parse bin to payload msg error: %v", err)
continue
}
netMsg.PayloadMessage = payloadMessage
}
m.netMsgOutput <- netMsg
}
}
func (m *MessageQueue) startSendHandler() {
for {
netMsg := <-m.netMsgInput
// protobuf PayloadMessage
payloadMessageData, err := pb.Marshal(netMsg.PayloadMessage)
if err != nil {
logger.LOG.Error("parse payload msg to bin error: %v", err)
continue
}
netMsg.PayloadMessageData = payloadMessageData
// msgpack NetMsg
netMsgData, err := msgpack.Marshal(netMsg)
if err != nil {
logger.LOG.Error("parse net msg to bin error: %v", err)
continue
}
natsMsg := nats.NewMsg("GS_HK4E")
natsMsg.Data = netMsgData
err = m.natsConn.PublishMsg(natsMsg)
if err != nil {
logger.LOG.Error("nats publish msg error: %v", err)
continue
}
}
}
+387
View File
@@ -0,0 +1,387 @@
package net
import (
"bytes"
"encoding/binary"
"hk4e/common/config"
"hk4e/common/utils/random"
"hk4e/gate/kcp"
"hk4e/logger"
"os"
"strconv"
"sync"
"time"
)
type KcpXorKey struct {
encKey []byte
decKey []byte
}
type KcpConnectManager struct {
openState bool
connMap map[uint64]*kcp.UDPSession
connMapLock sync.RWMutex
protoMsgInput chan *ProtoMsg
protoMsgOutput chan *ProtoMsg
kcpEventInput chan *KcpEvent
kcpEventOutput chan *KcpEvent
// 发送协程分发
kcpRawSendChanMap map[uint64]chan *ProtoMsg
kcpRawSendChanMapLock sync.RWMutex
// 收包发包监听标志
kcpRecvListenMap map[uint64]bool
kcpRecvListenMapLock sync.RWMutex
kcpSendListenMap map[uint64]bool
kcpSendListenMapLock sync.RWMutex
// key
dispatchKey []byte
secretKey []byte
kcpKeyMap map[uint64]*KcpXorKey
kcpKeyMapLock sync.RWMutex
// conv短时间内唯一生成
convGenMap map[uint64]int64
convGenMapLock sync.RWMutex
}
func NewKcpConnectManager(protoMsgInput chan *ProtoMsg, protoMsgOutput chan *ProtoMsg,
kcpEventInput chan *KcpEvent, kcpEventOutput chan *KcpEvent) (r *KcpConnectManager) {
r = new(KcpConnectManager)
r.openState = true
r.connMap = make(map[uint64]*kcp.UDPSession)
r.protoMsgInput = protoMsgInput
r.protoMsgOutput = protoMsgOutput
r.kcpEventInput = kcpEventInput
r.kcpEventOutput = kcpEventOutput
r.kcpRawSendChanMap = make(map[uint64]chan *ProtoMsg)
r.kcpRecvListenMap = make(map[uint64]bool)
r.kcpSendListenMap = make(map[uint64]bool)
r.kcpKeyMap = make(map[uint64]*KcpXorKey)
r.convGenMap = make(map[uint64]int64)
return r
}
func (k *KcpConnectManager) Start() {
go func() {
// key
var err error = nil
k.dispatchKey, err = os.ReadFile("static/dispatchKey.bin")
if err != nil {
logger.LOG.Error("open dispatchKey.bin error")
return
}
k.secretKey, err = os.ReadFile("static/secretKey.bin")
if err != nil {
logger.LOG.Error("open secretKey.bin error")
return
}
// kcp
port := strconv.FormatInt(int64(config.CONF.Hk4e.KcpPort), 10)
listener, err := kcp.ListenWithOptions("0.0.0.0:"+port, nil, 0, 0)
if err != nil {
logger.LOG.Error("listen kcp err: %v", err)
return
} else {
go k.enetHandle(listener)
go k.chanSendHandle()
go k.eventHandle()
for {
conn, err := listener.AcceptKCP()
if err != nil {
logger.LOG.Error("accept kcp err: %v", err)
return
}
if k.openState == false {
_ = conn.Close()
continue
}
conn.SetACKNoDelay(true)
conn.SetWriteDelay(false)
convId := conn.GetConv()
logger.LOG.Debug("client connect, convId: %v", convId)
// 连接建立成功通知
k.kcpEventOutput <- &KcpEvent{
ConvId: convId,
EventId: KcpConnEstNotify,
EventMessage: conn.RemoteAddr().String(),
}
k.connMapLock.Lock()
k.connMap[convId] = conn
k.connMapLock.Unlock()
k.kcpKeyMapLock.Lock()
k.kcpKeyMap[convId] = &KcpXorKey{
encKey: k.dispatchKey,
decKey: k.dispatchKey,
}
k.kcpKeyMapLock.Unlock()
go k.recvHandle(convId)
kcpRawSendChan := make(chan *ProtoMsg, 10000)
k.kcpRawSendChanMapLock.Lock()
k.kcpRawSendChanMap[convId] = kcpRawSendChan
k.kcpRawSendChanMapLock.Unlock()
go k.sendHandle(convId, kcpRawSendChan)
go k.rttMonitor(convId)
}
}
}()
go k.clearDeadConv()
}
func (k *KcpConnectManager) clearDeadConv() {
ticker := time.NewTicker(time.Minute)
for {
k.convGenMapLock.Lock()
now := time.Now().UnixNano()
oldConvList := make([]uint64, 0)
for conv, timestamp := range k.convGenMap {
if now-timestamp > int64(time.Hour) {
oldConvList = append(oldConvList, conv)
}
}
delConvList := make([]uint64, 0)
k.connMapLock.RLock()
for _, conv := range oldConvList {
_, exist := k.connMap[conv]
if !exist {
delConvList = append(delConvList, conv)
delete(k.convGenMap, conv)
}
}
k.connMapLock.RUnlock()
k.convGenMapLock.Unlock()
logger.LOG.Info("clean dead conv list: %v", delConvList)
<-ticker.C
}
}
func (k *KcpConnectManager) enetHandle(listener *kcp.Listener) {
for {
enetNotify := <-listener.EnetNotify
logger.LOG.Info("[Enet Notify], addr: %v, conv: %v, conn: %v, enet: %v", enetNotify.Addr, enetNotify.ConvId, enetNotify.ConnType, enetNotify.EnetType)
switch enetNotify.ConnType {
case kcp.ConnEnetSyn:
if enetNotify.EnetType == kcp.EnetClientConnectKey {
var conv uint64
k.convGenMapLock.Lock()
for {
convData := random.GetRandomByte(8)
convDataBuffer := bytes.NewBuffer(convData)
_ = binary.Read(convDataBuffer, binary.LittleEndian, &conv)
_, exist := k.convGenMap[conv]
if exist {
continue
} else {
k.convGenMap[conv] = time.Now().UnixNano()
break
}
}
k.convGenMapLock.Unlock()
listener.SendEnetNotifyToClient(&kcp.Enet{
Addr: enetNotify.Addr,
ConvId: conv,
ConnType: kcp.ConnEnetEst,
EnetType: enetNotify.EnetType,
})
}
case kcp.ConnEnetEst:
case kcp.ConnEnetFin:
k.closeKcpConn(enetNotify.ConvId, enetNotify.EnetType)
case kcp.ConnEnetAddrChange:
// 连接地址改变通知
k.kcpEventOutput <- &KcpEvent{
ConvId: enetNotify.ConvId,
EventId: KcpConnAddrChangeNotify,
EventMessage: enetNotify.Addr,
}
default:
}
}
}
func (k *KcpConnectManager) chanSendHandle() {
// 分发到每个连接具体的发送协程
for {
protoMsg := <-k.protoMsgInput
k.kcpRawSendChanMapLock.RLock()
kcpRawSendChan := k.kcpRawSendChanMap[protoMsg.ConvId]
k.kcpRawSendChanMapLock.RUnlock()
if kcpRawSendChan != nil {
select {
case kcpRawSendChan <- protoMsg:
default:
logger.LOG.Error("kcpRawSendChan is full, convId: %v", protoMsg.ConvId)
}
} else {
logger.LOG.Error("kcpRawSendChan is nil, convId: %v", protoMsg.ConvId)
}
}
}
func (k *KcpConnectManager) recvHandle(convId uint64) {
// 接收
k.connMapLock.RLock()
conn := k.connMap[convId]
k.connMapLock.RUnlock()
pktFreqLimitCounter := 0
pktFreqLimitTimer := time.Now().UnixNano()
protoEnDecode := NewProtoEnDecode()
recvBuf := make([]byte, conn.GetMaxPayloadLen())
for {
_ = conn.SetReadDeadline(time.Now().Add(time.Second * 30))
recvLen, err := conn.Read(recvBuf)
if err != nil {
logger.LOG.Error("exit recv loop, conn read err: %v, convId: %v", err, convId)
k.closeKcpConn(convId, kcp.EnetServerKick)
break
}
pktFreqLimitCounter++
now := time.Now().UnixNano()
if now-pktFreqLimitTimer > int64(time.Second) {
if pktFreqLimitCounter > 1000 {
logger.LOG.Error("exit recv loop, client packet send freq too high, convId: %v, pps: %v", convId, pktFreqLimitCounter)
k.closeKcpConn(convId, kcp.EnetPacketFreqTooHigh)
break
} else {
pktFreqLimitCounter = 0
}
pktFreqLimitTimer = now
}
recvData := recvBuf[:recvLen]
k.kcpRecvListenMapLock.RLock()
flag := k.kcpRecvListenMap[convId]
k.kcpRecvListenMapLock.RUnlock()
if flag {
// 收包通知
//recvMsg := make([]byte, len(recvData))
//copy(recvMsg, recvData)
k.kcpEventOutput <- &KcpEvent{
ConvId: convId,
EventId: KcpPacketRecvNotify,
EventMessage: recvData,
}
}
kcpMsgList := make([]*KcpMsg, 0)
k.decodeBinToPayload(recvData, convId, &kcpMsgList)
for _, v := range kcpMsgList {
protoMsgList := protoEnDecode.protoDecode(v)
for _, vv := range protoMsgList {
k.protoMsgOutput <- vv
}
}
}
}
func (k *KcpConnectManager) sendHandle(convId uint64, kcpRawSendChan chan *ProtoMsg) {
// 发送
k.connMapLock.RLock()
conn := k.connMap[convId]
k.connMapLock.RUnlock()
protoEnDecode := NewProtoEnDecode()
for {
protoMsg, ok := <-kcpRawSendChan
if !ok {
logger.LOG.Error("exit send loop, send chan close, convId: %v", convId)
k.closeKcpConn(convId, kcp.EnetServerKick)
break
}
kcpMsg := protoEnDecode.protoEncode(protoMsg)
if kcpMsg == nil {
logger.LOG.Error("decode kcp msg is nil, convId: %v", convId)
continue
}
bin := k.encodePayloadToBin(kcpMsg)
_ = conn.SetWriteDeadline(time.Now().Add(time.Second * 10))
_, err := conn.Write(bin)
if err != nil {
logger.LOG.Error("exit send loop, conn write err: %v, convId: %v", err, convId)
k.closeKcpConn(convId, kcp.EnetServerKick)
break
}
k.kcpSendListenMapLock.RLock()
flag := k.kcpSendListenMap[convId]
k.kcpSendListenMapLock.RUnlock()
if flag {
// 发包通知
k.kcpEventOutput <- &KcpEvent{
ConvId: convId,
EventId: KcpPacketSendNotify,
EventMessage: bin,
}
}
}
}
func (k *KcpConnectManager) rttMonitor(convId uint64) {
ticker := time.NewTicker(time.Second * 10)
for {
select {
case <-ticker.C:
k.connMapLock.RLock()
conn := k.connMap[convId]
k.connMapLock.RUnlock()
if conn == nil {
break
}
logger.LOG.Debug("convId: %v, RTO: %v, SRTT: %v, RTTVar: %v", convId, conn.GetRTO(), conn.GetSRTT(), conn.GetSRTTVar())
k.kcpEventOutput <- &KcpEvent{
ConvId: convId,
EventId: KcpConnRttNotify,
EventMessage: conn.GetSRTT(),
}
}
}
}
func (k *KcpConnectManager) closeKcpConn(convId uint64, enetType uint32) {
k.connMapLock.RLock()
conn, exist := k.connMap[convId]
k.connMapLock.RUnlock()
if !exist {
return
}
// 获取待关闭的发送管道
k.kcpRawSendChanMapLock.RLock()
kcpRawSendChan := k.kcpRawSendChanMap[convId]
k.kcpRawSendChanMapLock.RUnlock()
// 清理数据
k.connMapLock.Lock()
delete(k.connMap, convId)
k.connMapLock.Unlock()
k.kcpRawSendChanMapLock.Lock()
delete(k.kcpRawSendChanMap, convId)
k.kcpRawSendChanMapLock.Unlock()
k.kcpRecvListenMapLock.Lock()
delete(k.kcpRecvListenMap, convId)
k.kcpRecvListenMapLock.Unlock()
k.kcpSendListenMapLock.Lock()
delete(k.kcpSendListenMap, convId)
k.kcpSendListenMapLock.Unlock()
k.kcpKeyMapLock.Lock()
delete(k.kcpKeyMap, convId)
k.kcpKeyMapLock.Unlock()
// 关闭连接
conn.SendEnetNotify(&kcp.Enet{
ConnType: kcp.ConnEnetFin,
EnetType: enetType,
})
_ = conn.Close()
// 关闭发送管道
close(kcpRawSendChan)
// 连接关闭通知
k.kcpEventOutput <- &KcpEvent{
ConvId: convId,
EventId: KcpConnCloseNotify,
}
}
func (k *KcpConnectManager) closeAllKcpConn() {
closeConnList := make([]*kcp.UDPSession, 0)
k.connMapLock.RLock()
for _, v := range k.connMap {
closeConnList = append(closeConnList, v)
}
k.connMapLock.RUnlock()
for _, v := range closeConnList {
k.closeKcpConn(v.GetConv(), kcp.EnetServerShutdown)
}
}
+187
View File
@@ -0,0 +1,187 @@
package net
import (
"bytes"
"encoding/binary"
"hk4e/common/utils/endec"
"hk4e/logger"
)
/*
原神KCP协议(带*为xor加密数据)
0 1 2 4 8(字节)
+---------------------------------------------------------------------------------------+
| conv |
+---------------------------------------------------------------------------------------+
| cmd | frg | wnd | ts |
+---------------------------------------------------------------------------------------+
| sn | una |
+---------------------------------------------------------------------------------------+
| len | 0X4567* | cmdId* |
+---------------------------------------------------------------------------------------+
| headLen* | payloadLen* | head* |
+---------------------------------------------------------------------------------------+
| payload* | 0X89AB* |
+---------------------------------------------------------------------------------------+
*/
type KcpMsg struct {
ConvId uint64
CmdId uint16
HeadData []byte
ProtoData []byte
}
func (k *KcpConnectManager) decodeBinToPayload(data []byte, convId uint64, kcpMsgList *[]*KcpMsg) {
// xor解密
k.kcpKeyMapLock.RLock()
xorKey, exist := k.kcpKeyMap[convId]
k.kcpKeyMapLock.RUnlock()
if !exist {
logger.LOG.Error("kcp xor key not exist, convId: %v", convId)
return
}
endec.Xor(data, xorKey.decKey)
k.decodeRecur(data, convId, kcpMsgList)
}
func (k *KcpConnectManager) decodeRecur(data []byte, convId uint64, kcpMsgList *[]*KcpMsg) {
// 长度太短
if len(data) < 12 {
logger.LOG.Debug("packet len less 12 byte")
return
}
// 头部标志错误
if data[0] != 0x45 || data[1] != 0x67 {
logger.LOG.Error("packet head magic 0x4567 error")
return
}
// 协议号
cmdIdByteSlice := make([]byte, 8)
cmdIdByteSlice[6] = data[2]
cmdIdByteSlice[7] = data[3]
cmdIdBuffer := bytes.NewBuffer(cmdIdByteSlice)
var cmdId int64
err := binary.Read(cmdIdBuffer, binary.BigEndian, &cmdId)
if err != nil {
logger.LOG.Error("packet cmd id parse fail: %v", err)
return
}
// 头部长度
headLenByteSlice := make([]byte, 8)
headLenByteSlice[6] = data[4]
headLenByteSlice[7] = data[5]
headLenBuffer := bytes.NewBuffer(headLenByteSlice)
var headLen int64
err = binary.Read(headLenBuffer, binary.BigEndian, &headLen)
if err != nil {
logger.LOG.Error("packet head len parse fail: %v", err)
return
}
// proto长度
protoLenByteSlice := make([]byte, 8)
protoLenByteSlice[4] = data[6]
protoLenByteSlice[5] = data[7]
protoLenByteSlice[6] = data[8]
protoLenByteSlice[7] = data[9]
protoLenBuffer := bytes.NewBuffer(protoLenByteSlice)
var protoLen int64
err = binary.Read(protoLenBuffer, binary.BigEndian, &protoLen)
if err != nil {
logger.LOG.Error("packet proto len parse fail: %v", err)
return
}
// 检查最小长度
if len(data) < int(headLen+protoLen)+12 {
logger.LOG.Error("packet len error")
return
}
// 尾部标志错误
if data[headLen+protoLen+10] != 0x89 || data[headLen+protoLen+11] != 0xAB {
logger.LOG.Error("packet tail magic 0x89AB error")
return
}
// 判断是否有不止一个包
haveMoreData := false
if len(data) > int(headLen+protoLen)+12 {
haveMoreData = true
}
// 头部数据
headData := data[10 : 10+headLen]
// proto数据
protoData := data[10+headLen : 10+headLen+protoLen]
// 返回数据
kcpMsg := new(KcpMsg)
kcpMsg.ConvId = convId
kcpMsg.CmdId = uint16(cmdId)
//kcpMsg.HeadData = make([]byte, len(headData))
//copy(kcpMsg.HeadData, headData)
//kcpMsg.ProtoData = make([]byte, len(protoData))
//copy(kcpMsg.ProtoData, protoData)
kcpMsg.HeadData = headData
kcpMsg.ProtoData = protoData
*kcpMsgList = append(*kcpMsgList, kcpMsg)
// 递归解析
if haveMoreData {
k.decodeRecur(data[int(headLen+protoLen)+12:], convId, kcpMsgList)
}
}
func (k *KcpConnectManager) encodePayloadToBin(kcpMsg *KcpMsg) (bin []byte) {
if kcpMsg.HeadData == nil {
kcpMsg.HeadData = make([]byte, 0)
}
if kcpMsg.ProtoData == nil {
kcpMsg.ProtoData = make([]byte, 0)
}
bin = make([]byte, len(kcpMsg.HeadData)+len(kcpMsg.ProtoData)+12)
// 头部标志
bin[0] = 0x45
bin[1] = 0x67
// 协议号
cmdIdBuffer := bytes.NewBuffer([]byte{})
err := binary.Write(cmdIdBuffer, binary.BigEndian, kcpMsg.CmdId)
if err != nil {
logger.LOG.Error("cmd id encode err: %v", err)
return nil
}
bin[2] = (cmdIdBuffer.Bytes())[0]
bin[3] = (cmdIdBuffer.Bytes())[1]
// 头部长度
headLenBuffer := bytes.NewBuffer([]byte{})
err = binary.Write(headLenBuffer, binary.BigEndian, uint16(len(kcpMsg.HeadData)))
if err != nil {
logger.LOG.Error("head len encode err: %v", err)
return nil
}
bin[4] = (headLenBuffer.Bytes())[0]
bin[5] = (headLenBuffer.Bytes())[1]
// proto长度
protoLenBuffer := bytes.NewBuffer([]byte{})
err = binary.Write(protoLenBuffer, binary.BigEndian, uint32(len(kcpMsg.ProtoData)))
if err != nil {
logger.LOG.Error("proto len encode err: %v", err)
return nil
}
bin[6] = (protoLenBuffer.Bytes())[0]
bin[7] = (protoLenBuffer.Bytes())[1]
bin[8] = (protoLenBuffer.Bytes())[2]
bin[9] = (protoLenBuffer.Bytes())[3]
// 头部数据
copy(bin[10:], kcpMsg.HeadData)
// proto数据
copy(bin[10+len(kcpMsg.HeadData):], kcpMsg.ProtoData)
// 尾部标志
bin[len(bin)-2] = 0x89
bin[len(bin)-1] = 0xAB
// xor加密
k.kcpKeyMapLock.RLock()
xorKey, exist := k.kcpKeyMap[kcpMsg.ConvId]
k.kcpKeyMapLock.RUnlock()
if !exist {
logger.LOG.Error("kcp xor key not exist, convId: %v", kcpMsg.ConvId)
return
}
endec.Xor(bin, xorKey.encKey)
return bin
}
+134
View File
@@ -0,0 +1,134 @@
package net
import "hk4e/logger"
const (
KcpXorKeyChange = iota
KcpPacketRecvListen
KcpPacketSendListen
KcpConnForceClose
KcpAllConnForceClose
KcpGateOpenState
KcpPacketRecvNotify
KcpPacketSendNotify
KcpConnCloseNotify
KcpConnEstNotify
KcpConnRttNotify
KcpConnAddrChangeNotify
)
type KcpEvent struct {
ConvId uint64
EventId int
EventMessage any
}
func (k *KcpConnectManager) eventHandle() {
// 事件处理
for {
event := <-k.kcpEventInput
logger.LOG.Info("kcp manager recv event, ConvId: %v, EventId: %v, EventMessage: %v", event.ConvId, event.EventId, event.EventMessage)
switch event.EventId {
case KcpXorKeyChange:
// XOR密钥切换
k.connMapLock.RLock()
_, exist := k.connMap[event.ConvId]
k.connMapLock.RUnlock()
if !exist {
logger.LOG.Error("conn not exist, convId: %v", event.ConvId)
continue
}
flag, ok := event.EventMessage.(string)
if !ok {
logger.LOG.Error("event KcpXorKeyChange msg type error")
continue
}
if flag == "ENC" {
k.kcpKeyMapLock.Lock()
k.kcpKeyMap[event.ConvId].encKey = k.secretKey
k.kcpKeyMapLock.Unlock()
} else if flag == "DEC" {
k.kcpKeyMapLock.Lock()
k.kcpKeyMap[event.ConvId].decKey = k.secretKey
k.kcpKeyMapLock.Unlock()
}
case KcpPacketRecvListen:
// 收包监听
k.connMapLock.RLock()
_, exist := k.connMap[event.ConvId]
k.connMapLock.RUnlock()
if !exist {
logger.LOG.Error("conn not exist, convId: %v", event.ConvId)
continue
}
flag, ok := event.EventMessage.(string)
if !ok {
logger.LOG.Error("event KcpXorKeyChange msg type error")
continue
}
if flag == "Enable" {
k.kcpRecvListenMapLock.Lock()
k.kcpRecvListenMap[event.ConvId] = true
k.kcpRecvListenMapLock.Unlock()
} else if flag == "Disable" {
k.kcpRecvListenMapLock.Lock()
k.kcpRecvListenMap[event.ConvId] = false
k.kcpRecvListenMapLock.Unlock()
}
case KcpPacketSendListen:
// 发包监听
k.connMapLock.RLock()
_, exist := k.connMap[event.ConvId]
k.connMapLock.RUnlock()
if !exist {
logger.LOG.Error("conn not exist, convId: %v", event.ConvId)
continue
}
flag, ok := event.EventMessage.(string)
if !ok {
logger.LOG.Error("event KcpXorKeyChange msg type error")
continue
}
if flag == "Enable" {
k.kcpSendListenMapLock.Lock()
k.kcpSendListenMap[event.ConvId] = true
k.kcpSendListenMapLock.Unlock()
} else if flag == "Disable" {
k.kcpSendListenMapLock.Lock()
k.kcpSendListenMap[event.ConvId] = false
k.kcpSendListenMapLock.Unlock()
}
case KcpConnForceClose:
// 强制关闭某个连接
k.connMapLock.RLock()
_, exist := k.connMap[event.ConvId]
k.connMapLock.RUnlock()
if !exist {
logger.LOG.Error("conn not exist, convId: %v", event.ConvId)
continue
}
reason, ok := event.EventMessage.(uint32)
if !ok {
logger.LOG.Error("event KcpConnForceClose msg type error")
continue
}
k.closeKcpConn(event.ConvId, reason)
logger.LOG.Info("conn has been force close, convId: %v", event.ConvId)
case KcpAllConnForceClose:
// 强制关闭所有连接
k.closeAllKcpConn()
logger.LOG.Info("all conn has been force close")
case KcpGateOpenState:
// 改变网关开放状态
openState, ok := event.EventMessage.(bool)
if !ok {
logger.LOG.Error("event KcpGateOpenState msg type error")
continue
}
k.openState = openState
if openState == false {
k.closeAllKcpConn()
}
}
}
}
+161
View File
@@ -0,0 +1,161 @@
package net
import (
pb "google.golang.org/protobuf/proto"
"hk4e/logger"
"hk4e/protocol/cmd"
"hk4e/protocol/proto"
)
type ProtoEnDecode struct {
cmdProtoMap *cmd.CmdProtoMap
}
func NewProtoEnDecode() (r *ProtoEnDecode) {
r = new(ProtoEnDecode)
r.cmdProtoMap = cmd.NewCmdProtoMap()
return r
}
type ProtoMsg struct {
ConvId uint64
CmdId uint16
HeadMessage *proto.PacketHead
PayloadMessage pb.Message
}
type ProtoMessage struct {
cmdId uint16
message pb.Message
}
func (p *ProtoEnDecode) protoDecode(kcpMsg *KcpMsg) (protoMsgList []*ProtoMsg) {
protoMsgList = make([]*ProtoMsg, 0)
protoMsg := new(ProtoMsg)
protoMsg.ConvId = kcpMsg.ConvId
protoMsg.CmdId = kcpMsg.CmdId
// head msg
if kcpMsg.HeadData != nil && len(kcpMsg.HeadData) != 0 {
headMsg := new(proto.PacketHead)
err := pb.Unmarshal(kcpMsg.HeadData, headMsg)
if err != nil {
logger.LOG.Error("unmarshal head data err: %v", err)
return protoMsgList
}
protoMsg.HeadMessage = headMsg
} else {
protoMsg.HeadMessage = nil
}
// payload msg
protoMessageList := make([]*ProtoMessage, 0)
p.protoDecodePayloadCore(kcpMsg.CmdId, kcpMsg.ProtoData, &protoMessageList)
if len(protoMessageList) == 0 {
logger.LOG.Error("decode proto object is nil")
return protoMsgList
}
if kcpMsg.CmdId == cmd.UnionCmdNotify {
for _, protoMessage := range protoMessageList {
msg := new(ProtoMsg)
msg.ConvId = kcpMsg.ConvId
msg.CmdId = protoMessage.cmdId
msg.HeadMessage = protoMsg.HeadMessage
msg.PayloadMessage = protoMessage.message
logger.LOG.Debug("[recv] union proto msg, convId: %v, cmdId: %v", msg.ConvId, msg.CmdId)
if protoMessage.cmdId == cmd.UnionCmdNotify {
// 聚合消息自身不再往后发送
continue
}
logger.LOG.Debug("[recv] proto msg, convId: %v, cmdId: %v, headMsg: %v", protoMsg.ConvId, protoMsg.CmdId, protoMsg.HeadMessage)
protoMsgList = append(protoMsgList, msg)
}
// 聚合消息自身不再往后发送
return protoMsgList
} else {
protoMsg.PayloadMessage = protoMessageList[0].message
}
logger.LOG.Debug("[recv] proto msg, convId: %v, cmdId: %v, headMsg: %v", protoMsg.ConvId, protoMsg.CmdId, protoMsg.HeadMessage)
protoMsgList = append(protoMsgList, protoMsg)
return protoMsgList
}
func (p *ProtoEnDecode) protoDecodePayloadCore(cmdId uint16, protoData []byte, protoMessageList *[]*ProtoMessage) {
protoObj := p.decodePayloadToProto(cmdId, protoData)
if protoObj == nil {
logger.LOG.Error("decode proto object is nil")
return
}
if cmdId == cmd.UnionCmdNotify {
// 处理聚合消息
unionCmdNotify, ok := protoObj.(*proto.UnionCmdNotify)
if !ok {
logger.LOG.Error("parse union cmd error")
return
}
for _, unionCmd := range unionCmdNotify.GetCmdList() {
p.protoDecodePayloadCore(uint16(unionCmd.MessageId), unionCmd.Body, protoMessageList)
}
}
*protoMessageList = append(*protoMessageList, &ProtoMessage{
cmdId: cmdId,
message: protoObj,
})
}
func (p *ProtoEnDecode) protoEncode(protoMsg *ProtoMsg) (kcpMsg *KcpMsg) {
logger.LOG.Debug("[send] proto msg, convId: %v, cmdId: %v, headMsg: %v", protoMsg.ConvId, protoMsg.CmdId, protoMsg.HeadMessage)
kcpMsg = new(KcpMsg)
kcpMsg.ConvId = protoMsg.ConvId
kcpMsg.CmdId = protoMsg.CmdId
// head msg
if protoMsg.HeadMessage != nil {
headData, err := pb.Marshal(protoMsg.HeadMessage)
if err != nil {
logger.LOG.Error("marshal head data err: %v", err)
return nil
}
kcpMsg.HeadData = headData
} else {
kcpMsg.HeadData = nil
}
// payload msg
if protoMsg.PayloadMessage != nil {
cmdId, protoData := p.encodeProtoToPayload(protoMsg.PayloadMessage)
if cmdId == 0 || protoData == nil {
logger.LOG.Error("encode proto data is nil")
return nil
}
if cmdId != 65535 && cmdId != protoMsg.CmdId {
logger.LOG.Error("cmd id is not match with proto obj, src cmd id: %v, found cmd id: %v", protoMsg.CmdId, cmdId)
return nil
}
kcpMsg.ProtoData = protoData
} else {
kcpMsg.ProtoData = nil
}
return kcpMsg
}
func (p *ProtoEnDecode) decodePayloadToProto(cmdId uint16, protoData []byte) (protoObj pb.Message) {
protoObj = p.cmdProtoMap.GetProtoObjByCmdId(cmdId)
if protoObj == nil {
logger.LOG.Error("get new proto object is nil")
return nil
}
err := pb.Unmarshal(protoData, protoObj)
if err != nil {
logger.LOG.Error("unmarshal proto data err: %v", err)
return nil
}
return protoObj
}
func (p *ProtoEnDecode) encodeProtoToPayload(protoObj pb.Message) (cmdId uint16, protoData []byte) {
cmdId = p.cmdProtoMap.GetCmdIdByProtoObj(protoObj)
var err error = nil
protoData, err = pb.Marshal(protoObj)
if err != nil {
logger.LOG.Error("marshal proto object err: %v", err)
return 0, nil
}
return cmdId, protoData
}