From e839af3ef977f484104d139ba6cc25b5d67e2b16 Mon Sep 17 00:00:00 2001 From: Mickls <914541185@qq.com> Date: Tue, 13 Jun 2023 13:55:47 +0800 Subject: [PATCH] feat: Add support for uploading files in SendRequest (#111) --- netutil/http.go | 87 ++++++++++++++++++++++++++++++++-- netutil/http_test.go | 109 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 191 insertions(+), 5 deletions(-) diff --git a/netutil/http.go b/netutil/http.go index 70e0bb9..4c9d1c1 100644 --- a/netutil/http.go +++ b/netutil/http.go @@ -19,8 +19,10 @@ import ( "errors" "fmt" "io" + "mime/multipart" "net/http" "net/url" + "os" "sort" "strings" "time" @@ -94,6 +96,7 @@ type HttpRequest struct { Headers http.Header QueryParams url.Values FormData url.Values + File *File Body []byte } @@ -186,7 +189,11 @@ func (client *HttpClient) SendRequest(request *HttpRequest) (*http.Response, err } if request.FormData != nil { - client.setFormData(req, request.FormData) + if request.File != nil { + err = client.setFormData(req, request.FormData, setFile(request.File)) + } else { + err = client.setFormData(req, request.FormData, nil) + } } client.Request = req @@ -251,10 +258,80 @@ func (client *HttpClient) setQueryParam(req *http.Request, reqUrl string, queryP return nil } -func (client *HttpClient) setFormData(req *http.Request, values url.Values) { - formData := []byte(values.Encode()) - req.Body = io.NopCloser(bytes.NewReader(formData)) - req.ContentLength = int64(len(formData)) +// setFormData set http request FormData param +func (client *HttpClient) setFormData(req *http.Request, values url.Values, setFile SetFileFunc) error { + if setFile != nil { + err := setFile(req, values) + if err != nil { + return err + } + } else { + formData := []byte(values.Encode()) + req.Body = io.NopCloser(bytes.NewReader(formData)) + req.ContentLength = int64(len(formData)) + } + return nil +} + +type SetFileFunc func(req *http.Request, values url.Values) error + +// File struct is a combination of file attributes +type File struct { + Content []byte + Path string + FieldName string + FileName string +} + +// setFile set parameters for http request formdata file upload +func setFile(f *File) SetFileFunc { + return func(req *http.Request, values url.Values) error { + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + for key, vals := range values { + for _, val := range vals { + err := writer.WriteField(key, val) + if err != nil { + return err + } + } + } + + if f.Content != nil { + part, err := writer.CreateFormFile(f.FieldName, f.FileName) + if err != nil { + return err + } + part.Write(f.Content) + } else if f.Path != "" { + file, err := os.Open(f.Path) + if err != nil { + return err + } + defer file.Close() + + part, err := writer.CreateFormFile(f.FieldName, f.FileName) + if err != nil { + return err + } + _, err = io.Copy(part, file) + if err != nil { + return err + } + } + + err := writer.Close() + if err != nil { + return err + } + + req.Body = io.NopCloser(body) + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.ContentLength = int64(body.Len()) + + return nil + } } // validateRequest check if a request has url, and valid method. diff --git a/netutil/http_test.go b/netutil/http_test.go index d1b786d..6071b26 100644 --- a/netutil/http_test.go +++ b/netutil/http_test.go @@ -1,11 +1,15 @@ package netutil import ( + "bytes" "encoding/json" "io" + "io/ioutil" "log" "net/http" + "net/http/httptest" "net/url" + "os" "testing" "github.com/duke-git/lancet/v2/internal" @@ -245,3 +249,108 @@ func TestStructToUrlValues(t *testing.T) { assert.Equal("456", queryValues2.Get("userId")) assert.Equal("", queryValues2.Get("name")) } + +func handleFileRequest(t *testing.T, w http.ResponseWriter, r *http.Request) { + err := r.ParseMultipartForm(1024) + if err != nil { + t.Fatal(err) + } + + key1 := r.FormValue("key1") + expectedKey1 := "value1" + if key1 != expectedKey1 { + t.Fatalf("expected %s, got %s", expectedKey1, key1) + } + + key2 := r.FormValue("key2") + expectedKey2 := "value2" + if key2 != expectedKey2 { + t.Fatalf("expected %s, got %s", expectedKey2, key2) + } + + file, header, err := r.FormFile("image") + if err != nil { + t.Fatal(err) + } + + expectedFileName := "testImage.jpg" + if header.Filename != expectedFileName { + t.Fatalf("expected %s, got %s", expectedFileName, header.Filename) + } + + defer file.Close() + + content, err := ioutil.ReadAll(file) + if err != nil { + t.Fatal(err) + } + + expectedContent := []byte("file content") + if !bytes.Equal(content, expectedContent) { + t.Fatalf("expected %s, got %s", string(expectedContent), string(content)) + } +} + +func TestSendRequestWithFileContent(t *testing.T) { + handler := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + handleFileRequest(t, writer, request) + }) + + server := httptest.NewServer(handler) + defer server.Close() + + client := NewHttpClient() + request := &HttpRequest{ + RawURL: server.URL, + Method: "POST", + File: &File{Content: []byte("file content"), FieldName: "image", FileName: "testImage.jpg"}, + FormData: url.Values{"key1": {"value1"}, "key2": {"value2"}}, + } + + resp, err := client.SendRequest(request) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected %d, got %d", http.StatusOK, resp.StatusCode) + } +} + +func TestSendRequestWithFilePath(t *testing.T) { + handler := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + handleFileRequest(t, writer, request) + }) + + server := httptest.NewServer(handler) + defer server.Close() + + tmpFile, err := ioutil.TempFile("", "testImage.jpg") + if err != nil { + t.Fatal(err) + } + + defer os.Remove(tmpFile.Name()) + + tmpFile.Write([]byte("file content")) + tmpFile.Close() + + client := NewHttpClient() + request := &HttpRequest{ + RawURL: server.URL, + Method: "POST", + File: &File{Path: tmpFile.Name(), FieldName: "image", FileName: "testImage.jpg"}, + FormData: url.Values{"key1": {"value1"}, "key2": {"value2"}}, + } + + resp, err := client.SendRequest(request) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected %d, got %d", http.StatusOK, resp.StatusCode) + } +}