优化代码

This commit is contained in:
Sakurasan
2023-04-28 03:05:23 +08:00
parent d2f07a824e
commit 5200171dc0

View File

@@ -14,7 +14,6 @@ import (
"net/http/httputil" "net/http/httputil"
"opencatd-open/store" "opencatd-open/store"
"strings" "strings"
"sync"
"time" "time"
"github.com/Sakurasan/to" "github.com/Sakurasan/to"
@@ -324,7 +323,9 @@ func HandleProy(c *gin.Context) {
chatres = openai.ChatCompletionResponse{} chatres = openai.ChatCompletionResponse{}
chatlog store.Tokens chatlog store.Tokens
pre_prompt string pre_prompt string
wg sync.WaitGroup req *http.Request
err error
// wg sync.WaitGroup
) )
auth := c.Request.Header.Get("Authorization") auth := c.Request.Header.Get("Authorization")
if len(auth) > 7 && auth[:7] == "Bearer " { if len(auth) > 7 && auth[:7] == "Bearer " {
@@ -345,16 +346,25 @@ func HandleProy(c *gin.Context) {
chatlog.PromptCount = NumTokensFromMessages(chatreq.Messages, chatreq.Model) chatlog.PromptCount = NumTokensFromMessages(chatreq.Messages, chatreq.Model)
isStream = chatreq.Stream isStream = chatreq.Stream
chatlog.UserID, _ = store.GetUserID(auth[7:]) chatlog.UserID, _ = store.GetUserID(auth[7:])
var body bytes.Buffer
json.NewEncoder(&body).Encode(chatreq)
// 创建 API 请求
req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body)
if err != nil {
log.Println(err)
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
return
}
} else {
req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, c.Request.Body)
if err != nil {
log.Println(err)
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
return
}
} }
var body bytes.Buffer
json.NewEncoder(&body).Encode(chatreq)
// 创建 API 请求
req, err := http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body)
if err != nil {
log.Println(err)
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
return
}
req.Header = c.Request.Header req.Header = c.Request.Header
if localuser { if localuser {
if store.KeysCache.ItemCount() == 0 { if store.KeysCache.ItemCount() == 0 {
@@ -399,25 +409,38 @@ func HandleProy(c *gin.Context) {
defer writer.Flush() defer writer.Flush()
reader := bufio.NewReader(resp.Body) reader := bufio.NewReader(resp.Body)
var resbuf = bytes.NewBuffer(nil)
if resp.StatusCode == 200 && localuser { if resp.StatusCode == 200 && localuser {
wg.Add(1)
if isStream { if isStream {
contentCh := fetchResponseContent(resbuf, reader) contentCh := fetchResponseContent(writer, reader)
var buffer bytes.Buffer var buffer bytes.Buffer
for content := range contentCh { for content := range contentCh {
buffer.WriteString(content) buffer.WriteString(content)
} }
chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model) chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model)
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
} else { chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
reader.WriteTo(resbuf) if err := store.Record(&chatlog); err != nil {
json.NewDecoder(resbuf).Decode(&chatres) log.Println(err)
chatlog.PromptCount = chatres.Usage.PromptTokens }
chatlog.CompletionCount = chatres.Usage.CompletionTokens if err := store.SumDaily(chatlog.UserID); err != nil {
chatlog.TotalTokens = chatres.Usage.TotalTokens log.Println(err)
}
return
} }
res, err := io.ReadAll(reader)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
reader = bufio.NewReader(bytes.NewBuffer(res))
json.NewDecoder(bytes.NewBuffer(res)).Decode(&chatres)
chatlog.PromptCount = chatres.Usage.PromptTokens
chatlog.CompletionCount = chatres.Usage.CompletionTokens
chatlog.TotalTokens = chatres.Usage.TotalTokens
chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
if err := store.Record(&chatlog); err != nil { if err := store.Record(&chatlog); err != nil {
log.Println(err) log.Println(err)
@@ -425,26 +448,14 @@ func HandleProy(c *gin.Context) {
if err := store.SumDaily(chatlog.UserID); err != nil { if err := store.SumDaily(chatlog.UserID); err != nil {
log.Println(err) log.Println(err)
} }
// 返回 API 响应主体
if _, err := io.Copy(writer, resbuf); err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
go func() {
defer wg.Done()
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
}()
wg.Wait()
return
} }
// 返回 API 响应主体 // 返回 API 响应主体
if _, err := io.Copy(writer, reader); err != nil { if _, err := io.Copy(writer, reader); err != nil {
log.Println(err) log.Println(err)
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{
"message": err.Error(),
}})
return return
} }
} }
@@ -525,14 +536,15 @@ func HandleUsage(c *gin.Context) {
c.JSON(200, usage) c.JSON(200, usage)
} }
func fetchResponseContent(buf *bytes.Buffer, responseBody *bufio.Reader) <-chan string { func fetchResponseContent(w *bufio.Writer, responseBody *bufio.Reader) <-chan string {
contentCh := make(chan string) contentCh := make(chan string)
go func() { go func() {
defer close(contentCh) defer close(contentCh)
for { for {
line, err := responseBody.ReadString('\n') line, err := responseBody.ReadString('\n')
if err == nil { if err == nil {
buf.WriteString(line) w.WriteString(line)
w.Flush()
if line == "\n" { if line == "\n" {
continue continue
} }