diff --git a/netutil/net.go b/netutil/net.go index 5d41c7f..e0b498b 100644 --- a/netutil/net.go +++ b/netutil/net.go @@ -345,7 +345,6 @@ func BuildUrl(scheme, host, path string, query map[string]string) (string, error return parsedUrl.String(), nil } -// 支持的 Scheme 列表 var supportedSchemes = map[string]bool{ "http": true, "https": true, @@ -366,3 +365,35 @@ func validateScheme(scheme string) error { var hostRegex = regexp.MustCompile(`^([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])(\.[a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])*$`) var pathRegex = regexp.MustCompile(`^\/([a-zA-Z0-9%_-]+(?:\/[a-zA-Z0-9%_-]+)*)$`) + +var alphaNumericRegex = regexp.MustCompile(`^[a-zA-Z0-9]+$`) + +// AddQueryParams adds query parameters to the given URL. +// Play: todoå +func AddQueryParams(urlStr string, params map[string]string) (string, error) { + parsedUrl, err := url.Parse(urlStr) + if err != nil { + return "", err + } + + queryParams := parsedUrl.Query() + for k, v := range params { + if k == "" { + return "", errors.New("empty key is not allowed") + } + + if !alphaNumericRegex.MatchString(k) { + return "", fmt.Errorf("query parameter key %s must be alphanumeric", k) + } + + if !alphaNumericRegex.MatchString(v) { + return "", fmt.Errorf("query parameter value %s must be alphanumeric", v) + } + + queryParams.Add(k, v) + } + + parsedUrl.RawQuery = queryParams.Encode() + + return parsedUrl.String(), nil +} diff --git a/netutil/net_test.go b/netutil/net_test.go index 5912289..7401337 100644 --- a/netutil/net_test.go +++ b/netutil/net_test.go @@ -192,13 +192,46 @@ func TestBuildUrl(t *testing.T) { for _, tt := range tests { got, err := BuildUrl(tt.scheme, tt.host, tt.path, tt.query) - // if (err != nil) != tt.wantErr { - // t.Errorf("BuildUrl() error = %v, wantErr %v", err, tt.wantErr) - // return - // } - assert.Equal(tt.want, got) assert.Equal(tt.wantErr, err != nil) } } + +func TestAddQueryParams(t *testing.T) { + t.Parallel() + + assert := internal.NewAssert(t, "TestAddQueryParams") + + tests := []struct { + url string + query map[string]string + want string + wantErr bool + }{ + { + url: "http://www.test.com", + query: map[string]string{"a": "1", "b": "2"}, + want: "http://www.test.com?a=1&b=2", + wantErr: false, + }, + { + url: "http://www.test.com", + query: map[string]string{}, + want: "http://www.test.com", + wantErr: false, + }, + { + url: "http://www.test.com", + query: map[string]string{"a": "$%"}, + want: "", + wantErr: true, + }, + } + + for _, tt := range tests { + got, err := AddQueryParams(tt.url, tt.query) + assert.Equal(tt.want, got) + assert.Equal(tt.wantErr, err != nil) + } +}