From 33f2b2ef60cd7cd6919c0e29b529950a86f63ef3 Mon Sep 17 00:00:00 2001 From: wenzl Date: Sun, 11 Sep 2016 00:27:37 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E6=96=87=E6=9C=AC=E6=B6=88?= =?UTF-8?q?=E6=81=AF=E7=9A=84=E5=9B=9E=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- context/context.go | 6 -- context/render.go | 41 ++++++++++++ log/log.go | 61 ----------------- message/message.go | 112 +++++++++++++++++++++++++++++++ message/reply.go | 15 +++++ message/text.go | 7 ++ server/server.go | 159 ++++++++++++++++++++++++++++++++++++++++----- util/crypto.go | 82 +++++++++++++++++++++-- util/time.go | 8 +++ wechat.go | 8 --- 10 files changed, 403 insertions(+), 96 deletions(-) create mode 100644 context/render.go delete mode 100644 log/log.go create mode 100644 message/reply.go create mode 100644 message/text.go create mode 100644 util/time.go diff --git a/context/context.go b/context/context.go index 1fb2227..82288ae 100644 --- a/context/context.go +++ b/context/context.go @@ -17,12 +17,6 @@ func (ctx *Context) getAccessToken() { } -func (ctx *Context) String(str string) error { - ctx.Writer.WriteHeader(200) - _, err := ctx.Writer.Write([]byte(str)) - return err -} - // Query returns the keyed url query value if it exists func (ctx *Context) Query(key string) string { value, _ := ctx.GetQuery(key) diff --git a/context/render.go b/context/render.go new file mode 100644 index 0000000..35478dd --- /dev/null +++ b/context/render.go @@ -0,0 +1,41 @@ +package context + +import ( + "encoding/xml" + "net/http" +) + +var xmlContentType = []string{"application/xml; charset=utf-8"} +var plainContentType = []string{"text/plain; charset=utf-8"} + +//Render render from bytes +func (ctx *Context) Render(bytes []byte) { + ctx.Writer.WriteHeader(200) + _, err := ctx.Writer.Write(bytes) + if err != nil { + panic(err) + } +} + +//String render from string +func (ctx *Context) String(str string) { + writeContextType(ctx.Writer, plainContentType) + ctx.Render([]byte(str)) +} + +//XML render to xml +func (ctx *Context) XML(obj interface{}) { + writeContextType(ctx.Writer, xmlContentType) + bytes, err := xml.Marshal(obj) + if err != nil { + panic(err) + } + ctx.Render(bytes) +} + +func writeContextType(w http.ResponseWriter, value []string) { + header := w.Header() + if val := header["Content-Type"]; len(val) == 0 { + header["Content-Type"] = value + } +} diff --git a/log/log.go b/log/log.go deleted file mode 100644 index 97c5cf7..0000000 --- a/log/log.go +++ /dev/null @@ -1,61 +0,0 @@ -package log - -import "github.com/astaxie/beego/logs" - -const ( - LevelEmergency = iota - LevelAlert - LevelCritical - LevelError - LevelWarning - LevelNotice - LevelInformational - LevelDebug -) - -type Logger struct { - *logs.BeeLogger -} - -func NewLogger(channelLen int64, adapterName string, config string, logLevel int) *Logger { - logger := logs.NewLogger(channelLen) - logger.SetLogger(adapterName, config) - logger.SetLevel(logLevel) - logger.EnableFuncCallDepth(true) - logger.SetLogFuncCallDepth(3) - return &Logger{logger} -} - -func (logger *Logger) Printf(format string, v ...interface{}) { - logger.Trace(format, v...) -} - -var l *Logger - -func InitLogger(channelLen int64, adapterName string, config string, logLevel int) { - l = NewLogger(channelLen, adapterName, config, logLevel) -} - -func Criticalf(format string, v ...interface{}) { - l.Critical(format, v...) -} - -func Errorf(format string, v ...interface{}) { - l.Error(format, v...) -} - -func Warnf(format string, v ...interface{}) { - l.Warn(format, v...) -} - -func Infof(format string, v ...interface{}) { - l.Info(format, v...) -} - -func Tracef(format string, v ...interface{}) { - l.Trace(format, v...) -} - -func Debugf(format string, v ...interface{}) { - l.Debug(format, v...) -} diff --git a/message/message.go b/message/message.go index d43e37d..47c8781 100644 --- a/message/message.go +++ b/message/message.go @@ -1,8 +1,120 @@ package message +import "encoding/xml" + +//MsgType 基本消息类型 +type MsgType string + +//EventType 事件类型 +type EventType string + +const ( + //MsgTypeText 表示文本消息 + MsgTypeText MsgType = "text" + //MsgTypeImage 表示图片消息 + MsgTypeImage = "image" + //MsgTypeVoice 表示语音消息 + MsgTypeVoice = "voice" + //MsgTypeVideo 表示视频消息 + MsgTypeVideo = "video" + //MsgTypeShortVideo 表示短视频消息[限接收] + MsgTypeShortVideo = "shortvideo" + //MsgTypeLocation 表示坐标消息[限接收] + MsgTypeLocation = "location" + //MsgTypeLink 表示链接消息[限接收] + MsgTypeLink = "link" + //MsgTypeMusic 表示音乐消息[限回复] + MsgTypeMusic = "music" + //MsgTypeNews 表示图文消息[限回复] + MsgTypeNews = "news" + //MsgTypeTransfer 表示消息消息转发到客服 + MsgTypeTransfer = "transfer_customer_service" +) + +const ( + //EventSubscribe 订阅 + EventSubscribe EventType = "subscribe" + //EventUnsubscribe 取消订阅 + EventUnsubscribe EventType = "unsubscribe" + //EventScan 用户已经关注公众号,则微信会将带场景值扫描事件推送给开发者 + EventScan EventType = "SCAN" + //EventLocation 上报地理位置事件 + EventLocation EventType = "LOCATION" + //EventClick 点击菜单拉取消息时的事件推送 + EventClick EventType = "CLICK" + //EventView 点击菜单跳转链接时的事件推送 + EventView EventType = "VIEW" +) + //EncryptedXMLMsg 安全模式下的消息体 type EncryptedXMLMsg struct { XMLName struct{} `xml:"xml" json:"-"` ToUserName string `xml:"ToUserName" json:"ToUserName"` EncryptedMsg string `xml:"Encrypt" json:"Encrypt"` } + +//ResponseEncryptedXMLMsg 需要返回的消息体 +type ResponseEncryptedXMLMsg struct { + XMLName struct{} `xml:"xml" json:"-"` + EncryptedMsg string `xml:"Encrypt" json:"Encrypt"` + MsgSignature string `xml:"MsgSignature" json:"MsgSignature"` + Timestamp int64 `xml:"TimeStamp" json:"TimeStamp"` + Nonce string `xml:"Nonce" json:"Nonce"` +} + +// CommonToken 消息中通用的结构 +type CommonToken struct { + XMLName xml.Name `xml:"xml"` + ToUserName string `xml:"ToUserName"` + FromUserName string `xml:"FromUserName"` + CreateTime int64 `xml:"CreateTime"` + MsgType MsgType `xml:"MsgType"` +} + +//SetToUserName set ToUserName +func (msg *CommonToken) SetToUserName(toUserName string) { + msg.ToUserName = toUserName +} + +//SetFromUserName set FromUserName +func (msg *CommonToken) SetFromUserName(fromUserName string) { + msg.FromUserName = fromUserName +} + +//SetCreateTime set createTime +func (msg *CommonToken) SetCreateTime(createTime int64) { + msg.CreateTime = createTime +} + +//SetMsgType set MsgType +func (msg *CommonToken) SetMsgType(msgType MsgType) { + msg.MsgType = msgType +} + +//MixMessage 存放所有微信发送过来的消息和事件 +type MixMessage struct { + CommonToken + + //基本消息 + MsgID int64 `xml:"MsgId"` + Content string `xml:"Content"` + PicURL string `xml:"PicUrl"` + MediaID string `xml:"MediaId"` + Format string `xml:"Format"` + ThumbMediaID string `xml:"ThumbMediaId"` + LocationX float64 `xml:"Location_X"` + LocationY float64 `xml:"Location_Y"` + Scale float64 `xml:"Scale"` + Label string `xml:"Label"` + Title string `xml:"Title"` + Description string `xml:"Description"` + URL string `xml:"Url"` + + //事件相关 + Event string `xml:"Event"` + EventKey string `xml:"EventKey"` + Ticket string `xml:"Ticket"` + Latitude string `xml:"Latitude"` + Longitude string `xml:"Longitude"` + Precision string `xml:"Precision"` +} diff --git a/message/reply.go b/message/reply.go new file mode 100644 index 0000000..53592f0 --- /dev/null +++ b/message/reply.go @@ -0,0 +1,15 @@ +package message + +import "errors" + +//ErrInvalidReply 无效的回复 +var ErrInvalidReply = errors.New("无效的回复消息") + +//ErrUnsupportReply 不支持的回复类型 +var ErrUnsupportReply = errors.New("不支持的回复消息") + +//Reply 消息回复 +type Reply struct { + MsgType MsgType + MsgData interface{} +} diff --git a/message/text.go b/message/text.go new file mode 100644 index 0000000..c34fb59 --- /dev/null +++ b/message/text.go @@ -0,0 +1,7 @@ +package message + +//Text 文本消息 +type Text struct { + CommonToken + Content string `xml:"Content"` +} diff --git a/server/server.go b/server/server.go index c2a6ef9..068ba31 100644 --- a/server/server.go +++ b/server/server.go @@ -2,8 +2,12 @@ package server import ( "encoding/xml" + "errors" "fmt" "io/ioutil" + "reflect" + "strconv" + "strings" "github.com/silenceper/wechat/context" "github.com/silenceper/wechat/message" @@ -13,8 +17,20 @@ import ( //Server struct type Server struct { *context.Context + + openID string + + messageHandler func(message.MixMessage) *message.Reply + + requestRawXMLMsg []byte + requestMsg message.MixMessage + responseRawXMLMsg []byte + responseMsg interface{} + isSafeMode bool - rawXMLMsg string + random []byte + nonce string + timestamp int64 } //NewServer init @@ -26,18 +42,22 @@ func NewServer(context *context.Context) *Server { //Serve 处理微信的请求消息 func (srv *Server) Serve() error { - if !srv.Validate() { return fmt.Errorf("请求校验失败") } echostr, exists := srv.GetQuery("echostr") if exists { - return srv.String(echostr) + srv.String(echostr) + return nil } - srv.handleRequest() + response, err := srv.handleRequest() + if err != nil { + return err + } + srv.buildResponse(response) return nil } @@ -50,19 +70,37 @@ func (srv *Server) Validate() bool { } //HandleRequest 处理微信的请求 -func (srv *Server) handleRequest() { +func (srv *Server) handleRequest() (reply *message.Reply, err error) { + //set isSafeMode srv.isSafeMode = false encryptType := srv.Query("encrypt_type") if encryptType == "aes" { srv.isSafeMode = true } - _, err := srv.getMessage() + //set openID + srv.openID = srv.Query("openid") + + var msg interface{} + msg, err = srv.getMessage() if err != nil { - fmt.Printf("%v", err) + return } + mixMessage, success := msg.(message.MixMessage) + if !success { + err = errors.New("消息类型转换失败") + } + srv.requestMsg = mixMessage + reply = srv.messageHandler(mixMessage) + return } +//GetOpenID return openID +func (srv *Server) GetOpenID() string { + return srv.openID +} + +//getMessage 解析微信返回的消息 func (srv *Server) getMessage() (interface{}, error) { var rawXMLMsgBytes []byte var err error @@ -74,26 +112,117 @@ func (srv *Server) getMessage() (interface{}, error) { //验证消息签名 timestamp := srv.Query("timestamp") + srv.timestamp, err = strconv.ParseInt(timestamp, 10, 32) + if err != nil { + return nil, err + } nonce := srv.Query("nonce") + srv.nonce = nonce msgSignature := srv.Query("msg_signature") - msgSignatureCreate := util.Signature(srv.Token, timestamp, nonce, encryptedXMLMsg.EncryptedMsg) - if msgSignature != msgSignatureCreate { + msgSignatureGen := util.Signature(srv.Token, timestamp, nonce, encryptedXMLMsg.EncryptedMsg) + if msgSignature != msgSignatureGen { return nil, fmt.Errorf("消息不合法,验证签名失败") } //解密 - rawXMLMsgBytes, err = util.DecryptMsg(srv.AppID, encryptedXMLMsg.EncryptedMsg, srv.EncodingAESKey) + srv.random, rawXMLMsgBytes, err = util.DecryptMsg(srv.AppID, encryptedXMLMsg.EncryptedMsg, srv.EncodingAESKey) if err != nil { - return nil, fmt.Errorf("消息解密失败,err=%v", err) + return nil, fmt.Errorf("消息解密失败, err=%v", err) } } else { rawXMLMsgBytes, err = ioutil.ReadAll(srv.Request.Body) if err != nil { - return nil, fmt.Errorf("从body中解析xml失败,err=%v", err) + return nil, fmt.Errorf("从body中解析xml失败, err=%v", err) } } - srv.rawXMLMsg = string(rawXMLMsgBytes) - fmt.Println(srv.rawXMLMsg) - return nil, nil + srv.requestRawXMLMsg = rawXMLMsgBytes + + return srv.parseRequestMessage(rawXMLMsgBytes) +} + +func (srv *Server) parseRequestMessage(rawXMLMsgBytes []byte) (msg message.MixMessage, err error) { + msg = message.MixMessage{} + err = xml.Unmarshal(rawXMLMsgBytes, &msg) + return +} + +//SetMessageHandler 设置用户自定义的回调方法 +func (srv *Server) SetMessageHandler(handler func(message.MixMessage) *message.Reply) { + srv.messageHandler = handler +} + +func (srv *Server) buildResponse(reply *message.Reply) (err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("panic error: %v", err) + } + }() + if reply == nil { + //do nothing + return nil + } + msgType := reply.MsgType + switch msgType { + case message.MsgTypeText: + case message.MsgTypeImage: + case message.MsgTypeVoice: + case message.MsgTypeVideo: + case message.MsgTypeMusic: + case message.MsgTypeNews: + case message.MsgTypeTransfer: + default: + err = message.ErrUnsupportReply + return + } + + msgData := reply.MsgData + value := reflect.ValueOf(msgData) + //msgData must be a ptr + kind := value.Kind().String() + if 0 != strings.Compare("ptr", kind) { + return message.ErrUnsupportReply + } + + params := make([]reflect.Value, 1) + params[0] = reflect.ValueOf(srv.requestMsg.FromUserName) + value.MethodByName("SetToUserName").Call(params) + + params[0] = reflect.ValueOf(srv.requestMsg.ToUserName) + value.MethodByName("SetFromUserName").Call(params) + + params[0] = reflect.ValueOf(msgType) + value.MethodByName("SetMsgType").Call(params) + + params[0] = reflect.ValueOf(util.GetCurrTs()) + value.MethodByName("SetCreateTime").Call(params) + + srv.responseMsg = msgData + srv.responseRawXMLMsg, err = xml.Marshal(msgData) + return +} + +//Send 将自定义的消息发送 +func (srv *Server) Send() (err error) { + replyMsg := srv.responseMsg + if srv.isSafeMode { + //安全模式下对消息进行加密 + var encryptedMsg []byte + encryptedMsg, err = util.EncryptMsg(srv.random, srv.responseRawXMLMsg, srv.AppID, srv.EncodingAESKey) + if err != nil { + return + } + //TODO 如果获取不到timestamp nonce 则自己生成 + timestamp := srv.timestamp + timestampStr := strconv.FormatInt(timestamp, 10) + msgSignature := util.Signature(srv.Token, timestampStr, srv.nonce, string(encryptedMsg)) + replyMsg = message.ResponseEncryptedXMLMsg{ + EncryptedMsg: string(encryptedMsg), + MsgSignature: msgSignature, + Timestamp: timestamp, + Nonce: srv.nonce, + } + } + srv.XML(replyMsg) + return } diff --git a/util/crypto.go b/util/crypto.go index 928048d..1ec7a4f 100644 --- a/util/crypto.go +++ b/util/crypto.go @@ -4,12 +4,73 @@ import ( "crypto/aes" "crypto/cipher" "encoding/base64" - "errors" "fmt" ) +//EncryptMsg 加密消息 +func EncryptMsg(random, rawXMLMsg []byte, appID, aesKey string) (encrtptMsg []byte, err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("panic error: err=%v", e) + return + } + }() + var key []byte + key, err = aesKeyDecode(aesKey) + if err != nil { + panic(err) + } + ciphertext := AESEncryptMsg(random, rawXMLMsg, appID, key) + encrtptMsg = []byte(base64.StdEncoding.EncodeToString(ciphertext)) + return +} + +//AESEncryptMsg ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId] +//参考:github.com/chanxuehong/wechat.v2 +func AESEncryptMsg(random, rawXMLMsg []byte, appID string, aesKey []byte) (ciphertext []byte) { + const ( + BlockSize = 32 // PKCS#7 + BlockMask = BlockSize - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数 + ) + + appIDOffset := 20 + len(rawXMLMsg) + contentLen := appIDOffset + len(appID) + amountToPad := BlockSize - contentLen&BlockMask + plaintextLen := contentLen + amountToPad + + plaintext := make([]byte, plaintextLen) + + // 拼接 + copy(plaintext[:16], random) + encodeNetworkByteOrder(plaintext[16:20], uint32(len(rawXMLMsg))) + copy(plaintext[20:], rawXMLMsg) + copy(plaintext[appIDOffset:], appID) + + // PKCS#7 补位 + for i := contentLen; i < plaintextLen; i++ { + plaintext[i] = byte(amountToPad) + } + + // 加密 + block, err := aes.NewCipher(aesKey[:]) + if err != nil { + panic(err) + } + mode := cipher.NewCBCEncrypter(block, aesKey[:16]) + mode.CryptBlocks(plaintext, plaintext) + + ciphertext = plaintext + return +} + //DecryptMsg 消息解密 -func DecryptMsg(appID, encryptedMsg, aesKey string) (rawMsgXMLBytes []byte, err error) { +func DecryptMsg(appID, encryptedMsg, aesKey string) (random, rawMsgXMLBytes []byte, err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("panic error: err=%v", e) + return + } + }() var encryptedMsgBytes, key, getAppIDBytes []byte encryptedMsgBytes, err = base64.StdEncoding.DecodeString(encryptedMsg) if err != nil { @@ -17,9 +78,9 @@ func DecryptMsg(appID, encryptedMsg, aesKey string) (rawMsgXMLBytes []byte, err } key, err = aesKeyDecode(aesKey) if err != nil { - return + panic(err) } - _, rawMsgXMLBytes, getAppIDBytes, err = AESDecryptMsg(encryptedMsgBytes, key) + random, rawMsgXMLBytes, getAppIDBytes, err = AESDecryptMsg(encryptedMsgBytes, key) if err != nil { err = fmt.Errorf("消息解密失败,%v", err) return @@ -33,7 +94,7 @@ func DecryptMsg(appID, encryptedMsg, aesKey string) (rawMsgXMLBytes []byte, err func aesKeyDecode(encodedAESKey string) (key []byte, err error) { if len(encodedAESKey) != 43 { - err = errors.New("the length of encodedAESKey must be equal to 43") + err = fmt.Errorf("the length of encodedAESKey must be equal to 43") return } key, err = base64.StdEncoding.DecodeString(encodedAESKey + "=") @@ -41,13 +102,14 @@ func aesKeyDecode(encodedAESKey string) (key []byte, err error) { return } if len(key) != 32 { - err = errors.New("encodingAESKey invalid") + err = fmt.Errorf("encodingAESKey invalid") return } return } // AESDecryptMsg ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId] +//参考:github.com/chanxuehong/wechat.v2 func AESDecryptMsg(ciphertext []byte, aesKey []byte) (random, rawXMLMsg, appID []byte, err error) { const ( BlockSize = 32 // PKCS#7 @@ -104,6 +166,14 @@ func AESDecryptMsg(ciphertext []byte, aesKey []byte) (random, rawXMLMsg, appID [ return } +// 把整数 n 格式化成 4 字节的网络字节序 +func encodeNetworkByteOrder(orderBytes []byte, n uint32) { + orderBytes[0] = byte(n >> 24) + orderBytes[1] = byte(n >> 16) + orderBytes[2] = byte(n >> 8) + orderBytes[3] = byte(n) +} + // 从 4 字节的网络字节序里解析出整数 func decodeNetworkByteOrder(orderBytes []byte) (n uint32) { return uint32(orderBytes[0])<<24 | diff --git a/util/time.go b/util/time.go new file mode 100644 index 0000000..024839c --- /dev/null +++ b/util/time.go @@ -0,0 +1,8 @@ +package util + +import "time" + +//GetCurrTs return current timestamps +func GetCurrTs() int64 { + return time.Now().Unix() +} diff --git a/wechat.go b/wechat.go index a4af8d9..c81d562 100644 --- a/wechat.go +++ b/wechat.go @@ -4,7 +4,6 @@ import ( "net/http" "github.com/silenceper/wechat/context" - "github.com/silenceper/wechat/log" "github.com/silenceper/wechat/server" ) @@ -23,13 +22,6 @@ type Config struct { //NewWechat init func NewWechat(cfg *Config) *Wechat { - - channelLen := int64(10000) - adapterName := "console" - config := "" - logLevel := log.LevelDebug - log.InitLogger(channelLen, adapterName, config, logLevel) - context := new(context.Context) copyConfigToContext(cfg, context) return &Wechat{context}