mirror of
https://github.com/silenceper/wechat.git
synced 2026-02-04 12:52:27 +08:00
Add JSSDK context method functionality (#828)
* Add JSSDK context method functionality * 善JSSDK上下文方法,并添加测试文件 * feat: 完善JSSDK上下文方法,保证协程安全,并添加测试文件 * 修改 import 包分组处理 * feat: 修改测试文件中 fmt.Print -> t.Log * 删除空行
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package credential
|
package credential
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
context2 "context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -42,6 +43,16 @@ type ResTicket struct {
|
|||||||
|
|
||||||
// GetTicket 获取jsapi_ticket
|
// GetTicket 获取jsapi_ticket
|
||||||
func (js *DefaultJsTicket) GetTicket(accessToken string) (ticketStr string, err error) {
|
func (js *DefaultJsTicket) GetTicket(accessToken string) (ticketStr string, err error) {
|
||||||
|
return js.GetTicketContext(context2.Background(), accessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTicketFromServer 从服务器中获取ticket
|
||||||
|
func GetTicketFromServer(accessToken string) (ticket ResTicket, err error) {
|
||||||
|
return GetTicketFromServerContext(context2.Background(), accessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTicketContext 获取jsapi_ticket
|
||||||
|
func (js *DefaultJsTicket) GetTicketContext(ctx context2.Context, accessToken string) (ticketStr string, err error) {
|
||||||
// 先从cache中取
|
// 先从cache中取
|
||||||
jsAPITicketCacheKey := fmt.Sprintf("%s_jsapi_ticket_%s", js.cacheKeyPrefix, js.appID)
|
jsAPITicketCacheKey := fmt.Sprintf("%s_jsapi_ticket_%s", js.cacheKeyPrefix, js.appID)
|
||||||
if val := js.cache.Get(jsAPITicketCacheKey); val != nil {
|
if val := js.cache.Get(jsAPITicketCacheKey); val != nil {
|
||||||
@@ -57,7 +68,7 @@ func (js *DefaultJsTicket) GetTicket(accessToken string) (ticketStr string, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
var ticket ResTicket
|
var ticket ResTicket
|
||||||
ticket, err = GetTicketFromServer(accessToken)
|
ticket, err = GetTicketFromServerContext(ctx, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -67,11 +78,11 @@ func (js *DefaultJsTicket) GetTicket(accessToken string) (ticketStr string, err
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTicketFromServer 从服务器中获取ticket
|
// GetTicketFromServerContext 从服务器中获取ticket
|
||||||
func GetTicketFromServer(accessToken string) (ticket ResTicket, err error) {
|
func GetTicketFromServerContext(ctx context2.Context, accessToken string) (ticket ResTicket, err error) {
|
||||||
var response []byte
|
var response []byte
|
||||||
url := fmt.Sprintf(getTicketURL, accessToken)
|
url := fmt.Sprintf(getTicketURL, accessToken)
|
||||||
response, err = util.HTTPGet(url)
|
response, err = util.HTTPGetContext(ctx, url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
22
credential/default_js_ticket_test.go
Normal file
22
credential/default_js_ticket_test.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package credential
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"gopkg.in/h2non/gock.v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestGetTicketFromServerContext 测试 GetTicketFromServerContext 函数
|
||||||
|
func TestGetTicketFromServerContext(t *testing.T) {
|
||||||
|
defer gock.Off()
|
||||||
|
gock.New(fmt.Sprintf(getTicketURL, "arg-ak")).Reply(200).JSON(&ResTicket{Ticket: "mock-ticket", ExpiresIn: 10})
|
||||||
|
|
||||||
|
ticket, err := GetTicketFromServerContext(context.Background(), "arg-ak")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, int64(0), ticket.ErrCode)
|
||||||
|
assert.Equal(t, "mock-ticket", ticket.Ticket, "they should be equal")
|
||||||
|
assert.Equal(t, int64(10), ticket.ExpiresIn, "they should be equal")
|
||||||
|
}
|
||||||
@@ -1,7 +1,15 @@
|
|||||||
package credential
|
package credential
|
||||||
|
|
||||||
|
import context2 "context"
|
||||||
|
|
||||||
// JsTicketHandle js ticket获取
|
// JsTicketHandle js ticket获取
|
||||||
type JsTicketHandle interface {
|
type JsTicketHandle interface {
|
||||||
// GetTicket 获取ticket
|
// GetTicket 获取ticket
|
||||||
GetTicket(accessToken string) (ticket string, err error)
|
GetTicket(accessToken string) (ticket string, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JsTicketContextHandle js ticket获取
|
||||||
|
type JsTicketContextHandle interface {
|
||||||
|
JsTicketHandle
|
||||||
|
GetTicketContext(ctx context2.Context, accessToken string) (ticket string, err error)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package js
|
package js
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
context2 "context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/silenceper/wechat/v2/credential"
|
"github.com/silenceper/wechat/v2/credential"
|
||||||
@@ -39,20 +40,40 @@ func (js *Js) SetJsTicketHandle(ticketHandle credential.JsTicketHandle) {
|
|||||||
// GetConfig 获取jssdk需要的配置参数
|
// GetConfig 获取jssdk需要的配置参数
|
||||||
// uri 为当前网页地址
|
// uri 为当前网页地址
|
||||||
func (js *Js) GetConfig(uri string) (config *Config, err error) {
|
func (js *Js) GetConfig(uri string) (config *Config, err error) {
|
||||||
|
return js.GetConfigContext(context2.Background(), uri)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfigContext 新方法,允许传入上下文,避免协程泄漏
|
||||||
|
func (js *Js) GetConfigContext(ctx context2.Context, uri string) (config *Config, err error) {
|
||||||
var accessToken string
|
var accessToken string
|
||||||
accessToken, err = js.GetAccessToken()
|
// 类型断言,如果断言成功,调用安全的 GetAccessTokenContext 方法
|
||||||
|
if ctxHandle, ok := js.Context.AccessTokenHandle.(credential.AccessTokenContextHandle); ok {
|
||||||
|
accessToken, err = ctxHandle.GetAccessTokenContext(ctx)
|
||||||
|
} else {
|
||||||
|
// 如果没有实现 AccessTokenContextHandle 接口,调用旧的 GetAccessToken 方法
|
||||||
|
accessToken, err = js.Context.GetAccessToken()
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var ticketStr string
|
var ticketStr string
|
||||||
ticketStr, err = js.GetTicket(accessToken)
|
// 类型断言 jsTicket
|
||||||
|
if ticketCtxHandle, ok := js.JsTicketHandle.(credential.JsTicketContextHandle); ok {
|
||||||
|
ticketStr, err = ticketCtxHandle.GetTicketContext(ctx, accessToken)
|
||||||
|
} else {
|
||||||
|
// 如果没有实现 JsTicketContextHandle 接口,调用旧的 GetTicket 方法
|
||||||
|
ticketStr, err = js.GetTicket(accessToken)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nonceStr := util.RandomStr(16)
|
nonceStr := util.RandomStr(16)
|
||||||
timestamp := util.GetCurrTS()
|
timestamp := util.GetCurrTS()
|
||||||
str := fmt.Sprintf("jsapi_ticket=%s&noncestr=%s×tamp=%d&url=%s", ticketStr, nonceStr, timestamp, uri)
|
str := fmt.Sprintf("jsapi_ticket=%s&noncestr=%s×tamp=%d&url=%s", ticketStr, nonceStr, timestamp, uri)
|
||||||
sigStr := util.Signature(str)
|
sigStr := util.Signature(str)
|
||||||
|
|
||||||
config = new(Config)
|
config = new(Config)
|
||||||
config.AppID = js.AppID
|
config.AppID = js.AppID
|
||||||
config.NonceStr = nonceStr
|
config.NonceStr = nonceStr
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package js
|
package js
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
context2 "context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/silenceper/wechat/v2/credential"
|
"github.com/silenceper/wechat/v2/credential"
|
||||||
@@ -32,14 +33,31 @@ func (js *Js) SetJsTicketHandle(ticketHandle credential.JsTicketHandle) {
|
|||||||
// GetConfig 第三方平台 - 获取jssdk需要的配置参数
|
// GetConfig 第三方平台 - 获取jssdk需要的配置参数
|
||||||
// uri 为当前网页地址
|
// uri 为当前网页地址
|
||||||
func (js *Js) GetConfig(uri, appid string) (config *officialJs.Config, err error) {
|
func (js *Js) GetConfig(uri, appid string) (config *officialJs.Config, err error) {
|
||||||
config = new(officialJs.Config)
|
return js.GetConfigContext(context2.Background(), uri, appid)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfigContext 新方法,允许传入上下文,避免协程泄漏
|
||||||
|
func (js *Js) GetConfigContext(ctx context2.Context, uri, appid string) (config *officialJs.Config, err error) {
|
||||||
var accessToken string
|
var accessToken string
|
||||||
accessToken, err = js.GetAccessToken()
|
// 类型断言,如果断言成功,调用安全的 GetAccessTokenContext 方法
|
||||||
|
if ctxHandle, ok := js.Context.AccessTokenHandle.(credential.AccessTokenContextHandle); ok {
|
||||||
|
accessToken, err = ctxHandle.GetAccessTokenContext(ctx)
|
||||||
|
} else {
|
||||||
|
// 如果没有实现 AccessTokenContextHandle 接口,调用旧的 GetAccessToken 方法
|
||||||
|
accessToken, err = js.Context.GetAccessToken()
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var ticketStr string
|
var ticketStr string
|
||||||
ticketStr, err = js.GetTicket(accessToken)
|
// 类型断言 jsTicket
|
||||||
|
if ticketCtxHandle, ok := js.JsTicketHandle.(credential.JsTicketContextHandle); ok {
|
||||||
|
ticketStr, err = ticketCtxHandle.GetTicketContext(ctx, accessToken)
|
||||||
|
} else {
|
||||||
|
// 如果没有实现 JsTicketContextHandle 接口,调用旧的 GetTicket 方法
|
||||||
|
ticketStr, err = js.GetTicket(accessToken)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -49,6 +67,7 @@ func (js *Js) GetConfig(uri, appid string) (config *officialJs.Config, err error
|
|||||||
str := fmt.Sprintf("jsapi_ticket=%s&noncestr=%s×tamp=%d&url=%s", ticketStr, nonceStr, timestamp, uri)
|
str := fmt.Sprintf("jsapi_ticket=%s&noncestr=%s×tamp=%d&url=%s", ticketStr, nonceStr, timestamp, uri)
|
||||||
sigStr := util.Signature(str)
|
sigStr := util.Signature(str)
|
||||||
|
|
||||||
|
config = new(officialJs.Config)
|
||||||
config.AppID = appid
|
config.AppID = appid
|
||||||
config.NonceStr = nonceStr
|
config.NonceStr = nonceStr
|
||||||
config.Timestamp = timestamp
|
config.Timestamp = timestamp
|
||||||
|
|||||||
147
openplatform/officialaccount/js/js_test.go
Normal file
147
openplatform/officialaccount/js/js_test.go
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
// 验证 js.GetConfigContext 是否能正确传递上下文到 HTTP 请求,确保上下文正确传播,防止在获取 JSSDK 配置时发生协程泄露。
|
||||||
|
package js
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
context2 "context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/silenceper/wechat/v2/cache"
|
||||||
|
"github.com/silenceper/wechat/v2/credential"
|
||||||
|
"github.com/silenceper/wechat/v2/officialaccount/config"
|
||||||
|
"github.com/silenceper/wechat/v2/officialaccount/context"
|
||||||
|
"github.com/silenceper/wechat/v2/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockAccessTokenHandle 模拟 AccessTokenHandle
|
||||||
|
type mockAccessTokenHandle struct{}
|
||||||
|
|
||||||
|
func (m *mockAccessTokenHandle) GetAccessToken() (string, error) {
|
||||||
|
return "mock-access-token", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAccessTokenHandle) GetAccessTokenContext(_ context2.Context) (string, error) {
|
||||||
|
return "mock-access-token", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// contextCheckingRoundTripper 自定义 RoundTripper 用于检查 context
|
||||||
|
type contextCheckingRoundTripper struct {
|
||||||
|
originalCtx context2.Context
|
||||||
|
t *testing.T
|
||||||
|
key interface{}
|
||||||
|
expectedVal interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rt *contextCheckingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
// 获取请求中的 context
|
||||||
|
reqCtx := req.Context()
|
||||||
|
|
||||||
|
// 打印 context 比较结果
|
||||||
|
rt.t.Logf("比较上下文的内存地址:\n")
|
||||||
|
if reqCtx == rt.originalCtx {
|
||||||
|
rt.t.Logf("上下文具有相同的内存地址。原始上下文: %p, 请求上下文: %p\n", rt.originalCtx, reqCtx)
|
||||||
|
} else {
|
||||||
|
rt.t.Logf("上下文具有不同的内存地址。原始上下文: %p, 请求上下文: %p\n", rt.originalCtx, reqCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 context 中的键值对
|
||||||
|
if rt.key != nil {
|
||||||
|
value := reqCtx.Value(rt.key)
|
||||||
|
rt.t.Logf("检查请求上下文中的键 %v:\n", rt.key)
|
||||||
|
if value != rt.expectedVal {
|
||||||
|
rt.t.Errorf("上下文键 %v 的值不匹配: 预期 %v, 实际 %v\n", rt.key, rt.expectedVal, value)
|
||||||
|
} else {
|
||||||
|
rt.t.Logf("上下文键 %v 的值匹配: 预期 %v, 实际 %v\n", rt.key, rt.expectedVal, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查上下文是否已取消
|
||||||
|
select {
|
||||||
|
case <-reqCtx.Done():
|
||||||
|
return nil, reqCtx.Err() // 返回上下文取消错误
|
||||||
|
default:
|
||||||
|
// 返回模拟的 HTTP 响应,包含有效的 JSON
|
||||||
|
responseBody := `{"ticket":"mock-ticket","expires_in":7200}`
|
||||||
|
response := &http.Response{
|
||||||
|
Status: "200 OK",
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Proto: "HTTP/1.1",
|
||||||
|
ProtoMajor: 1,
|
||||||
|
ProtoMinor: 1,
|
||||||
|
Body: io.NopCloser(bytes.NewReader([]byte(responseBody))),
|
||||||
|
ContentLength: int64(len(responseBody)),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
response.Header.Set("Content-Type", "application/json")
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// contextKey 定义自定义上下文键类型,避免使用内置 string 类型
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
// setupJsInstance 初始化 Js 实例和 HTTP 客户端
|
||||||
|
func setupJsInstance(t *testing.T, ctx context2.Context, key, val interface{}) (*Js, func()) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
AppID: "test-app-id",
|
||||||
|
AppSecret: "test-app-secret",
|
||||||
|
Cache: cache.NewMemory(),
|
||||||
|
}
|
||||||
|
cacheKey := fmt.Sprintf("%s_jsapi_ticket_%s", credential.CacheKeyOfficialAccountPrefix, cfg.AppID)
|
||||||
|
if err := cfg.Cache.Delete(cacheKey); err != nil {
|
||||||
|
t.Fatalf("清除缓存失败: %v", err)
|
||||||
|
}
|
||||||
|
t.Log("清除 jsapi_ticket 的缓存:", cacheKey)
|
||||||
|
|
||||||
|
ctxHandle := &context.Context{Config: cfg, AccessTokenHandle: &mockAccessTokenHandle{}}
|
||||||
|
jsInstance := NewJs(ctxHandle, cfg.AppID)
|
||||||
|
jsInstance.SetJsTicketHandle(credential.NewDefaultJsTicket(cfg.AppID, credential.CacheKeyOfficialAccountPrefix, cfg.Cache))
|
||||||
|
|
||||||
|
originalClient := util.DefaultHTTPClient
|
||||||
|
util.DefaultHTTPClient = &http.Client{
|
||||||
|
Transport: &contextCheckingRoundTripper{originalCtx: ctx, t: t, key: key, expectedVal: val},
|
||||||
|
}
|
||||||
|
return jsInstance, func() { util.DefaultHTTPClient = originalClient }
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetConfigContext 测试GetConfigContext的上下文传递和取消行为。
|
||||||
|
func TestGetConfigContext(t *testing.T) {
|
||||||
|
t.Run("ContextPassing", func(t *testing.T) {
|
||||||
|
ctxKey := contextKey("testKey111") // 使用自定义类型 contextKey
|
||||||
|
ctxValue := "testValue222"
|
||||||
|
ctx := context2.WithValue(context2.Background(), ctxKey, ctxValue)
|
||||||
|
t.Logf("创建的测试上下文: %p, 添加的键值对: %v=%v\n", ctx, ctxKey, ctxValue)
|
||||||
|
|
||||||
|
jsInstance, cleanup := setupJsInstance(t, ctx, ctxKey, ctxValue)
|
||||||
|
defer cleanup()
|
||||||
|
t.Log("调用 GetConfigContext")
|
||||||
|
config2, err := jsInstance.GetConfigContext(ctx, "https://www.baidu.com", "test-app-id")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetConfigContext 失败: %v", err)
|
||||||
|
}
|
||||||
|
if config2.AppID != "test-app-id" {
|
||||||
|
t.Errorf("预期 AppID 为 %s,实际为 %s", "test-app-id", config2.AppID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ContextCancellation", func(t *testing.T) {
|
||||||
|
ctx, cancel := context2.WithCancel(context2.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
jsInstance, cleanup := setupJsInstance(t, ctx, nil, nil)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
t.Log("调用 GetConfigContext(已取消上下文)")
|
||||||
|
_, err := jsInstance.GetConfigContext(ctx, "https://www.baidu.com", "test-app-id")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("预期上下文取消错误,但 GetConfigContext 未返回错误")
|
||||||
|
} else if !errors.Is(err, context2.Canceled) {
|
||||||
|
t.Errorf("预期错误为 context.Canceled,实际为: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user