diff --git a/credential/access_token.go b/credential/access_token.go index fcc4668..6094a02 100644 --- a/credential/access_token.go +++ b/credential/access_token.go @@ -1,6 +1,14 @@ package credential +import "context" + // AccessTokenHandle AccessToken 接口 type AccessTokenHandle interface { GetAccessToken() (accessToken string, err error) } + +// AccessTokenContextHandle AccessToken 接口 +type AccessTokenContextHandle interface { + AccessTokenHandle + GetAccessTokenContext(ctx context.Context) (accessToken string, err error) +} diff --git a/credential/default_access_token.go b/credential/default_access_token.go index 7c91544..d58efe6 100644 --- a/credential/default_access_token.go +++ b/credential/default_access_token.go @@ -1,6 +1,7 @@ package credential import ( + "context" "encoding/json" "fmt" "sync" @@ -33,7 +34,7 @@ type DefaultAccessToken struct { } // NewDefaultAccessToken new DefaultAccessToken -func NewDefaultAccessToken(appID, appSecret, cacheKeyPrefix string, cache cache.Cache) AccessTokenHandle { +func NewDefaultAccessToken(appID, appSecret, cacheKeyPrefix string, cache cache.Cache) AccessTokenContextHandle { if cache == nil { panic("cache is ineed") } @@ -56,6 +57,11 @@ type ResAccessToken struct { // GetAccessToken 获取access_token,先从cache中获取,没有则从服务端获取 func (ak *DefaultAccessToken) GetAccessToken() (accessToken string, err error) { + return ak.GetAccessTokenContext(context.Background()) +} + +// GetAccessTokenContext 获取access_token,先从cache中获取,没有则从服务端获取 +func (ak *DefaultAccessToken) GetAccessTokenContext(ctx context.Context) (accessToken string, err error) { // 先从cache中取 accessTokenCacheKey := fmt.Sprintf("%s_access_token_%s", ak.cacheKeyPrefix, ak.appID) if val := ak.cache.Get(accessTokenCacheKey); val != nil { @@ -73,7 +79,7 @@ func (ak *DefaultAccessToken) GetAccessToken() (accessToken string, err error) { // cache失效,从微信服务器获取 var resAccessToken ResAccessToken - resAccessToken, err = GetTokenFromServer(fmt.Sprintf(accessTokenURL, ak.appID, ak.appSecret)) + resAccessToken, err = GetTokenFromServerContext(ctx, fmt.Sprintf(accessTokenURL, ak.appID, ak.appSecret)) if err != nil { return } @@ -97,7 +103,7 @@ type WorkAccessToken struct { } // NewWorkAccessToken new WorkAccessToken -func NewWorkAccessToken(corpID, corpSecret, cacheKeyPrefix string, cache cache.Cache) AccessTokenHandle { +func NewWorkAccessToken(corpID, corpSecret, cacheKeyPrefix string, cache cache.Cache) AccessTokenContextHandle { if cache == nil { panic("cache the not exist") } @@ -112,6 +118,11 @@ func NewWorkAccessToken(corpID, corpSecret, cacheKeyPrefix string, cache cache.C // GetAccessToken 企业微信获取access_token,先从cache中获取,没有则从服务端获取 func (ak *WorkAccessToken) GetAccessToken() (accessToken string, err error) { + return ak.GetAccessTokenContext(context.Background()) +} + +// GetAccessTokenContext 企业微信获取access_token,先从cache中获取,没有则从服务端获取 +func (ak *WorkAccessToken) GetAccessTokenContext(ctx context.Context) (accessToken string, err error) { // 加上lock,是为了防止在并发获取token时,cache刚好失效,导致从微信服务器上获取到不同token ak.accessTokenLock.Lock() defer ak.accessTokenLock.Unlock() @@ -124,7 +135,7 @@ func (ak *WorkAccessToken) GetAccessToken() (accessToken string, err error) { // cache失效,从微信服务器获取 var resAccessToken ResAccessToken - resAccessToken, err = GetTokenFromServer(fmt.Sprintf(workAccessTokenURL, ak.CorpID, ak.CorpSecret)) + resAccessToken, err = GetTokenFromServerContext(ctx, fmt.Sprintf(workAccessTokenURL, ak.CorpID, ak.CorpSecret)) if err != nil { return } @@ -140,8 +151,13 @@ func (ak *WorkAccessToken) GetAccessToken() (accessToken string, err error) { // GetTokenFromServer 强制从微信服务器获取token func GetTokenFromServer(url string) (resAccessToken ResAccessToken, err error) { + return GetTokenFromServerContext(context.Background(), url) +} + +// GetTokenFromServerContext 强制从微信服务器获取token +func GetTokenFromServerContext(ctx context.Context, url string) (resAccessToken ResAccessToken, err error) { var body []byte - body, err = util.HTTPGet(url) + body, err = util.HTTPGetContext(ctx, url) if err != nil { return } diff --git a/go.sum b/go.sum index df307bb..f9efc27 100644 --- a/go.sum +++ b/go.sum @@ -26,12 +26,10 @@ github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:W github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= @@ -48,7 +46,6 @@ github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108 github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= -github.com/onsi/ginkgo/v2 v2.0.0 h1:CcuG/HvWNkkaqCUpJifQY8z7qEMBJya6aLPx6ftGyjQ= github.com/onsi/ginkgo/v2 v2.0.0/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= @@ -118,7 +115,6 @@ golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4f golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= @@ -127,7 +123,6 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/officialaccount/officialaccount.go b/officialaccount/officialaccount.go index ed3dc71..3444284 100644 --- a/officialaccount/officialaccount.go +++ b/officialaccount/officialaccount.go @@ -1,6 +1,7 @@ package officialaccount import ( + stdcontext "context" "net/http" "github.com/silenceper/wechat/v2/officialaccount/draft" @@ -94,6 +95,14 @@ func (officialAccount *OfficialAccount) GetAccessToken() (string, error) { return officialAccount.ctx.GetAccessToken() } +// GetAccessTokenContext 获取access_token +func (officialAccount *OfficialAccount) GetAccessTokenContext(ctx stdcontext.Context) (string, error) { + if c, ok := officialAccount.ctx.AccessTokenHandle.(credential.AccessTokenContextHandle); ok { + return c.GetAccessTokenContext(ctx) + } + return officialAccount.ctx.GetAccessToken() +} + // GetOauth oauth2网页授权 func (officialAccount *OfficialAccount) GetOauth() *oauth.Oauth { if officialAccount.oauth == nil {