优化代码
This commit is contained in:
@@ -14,7 +14,6 @@ import (
|
|||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"opencatd-open/store"
|
"opencatd-open/store"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Sakurasan/to"
|
"github.com/Sakurasan/to"
|
||||||
@@ -324,7 +323,9 @@ func HandleProy(c *gin.Context) {
|
|||||||
chatres = openai.ChatCompletionResponse{}
|
chatres = openai.ChatCompletionResponse{}
|
||||||
chatlog store.Tokens
|
chatlog store.Tokens
|
||||||
pre_prompt string
|
pre_prompt string
|
||||||
wg sync.WaitGroup
|
req *http.Request
|
||||||
|
err error
|
||||||
|
// wg sync.WaitGroup
|
||||||
)
|
)
|
||||||
auth := c.Request.Header.Get("Authorization")
|
auth := c.Request.Header.Get("Authorization")
|
||||||
if len(auth) > 7 && auth[:7] == "Bearer " {
|
if len(auth) > 7 && auth[:7] == "Bearer " {
|
||||||
@@ -345,16 +346,25 @@ func HandleProy(c *gin.Context) {
|
|||||||
chatlog.PromptCount = NumTokensFromMessages(chatreq.Messages, chatreq.Model)
|
chatlog.PromptCount = NumTokensFromMessages(chatreq.Messages, chatreq.Model)
|
||||||
isStream = chatreq.Stream
|
isStream = chatreq.Stream
|
||||||
chatlog.UserID, _ = store.GetUserID(auth[7:])
|
chatlog.UserID, _ = store.GetUserID(auth[7:])
|
||||||
|
|
||||||
|
var body bytes.Buffer
|
||||||
|
json.NewEncoder(&body).Encode(chatreq)
|
||||||
|
// 创建 API 请求
|
||||||
|
req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
var body bytes.Buffer
|
|
||||||
json.NewEncoder(&body).Encode(chatreq)
|
|
||||||
// 创建 API 请求
|
|
||||||
req, err := http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body)
|
|
||||||
if err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req.Header = c.Request.Header
|
req.Header = c.Request.Header
|
||||||
if localuser {
|
if localuser {
|
||||||
if store.KeysCache.ItemCount() == 0 {
|
if store.KeysCache.ItemCount() == 0 {
|
||||||
@@ -399,25 +409,38 @@ func HandleProy(c *gin.Context) {
|
|||||||
defer writer.Flush()
|
defer writer.Flush()
|
||||||
|
|
||||||
reader := bufio.NewReader(resp.Body)
|
reader := bufio.NewReader(resp.Body)
|
||||||
var resbuf = bytes.NewBuffer(nil)
|
|
||||||
|
|
||||||
if resp.StatusCode == 200 && localuser {
|
if resp.StatusCode == 200 && localuser {
|
||||||
wg.Add(1)
|
|
||||||
if isStream {
|
if isStream {
|
||||||
contentCh := fetchResponseContent(resbuf, reader)
|
contentCh := fetchResponseContent(writer, reader)
|
||||||
var buffer bytes.Buffer
|
var buffer bytes.Buffer
|
||||||
for content := range contentCh {
|
for content := range contentCh {
|
||||||
buffer.WriteString(content)
|
buffer.WriteString(content)
|
||||||
}
|
}
|
||||||
chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model)
|
chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model)
|
||||||
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
|
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
|
||||||
} else {
|
chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
|
||||||
reader.WriteTo(resbuf)
|
if err := store.Record(&chatlog); err != nil {
|
||||||
json.NewDecoder(resbuf).Decode(&chatres)
|
log.Println(err)
|
||||||
chatlog.PromptCount = chatres.Usage.PromptTokens
|
}
|
||||||
chatlog.CompletionCount = chatres.Usage.CompletionTokens
|
if err := store.SumDaily(chatlog.UserID); err != nil {
|
||||||
chatlog.TotalTokens = chatres.Usage.TotalTokens
|
log.Println(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
res, err := io.ReadAll(reader)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{
|
||||||
|
"message": err.Error(),
|
||||||
|
}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reader = bufio.NewReader(bytes.NewBuffer(res))
|
||||||
|
json.NewDecoder(bytes.NewBuffer(res)).Decode(&chatres)
|
||||||
|
chatlog.PromptCount = chatres.Usage.PromptTokens
|
||||||
|
chatlog.CompletionCount = chatres.Usage.CompletionTokens
|
||||||
|
chatlog.TotalTokens = chatres.Usage.TotalTokens
|
||||||
chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
|
chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
|
||||||
if err := store.Record(&chatlog); err != nil {
|
if err := store.Record(&chatlog); err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
@@ -425,26 +448,14 @@ func HandleProy(c *gin.Context) {
|
|||||||
if err := store.SumDaily(chatlog.UserID); err != nil {
|
if err := store.SumDaily(chatlog.UserID); err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
}
|
}
|
||||||
// 返回 API 响应主体
|
|
||||||
if _, err := io.Copy(writer, resbuf); err != nil {
|
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
wg.Wait()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
// 返回 API 响应主体
|
// 返回 API 响应主体
|
||||||
if _, err := io.Copy(writer, reader); err != nil {
|
if _, err := io.Copy(writer, reader); err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{
|
||||||
|
"message": err.Error(),
|
||||||
|
}})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -525,14 +536,15 @@ func HandleUsage(c *gin.Context) {
|
|||||||
c.JSON(200, usage)
|
c.JSON(200, usage)
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchResponseContent(buf *bytes.Buffer, responseBody *bufio.Reader) <-chan string {
|
func fetchResponseContent(w *bufio.Writer, responseBody *bufio.Reader) <-chan string {
|
||||||
contentCh := make(chan string)
|
contentCh := make(chan string)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(contentCh)
|
defer close(contentCh)
|
||||||
for {
|
for {
|
||||||
line, err := responseBody.ReadString('\n')
|
line, err := responseBody.ReadString('\n')
|
||||||
if err == nil {
|
if err == nil {
|
||||||
buf.WriteString(line)
|
w.WriteString(line)
|
||||||
|
w.Flush()
|
||||||
if line == "\n" {
|
if line == "\n" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user