mirror of
https://github.com/silenceper/wechat.git
synced 2026-02-06 05:32:26 +08:00
11
cache/cache.go
vendored
Normal file
11
cache/cache.go
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
package cache
|
||||
|
||||
import "time"
|
||||
|
||||
//Cache interface
|
||||
type Cache interface {
|
||||
Get(key string) interface{}
|
||||
Set(key string, val interface{}, timeout time.Duration) error
|
||||
IsExist(key string) bool
|
||||
Delete(key string) error
|
||||
}
|
||||
51
cache/memcache.go
vendored
Normal file
51
cache/memcache.go
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/bradfitz/gomemcache/memcache"
|
||||
)
|
||||
|
||||
//Memcache struct contains *memcache.Client
|
||||
type Memcache struct {
|
||||
conn *memcache.Client
|
||||
}
|
||||
|
||||
//NewMemcache create new memcache
|
||||
func NewMemcache(server ...string) *Memcache {
|
||||
mc := memcache.New(server...)
|
||||
return &Memcache{mc}
|
||||
}
|
||||
|
||||
//Get return cached value
|
||||
func (mem *Memcache) Get(key string) interface{} {
|
||||
if item, err := mem.conn.Get(key); err == nil {
|
||||
return string(item.Value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsExist check value exists in memcache.
|
||||
func (mem *Memcache) IsExist(key string) bool {
|
||||
_, err := mem.conn.Get(key)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
//Set cached value with key and expire time.
|
||||
func (mem *Memcache) Set(key string, val interface{}, timeout time.Duration) error {
|
||||
v, ok := val.(string)
|
||||
if !ok {
|
||||
return errors.New("val must string")
|
||||
}
|
||||
item := &memcache.Item{Key: key, Value: []byte(v), Expiration: int32(timeout / time.Second)}
|
||||
return mem.conn.Set(item)
|
||||
}
|
||||
|
||||
//Delete delete value in memcache.
|
||||
func (mem *Memcache) Delete(key string) error {
|
||||
return mem.conn.Delete(key)
|
||||
}
|
||||
28
cache/memcache_test.go
vendored
Normal file
28
cache/memcache_test.go
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMemcache(t *testing.T) {
|
||||
mem := NewMemcache("127.0.0.1:11211")
|
||||
var err error
|
||||
timeoutDuration := 10 * time.Second
|
||||
if err = mem.Set("username", "silenceper", timeoutDuration); err != nil {
|
||||
t.Error("set Error", err)
|
||||
}
|
||||
|
||||
if !mem.IsExist("username") {
|
||||
t.Error("IsExist Error")
|
||||
}
|
||||
|
||||
name := mem.Get("username").(string)
|
||||
if name != "silenceper" {
|
||||
t.Error("get Error")
|
||||
}
|
||||
|
||||
if err = mem.Delete("username"); err != nil {
|
||||
t.Errorf("delete Error , err=%v", err)
|
||||
}
|
||||
}
|
||||
71
context/access_token.go
Normal file
71
context/access_token.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package context
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/silenceper/wechat/util"
|
||||
)
|
||||
|
||||
const (
|
||||
//AccessTokenURL 获取access_token的接口
|
||||
AccessTokenURL = "https://api.weixin.qq.com/cgi-bin/token"
|
||||
)
|
||||
|
||||
//ResAccessToken struct
|
||||
type ResAccessToken struct {
|
||||
util.CommonError
|
||||
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
}
|
||||
|
||||
//SetAccessTokenLock 设置读写锁(一个appID一个读写锁)
|
||||
func (ctx *Context) SetAccessTokenLock(l *sync.RWMutex) {
|
||||
ctx.accessTokenLock = l
|
||||
}
|
||||
|
||||
//GetAccessToken 获取access_token
|
||||
func (ctx *Context) GetAccessToken() (accessToken string, err error) {
|
||||
ctx.accessTokenLock.Lock()
|
||||
defer ctx.accessTokenLock.Unlock()
|
||||
|
||||
accessTokenCacheKey := fmt.Sprintf("access_token_%s", ctx.AppID)
|
||||
val := ctx.Cache.Get(accessTokenCacheKey)
|
||||
if val != nil {
|
||||
accessToken = val.(string)
|
||||
return
|
||||
}
|
||||
|
||||
//从微信服务器获取
|
||||
var resAccessToken ResAccessToken
|
||||
resAccessToken, err = ctx.GetAccessTokenFromServer()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
accessToken = resAccessToken.AccessToken
|
||||
return
|
||||
}
|
||||
|
||||
//GetAccessTokenFromServer 强制从微信服务器获取token
|
||||
func (ctx *Context) GetAccessTokenFromServer() (resAccessToken ResAccessToken, err error) {
|
||||
url := fmt.Sprintf("%s?grant_type=client_credential&appid=%s&secret=%s", AccessTokenURL, ctx.AppID, ctx.AppSecret)
|
||||
var body []byte
|
||||
body, err = util.HTTPGet(url)
|
||||
err = json.Unmarshal(body, &resAccessToken)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if resAccessToken.ErrMsg != "" {
|
||||
err = fmt.Errorf("get access_token error : errcode=%v , errormsg=%v", resAccessToken.ErrCode, resAccessToken.ErrMsg)
|
||||
return
|
||||
}
|
||||
|
||||
accessTokenCacheKey := fmt.Sprintf("access_token_%s", ctx.AppID)
|
||||
expires := resAccessToken.ExpiresIn - 1500
|
||||
err = ctx.Cache.Set(accessTokenCacheKey, resAccessToken.AccessToken, time.Duration(expires)*time.Second)
|
||||
return
|
||||
}
|
||||
51
context/context.go
Normal file
51
context/context.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package context
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/silenceper/wechat/cache"
|
||||
)
|
||||
|
||||
//Context struct
|
||||
type Context struct {
|
||||
AppID string
|
||||
AppSecret string
|
||||
Token string
|
||||
EncodingAESKey string
|
||||
|
||||
Cache cache.Cache
|
||||
|
||||
Writer http.ResponseWriter
|
||||
Request *http.Request
|
||||
|
||||
//accessTokenLock 读写锁 同一个AppID一个
|
||||
accessTokenLock *sync.RWMutex
|
||||
|
||||
//jsapiTicket 读写锁 同一个AppID一个
|
||||
jsApiTicketLock *sync.RWMutex
|
||||
}
|
||||
|
||||
// Query returns the keyed url query value if it exists
|
||||
func (ctx *Context) Query(key string) string {
|
||||
value, _ := ctx.GetQuery(key)
|
||||
return value
|
||||
}
|
||||
|
||||
// GetQuery is like Query(), it returns the keyed url query value
|
||||
func (ctx *Context) GetQuery(key string) (string, bool) {
|
||||
req := ctx.Request
|
||||
if values, ok := req.URL.Query()[key]; ok && len(values) > 0 {
|
||||
return values[0], true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
//SetJsApiTicket 设置jsApiTicket的lock
|
||||
func (ctx *Context) SetJsApiTicketLock(lock *sync.RWMutex) {
|
||||
ctx.jsApiTicketLock = lock
|
||||
}
|
||||
|
||||
func (ctx *Context) GetJsApiTicketLock() *sync.RWMutex {
|
||||
return ctx.jsApiTicketLock
|
||||
}
|
||||
43
context/render.go
Normal file
43
context/render.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package context
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var xmlContentType = []string{"application/xml; charset=utf-8"}
|
||||
var plainContentType = []string{"text/plain; charset=utf-8"}
|
||||
|
||||
//Render render from bytes
|
||||
func (ctx *Context) Render(bytes []byte) {
|
||||
//debug
|
||||
//fmt.Println("response msg = ", string(bytes))
|
||||
ctx.Writer.WriteHeader(200)
|
||||
_, err := ctx.Writer.Write(bytes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
//String render from string
|
||||
func (ctx *Context) String(str string) {
|
||||
writeContextType(ctx.Writer, plainContentType)
|
||||
ctx.Render([]byte(str))
|
||||
}
|
||||
|
||||
//XML render to xml
|
||||
func (ctx *Context) XML(obj interface{}) {
|
||||
writeContextType(ctx.Writer, xmlContentType)
|
||||
bytes, err := xml.Marshal(obj)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ctx.Render(bytes)
|
||||
}
|
||||
|
||||
func writeContextType(w http.ResponseWriter, value []string) {
|
||||
header := w.Header()
|
||||
if val := header["Content-Type"]; len(val) == 0 {
|
||||
header["Content-Type"] = value
|
||||
}
|
||||
}
|
||||
109
js/js.go
Normal file
109
js/js.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package js
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/silenceper/wechat/context"
|
||||
"github.com/silenceper/wechat/util"
|
||||
)
|
||||
|
||||
const getTicketURL = "https://api.weixin.qq.com/cgi-bin/ticket/getticket?access_token=%s&type=jsapi"
|
||||
|
||||
//Js struct
|
||||
type Js struct {
|
||||
*context.Context
|
||||
}
|
||||
|
||||
//Config 返回给用户jssdk配置信息
|
||||
type Config struct {
|
||||
AppID string
|
||||
TimeStamp int64
|
||||
NonceStr string
|
||||
Signature string
|
||||
}
|
||||
|
||||
//resTicket 请求jsapi_tikcet返回结果
|
||||
type resTicket struct {
|
||||
util.CommonError
|
||||
|
||||
Ticket string `json:"ticket"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
}
|
||||
|
||||
//NewJs init
|
||||
func NewJs(context *context.Context) *Js {
|
||||
js := new(Js)
|
||||
js.Context = context
|
||||
return js
|
||||
}
|
||||
|
||||
//GetConfig 获取jssdk需要的配置参数
|
||||
//uri 为当前网页地址
|
||||
func (js *Js) GetConfig(uri string) (config *Config, err error) {
|
||||
config = new(Config)
|
||||
var ticketStr string
|
||||
ticketStr, err = js.getTicket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
nonceStr := util.RandomStr(16)
|
||||
timestamp := util.GetCurrTs()
|
||||
str := fmt.Sprintf("jsapi_ticket=%s&noncestr=%s×tamp=%d&url=%s", ticketStr, nonceStr, timestamp, uri)
|
||||
sigStr := util.Signature(str)
|
||||
|
||||
config.AppID = js.AppID
|
||||
config.NonceStr = nonceStr
|
||||
config.TimeStamp = timestamp
|
||||
config.Signature = sigStr
|
||||
return
|
||||
}
|
||||
|
||||
//getTicket 获取jsapi_tocket全局缓存
|
||||
func (js *Js) getTicket() (ticketStr string, err error) {
|
||||
js.GetJsApiTicketLock().Lock()
|
||||
defer js.GetJsApiTicketLock().Unlock()
|
||||
|
||||
//先从cache中取
|
||||
jsAPITicketCacheKey := fmt.Sprintf("jsapi_ticket_%s", js.AppID)
|
||||
val := js.Cache.Get(jsAPITicketCacheKey)
|
||||
if val != nil {
|
||||
ticketStr = val.(string)
|
||||
return
|
||||
}
|
||||
var ticket resTicket
|
||||
ticket, err = js.getTicketFromServer()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ticketStr = ticket.Ticket
|
||||
return
|
||||
}
|
||||
|
||||
//getTicketFromServer 强制从服务器中获取ticket
|
||||
func (js *Js) getTicketFromServer() (ticket resTicket, err error) {
|
||||
var accessToken string
|
||||
accessToken, err = js.GetAccessToken()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var response []byte
|
||||
url := fmt.Sprintf(getTicketURL, accessToken)
|
||||
response, err = util.HTTPGet(url)
|
||||
err = json.Unmarshal(response, &ticket)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if ticket.ErrCode != 0 {
|
||||
err = fmt.Errorf("getTicket Error : errcode=%s , errmsg=%s", ticket.ErrCode, ticket.ErrMsg)
|
||||
return
|
||||
}
|
||||
|
||||
jsAPITicketCacheKey := fmt.Sprintf("jsapi_ticket_%s", js.AppID)
|
||||
expires := ticket.ExpiresIn - 1500
|
||||
err = js.Cache.Set(jsAPITicketCacheKey, ticket.Ticket, time.Duration(expires)*time.Second)
|
||||
return
|
||||
}
|
||||
167
material/material.go
Normal file
167
material/material.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package material
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/silenceper/wechat/context"
|
||||
"github.com/silenceper/wechat/util"
|
||||
)
|
||||
|
||||
const (
|
||||
addNewsURL = "https://api.weixin.qq.com/cgi-bin/material/add_news"
|
||||
addMaterialURL = "https://api.weixin.qq.com/cgi-bin/material/add_material"
|
||||
)
|
||||
|
||||
//Material 素材管理
|
||||
type Material struct {
|
||||
*context.Context
|
||||
}
|
||||
|
||||
//NewMaterial init
|
||||
func NewMaterial(context *context.Context) *Material {
|
||||
material := new(Material)
|
||||
material.Context = context
|
||||
return material
|
||||
}
|
||||
|
||||
//Article 永久图文素材
|
||||
type Article struct {
|
||||
ThumbMediaID string `json:"thumb_media_id"`
|
||||
Author string `json:"author"`
|
||||
Digest string `json:"digest"`
|
||||
ShowCoverPic int `json:"show_cover_pic"`
|
||||
Content string `json:"content"`
|
||||
ContentSourceURL string `json:"content_source_url"`
|
||||
}
|
||||
|
||||
//reqArticles 永久性图文素材请求信息
|
||||
type reqArticles struct {
|
||||
articles []*Article `json:"articles"`
|
||||
}
|
||||
|
||||
//resArticles 永久性图文素材返回结果
|
||||
type resArticles struct {
|
||||
util.CommonError
|
||||
|
||||
MediaID string `json:"media_id"`
|
||||
}
|
||||
|
||||
//AddNews 新增永久图文素材
|
||||
func (material *Material) AddNews(articles []*Article) (mediaID string, err error) {
|
||||
req := &reqArticles{articles}
|
||||
|
||||
var accessToken string
|
||||
accessToken, err = material.GetAccessToken()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
uri := fmt.Sprintf("%s?access_token=%s", addNewsURL, accessToken)
|
||||
responseBytes, err := util.PostJSON(uri, req)
|
||||
var res resArticles
|
||||
err = json.Unmarshal(responseBytes, res)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
mediaID = res.MediaID
|
||||
return
|
||||
}
|
||||
|
||||
//resAddMaterial 永久性素材上传返回的结果
|
||||
type resAddMaterial struct {
|
||||
util.CommonError
|
||||
|
||||
MediaID string `json:"media_id"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
//AddMaterial 上传永久性素材(处理视频需要单独上传)
|
||||
func (material *Material) AddMaterial(mediaType MediaType, filename string) (mediaID string, url string, err error) {
|
||||
if mediaType == MediaTypeVideo {
|
||||
err = errors.New("永久视频素材上传使用 AddVideo 方法")
|
||||
}
|
||||
var accessToken string
|
||||
accessToken, err = material.GetAccessToken()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
uri := fmt.Sprintf("%s?access_token=%s&type=%s", addMaterialURL, accessToken, mediaType)
|
||||
var response []byte
|
||||
response, err = util.PostFile("media", filename, uri)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var resMaterial resAddMaterial
|
||||
err = json.Unmarshal(response, &resMaterial)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if resMaterial.ErrCode != 0 {
|
||||
err = fmt.Errorf("AddMaterial error : errcode=%v , errmsg=%v", resMaterial.ErrCode, resMaterial.ErrMsg)
|
||||
return
|
||||
}
|
||||
mediaID = resMaterial.MediaID
|
||||
url = resMaterial.URL
|
||||
return
|
||||
}
|
||||
|
||||
type reqVideo struct {
|
||||
Title string `json:"title"`
|
||||
Introduction string `json:"introduction"`
|
||||
}
|
||||
|
||||
//AddVideo 永久视频素材文件上传
|
||||
func (material *Material) AddVideo(filename, title, introduction string) (mediaID string, url string, err error) {
|
||||
var accessToken string
|
||||
accessToken, err = material.GetAccessToken()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
uri := fmt.Sprintf("%s?access_token=%s&type=video", addMaterialURL, accessToken)
|
||||
|
||||
videoDesc := &reqVideo{
|
||||
Title: title,
|
||||
Introduction: introduction,
|
||||
}
|
||||
var fieldValue []byte
|
||||
fieldValue, err = json.Marshal(videoDesc)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
fields := []util.MultipartFormField{
|
||||
{
|
||||
IsFile: true,
|
||||
Fieldname: "video",
|
||||
Filename: filename,
|
||||
},
|
||||
{
|
||||
IsFile: true,
|
||||
Fieldname: "description",
|
||||
Value: fieldValue,
|
||||
},
|
||||
}
|
||||
|
||||
var response []byte
|
||||
response, err = util.PostMultipartForm(fields, uri)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var resMaterial resAddMaterial
|
||||
err = json.Unmarshal(response, &resMaterial)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if resMaterial.ErrCode != 0 {
|
||||
err = fmt.Errorf("AddMaterial error : errcode=%v , errmsg=%v", resMaterial.ErrCode, resMaterial.ErrMsg)
|
||||
return
|
||||
}
|
||||
mediaID = resMaterial.MediaID
|
||||
url = resMaterial.URL
|
||||
return
|
||||
}
|
||||
109
material/media.go
Normal file
109
material/media.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package material
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/silenceper/wechat/util"
|
||||
)
|
||||
|
||||
//MediaType 媒体文件类型
|
||||
type MediaType string
|
||||
|
||||
const (
|
||||
//MediaTypeImage 媒体文件:图片
|
||||
MediaTypeImage MediaType = "image"
|
||||
//MediaTypeVoice 媒体文件:声音
|
||||
MediaTypeVoice = "voice"
|
||||
//MediaTypeVideo 媒体文件:视频
|
||||
MediaTypeVideo = "video"
|
||||
//MediaTypeThumb 媒体文件:缩略图
|
||||
MediaTypeThumb = "thumb"
|
||||
)
|
||||
|
||||
const (
|
||||
mediaUploadURL = "https://api.weixin.qq.com/cgi-bin/media/upload"
|
||||
mediaUploadImageURL = "https://api.weixin.qq.com/cgi-bin/media/uploadimg"
|
||||
mediaGetURL = "https://api.weixin.qq.com/cgi-bin/media/get"
|
||||
)
|
||||
|
||||
//Media 临时素材上传返回信息
|
||||
type Media struct {
|
||||
util.CommonError
|
||||
|
||||
Type MediaType `json:"type"`
|
||||
MediaID string `json:"media_id"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
|
||||
//MediaUpload 临时素材上传
|
||||
func (material *Material) MediaUpload(mediaType MediaType, filename string) (media Media, err error) {
|
||||
var accessToken string
|
||||
accessToken, err = material.GetAccessToken()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
uri := fmt.Sprintf("%s?access_token=%s&type=%s", mediaUploadURL, accessToken, mediaType)
|
||||
var response []byte
|
||||
response, err = util.PostFile("media", filename, uri)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal(response, &media)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if media.ErrCode != 0 {
|
||||
err = fmt.Errorf("MediaUpload error : errcode=%v , errmsg=%v", media.ErrCode, media.ErrMsg)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//GetMediaURL 返回临时素材的下载地址供用户自己处理
|
||||
//NOTICE: URL 不可公开,因为含access_token 需要立即另存文件
|
||||
func (material *Material) GetMediaURL(mediaID string) (mediaURL string, err error) {
|
||||
var accessToken string
|
||||
accessToken, err = material.GetAccessToken()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
mediaURL = fmt.Sprintf("%s?access_token=%s&media_id=%s", mediaGetURL, accessToken, mediaID)
|
||||
return
|
||||
}
|
||||
|
||||
//resMediaImage 图片上传返回结果
|
||||
type resMediaImage struct {
|
||||
util.CommonError
|
||||
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
//ImageUpload 图片上传
|
||||
func (material *Material) ImageUpload(filename string) (url string, err error) {
|
||||
var accessToken string
|
||||
accessToken, err = material.GetAccessToken()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
uri := fmt.Sprintf("%s?access_token=%s", mediaUploadImageURL, accessToken)
|
||||
var response []byte
|
||||
response, err = util.PostFile("media", filename, uri)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var image resMediaImage
|
||||
err = json.Unmarshal(response, &image)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if image.ErrCode != 0 {
|
||||
err = fmt.Errorf("UploadImage error : errcode=%v , errmsg=%v", image.ErrCode, image.ErrMsg)
|
||||
return
|
||||
}
|
||||
url = image.URL
|
||||
return
|
||||
|
||||
}
|
||||
17
message/image.go
Normal file
17
message/image.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package message
|
||||
|
||||
//Image 图片消息
|
||||
type Image struct {
|
||||
CommonToken
|
||||
|
||||
Image struct {
|
||||
MediaID string `xml:"MediaId"`
|
||||
} `xml:"Image"`
|
||||
}
|
||||
|
||||
//NewImage 回复图片消息
|
||||
func NewImage(mediaID string) *Image {
|
||||
image := new(Image)
|
||||
image.Image.MediaID = mediaID
|
||||
return image
|
||||
}
|
||||
120
message/message.go
Normal file
120
message/message.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package message
|
||||
|
||||
import "encoding/xml"
|
||||
|
||||
//MsgType 基本消息类型
|
||||
type MsgType string
|
||||
|
||||
//EventType 事件类型
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
//MsgTypeText 表示文本消息
|
||||
MsgTypeText MsgType = "text"
|
||||
//MsgTypeImage 表示图片消息
|
||||
MsgTypeImage = "image"
|
||||
//MsgTypeVoice 表示语音消息
|
||||
MsgTypeVoice = "voice"
|
||||
//MsgTypeVideo 表示视频消息
|
||||
MsgTypeVideo = "video"
|
||||
//MsgTypeShortVideo 表示短视频消息[限接收]
|
||||
MsgTypeShortVideo = "shortvideo"
|
||||
//MsgTypeLocation 表示坐标消息[限接收]
|
||||
MsgTypeLocation = "location"
|
||||
//MsgTypeLink 表示链接消息[限接收]
|
||||
MsgTypeLink = "link"
|
||||
//MsgTypeMusic 表示音乐消息[限回复]
|
||||
MsgTypeMusic = "music"
|
||||
//MsgTypeNews 表示图文消息[限回复]
|
||||
MsgTypeNews = "news"
|
||||
//MsgTypeTransfer 表示消息消息转发到客服
|
||||
MsgTypeTransfer = "transfer_customer_service"
|
||||
)
|
||||
|
||||
const (
|
||||
//EventSubscribe 订阅
|
||||
EventSubscribe EventType = "subscribe"
|
||||
//EventUnsubscribe 取消订阅
|
||||
EventUnsubscribe EventType = "unsubscribe"
|
||||
//EventScan 用户已经关注公众号,则微信会将带场景值扫描事件推送给开发者
|
||||
EventScan EventType = "SCAN"
|
||||
//EventLocation 上报地理位置事件
|
||||
EventLocation EventType = "LOCATION"
|
||||
//EventClick 点击菜单拉取消息时的事件推送
|
||||
EventClick EventType = "CLICK"
|
||||
//EventView 点击菜单跳转链接时的事件推送
|
||||
EventView EventType = "VIEW"
|
||||
)
|
||||
|
||||
//MixMessage 存放所有微信发送过来的消息和事件
|
||||
type MixMessage struct {
|
||||
CommonToken
|
||||
|
||||
//基本消息
|
||||
MsgID int64 `xml:"MsgId"`
|
||||
Content string `xml:"Content"`
|
||||
PicURL string `xml:"PicUrl"`
|
||||
MediaID string `xml:"MediaId"`
|
||||
Format string `xml:"Format"`
|
||||
ThumbMediaID string `xml:"ThumbMediaId"`
|
||||
LocationX float64 `xml:"Location_X"`
|
||||
LocationY float64 `xml:"Location_Y"`
|
||||
Scale float64 `xml:"Scale"`
|
||||
Label string `xml:"Label"`
|
||||
Title string `xml:"Title"`
|
||||
Description string `xml:"Description"`
|
||||
URL string `xml:"Url"`
|
||||
|
||||
//事件相关
|
||||
Event string `xml:"Event"`
|
||||
EventKey string `xml:"EventKey"`
|
||||
Ticket string `xml:"Ticket"`
|
||||
Latitude string `xml:"Latitude"`
|
||||
Longitude string `xml:"Longitude"`
|
||||
Precision string `xml:"Precision"`
|
||||
}
|
||||
|
||||
//EncryptedXMLMsg 安全模式下的消息体
|
||||
type EncryptedXMLMsg struct {
|
||||
XMLName struct{} `xml:"xml" json:"-"`
|
||||
ToUserName string `xml:"ToUserName" json:"ToUserName"`
|
||||
EncryptedMsg string `xml:"Encrypt" json:"Encrypt"`
|
||||
}
|
||||
|
||||
//ResponseEncryptedXMLMsg 需要返回的消息体
|
||||
type ResponseEncryptedXMLMsg struct {
|
||||
XMLName struct{} `xml:"xml" json:"-"`
|
||||
EncryptedMsg string `xml:"Encrypt" json:"Encrypt"`
|
||||
MsgSignature string `xml:"MsgSignature" json:"MsgSignature"`
|
||||
Timestamp int64 `xml:"TimeStamp" json:"TimeStamp"`
|
||||
Nonce string `xml:"Nonce" json:"Nonce"`
|
||||
}
|
||||
|
||||
// CommonToken 消息中通用的结构
|
||||
type CommonToken struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
ToUserName string `xml:"ToUserName"`
|
||||
FromUserName string `xml:"FromUserName"`
|
||||
CreateTime int64 `xml:"CreateTime"`
|
||||
MsgType MsgType `xml:"MsgType"`
|
||||
}
|
||||
|
||||
//SetToUserName set ToUserName
|
||||
func (msg *CommonToken) SetToUserName(toUserName string) {
|
||||
msg.ToUserName = toUserName
|
||||
}
|
||||
|
||||
//SetFromUserName set FromUserName
|
||||
func (msg *CommonToken) SetFromUserName(fromUserName string) {
|
||||
msg.FromUserName = fromUserName
|
||||
}
|
||||
|
||||
//SetCreateTime set createTime
|
||||
func (msg *CommonToken) SetCreateTime(createTime int64) {
|
||||
msg.CreateTime = createTime
|
||||
}
|
||||
|
||||
//SetMsgType set MsgType
|
||||
func (msg *CommonToken) SetMsgType(msgType MsgType) {
|
||||
msg.MsgType = msgType
|
||||
}
|
||||
24
message/music.go
Normal file
24
message/music.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package message
|
||||
|
||||
//Music 音乐消息
|
||||
type Music struct {
|
||||
CommonToken
|
||||
|
||||
Music struct {
|
||||
Title string `xml:"Title" `
|
||||
Description string `xml:"Description" `
|
||||
MusicURL string `xml:"MusicUrl" `
|
||||
HQMusicURL string `xml:"HQMusicUrl" `
|
||||
ThumbMediaID string `xml:"ThumbMediaId"`
|
||||
} `xml:"Music"`
|
||||
}
|
||||
|
||||
//NewMusic 回复音乐消息
|
||||
func NewMusic(title, description, musicURL, hQMusicURL, thumbMediaID string) *Music {
|
||||
music := new(Music)
|
||||
music.Music.Title = title
|
||||
music.Music.Description = description
|
||||
music.Music.MusicURL = musicURL
|
||||
music.Music.ThumbMediaID = thumbMediaID
|
||||
return music
|
||||
}
|
||||
35
message/news.go
Normal file
35
message/news.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package message
|
||||
|
||||
//News 图文消息
|
||||
type News struct {
|
||||
CommonToken
|
||||
|
||||
ArticleCount int `xml:"ArticleCount"`
|
||||
Articles []*Article `xml:"Articles>item,omitempty"`
|
||||
}
|
||||
|
||||
//NewNews 初始化图文消息
|
||||
func NewNews(articles []*Article) *News {
|
||||
news := new(News)
|
||||
news.ArticleCount = len(articles)
|
||||
news.Articles = articles
|
||||
return news
|
||||
}
|
||||
|
||||
//Article 单篇文章
|
||||
type Article struct {
|
||||
Title string `xml:"Title,omitempty"`
|
||||
Description string `xml:"Description,omitempty"`
|
||||
PicURL string `xml:"PicUrl,omitempty"`
|
||||
URL string `xml:"Url,omitempty"`
|
||||
}
|
||||
|
||||
//NewArticle 初始化文章
|
||||
func NewArticle(title, description, picURL, url string) *Article {
|
||||
article := new(Article)
|
||||
article.Title = title
|
||||
article.Description = description
|
||||
article.PicURL = picURL
|
||||
article.URL = url
|
||||
return article
|
||||
}
|
||||
15
message/reply.go
Normal file
15
message/reply.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package message
|
||||
|
||||
import "errors"
|
||||
|
||||
//ErrInvalidReply 无效的回复
|
||||
var ErrInvalidReply = errors.New("无效的回复消息")
|
||||
|
||||
//ErrUnsupportReply 不支持的回复类型
|
||||
var ErrUnsupportReply = errors.New("不支持的回复消息")
|
||||
|
||||
//Reply 消息回复
|
||||
type Reply struct {
|
||||
MsgType MsgType
|
||||
MsgData interface{}
|
||||
}
|
||||
14
message/text.go
Normal file
14
message/text.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package message
|
||||
|
||||
//Text 文本消息
|
||||
type Text struct {
|
||||
CommonToken
|
||||
Content string `xml:"Content"`
|
||||
}
|
||||
|
||||
//NewText 初始化文本消息
|
||||
func NewText(content string) *Text {
|
||||
text := new(Text)
|
||||
text.Content = content
|
||||
return text
|
||||
}
|
||||
21
message/video.go
Normal file
21
message/video.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package message
|
||||
|
||||
//Video 视频消息
|
||||
type Video struct {
|
||||
CommonToken
|
||||
|
||||
Video struct {
|
||||
MediaID string `xml:"MediaId"`
|
||||
Title string `xml:"Title,omitempty"`
|
||||
Description string `xml:"Description,omitempty"`
|
||||
} `xml:"Video"`
|
||||
}
|
||||
|
||||
//NewVideo 回复图片消息
|
||||
func NewVideo(mediaID, title, description string) *Video {
|
||||
video := new(Video)
|
||||
video.Video.MediaID = mediaID
|
||||
video.Video.Title = title
|
||||
video.Video.Description = description
|
||||
return video
|
||||
}
|
||||
17
message/voice.go
Normal file
17
message/voice.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package message
|
||||
|
||||
//Voice 语音消息
|
||||
type Voice struct {
|
||||
CommonToken
|
||||
|
||||
Voice struct {
|
||||
MediaID string `xml:"MediaId"`
|
||||
} `xml:"Voice"`
|
||||
}
|
||||
|
||||
//NewVoice 回复语音消息
|
||||
func NewVoice(mediaID string) *Voice {
|
||||
voice := new(Voice)
|
||||
voice.Voice.MediaID = mediaID
|
||||
return voice
|
||||
}
|
||||
152
oauth/oauth.go
Normal file
152
oauth/oauth.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/silenceper/wechat/context"
|
||||
"github.com/silenceper/wechat/util"
|
||||
)
|
||||
|
||||
const (
|
||||
redirectOauthURL = "https://open.weixin.qq.com/connect/oauth2/authorize?appid=%s&redirect_uri=%s&response_type=code&scope=%s&state=%s#wechat_redirect"
|
||||
accessTokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token?appid=%s&secret=%s&code=%s&grant_type=authorization_code"
|
||||
refreshAccessTokenURL = "https://api.weixin.qq.com/sns/oauth2/refresh_token?appid=%s&grant_type=refresh_token&refresh_token=%s"
|
||||
userInfoURL = "https://api.weixin.qq.com/sns/userinfo?access_token=%s&openid=%s&lang=zh_CN"
|
||||
checkAccessTokenURL = "https://api.weixin.qq.com/sns/auth?access_token=%s&openid=%s"
|
||||
)
|
||||
|
||||
//Oauth 保存用户授权信息
|
||||
type Oauth struct {
|
||||
*context.Context
|
||||
}
|
||||
|
||||
//NewOauth 实例化授权信息
|
||||
func NewOauth(context *context.Context) *Oauth {
|
||||
auth := new(Oauth)
|
||||
auth.Context = context
|
||||
return auth
|
||||
}
|
||||
|
||||
//GetRedirectURL 获取跳转的url地址
|
||||
func (oauth *Oauth) GetRedirectURL(redirectURI, scope, state string) (string, error) {
|
||||
//url encode
|
||||
urlStr := url.QueryEscape(redirectURI)
|
||||
return fmt.Sprintf(redirectOauthURL, oauth.AppID, urlStr, scope, state), nil
|
||||
}
|
||||
|
||||
//Redirect 跳转到网页授权
|
||||
func (oauth *Oauth) Redirect(redirectURI, scope, state string) error {
|
||||
location, err := oauth.GetRedirectURL(redirectURI, scope, state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
http.Redirect(oauth.Writer, oauth.Request, location, 302)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResAccessToken 获取用户授权access_token的返回结果
|
||||
type ResAccessToken struct {
|
||||
util.CommonError
|
||||
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
OpenID string `json:"openid"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// GetUserAccessToken 通过网页授权的code 换取access_token(区别于context中的access_token)
|
||||
func (oauth *Oauth) GetUserAccessToken(code string) (result ResAccessToken, err error) {
|
||||
urlStr := fmt.Sprintf(accessTokenURL, oauth.AppID, oauth.AppSecret, code)
|
||||
var response []byte
|
||||
response, err = util.HTTPGet(urlStr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal(response, &result)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if result.ErrCode != 0 {
|
||||
err = fmt.Errorf("GetUserAccessToken error : errcode=%v , errmsg=%v", result.ErrCode, result.ErrMsg)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//RefreshAccessToken 刷新access_token
|
||||
func (oauth *Oauth) RefreshAccessToken(refreshToken string) (result ResAccessToken, err error) {
|
||||
urlStr := fmt.Sprintf(refreshAccessTokenURL, oauth.AppID, refreshToken)
|
||||
var response []byte
|
||||
response, err = util.HTTPGet(urlStr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal(response, &result)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if result.ErrCode != 0 {
|
||||
err = fmt.Errorf("GetUserAccessToken error : errcode=%v , errmsg=%v", result.ErrCode, result.ErrMsg)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//CheckAccessToken 检验access_token是否有效
|
||||
func (oauth *Oauth) CheckAccessToken(accessToken, openID string) (b bool, err error) {
|
||||
urlStr := fmt.Sprintf(checkAccessTokenURL, accessToken, openID)
|
||||
var response []byte
|
||||
response, err = util.HTTPGet(urlStr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var result util.CommonError
|
||||
err = json.Unmarshal(response, &result)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if result.ErrCode != 0 {
|
||||
b = false
|
||||
return
|
||||
}
|
||||
b = true
|
||||
return
|
||||
}
|
||||
|
||||
//UserInfo 用户授权获取到用户信息
|
||||
type UserInfo struct {
|
||||
util.CommonError
|
||||
|
||||
OpenID string `json:"openid"`
|
||||
Nickname string `json:"nickname"`
|
||||
Sex int32 `json:"sex"`
|
||||
Province string `json:"province"`
|
||||
City string `json:"city"`
|
||||
Country string `json:"country"`
|
||||
HeadImgURL string `json:"headimgurl"`
|
||||
Privilege []string `json:"privilege"`
|
||||
Unionid string `json:"unionid"`
|
||||
}
|
||||
|
||||
//GetUserInfo 如果scope为 snsapi_userinfo 则可以通过此方法获取到用户基本信息
|
||||
func (oauth *Oauth) GetUserInfo(accessToken, openID string) (result UserInfo, err error) {
|
||||
urlStr := fmt.Sprintf(userInfoURL, accessToken, openID)
|
||||
var response []byte
|
||||
response, err = util.HTTPGet(urlStr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal(response, &result)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if result.ErrCode != 0 {
|
||||
err = fmt.Errorf("GetUserInfo error : errcode=%v , errmsg=%v", result.ErrCode, result.ErrMsg)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
233
server/server.go
Normal file
233
server/server.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/silenceper/wechat/context"
|
||||
"github.com/silenceper/wechat/message"
|
||||
"github.com/silenceper/wechat/util"
|
||||
)
|
||||
|
||||
//Server struct
|
||||
type Server struct {
|
||||
*context.Context
|
||||
|
||||
openID string
|
||||
|
||||
messageHandler func(message.MixMessage) *message.Reply
|
||||
|
||||
requestRawXMLMsg []byte
|
||||
requestMsg message.MixMessage
|
||||
responseRawXMLMsg []byte
|
||||
responseMsg interface{}
|
||||
|
||||
isSafeMode bool
|
||||
random []byte
|
||||
nonce string
|
||||
timestamp int64
|
||||
}
|
||||
|
||||
//NewServer init
|
||||
func NewServer(context *context.Context) *Server {
|
||||
srv := new(Server)
|
||||
srv.Context = context
|
||||
return srv
|
||||
}
|
||||
|
||||
//Serve 处理微信的请求消息
|
||||
func (srv *Server) Serve() error {
|
||||
if !srv.Validate() {
|
||||
return fmt.Errorf("请求校验失败")
|
||||
}
|
||||
|
||||
echostr, exists := srv.GetQuery("echostr")
|
||||
if exists {
|
||||
srv.String(echostr)
|
||||
return nil
|
||||
}
|
||||
|
||||
response, err := srv.handleRequest()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//debug
|
||||
//fmt.Println("request msg = ", string(srv.requestRawXMLMsg))
|
||||
|
||||
return srv.buildResponse(response)
|
||||
}
|
||||
|
||||
//Validate 校验请求是否合法
|
||||
func (srv *Server) Validate() bool {
|
||||
timestamp := srv.Query("timestamp")
|
||||
nonce := srv.Query("nonce")
|
||||
signature := srv.Query("signature")
|
||||
return signature == util.Signature(srv.Token, timestamp, nonce)
|
||||
}
|
||||
|
||||
//HandleRequest 处理微信的请求
|
||||
func (srv *Server) handleRequest() (reply *message.Reply, err error) {
|
||||
//set isSafeMode
|
||||
srv.isSafeMode = false
|
||||
encryptType := srv.Query("encrypt_type")
|
||||
if encryptType == "aes" {
|
||||
srv.isSafeMode = true
|
||||
}
|
||||
|
||||
//set openID
|
||||
srv.openID = srv.Query("openid")
|
||||
|
||||
var msg interface{}
|
||||
msg, err = srv.getMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
mixMessage, success := msg.(message.MixMessage)
|
||||
if !success {
|
||||
err = errors.New("消息类型转换失败")
|
||||
}
|
||||
srv.requestMsg = mixMessage
|
||||
reply = srv.messageHandler(mixMessage)
|
||||
return
|
||||
}
|
||||
|
||||
//GetOpenID return openID
|
||||
func (srv *Server) GetOpenID() string {
|
||||
return srv.openID
|
||||
}
|
||||
|
||||
//getMessage 解析微信返回的消息
|
||||
func (srv *Server) getMessage() (interface{}, error) {
|
||||
var rawXMLMsgBytes []byte
|
||||
var err error
|
||||
if srv.isSafeMode {
|
||||
var encryptedXMLMsg message.EncryptedXMLMsg
|
||||
if err := xml.NewDecoder(srv.Request.Body).Decode(&encryptedXMLMsg); err != nil {
|
||||
return nil, fmt.Errorf("从body中解析xml失败,err=%v", err)
|
||||
}
|
||||
|
||||
//验证消息签名
|
||||
timestamp := srv.Query("timestamp")
|
||||
srv.timestamp, err = strconv.ParseInt(timestamp, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nonce := srv.Query("nonce")
|
||||
srv.nonce = nonce
|
||||
msgSignature := srv.Query("msg_signature")
|
||||
msgSignatureGen := util.Signature(srv.Token, timestamp, nonce, encryptedXMLMsg.EncryptedMsg)
|
||||
if msgSignature != msgSignatureGen {
|
||||
return nil, fmt.Errorf("消息不合法,验证签名失败")
|
||||
}
|
||||
|
||||
//解密
|
||||
srv.random, rawXMLMsgBytes, err = util.DecryptMsg(srv.AppID, encryptedXMLMsg.EncryptedMsg, srv.EncodingAESKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("消息解密失败, err=%v", err)
|
||||
}
|
||||
} else {
|
||||
rawXMLMsgBytes, err = ioutil.ReadAll(srv.Request.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("从body中解析xml失败, err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
srv.requestRawXMLMsg = rawXMLMsgBytes
|
||||
|
||||
return srv.parseRequestMessage(rawXMLMsgBytes)
|
||||
}
|
||||
|
||||
func (srv *Server) parseRequestMessage(rawXMLMsgBytes []byte) (msg message.MixMessage, err error) {
|
||||
msg = message.MixMessage{}
|
||||
err = xml.Unmarshal(rawXMLMsgBytes, &msg)
|
||||
return
|
||||
}
|
||||
|
||||
//SetMessageHandler 设置用户自定义的回调方法
|
||||
func (srv *Server) SetMessageHandler(handler func(message.MixMessage) *message.Reply) {
|
||||
srv.messageHandler = handler
|
||||
}
|
||||
|
||||
func (srv *Server) buildResponse(reply *message.Reply) (err error) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
err = fmt.Errorf("panic error: %v\n%s", e, debug.Stack())
|
||||
}
|
||||
}()
|
||||
if reply == nil {
|
||||
//do nothing
|
||||
return nil
|
||||
}
|
||||
msgType := reply.MsgType
|
||||
switch msgType {
|
||||
case message.MsgTypeText:
|
||||
case message.MsgTypeImage:
|
||||
case message.MsgTypeVoice:
|
||||
case message.MsgTypeVideo:
|
||||
case message.MsgTypeMusic:
|
||||
case message.MsgTypeNews:
|
||||
case message.MsgTypeTransfer:
|
||||
default:
|
||||
err = message.ErrUnsupportReply
|
||||
return
|
||||
}
|
||||
|
||||
msgData := reply.MsgData
|
||||
value := reflect.ValueOf(msgData)
|
||||
//msgData must be a ptr
|
||||
kind := value.Kind().String()
|
||||
if 0 != strings.Compare("ptr", kind) {
|
||||
return message.ErrUnsupportReply
|
||||
}
|
||||
|
||||
params := make([]reflect.Value, 1)
|
||||
params[0] = reflect.ValueOf(srv.requestMsg.FromUserName)
|
||||
value.MethodByName("SetToUserName").Call(params)
|
||||
|
||||
params[0] = reflect.ValueOf(srv.requestMsg.ToUserName)
|
||||
value.MethodByName("SetFromUserName").Call(params)
|
||||
|
||||
params[0] = reflect.ValueOf(msgType)
|
||||
value.MethodByName("SetMsgType").Call(params)
|
||||
|
||||
params[0] = reflect.ValueOf(util.GetCurrTs())
|
||||
value.MethodByName("SetCreateTime").Call(params)
|
||||
|
||||
srv.responseMsg = msgData
|
||||
srv.responseRawXMLMsg, err = xml.Marshal(msgData)
|
||||
return
|
||||
}
|
||||
|
||||
//Send 将自定义的消息发送
|
||||
func (srv *Server) Send() (err error) {
|
||||
replyMsg := srv.responseMsg
|
||||
if srv.isSafeMode {
|
||||
//安全模式下对消息进行加密
|
||||
var encryptedMsg []byte
|
||||
encryptedMsg, err = util.EncryptMsg(srv.random, srv.responseRawXMLMsg, srv.AppID, srv.EncodingAESKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
//TODO 如果获取不到timestamp nonce 则自己生成
|
||||
timestamp := srv.timestamp
|
||||
timestampStr := strconv.FormatInt(timestamp, 10)
|
||||
msgSignature := util.Signature(srv.Token, timestampStr, srv.nonce, string(encryptedMsg))
|
||||
replyMsg = message.ResponseEncryptedXMLMsg{
|
||||
EncryptedMsg: string(encryptedMsg),
|
||||
MsgSignature: msgSignature,
|
||||
Timestamp: timestamp,
|
||||
Nonce: srv.nonce,
|
||||
}
|
||||
}
|
||||
if replyMsg != nil {
|
||||
srv.XML(replyMsg)
|
||||
}
|
||||
return
|
||||
}
|
||||
183
util/crypto.go
Normal file
183
util/crypto.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
//EncryptMsg 加密消息
|
||||
func EncryptMsg(random, rawXMLMsg []byte, appID, aesKey string) (encrtptMsg []byte, err error) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
err = fmt.Errorf("panic error: err=%v", e)
|
||||
return
|
||||
}
|
||||
}()
|
||||
var key []byte
|
||||
key, err = aesKeyDecode(aesKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ciphertext := AESEncryptMsg(random, rawXMLMsg, appID, key)
|
||||
encrtptMsg = []byte(base64.StdEncoding.EncodeToString(ciphertext))
|
||||
return
|
||||
}
|
||||
|
||||
//AESEncryptMsg ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId]
|
||||
//参考:github.com/chanxuehong/wechat.v2
|
||||
func AESEncryptMsg(random, rawXMLMsg []byte, appID string, aesKey []byte) (ciphertext []byte) {
|
||||
const (
|
||||
BlockSize = 32 // PKCS#7
|
||||
BlockMask = BlockSize - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数
|
||||
)
|
||||
|
||||
appIDOffset := 20 + len(rawXMLMsg)
|
||||
contentLen := appIDOffset + len(appID)
|
||||
amountToPad := BlockSize - contentLen&BlockMask
|
||||
plaintextLen := contentLen + amountToPad
|
||||
|
||||
plaintext := make([]byte, plaintextLen)
|
||||
|
||||
// 拼接
|
||||
copy(plaintext[:16], random)
|
||||
encodeNetworkByteOrder(plaintext[16:20], uint32(len(rawXMLMsg)))
|
||||
copy(plaintext[20:], rawXMLMsg)
|
||||
copy(plaintext[appIDOffset:], appID)
|
||||
|
||||
// PKCS#7 补位
|
||||
for i := contentLen; i < plaintextLen; i++ {
|
||||
plaintext[i] = byte(amountToPad)
|
||||
}
|
||||
|
||||
// 加密
|
||||
block, err := aes.NewCipher(aesKey[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mode := cipher.NewCBCEncrypter(block, aesKey[:16])
|
||||
mode.CryptBlocks(plaintext, plaintext)
|
||||
|
||||
ciphertext = plaintext
|
||||
return
|
||||
}
|
||||
|
||||
//DecryptMsg 消息解密
|
||||
func DecryptMsg(appID, encryptedMsg, aesKey string) (random, rawMsgXMLBytes []byte, err error) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
err = fmt.Errorf("panic error: err=%v", e)
|
||||
return
|
||||
}
|
||||
}()
|
||||
var encryptedMsgBytes, key, getAppIDBytes []byte
|
||||
encryptedMsgBytes, err = base64.StdEncoding.DecodeString(encryptedMsg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
key, err = aesKeyDecode(aesKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
random, rawMsgXMLBytes, getAppIDBytes, err = AESDecryptMsg(encryptedMsgBytes, key)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("消息解密失败,%v", err)
|
||||
return
|
||||
}
|
||||
if appID != string(getAppIDBytes) {
|
||||
err = fmt.Errorf("消息解密校验APPID失败")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func aesKeyDecode(encodedAESKey string) (key []byte, err error) {
|
||||
if len(encodedAESKey) != 43 {
|
||||
err = fmt.Errorf("the length of encodedAESKey must be equal to 43")
|
||||
return
|
||||
}
|
||||
key, err = base64.StdEncoding.DecodeString(encodedAESKey + "=")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(key) != 32 {
|
||||
err = fmt.Errorf("encodingAESKey invalid")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// AESDecryptMsg ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId]
|
||||
//参考:github.com/chanxuehong/wechat.v2
|
||||
func AESDecryptMsg(ciphertext []byte, aesKey []byte) (random, rawXMLMsg, appID []byte, err error) {
|
||||
const (
|
||||
BlockSize = 32 // PKCS#7
|
||||
BlockMask = BlockSize - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数
|
||||
)
|
||||
|
||||
if len(ciphertext) < BlockSize {
|
||||
err = fmt.Errorf("the length of ciphertext too short: %d", len(ciphertext))
|
||||
return
|
||||
}
|
||||
if len(ciphertext)&BlockMask != 0 {
|
||||
err = fmt.Errorf("ciphertext is not a multiple of the block size, the length is %d", len(ciphertext))
|
||||
return
|
||||
}
|
||||
|
||||
plaintext := make([]byte, len(ciphertext)) // len(plaintext) >= BLOCK_SIZE
|
||||
|
||||
// 解密
|
||||
block, err := aes.NewCipher(aesKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mode := cipher.NewCBCDecrypter(block, aesKey[:16])
|
||||
mode.CryptBlocks(plaintext, ciphertext)
|
||||
|
||||
// PKCS#7 去除补位
|
||||
amountToPad := int(plaintext[len(plaintext)-1])
|
||||
if amountToPad < 1 || amountToPad > BlockSize {
|
||||
err = fmt.Errorf("the amount to pad is incorrect: %d", amountToPad)
|
||||
return
|
||||
}
|
||||
plaintext = plaintext[:len(plaintext)-amountToPad]
|
||||
|
||||
// 反拼接
|
||||
// len(plaintext) == 16+4+len(rawXMLMsg)+len(appId)
|
||||
if len(plaintext) <= 20 {
|
||||
err = fmt.Errorf("plaintext too short, the length is %d", len(plaintext))
|
||||
return
|
||||
}
|
||||
rawXMLMsgLen := int(decodeNetworkByteOrder(plaintext[16:20]))
|
||||
if rawXMLMsgLen < 0 {
|
||||
err = fmt.Errorf("incorrect msg length: %d", rawXMLMsgLen)
|
||||
return
|
||||
}
|
||||
appIDOffset := 20 + rawXMLMsgLen
|
||||
if len(plaintext) <= appIDOffset {
|
||||
err = fmt.Errorf("msg length too large: %d", rawXMLMsgLen)
|
||||
return
|
||||
}
|
||||
|
||||
random = plaintext[:16:20]
|
||||
rawXMLMsg = plaintext[20:appIDOffset:appIDOffset]
|
||||
appID = plaintext[appIDOffset:]
|
||||
return
|
||||
}
|
||||
|
||||
// 把整数 n 格式化成 4 字节的网络字节序
|
||||
func encodeNetworkByteOrder(orderBytes []byte, n uint32) {
|
||||
orderBytes[0] = byte(n >> 24)
|
||||
orderBytes[1] = byte(n >> 16)
|
||||
orderBytes[2] = byte(n >> 8)
|
||||
orderBytes[3] = byte(n)
|
||||
}
|
||||
|
||||
// 从 4 字节的网络字节序里解析出整数
|
||||
func decodeNetworkByteOrder(orderBytes []byte) (n uint32) {
|
||||
return uint32(orderBytes[0])<<24 |
|
||||
uint32(orderBytes[1])<<16 |
|
||||
uint32(orderBytes[2])<<8 |
|
||||
uint32(orderBytes[3])
|
||||
}
|
||||
6
util/error.go
Normal file
6
util/error.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package util
|
||||
|
||||
type CommonError struct {
|
||||
ErrCode int64 `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
154
util/http.go
Normal file
154
util/http.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
)
|
||||
|
||||
//HTTPGet get 请求
|
||||
func HTTPGet(uri string) ([]byte, error) {
|
||||
response, err := http.Get(uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("http get error : uri=%v , statusCode=%v", uri, response.StatusCode)
|
||||
}
|
||||
return ioutil.ReadAll(response.Body)
|
||||
}
|
||||
|
||||
//PostJSON post json 数据请求
|
||||
func PostJSON(uri string, obj interface{}) ([]byte, error) {
|
||||
jsonData, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body := bytes.NewBuffer(jsonData)
|
||||
response, err := http.Post(uri, "application/json;charset=utf-8", body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("http get error : uri=%v , statusCode=%v", uri, response.StatusCode)
|
||||
}
|
||||
return ioutil.ReadAll(response.Body)
|
||||
}
|
||||
|
||||
//PostFile 上传文件
|
||||
func PostFile(fieldname, filename, uri string) ([]byte, error) {
|
||||
/*
|
||||
bodyBuf := &bytes.Buffer{}
|
||||
bodyWriter := multipart.NewWriter(bodyBuf)
|
||||
|
||||
fileWriter, err := bodyWriter.CreateFormFile(fieldname, filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error writing to buffer")
|
||||
}
|
||||
|
||||
fh, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening file")
|
||||
}
|
||||
defer fh.Close()
|
||||
|
||||
_, err = io.Copy(fileWriter, fh)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
contentType := bodyWriter.FormDataContentType()
|
||||
bodyWriter.Close()
|
||||
|
||||
resp, err := http.Post(uri, contentType, bodyBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, err
|
||||
}
|
||||
respBody, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return respBody, nil
|
||||
*/
|
||||
fields := []MultipartFormField{
|
||||
{
|
||||
IsFile: true,
|
||||
Fieldname: fieldname,
|
||||
Filename: filename,
|
||||
},
|
||||
}
|
||||
return PostMultipartForm(fields, uri)
|
||||
}
|
||||
|
||||
//MultipartFormField 保存文件或其他字段信息
|
||||
type MultipartFormField struct {
|
||||
IsFile bool
|
||||
Fieldname string
|
||||
Value []byte
|
||||
Filename string
|
||||
}
|
||||
|
||||
//PostMultipartForm 上传文件或其他多个字段
|
||||
func PostMultipartForm(fields []MultipartFormField, uri string) (respBody []byte, err error) {
|
||||
bodyBuf := &bytes.Buffer{}
|
||||
bodyWriter := multipart.NewWriter(bodyBuf)
|
||||
|
||||
for _, field := range fields {
|
||||
if field.IsFile {
|
||||
fileWriter, e := bodyWriter.CreateFormFile(field.Fieldname, field.Filename)
|
||||
if e != nil {
|
||||
err = fmt.Errorf("error writing to buffer , err=%v", e)
|
||||
return
|
||||
}
|
||||
|
||||
fh, e := os.Open(field.Filename)
|
||||
if e != nil {
|
||||
err = fmt.Errorf("error opening file , err=%v", e)
|
||||
return
|
||||
}
|
||||
defer fh.Close()
|
||||
|
||||
if _, err = io.Copy(fileWriter, fh); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
partWriter, e := bodyWriter.CreateFormField(field.Fieldname)
|
||||
if e != nil {
|
||||
err = e
|
||||
return
|
||||
}
|
||||
valueReader := bytes.NewReader(field.Value)
|
||||
if _, err = io.Copy(partWriter, valueReader); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
contentType := bodyWriter.FormDataContentType()
|
||||
bodyWriter.Close()
|
||||
|
||||
resp, e := http.Post(uri, contentType, bodyBuf)
|
||||
if e != nil {
|
||||
err = e
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, err
|
||||
}
|
||||
respBody, err = ioutil.ReadAll(resp.Body)
|
||||
return
|
||||
}
|
||||
18
util/signature.go
Normal file
18
util/signature.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
)
|
||||
|
||||
//Signature sha1签名
|
||||
func Signature(params ...string) string {
|
||||
sort.Strings(params)
|
||||
h := sha1.New()
|
||||
for _, s := range params {
|
||||
io.WriteString(h, s)
|
||||
}
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
11
util/signature_test.go
Normal file
11
util/signature_test.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package util
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSignature(t *testing.T) {
|
||||
//abc sig
|
||||
abc := "a9993e364706816aba3e25717850c26c9cd0d89d"
|
||||
if abc != Signature("a", "b", "c") {
|
||||
t.Error("test Signature Error")
|
||||
}
|
||||
}
|
||||
18
util/string.go
Normal file
18
util/string.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
//RandomStr 随机生成字符串
|
||||
func RandomStr(length int) string {
|
||||
str := "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
bytes := []byte(str)
|
||||
result := []byte{}
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
for i := 0; i < length; i++ {
|
||||
result = append(result, bytes[r.Intn(len(bytes))])
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
8
util/time.go
Normal file
8
util/time.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package util
|
||||
|
||||
import "time"
|
||||
|
||||
//GetCurrTs return current timestamps
|
||||
func GetCurrTs() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
70
wechat.go
Normal file
70
wechat.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package wechat
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/silenceper/wechat/cache"
|
||||
"github.com/silenceper/wechat/context"
|
||||
"github.com/silenceper/wechat/js"
|
||||
"github.com/silenceper/wechat/material"
|
||||
"github.com/silenceper/wechat/oauth"
|
||||
"github.com/silenceper/wechat/server"
|
||||
)
|
||||
|
||||
//Wechat struct
|
||||
type Wechat struct {
|
||||
Context *context.Context
|
||||
}
|
||||
|
||||
//Config for user
|
||||
type Config struct {
|
||||
AppID string
|
||||
AppSecret string
|
||||
Token string
|
||||
EncodingAESKey string
|
||||
Cache cache.Cache
|
||||
}
|
||||
|
||||
//NewWechat init
|
||||
func NewWechat(cfg *Config) *Wechat {
|
||||
context := new(context.Context)
|
||||
copyConfigToContext(cfg, context)
|
||||
return &Wechat{context}
|
||||
}
|
||||
|
||||
func copyConfigToContext(cfg *Config, context *context.Context) {
|
||||
context.AppID = cfg.AppID
|
||||
context.AppSecret = cfg.AppSecret
|
||||
context.Token = cfg.Token
|
||||
context.EncodingAESKey = cfg.EncodingAESKey
|
||||
context.Cache = cfg.Cache
|
||||
context.SetAccessTokenLock(new(sync.RWMutex))
|
||||
context.SetJsApiTicketLock(new(sync.RWMutex))
|
||||
}
|
||||
|
||||
//GetServer init
|
||||
func (wc *Wechat) GetServer(req *http.Request, writer http.ResponseWriter) *server.Server {
|
||||
wc.Context.Request = req
|
||||
wc.Context.Writer = writer
|
||||
return server.NewServer(wc.Context)
|
||||
}
|
||||
|
||||
//GetMaterial init
|
||||
func (wc *Wechat) GetMaterial() *material.Material {
|
||||
return material.NewMaterial(wc.Context)
|
||||
}
|
||||
|
||||
//GetOauth init
|
||||
func (wc *Wechat) GetOauth(req *http.Request, writer http.ResponseWriter) *oauth.Oauth {
|
||||
wc.Context.Request = req
|
||||
wc.Context.Writer = writer
|
||||
return oauth.NewOauth(wc.Context)
|
||||
}
|
||||
|
||||
//GetJs init
|
||||
func (wc *Wechat) GetJs(req *http.Request, writer http.ResponseWriter) *js.Js {
|
||||
wc.Context.Request = req
|
||||
wc.Context.Writer = writer
|
||||
return js.NewJs(wc.Context)
|
||||
}
|
||||
Reference in New Issue
Block a user