1
0
mirror of https://github.com/silenceper/wechat.git synced 2025-12-19 16:52:24 +08:00

fix: improve type safety in httpWithTLS for custom RoundTripper (#861)

* fix: improve type safety in httpWithTLS for custom RoundTripper

Add type assertion check to handle cases where DefaultHTTPClient.Transport
is a custom http.RoundTripper implementation (not *http.Transport).

This improves upon the fix in PR #844 which only handled nil Transport.
The previous code would still panic if users set a custom RoundTripper:

  trans := baseTransport.(*http.Transport).Clone()  // panic if not *http.Transport

Now safely handles three scenarios:
1. Transport is nil -> use http.DefaultTransport
2. Transport is *http.Transport -> clone it
3. Transport is custom RoundTripper -> use http.DefaultTransport

Added comprehensive test cases:
- TestHttpWithTLS_NilTransport
- TestHttpWithTLS_CustomTransport
- TestHttpWithTLS_CustomRoundTripper

Related to #803

* refactor: reduce code duplication and complexity in httpWithTLS

- Eliminate duplicate http.DefaultTransport.Clone() calls
- Reduce cyclomatic complexity by simplifying conditional logic
- Use nil check pattern instead of nested else branches
- Maintain same functionality with cleaner code structure

This addresses golangci-lint warnings for dupl and gocyclo.

* fix: add newline at end of http_test.go

Fix gofmt -s compliance issue:
- File must end with newline character
- Addresses golangci-lint gofmt error on line 81

This fixes CI check failure.
This commit is contained in:
is-Xiaoen
2025-10-27 14:24:24 +08:00
committed by GitHub
parent 6f6e95cfdb
commit 30c8e77246
2 changed files with 92 additions and 5 deletions

View File

@@ -292,13 +292,19 @@ func httpWithTLS(rootCa, key string) (*http.Client, error) {
Certificates: []tls.Certificate{cert},
}
var baseTransport http.RoundTripper
// 安全地获取 *http.Transport
var trans *http.Transport
// 尝试从 DefaultHTTPClient 获取 Transport如果失败则使用默认值
if DefaultHTTPClient.Transport != nil {
baseTransport = DefaultHTTPClient.Transport
} else {
baseTransport = http.DefaultTransport
if t, ok := DefaultHTTPClient.Transport.(*http.Transport); ok {
trans = t.Clone()
}
}
trans := baseTransport.(*http.Transport).Clone()
// 如果无法获取有效的 Transport使用默认值
if trans == nil {
trans = http.DefaultTransport.(*http.Transport).Clone()
}
trans.TLSClientConfig = config
trans.DisableCompression = true
client = &http.Client{Transport: trans}

81
util/http_test.go Normal file
View File

@@ -0,0 +1,81 @@
package util
import (
"net/http"
"testing"
)
// TestHttpWithTLS_NilTransport tests the scenario where DefaultHTTPClient.Transport is nil
func TestHttpWithTLS_NilTransport(t *testing.T) {
// Save original transport
originalTransport := DefaultHTTPClient.Transport
defer func() {
DefaultHTTPClient.Transport = originalTransport
}()
// Set Transport to nil to simulate the bug scenario
DefaultHTTPClient.Transport = nil
// This should not panic after the fix
// Note: This will fail due to invalid cert path, but shouldn't panic on type assertion
_, err := httpWithTLS("./testdata/invalid_cert.p12", "password")
// We expect an error (cert file not found), but NOT a panic
if err == nil {
t.Error("Expected error due to invalid cert path, but got nil")
}
}
// TestHttpWithTLS_CustomTransport tests the scenario where DefaultHTTPClient has a custom Transport
func TestHttpWithTLS_CustomTransport(t *testing.T) {
// Save original transport
originalTransport := DefaultHTTPClient.Transport
defer func() {
DefaultHTTPClient.Transport = originalTransport
}()
// Set a custom http.Transport
customTransport := &http.Transport{
MaxIdleConns: 100,
}
DefaultHTTPClient.Transport = customTransport
// This should not panic
_, err := httpWithTLS("./testdata/invalid_cert.p12", "password")
// We expect an error (cert file not found), but NOT a panic
if err == nil {
t.Error("Expected error due to invalid cert path, but got nil")
}
}
// CustomRoundTripper is a custom implementation of http.RoundTripper
type CustomRoundTripper struct{}
func (c *CustomRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return http.DefaultTransport.RoundTrip(req)
}
// TestHttpWithTLS_CustomRoundTripper tests the edge case where DefaultHTTPClient has a custom RoundTripper
// that is NOT *http.Transport
func TestHttpWithTLS_CustomRoundTripper(t *testing.T) {
// Save original transport
originalTransport := DefaultHTTPClient.Transport
defer func() {
DefaultHTTPClient.Transport = originalTransport
}()
// Set a custom RoundTripper that is NOT *http.Transport
customRoundTripper := &CustomRoundTripper{}
DefaultHTTPClient.Transport = customRoundTripper
// Create a recovery handler to catch potential panic
defer func() {
if r := recover(); r != nil {
t.Errorf("httpWithTLS panicked with custom RoundTripper: %v", r)
}
}()
// This might panic if the code doesn't handle non-*http.Transport RoundTripper properly
_, _ = httpWithTLS("./testdata/invalid_cert.p12", "password")
}