diff --git a/cache/cache.go b/cache/cache.go new file mode 100644 index 0000000..d144b8f --- /dev/null +++ b/cache/cache.go @@ -0,0 +1,11 @@ +package cache + +import "time" + +//Cache interface +type Cache interface { + Get(key string) interface{} + Set(key string, val interface{}, timeout time.Duration) error + IsExist(key string) bool + Delete(key string) error +} diff --git a/cache/memcache.go b/cache/memcache.go new file mode 100644 index 0000000..010b2c6 --- /dev/null +++ b/cache/memcache.go @@ -0,0 +1,51 @@ +package cache + +import ( + "errors" + "time" + + "github.com/bradfitz/gomemcache/memcache" +) + +//Memcache struct contains *memcache.Client +type Memcache struct { + conn *memcache.Client +} + +//NewMemcache create new memcache +func NewMemcache(server ...string) *Memcache { + mc := memcache.New(server...) + return &Memcache{mc} +} + +//Get return cached value +func (mem *Memcache) Get(key string) interface{} { + if item, err := mem.conn.Get(key); err == nil { + return string(item.Value) + } + return nil +} + +// IsExist check value exists in memcache. +func (mem *Memcache) IsExist(key string) bool { + _, err := mem.conn.Get(key) + if err != nil { + return false + } + return true +} + +//Set cached value with key and expire time. +func (mem *Memcache) Set(key string, val interface{}, timeout time.Duration) error { + v, ok := val.(string) + if !ok { + return errors.New("val must string") + } + item := &memcache.Item{Key: key, Value: []byte(v), Expiration: int32(timeout / time.Second)} + return mem.conn.Set(item) +} + +//Delete delete value in memcache. +func (mem *Memcache) Delete(key string) error { + return mem.conn.Delete(key) +} diff --git a/cache/memcache_test.go b/cache/memcache_test.go new file mode 100644 index 0000000..6b4ea02 --- /dev/null +++ b/cache/memcache_test.go @@ -0,0 +1,28 @@ +package cache + +import ( + "testing" + "time" +) + +func TestMemcache(t *testing.T) { + mem := NewMemcache("127.0.0.1:11211") + var err error + timeoutDuration := 10 * time.Second + if err = mem.Set("username", "silenceper", timeoutDuration); err != nil { + t.Error("set Error", err) + } + + if !mem.IsExist("username") { + t.Error("IsExist Error") + } + + name := mem.Get("username").(string) + if name != "silenceper" { + t.Error("get Error") + } + + if err = mem.Delete("username"); err != nil { + t.Errorf("delete Error , err=%v", err) + } +} diff --git a/context/access_token.go b/context/access_token.go new file mode 100644 index 0000000..3b3d263 --- /dev/null +++ b/context/access_token.go @@ -0,0 +1,71 @@ +package context + +import ( + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/silenceper/wechat/util" +) + +const ( + //AccessTokenURL 获取access_token的接口 + AccessTokenURL = "https://api.weixin.qq.com/cgi-bin/token" +) + +//ResAccessToken struct +type ResAccessToken struct { + util.CommonError + + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` +} + +//SetAccessTokenLock 设置读写锁(一个appID一个读写锁) +func (ctx *Context) SetAccessTokenLock(l *sync.RWMutex) { + ctx.accessTokenLock = l +} + +//GetAccessToken 获取access_token +func (ctx *Context) GetAccessToken() (accessToken string, err error) { + ctx.accessTokenLock.Lock() + defer ctx.accessTokenLock.Unlock() + + accessTokenCacheKey := fmt.Sprintf("access_token_%s", ctx.AppID) + val := ctx.Cache.Get(accessTokenCacheKey) + if val != nil { + accessToken = val.(string) + return + } + + //从微信服务器获取 + var resAccessToken ResAccessToken + resAccessToken, err = ctx.GetAccessTokenFromServer() + if err != nil { + return + } + + accessToken = resAccessToken.AccessToken + return +} + +//GetAccessTokenFromServer 强制从微信服务器获取token +func (ctx *Context) GetAccessTokenFromServer() (resAccessToken ResAccessToken, err error) { + url := fmt.Sprintf("%s?grant_type=client_credential&appid=%s&secret=%s", AccessTokenURL, ctx.AppID, ctx.AppSecret) + var body []byte + body, err = util.HTTPGet(url) + err = json.Unmarshal(body, &resAccessToken) + if err != nil { + return + } + if resAccessToken.ErrMsg != "" { + err = fmt.Errorf("get access_token error : errcode=%v , errormsg=%v", resAccessToken.ErrCode, resAccessToken.ErrMsg) + return + } + + accessTokenCacheKey := fmt.Sprintf("access_token_%s", ctx.AppID) + expires := resAccessToken.ExpiresIn - 1500 + err = ctx.Cache.Set(accessTokenCacheKey, resAccessToken.AccessToken, time.Duration(expires)*time.Second) + return +} diff --git a/context/context.go b/context/context.go new file mode 100644 index 0000000..ce10617 --- /dev/null +++ b/context/context.go @@ -0,0 +1,51 @@ +package context + +import ( + "net/http" + "sync" + + "github.com/silenceper/wechat/cache" +) + +//Context struct +type Context struct { + AppID string + AppSecret string + Token string + EncodingAESKey string + + Cache cache.Cache + + Writer http.ResponseWriter + Request *http.Request + + //accessTokenLock 读写锁 同一个AppID一个 + accessTokenLock *sync.RWMutex + + //jsapiTicket 读写锁 同一个AppID一个 + jsApiTicketLock *sync.RWMutex +} + +// Query returns the keyed url query value if it exists +func (ctx *Context) Query(key string) string { + value, _ := ctx.GetQuery(key) + return value +} + +// GetQuery is like Query(), it returns the keyed url query value +func (ctx *Context) GetQuery(key string) (string, bool) { + req := ctx.Request + if values, ok := req.URL.Query()[key]; ok && len(values) > 0 { + return values[0], true + } + return "", false +} + +//SetJsApiTicket 设置jsApiTicket的lock +func (ctx *Context) SetJsApiTicketLock(lock *sync.RWMutex) { + ctx.jsApiTicketLock = lock +} + +func (ctx *Context) GetJsApiTicketLock() *sync.RWMutex { + return ctx.jsApiTicketLock +} diff --git a/context/render.go b/context/render.go new file mode 100644 index 0000000..e2fb825 --- /dev/null +++ b/context/render.go @@ -0,0 +1,43 @@ +package context + +import ( + "encoding/xml" + "net/http" +) + +var xmlContentType = []string{"application/xml; charset=utf-8"} +var plainContentType = []string{"text/plain; charset=utf-8"} + +//Render render from bytes +func (ctx *Context) Render(bytes []byte) { + //debug + //fmt.Println("response msg = ", string(bytes)) + ctx.Writer.WriteHeader(200) + _, err := ctx.Writer.Write(bytes) + if err != nil { + panic(err) + } +} + +//String render from string +func (ctx *Context) String(str string) { + writeContextType(ctx.Writer, plainContentType) + ctx.Render([]byte(str)) +} + +//XML render to xml +func (ctx *Context) XML(obj interface{}) { + writeContextType(ctx.Writer, xmlContentType) + bytes, err := xml.Marshal(obj) + if err != nil { + panic(err) + } + ctx.Render(bytes) +} + +func writeContextType(w http.ResponseWriter, value []string) { + header := w.Header() + if val := header["Content-Type"]; len(val) == 0 { + header["Content-Type"] = value + } +} diff --git a/js/js.go b/js/js.go new file mode 100644 index 0000000..44edae6 --- /dev/null +++ b/js/js.go @@ -0,0 +1,109 @@ +package js + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/silenceper/wechat/context" + "github.com/silenceper/wechat/util" +) + +const getTicketURL = "https://api.weixin.qq.com/cgi-bin/ticket/getticket?access_token=%s&type=jsapi" + +//Js struct +type Js struct { + *context.Context +} + +//Config 返回给用户jssdk配置信息 +type Config struct { + AppID string + TimeStamp int64 + NonceStr string + Signature string +} + +//resTicket 请求jsapi_tikcet返回结果 +type resTicket struct { + util.CommonError + + Ticket string `json:"ticket"` + ExpiresIn int64 `json:"expires_in"` +} + +//NewJs init +func NewJs(context *context.Context) *Js { + js := new(Js) + js.Context = context + return js +} + +//GetConfig 获取jssdk需要的配置参数 +//uri 为当前网页地址 +func (js *Js) GetConfig(uri string) (config *Config, err error) { + config = new(Config) + var ticketStr string + ticketStr, err = js.getTicket() + if err != nil { + return + } + + nonceStr := util.RandomStr(16) + timestamp := util.GetCurrTs() + str := fmt.Sprintf("jsapi_ticket=%s&noncestr=%s×tamp=%d&url=%s", ticketStr, nonceStr, timestamp, uri) + sigStr := util.Signature(str) + + config.AppID = js.AppID + config.NonceStr = nonceStr + config.TimeStamp = timestamp + config.Signature = sigStr + return +} + +//getTicket 获取jsapi_tocket全局缓存 +func (js *Js) getTicket() (ticketStr string, err error) { + js.GetJsApiTicketLock().Lock() + defer js.GetJsApiTicketLock().Unlock() + + //先从cache中取 + jsAPITicketCacheKey := fmt.Sprintf("jsapi_ticket_%s", js.AppID) + val := js.Cache.Get(jsAPITicketCacheKey) + if val != nil { + ticketStr = val.(string) + return + } + var ticket resTicket + ticket, err = js.getTicketFromServer() + if err != nil { + return + } + ticketStr = ticket.Ticket + return +} + +//getTicketFromServer 强制从服务器中获取ticket +func (js *Js) getTicketFromServer() (ticket resTicket, err error) { + var accessToken string + accessToken, err = js.GetAccessToken() + if err != nil { + return + } + + var response []byte + url := fmt.Sprintf(getTicketURL, accessToken) + response, err = util.HTTPGet(url) + err = json.Unmarshal(response, &ticket) + if err != nil { + return + } + if ticket.ErrCode != 0 { + err = fmt.Errorf("getTicket Error : errcode=%s , errmsg=%s", ticket.ErrCode, ticket.ErrMsg) + return + } + + jsAPITicketCacheKey := fmt.Sprintf("jsapi_ticket_%s", js.AppID) + expires := ticket.ExpiresIn - 1500 + err = js.Cache.Set(jsAPITicketCacheKey, ticket.Ticket, time.Duration(expires)*time.Second) + return +} diff --git a/material/material.go b/material/material.go new file mode 100644 index 0000000..60af0e3 --- /dev/null +++ b/material/material.go @@ -0,0 +1,167 @@ +package material + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/silenceper/wechat/context" + "github.com/silenceper/wechat/util" +) + +const ( + addNewsURL = "https://api.weixin.qq.com/cgi-bin/material/add_news" + addMaterialURL = "https://api.weixin.qq.com/cgi-bin/material/add_material" +) + +//Material 素材管理 +type Material struct { + *context.Context +} + +//NewMaterial init +func NewMaterial(context *context.Context) *Material { + material := new(Material) + material.Context = context + return material +} + +//Article 永久图文素材 +type Article struct { + ThumbMediaID string `json:"thumb_media_id"` + Author string `json:"author"` + Digest string `json:"digest"` + ShowCoverPic int `json:"show_cover_pic"` + Content string `json:"content"` + ContentSourceURL string `json:"content_source_url"` +} + +//reqArticles 永久性图文素材请求信息 +type reqArticles struct { + articles []*Article `json:"articles"` +} + +//resArticles 永久性图文素材返回结果 +type resArticles struct { + util.CommonError + + MediaID string `json:"media_id"` +} + +//AddNews 新增永久图文素材 +func (material *Material) AddNews(articles []*Article) (mediaID string, err error) { + req := &reqArticles{articles} + + var accessToken string + accessToken, err = material.GetAccessToken() + if err != nil { + return + } + + uri := fmt.Sprintf("%s?access_token=%s", addNewsURL, accessToken) + responseBytes, err := util.PostJSON(uri, req) + var res resArticles + err = json.Unmarshal(responseBytes, res) + if err != nil { + return + } + mediaID = res.MediaID + return +} + +//resAddMaterial 永久性素材上传返回的结果 +type resAddMaterial struct { + util.CommonError + + MediaID string `json:"media_id"` + URL string `json:"url"` +} + +//AddMaterial 上传永久性素材(处理视频需要单独上传) +func (material *Material) AddMaterial(mediaType MediaType, filename string) (mediaID string, url string, err error) { + if mediaType == MediaTypeVideo { + err = errors.New("永久视频素材上传使用 AddVideo 方法") + } + var accessToken string + accessToken, err = material.GetAccessToken() + if err != nil { + return + } + + uri := fmt.Sprintf("%s?access_token=%s&type=%s", addMaterialURL, accessToken, mediaType) + var response []byte + response, err = util.PostFile("media", filename, uri) + if err != nil { + return + } + var resMaterial resAddMaterial + err = json.Unmarshal(response, &resMaterial) + if err != nil { + return + } + if resMaterial.ErrCode != 0 { + err = fmt.Errorf("AddMaterial error : errcode=%v , errmsg=%v", resMaterial.ErrCode, resMaterial.ErrMsg) + return + } + mediaID = resMaterial.MediaID + url = resMaterial.URL + return +} + +type reqVideo struct { + Title string `json:"title"` + Introduction string `json:"introduction"` +} + +//AddVideo 永久视频素材文件上传 +func (material *Material) AddVideo(filename, title, introduction string) (mediaID string, url string, err error) { + var accessToken string + accessToken, err = material.GetAccessToken() + if err != nil { + return + } + + uri := fmt.Sprintf("%s?access_token=%s&type=video", addMaterialURL, accessToken) + + videoDesc := &reqVideo{ + Title: title, + Introduction: introduction, + } + var fieldValue []byte + fieldValue, err = json.Marshal(videoDesc) + if err != nil { + return + } + + fields := []util.MultipartFormField{ + { + IsFile: true, + Fieldname: "video", + Filename: filename, + }, + { + IsFile: true, + Fieldname: "description", + Value: fieldValue, + }, + } + + var response []byte + response, err = util.PostMultipartForm(fields, uri) + if err != nil { + return + } + + var resMaterial resAddMaterial + err = json.Unmarshal(response, &resMaterial) + if err != nil { + return + } + if resMaterial.ErrCode != 0 { + err = fmt.Errorf("AddMaterial error : errcode=%v , errmsg=%v", resMaterial.ErrCode, resMaterial.ErrMsg) + return + } + mediaID = resMaterial.MediaID + url = resMaterial.URL + return +} diff --git a/material/media.go b/material/media.go new file mode 100644 index 0000000..aca5f92 --- /dev/null +++ b/material/media.go @@ -0,0 +1,109 @@ +package material + +import ( + "encoding/json" + "fmt" + + "github.com/silenceper/wechat/util" +) + +//MediaType 媒体文件类型 +type MediaType string + +const ( + //MediaTypeImage 媒体文件:图片 + MediaTypeImage MediaType = "image" + //MediaTypeVoice 媒体文件:声音 + MediaTypeVoice = "voice" + //MediaTypeVideo 媒体文件:视频 + MediaTypeVideo = "video" + //MediaTypeThumb 媒体文件:缩略图 + MediaTypeThumb = "thumb" +) + +const ( + mediaUploadURL = "https://api.weixin.qq.com/cgi-bin/media/upload" + mediaUploadImageURL = "https://api.weixin.qq.com/cgi-bin/media/uploadimg" + mediaGetURL = "https://api.weixin.qq.com/cgi-bin/media/get" +) + +//Media 临时素材上传返回信息 +type Media struct { + util.CommonError + + Type MediaType `json:"type"` + MediaID string `json:"media_id"` + CreatedAt int64 `json:"created_at"` +} + +//MediaUpload 临时素材上传 +func (material *Material) MediaUpload(mediaType MediaType, filename string) (media Media, err error) { + var accessToken string + accessToken, err = material.GetAccessToken() + if err != nil { + return + } + + uri := fmt.Sprintf("%s?access_token=%s&type=%s", mediaUploadURL, accessToken, mediaType) + var response []byte + response, err = util.PostFile("media", filename, uri) + if err != nil { + return + } + err = json.Unmarshal(response, &media) + if err != nil { + return + } + if media.ErrCode != 0 { + err = fmt.Errorf("MediaUpload error : errcode=%v , errmsg=%v", media.ErrCode, media.ErrMsg) + return + } + return +} + +//GetMediaURL 返回临时素材的下载地址供用户自己处理 +//NOTICE: URL 不可公开,因为含access_token 需要立即另存文件 +func (material *Material) GetMediaURL(mediaID string) (mediaURL string, err error) { + var accessToken string + accessToken, err = material.GetAccessToken() + if err != nil { + return + } + mediaURL = fmt.Sprintf("%s?access_token=%s&media_id=%s", mediaGetURL, accessToken, mediaID) + return +} + +//resMediaImage 图片上传返回结果 +type resMediaImage struct { + util.CommonError + + URL string `json:"url"` +} + +//ImageUpload 图片上传 +func (material *Material) ImageUpload(filename string) (url string, err error) { + var accessToken string + accessToken, err = material.GetAccessToken() + if err != nil { + return + } + + uri := fmt.Sprintf("%s?access_token=%s", mediaUploadImageURL, accessToken) + var response []byte + response, err = util.PostFile("media", filename, uri) + if err != nil { + return + } + var image resMediaImage + err = json.Unmarshal(response, &image) + if err != nil { + return + } + if image.ErrCode != 0 { + err = fmt.Errorf("UploadImage error : errcode=%v , errmsg=%v", image.ErrCode, image.ErrMsg) + return + } + url = image.URL + return + +} diff --git a/message/image.go b/message/image.go new file mode 100644 index 0000000..93e6bc0 --- /dev/null +++ b/message/image.go @@ -0,0 +1,17 @@ +package message + +//Image 图片消息 +type Image struct { + CommonToken + + Image struct { + MediaID string `xml:"MediaId"` + } `xml:"Image"` +} + +//NewImage 回复图片消息 +func NewImage(mediaID string) *Image { + image := new(Image) + image.Image.MediaID = mediaID + return image +} diff --git a/message/message.go b/message/message.go new file mode 100644 index 0000000..b64b0e3 --- /dev/null +++ b/message/message.go @@ -0,0 +1,120 @@ +package message + +import "encoding/xml" + +//MsgType 基本消息类型 +type MsgType string + +//EventType 事件类型 +type EventType string + +const ( + //MsgTypeText 表示文本消息 + MsgTypeText MsgType = "text" + //MsgTypeImage 表示图片消息 + MsgTypeImage = "image" + //MsgTypeVoice 表示语音消息 + MsgTypeVoice = "voice" + //MsgTypeVideo 表示视频消息 + MsgTypeVideo = "video" + //MsgTypeShortVideo 表示短视频消息[限接收] + MsgTypeShortVideo = "shortvideo" + //MsgTypeLocation 表示坐标消息[限接收] + MsgTypeLocation = "location" + //MsgTypeLink 表示链接消息[限接收] + MsgTypeLink = "link" + //MsgTypeMusic 表示音乐消息[限回复] + MsgTypeMusic = "music" + //MsgTypeNews 表示图文消息[限回复] + MsgTypeNews = "news" + //MsgTypeTransfer 表示消息消息转发到客服 + MsgTypeTransfer = "transfer_customer_service" +) + +const ( + //EventSubscribe 订阅 + EventSubscribe EventType = "subscribe" + //EventUnsubscribe 取消订阅 + EventUnsubscribe EventType = "unsubscribe" + //EventScan 用户已经关注公众号,则微信会将带场景值扫描事件推送给开发者 + EventScan EventType = "SCAN" + //EventLocation 上报地理位置事件 + EventLocation EventType = "LOCATION" + //EventClick 点击菜单拉取消息时的事件推送 + EventClick EventType = "CLICK" + //EventView 点击菜单跳转链接时的事件推送 + EventView EventType = "VIEW" +) + +//MixMessage 存放所有微信发送过来的消息和事件 +type MixMessage struct { + CommonToken + + //基本消息 + MsgID int64 `xml:"MsgId"` + Content string `xml:"Content"` + PicURL string `xml:"PicUrl"` + MediaID string `xml:"MediaId"` + Format string `xml:"Format"` + ThumbMediaID string `xml:"ThumbMediaId"` + LocationX float64 `xml:"Location_X"` + LocationY float64 `xml:"Location_Y"` + Scale float64 `xml:"Scale"` + Label string `xml:"Label"` + Title string `xml:"Title"` + Description string `xml:"Description"` + URL string `xml:"Url"` + + //事件相关 + Event string `xml:"Event"` + EventKey string `xml:"EventKey"` + Ticket string `xml:"Ticket"` + Latitude string `xml:"Latitude"` + Longitude string `xml:"Longitude"` + Precision string `xml:"Precision"` +} + +//EncryptedXMLMsg 安全模式下的消息体 +type EncryptedXMLMsg struct { + XMLName struct{} `xml:"xml" json:"-"` + ToUserName string `xml:"ToUserName" json:"ToUserName"` + EncryptedMsg string `xml:"Encrypt" json:"Encrypt"` +} + +//ResponseEncryptedXMLMsg 需要返回的消息体 +type ResponseEncryptedXMLMsg struct { + XMLName struct{} `xml:"xml" json:"-"` + EncryptedMsg string `xml:"Encrypt" json:"Encrypt"` + MsgSignature string `xml:"MsgSignature" json:"MsgSignature"` + Timestamp int64 `xml:"TimeStamp" json:"TimeStamp"` + Nonce string `xml:"Nonce" json:"Nonce"` +} + +// CommonToken 消息中通用的结构 +type CommonToken struct { + XMLName xml.Name `xml:"xml"` + ToUserName string `xml:"ToUserName"` + FromUserName string `xml:"FromUserName"` + CreateTime int64 `xml:"CreateTime"` + MsgType MsgType `xml:"MsgType"` +} + +//SetToUserName set ToUserName +func (msg *CommonToken) SetToUserName(toUserName string) { + msg.ToUserName = toUserName +} + +//SetFromUserName set FromUserName +func (msg *CommonToken) SetFromUserName(fromUserName string) { + msg.FromUserName = fromUserName +} + +//SetCreateTime set createTime +func (msg *CommonToken) SetCreateTime(createTime int64) { + msg.CreateTime = createTime +} + +//SetMsgType set MsgType +func (msg *CommonToken) SetMsgType(msgType MsgType) { + msg.MsgType = msgType +} diff --git a/message/music.go b/message/music.go new file mode 100644 index 0000000..3e010ed --- /dev/null +++ b/message/music.go @@ -0,0 +1,24 @@ +package message + +//Music 音乐消息 +type Music struct { + CommonToken + + Music struct { + Title string `xml:"Title" ` + Description string `xml:"Description" ` + MusicURL string `xml:"MusicUrl" ` + HQMusicURL string `xml:"HQMusicUrl" ` + ThumbMediaID string `xml:"ThumbMediaId"` + } `xml:"Music"` +} + +//NewMusic 回复音乐消息 +func NewMusic(title, description, musicURL, hQMusicURL, thumbMediaID string) *Music { + music := new(Music) + music.Music.Title = title + music.Music.Description = description + music.Music.MusicURL = musicURL + music.Music.ThumbMediaID = thumbMediaID + return music +} diff --git a/message/news.go b/message/news.go new file mode 100644 index 0000000..ee28b0c --- /dev/null +++ b/message/news.go @@ -0,0 +1,35 @@ +package message + +//News 图文消息 +type News struct { + CommonToken + + ArticleCount int `xml:"ArticleCount"` + Articles []*Article `xml:"Articles>item,omitempty"` +} + +//NewNews 初始化图文消息 +func NewNews(articles []*Article) *News { + news := new(News) + news.ArticleCount = len(articles) + news.Articles = articles + return news +} + +//Article 单篇文章 +type Article struct { + Title string `xml:"Title,omitempty"` + Description string `xml:"Description,omitempty"` + PicURL string `xml:"PicUrl,omitempty"` + URL string `xml:"Url,omitempty"` +} + +//NewArticle 初始化文章 +func NewArticle(title, description, picURL, url string) *Article { + article := new(Article) + article.Title = title + article.Description = description + article.PicURL = picURL + article.URL = url + return article +} diff --git a/message/reply.go b/message/reply.go new file mode 100644 index 0000000..53592f0 --- /dev/null +++ b/message/reply.go @@ -0,0 +1,15 @@ +package message + +import "errors" + +//ErrInvalidReply 无效的回复 +var ErrInvalidReply = errors.New("无效的回复消息") + +//ErrUnsupportReply 不支持的回复类型 +var ErrUnsupportReply = errors.New("不支持的回复消息") + +//Reply 消息回复 +type Reply struct { + MsgType MsgType + MsgData interface{} +} diff --git a/message/text.go b/message/text.go new file mode 100644 index 0000000..d981d96 --- /dev/null +++ b/message/text.go @@ -0,0 +1,14 @@ +package message + +//Text 文本消息 +type Text struct { + CommonToken + Content string `xml:"Content"` +} + +//NewText 初始化文本消息 +func NewText(content string) *Text { + text := new(Text) + text.Content = content + return text +} diff --git a/message/video.go b/message/video.go new file mode 100644 index 0000000..a082065 --- /dev/null +++ b/message/video.go @@ -0,0 +1,21 @@ +package message + +//Video 视频消息 +type Video struct { + CommonToken + + Video struct { + MediaID string `xml:"MediaId"` + Title string `xml:"Title,omitempty"` + Description string `xml:"Description,omitempty"` + } `xml:"Video"` +} + +//NewVideo 回复图片消息 +func NewVideo(mediaID, title, description string) *Video { + video := new(Video) + video.Video.MediaID = mediaID + video.Video.Title = title + video.Video.Description = description + return video +} diff --git a/message/voice.go b/message/voice.go new file mode 100644 index 0000000..d76985c --- /dev/null +++ b/message/voice.go @@ -0,0 +1,17 @@ +package message + +//Voice 语音消息 +type Voice struct { + CommonToken + + Voice struct { + MediaID string `xml:"MediaId"` + } `xml:"Voice"` +} + +//NewVoice 回复语音消息 +func NewVoice(mediaID string) *Voice { + voice := new(Voice) + voice.Voice.MediaID = mediaID + return voice +} diff --git a/oauth/oauth.go b/oauth/oauth.go new file mode 100644 index 0000000..27fe17a --- /dev/null +++ b/oauth/oauth.go @@ -0,0 +1,152 @@ +package oauth + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + + "github.com/silenceper/wechat/context" + "github.com/silenceper/wechat/util" +) + +const ( + redirectOauthURL = "https://open.weixin.qq.com/connect/oauth2/authorize?appid=%s&redirect_uri=%s&response_type=code&scope=%s&state=%s#wechat_redirect" + accessTokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token?appid=%s&secret=%s&code=%s&grant_type=authorization_code" + refreshAccessTokenURL = "https://api.weixin.qq.com/sns/oauth2/refresh_token?appid=%s&grant_type=refresh_token&refresh_token=%s" + userInfoURL = "https://api.weixin.qq.com/sns/userinfo?access_token=%s&openid=%s&lang=zh_CN" + checkAccessTokenURL = "https://api.weixin.qq.com/sns/auth?access_token=%s&openid=%s" +) + +//Oauth 保存用户授权信息 +type Oauth struct { + *context.Context +} + +//NewOauth 实例化授权信息 +func NewOauth(context *context.Context) *Oauth { + auth := new(Oauth) + auth.Context = context + return auth +} + +//GetRedirectURL 获取跳转的url地址 +func (oauth *Oauth) GetRedirectURL(redirectURI, scope, state string) (string, error) { + //url encode + urlStr := url.QueryEscape(redirectURI) + return fmt.Sprintf(redirectOauthURL, oauth.AppID, urlStr, scope, state), nil +} + +//Redirect 跳转到网页授权 +func (oauth *Oauth) Redirect(redirectURI, scope, state string) error { + location, err := oauth.GetRedirectURL(redirectURI, scope, state) + if err != nil { + return err + } + http.Redirect(oauth.Writer, oauth.Request, location, 302) + return nil +} + +// ResAccessToken 获取用户授权access_token的返回结果 +type ResAccessToken struct { + util.CommonError + + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + OpenID string `json:"openid"` + Scope string `json:"scope"` +} + +// GetUserAccessToken 通过网页授权的code 换取access_token(区别于context中的access_token) +func (oauth *Oauth) GetUserAccessToken(code string) (result ResAccessToken, err error) { + urlStr := fmt.Sprintf(accessTokenURL, oauth.AppID, oauth.AppSecret, code) + var response []byte + response, err = util.HTTPGet(urlStr) + if err != nil { + return + } + err = json.Unmarshal(response, &result) + if err != nil { + return + } + if result.ErrCode != 0 { + err = fmt.Errorf("GetUserAccessToken error : errcode=%v , errmsg=%v", result.ErrCode, result.ErrMsg) + return + } + return +} + +//RefreshAccessToken 刷新access_token +func (oauth *Oauth) RefreshAccessToken(refreshToken string) (result ResAccessToken, err error) { + urlStr := fmt.Sprintf(refreshAccessTokenURL, oauth.AppID, refreshToken) + var response []byte + response, err = util.HTTPGet(urlStr) + if err != nil { + return + } + err = json.Unmarshal(response, &result) + if err != nil { + return + } + if result.ErrCode != 0 { + err = fmt.Errorf("GetUserAccessToken error : errcode=%v , errmsg=%v", result.ErrCode, result.ErrMsg) + return + } + return +} + +//CheckAccessToken 检验access_token是否有效 +func (oauth *Oauth) CheckAccessToken(accessToken, openID string) (b bool, err error) { + urlStr := fmt.Sprintf(checkAccessTokenURL, accessToken, openID) + var response []byte + response, err = util.HTTPGet(urlStr) + if err != nil { + return + } + var result util.CommonError + err = json.Unmarshal(response, &result) + if err != nil { + return + } + if result.ErrCode != 0 { + b = false + return + } + b = true + return +} + +//UserInfo 用户授权获取到用户信息 +type UserInfo struct { + util.CommonError + + OpenID string `json:"openid"` + Nickname string `json:"nickname"` + Sex int32 `json:"sex"` + Province string `json:"province"` + City string `json:"city"` + Country string `json:"country"` + HeadImgURL string `json:"headimgurl"` + Privilege []string `json:"privilege"` + Unionid string `json:"unionid"` +} + +//GetUserInfo 如果scope为 snsapi_userinfo 则可以通过此方法获取到用户基本信息 +func (oauth *Oauth) GetUserInfo(accessToken, openID string) (result UserInfo, err error) { + urlStr := fmt.Sprintf(userInfoURL, accessToken, openID) + var response []byte + response, err = util.HTTPGet(urlStr) + if err != nil { + return + } + err = json.Unmarshal(response, &result) + if err != nil { + return + } + if result.ErrCode != 0 { + err = fmt.Errorf("GetUserInfo error : errcode=%v , errmsg=%v", result.ErrCode, result.ErrMsg) + return + } + return +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..8e6b882 --- /dev/null +++ b/server/server.go @@ -0,0 +1,233 @@ +package server + +import ( + "encoding/xml" + "errors" + "fmt" + "io/ioutil" + "reflect" + "runtime/debug" + "strconv" + "strings" + + "github.com/silenceper/wechat/context" + "github.com/silenceper/wechat/message" + "github.com/silenceper/wechat/util" +) + +//Server struct +type Server struct { + *context.Context + + openID string + + messageHandler func(message.MixMessage) *message.Reply + + requestRawXMLMsg []byte + requestMsg message.MixMessage + responseRawXMLMsg []byte + responseMsg interface{} + + isSafeMode bool + random []byte + nonce string + timestamp int64 +} + +//NewServer init +func NewServer(context *context.Context) *Server { + srv := new(Server) + srv.Context = context + return srv +} + +//Serve 处理微信的请求消息 +func (srv *Server) Serve() error { + if !srv.Validate() { + return fmt.Errorf("请求校验失败") + } + + echostr, exists := srv.GetQuery("echostr") + if exists { + srv.String(echostr) + return nil + } + + response, err := srv.handleRequest() + if err != nil { + return err + } + + //debug + //fmt.Println("request msg = ", string(srv.requestRawXMLMsg)) + + return srv.buildResponse(response) +} + +//Validate 校验请求是否合法 +func (srv *Server) Validate() bool { + timestamp := srv.Query("timestamp") + nonce := srv.Query("nonce") + signature := srv.Query("signature") + return signature == util.Signature(srv.Token, timestamp, nonce) +} + +//HandleRequest 处理微信的请求 +func (srv *Server) handleRequest() (reply *message.Reply, err error) { + //set isSafeMode + srv.isSafeMode = false + encryptType := srv.Query("encrypt_type") + if encryptType == "aes" { + srv.isSafeMode = true + } + + //set openID + srv.openID = srv.Query("openid") + + var msg interface{} + msg, err = srv.getMessage() + if err != nil { + return + } + mixMessage, success := msg.(message.MixMessage) + if !success { + err = errors.New("消息类型转换失败") + } + srv.requestMsg = mixMessage + reply = srv.messageHandler(mixMessage) + return +} + +//GetOpenID return openID +func (srv *Server) GetOpenID() string { + return srv.openID +} + +//getMessage 解析微信返回的消息 +func (srv *Server) getMessage() (interface{}, error) { + var rawXMLMsgBytes []byte + var err error + if srv.isSafeMode { + var encryptedXMLMsg message.EncryptedXMLMsg + if err := xml.NewDecoder(srv.Request.Body).Decode(&encryptedXMLMsg); err != nil { + return nil, fmt.Errorf("从body中解析xml失败,err=%v", err) + } + + //验证消息签名 + timestamp := srv.Query("timestamp") + srv.timestamp, err = strconv.ParseInt(timestamp, 10, 32) + if err != nil { + return nil, err + } + nonce := srv.Query("nonce") + srv.nonce = nonce + msgSignature := srv.Query("msg_signature") + msgSignatureGen := util.Signature(srv.Token, timestamp, nonce, encryptedXMLMsg.EncryptedMsg) + if msgSignature != msgSignatureGen { + return nil, fmt.Errorf("消息不合法,验证签名失败") + } + + //解密 + srv.random, rawXMLMsgBytes, err = util.DecryptMsg(srv.AppID, encryptedXMLMsg.EncryptedMsg, srv.EncodingAESKey) + if err != nil { + return nil, fmt.Errorf("消息解密失败, err=%v", err) + } + } else { + rawXMLMsgBytes, err = ioutil.ReadAll(srv.Request.Body) + if err != nil { + return nil, fmt.Errorf("从body中解析xml失败, err=%v", err) + } + } + + srv.requestRawXMLMsg = rawXMLMsgBytes + + return srv.parseRequestMessage(rawXMLMsgBytes) +} + +func (srv *Server) parseRequestMessage(rawXMLMsgBytes []byte) (msg message.MixMessage, err error) { + msg = message.MixMessage{} + err = xml.Unmarshal(rawXMLMsgBytes, &msg) + return +} + +//SetMessageHandler 设置用户自定义的回调方法 +func (srv *Server) SetMessageHandler(handler func(message.MixMessage) *message.Reply) { + srv.messageHandler = handler +} + +func (srv *Server) buildResponse(reply *message.Reply) (err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("panic error: %v\n%s", e, debug.Stack()) + } + }() + if reply == nil { + //do nothing + return nil + } + msgType := reply.MsgType + switch msgType { + case message.MsgTypeText: + case message.MsgTypeImage: + case message.MsgTypeVoice: + case message.MsgTypeVideo: + case message.MsgTypeMusic: + case message.MsgTypeNews: + case message.MsgTypeTransfer: + default: + err = message.ErrUnsupportReply + return + } + + msgData := reply.MsgData + value := reflect.ValueOf(msgData) + //msgData must be a ptr + kind := value.Kind().String() + if 0 != strings.Compare("ptr", kind) { + return message.ErrUnsupportReply + } + + params := make([]reflect.Value, 1) + params[0] = reflect.ValueOf(srv.requestMsg.FromUserName) + value.MethodByName("SetToUserName").Call(params) + + params[0] = reflect.ValueOf(srv.requestMsg.ToUserName) + value.MethodByName("SetFromUserName").Call(params) + + params[0] = reflect.ValueOf(msgType) + value.MethodByName("SetMsgType").Call(params) + + params[0] = reflect.ValueOf(util.GetCurrTs()) + value.MethodByName("SetCreateTime").Call(params) + + srv.responseMsg = msgData + srv.responseRawXMLMsg, err = xml.Marshal(msgData) + return +} + +//Send 将自定义的消息发送 +func (srv *Server) Send() (err error) { + replyMsg := srv.responseMsg + if srv.isSafeMode { + //安全模式下对消息进行加密 + var encryptedMsg []byte + encryptedMsg, err = util.EncryptMsg(srv.random, srv.responseRawXMLMsg, srv.AppID, srv.EncodingAESKey) + if err != nil { + return + } + //TODO 如果获取不到timestamp nonce 则自己生成 + timestamp := srv.timestamp + timestampStr := strconv.FormatInt(timestamp, 10) + msgSignature := util.Signature(srv.Token, timestampStr, srv.nonce, string(encryptedMsg)) + replyMsg = message.ResponseEncryptedXMLMsg{ + EncryptedMsg: string(encryptedMsg), + MsgSignature: msgSignature, + Timestamp: timestamp, + Nonce: srv.nonce, + } + } + if replyMsg != nil { + srv.XML(replyMsg) + } + return +} diff --git a/util/crypto.go b/util/crypto.go new file mode 100644 index 0000000..1ec7a4f --- /dev/null +++ b/util/crypto.go @@ -0,0 +1,183 @@ +package util + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "fmt" +) + +//EncryptMsg 加密消息 +func EncryptMsg(random, rawXMLMsg []byte, appID, aesKey string) (encrtptMsg []byte, err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("panic error: err=%v", e) + return + } + }() + var key []byte + key, err = aesKeyDecode(aesKey) + if err != nil { + panic(err) + } + ciphertext := AESEncryptMsg(random, rawXMLMsg, appID, key) + encrtptMsg = []byte(base64.StdEncoding.EncodeToString(ciphertext)) + return +} + +//AESEncryptMsg ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId] +//参考:github.com/chanxuehong/wechat.v2 +func AESEncryptMsg(random, rawXMLMsg []byte, appID string, aesKey []byte) (ciphertext []byte) { + const ( + BlockSize = 32 // PKCS#7 + BlockMask = BlockSize - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数 + ) + + appIDOffset := 20 + len(rawXMLMsg) + contentLen := appIDOffset + len(appID) + amountToPad := BlockSize - contentLen&BlockMask + plaintextLen := contentLen + amountToPad + + plaintext := make([]byte, plaintextLen) + + // 拼接 + copy(plaintext[:16], random) + encodeNetworkByteOrder(plaintext[16:20], uint32(len(rawXMLMsg))) + copy(plaintext[20:], rawXMLMsg) + copy(plaintext[appIDOffset:], appID) + + // PKCS#7 补位 + for i := contentLen; i < plaintextLen; i++ { + plaintext[i] = byte(amountToPad) + } + + // 加密 + block, err := aes.NewCipher(aesKey[:]) + if err != nil { + panic(err) + } + mode := cipher.NewCBCEncrypter(block, aesKey[:16]) + mode.CryptBlocks(plaintext, plaintext) + + ciphertext = plaintext + return +} + +//DecryptMsg 消息解密 +func DecryptMsg(appID, encryptedMsg, aesKey string) (random, rawMsgXMLBytes []byte, err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("panic error: err=%v", e) + return + } + }() + var encryptedMsgBytes, key, getAppIDBytes []byte + encryptedMsgBytes, err = base64.StdEncoding.DecodeString(encryptedMsg) + if err != nil { + return + } + key, err = aesKeyDecode(aesKey) + if err != nil { + panic(err) + } + random, rawMsgXMLBytes, getAppIDBytes, err = AESDecryptMsg(encryptedMsgBytes, key) + if err != nil { + err = fmt.Errorf("消息解密失败,%v", err) + return + } + if appID != string(getAppIDBytes) { + err = fmt.Errorf("消息解密校验APPID失败") + return + } + return +} + +func aesKeyDecode(encodedAESKey string) (key []byte, err error) { + if len(encodedAESKey) != 43 { + err = fmt.Errorf("the length of encodedAESKey must be equal to 43") + return + } + key, err = base64.StdEncoding.DecodeString(encodedAESKey + "=") + if err != nil { + return + } + if len(key) != 32 { + err = fmt.Errorf("encodingAESKey invalid") + return + } + return +} + +// AESDecryptMsg ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId] +//参考:github.com/chanxuehong/wechat.v2 +func AESDecryptMsg(ciphertext []byte, aesKey []byte) (random, rawXMLMsg, appID []byte, err error) { + const ( + BlockSize = 32 // PKCS#7 + BlockMask = BlockSize - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数 + ) + + if len(ciphertext) < BlockSize { + err = fmt.Errorf("the length of ciphertext too short: %d", len(ciphertext)) + return + } + if len(ciphertext)&BlockMask != 0 { + err = fmt.Errorf("ciphertext is not a multiple of the block size, the length is %d", len(ciphertext)) + return + } + + plaintext := make([]byte, len(ciphertext)) // len(plaintext) >= BLOCK_SIZE + + // 解密 + block, err := aes.NewCipher(aesKey) + if err != nil { + panic(err) + } + mode := cipher.NewCBCDecrypter(block, aesKey[:16]) + mode.CryptBlocks(plaintext, ciphertext) + + // PKCS#7 去除补位 + amountToPad := int(plaintext[len(plaintext)-1]) + if amountToPad < 1 || amountToPad > BlockSize { + err = fmt.Errorf("the amount to pad is incorrect: %d", amountToPad) + return + } + plaintext = plaintext[:len(plaintext)-amountToPad] + + // 反拼接 + // len(plaintext) == 16+4+len(rawXMLMsg)+len(appId) + if len(plaintext) <= 20 { + err = fmt.Errorf("plaintext too short, the length is %d", len(plaintext)) + return + } + rawXMLMsgLen := int(decodeNetworkByteOrder(plaintext[16:20])) + if rawXMLMsgLen < 0 { + err = fmt.Errorf("incorrect msg length: %d", rawXMLMsgLen) + return + } + appIDOffset := 20 + rawXMLMsgLen + if len(plaintext) <= appIDOffset { + err = fmt.Errorf("msg length too large: %d", rawXMLMsgLen) + return + } + + random = plaintext[:16:20] + rawXMLMsg = plaintext[20:appIDOffset:appIDOffset] + appID = plaintext[appIDOffset:] + return +} + +// 把整数 n 格式化成 4 字节的网络字节序 +func encodeNetworkByteOrder(orderBytes []byte, n uint32) { + orderBytes[0] = byte(n >> 24) + orderBytes[1] = byte(n >> 16) + orderBytes[2] = byte(n >> 8) + orderBytes[3] = byte(n) +} + +// 从 4 字节的网络字节序里解析出整数 +func decodeNetworkByteOrder(orderBytes []byte) (n uint32) { + return uint32(orderBytes[0])<<24 | + uint32(orderBytes[1])<<16 | + uint32(orderBytes[2])<<8 | + uint32(orderBytes[3]) +} diff --git a/util/error.go b/util/error.go new file mode 100644 index 0000000..b8f4211 --- /dev/null +++ b/util/error.go @@ -0,0 +1,6 @@ +package util + +type CommonError struct { + ErrCode int64 `json:"errcode"` + ErrMsg string `json:"errmsg"` +} diff --git a/util/http.go b/util/http.go new file mode 100644 index 0000000..b4f195d --- /dev/null +++ b/util/http.go @@ -0,0 +1,154 @@ +package util + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "mime/multipart" + "net/http" + "os" +) + +//HTTPGet get 请求 +func HTTPGet(uri string) ([]byte, error) { + response, err := http.Get(uri) + if err != nil { + return nil, err + } + + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("http get error : uri=%v , statusCode=%v", uri, response.StatusCode) + } + return ioutil.ReadAll(response.Body) +} + +//PostJSON post json 数据请求 +func PostJSON(uri string, obj interface{}) ([]byte, error) { + jsonData, err := json.Marshal(obj) + if err != nil { + return nil, err + } + body := bytes.NewBuffer(jsonData) + response, err := http.Post(uri, "application/json;charset=utf-8", body) + if err != nil { + return nil, err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("http get error : uri=%v , statusCode=%v", uri, response.StatusCode) + } + return ioutil.ReadAll(response.Body) +} + +//PostFile 上传文件 +func PostFile(fieldname, filename, uri string) ([]byte, error) { + /* + bodyBuf := &bytes.Buffer{} + bodyWriter := multipart.NewWriter(bodyBuf) + + fileWriter, err := bodyWriter.CreateFormFile(fieldname, filename) + if err != nil { + return nil, fmt.Errorf("error writing to buffer") + } + + fh, err := os.Open(filename) + if err != nil { + return nil, fmt.Errorf("error opening file") + } + defer fh.Close() + + _, err = io.Copy(fileWriter, fh) + if err != nil { + return nil, err + } + + contentType := bodyWriter.FormDataContentType() + bodyWriter.Close() + + resp, err := http.Post(uri, contentType, bodyBuf) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, err + } + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + return respBody, nil + */ + fields := []MultipartFormField{ + { + IsFile: true, + Fieldname: fieldname, + Filename: filename, + }, + } + return PostMultipartForm(fields, uri) +} + +//MultipartFormField 保存文件或其他字段信息 +type MultipartFormField struct { + IsFile bool + Fieldname string + Value []byte + Filename string +} + +//PostMultipartForm 上传文件或其他多个字段 +func PostMultipartForm(fields []MultipartFormField, uri string) (respBody []byte, err error) { + bodyBuf := &bytes.Buffer{} + bodyWriter := multipart.NewWriter(bodyBuf) + + for _, field := range fields { + if field.IsFile { + fileWriter, e := bodyWriter.CreateFormFile(field.Fieldname, field.Filename) + if e != nil { + err = fmt.Errorf("error writing to buffer , err=%v", e) + return + } + + fh, e := os.Open(field.Filename) + if e != nil { + err = fmt.Errorf("error opening file , err=%v", e) + return + } + defer fh.Close() + + if _, err = io.Copy(fileWriter, fh); err != nil { + return + } + } else { + partWriter, e := bodyWriter.CreateFormField(field.Fieldname) + if e != nil { + err = e + return + } + valueReader := bytes.NewReader(field.Value) + if _, err = io.Copy(partWriter, valueReader); err != nil { + return + } + } + } + + contentType := bodyWriter.FormDataContentType() + bodyWriter.Close() + + resp, e := http.Post(uri, contentType, bodyBuf) + if e != nil { + err = e + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, err + } + respBody, err = ioutil.ReadAll(resp.Body) + return +} diff --git a/util/signature.go b/util/signature.go new file mode 100644 index 0000000..a65ee23 --- /dev/null +++ b/util/signature.go @@ -0,0 +1,18 @@ +package util + +import ( + "crypto/sha1" + "fmt" + "io" + "sort" +) + +//Signature sha1签名 +func Signature(params ...string) string { + sort.Strings(params) + h := sha1.New() + for _, s := range params { + io.WriteString(h, s) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} diff --git a/util/signature_test.go b/util/signature_test.go new file mode 100644 index 0000000..9aa2f7f --- /dev/null +++ b/util/signature_test.go @@ -0,0 +1,11 @@ +package util + +import "testing" + +func TestSignature(t *testing.T) { + //abc sig + abc := "a9993e364706816aba3e25717850c26c9cd0d89d" + if abc != Signature("a", "b", "c") { + t.Error("test Signature Error") + } +} diff --git a/util/string.go b/util/string.go new file mode 100644 index 0000000..62b5c13 --- /dev/null +++ b/util/string.go @@ -0,0 +1,18 @@ +package util + +import ( + "math/rand" + "time" +) + +//RandomStr 随机生成字符串 +func RandomStr(length int) string { + str := "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + bytes := []byte(str) + result := []byte{} + r := rand.New(rand.NewSource(time.Now().UnixNano())) + for i := 0; i < length; i++ { + result = append(result, bytes[r.Intn(len(bytes))]) + } + return string(result) +} diff --git a/util/time.go b/util/time.go new file mode 100644 index 0000000..024839c --- /dev/null +++ b/util/time.go @@ -0,0 +1,8 @@ +package util + +import "time" + +//GetCurrTs return current timestamps +func GetCurrTs() int64 { + return time.Now().Unix() +} diff --git a/wechat.go b/wechat.go new file mode 100644 index 0000000..48f207e --- /dev/null +++ b/wechat.go @@ -0,0 +1,70 @@ +package wechat + +import ( + "net/http" + "sync" + + "github.com/silenceper/wechat/cache" + "github.com/silenceper/wechat/context" + "github.com/silenceper/wechat/js" + "github.com/silenceper/wechat/material" + "github.com/silenceper/wechat/oauth" + "github.com/silenceper/wechat/server" +) + +//Wechat struct +type Wechat struct { + Context *context.Context +} + +//Config for user +type Config struct { + AppID string + AppSecret string + Token string + EncodingAESKey string + Cache cache.Cache +} + +//NewWechat init +func NewWechat(cfg *Config) *Wechat { + context := new(context.Context) + copyConfigToContext(cfg, context) + return &Wechat{context} +} + +func copyConfigToContext(cfg *Config, context *context.Context) { + context.AppID = cfg.AppID + context.AppSecret = cfg.AppSecret + context.Token = cfg.Token + context.EncodingAESKey = cfg.EncodingAESKey + context.Cache = cfg.Cache + context.SetAccessTokenLock(new(sync.RWMutex)) + context.SetJsApiTicketLock(new(sync.RWMutex)) +} + +//GetServer init +func (wc *Wechat) GetServer(req *http.Request, writer http.ResponseWriter) *server.Server { + wc.Context.Request = req + wc.Context.Writer = writer + return server.NewServer(wc.Context) +} + +//GetMaterial init +func (wc *Wechat) GetMaterial() *material.Material { + return material.NewMaterial(wc.Context) +} + +//GetOauth init +func (wc *Wechat) GetOauth(req *http.Request, writer http.ResponseWriter) *oauth.Oauth { + wc.Context.Request = req + wc.Context.Writer = writer + return oauth.NewOauth(wc.Context) +} + +//GetJs init +func (wc *Wechat) GetJs(req *http.Request, writer http.ResponseWriter) *js.Js { + wc.Context.Request = req + wc.Context.Writer = writer + return js.NewJs(wc.Context) +}