From 8d2ad4e993244ca35b3800b328faa2e34cb49ba4 Mon Sep 17 00:00:00 2001 From: Sakurasan <1173092237@qq.com> Date: Fri, 28 Apr 2023 03:26:02 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=8D=A1=E9=A1=BF&bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- opencat.go | 3 +- router/router.go | 92 ++++++++++++++++++++++++++++-------------------- 2 files changed, 55 insertions(+), 40 deletions(-) diff --git a/opencat.go b/opencat.go index 4cb2d53..0ffebbe 100644 --- a/opencat.go +++ b/opencat.go @@ -59,7 +59,8 @@ func main() { // 初始化用户 r.POST("/1/users/init", router.Handleinit) - r.Any("/v1/*proxypath", router.HandleProy) + // r.Any("/v1/*proxypath", router.HandleProy) + r.Match([]string{http.MethodGet, http.MethodPost}, "/v1/*proxypath", router.HandleProy) // r.POST("/v1/chat/completions", router.HandleProy) // r.GET("/v1/models", router.HandleProy) diff --git a/router/router.go b/router/router.go index 143053e..2fc507b 100644 --- a/router/router.go +++ b/router/router.go @@ -14,7 +14,6 @@ import ( "net/http/httputil" "opencatd-open/store" "strings" - "sync" "time" "github.com/Sakurasan/to" @@ -324,7 +323,9 @@ func HandleProy(c *gin.Context) { chatres = openai.ChatCompletionResponse{} chatlog store.Tokens pre_prompt string - wg sync.WaitGroup + req *http.Request + err error + // wg sync.WaitGroup ) auth := c.Request.Header.Get("Authorization") if len(auth) > 7 && auth[:7] == "Bearer " { @@ -345,20 +346,31 @@ func HandleProy(c *gin.Context) { chatlog.PromptCount = NumTokensFromMessages(chatreq.Messages, chatreq.Model) isStream = chatreq.Stream chatlog.UserID, _ = store.GetUserID(auth[7:]) + + var body bytes.Buffer + json.NewEncoder(&body).Encode(chatreq) + // 创建 API 请求 + req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body) + if err != nil { + log.Println(err) + c.JSON(http.StatusOK, gin.H{"error": err.Error()}) + return + } + } else { + req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, c.Request.Body) + if err != nil { + log.Println(err) + c.JSON(http.StatusOK, gin.H{"error": err.Error()}) + return + } } - var body bytes.Buffer - json.NewEncoder(&body).Encode(chatreq) - // 创建 API 请求 - req, err := http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body) - if err != nil { - log.Println(err) - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) - return - } + req.Header = c.Request.Header if localuser { if store.KeysCache.ItemCount() == 0 { - c.JSON(http.StatusOK, gin.H{"error": "No Api-Key Available"}) + c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{ + "message": "No Api-Key Available", + }}) return } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", store.FromKeyCacheRandomItem())) @@ -397,25 +409,38 @@ func HandleProy(c *gin.Context) { defer writer.Flush() reader := bufio.NewReader(resp.Body) - var resbuf = bytes.NewBuffer(nil) if resp.StatusCode == 200 && localuser { - wg.Add(1) + if isStream { - contentCh := fetchResponseContent(resbuf, reader) + contentCh := fetchResponseContent(writer, reader) var buffer bytes.Buffer for content := range contentCh { buffer.WriteString(content) } chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model) chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount - } else { - reader.WriteTo(resbuf) - json.NewDecoder(resbuf).Decode(&chatres) - chatlog.PromptCount = chatres.Usage.PromptTokens - chatlog.CompletionCount = chatres.Usage.CompletionTokens - chatlog.TotalTokens = chatres.Usage.TotalTokens + chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) + if err := store.Record(&chatlog); err != nil { + log.Println(err) + } + if err := store.SumDaily(chatlog.UserID); err != nil { + log.Println(err) + } + return } + res, err := io.ReadAll(reader) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } + reader = bufio.NewReader(bytes.NewBuffer(res)) + json.NewDecoder(bytes.NewBuffer(res)).Decode(&chatres) + chatlog.PromptCount = chatres.Usage.PromptTokens + chatlog.CompletionCount = chatres.Usage.CompletionTokens + chatlog.TotalTokens = chatres.Usage.TotalTokens chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) if err := store.Record(&chatlog); err != nil { log.Println(err) @@ -423,26 +448,14 @@ func HandleProy(c *gin.Context) { if err := store.SumDaily(chatlog.UserID); err != nil { log.Println(err) } - // 返回 API 响应主体 - if _, err := io.Copy(writer, resbuf); err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) - return - } - go func() { - defer wg.Done() - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) - return - } - }() - wg.Wait() - return + } // 返回 API 响应主体 if _, err := io.Copy(writer, reader); err != nil { log.Println(err) - c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) + c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{ + "message": err.Error(), + }}) return } } @@ -523,14 +536,15 @@ func HandleUsage(c *gin.Context) { c.JSON(200, usage) } -func fetchResponseContent(buf *bytes.Buffer, responseBody *bufio.Reader) <-chan string { +func fetchResponseContent(w *bufio.Writer, responseBody *bufio.Reader) <-chan string { contentCh := make(chan string) go func() { defer close(contentCh) for { line, err := responseBody.ReadString('\n') if err == nil { - buf.WriteString(line) + w.WriteString(line) + w.Flush() if line == "\n" { continue }