This commit is contained in:
Sakurasan
2024-11-18 02:56:48 +08:00
parent d5d87a9bb0
commit b7190c6eb5

View File

@@ -273,16 +273,16 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) {
return
}
defer resp.Body.Close()
c.Writer.WriteHeader(resp.StatusCode)
for key, value := range resp.Header {
for _, v := range value {
c.Writer.Header().Add(key, v)
}
}
teeReader := io.TeeReader(resp.Body, c.Writer)
var result string
if chatReq.Stream {
for key, value := range resp.Header {
for _, v := range value {
c.Writer.Header().Add(key, v)
}
}
c.Writer.WriteHeader(resp.StatusCode)
teeReader := io.TeeReader(resp.Body, c.Writer)
// 流式响应
scanner := bufio.NewScanner(teeReader)
@@ -318,15 +318,18 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) {
}
} else {
// 处理非流式响应
body, err := io.ReadAll(teeReader)
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Println("Error reading response body:", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var opiResp ChatCompletionResponse
if err := json.Unmarshal(body, &opiResp); err != nil {
log.Println("Error parsing JSON:", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if opiResp.Choices != nil && len(opiResp.Choices) > 0 {
@@ -343,6 +346,16 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) {
}
}
resp.Body = io.NopCloser(bytes.NewBuffer(body))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
log.Println(err)
}
}
usagelog.CompletionCount = tokenizer.NumTokensFromStr(result, chatReq.Model)
usagelog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(usagelog.Model, usagelog.PromptCount, usagelog.CompletionCount))