1
0
mirror of https://github.com/silenceper/wechat.git synced 2026-02-04 12:52:27 +08:00

oauth2,jssdk

This commit is contained in:
wenzl
2016-09-15 01:27:28 +08:00
parent 3854bb0487
commit 4333691b37
21 changed files with 913 additions and 43 deletions

2
cache/cache.go vendored
View File

@@ -5,7 +5,7 @@ import "time"
//Cache interface
type Cache interface {
Get(key string) interface{}
Set(key string, val interface{}, timeput time.Duration) error
Set(key string, val interface{}, timeout time.Duration) error
IsExist(key string) bool
Delete(key string) error
}

View File

@@ -16,11 +16,10 @@ const (
//ResAccessToken struct
type ResAccessToken struct {
ErrorCode int32 `json:"errcode"`
ErrorMsg string `json:"errmsg"`
util.CommonError
AccessToken string `json:"access_token"`
ExpiresIn int32 `json:"expires_in"`
ExpiresIn int64 `json:"expires_in"`
}
//SetAccessTokenLock 设置读写锁一个appID一个读写锁
@@ -60,12 +59,13 @@ func (ctx *Context) GetAccessTokenFromServer() (resAccessToken ResAccessToken, e
if err != nil {
return
}
if resAccessToken.ErrorMsg != "" {
err = fmt.Errorf("get access_token error : errcode=%v , errormsg=%v", resAccessToken.ErrorCode, resAccessToken.ErrorMsg)
if resAccessToken.ErrMsg != "" {
err = fmt.Errorf("get access_token error : errcode=%v , errormsg=%v", resAccessToken.ErrCode, resAccessToken.ErrMsg)
return
}
accessTokenCacheKey := fmt.Sprintf("access_token_%s", ctx.AppID)
err = ctx.Cache.Set(accessTokenCacheKey, resAccessToken.AccessToken, time.Duration(resAccessToken.ExpiresIn)-1500/time.Second)
expires := resAccessToken.ExpiresIn - 1500
err = ctx.Cache.Set(accessTokenCacheKey, resAccessToken.AccessToken, time.Duration(expires)*time.Second)
return
}

View File

@@ -21,6 +21,9 @@ type Context struct {
//accessTokenLock 读写锁 同一个AppID一个
accessTokenLock *sync.RWMutex
//jsapiTicket 读写锁 同一个AppID一个
jsApiTicketLock *sync.RWMutex
}
// Query returns the keyed url query value if it exists
@@ -37,3 +40,12 @@ func (ctx *Context) GetQuery(key string) (string, bool) {
}
return "", false
}
//SetJsApiTicket 设置jsApiTicket的lock
func (ctx *Context) SetJsApiTicketLock(lock *sync.RWMutex) {
ctx.jsApiTicketLock = lock
}
func (ctx *Context) GetJsApiTicketLock() *sync.RWMutex {
return ctx.jsApiTicketLock
}

View File

@@ -10,6 +10,8 @@ var plainContentType = []string{"text/plain; charset=utf-8"}
//Render render from bytes
func (ctx *Context) Render(bytes []byte) {
//debug
//fmt.Println("response msg = ", string(bytes))
ctx.Writer.WriteHeader(200)
_, err := ctx.Writer.Write(bytes)
if err != nil {

109
js/js.go Normal file
View File

@@ -0,0 +1,109 @@
package js
import (
"encoding/json"
"fmt"
"time"
"github.com/silenceper/wechat/context"
"github.com/silenceper/wechat/util"
)
const getTicketURL = "https://api.weixin.qq.com/cgi-bin/ticket/getticket?access_token=%s&type=jsapi"
//Js struct
type Js struct {
*context.Context
}
//Config 返回给用户jssdk配置信息
type Config struct {
AppID string
TimeStamp int64
NonceStr string
Signature string
}
//resTicket 请求jsapi_tikcet返回结果
type resTicket struct {
util.CommonError
Ticket string `json:"ticket"`
ExpiresIn int64 `json:"expires_in"`
}
//NewJs init
func NewJs(context *context.Context) *Js {
js := new(Js)
js.Context = context
return js
}
//GetConfig 获取jssdk需要的配置参数
//uri 为当前网页地址
func (js *Js) GetConfig(uri string) (config *Config, err error) {
config = new(Config)
var ticketStr string
ticketStr, err = js.getTicket()
if err != nil {
return
}
nonceStr := util.RandomStr(16)
timestamp := util.GetCurrTs()
str := fmt.Sprintf("jsapi_ticket=%s&noncestr=%s&timestamp=%d&url=%s", ticketStr, nonceStr, timestamp, uri)
sigStr := util.Signature(str)
config.AppID = js.AppID
config.NonceStr = nonceStr
config.TimeStamp = timestamp
config.Signature = sigStr
return
}
//getTicket 获取jsapi_tocket全局缓存
func (js *Js) getTicket() (ticketStr string, err error) {
js.GetJsApiTicketLock().Lock()
defer js.GetJsApiTicketLock().Unlock()
//先从cache中取
jsAPITicketCacheKey := fmt.Sprintf("jsapi_ticket_%s", js.AppID)
val := js.Cache.Get(jsAPITicketCacheKey)
if val != nil {
ticketStr = val.(string)
return
}
var ticket resTicket
ticket, err = js.getTicketFromServer()
if err != nil {
return
}
ticketStr = ticket.Ticket
return
}
//getTicketFromServer 强制从服务器中获取ticket
func (js *Js) getTicketFromServer() (ticket resTicket, err error) {
var accessToken string
accessToken, err = js.GetAccessToken()
if err != nil {
return
}
var response []byte
url := fmt.Sprintf(getTicketURL, accessToken)
response, err = util.HTTPGet(url)
err = json.Unmarshal(response, &ticket)
if err != nil {
return
}
if ticket.ErrCode != 0 {
err = fmt.Errorf("getTicket Error : errcode=%s , errmsg=%s", ticket.ErrCode, ticket.ErrMsg)
return
}
jsAPITicketCacheKey := fmt.Sprintf("jsapi_ticket_%s", js.AppID)
expires := ticket.ExpiresIn - 1500
err = js.Cache.Set(jsAPITicketCacheKey, ticket.Ticket, time.Duration(expires)*time.Second)
return
}

167
material/material.go Normal file
View File

@@ -0,0 +1,167 @@
package material
import (
"encoding/json"
"errors"
"fmt"
"github.com/silenceper/wechat/context"
"github.com/silenceper/wechat/util"
)
const (
addNewsURL = "https://api.weixin.qq.com/cgi-bin/material/add_news"
addMaterialURL = "https://api.weixin.qq.com/cgi-bin/material/add_material"
)
//Material 素材管理
type Material struct {
*context.Context
}
//NewMaterial init
func NewMaterial(context *context.Context) *Material {
material := new(Material)
material.Context = context
return material
}
//Article 永久图文素材
type Article struct {
ThumbMediaID string `json:"thumb_media_id"`
Author string `json:"author"`
Digest string `json:"digest"`
ShowCoverPic int `json:"show_cover_pic"`
Content string `json:"content"`
ContentSourceURL string `json:"content_source_url"`
}
//reqArticles 永久性图文素材请求信息
type reqArticles struct {
articles []*Article `json:"articles"`
}
//resArticles 永久性图文素材返回结果
type resArticles struct {
util.CommonError
MediaID string `json:"media_id"`
}
//AddNews 新增永久图文素材
func (material *Material) AddNews(articles []*Article) (mediaID string, err error) {
req := &reqArticles{articles}
var accessToken string
accessToken, err = material.GetAccessToken()
if err != nil {
return
}
uri := fmt.Sprintf("%s?access_token=%s", addNewsURL, accessToken)
responseBytes, err := util.PostJSON(uri, req)
var res resArticles
err = json.Unmarshal(responseBytes, res)
if err != nil {
return
}
mediaID = res.MediaID
return
}
//resAddMaterial 永久性素材上传返回的结果
type resAddMaterial struct {
util.CommonError
MediaID string `json:"media_id"`
URL string `json:"url"`
}
//AddMaterial 上传永久性素材(处理视频需要单独上传)
func (material *Material) AddMaterial(mediaType MediaType, filename string) (mediaID string, url string, err error) {
if mediaType == MediaTypeVideo {
err = errors.New("永久视频素材上传使用 AddVideo 方法")
}
var accessToken string
accessToken, err = material.GetAccessToken()
if err != nil {
return
}
uri := fmt.Sprintf("%s?access_token=%s&type=%s", addMaterialURL, accessToken, mediaType)
var response []byte
response, err = util.PostFile("media", filename, uri)
if err != nil {
return
}
var resMaterial resAddMaterial
err = json.Unmarshal(response, &resMaterial)
if err != nil {
return
}
if resMaterial.ErrCode != 0 {
err = fmt.Errorf("AddMaterial error : errcode=%v , errmsg=%v", resMaterial.ErrCode, resMaterial.ErrMsg)
return
}
mediaID = resMaterial.MediaID
url = resMaterial.URL
return
}
type reqVideo struct {
Title string `json:"title"`
Introduction string `json:"introduction"`
}
//AddVideo 永久视频素材文件上传
func (material *Material) AddVideo(filename, title, introduction string) (mediaID string, url string, err error) {
var accessToken string
accessToken, err = material.GetAccessToken()
if err != nil {
return
}
uri := fmt.Sprintf("%s?access_token=%s&type=video", addMaterialURL, accessToken)
videoDesc := &reqVideo{
Title: title,
Introduction: introduction,
}
var fieldValue []byte
fieldValue, err = json.Marshal(videoDesc)
if err != nil {
return
}
fields := []util.MultipartFormField{
{
IsFile: true,
Fieldname: "video",
Filename: filename,
},
{
IsFile: true,
Fieldname: "description",
Value: fieldValue,
},
}
var response []byte
response, err = util.PostMultipartForm(fields, uri)
if err != nil {
return
}
var resMaterial resAddMaterial
err = json.Unmarshal(response, &resMaterial)
if err != nil {
return
}
if resMaterial.ErrCode != 0 {
err = fmt.Errorf("AddMaterial error : errcode=%v , errmsg=%v", resMaterial.ErrCode, resMaterial.ErrMsg)
return
}
mediaID = resMaterial.MediaID
url = resMaterial.URL
return
}

109
material/media.go Normal file
View File

@@ -0,0 +1,109 @@
package material
import (
"encoding/json"
"fmt"
"github.com/silenceper/wechat/util"
)
//MediaType 媒体文件类型
type MediaType string
const (
//MediaTypeImage 媒体文件:图片
MediaTypeImage MediaType = "image"
//MediaTypeVoice 媒体文件:声音
MediaTypeVoice = "voice"
//MediaTypeVideo 媒体文件:视频
MediaTypeVideo = "video"
//MediaTypeThumb 媒体文件:缩略图
MediaTypeThumb = "thumb"
)
const (
mediaUploadURL = "https://api.weixin.qq.com/cgi-bin/media/upload"
mediaUploadImageURL = "https://api.weixin.qq.com/cgi-bin/media/uploadimg"
mediaGetURL = "https://api.weixin.qq.com/cgi-bin/media/get"
)
//Media 临时素材上传返回信息
type Media struct {
util.CommonError
Type MediaType `json:"type"`
MediaID string `json:"media_id"`
CreatedAt int64 `json:"created_at"`
}
//MediaUpload 临时素材上传
func (material *Material) MediaUpload(mediaType MediaType, filename string) (media Media, err error) {
var accessToken string
accessToken, err = material.GetAccessToken()
if err != nil {
return
}
uri := fmt.Sprintf("%s?access_token=%s&type=%s", mediaUploadURL, accessToken, mediaType)
var response []byte
response, err = util.PostFile("media", filename, uri)
if err != nil {
return
}
err = json.Unmarshal(response, &media)
if err != nil {
return
}
if media.ErrCode != 0 {
err = fmt.Errorf("MediaUpload error : errcode=%v , errmsg=%v", media.ErrCode, media.ErrMsg)
return
}
return
}
//GetMediaURL 返回临时素材的下载地址供用户自己处理
//NOTICE: URL 不可公开因为含access_token 需要立即另存文件
func (material *Material) GetMediaURL(mediaID string) (mediaURL string, err error) {
var accessToken string
accessToken, err = material.GetAccessToken()
if err != nil {
return
}
mediaURL = fmt.Sprintf("%s?access_token=%s&media_id=%s", mediaGetURL, accessToken, mediaID)
return
}
//resMediaImage 图片上传返回结果
type resMediaImage struct {
util.CommonError
URL string `json:"url"`
}
//ImageUpload 图片上传
func (material *Material) ImageUpload(filename string) (url string, err error) {
var accessToken string
accessToken, err = material.GetAccessToken()
if err != nil {
return
}
uri := fmt.Sprintf("%s?access_token=%s", mediaUploadImageURL, accessToken)
var response []byte
response, err = util.PostFile("media", filename, uri)
if err != nil {
return
}
var image resMediaImage
err = json.Unmarshal(response, &image)
if err != nil {
return
}
if image.ErrCode != 0 {
err = fmt.Errorf("UploadImage error : errcode=%v , errmsg=%v", image.ErrCode, image.ErrMsg)
return
}
url = image.URL
return
}

17
message/image.go Normal file
View File

@@ -0,0 +1,17 @@
package message
//Image 图片消息
type Image struct {
CommonToken
Image struct {
MediaID string `xml:"MediaId"`
} `xml:"Image"`
}
//NewImage 回复图片消息
func NewImage(mediaID string) *Image {
image := new(Image)
image.Image.MediaID = mediaID
return image
}

View File

@@ -46,6 +46,34 @@ const (
EventView EventType = "VIEW"
)
//MixMessage 存放所有微信发送过来的消息和事件
type MixMessage struct {
CommonToken
//基本消息
MsgID int64 `xml:"MsgId"`
Content string `xml:"Content"`
PicURL string `xml:"PicUrl"`
MediaID string `xml:"MediaId"`
Format string `xml:"Format"`
ThumbMediaID string `xml:"ThumbMediaId"`
LocationX float64 `xml:"Location_X"`
LocationY float64 `xml:"Location_Y"`
Scale float64 `xml:"Scale"`
Label string `xml:"Label"`
Title string `xml:"Title"`
Description string `xml:"Description"`
URL string `xml:"Url"`
//事件相关
Event string `xml:"Event"`
EventKey string `xml:"EventKey"`
Ticket string `xml:"Ticket"`
Latitude string `xml:"Latitude"`
Longitude string `xml:"Longitude"`
Precision string `xml:"Precision"`
}
//EncryptedXMLMsg 安全模式下的消息体
type EncryptedXMLMsg struct {
XMLName struct{} `xml:"xml" json:"-"`
@@ -90,31 +118,3 @@ func (msg *CommonToken) SetCreateTime(createTime int64) {
func (msg *CommonToken) SetMsgType(msgType MsgType) {
msg.MsgType = msgType
}
//MixMessage 存放所有微信发送过来的消息和事件
type MixMessage struct {
CommonToken
//基本消息
MsgID int64 `xml:"MsgId"`
Content string `xml:"Content"`
PicURL string `xml:"PicUrl"`
MediaID string `xml:"MediaId"`
Format string `xml:"Format"`
ThumbMediaID string `xml:"ThumbMediaId"`
LocationX float64 `xml:"Location_X"`
LocationY float64 `xml:"Location_Y"`
Scale float64 `xml:"Scale"`
Label string `xml:"Label"`
Title string `xml:"Title"`
Description string `xml:"Description"`
URL string `xml:"Url"`
//事件相关
Event string `xml:"Event"`
EventKey string `xml:"EventKey"`
Ticket string `xml:"Ticket"`
Latitude string `xml:"Latitude"`
Longitude string `xml:"Longitude"`
Precision string `xml:"Precision"`
}

24
message/music.go Normal file
View File

@@ -0,0 +1,24 @@
package message
//Music 音乐消息
type Music struct {
CommonToken
Music struct {
Title string `xml:"Title" `
Description string `xml:"Description" `
MusicURL string `xml:"MusicUrl" `
HQMusicURL string `xml:"HQMusicUrl" `
ThumbMediaID string `xml:"ThumbMediaId"`
} `xml:"Music"`
}
//NewMusic 回复音乐消息
func NewMusic(title, description, musicURL, hQMusicURL, thumbMediaID string) *Music {
music := new(Music)
music.Music.Title = title
music.Music.Description = description
music.Music.MusicURL = musicURL
music.Music.ThumbMediaID = thumbMediaID
return music
}

35
message/news.go Normal file
View File

@@ -0,0 +1,35 @@
package message
//News 图文消息
type News struct {
CommonToken
ArticleCount int `xml:"ArticleCount"`
Articles []*Article `xml:"Articles>item,omitempty"`
}
//NewNews 初始化图文消息
func NewNews(articles []*Article) *News {
news := new(News)
news.ArticleCount = len(articles)
news.Articles = articles
return news
}
//Article 单篇文章
type Article struct {
Title string `xml:"Title,omitempty"`
Description string `xml:"Description,omitempty"`
PicURL string `xml:"PicUrl,omitempty"`
URL string `xml:"Url,omitempty"`
}
//NewArticle 初始化文章
func NewArticle(title, description, picURL, url string) *Article {
article := new(Article)
article.Title = title
article.Description = description
article.PicURL = picURL
article.URL = url
return article
}

View File

@@ -5,3 +5,10 @@ type Text struct {
CommonToken
Content string `xml:"Content"`
}
//NewText 初始化文本消息
func NewText(content string) *Text {
text := new(Text)
text.Content = content
return text
}

21
message/video.go Normal file
View File

@@ -0,0 +1,21 @@
package message
//Video 视频消息
type Video struct {
CommonToken
Video struct {
MediaID string `xml:"MediaId"`
Title string `xml:"Title,omitempty"`
Description string `xml:"Description,omitempty"`
} `xml:"Video"`
}
//NewVideo 回复图片消息
func NewVideo(mediaID, title, description string) *Video {
video := new(Video)
video.Video.MediaID = mediaID
video.Video.Title = title
video.Video.Description = description
return video
}

17
message/voice.go Normal file
View File

@@ -0,0 +1,17 @@
package message
//Voice 语音消息
type Voice struct {
CommonToken
Voice struct {
MediaID string `xml:"MediaId"`
} `xml:"Voice"`
}
//NewVoice 回复语音消息
func NewVoice(mediaID string) *Voice {
voice := new(Voice)
voice.Voice.MediaID = mediaID
return voice
}

152
oauth/oauth.go Normal file
View File

@@ -0,0 +1,152 @@
package oauth
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"github.com/silenceper/wechat/context"
"github.com/silenceper/wechat/util"
)
const (
redirectOauthURL = "https://open.weixin.qq.com/connect/oauth2/authorize?appid=%s&redirect_uri=%s&response_type=code&scope=%s&state=%s#wechat_redirect"
accessTokenURL = "https://api.weixin.qq.com/sns/oauth2/access_token?appid=%s&secret=%s&code=%s&grant_type=authorization_code"
refreshAccessTokenURL = "https://api.weixin.qq.com/sns/oauth2/refresh_token?appid=%s&grant_type=refresh_token&refresh_token=%s"
userInfoURL = "https://api.weixin.qq.com/sns/userinfo?access_token=%s&openid=%s&lang=zh_CN"
checkAccessTokenURL = "https://api.weixin.qq.com/sns/auth?access_token=%s&openid=%s"
)
//Oauth 保存用户授权信息
type Oauth struct {
*context.Context
}
//NewOauth 实例化授权信息
func NewOauth(context *context.Context) *Oauth {
auth := new(Oauth)
auth.Context = context
return auth
}
//GetRedirectURL 获取跳转的url地址
func (oauth *Oauth) GetRedirectURL(redirectURI, scope, state string) (string, error) {
//url encode
urlStr := url.QueryEscape(redirectURI)
return fmt.Sprintf(redirectOauthURL, oauth.AppID, urlStr, scope, state), nil
}
//Redirect 跳转到网页授权
func (oauth *Oauth) Redirect(redirectURI, scope, state string) error {
location, err := oauth.GetRedirectURL(redirectURI, scope, state)
if err != nil {
return err
}
http.Redirect(oauth.Writer, oauth.Request, location, 302)
return nil
}
// ResAccessToken 获取用户授权access_token的返回结果
type ResAccessToken struct {
util.CommonError
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
OpenID string `json:"openid"`
Scope string `json:"scope"`
}
// GetUserAccessToken 通过网页授权的code 换取access_token(区别于context中的access_token)
func (oauth *Oauth) GetUserAccessToken(code string) (result ResAccessToken, err error) {
urlStr := fmt.Sprintf(accessTokenURL, oauth.AppID, oauth.AppSecret, code)
var response []byte
response, err = util.HTTPGet(urlStr)
if err != nil {
return
}
err = json.Unmarshal(response, &result)
if err != nil {
return
}
if result.ErrCode != 0 {
err = fmt.Errorf("GetUserAccessToken error : errcode=%v , errmsg=%v", result.ErrCode, result.ErrMsg)
return
}
return
}
//RefreshAccessToken 刷新access_token
func (oauth *Oauth) RefreshAccessToken(refreshToken string) (result ResAccessToken, err error) {
urlStr := fmt.Sprintf(refreshAccessTokenURL, oauth.AppID, refreshToken)
var response []byte
response, err = util.HTTPGet(urlStr)
if err != nil {
return
}
err = json.Unmarshal(response, &result)
if err != nil {
return
}
if result.ErrCode != 0 {
err = fmt.Errorf("GetUserAccessToken error : errcode=%v , errmsg=%v", result.ErrCode, result.ErrMsg)
return
}
return
}
//CheckAccessToken 检验access_token是否有效
func (oauth *Oauth) CheckAccessToken(accessToken, openID string) (b bool, err error) {
urlStr := fmt.Sprintf(checkAccessTokenURL, accessToken, openID)
var response []byte
response, err = util.HTTPGet(urlStr)
if err != nil {
return
}
var result util.CommonError
err = json.Unmarshal(response, &result)
if err != nil {
return
}
if result.ErrCode != 0 {
b = false
return
}
b = true
return
}
//UserInfo 用户授权获取到用户信息
type UserInfo struct {
util.CommonError
OpenID string `json:"openid"`
Nickname string `json:"nickname"`
Sex int32 `json:"sex"`
Province string `json:"province"`
City string `json:"city"`
Country string `json:"country"`
HeadImgURL string `json:"headimgurl"`
Privilege []string `json:"privilege"`
Unionid string `json:"unionid"`
}
//GetUserInfo 如果scope为 snsapi_userinfo 则可以通过此方法获取到用户基本信息
func (oauth *Oauth) GetUserInfo(accessToken, openID string) (result UserInfo, err error) {
urlStr := fmt.Sprintf(userInfoURL, accessToken, openID)
var response []byte
response, err = util.HTTPGet(urlStr)
if err != nil {
return
}
err = json.Unmarshal(response, &result)
if err != nil {
return
}
if result.ErrCode != 0 {
err = fmt.Errorf("GetUserInfo error : errcode=%v , errmsg=%v", result.ErrCode, result.ErrMsg)
return
}
return
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io/ioutil"
"reflect"
"runtime/debug"
"strconv"
"strings"
@@ -57,8 +58,10 @@ func (srv *Server) Serve() error {
return err
}
srv.buildResponse(response)
return nil
//debug
//fmt.Println("request msg = ", string(srv.requestRawXMLMsg))
return srv.buildResponse(response)
}
//Validate 校验请求是否合法
@@ -155,7 +158,7 @@ func (srv *Server) SetMessageHandler(handler func(message.MixMessage) *message.R
func (srv *Server) buildResponse(reply *message.Reply) (err error) {
defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("panic error: %v", err)
err = fmt.Errorf("panic error: %v\n%s", e, debug.Stack())
}
}()
if reply == nil {
@@ -223,6 +226,8 @@ func (srv *Server) Send() (err error) {
Nonce: srv.nonce,
}
}
srv.XML(replyMsg)
if replyMsg != nil {
srv.XML(replyMsg)
}
return
}

6
util/error.go Normal file
View File

@@ -0,0 +1,6 @@
package util
type CommonError struct {
ErrCode int64 `json:"errcode"`
ErrMsg string `json:"errmsg"`
}

View File

@@ -1,19 +1,154 @@
package util
import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
"net/http"
"os"
)
//HTTPGet get 请求
func HTTPGet(url string) ([]byte, error) {
response, err := http.Get(url)
func HTTPGet(uri string) ([]byte, error) {
response, err := http.Get(uri)
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("http get error : uri=%v , statusCode=%v", uri, response.StatusCode)
}
return ioutil.ReadAll(response.Body)
}
//HTTPPost post 请求
func HTTPPost() {
//PostJSON post json 数据请求
func PostJSON(uri string, obj interface{}) ([]byte, error) {
jsonData, err := json.Marshal(obj)
if err != nil {
return nil, err
}
body := bytes.NewBuffer(jsonData)
response, err := http.Post(uri, "application/json;charset=utf-8", body)
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("http get error : uri=%v , statusCode=%v", uri, response.StatusCode)
}
return ioutil.ReadAll(response.Body)
}
//PostFile 上传文件
func PostFile(fieldname, filename, uri string) ([]byte, error) {
/*
bodyBuf := &bytes.Buffer{}
bodyWriter := multipart.NewWriter(bodyBuf)
fileWriter, err := bodyWriter.CreateFormFile(fieldname, filename)
if err != nil {
return nil, fmt.Errorf("error writing to buffer")
}
fh, err := os.Open(filename)
if err != nil {
return nil, fmt.Errorf("error opening file")
}
defer fh.Close()
_, err = io.Copy(fileWriter, fh)
if err != nil {
return nil, err
}
contentType := bodyWriter.FormDataContentType()
bodyWriter.Close()
resp, err := http.Post(uri, contentType, bodyBuf)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, err
}
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return respBody, nil
*/
fields := []MultipartFormField{
{
IsFile: true,
Fieldname: fieldname,
Filename: filename,
},
}
return PostMultipartForm(fields, uri)
}
//MultipartFormField 保存文件或其他字段信息
type MultipartFormField struct {
IsFile bool
Fieldname string
Value []byte
Filename string
}
//PostMultipartForm 上传文件或其他多个字段
func PostMultipartForm(fields []MultipartFormField, uri string) (respBody []byte, err error) {
bodyBuf := &bytes.Buffer{}
bodyWriter := multipart.NewWriter(bodyBuf)
for _, field := range fields {
if field.IsFile {
fileWriter, e := bodyWriter.CreateFormFile(field.Fieldname, field.Filename)
if e != nil {
err = fmt.Errorf("error writing to buffer , err=%v", e)
return
}
fh, e := os.Open(field.Filename)
if e != nil {
err = fmt.Errorf("error opening file , err=%v", e)
return
}
defer fh.Close()
if _, err = io.Copy(fileWriter, fh); err != nil {
return
}
} else {
partWriter, e := bodyWriter.CreateFormField(field.Fieldname)
if e != nil {
err = e
return
}
valueReader := bytes.NewReader(field.Value)
if _, err = io.Copy(partWriter, valueReader); err != nil {
return
}
}
}
contentType := bodyWriter.FormDataContentType()
bodyWriter.Close()
resp, e := http.Post(uri, contentType, bodyBuf)
if e != nil {
err = e
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, err
}
respBody, err = ioutil.ReadAll(resp.Body)
return
}

11
util/signature_test.go Normal file
View File

@@ -0,0 +1,11 @@
package util
import "testing"
func TestSignature(t *testing.T) {
//abc sig
abc := "a9993e364706816aba3e25717850c26c9cd0d89d"
if abc != Signature("a", "b", "c") {
t.Error("test Signature Error")
}
}

18
util/string.go Normal file
View File

@@ -0,0 +1,18 @@
package util
import (
"math/rand"
"time"
)
//RandomStr 随机生成字符串
func RandomStr(length int) string {
str := "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
bytes := []byte(str)
result := []byte{}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for i := 0; i < length; i++ {
result = append(result, bytes[r.Intn(len(bytes))])
}
return string(result)
}

View File

@@ -6,6 +6,9 @@ import (
"github.com/silenceper/wechat/cache"
"github.com/silenceper/wechat/context"
"github.com/silenceper/wechat/js"
"github.com/silenceper/wechat/material"
"github.com/silenceper/wechat/oauth"
"github.com/silenceper/wechat/server"
)
@@ -37,6 +40,7 @@ func copyConfigToContext(cfg *Config, context *context.Context) {
context.EncodingAESKey = cfg.EncodingAESKey
context.Cache = cfg.Cache
context.SetAccessTokenLock(new(sync.RWMutex))
context.SetJsApiTicketLock(new(sync.RWMutex))
}
//GetServer init
@@ -45,3 +49,22 @@ func (wc *Wechat) GetServer(req *http.Request, writer http.ResponseWriter) *serv
wc.Context.Writer = writer
return server.NewServer(wc.Context)
}
//GetMaterial init
func (wc *Wechat) GetMaterial() *material.Material {
return material.NewMaterial(wc.Context)
}
//GetOauth init
func (wc *Wechat) GetOauth(req *http.Request, writer http.ResponseWriter) *oauth.Oauth {
wc.Context.Request = req
wc.Context.Writer = writer
return oauth.NewOauth(wc.Context)
}
//GetJs init
func (wc *Wechat) GetJs(req *http.Request, writer http.ResponseWriter) *js.Js {
wc.Context.Request = req
wc.Context.Writer = writer
return js.NewJs(wc.Context)
}