diff --git a/context/access_token.go b/context/access_token.go index a7f7810..7405771 100644 --- a/context/access_token.go +++ b/context/access_token.go @@ -22,16 +22,27 @@ type ResAccessToken struct { ExpiresIn int64 `json:"expires_in"` } +//GetAccessTokenFunc 获取 access token 的函数签名 +type GetAccessTokenFunc func(ctx *Context) (accessToken string, err error) + //SetAccessTokenLock 设置读写锁(一个appID一个读写锁) func (ctx *Context) SetAccessTokenLock(l *sync.RWMutex) { ctx.accessTokenLock = l } +//SetGetAccessTokenFunc 设置自定义获取accessToken的方式, 需要自己实现缓存 +func (ctx *Context) SetGetAccessTokenFunc(f GetAccessTokenFunc) { + ctx.accessTokenFunc = f +} + //GetAccessToken 获取access_token func (ctx *Context) GetAccessToken() (accessToken string, err error) { ctx.accessTokenLock.Lock() defer ctx.accessTokenLock.Unlock() + if ctx.accessTokenFunc != nil { + return ctx.accessTokenFunc(ctx) + } accessTokenCacheKey := fmt.Sprintf("access_token_%s", ctx.AppID) val := ctx.Cache.Get(accessTokenCacheKey) if val != nil { diff --git a/context/access_token_test.go b/context/access_token_test.go new file mode 100644 index 0000000..fdae218 --- /dev/null +++ b/context/access_token_test.go @@ -0,0 +1,30 @@ +package context + +import ( + "sync" + "testing" +) + +func TestContext_SetCustomAccessTokenFunc(t *testing.T) { + ctx := Context{ + accessTokenLock: new(sync.RWMutex), + } + f := func(ctx *Context) (accessToken string, err error) { + return "fake token", nil + } + ctx.SetGetAccessTokenFunc(f) + res, err := ctx.GetAccessToken() + if res != "fake token" || err != nil { + t.Error("expect fake token but error") + } +} + +func TestContext_NoSetCustomAccessTokenFunc(t *testing.T) { + ctx := Context{ + accessTokenLock: new(sync.RWMutex), + } + + if ctx.accessTokenFunc != nil { + t.Error("error accessTokenFunc") + } +} diff --git a/context/context.go b/context/context.go index 45bcf50..07e42c3 100644 --- a/context/context.go +++ b/context/context.go @@ -27,6 +27,9 @@ type Context struct { //jsAPITicket 读写锁 同一个AppID一个 jsAPITicketLock *sync.RWMutex + + //accessTokenFunc 自定义获取 access token 的方法 + accessTokenFunc GetAccessTokenFunc } // Query returns the keyed url query value if it exists