Files
hk4e/gate/net/proto_endecode.go
2023-01-05 01:29:00 +08:00

308 lines
9.6 KiB
Go

package net
import (
"reflect"
"hk4e/common/config"
"hk4e/pkg/logger"
"hk4e/pkg/object"
"hk4e/protocol/cmd"
"hk4e/protocol/proto"
pb "google.golang.org/protobuf/proto"
)
type ProtoMsg struct {
ConvId uint64
CmdId uint16
HeadMessage *proto.PacketHead
PayloadMessage pb.Message
}
type ProtoMessage struct {
cmdId uint16
message pb.Message
}
func (k *KcpConnectManager) protoDecode(kcpMsg *KcpMsg) (protoMsgList []*ProtoMsg) {
protoMsgList = make([]*ProtoMsg, 0)
if config.CONF.Hk4e.ClientProtoProxyEnable {
clientCmdId := kcpMsg.CmdId
clientProtoData := kcpMsg.ProtoData
cmdName := k.clientCmdProtoMap.GetClientCmdNameByCmdId(clientCmdId)
if cmdName == "" {
logger.Error("get cmdName is nil, clientCmdId: %v", clientCmdId)
return protoMsgList
}
clientProtoObj := k.getClientProtoObjByName(cmdName)
if clientProtoObj == nil {
logger.Error("get client proto obj is nil, cmdName: %v", cmdName)
return protoMsgList
}
err := pb.Unmarshal(clientProtoData, clientProtoObj)
if err != nil {
logger.Error("unmarshal client proto error: %v", err)
return protoMsgList
}
serverCmdId := k.serverCmdProtoMap.GetCmdIdByCmdName(cmdName)
if serverCmdId == 0 {
logger.Error("get server cmdId is nil, cmdName: %v", cmdName)
return protoMsgList
}
serverProtoObj := k.serverCmdProtoMap.GetProtoObjByCmdId(serverCmdId)
if serverProtoObj == nil {
logger.Error("get server proto obj is nil, serverCmdId: %v", serverCmdId)
return protoMsgList
}
delList, err := object.CopyProtoBufSameField(serverProtoObj, clientProtoObj)
if err != nil {
logger.Error("copy proto obj error: %v", err)
return protoMsgList
}
if len(delList) != 0 {
logger.Error("delete field name list: %v, cmdName: %v", delList, cmdName)
}
serverProtoData, err := pb.Marshal(serverProtoObj)
if err != nil {
logger.Error("marshal server proto error: %v", err)
return protoMsgList
}
kcpMsg.CmdId = serverCmdId
kcpMsg.ProtoData = serverProtoData
}
protoMsg := new(ProtoMsg)
protoMsg.ConvId = kcpMsg.ConvId
protoMsg.CmdId = kcpMsg.CmdId
// head msg
if kcpMsg.HeadData != nil && len(kcpMsg.HeadData) != 0 {
headMsg := new(proto.PacketHead)
err := pb.Unmarshal(kcpMsg.HeadData, headMsg)
if err != nil {
logger.Error("unmarshal head data err: %v", err)
return protoMsgList
}
protoMsg.HeadMessage = headMsg
} else {
protoMsg.HeadMessage = nil
}
// payload msg
protoMessageList := make([]*ProtoMessage, 0)
k.protoDecodePayloadLoop(kcpMsg.CmdId, kcpMsg.ProtoData, &protoMessageList)
if len(protoMessageList) == 0 {
logger.Error("decode proto object is nil")
return protoMsgList
}
if kcpMsg.CmdId == cmd.UnionCmdNotify {
for _, protoMessage := range protoMessageList {
msg := new(ProtoMsg)
msg.ConvId = kcpMsg.ConvId
msg.CmdId = protoMessage.cmdId
msg.HeadMessage = protoMsg.HeadMessage
msg.PayloadMessage = protoMessage.message
protoMsgList = append(protoMsgList, msg)
}
for _, msg := range protoMsgList {
cmdName := "???"
if msg.PayloadMessage != nil {
cmdName = string(msg.PayloadMessage.ProtoReflect().Descriptor().FullName())
}
logger.Debug("[RECV UNION CMD], cmdId: %v, cmdName: %v, convId: %v, headMsg: %v", msg.CmdId, cmdName, msg.ConvId, msg.HeadMessage)
}
} else {
protoMsg.PayloadMessage = protoMessageList[0].message
protoMsgList = append(protoMsgList, protoMsg)
cmdName := ""
if protoMsg.PayloadMessage != nil {
cmdName = string(protoMsg.PayloadMessage.ProtoReflect().Descriptor().FullName())
}
logger.Debug("[RECV], cmdId: %v, cmdName: %v, convId: %v, headMsg: %v", protoMsg.CmdId, cmdName, protoMsg.ConvId, protoMsg.HeadMessage)
}
return protoMsgList
}
func (k *KcpConnectManager) protoDecodePayloadLoop(cmdId uint16, protoData []byte, protoMessageList *[]*ProtoMessage) {
protoObj := k.decodePayloadToProto(cmdId, protoData)
if protoObj == nil {
logger.Error("decode proto object is nil")
return
}
if cmdId == cmd.UnionCmdNotify {
// 处理聚合消息
unionCmdNotify, ok := protoObj.(*proto.UnionCmdNotify)
if !ok {
logger.Error("parse union cmd error")
return
}
for _, unionCmd := range unionCmdNotify.GetCmdList() {
if config.CONF.Hk4e.ClientProtoProxyEnable {
clientCmdId := uint16(unionCmd.MessageId)
clientProtoData := unionCmd.Body
cmdName := k.clientCmdProtoMap.GetClientCmdNameByCmdId(clientCmdId)
if cmdName == "" {
logger.Error("get cmdName is nil, clientCmdId: %v", clientCmdId)
continue
}
clientProtoObj := k.getClientProtoObjByName(cmdName)
if clientProtoObj == nil {
logger.Error("get client proto obj is nil, cmdName: %v", cmdName)
continue
}
err := pb.Unmarshal(clientProtoData, clientProtoObj)
if err != nil {
logger.Error("unmarshal client proto error: %v", err)
continue
}
serverCmdId := k.serverCmdProtoMap.GetCmdIdByCmdName(cmdName)
if serverCmdId == 0 {
logger.Error("get server cmdId is nil, cmdName: %v", cmdName)
continue
}
serverProtoObj := k.serverCmdProtoMap.GetProtoObjByCmdId(serverCmdId)
if serverProtoObj == nil {
logger.Error("get server proto obj is nil, serverCmdId: %v", serverCmdId)
continue
}
delList, err := object.CopyProtoBufSameField(serverProtoObj, clientProtoObj)
if err != nil {
logger.Error("copy proto obj error: %v", err)
continue
}
if len(delList) != 0 {
logger.Error("delete field name list: %v, cmdName: %v", delList, cmdName)
}
serverProtoData, err := pb.Marshal(serverProtoObj)
if err != nil {
logger.Error("marshal server proto error: %v", err)
continue
}
unionCmd.MessageId = uint32(serverCmdId)
unionCmd.Body = serverProtoData
}
k.protoDecodePayloadLoop(uint16(unionCmd.MessageId), unionCmd.Body, protoMessageList)
}
}
*protoMessageList = append(*protoMessageList, &ProtoMessage{
cmdId: cmdId,
message: protoObj,
})
}
func (k *KcpConnectManager) protoEncode(protoMsg *ProtoMsg) (kcpMsg *KcpMsg) {
cmdName := ""
if protoMsg.PayloadMessage != nil {
cmdName = string(protoMsg.PayloadMessage.ProtoReflect().Descriptor().FullName())
}
logger.Debug("[SEND], cmdId: %v, cmdName: %v, convId: %v, headMsg: %v", protoMsg.CmdId, cmdName, protoMsg.ConvId, protoMsg.HeadMessage)
kcpMsg = new(KcpMsg)
kcpMsg.ConvId = protoMsg.ConvId
kcpMsg.CmdId = protoMsg.CmdId
// head msg
if protoMsg.HeadMessage != nil {
headData, err := pb.Marshal(protoMsg.HeadMessage)
if err != nil {
logger.Error("marshal head data err: %v", err)
return nil
}
kcpMsg.HeadData = headData
} else {
kcpMsg.HeadData = nil
}
// payload msg
if protoMsg.PayloadMessage != nil {
cmdId, protoData := k.encodeProtoToPayload(protoMsg.PayloadMessage)
if cmdId == 0 || protoData == nil {
logger.Error("encode proto data is nil")
return nil
}
if cmdId != 65535 && cmdId != protoMsg.CmdId {
logger.Error("cmd id is not match with proto obj, src cmd id: %v, found cmd id: %v", protoMsg.CmdId, cmdId)
return nil
}
kcpMsg.ProtoData = protoData
} else {
kcpMsg.ProtoData = nil
}
if config.CONF.Hk4e.ClientProtoProxyEnable {
serverCmdId := kcpMsg.CmdId
serverProtoData := kcpMsg.ProtoData
serverProtoObj := k.serverCmdProtoMap.GetProtoObjByCmdId(serverCmdId)
if serverProtoObj == nil {
logger.Error("get server proto obj is nil, serverCmdId: %v", serverCmdId)
return nil
}
err := pb.Unmarshal(serverProtoData, serverProtoObj)
if err != nil {
logger.Error("unmarshal server proto error: %v", err)
return nil
}
cmdName := k.serverCmdProtoMap.GetCmdNameByCmdId(serverCmdId)
if cmdName == "" {
logger.Error("get cmdName is nil, serverCmdId: %v", serverCmdId)
return nil
}
clientProtoObj := k.getClientProtoObjByName(cmdName)
if clientProtoObj == nil {
logger.Error("get client proto obj is nil, cmdName: %v", cmdName)
return nil
}
delList, err := object.CopyProtoBufSameField(clientProtoObj, serverProtoObj)
if err != nil {
logger.Error("copy proto obj error: %v", err)
return nil
}
if len(delList) != 0 {
logger.Error("delete field name list: %v, cmdName: %v", delList, cmdName)
}
clientProtoData, err := pb.Marshal(clientProtoObj)
if err != nil {
logger.Error("marshal client proto error: %v", err)
return nil
}
clientCmdId := k.clientCmdProtoMap.GetClientCmdIdByCmdName(cmdName)
if clientCmdId == 0 {
logger.Error("get client cmdId is nil, cmdName: %v", cmdName)
return nil
}
kcpMsg.CmdId = clientCmdId
kcpMsg.ProtoData = clientProtoData
}
return kcpMsg
}
func (k *KcpConnectManager) decodePayloadToProto(cmdId uint16, protoData []byte) (protoObj pb.Message) {
protoObj = k.serverCmdProtoMap.GetProtoObjByCmdId(cmdId)
if protoObj == nil {
logger.Error("get new proto object is nil")
return nil
}
err := pb.Unmarshal(protoData, protoObj)
if err != nil {
logger.Error("unmarshal proto data err: %v", err)
return nil
}
return protoObj
}
func (k *KcpConnectManager) encodeProtoToPayload(protoObj pb.Message) (cmdId uint16, protoData []byte) {
cmdId = k.serverCmdProtoMap.GetCmdIdByProtoObj(protoObj)
var err error = nil
protoData, err = pb.Marshal(protoObj)
if err != nil {
logger.Error("marshal proto object err: %v", err)
return 0, nil
}
return cmdId, protoData
}
func (k *KcpConnectManager) getClientProtoObjByName(protoObjName string) pb.Message {
fn := k.clientCmdProtoMapRefValue.MethodByName("GetClientProtoObjByName")
ret := fn.Call([]reflect.Value{reflect.ValueOf(protoObjName)})
obj := ret[0].Interface()
if obj == nil {
logger.Error("try to get a not exist proto obj, protoObjName: %v", protoObjName)
return nil
}
clientProtoObj := obj.(pb.Message)
return clientProtoObj
}