1
0
mirror of https://github.com/silenceper/wechat.git synced 2026-02-04 12:52:27 +08:00

支持文本消息的回复

This commit is contained in:
wenzl
2016-09-11 00:27:37 +08:00
parent e713b4ffb2
commit 33f2b2ef60
10 changed files with 403 additions and 96 deletions

View File

@@ -17,12 +17,6 @@ func (ctx *Context) getAccessToken() {
}
func (ctx *Context) String(str string) error {
ctx.Writer.WriteHeader(200)
_, err := ctx.Writer.Write([]byte(str))
return err
}
// Query returns the keyed url query value if it exists
func (ctx *Context) Query(key string) string {
value, _ := ctx.GetQuery(key)

41
context/render.go Normal file
View File

@@ -0,0 +1,41 @@
package context
import (
"encoding/xml"
"net/http"
)
var xmlContentType = []string{"application/xml; charset=utf-8"}
var plainContentType = []string{"text/plain; charset=utf-8"}
//Render render from bytes
func (ctx *Context) Render(bytes []byte) {
ctx.Writer.WriteHeader(200)
_, err := ctx.Writer.Write(bytes)
if err != nil {
panic(err)
}
}
//String render from string
func (ctx *Context) String(str string) {
writeContextType(ctx.Writer, plainContentType)
ctx.Render([]byte(str))
}
//XML render to xml
func (ctx *Context) XML(obj interface{}) {
writeContextType(ctx.Writer, xmlContentType)
bytes, err := xml.Marshal(obj)
if err != nil {
panic(err)
}
ctx.Render(bytes)
}
func writeContextType(w http.ResponseWriter, value []string) {
header := w.Header()
if val := header["Content-Type"]; len(val) == 0 {
header["Content-Type"] = value
}
}

View File

@@ -1,61 +0,0 @@
package log
import "github.com/astaxie/beego/logs"
const (
LevelEmergency = iota
LevelAlert
LevelCritical
LevelError
LevelWarning
LevelNotice
LevelInformational
LevelDebug
)
type Logger struct {
*logs.BeeLogger
}
func NewLogger(channelLen int64, adapterName string, config string, logLevel int) *Logger {
logger := logs.NewLogger(channelLen)
logger.SetLogger(adapterName, config)
logger.SetLevel(logLevel)
logger.EnableFuncCallDepth(true)
logger.SetLogFuncCallDepth(3)
return &Logger{logger}
}
func (logger *Logger) Printf(format string, v ...interface{}) {
logger.Trace(format, v...)
}
var l *Logger
func InitLogger(channelLen int64, adapterName string, config string, logLevel int) {
l = NewLogger(channelLen, adapterName, config, logLevel)
}
func Criticalf(format string, v ...interface{}) {
l.Critical(format, v...)
}
func Errorf(format string, v ...interface{}) {
l.Error(format, v...)
}
func Warnf(format string, v ...interface{}) {
l.Warn(format, v...)
}
func Infof(format string, v ...interface{}) {
l.Info(format, v...)
}
func Tracef(format string, v ...interface{}) {
l.Trace(format, v...)
}
func Debugf(format string, v ...interface{}) {
l.Debug(format, v...)
}

View File

@@ -1,8 +1,120 @@
package message
import "encoding/xml"
//MsgType 基本消息类型
type MsgType string
//EventType 事件类型
type EventType string
const (
//MsgTypeText 表示文本消息
MsgTypeText MsgType = "text"
//MsgTypeImage 表示图片消息
MsgTypeImage = "image"
//MsgTypeVoice 表示语音消息
MsgTypeVoice = "voice"
//MsgTypeVideo 表示视频消息
MsgTypeVideo = "video"
//MsgTypeShortVideo 表示短视频消息[限接收]
MsgTypeShortVideo = "shortvideo"
//MsgTypeLocation 表示坐标消息[限接收]
MsgTypeLocation = "location"
//MsgTypeLink 表示链接消息[限接收]
MsgTypeLink = "link"
//MsgTypeMusic 表示音乐消息[限回复]
MsgTypeMusic = "music"
//MsgTypeNews 表示图文消息[限回复]
MsgTypeNews = "news"
//MsgTypeTransfer 表示消息消息转发到客服
MsgTypeTransfer = "transfer_customer_service"
)
const (
//EventSubscribe 订阅
EventSubscribe EventType = "subscribe"
//EventUnsubscribe 取消订阅
EventUnsubscribe EventType = "unsubscribe"
//EventScan 用户已经关注公众号,则微信会将带场景值扫描事件推送给开发者
EventScan EventType = "SCAN"
//EventLocation 上报地理位置事件
EventLocation EventType = "LOCATION"
//EventClick 点击菜单拉取消息时的事件推送
EventClick EventType = "CLICK"
//EventView 点击菜单跳转链接时的事件推送
EventView EventType = "VIEW"
)
//EncryptedXMLMsg 安全模式下的消息体
type EncryptedXMLMsg struct {
XMLName struct{} `xml:"xml" json:"-"`
ToUserName string `xml:"ToUserName" json:"ToUserName"`
EncryptedMsg string `xml:"Encrypt" json:"Encrypt"`
}
//ResponseEncryptedXMLMsg 需要返回的消息体
type ResponseEncryptedXMLMsg struct {
XMLName struct{} `xml:"xml" json:"-"`
EncryptedMsg string `xml:"Encrypt" json:"Encrypt"`
MsgSignature string `xml:"MsgSignature" json:"MsgSignature"`
Timestamp int64 `xml:"TimeStamp" json:"TimeStamp"`
Nonce string `xml:"Nonce" json:"Nonce"`
}
// CommonToken 消息中通用的结构
type CommonToken struct {
XMLName xml.Name `xml:"xml"`
ToUserName string `xml:"ToUserName"`
FromUserName string `xml:"FromUserName"`
CreateTime int64 `xml:"CreateTime"`
MsgType MsgType `xml:"MsgType"`
}
//SetToUserName set ToUserName
func (msg *CommonToken) SetToUserName(toUserName string) {
msg.ToUserName = toUserName
}
//SetFromUserName set FromUserName
func (msg *CommonToken) SetFromUserName(fromUserName string) {
msg.FromUserName = fromUserName
}
//SetCreateTime set createTime
func (msg *CommonToken) SetCreateTime(createTime int64) {
msg.CreateTime = createTime
}
//SetMsgType set MsgType
func (msg *CommonToken) SetMsgType(msgType MsgType) {
msg.MsgType = msgType
}
//MixMessage 存放所有微信发送过来的消息和事件
type MixMessage struct {
CommonToken
//基本消息
MsgID int64 `xml:"MsgId"`
Content string `xml:"Content"`
PicURL string `xml:"PicUrl"`
MediaID string `xml:"MediaId"`
Format string `xml:"Format"`
ThumbMediaID string `xml:"ThumbMediaId"`
LocationX float64 `xml:"Location_X"`
LocationY float64 `xml:"Location_Y"`
Scale float64 `xml:"Scale"`
Label string `xml:"Label"`
Title string `xml:"Title"`
Description string `xml:"Description"`
URL string `xml:"Url"`
//事件相关
Event string `xml:"Event"`
EventKey string `xml:"EventKey"`
Ticket string `xml:"Ticket"`
Latitude string `xml:"Latitude"`
Longitude string `xml:"Longitude"`
Precision string `xml:"Precision"`
}

15
message/reply.go Normal file
View File

@@ -0,0 +1,15 @@
package message
import "errors"
//ErrInvalidReply 无效的回复
var ErrInvalidReply = errors.New("无效的回复消息")
//ErrUnsupportReply 不支持的回复类型
var ErrUnsupportReply = errors.New("不支持的回复消息")
//Reply 消息回复
type Reply struct {
MsgType MsgType
MsgData interface{}
}

7
message/text.go Normal file
View File

@@ -0,0 +1,7 @@
package message
//Text 文本消息
type Text struct {
CommonToken
Content string `xml:"Content"`
}

View File

@@ -2,8 +2,12 @@ package server
import (
"encoding/xml"
"errors"
"fmt"
"io/ioutil"
"reflect"
"strconv"
"strings"
"github.com/silenceper/wechat/context"
"github.com/silenceper/wechat/message"
@@ -13,8 +17,20 @@ import (
//Server struct
type Server struct {
*context.Context
openID string
messageHandler func(message.MixMessage) *message.Reply
requestRawXMLMsg []byte
requestMsg message.MixMessage
responseRawXMLMsg []byte
responseMsg interface{}
isSafeMode bool
rawXMLMsg string
random []byte
nonce string
timestamp int64
}
//NewServer init
@@ -26,18 +42,22 @@ func NewServer(context *context.Context) *Server {
//Serve 处理微信的请求消息
func (srv *Server) Serve() error {
if !srv.Validate() {
return fmt.Errorf("请求校验失败")
}
echostr, exists := srv.GetQuery("echostr")
if exists {
return srv.String(echostr)
srv.String(echostr)
return nil
}
srv.handleRequest()
response, err := srv.handleRequest()
if err != nil {
return err
}
srv.buildResponse(response)
return nil
}
@@ -50,19 +70,37 @@ func (srv *Server) Validate() bool {
}
//HandleRequest 处理微信的请求
func (srv *Server) handleRequest() {
func (srv *Server) handleRequest() (reply *message.Reply, err error) {
//set isSafeMode
srv.isSafeMode = false
encryptType := srv.Query("encrypt_type")
if encryptType == "aes" {
srv.isSafeMode = true
}
_, err := srv.getMessage()
//set openID
srv.openID = srv.Query("openid")
var msg interface{}
msg, err = srv.getMessage()
if err != nil {
fmt.Printf("%v", err)
return
}
mixMessage, success := msg.(message.MixMessage)
if !success {
err = errors.New("消息类型转换失败")
}
srv.requestMsg = mixMessage
reply = srv.messageHandler(mixMessage)
return
}
//GetOpenID return openID
func (srv *Server) GetOpenID() string {
return srv.openID
}
//getMessage 解析微信返回的消息
func (srv *Server) getMessage() (interface{}, error) {
var rawXMLMsgBytes []byte
var err error
@@ -74,26 +112,117 @@ func (srv *Server) getMessage() (interface{}, error) {
//验证消息签名
timestamp := srv.Query("timestamp")
srv.timestamp, err = strconv.ParseInt(timestamp, 10, 32)
if err != nil {
return nil, err
}
nonce := srv.Query("nonce")
srv.nonce = nonce
msgSignature := srv.Query("msg_signature")
msgSignatureCreate := util.Signature(srv.Token, timestamp, nonce, encryptedXMLMsg.EncryptedMsg)
if msgSignature != msgSignatureCreate {
msgSignatureGen := util.Signature(srv.Token, timestamp, nonce, encryptedXMLMsg.EncryptedMsg)
if msgSignature != msgSignatureGen {
return nil, fmt.Errorf("消息不合法,验证签名失败")
}
//解密
rawXMLMsgBytes, err = util.DecryptMsg(srv.AppID, encryptedXMLMsg.EncryptedMsg, srv.EncodingAESKey)
srv.random, rawXMLMsgBytes, err = util.DecryptMsg(srv.AppID, encryptedXMLMsg.EncryptedMsg, srv.EncodingAESKey)
if err != nil {
return nil, fmt.Errorf("消息解密失败,err=%v", err)
return nil, fmt.Errorf("消息解密失败, err=%v", err)
}
} else {
rawXMLMsgBytes, err = ioutil.ReadAll(srv.Request.Body)
if err != nil {
return nil, fmt.Errorf("从body中解析xml失败,err=%v", err)
return nil, fmt.Errorf("从body中解析xml失败, err=%v", err)
}
}
srv.rawXMLMsg = string(rawXMLMsgBytes)
fmt.Println(srv.rawXMLMsg)
return nil, nil
srv.requestRawXMLMsg = rawXMLMsgBytes
return srv.parseRequestMessage(rawXMLMsgBytes)
}
func (srv *Server) parseRequestMessage(rawXMLMsgBytes []byte) (msg message.MixMessage, err error) {
msg = message.MixMessage{}
err = xml.Unmarshal(rawXMLMsgBytes, &msg)
return
}
//SetMessageHandler 设置用户自定义的回调方法
func (srv *Server) SetMessageHandler(handler func(message.MixMessage) *message.Reply) {
srv.messageHandler = handler
}
func (srv *Server) buildResponse(reply *message.Reply) (err error) {
defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("panic error: %v", err)
}
}()
if reply == nil {
//do nothing
return nil
}
msgType := reply.MsgType
switch msgType {
case message.MsgTypeText:
case message.MsgTypeImage:
case message.MsgTypeVoice:
case message.MsgTypeVideo:
case message.MsgTypeMusic:
case message.MsgTypeNews:
case message.MsgTypeTransfer:
default:
err = message.ErrUnsupportReply
return
}
msgData := reply.MsgData
value := reflect.ValueOf(msgData)
//msgData must be a ptr
kind := value.Kind().String()
if 0 != strings.Compare("ptr", kind) {
return message.ErrUnsupportReply
}
params := make([]reflect.Value, 1)
params[0] = reflect.ValueOf(srv.requestMsg.FromUserName)
value.MethodByName("SetToUserName").Call(params)
params[0] = reflect.ValueOf(srv.requestMsg.ToUserName)
value.MethodByName("SetFromUserName").Call(params)
params[0] = reflect.ValueOf(msgType)
value.MethodByName("SetMsgType").Call(params)
params[0] = reflect.ValueOf(util.GetCurrTs())
value.MethodByName("SetCreateTime").Call(params)
srv.responseMsg = msgData
srv.responseRawXMLMsg, err = xml.Marshal(msgData)
return
}
//Send 将自定义的消息发送
func (srv *Server) Send() (err error) {
replyMsg := srv.responseMsg
if srv.isSafeMode {
//安全模式下对消息进行加密
var encryptedMsg []byte
encryptedMsg, err = util.EncryptMsg(srv.random, srv.responseRawXMLMsg, srv.AppID, srv.EncodingAESKey)
if err != nil {
return
}
//TODO 如果获取不到timestamp nonce 则自己生成
timestamp := srv.timestamp
timestampStr := strconv.FormatInt(timestamp, 10)
msgSignature := util.Signature(srv.Token, timestampStr, srv.nonce, string(encryptedMsg))
replyMsg = message.ResponseEncryptedXMLMsg{
EncryptedMsg: string(encryptedMsg),
MsgSignature: msgSignature,
Timestamp: timestamp,
Nonce: srv.nonce,
}
}
srv.XML(replyMsg)
return
}

View File

@@ -4,12 +4,73 @@ import (
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"errors"
"fmt"
)
//EncryptMsg 加密消息
func EncryptMsg(random, rawXMLMsg []byte, appID, aesKey string) (encrtptMsg []byte, err error) {
defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("panic error: err=%v", e)
return
}
}()
var key []byte
key, err = aesKeyDecode(aesKey)
if err != nil {
panic(err)
}
ciphertext := AESEncryptMsg(random, rawXMLMsg, appID, key)
encrtptMsg = []byte(base64.StdEncoding.EncodeToString(ciphertext))
return
}
//AESEncryptMsg ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId]
//参考github.com/chanxuehong/wechat.v2
func AESEncryptMsg(random, rawXMLMsg []byte, appID string, aesKey []byte) (ciphertext []byte) {
const (
BlockSize = 32 // PKCS#7
BlockMask = BlockSize - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数
)
appIDOffset := 20 + len(rawXMLMsg)
contentLen := appIDOffset + len(appID)
amountToPad := BlockSize - contentLen&BlockMask
plaintextLen := contentLen + amountToPad
plaintext := make([]byte, plaintextLen)
// 拼接
copy(plaintext[:16], random)
encodeNetworkByteOrder(plaintext[16:20], uint32(len(rawXMLMsg)))
copy(plaintext[20:], rawXMLMsg)
copy(plaintext[appIDOffset:], appID)
// PKCS#7 补位
for i := contentLen; i < plaintextLen; i++ {
plaintext[i] = byte(amountToPad)
}
// 加密
block, err := aes.NewCipher(aesKey[:])
if err != nil {
panic(err)
}
mode := cipher.NewCBCEncrypter(block, aesKey[:16])
mode.CryptBlocks(plaintext, plaintext)
ciphertext = plaintext
return
}
//DecryptMsg 消息解密
func DecryptMsg(appID, encryptedMsg, aesKey string) (rawMsgXMLBytes []byte, err error) {
func DecryptMsg(appID, encryptedMsg, aesKey string) (random, rawMsgXMLBytes []byte, err error) {
defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("panic error: err=%v", e)
return
}
}()
var encryptedMsgBytes, key, getAppIDBytes []byte
encryptedMsgBytes, err = base64.StdEncoding.DecodeString(encryptedMsg)
if err != nil {
@@ -17,9 +78,9 @@ func DecryptMsg(appID, encryptedMsg, aesKey string) (rawMsgXMLBytes []byte, err
}
key, err = aesKeyDecode(aesKey)
if err != nil {
return
panic(err)
}
_, rawMsgXMLBytes, getAppIDBytes, err = AESDecryptMsg(encryptedMsgBytes, key)
random, rawMsgXMLBytes, getAppIDBytes, err = AESDecryptMsg(encryptedMsgBytes, key)
if err != nil {
err = fmt.Errorf("消息解密失败,%v", err)
return
@@ -33,7 +94,7 @@ func DecryptMsg(appID, encryptedMsg, aesKey string) (rawMsgXMLBytes []byte, err
func aesKeyDecode(encodedAESKey string) (key []byte, err error) {
if len(encodedAESKey) != 43 {
err = errors.New("the length of encodedAESKey must be equal to 43")
err = fmt.Errorf("the length of encodedAESKey must be equal to 43")
return
}
key, err = base64.StdEncoding.DecodeString(encodedAESKey + "=")
@@ -41,13 +102,14 @@ func aesKeyDecode(encodedAESKey string) (key []byte, err error) {
return
}
if len(key) != 32 {
err = errors.New("encodingAESKey invalid")
err = fmt.Errorf("encodingAESKey invalid")
return
}
return
}
// AESDecryptMsg ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId]
//参考github.com/chanxuehong/wechat.v2
func AESDecryptMsg(ciphertext []byte, aesKey []byte) (random, rawXMLMsg, appID []byte, err error) {
const (
BlockSize = 32 // PKCS#7
@@ -104,6 +166,14 @@ func AESDecryptMsg(ciphertext []byte, aesKey []byte) (random, rawXMLMsg, appID [
return
}
// 把整数 n 格式化成 4 字节的网络字节序
func encodeNetworkByteOrder(orderBytes []byte, n uint32) {
orderBytes[0] = byte(n >> 24)
orderBytes[1] = byte(n >> 16)
orderBytes[2] = byte(n >> 8)
orderBytes[3] = byte(n)
}
// 从 4 字节的网络字节序里解析出整数
func decodeNetworkByteOrder(orderBytes []byte) (n uint32) {
return uint32(orderBytes[0])<<24 |

8
util/time.go Normal file
View File

@@ -0,0 +1,8 @@
package util
import "time"
//GetCurrTs return current timestamps
func GetCurrTs() int64 {
return time.Now().Unix()
}

View File

@@ -4,7 +4,6 @@ import (
"net/http"
"github.com/silenceper/wechat/context"
"github.com/silenceper/wechat/log"
"github.com/silenceper/wechat/server"
)
@@ -23,13 +22,6 @@ type Config struct {
//NewWechat init
func NewWechat(cfg *Config) *Wechat {
channelLen := int64(10000)
adapterName := "console"
config := ""
logLevel := log.LevelDebug
log.InitLogger(channelLen, adapterName, config, logLevel)
context := new(context.Context)
copyConfigToContext(cfg, context)
return &Wechat{context}