diff --git a/router/router.go b/router/router.go index ceb1c08..2fc507b 100644 --- a/router/router.go +++ b/router/router.go @@ -14,7 +14,6 @@ import ( "net/http/httputil" "opencatd-open/store" "strings" - "sync" "time" "github.com/Sakurasan/to" @@ -324,7 +323,9 @@ func HandleProy(c *gin.Context) { chatres = openai.ChatCompletionResponse{} chatlog store.Tokens pre_prompt string - wg sync.WaitGroup + req *http.Request + err error + // wg sync.WaitGroup ) auth := c.Request.Header.Get("Authorization") if len(auth) > 7 && auth[:7] == "Bearer " { @@ -345,16 +346,25 @@ func HandleProy(c *gin.Context) { chatlog.PromptCount = NumTokensFromMessages(chatreq.Messages, chatreq.Model) isStream = chatreq.Stream chatlog.UserID, _ = store.GetUserID(auth[7:]) + + var body bytes.Buffer + json.NewEncoder(&body).Encode(chatreq) + // 创建 API 请求 + req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body) + if err != nil { + log.Println(err) + c.JSON(http.StatusOK, gin.H{"error": err.Error()}) + return + } + } else { + req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, c.Request.Body) + if err != nil { + log.Println(err) + c.JSON(http.StatusOK, gin.H{"error": err.Error()}) + return + } } - var body bytes.Buffer - json.NewEncoder(&body).Encode(chatreq) - // 创建 API 请求 - req, err := http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body) - if err != nil { - log.Println(err) - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) - return - } + req.Header = c.Request.Header if localuser { if store.KeysCache.ItemCount() == 0 { @@ -399,25 +409,38 @@ func HandleProy(c *gin.Context) { defer writer.Flush() reader := bufio.NewReader(resp.Body) - var resbuf = bytes.NewBuffer(nil) if resp.StatusCode == 200 && localuser { - wg.Add(1) + if isStream { - contentCh := fetchResponseContent(resbuf, reader) + contentCh := fetchResponseContent(writer, reader) var buffer bytes.Buffer for content := range contentCh { buffer.WriteString(content) } chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model) chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount - } else { - reader.WriteTo(resbuf) - json.NewDecoder(resbuf).Decode(&chatres) - chatlog.PromptCount = chatres.Usage.PromptTokens - chatlog.CompletionCount = chatres.Usage.CompletionTokens - chatlog.TotalTokens = chatres.Usage.TotalTokens + chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) + if err := store.Record(&chatlog); err != nil { + log.Println(err) + } + if err := store.SumDaily(chatlog.UserID); err != nil { + log.Println(err) + } + return } + res, err := io.ReadAll(reader) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } + reader = bufio.NewReader(bytes.NewBuffer(res)) + json.NewDecoder(bytes.NewBuffer(res)).Decode(&chatres) + chatlog.PromptCount = chatres.Usage.PromptTokens + chatlog.CompletionCount = chatres.Usage.CompletionTokens + chatlog.TotalTokens = chatres.Usage.TotalTokens chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) if err := store.Record(&chatlog); err != nil { log.Println(err) @@ -425,26 +448,14 @@ func HandleProy(c *gin.Context) { if err := store.SumDaily(chatlog.UserID); err != nil { log.Println(err) } - // 返回 API 响应主体 - if _, err := io.Copy(writer, resbuf); err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) - return - } - go func() { - defer wg.Done() - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) - return - } - }() - wg.Wait() - return + } // 返回 API 响应主体 if _, err := io.Copy(writer, reader); err != nil { log.Println(err) - c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) + c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{ + "message": err.Error(), + }}) return } } @@ -525,14 +536,15 @@ func HandleUsage(c *gin.Context) { c.JSON(200, usage) } -func fetchResponseContent(buf *bytes.Buffer, responseBody *bufio.Reader) <-chan string { +func fetchResponseContent(w *bufio.Writer, responseBody *bufio.Reader) <-chan string { contentCh := make(chan string) go func() { defer close(contentCh) for { line, err := responseBody.ReadString('\n') if err == nil { - buf.WriteString(line) + w.WriteString(line) + w.Flush() if line == "\n" { continue }