mirror of
https://github.com/silenceper/wechat.git
synced 2026-02-13 01:02:27 +08:00
实现消息解密
This commit is contained in:
39
context/context.go
Normal file
39
context/context.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package context
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
//Context struct
|
||||||
|
type Context struct {
|
||||||
|
AppID string
|
||||||
|
AppSecret string
|
||||||
|
Token string
|
||||||
|
EncodingAESKey string
|
||||||
|
|
||||||
|
Writer http.ResponseWriter
|
||||||
|
Request *http.Request
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *Context) getAccessToken() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ctx *Context) String(str string) error {
|
||||||
|
ctx.Writer.WriteHeader(200)
|
||||||
|
_, err := ctx.Writer.Write([]byte(str))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query returns the keyed url query value if it exists
|
||||||
|
func (ctx *Context) Query(key string) string {
|
||||||
|
value, _ := ctx.GetQuery(key)
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQuery is like Query(), it returns the keyed url query value
|
||||||
|
func (ctx *Context) GetQuery(key string) (string, bool) {
|
||||||
|
req := ctx.Request
|
||||||
|
if values, ok := req.URL.Query()[key]; ok && len(values) > 0 {
|
||||||
|
return values[0], true
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
61
log/log.go
Normal file
61
log/log.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package log
|
||||||
|
|
||||||
|
import "github.com/astaxie/beego/logs"
|
||||||
|
|
||||||
|
const (
|
||||||
|
LevelEmergency = iota
|
||||||
|
LevelAlert
|
||||||
|
LevelCritical
|
||||||
|
LevelError
|
||||||
|
LevelWarning
|
||||||
|
LevelNotice
|
||||||
|
LevelInformational
|
||||||
|
LevelDebug
|
||||||
|
)
|
||||||
|
|
||||||
|
type Logger struct {
|
||||||
|
*logs.BeeLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLogger(channelLen int64, adapterName string, config string, logLevel int) *Logger {
|
||||||
|
logger := logs.NewLogger(channelLen)
|
||||||
|
logger.SetLogger(adapterName, config)
|
||||||
|
logger.SetLevel(logLevel)
|
||||||
|
logger.EnableFuncCallDepth(true)
|
||||||
|
logger.SetLogFuncCallDepth(3)
|
||||||
|
return &Logger{logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (logger *Logger) Printf(format string, v ...interface{}) {
|
||||||
|
logger.Trace(format, v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
var l *Logger
|
||||||
|
|
||||||
|
func InitLogger(channelLen int64, adapterName string, config string, logLevel int) {
|
||||||
|
l = NewLogger(channelLen, adapterName, config, logLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Criticalf(format string, v ...interface{}) {
|
||||||
|
l.Critical(format, v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Errorf(format string, v ...interface{}) {
|
||||||
|
l.Error(format, v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Warnf(format string, v ...interface{}) {
|
||||||
|
l.Warn(format, v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Infof(format string, v ...interface{}) {
|
||||||
|
l.Info(format, v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Tracef(format string, v ...interface{}) {
|
||||||
|
l.Trace(format, v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Debugf(format string, v ...interface{}) {
|
||||||
|
l.Debug(format, v...)
|
||||||
|
}
|
||||||
8
message/message.go
Normal file
8
message/message.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package message
|
||||||
|
|
||||||
|
//EncryptedXMLMsg 安全模式下的消息体
|
||||||
|
type EncryptedXMLMsg struct {
|
||||||
|
XMLName struct{} `xml:"xml" json:"-"`
|
||||||
|
ToUserName string `xml:"ToUserName" json:"ToUserName"`
|
||||||
|
EncryptedMsg string `xml:"Encrypt" json:"Encrypt"`
|
||||||
|
}
|
||||||
99
server/server.go
Normal file
99
server/server.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/xml"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
|
||||||
|
"github.com/silenceper/wechat/context"
|
||||||
|
"github.com/silenceper/wechat/message"
|
||||||
|
"github.com/silenceper/wechat/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
//Server struct
|
||||||
|
type Server struct {
|
||||||
|
*context.Context
|
||||||
|
isSafeMode bool
|
||||||
|
rawXMLMsg string
|
||||||
|
}
|
||||||
|
|
||||||
|
//NewServer init
|
||||||
|
func NewServer(context *context.Context) *Server {
|
||||||
|
srv := new(Server)
|
||||||
|
srv.Context = context
|
||||||
|
return srv
|
||||||
|
}
|
||||||
|
|
||||||
|
//Serve 处理微信的请求消息
|
||||||
|
func (srv *Server) Serve() error {
|
||||||
|
|
||||||
|
if !srv.Validate() {
|
||||||
|
return fmt.Errorf("请求校验失败")
|
||||||
|
}
|
||||||
|
|
||||||
|
echostr, exists := srv.GetQuery("echostr")
|
||||||
|
if exists {
|
||||||
|
return srv.String(echostr)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.handleRequest()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//Validate 校验请求是否合法
|
||||||
|
func (srv *Server) Validate() bool {
|
||||||
|
timestamp := srv.Query("timestamp")
|
||||||
|
nonce := srv.Query("nonce")
|
||||||
|
signature := srv.Query("signature")
|
||||||
|
return signature == util.Signature(srv.Token, timestamp, nonce)
|
||||||
|
}
|
||||||
|
|
||||||
|
//HandleRequest 处理微信的请求
|
||||||
|
func (srv *Server) handleRequest() {
|
||||||
|
srv.isSafeMode = false
|
||||||
|
encryptType := srv.Query("encrypt_type")
|
||||||
|
if encryptType == "aes" {
|
||||||
|
srv.isSafeMode = true
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := srv.getMessage()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("%v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) getMessage() (interface{}, error) {
|
||||||
|
var rawXMLMsgBytes []byte
|
||||||
|
var err error
|
||||||
|
if srv.isSafeMode {
|
||||||
|
var encryptedXMLMsg message.EncryptedXMLMsg
|
||||||
|
if err := xml.NewDecoder(srv.Request.Body).Decode(&encryptedXMLMsg); err != nil {
|
||||||
|
return nil, fmt.Errorf("从body中解析xml失败,err=%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//验证消息签名
|
||||||
|
timestamp := srv.Query("timestamp")
|
||||||
|
nonce := srv.Query("nonce")
|
||||||
|
msgSignature := srv.Query("msg_signature")
|
||||||
|
msgSignatureCreate := util.Signature(srv.Token, timestamp, nonce, encryptedXMLMsg.EncryptedMsg)
|
||||||
|
if msgSignature != msgSignatureCreate {
|
||||||
|
return nil, fmt.Errorf("消息不合法,验证签名失败")
|
||||||
|
}
|
||||||
|
|
||||||
|
//解密
|
||||||
|
rawXMLMsgBytes, err = util.DecryptMsg(srv.AppID, encryptedXMLMsg.EncryptedMsg, srv.EncodingAESKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("消息解密失败,err=%v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
rawXMLMsgBytes, err = ioutil.ReadAll(srv.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("从body中解析xml失败,err=%v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.rawXMLMsg = string(rawXMLMsgBytes)
|
||||||
|
fmt.Println(srv.rawXMLMsg)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
113
util/crypto.go
Normal file
113
util/crypto.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
//DecryptMsg 消息解密
|
||||||
|
func DecryptMsg(appID, encryptedMsg, aesKey string) (rawMsgXMLBytes []byte, err error) {
|
||||||
|
var encryptedMsgBytes, key, getAppIDBytes []byte
|
||||||
|
encryptedMsgBytes, err = base64.StdEncoding.DecodeString(encryptedMsg)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key, err = aesKeyDecode(aesKey)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, rawMsgXMLBytes, getAppIDBytes, err = AESDecryptMsg(encryptedMsgBytes, key)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("消息解密失败,%v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if appID != string(getAppIDBytes) {
|
||||||
|
err = fmt.Errorf("消息解密校验APPID失败")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func aesKeyDecode(encodedAESKey string) (key []byte, err error) {
|
||||||
|
if len(encodedAESKey) != 43 {
|
||||||
|
err = errors.New("the length of encodedAESKey must be equal to 43")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key, err = base64.StdEncoding.DecodeString(encodedAESKey + "=")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(key) != 32 {
|
||||||
|
err = errors.New("encodingAESKey invalid")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// AESDecryptMsg ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId]
|
||||||
|
func AESDecryptMsg(ciphertext []byte, aesKey []byte) (random, rawXMLMsg, appID []byte, err error) {
|
||||||
|
const (
|
||||||
|
BlockSize = 32 // PKCS#7
|
||||||
|
BlockMask = BlockSize - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(ciphertext) < BlockSize {
|
||||||
|
err = fmt.Errorf("the length of ciphertext too short: %d", len(ciphertext))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(ciphertext)&BlockMask != 0 {
|
||||||
|
err = fmt.Errorf("ciphertext is not a multiple of the block size, the length is %d", len(ciphertext))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext := make([]byte, len(ciphertext)) // len(plaintext) >= BLOCK_SIZE
|
||||||
|
|
||||||
|
// 解密
|
||||||
|
block, err := aes.NewCipher(aesKey)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
mode := cipher.NewCBCDecrypter(block, aesKey[:16])
|
||||||
|
mode.CryptBlocks(plaintext, ciphertext)
|
||||||
|
|
||||||
|
// PKCS#7 去除补位
|
||||||
|
amountToPad := int(plaintext[len(plaintext)-1])
|
||||||
|
if amountToPad < 1 || amountToPad > BlockSize {
|
||||||
|
err = fmt.Errorf("the amount to pad is incorrect: %d", amountToPad)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
plaintext = plaintext[:len(plaintext)-amountToPad]
|
||||||
|
|
||||||
|
// 反拼接
|
||||||
|
// len(plaintext) == 16+4+len(rawXMLMsg)+len(appId)
|
||||||
|
if len(plaintext) <= 20 {
|
||||||
|
err = fmt.Errorf("plaintext too short, the length is %d", len(plaintext))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rawXMLMsgLen := int(decodeNetworkByteOrder(plaintext[16:20]))
|
||||||
|
if rawXMLMsgLen < 0 {
|
||||||
|
err = fmt.Errorf("incorrect msg length: %d", rawXMLMsgLen)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
appIDOffset := 20 + rawXMLMsgLen
|
||||||
|
if len(plaintext) <= appIDOffset {
|
||||||
|
err = fmt.Errorf("msg length too large: %d", rawXMLMsgLen)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
random = plaintext[:16:20]
|
||||||
|
rawXMLMsg = plaintext[20:appIDOffset:appIDOffset]
|
||||||
|
appID = plaintext[appIDOffset:]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从 4 字节的网络字节序里解析出整数
|
||||||
|
func decodeNetworkByteOrder(orderBytes []byte) (n uint32) {
|
||||||
|
return uint32(orderBytes[0])<<24 |
|
||||||
|
uint32(orderBytes[1])<<16 |
|
||||||
|
uint32(orderBytes[2])<<8 |
|
||||||
|
uint32(orderBytes[3])
|
||||||
|
}
|
||||||
18
util/signature.go
Normal file
18
util/signature.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha1"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
|
//Signature sha1签名
|
||||||
|
func Signature(params ...string) string {
|
||||||
|
sort.Strings(params)
|
||||||
|
h := sha1.New()
|
||||||
|
for _, s := range params {
|
||||||
|
io.WriteString(h, s)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%x", h.Sum(nil))
|
||||||
|
}
|
||||||
50
wechat.go
Normal file
50
wechat.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package wechat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/silenceper/wechat/context"
|
||||||
|
"github.com/silenceper/wechat/log"
|
||||||
|
"github.com/silenceper/wechat/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
//Wechat struct
|
||||||
|
type Wechat struct {
|
||||||
|
Context *context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
//Config for user
|
||||||
|
type Config struct {
|
||||||
|
AppID string
|
||||||
|
AppSecret string
|
||||||
|
Token string
|
||||||
|
EncodingAESKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
//NewWechat init
|
||||||
|
func NewWechat(cfg *Config) *Wechat {
|
||||||
|
|
||||||
|
channelLen := int64(10000)
|
||||||
|
adapterName := "console"
|
||||||
|
config := ""
|
||||||
|
logLevel := log.LevelDebug
|
||||||
|
log.InitLogger(channelLen, adapterName, config, logLevel)
|
||||||
|
|
||||||
|
context := new(context.Context)
|
||||||
|
copyConfigToContext(cfg, context)
|
||||||
|
return &Wechat{context}
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyConfigToContext(cfg *Config, context *context.Context) {
|
||||||
|
context.AppID = cfg.AppID
|
||||||
|
context.AppSecret = cfg.AppSecret
|
||||||
|
context.Token = cfg.Token
|
||||||
|
context.EncodingAESKey = cfg.EncodingAESKey
|
||||||
|
}
|
||||||
|
|
||||||
|
//GetServer init
|
||||||
|
func (wc *Wechat) GetServer(req *http.Request, writer http.ResponseWriter) *server.Server {
|
||||||
|
wc.Context.Request = req
|
||||||
|
wc.Context.Writer = writer
|
||||||
|
return server.NewServer(wc.Context)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user