From a202dfadcacb008bba0cd8064d449230e393dbd7 Mon Sep 17 00:00:00 2001 From: Sakurasan <1173092237@qq.com> Date: Thu, 27 Apr 2023 02:31:35 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- router/router.go | 77 +++++++++++++++++++++++++++--------------------- store/usage.go | 2 +- 2 files changed, 44 insertions(+), 35 deletions(-) diff --git a/router/router.go b/router/router.go index 34756d8..c7436ba 100644 --- a/router/router.go +++ b/router/router.go @@ -14,6 +14,7 @@ import ( "net/http/httputil" "opencatd-open/store" "strings" + "sync" "time" "github.com/Sakurasan/to" @@ -297,29 +298,7 @@ func GenerateToken() string { return token.String() } -// type Tokens struct { -// UserID int -// PromptCount int -// CompletionCount int -// TotalTokens int -// Model string -// PromptHash string -// } - -func HandleProy(c *gin.Context) { - var ( - localuser bool - isStream bool - chatreq = openai.ChatCompletionRequest{} - chatres = openai.ChatCompletionResponse{} - chatlog store.Tokens - pre_prompt string - ) - auth := c.Request.Header.Get("Authorization") - if len(auth) > 7 && auth[:7] == "Bearer " { - localuser = store.IsExistAuthCache(auth[7:]) - } - client := http.DefaultClient +func getHttpClient() *http.Client { tr := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ @@ -333,7 +312,23 @@ func HandleProy(c *gin.Context) { ExpectContinueTimeout: 1 * time.Second, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } - client.Transport = tr + return &http.Client{Transport: tr} +} + +func HandleProy(c *gin.Context) { + var ( + localuser bool + isStream bool + chatreq = openai.ChatCompletionRequest{} + chatres = openai.ChatCompletionResponse{} + chatlog store.Tokens + pre_prompt string + wg sync.WaitGroup + ) + auth := c.Request.Header.Get("Authorization") + if len(auth) > 7 && auth[:7] == "Bearer " { + localuser = store.IsExistAuthCache(auth[7:]) + } if c.Request.URL.Path == "/v1/chat/completions" && localuser { @@ -367,7 +362,7 @@ func HandleProy(c *gin.Context) { } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", store.FromKeyCacheRandomItem())) } - + client := getHttpClient() resp, err := client.Do(req) if err != nil { log.Println(err) @@ -396,13 +391,22 @@ func HandleProy(c *gin.Context) { resp.Header.Del("content-security-policy-report-only") resp.Header.Del("clear-site-data") + c.Writer.WriteHeader(resp.StatusCode) + writer := bufio.NewWriter(c.Writer) + defer writer.Flush() + reader := bufio.NewReader(resp.Body) var resbuf = bytes.NewBuffer(nil) if resp.StatusCode == 200 && localuser { + wg.Add(1) if isStream { - chatdata := <-fetchResponseContent(resbuf, reader) - chatlog.CompletionCount = NumTokensFromStr(chatdata, chatreq.Model) + contentCh := fetchResponseContent(resbuf, reader) + var buffer bytes.Buffer + for content := range contentCh { + buffer.WriteString(content) + } + chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model) chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount } else { reader.WriteTo(resbuf) @@ -418,19 +422,24 @@ func HandleProy(c *gin.Context) { if err := store.SumDaily(chatlog.UserID); err != nil { log.Println(err) } - - } - c.Writer.WriteHeader(resp.StatusCode) - if localuser { // 返回 API 响应主体 - if _, err := io.Copy(c.Writer, resbuf); err != nil { - log.Println(err) + if _, err := io.Copy(writer, resbuf); err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) return } + go func() { + defer wg.Done() + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) + return + } + }() + wg.Wait() + return } // 返回 API 响应主体 - if _, err := io.Copy(c.Writer, io.NopCloser(reader)); err != nil { + if _, err := io.Copy(writer, reader); err != nil { log.Println(err) c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) return diff --git a/store/usage.go b/store/usage.go index e2f33d9..d0c3bc3 100644 --- a/store/usage.go +++ b/store/usage.go @@ -58,7 +58,7 @@ func QueryUsage(from, to string) ([]CalcUsage, error) { --SUM(prompt_units) AS prompt_units, -- SUM(completion_units) AS completion_units, SUM(total_unit) AS total_unit, - SUM(cost) AS cost`). + printf('%.6f', SUM(cost)) AS cost`). Group("user_id"). Where("date >= ? AND date < ?", from, to). Find(&results).Error