diff --git a/context/context.go b/context/context.go new file mode 100644 index 0000000..1fb2227 --- /dev/null +++ b/context/context.go @@ -0,0 +1,39 @@ +package context + +import "net/http" + +//Context struct +type Context struct { + AppID string + AppSecret string + Token string + EncodingAESKey string + + Writer http.ResponseWriter + Request *http.Request +} + +func (ctx *Context) getAccessToken() { + +} + +func (ctx *Context) String(str string) error { + ctx.Writer.WriteHeader(200) + _, err := ctx.Writer.Write([]byte(str)) + return err +} + +// Query returns the keyed url query value if it exists +func (ctx *Context) Query(key string) string { + value, _ := ctx.GetQuery(key) + return value +} + +// GetQuery is like Query(), it returns the keyed url query value +func (ctx *Context) GetQuery(key string) (string, bool) { + req := ctx.Request + if values, ok := req.URL.Query()[key]; ok && len(values) > 0 { + return values[0], true + } + return "", false +} diff --git a/log/log.go b/log/log.go new file mode 100644 index 0000000..97c5cf7 --- /dev/null +++ b/log/log.go @@ -0,0 +1,61 @@ +package log + +import "github.com/astaxie/beego/logs" + +const ( + LevelEmergency = iota + LevelAlert + LevelCritical + LevelError + LevelWarning + LevelNotice + LevelInformational + LevelDebug +) + +type Logger struct { + *logs.BeeLogger +} + +func NewLogger(channelLen int64, adapterName string, config string, logLevel int) *Logger { + logger := logs.NewLogger(channelLen) + logger.SetLogger(adapterName, config) + logger.SetLevel(logLevel) + logger.EnableFuncCallDepth(true) + logger.SetLogFuncCallDepth(3) + return &Logger{logger} +} + +func (logger *Logger) Printf(format string, v ...interface{}) { + logger.Trace(format, v...) +} + +var l *Logger + +func InitLogger(channelLen int64, adapterName string, config string, logLevel int) { + l = NewLogger(channelLen, adapterName, config, logLevel) +} + +func Criticalf(format string, v ...interface{}) { + l.Critical(format, v...) +} + +func Errorf(format string, v ...interface{}) { + l.Error(format, v...) +} + +func Warnf(format string, v ...interface{}) { + l.Warn(format, v...) +} + +func Infof(format string, v ...interface{}) { + l.Info(format, v...) +} + +func Tracef(format string, v ...interface{}) { + l.Trace(format, v...) +} + +func Debugf(format string, v ...interface{}) { + l.Debug(format, v...) +} diff --git a/message/message.go b/message/message.go new file mode 100644 index 0000000..d43e37d --- /dev/null +++ b/message/message.go @@ -0,0 +1,8 @@ +package message + +//EncryptedXMLMsg 安全模式下的消息体 +type EncryptedXMLMsg struct { + XMLName struct{} `xml:"xml" json:"-"` + ToUserName string `xml:"ToUserName" json:"ToUserName"` + EncryptedMsg string `xml:"Encrypt" json:"Encrypt"` +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..c2a6ef9 --- /dev/null +++ b/server/server.go @@ -0,0 +1,99 @@ +package server + +import ( + "encoding/xml" + "fmt" + "io/ioutil" + + "github.com/silenceper/wechat/context" + "github.com/silenceper/wechat/message" + "github.com/silenceper/wechat/util" +) + +//Server struct +type Server struct { + *context.Context + isSafeMode bool + rawXMLMsg string +} + +//NewServer init +func NewServer(context *context.Context) *Server { + srv := new(Server) + srv.Context = context + return srv +} + +//Serve 处理微信的请求消息 +func (srv *Server) Serve() error { + + if !srv.Validate() { + return fmt.Errorf("请求校验失败") + } + + echostr, exists := srv.GetQuery("echostr") + if exists { + return srv.String(echostr) + } + + srv.handleRequest() + + return nil +} + +//Validate 校验请求是否合法 +func (srv *Server) Validate() bool { + timestamp := srv.Query("timestamp") + nonce := srv.Query("nonce") + signature := srv.Query("signature") + return signature == util.Signature(srv.Token, timestamp, nonce) +} + +//HandleRequest 处理微信的请求 +func (srv *Server) handleRequest() { + srv.isSafeMode = false + encryptType := srv.Query("encrypt_type") + if encryptType == "aes" { + srv.isSafeMode = true + } + + _, err := srv.getMessage() + if err != nil { + fmt.Printf("%v", err) + } +} + +func (srv *Server) getMessage() (interface{}, error) { + var rawXMLMsgBytes []byte + var err error + if srv.isSafeMode { + var encryptedXMLMsg message.EncryptedXMLMsg + if err := xml.NewDecoder(srv.Request.Body).Decode(&encryptedXMLMsg); err != nil { + return nil, fmt.Errorf("从body中解析xml失败,err=%v", err) + } + + //验证消息签名 + timestamp := srv.Query("timestamp") + nonce := srv.Query("nonce") + msgSignature := srv.Query("msg_signature") + msgSignatureCreate := util.Signature(srv.Token, timestamp, nonce, encryptedXMLMsg.EncryptedMsg) + if msgSignature != msgSignatureCreate { + return nil, fmt.Errorf("消息不合法,验证签名失败") + } + + //解密 + rawXMLMsgBytes, err = util.DecryptMsg(srv.AppID, encryptedXMLMsg.EncryptedMsg, srv.EncodingAESKey) + if err != nil { + return nil, fmt.Errorf("消息解密失败,err=%v", err) + } + } else { + rawXMLMsgBytes, err = ioutil.ReadAll(srv.Request.Body) + if err != nil { + return nil, fmt.Errorf("从body中解析xml失败,err=%v", err) + } + } + + srv.rawXMLMsg = string(rawXMLMsgBytes) + fmt.Println(srv.rawXMLMsg) + return nil, nil +} diff --git a/util/crypto.go b/util/crypto.go new file mode 100644 index 0000000..928048d --- /dev/null +++ b/util/crypto.go @@ -0,0 +1,113 @@ +package util + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "errors" + "fmt" +) + +//DecryptMsg 消息解密 +func DecryptMsg(appID, encryptedMsg, aesKey string) (rawMsgXMLBytes []byte, err error) { + var encryptedMsgBytes, key, getAppIDBytes []byte + encryptedMsgBytes, err = base64.StdEncoding.DecodeString(encryptedMsg) + if err != nil { + return + } + key, err = aesKeyDecode(aesKey) + if err != nil { + return + } + _, rawMsgXMLBytes, getAppIDBytes, err = AESDecryptMsg(encryptedMsgBytes, key) + if err != nil { + err = fmt.Errorf("消息解密失败,%v", err) + return + } + if appID != string(getAppIDBytes) { + err = fmt.Errorf("消息解密校验APPID失败") + return + } + return +} + +func aesKeyDecode(encodedAESKey string) (key []byte, err error) { + if len(encodedAESKey) != 43 { + err = errors.New("the length of encodedAESKey must be equal to 43") + return + } + key, err = base64.StdEncoding.DecodeString(encodedAESKey + "=") + if err != nil { + return + } + if len(key) != 32 { + err = errors.New("encodingAESKey invalid") + return + } + return +} + +// AESDecryptMsg ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId] +func AESDecryptMsg(ciphertext []byte, aesKey []byte) (random, rawXMLMsg, appID []byte, err error) { + const ( + BlockSize = 32 // PKCS#7 + BlockMask = BlockSize - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数 + ) + + if len(ciphertext) < BlockSize { + err = fmt.Errorf("the length of ciphertext too short: %d", len(ciphertext)) + return + } + if len(ciphertext)&BlockMask != 0 { + err = fmt.Errorf("ciphertext is not a multiple of the block size, the length is %d", len(ciphertext)) + return + } + + plaintext := make([]byte, len(ciphertext)) // len(plaintext) >= BLOCK_SIZE + + // 解密 + block, err := aes.NewCipher(aesKey) + if err != nil { + panic(err) + } + mode := cipher.NewCBCDecrypter(block, aesKey[:16]) + mode.CryptBlocks(plaintext, ciphertext) + + // PKCS#7 去除补位 + amountToPad := int(plaintext[len(plaintext)-1]) + if amountToPad < 1 || amountToPad > BlockSize { + err = fmt.Errorf("the amount to pad is incorrect: %d", amountToPad) + return + } + plaintext = plaintext[:len(plaintext)-amountToPad] + + // 反拼接 + // len(plaintext) == 16+4+len(rawXMLMsg)+len(appId) + if len(plaintext) <= 20 { + err = fmt.Errorf("plaintext too short, the length is %d", len(plaintext)) + return + } + rawXMLMsgLen := int(decodeNetworkByteOrder(plaintext[16:20])) + if rawXMLMsgLen < 0 { + err = fmt.Errorf("incorrect msg length: %d", rawXMLMsgLen) + return + } + appIDOffset := 20 + rawXMLMsgLen + if len(plaintext) <= appIDOffset { + err = fmt.Errorf("msg length too large: %d", rawXMLMsgLen) + return + } + + random = plaintext[:16:20] + rawXMLMsg = plaintext[20:appIDOffset:appIDOffset] + appID = plaintext[appIDOffset:] + return +} + +// 从 4 字节的网络字节序里解析出整数 +func decodeNetworkByteOrder(orderBytes []byte) (n uint32) { + return uint32(orderBytes[0])<<24 | + uint32(orderBytes[1])<<16 | + uint32(orderBytes[2])<<8 | + uint32(orderBytes[3]) +} diff --git a/util/signature.go b/util/signature.go new file mode 100644 index 0000000..a65ee23 --- /dev/null +++ b/util/signature.go @@ -0,0 +1,18 @@ +package util + +import ( + "crypto/sha1" + "fmt" + "io" + "sort" +) + +//Signature sha1签名 +func Signature(params ...string) string { + sort.Strings(params) + h := sha1.New() + for _, s := range params { + io.WriteString(h, s) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} diff --git a/wechat.go b/wechat.go new file mode 100644 index 0000000..a4af8d9 --- /dev/null +++ b/wechat.go @@ -0,0 +1,50 @@ +package wechat + +import ( + "net/http" + + "github.com/silenceper/wechat/context" + "github.com/silenceper/wechat/log" + "github.com/silenceper/wechat/server" +) + +//Wechat struct +type Wechat struct { + Context *context.Context +} + +//Config for user +type Config struct { + AppID string + AppSecret string + Token string + EncodingAESKey string +} + +//NewWechat init +func NewWechat(cfg *Config) *Wechat { + + channelLen := int64(10000) + adapterName := "console" + config := "" + logLevel := log.LevelDebug + log.InitLogger(channelLen, adapterName, config, logLevel) + + context := new(context.Context) + copyConfigToContext(cfg, context) + return &Wechat{context} +} + +func copyConfigToContext(cfg *Config, context *context.Context) { + context.AppID = cfg.AppID + context.AppSecret = cfg.AppSecret + context.Token = cfg.Token + context.EncodingAESKey = cfg.EncodingAESKey +} + +//GetServer init +func (wc *Wechat) GetServer(req *http.Request, writer http.ResponseWriter) *server.Server { + wc.Context.Request = req + wc.Context.Writer = writer + return server.NewServer(wc.Context) +}