diff --git a/credential/access_token.go b/credential/access_token.go new file mode 100644 index 0000000..362e705 --- /dev/null +++ b/credential/access_token.go @@ -0,0 +1,6 @@ +package credential + +//AccessTokenHandle AccessToken 接口 +type AccessTokenHandle interface { + GetAccessToken() (accessToken string, err error) +} diff --git a/credential/default_access_token.go b/credential/default_access_token.go new file mode 100644 index 0000000..f346933 --- /dev/null +++ b/credential/default_access_token.go @@ -0,0 +1,97 @@ +package credential + +import ( + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/silenceper/wechat/v2/cache" + "github.com/silenceper/wechat/v2/util" +) + +const ( + //AccessTokenURL 获取access_token的接口 + accessTokenURL = "https://api.weixin.qq.com/cgi-bin/token" + //CacheKeyOfficialAccountPrefix 微信公众号cache key前缀 + CacheKeyOfficialAccountPrefix = "gowechat_officialaccount_" +) + +//DefaultAccessToken 默认AccessToken 获取 +type DefaultAccessToken struct { + appID string + appSecret string + cacheKeyPrefix string + cache cache.Cache + accessTokenLock *sync.Mutex +} + +//NewDefaultAccessToken new DefaultAccessToken +func NewDefaultAccessToken(appID, appSecret, cacheKeyPrefix string, cache cache.Cache) AccessTokenHandle { + if cache == nil { + panic("cache is ineed") + } + return &DefaultAccessToken{ + appID: appID, + appSecret: appSecret, + cache: cache, + cacheKeyPrefix: cacheKeyPrefix, + accessTokenLock: new(sync.Mutex), + } +} + +//ResAccessToken struct +type ResAccessToken struct { + util.CommonError + + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` +} + +//GetAccessToken 获取access_token,先从cache中获取,没有则从服务端获取 +func (ak *DefaultAccessToken) GetAccessToken() (accessToken string, err error) { + //加上lock,是为了防止在并发获取token时,cache刚好失效,导致从微信服务器上获取到不同token + ak.accessTokenLock.Lock() + defer ak.accessTokenLock.Unlock() + + accessTokenCacheKey := fmt.Sprintf("%s_access_token_%s", ak.cacheKeyPrefix, ak.appID) + val := ak.cache.Get(accessTokenCacheKey) + if val != nil { + accessToken = val.(string) + return + } + + //cache失效,从微信服务器获取 + var resAccessToken ResAccessToken + resAccessToken, err = GetTokenFromServer(ak.appID, ak.appSecret) + if err != nil { + return + } + + expires := resAccessToken.ExpiresIn - 1500 + err = ak.cache.Set(accessTokenCacheKey, resAccessToken.AccessToken, time.Duration(expires)*time.Second) + if err != nil { + return + } + accessToken = resAccessToken.AccessToken + return +} + +//GetTokenFromServer 强制从微信服务器获取token +func GetTokenFromServer(appID, appSecret string) (resAccessToken ResAccessToken, err error) { + url := fmt.Sprintf("%s?grant_type=client_credential&appid=%s&secret=%s", accessTokenURL, appID, appSecret) + var body []byte + body, err = util.HTTPGet(url) + if err != nil { + return + } + err = json.Unmarshal(body, &resAccessToken) + if err != nil { + return + } + if resAccessToken.ErrMsg != "" { + err = fmt.Errorf("get access_token error : errcode=%v , errormsg=%v", resAccessToken.ErrCode, resAccessToken.ErrMsg) + return + } + return +} diff --git a/credential/default_js_ticket.go b/credential/default_js_ticket.go new file mode 100644 index 0000000..5567424 --- /dev/null +++ b/credential/default_js_ticket.go @@ -0,0 +1,80 @@ +package credential + +import ( + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/silenceper/wechat/v2/cache" + "github.com/silenceper/wechat/v2/util" +) + +//获取ticket的url +const getTicketURL = "https://api.weixin.qq.com/cgi-bin/ticket/getticket?access_token=%s&type=jsapi" + +//DefaultJsTicket 默认获取js ticket方法 +type DefaultJsTicket struct { + appID string + cacheKeyPrefix string + cache cache.Cache + //jsAPITicket 读写锁 同一个AppID一个 + jsAPITicketLock *sync.Mutex +} + +//NewDefaultJsTicket new +func NewDefaultJsTicket(appID string, cacheKeyPrefix string, cache cache.Cache) JsTicketHandle { + return &DefaultJsTicket{ + appID: appID, + cache: cache, + cacheKeyPrefix: cacheKeyPrefix, + jsAPITicketLock: new(sync.Mutex), + } +} + +// ResTicket 请求jsapi_tikcet返回结果 +type ResTicket struct { + util.CommonError + + Ticket string `json:"ticket"` + ExpiresIn int64 `json:"expires_in"` +} + +//GetTicket 获取jsapi_ticket +func (js *DefaultJsTicket) GetTicket(accessToken string) (ticketStr string, err error) { + js.jsAPITicketLock.Lock() + defer js.jsAPITicketLock.Unlock() + + //先从cache中取 + jsAPITicketCacheKey := fmt.Sprintf("%s_jsapi_ticket_%s", js.cacheKeyPrefix, js.appID) + val := js.cache.Get(jsAPITicketCacheKey) + if val != nil { + ticketStr = val.(string) + return + } + var ticket ResTicket + ticket, err = GetTicketFromServer(accessToken) + if err != nil { + return + } + expires := ticket.ExpiresIn - 1500 + err = js.cache.Set(jsAPITicketCacheKey, ticket.Ticket, time.Duration(expires)*time.Second) + ticketStr = ticket.Ticket + return +} + +//GetTicketFromServer 从服务器中获取ticket +func GetTicketFromServer(accessToken string) (ticket ResTicket, err error) { + var response []byte + url := fmt.Sprintf(getTicketURL, accessToken) + response, err = util.HTTPGet(url) + err = json.Unmarshal(response, &ticket) + if err != nil { + return + } + if ticket.ErrCode != 0 { + err = fmt.Errorf("getTicket Error : errcode=%d , errmsg=%s", ticket.ErrCode, ticket.ErrMsg) + return + } + return +} diff --git a/credential/js_ticket.go b/credential/js_ticket.go new file mode 100644 index 0000000..e6f4ebc --- /dev/null +++ b/credential/js_ticket.go @@ -0,0 +1,7 @@ +package credential + +//JsTicketHandle js ticket获取 +type JsTicketHandle interface { + //GetTicket 获取ticket + GetTicket(accessToken string) (ticket string, err error) +} diff --git a/officialaccount/context/access_token.go b/officialaccount/context/access_token.go deleted file mode 100644 index 56069e7..0000000 --- a/officialaccount/context/access_token.go +++ /dev/null @@ -1,87 +0,0 @@ -package context - -import ( - "encoding/json" - "fmt" - "sync" - "time" - - "github.com/silenceper/wechat/v2/util" -) - -const ( - //AccessTokenURL 获取access_token的接口 - AccessTokenURL = "https://api.weixin.qq.com/cgi-bin/token" - //CacheKeyPrefix 微信公众号cache key前缀 - CacheKeyPrefix = "gowechat_officialaccount_" -) - -//ResAccessToken struct -type ResAccessToken struct { - util.CommonError - - AccessToken string `json:"access_token"` - ExpiresIn int64 `json:"expires_in"` -} - -//GetAccessTokenFunc 获取 access token 的函数签名 -type GetAccessTokenFunc func(ctx *Context) (accessToken string, err error) - -//SetAccessTokenLock 设置读写锁(一个appID一个读写锁) -func (ctx *Context) SetAccessTokenLock(l *sync.RWMutex) { - ctx.accessTokenLock = l -} - -//SetGetAccessTokenFunc 设置自定义获取accessToken的方式, 需要自己实现缓存 -func (ctx *Context) SetGetAccessTokenFunc(f GetAccessTokenFunc) { - ctx.accessTokenFunc = f -} - -//GetAccessToken 获取access_token -func (ctx *Context) GetAccessToken() (accessToken string, err error) { - ctx.accessTokenLock.Lock() - defer ctx.accessTokenLock.Unlock() - - if ctx.accessTokenFunc != nil { - return ctx.accessTokenFunc(ctx) - } - accessTokenCacheKey := fmt.Sprintf("%s_access_token_%s", CacheKeyPrefix, ctx.AppID) - val := ctx.Cache.Get(accessTokenCacheKey) - if val != nil { - accessToken = val.(string) - return - } - - //从微信服务器获取 - var resAccessToken ResAccessToken - resAccessToken, err = ctx.GetAccessTokenFromServer() - if err != nil { - return - } - - accessToken = resAccessToken.AccessToken - return -} - -//GetAccessTokenFromServer 强制从微信服务器获取token -func (ctx *Context) GetAccessTokenFromServer() (resAccessToken ResAccessToken, err error) { - url := fmt.Sprintf("%s?grant_type=client_credential&appid=%s&secret=%s", AccessTokenURL, ctx.AppID, ctx.AppSecret) - var body []byte - body, err = util.HTTPGet(url) - if err != nil { - return - } - err = json.Unmarshal(body, &resAccessToken) - if err != nil { - return - } - if resAccessToken.ErrMsg != "" { - err = fmt.Errorf("get access_token error : errcode=%v , errormsg=%v", resAccessToken.ErrCode, resAccessToken.ErrMsg) - return - } - - accessTokenCacheKey := fmt.Sprintf("%s_access_token_%s", CacheKeyPrefix, ctx.AppID) - expires := resAccessToken.ExpiresIn - 1500 - err = ctx.Cache.Set(accessTokenCacheKey, resAccessToken.AccessToken, time.Duration(expires)*time.Second) - return -} diff --git a/officialaccount/context/context.go b/officialaccount/context/context.go index a7d12a9..418b9ad 100644 --- a/officialaccount/context/context.go +++ b/officialaccount/context/context.go @@ -1,31 +1,12 @@ package context import ( - "sync" - + "github.com/silenceper/wechat/v2/credential" "github.com/silenceper/wechat/v2/officialaccount/config" ) // Context struct type Context struct { *config.Config - - //accessTokenLock 读写锁 同一个AppID一个 - accessTokenLock *sync.RWMutex - - //jsAPITicket 读写锁 同一个AppID一个 - jsAPITicketLock *sync.RWMutex - - //accessTokenFunc 自定义获取 access token 的方法 - accessTokenFunc GetAccessTokenFunc -} - -// SetJsAPITicketLock 设置jsAPITicket的lock -func (ctx *Context) SetJsAPITicketLock(lock *sync.RWMutex) { - ctx.jsAPITicketLock = lock -} - -// GetJsAPITicketLock 获取jsAPITicket 的lock -func (ctx *Context) GetJsAPITicketLock() *sync.RWMutex { - return ctx.jsAPITicketLock + credential.AccessTokenHandle } diff --git a/officialaccount/js/js.go b/officialaccount/js/js.go index 16b20f3..1c1b77b 100644 --- a/officialaccount/js/js.go +++ b/officialaccount/js/js.go @@ -1,19 +1,17 @@ package js import ( - "encoding/json" "fmt" - "time" + "github.com/silenceper/wechat/v2/credential" "github.com/silenceper/wechat/v2/officialaccount/context" "github.com/silenceper/wechat/v2/util" ) -const getTicketURL = "https://api.weixin.qq.com/cgi-bin/ticket/getticket?access_token=%s&type=jsapi" - // Js struct type Js struct { *context.Context + credential.JsTicketHandle } // Config 返回给用户jssdk配置信息 @@ -24,27 +22,31 @@ type Config struct { Signature string `json:"signature"` } -// resTicket 请求jsapi_tikcet返回结果 -type resTicket struct { - util.CommonError - - Ticket string `json:"ticket"` - ExpiresIn int64 `json:"expires_in"` -} - //NewJs init func NewJs(context *context.Context) *Js { js := new(Js) js.Context = context + jsTicketHandle := credential.NewDefaultJsTicket(context.AppID, credential.CacheKeyOfficialAccountPrefix, context.Cache) + js.SetJsTicketHandle(jsTicketHandle) return js } +//SetJsTicketHandle 自定义js ticket取值方式 +func (js *Js) SetJsTicketHandle(ticketHandle credential.JsTicketHandle) { + js.JsTicketHandle = ticketHandle +} + //GetConfig 获取jssdk需要的配置参数 //uri 为当前网页地址 func (js *Js) GetConfig(uri string) (config *Config, err error) { config = new(Config) + var accessToken string + accessToken, err = js.GetAccessToken() + if err != nil { + return + } var ticketStr string - ticketStr, err = js.GetTicket() + ticketStr, err = js.GetTicket(accessToken) if err != nil { return } @@ -60,50 +62,3 @@ func (js *Js) GetConfig(uri string) (config *Config, err error) { config.Signature = sigStr return } - -//GetTicket 获取jsapi_ticket -func (js *Js) GetTicket() (ticketStr string, err error) { - js.GetJsAPITicketLock().Lock() - defer js.GetJsAPITicketLock().Unlock() - - //先从cache中取 - jsAPITicketCacheKey := fmt.Sprintf("%s_jsapi_ticket_%s", context.CacheKeyPrefix, js.AppID) - val := js.Cache.Get(jsAPITicketCacheKey) - if val != nil { - ticketStr = val.(string) - return - } - var ticket resTicket - ticket, err = js.getTicketFromServer() - if err != nil { - return - } - ticketStr = ticket.Ticket - return -} - -//getTicketFromServer 强制从服务器中获取ticket -func (js *Js) getTicketFromServer() (ticket resTicket, err error) { - var accessToken string - accessToken, err = js.GetAccessToken() - if err != nil { - return - } - - var response []byte - url := fmt.Sprintf(getTicketURL, accessToken) - response, err = util.HTTPGet(url) - err = json.Unmarshal(response, &ticket) - if err != nil { - return - } - if ticket.ErrCode != 0 { - err = fmt.Errorf("getTicket Error : errcode=%d , errmsg=%s", ticket.ErrCode, ticket.ErrMsg) - return - } - - jsAPITicketCacheKey := fmt.Sprintf("%s_jsapi_ticket_%s", context.CacheKeyPrefix, js.AppID) - expires := ticket.ExpiresIn - 1500 - err = js.Cache.Set(jsAPITicketCacheKey, ticket.Ticket, time.Duration(expires)*time.Second) - return -} diff --git a/officialaccount/officialaccount.go b/officialaccount/officialaccount.go index d26d733..ac3fdf9 100644 --- a/officialaccount/officialaccount.go +++ b/officialaccount/officialaccount.go @@ -2,8 +2,8 @@ package officialaccount import ( "net/http" - "sync" + "github.com/silenceper/wechat/v2/credential" "github.com/silenceper/wechat/v2/officialaccount/basic" "github.com/silenceper/wechat/v2/officialaccount/config" "github.com/silenceper/wechat/v2/officialaccount/context" @@ -24,15 +24,17 @@ type OfficialAccount struct { //NewOfficialAccount 实例化公众号API func NewOfficialAccount(cfg *config.Config) *OfficialAccount { - //if cfg.Cache == nil { - // panic("cache未设置") - //} + defaultAK := credential.NewDefaultAccessToken(cfg.AppID, cfg.AppSecret, credential.CacheKeyOfficialAccountPrefix, cfg.Cache) ctx := &context.Context{ - Config: cfg, + Config: cfg, + AccessTokenHandle: defaultAK, } - ctx.SetAccessTokenLock(new(sync.RWMutex)) - ctx.SetJsAPITicketLock(new(sync.RWMutex)) - return &OfficialAccount{ctx} + return &OfficialAccount{ctx: ctx} +} + +//SetAccessTokenHandle 自定义access_token获取方式 +func (officialAccount *OfficialAccount) SetAccessTokenHandle(accessTokenHandle credential.AccessTokenHandle) { + officialAccount.ctx.AccessTokenHandle = accessTokenHandle } // GetContext get Context diff --git a/openplatform/officialaccount/officialaccount.go b/openplatform/officialaccount/officialaccount.go index c7d4c4f..a636ac8 100644 --- a/openplatform/officialaccount/officialaccount.go +++ b/openplatform/officialaccount/officialaccount.go @@ -1,9 +1,9 @@ package officialaccount import ( + "github.com/silenceper/wechat/v2/credential" "github.com/silenceper/wechat/v2/officialaccount" offConfig "github.com/silenceper/wechat/v2/officialaccount/config" - offContext "github.com/silenceper/wechat/v2/officialaccount/context" opContext "github.com/silenceper/wechat/v2/openplatform/context" ) @@ -25,9 +25,25 @@ func NewOfficialAccount(opCtx *opContext.Context, appID string) *OfficialAccount Cache: opCtx.Cache, }) //设置获取access_token的函数 - officialAccount.GetContext().SetGetAccessTokenFunc(func(offCtx *offContext.Context) (accessToken string, err error) { - // 获取授权方的access_token - return opCtx.GetAuthrAccessToken(appID) - }) + officialAccount.SetAccessTokenHandle(NewDefaultAuthrAccessToken(opCtx, appID)) return &OfficialAccount{appID: appID, OfficialAccount: officialAccount} } + +//DefaultAuthrAccessToken 默认获取授权ak的方法 +type DefaultAuthrAccessToken struct { + opCtx *opContext.Context + appID string +} + +//NewDefaultAuthrAccessToken New +func NewDefaultAuthrAccessToken(opCtx *opContext.Context, appID string) credential.AccessTokenHandle { + return &DefaultAuthrAccessToken{ + opCtx: opCtx, + appID: appID, + } +} + +//GetAccessToken 获取ak +func (ak *DefaultAuthrAccessToken) GetAccessToken() (string, error) { + return ak.opCtx.GetAuthrAccessToken(ak.appID) +}