From 4333691b37aa469797c9fdf642dc44cf19e53de4 Mon Sep 17 00:00:00 2001 From: wenzl Date: Thu, 15 Sep 2016 01:27:28 +0800 Subject: [PATCH] oauth2,jssdk --- cache/cache.go | 2 +- context/access_token.go | 12 +-- context/context.go | 12 +++ context/render.go | 2 + js/js.go | 109 ++++++++++++++++++++++++++ material/material.go | 167 ++++++++++++++++++++++++++++++++++++++++ material/media.go | 109 ++++++++++++++++++++++++++ message/image.go | 17 ++++ message/message.go | 56 +++++++------- message/music.go | 24 ++++++ message/news.go | 35 +++++++++ message/text.go | 7 ++ message/video.go | 21 +++++ message/voice.go | 17 ++++ oauth/oauth.go | 152 ++++++++++++++++++++++++++++++++++++ server/server.go | 13 +++- util/error.go | 6 ++ util/http.go | 143 +++++++++++++++++++++++++++++++++- util/signature_test.go | 11 +++ util/string.go | 18 +++++ wechat.go | 23 ++++++ 21 files changed, 913 insertions(+), 43 deletions(-) create mode 100644 js/js.go create mode 100644 material/material.go create mode 100644 material/media.go create mode 100644 message/image.go create mode 100644 message/music.go create mode 100644 message/news.go create mode 100644 message/video.go create mode 100644 message/voice.go create mode 100644 oauth/oauth.go create mode 100644 util/error.go create mode 100644 util/signature_test.go create mode 100644 util/string.go diff --git a/cache/cache.go b/cache/cache.go index 614999f..d144b8f 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -5,7 +5,7 @@ import "time" //Cache interface type Cache interface { Get(key string) interface{} - Set(key string, val interface{}, timeput time.Duration) error + Set(key string, val interface{}, timeout time.Duration) error IsExist(key string) bool Delete(key string) error } diff --git a/context/access_token.go b/context/access_token.go index ae55564..3b3d263 100644 --- a/context/access_token.go +++ b/context/access_token.go @@ -16,11 +16,10 @@ const ( //ResAccessToken struct type ResAccessToken struct { - ErrorCode int32 `json:"errcode"` - ErrorMsg string `json:"errmsg"` + util.CommonError AccessToken string `json:"access_token"` - ExpiresIn int32 `json:"expires_in"` + ExpiresIn int64 `json:"expires_in"` } //SetAccessTokenLock 设置读写锁(一个appID一个读写锁) @@ -60,12 +59,13 @@ func (ctx *Context) GetAccessTokenFromServer() (resAccessToken ResAccessToken, e if err != nil { return } - if resAccessToken.ErrorMsg != "" { - err = fmt.Errorf("get access_token error : errcode=%v , errormsg=%v", resAccessToken.ErrorCode, resAccessToken.ErrorMsg) + if resAccessToken.ErrMsg != "" { + err = fmt.Errorf("get access_token error : errcode=%v , errormsg=%v", resAccessToken.ErrCode, resAccessToken.ErrMsg) return } accessTokenCacheKey := fmt.Sprintf("access_token_%s", ctx.AppID) - err = ctx.Cache.Set(accessTokenCacheKey, resAccessToken.AccessToken, time.Duration(resAccessToken.ExpiresIn)-1500/time.Second) + expires := resAccessToken.ExpiresIn - 1500 + err = ctx.Cache.Set(accessTokenCacheKey, resAccessToken.AccessToken, time.Duration(expires)*time.Second) return } diff --git a/context/context.go b/context/context.go index 75bd4a8..ce10617 100644 --- a/context/context.go +++ b/context/context.go @@ -21,6 +21,9 @@ type Context struct { //accessTokenLock 读写锁 同一个AppID一个 accessTokenLock *sync.RWMutex + + //jsapiTicket 读写锁 同一个AppID一个 + jsApiTicketLock *sync.RWMutex } // Query returns the keyed url query value if it exists @@ -37,3 +40,12 @@ func (ctx *Context) GetQuery(key string) (string, bool) { } return "", false } + +//SetJsApiTicket 设置jsApiTicket的lock +func (ctx *Context) SetJsApiTicketLock(lock *sync.RWMutex) { + ctx.jsApiTicketLock = lock +} + +func (ctx *Context) GetJsApiTicketLock() *sync.RWMutex { + return ctx.jsApiTicketLock +} diff --git a/context/render.go b/context/render.go index 35478dd..e2fb825 100644 --- a/context/render.go +++ b/context/render.go @@ -10,6 +10,8 @@ var plainContentType = []string{"text/plain; charset=utf-8"} //Render render from bytes func (ctx *Context) Render(bytes []byte) { + //debug + //fmt.Println("response msg = ", string(bytes)) ctx.Writer.WriteHeader(200) _, err := ctx.Writer.Write(bytes) if err != nil { diff --git a/js/js.go b/js/js.go new file mode 100644 index 0000000..44edae6 --- /dev/null +++ b/js/js.go @@ -0,0 +1,109 @@ +package js + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/silenceper/wechat/context" + "github.com/silenceper/wechat/util" +) + +const getTicketURL = "https://api.weixin.qq.com/cgi-bin/ticket/getticket?access_token=%s&type=jsapi" + +//Js struct +type Js struct { + *context.Context +} + +//Config 返回给用户jssdk配置信息 +type Config struct { + AppID string + TimeStamp int64 + NonceStr string + Signature string +} + +//resTicket 请求jsapi_tikcet返回结果 +type resTicket struct { + util.CommonError + + Ticket string `json:"ticket"` + ExpiresIn int64 `json:"expires_in"` +} + +//NewJs init +func NewJs(context *context.Context) *Js { + js := new(Js) + js.Context = context + return js +} + +//GetConfig 获取jssdk需要的配置参数 +//uri 为当前网页地址 +func (js *Js) GetConfig(uri string) (config *Config, err error) { + config = new(Config) + var ticketStr string + ticketStr, err = js.getTicket() + if err != nil { + return + } + + nonceStr := util.RandomStr(16) + timestamp := util.GetCurrTs() + str := fmt.Sprintf("jsapi_ticket=%s&noncestr=%s×tamp=%d&url=%s", ticketStr, nonceStr, timestamp, uri) + sigStr := util.Signature(str) + + config.AppID = js.AppID + config.NonceStr = nonceStr + config.TimeStamp = timestamp + config.Signature = sigStr + return +} + +//getTicket 获取jsapi_tocket全局缓存 +func (js *Js) getTicket() (ticketStr string, err error) { + js.GetJsApiTicketLock().Lock() + defer js.GetJsApiTicketLock().Unlock() + + //先从cache中取 + jsAPITicketCacheKey := fmt.Sprintf("jsapi_ticket_%s", js.AppID) + val := js.Cache.Get(jsAPITicketCacheKey) + if val != nil { + ticketStr = val.(string) + return + } + var ticket resTicket + ticket, err = js.getTicketFromServer() + if err != nil { + return + } + ticketStr = ticket.Ticket + return +} + +//getTicketFromServer 强制从服务器中获取ticket +func (js *Js) getTicketFromServer() (ticket resTicket, err error) { + var accessToken string + accessToken, err = js.GetAccessToken() + if err != nil { + return + } + + var response []byte + url := fmt.Sprintf(getTicketURL, accessToken) + response, err = util.HTTPGet(url) + err = json.Unmarshal(response, &ticket) + if err != nil { + return + } + if ticket.ErrCode != 0 { + err = fmt.Errorf("getTicket Error : errcode=%s , errmsg=%s", ticket.ErrCode, ticket.ErrMsg) + return + } + + jsAPITicketCacheKey := fmt.Sprintf("jsapi_ticket_%s", js.AppID) + expires := ticket.ExpiresIn - 1500 + err = js.Cache.Set(jsAPITicketCacheKey, ticket.Ticket, time.Duration(expires)*time.Second) + return +} diff --git a/material/material.go b/material/material.go new file mode 100644 index 0000000..60af0e3 --- /dev/null +++ b/material/material.go @@ -0,0 +1,167 @@ +package material + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/silenceper/wechat/context" + "github.com/silenceper/wechat/util" +) + +const ( + addNewsURL = "https://api.weixin.qq.com/cgi-bin/material/add_news" + addMaterialURL = "https://api.weixin.qq.com/cgi-bin/material/add_material" +) + +//Material 素材管理 +type Material struct { + *context.Context +} + +//NewMaterial init +func NewMaterial(context *context.Context) *Material { + material := new(Material) + material.Context = context + return material +} + +//Article 永久图文素材 +type Article struct { + ThumbMediaID string `json:"thumb_media_id"` + Author string `json:"author"` + Digest string `json:"digest"` + ShowCoverPic int `json:"show_cover_pic"` + Content string `json:"content"` + ContentSourceURL string `json:"content_source_url"` +} + +//reqArticles 永久性图文素材请求信息 +type reqArticles struct { + articles []*Article `json:"articles"` +} + +//resArticles 永久性图文素材返回结果 +type resArticles struct { + util.CommonError + + MediaID string `json:"media_id"` +} + +//AddNews 新增永久图文素材 +func (material *Material) AddNews(articles []*Article) (mediaID string, err error) { + req := &reqArticles{articles} + + var accessToken string + accessToken, err = material.GetAccessToken() + if err != nil { + return + } + + uri := fmt.Sprintf("%s?access_token=%s", addNewsURL, accessToken) + responseBytes, err := util.PostJSON(uri, req) + var res resArticles + err = json.Unmarshal(responseBytes, res) + if err != nil { + return + } + mediaID = res.MediaID + return +} + +//resAddMaterial 永久性素材上传返回的结果 +type resAddMaterial struct { + util.CommonError + + MediaID string `json:"media_id"` + URL string `json:"url"` +} + +//AddMaterial 上传永久性素材(处理视频需要单独上传) +func (material *Material) AddMaterial(mediaType MediaType, filename string) (mediaID string, url string, err error) { + if mediaType == MediaTypeVideo { + err = errors.New("永久视频素材上传使用 AddVideo 方法") + } + var accessToken string + accessToken, err = material.GetAccessToken() + if err != nil { + return + } + + uri := fmt.Sprintf("%s?access_token=%s&type=%s", addMaterialURL, accessToken, mediaType) + var response []byte + response, err = util.PostFile("media", filename, uri) + if err != nil { + return + } + var resMaterial resAddMaterial + err = json.Unmarshal(response, &resMaterial) + if err != nil { + return + } + if resMaterial.ErrCode != 0 { + err = fmt.Errorf("AddMaterial error : errcode=%v , errmsg=%v", resMaterial.ErrCode, resMaterial.ErrMsg) + return + } + mediaID = resMaterial.MediaID + url = resMaterial.URL + return +} + +type reqVideo struct { + Title string `json:"title"` + Introduction string `json:"introduction"` +} + +//AddVideo 永久视频素材文件上传 +func (material *Material) AddVideo(filename, title, introduction string) (mediaID string, url string, err error) { + var accessToken string + accessToken, err = material.GetAccessToken() + if err != nil { + return + } + + uri := fmt.Sprintf("%s?access_token=%s&type=video", addMaterialURL, accessToken) + + videoDesc := &reqVideo{ + Title: title, + Introduction: introduction, + } + var fieldValue []byte + fieldValue, err = json.Marshal(videoDesc) + if err != nil { + return + } + + fields := []util.MultipartFormField{ + { + IsFile: true, + Fieldname: "video", + Filename: filename, + }, + { + IsFile: true, + Fieldname: "description", + Value: fieldValue, + }, + } + + var response []byte + response, err = util.PostMultipartForm(fields, uri) + if err != nil { + return + } + + var resMaterial resAddMaterial + err = json.Unmarshal(response, &resMaterial) + if err != nil { + return + } + if resMaterial.ErrCode != 0 { + err = fmt.Errorf("AddMaterial error : errcode=%v , errmsg=%v", resMaterial.ErrCode, resMaterial.ErrMsg) + return + } + mediaID = resMaterial.MediaID + url = resMaterial.URL + return +} diff --git a/material/media.go b/material/media.go new file mode 100644 index 0000000..aca5f92 --- /dev/null +++ b/material/media.go @@ -0,0 +1,109 @@ +package material + +import ( + "encoding/json" + "fmt" + + "github.com/silenceper/wechat/util" +) + +//MediaType 媒体文件类型 +type MediaType string + +const ( + //MediaTypeImage 媒体文件:图片 + MediaTypeImage MediaType = "image" + //MediaTypeVoice 媒体文件:声音 + MediaTypeVoice = "voice" + //MediaTypeVideo 媒体文件:视频 + MediaTypeVideo = "video" + //MediaTypeThumb 媒体文件:缩略图 + MediaTypeThumb = "thumb" +) + +const ( + mediaUploadURL = "https://api.weixin.qq.com/cgi-bin/media/upload" + mediaUploadImageURL = "https://api.weixin.qq.com/cgi-bin/media/uploadimg" + mediaGetURL = "https://api.weixin.qq.com/cgi-bin/media/get" +) + +//Media 临时素材上传返回信息 +type Media struct { + util.CommonError + + Type MediaType `json:"type"` + MediaID string `json:"media_id"` + CreatedAt int64 `json:"created_at"` +} + +//MediaUpload 临时素材上传 +func (material *Material) MediaUpload(mediaType MediaType, filename string) (media Media, err error) { + var accessToken string + accessToken, err = material.GetAccessToken() + if err != nil { + return + } + + uri := fmt.Sprintf("%s?access_token=%s&type=%s", mediaUploadURL, accessToken, mediaType) + var response []byte + response, err = util.PostFile("media", filename, uri) + if err != nil { + return + } + err = json.Unmarshal(response, &media) + if err != nil { + return + } + if media.ErrCode != 0 { + err = fmt.Errorf("MediaUpload error : errcode=%v , errmsg=%v", media.ErrCode, media.ErrMsg) + return + } + return +} + +//GetMediaURL 返回临时素材的下载地址供用户自己处理 +//NOTICE: URL 不可公开,因为含access_token 需要立即另存文件 +func (material *Material) GetMediaURL(mediaID string) (mediaURL string, err error) { + var accessToken string + accessToken, err = material.GetAccessToken() + if err != nil { + return + } + mediaURL = fmt.Sprintf("%s?access_token=%s&media_id=%s", mediaGetURL, accessToken, mediaID) + return +} + +//resMediaImage 图片上传返回结果 +type resMediaImage struct { + util.CommonError + + URL string `json:"url"` +} + +//ImageUpload 图片上传 +func (material *Material) ImageUpload(filename string) (url string, err error) { + var accessToken string + accessToken, err = material.GetAccessToken() + if err != nil { + return + } + + uri := fmt.Sprintf("%s?access_token=%s", mediaUploadImageURL, accessToken) + var response []byte + response, err = util.PostFile("media", filename, uri) + if err != nil { + return + } + var image resMediaImage + err = json.Unmarshal(response, &image) + if err != nil { + return + } + if image.ErrCode != 0 { + err = fmt.Errorf("UploadImage error : errcode=%v , errmsg=%v", image.ErrCode, image.ErrMsg) + return + } + url = image.URL + return + +} diff --git a/message/image.go b/message/image.go new file mode 100644 index 0000000..93e6bc0 --- /dev/null +++ b/message/image.go @@ -0,0 +1,17 @@ +package message + +//Image 图片消息 +type Image struct { + CommonToken + + Image struct { + MediaID string `xml:"MediaId"` + } `xml:"Image"` +} + +//NewImage 回复图片消息 +func NewImage(mediaID string) *Image { + image := new(Image) + image.Image.MediaID = mediaID + return image +} diff --git a/message/message.go b/message/message.go index 47c8781..b64b0e3 100644 --- a/message/message.go +++ b/message/message.go @@ -46,6 +46,34 @@ const ( EventView EventType = "VIEW" ) +//MixMessage 存放所有微信发送过来的消息和事件 +type MixMessage struct { + CommonToken + + //基本消息 + MsgID int64 `xml:"MsgId"` + Content string `xml:"Content"` + PicURL string `xml:"PicUrl"` + MediaID string `xml:"MediaId"` + Format string `xml:"Format"` + ThumbMediaID string `xml:"ThumbMediaId"` + LocationX float64 `xml:"Location_X"` + LocationY float64 `xml:"Location_Y"` + Scale float64 `xml:"Scale"` + Label string `xml:"Label"` + Title string `xml:"Title"` + Description string `xml:"Description"` + URL string `xml:"Url"` + + //事件相关 + Event string `xml:"Event"` + EventKey string `xml:"EventKey"` + Ticket string `xml:"Ticket"` + Latitude string `xml:"Latitude"` + Longitude string `xml:"Longitude"` + Precision string `xml:"Precision"` +} + //EncryptedXMLMsg 安全模式下的消息体 type EncryptedXMLMsg struct { XMLName struct{} `xml:"xml" json:"-"` @@ -90,31 +118,3 @@ func (msg *CommonToken) SetCreateTime(createTime int64) { func (msg *CommonToken) SetMsgType(msgType MsgType) { msg.MsgType = msgType } - -//MixMessage 存放所有微信发送过来的消息和事件 -type MixMessage struct { - CommonToken - - //基本消息 - MsgID int64 `xml:"MsgId"` - Content string `xml:"Content"` - PicURL string `xml:"PicUrl"` - MediaID string `xml:"MediaId"` - Format string `xml:"Format"` - ThumbMediaID string `xml:"ThumbMediaId"` - LocationX float64 `xml:"Location_X"` - LocationY float64 `xml:"Location_Y"` - Scale float64 `xml:"Scale"` - Label string `xml:"Label"` - Title string `xml:"Title"` - Description string `xml:"Description"` - URL string `xml:"Url"` - - //事件相关 - Event string `xml:"Event"` - EventKey string `xml:"EventKey"` - Ticket string `xml:"Ticket"` - Latitude string `xml:"Latitude"` - Longitude string `xml:"Longitude"` - Precision string `xml:"Precision"` -} diff --git a/message/music.go b/message/music.go new file mode 100644 index 0000000..3e010ed --- /dev/null +++ b/message/music.go @@ -0,0 +1,24 @@ +package message + +//Music 音乐消息 +type Music struct { + CommonToken + + Music struct { + Title string `xml:"Title" ` + Description string `xml:"Description" ` + MusicURL string `xml:"MusicUrl" ` + HQMusicURL string `xml:"HQMusicUrl" ` + ThumbMediaID string `xml:"ThumbMediaId"` + } `xml:"Music"` +} + +//NewMusic 回复音乐消息 +func NewMusic(title, description, musicURL, hQMusicURL, thumbMediaID string) *Music { + music := new(Music) + music.Music.Title = title + music.Music.Description = description + music.Music.MusicURL = musicURL + music.Music.ThumbMediaID = thumbMediaID + return music +} diff --git a/message/news.go b/message/news.go new file mode 100644 index 0000000..ee28b0c --- /dev/null +++ b/message/news.go @@ -0,0 +1,35 @@ +package message + +//News 图文消息 +type News struct { + CommonToken + + ArticleCount int `xml:"ArticleCount"` + Articles []*Article `xml:"Articles>item,omitempty"` +} + +//NewNews 初始化图文消息 +func NewNews(articles []*Article) *News { + news := new(News) + news.ArticleCount = len(articles) + news.Articles = articles + return news +} + +//Article 单篇文章 +type Article struct { + Title string `xml:"Title,omitempty"` + Description string `xml:"Description,omitempty"` + PicURL string `xml:"PicUrl,omitempty"` + URL string `xml:"Url,omitempty"` +} + +//NewArticle 初始化文章 +func NewArticle(title, description, picURL, url string) *Article { + article := new(Article) + article.Title = title + article.Description = description + article.PicURL = picURL + article.URL = url + return article +} diff --git a/message/text.go b/message/text.go index c34fb59..d981d96 100644 --- a/message/text.go +++ b/message/text.go @@ -5,3 +5,10 @@ type Text struct { CommonToken Content string `xml:"Content"` } + +//NewText 初始化文本消息 +func NewText(content string) *Text { + text := new(Text) + text.Content = content + return text +} diff --git a/message/video.go b/message/video.go new file mode 100644 index 0000000..a082065 --- /dev/null +++ b/message/video.go @@ -0,0 +1,21 @@ +package message + +//Video 视频消息 +type Video struct { + CommonToken + + Video struct { + MediaID string `xml:"MediaId"` + Title string `xml:"Title,omitempty"` + Description string `xml:"Description,omitempty"` + } `xml:"Video"` +} + +//NewVideo 回复图片消息 +func NewVideo(mediaID, title, description string) *Video { + video := new(Video) + video.Video.MediaID = mediaID + video.Video.Title = title + video.Video.Description = description + return video +} diff --git a/message/voice.go b/message/voice.go new file mode 100644 index 0000000..d76985c --- /dev/null +++ b/message/voice.go @@ -0,0 +1,17 @@ +package message + +//Voice 语音消息 +type Voice struct { + CommonToken + + Voice struct { + MediaID string `xml:"MediaId"` + } `xml:"Voice"` +} + +//NewVoice 回复语音消息 +func NewVoice(mediaID string) *Voice { + voice := new(Voice) + voice.Voice.MediaID = mediaID + return voice +} diff --git a/oauth/oauth.go b/oauth/oauth.go new file mode 100644 index 0000000..27fe17a --- /dev/null +++ b/oauth/oauth.go @@ -0,0 +1,152 @@ +package oauth + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + + "github.com/silenceper/wechat/context" + "github.com/silenceper/wechat/util" +) + +const ( + redirectOauthURL = "https://open.weixin.qq.com/connect/oauth2/authorize?appid=%s&redirect_uri=%s&response_type=code&scope=%s&state=%s#wechat_redirect" + accessTokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token?appid=%s&secret=%s&code=%s&grant_type=authorization_code" + refreshAccessTokenURL = "https://api.weixin.qq.com/sns/oauth2/refresh_token?appid=%s&grant_type=refresh_token&refresh_token=%s" + userInfoURL = "https://api.weixin.qq.com/sns/userinfo?access_token=%s&openid=%s&lang=zh_CN" + checkAccessTokenURL = "https://api.weixin.qq.com/sns/auth?access_token=%s&openid=%s" +) + +//Oauth 保存用户授权信息 +type Oauth struct { + *context.Context +} + +//NewOauth 实例化授权信息 +func NewOauth(context *context.Context) *Oauth { + auth := new(Oauth) + auth.Context = context + return auth +} + +//GetRedirectURL 获取跳转的url地址 +func (oauth *Oauth) GetRedirectURL(redirectURI, scope, state string) (string, error) { + //url encode + urlStr := url.QueryEscape(redirectURI) + return fmt.Sprintf(redirectOauthURL, oauth.AppID, urlStr, scope, state), nil +} + +//Redirect 跳转到网页授权 +func (oauth *Oauth) Redirect(redirectURI, scope, state string) error { + location, err := oauth.GetRedirectURL(redirectURI, scope, state) + if err != nil { + return err + } + http.Redirect(oauth.Writer, oauth.Request, location, 302) + return nil +} + +// ResAccessToken 获取用户授权access_token的返回结果 +type ResAccessToken struct { + util.CommonError + + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + OpenID string `json:"openid"` + Scope string `json:"scope"` +} + +// GetUserAccessToken 通过网页授权的code 换取access_token(区别于context中的access_token) +func (oauth *Oauth) GetUserAccessToken(code string) (result ResAccessToken, err error) { + urlStr := fmt.Sprintf(accessTokenURL, oauth.AppID, oauth.AppSecret, code) + var response []byte + response, err = util.HTTPGet(urlStr) + if err != nil { + return + } + err = json.Unmarshal(response, &result) + if err != nil { + return + } + if result.ErrCode != 0 { + err = fmt.Errorf("GetUserAccessToken error : errcode=%v , errmsg=%v", result.ErrCode, result.ErrMsg) + return + } + return +} + +//RefreshAccessToken 刷新access_token +func (oauth *Oauth) RefreshAccessToken(refreshToken string) (result ResAccessToken, err error) { + urlStr := fmt.Sprintf(refreshAccessTokenURL, oauth.AppID, refreshToken) + var response []byte + response, err = util.HTTPGet(urlStr) + if err != nil { + return + } + err = json.Unmarshal(response, &result) + if err != nil { + return + } + if result.ErrCode != 0 { + err = fmt.Errorf("GetUserAccessToken error : errcode=%v , errmsg=%v", result.ErrCode, result.ErrMsg) + return + } + return +} + +//CheckAccessToken 检验access_token是否有效 +func (oauth *Oauth) CheckAccessToken(accessToken, openID string) (b bool, err error) { + urlStr := fmt.Sprintf(checkAccessTokenURL, accessToken, openID) + var response []byte + response, err = util.HTTPGet(urlStr) + if err != nil { + return + } + var result util.CommonError + err = json.Unmarshal(response, &result) + if err != nil { + return + } + if result.ErrCode != 0 { + b = false + return + } + b = true + return +} + +//UserInfo 用户授权获取到用户信息 +type UserInfo struct { + util.CommonError + + OpenID string `json:"openid"` + Nickname string `json:"nickname"` + Sex int32 `json:"sex"` + Province string `json:"province"` + City string `json:"city"` + Country string `json:"country"` + HeadImgURL string `json:"headimgurl"` + Privilege []string `json:"privilege"` + Unionid string `json:"unionid"` +} + +//GetUserInfo 如果scope为 snsapi_userinfo 则可以通过此方法获取到用户基本信息 +func (oauth *Oauth) GetUserInfo(accessToken, openID string) (result UserInfo, err error) { + urlStr := fmt.Sprintf(userInfoURL, accessToken, openID) + var response []byte + response, err = util.HTTPGet(urlStr) + if err != nil { + return + } + err = json.Unmarshal(response, &result) + if err != nil { + return + } + if result.ErrCode != 0 { + err = fmt.Errorf("GetUserInfo error : errcode=%v , errmsg=%v", result.ErrCode, result.ErrMsg) + return + } + return +} diff --git a/server/server.go b/server/server.go index 068ba31..8e6b882 100644 --- a/server/server.go +++ b/server/server.go @@ -6,6 +6,7 @@ import ( "fmt" "io/ioutil" "reflect" + "runtime/debug" "strconv" "strings" @@ -57,8 +58,10 @@ func (srv *Server) Serve() error { return err } - srv.buildResponse(response) - return nil + //debug + //fmt.Println("request msg = ", string(srv.requestRawXMLMsg)) + + return srv.buildResponse(response) } //Validate 校验请求是否合法 @@ -155,7 +158,7 @@ func (srv *Server) SetMessageHandler(handler func(message.MixMessage) *message.R func (srv *Server) buildResponse(reply *message.Reply) (err error) { defer func() { if e := recover(); e != nil { - err = fmt.Errorf("panic error: %v", err) + err = fmt.Errorf("panic error: %v\n%s", e, debug.Stack()) } }() if reply == nil { @@ -223,6 +226,8 @@ func (srv *Server) Send() (err error) { Nonce: srv.nonce, } } - srv.XML(replyMsg) + if replyMsg != nil { + srv.XML(replyMsg) + } return } diff --git a/util/error.go b/util/error.go new file mode 100644 index 0000000..b8f4211 --- /dev/null +++ b/util/error.go @@ -0,0 +1,6 @@ +package util + +type CommonError struct { + ErrCode int64 `json:"errcode"` + ErrMsg string `json:"errmsg"` +} diff --git a/util/http.go b/util/http.go index 203fa4a..b4f195d 100644 --- a/util/http.go +++ b/util/http.go @@ -1,19 +1,154 @@ package util import ( + "bytes" + "encoding/json" + "fmt" + "io" "io/ioutil" + "mime/multipart" "net/http" + "os" ) //HTTPGet get 请求 -func HTTPGet(url string) ([]byte, error) { - response, err := http.Get(url) +func HTTPGet(uri string) ([]byte, error) { + response, err := http.Get(uri) if err != nil { return nil, err } + + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("http get error : uri=%v , statusCode=%v", uri, response.StatusCode) + } return ioutil.ReadAll(response.Body) } -//HTTPPost post 请求 -func HTTPPost() { +//PostJSON post json 数据请求 +func PostJSON(uri string, obj interface{}) ([]byte, error) { + jsonData, err := json.Marshal(obj) + if err != nil { + return nil, err + } + body := bytes.NewBuffer(jsonData) + response, err := http.Post(uri, "application/json;charset=utf-8", body) + if err != nil { + return nil, err + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("http get error : uri=%v , statusCode=%v", uri, response.StatusCode) + } + return ioutil.ReadAll(response.Body) +} + +//PostFile 上传文件 +func PostFile(fieldname, filename, uri string) ([]byte, error) { + /* + bodyBuf := &bytes.Buffer{} + bodyWriter := multipart.NewWriter(bodyBuf) + + fileWriter, err := bodyWriter.CreateFormFile(fieldname, filename) + if err != nil { + return nil, fmt.Errorf("error writing to buffer") + } + + fh, err := os.Open(filename) + if err != nil { + return nil, fmt.Errorf("error opening file") + } + defer fh.Close() + + _, err = io.Copy(fileWriter, fh) + if err != nil { + return nil, err + } + + contentType := bodyWriter.FormDataContentType() + bodyWriter.Close() + + resp, err := http.Post(uri, contentType, bodyBuf) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, err + } + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, err + } + return respBody, nil + */ + fields := []MultipartFormField{ + { + IsFile: true, + Fieldname: fieldname, + Filename: filename, + }, + } + return PostMultipartForm(fields, uri) +} + +//MultipartFormField 保存文件或其他字段信息 +type MultipartFormField struct { + IsFile bool + Fieldname string + Value []byte + Filename string +} + +//PostMultipartForm 上传文件或其他多个字段 +func PostMultipartForm(fields []MultipartFormField, uri string) (respBody []byte, err error) { + bodyBuf := &bytes.Buffer{} + bodyWriter := multipart.NewWriter(bodyBuf) + + for _, field := range fields { + if field.IsFile { + fileWriter, e := bodyWriter.CreateFormFile(field.Fieldname, field.Filename) + if e != nil { + err = fmt.Errorf("error writing to buffer , err=%v", e) + return + } + + fh, e := os.Open(field.Filename) + if e != nil { + err = fmt.Errorf("error opening file , err=%v", e) + return + } + defer fh.Close() + + if _, err = io.Copy(fileWriter, fh); err != nil { + return + } + } else { + partWriter, e := bodyWriter.CreateFormField(field.Fieldname) + if e != nil { + err = e + return + } + valueReader := bytes.NewReader(field.Value) + if _, err = io.Copy(partWriter, valueReader); err != nil { + return + } + } + } + + contentType := bodyWriter.FormDataContentType() + bodyWriter.Close() + + resp, e := http.Post(uri, contentType, bodyBuf) + if e != nil { + err = e + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, err + } + respBody, err = ioutil.ReadAll(resp.Body) + return } diff --git a/util/signature_test.go b/util/signature_test.go new file mode 100644 index 0000000..9aa2f7f --- /dev/null +++ b/util/signature_test.go @@ -0,0 +1,11 @@ +package util + +import "testing" + +func TestSignature(t *testing.T) { + //abc sig + abc := "a9993e364706816aba3e25717850c26c9cd0d89d" + if abc != Signature("a", "b", "c") { + t.Error("test Signature Error") + } +} diff --git a/util/string.go b/util/string.go new file mode 100644 index 0000000..62b5c13 --- /dev/null +++ b/util/string.go @@ -0,0 +1,18 @@ +package util + +import ( + "math/rand" + "time" +) + +//RandomStr 随机生成字符串 +func RandomStr(length int) string { + str := "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + bytes := []byte(str) + result := []byte{} + r := rand.New(rand.NewSource(time.Now().UnixNano())) + for i := 0; i < length; i++ { + result = append(result, bytes[r.Intn(len(bytes))]) + } + return string(result) +} diff --git a/wechat.go b/wechat.go index 1067169..48f207e 100644 --- a/wechat.go +++ b/wechat.go @@ -6,6 +6,9 @@ import ( "github.com/silenceper/wechat/cache" "github.com/silenceper/wechat/context" + "github.com/silenceper/wechat/js" + "github.com/silenceper/wechat/material" + "github.com/silenceper/wechat/oauth" "github.com/silenceper/wechat/server" ) @@ -37,6 +40,7 @@ func copyConfigToContext(cfg *Config, context *context.Context) { context.EncodingAESKey = cfg.EncodingAESKey context.Cache = cfg.Cache context.SetAccessTokenLock(new(sync.RWMutex)) + context.SetJsApiTicketLock(new(sync.RWMutex)) } //GetServer init @@ -45,3 +49,22 @@ func (wc *Wechat) GetServer(req *http.Request, writer http.ResponseWriter) *serv wc.Context.Writer = writer return server.NewServer(wc.Context) } + +//GetMaterial init +func (wc *Wechat) GetMaterial() *material.Material { + return material.NewMaterial(wc.Context) +} + +//GetOauth init +func (wc *Wechat) GetOauth(req *http.Request, writer http.ResponseWriter) *oauth.Oauth { + wc.Context.Request = req + wc.Context.Writer = writer + return oauth.NewOauth(wc.Context) +} + +//GetJs init +func (wc *Wechat) GetJs(req *http.Request, writer http.ResponseWriter) *js.Js { + wc.Context.Request = req + wc.Context.Writer = writer + return js.NewJs(wc.Context) +}