From b7190c6eb5824cf9d42f935561051f905af9fb47 Mon Sep 17 00:00:00 2001 From: Sakurasan <26715255+Sakurasan@users.noreply.github.com> Date: Mon, 18 Nov 2024 02:56:48 +0800 Subject: [PATCH] up --- pkg/openai/chat.go | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/pkg/openai/chat.go b/pkg/openai/chat.go index 8985ba2..669d8e6 100644 --- a/pkg/openai/chat.go +++ b/pkg/openai/chat.go @@ -273,16 +273,16 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) { return } defer resp.Body.Close() - c.Writer.WriteHeader(resp.StatusCode) - for key, value := range resp.Header { - for _, v := range value { - c.Writer.Header().Add(key, v) - } - } - teeReader := io.TeeReader(resp.Body, c.Writer) var result string if chatReq.Stream { + for key, value := range resp.Header { + for _, v := range value { + c.Writer.Header().Add(key, v) + } + } + c.Writer.WriteHeader(resp.StatusCode) + teeReader := io.TeeReader(resp.Body, c.Writer) // 流式响应 scanner := bufio.NewScanner(teeReader) @@ -318,15 +318,18 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) { } } else { + // 处理非流式响应 - body, err := io.ReadAll(teeReader) + body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println("Error reading response body:", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } var opiResp ChatCompletionResponse if err := json.Unmarshal(body, &opiResp); err != nil { log.Println("Error parsing JSON:", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } if opiResp.Choices != nil && len(opiResp.Choices) > 0 { @@ -343,6 +346,16 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) { } } + resp.Body = io.NopCloser(bytes.NewBuffer(body)) + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + log.Println(err) + } } usagelog.CompletionCount = tokenizer.NumTokensFromStr(result, chatReq.Model) usagelog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(usagelog.Model, usagelog.PromptCount, usagelog.CompletionCount))