优化代码

This commit is contained in:
Sakurasan
2023-04-27 02:31:35 +08:00
parent a100d75c00
commit a202dfadca
2 changed files with 44 additions and 35 deletions

View File

@@ -14,6 +14,7 @@ import (
"net/http/httputil"
"opencatd-open/store"
"strings"
"sync"
"time"
"github.com/Sakurasan/to"
@@ -297,29 +298,7 @@ func GenerateToken() string {
return token.String()
}
// type Tokens struct {
// UserID int
// PromptCount int
// CompletionCount int
// TotalTokens int
// Model string
// PromptHash string
// }
func HandleProy(c *gin.Context) {
var (
localuser bool
isStream bool
chatreq = openai.ChatCompletionRequest{}
chatres = openai.ChatCompletionResponse{}
chatlog store.Tokens
pre_prompt string
)
auth := c.Request.Header.Get("Authorization")
if len(auth) > 7 && auth[:7] == "Bearer " {
localuser = store.IsExistAuthCache(auth[7:])
}
client := http.DefaultClient
func getHttpClient() *http.Client {
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
@@ -333,7 +312,23 @@ func HandleProy(c *gin.Context) {
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client.Transport = tr
return &http.Client{Transport: tr}
}
func HandleProy(c *gin.Context) {
var (
localuser bool
isStream bool
chatreq = openai.ChatCompletionRequest{}
chatres = openai.ChatCompletionResponse{}
chatlog store.Tokens
pre_prompt string
wg sync.WaitGroup
)
auth := c.Request.Header.Get("Authorization")
if len(auth) > 7 && auth[:7] == "Bearer " {
localuser = store.IsExistAuthCache(auth[7:])
}
if c.Request.URL.Path == "/v1/chat/completions" && localuser {
@@ -367,7 +362,7 @@ func HandleProy(c *gin.Context) {
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", store.FromKeyCacheRandomItem()))
}
client := getHttpClient()
resp, err := client.Do(req)
if err != nil {
log.Println(err)
@@ -396,13 +391,22 @@ func HandleProy(c *gin.Context) {
resp.Header.Del("content-security-policy-report-only")
resp.Header.Del("clear-site-data")
c.Writer.WriteHeader(resp.StatusCode)
writer := bufio.NewWriter(c.Writer)
defer writer.Flush()
reader := bufio.NewReader(resp.Body)
var resbuf = bytes.NewBuffer(nil)
if resp.StatusCode == 200 && localuser {
wg.Add(1)
if isStream {
chatdata := <-fetchResponseContent(resbuf, reader)
chatlog.CompletionCount = NumTokensFromStr(chatdata, chatreq.Model)
contentCh := fetchResponseContent(resbuf, reader)
var buffer bytes.Buffer
for content := range contentCh {
buffer.WriteString(content)
}
chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model)
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
} else {
reader.WriteTo(resbuf)
@@ -418,19 +422,24 @@ func HandleProy(c *gin.Context) {
if err := store.SumDaily(chatlog.UserID); err != nil {
log.Println(err)
}
}
c.Writer.WriteHeader(resp.StatusCode)
if localuser {
// 返回 API 响应主体
if _, err := io.Copy(c.Writer, resbuf); err != nil {
log.Println(err)
if _, err := io.Copy(writer, resbuf); err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
go func() {
defer wg.Done()
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
}()
wg.Wait()
return
}
// 返回 API 响应主体
if _, err := io.Copy(c.Writer, io.NopCloser(reader)); err != nil {
if _, err := io.Copy(writer, reader); err != nil {
log.Println(err)
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return

View File

@@ -58,7 +58,7 @@ func QueryUsage(from, to string) ([]CalcUsage, error) {
--SUM(prompt_units) AS prompt_units,
-- SUM(completion_units) AS completion_units,
SUM(total_unit) AS total_unit,
SUM(cost) AS cost`).
printf('%.6f', SUM(cost)) AS cost`).
Group("user_id").
Where("date >= ? AND date < ?", from, to).
Find(&results).Error