From 33034a9db0edfc45b4f60306ecc1cd52058709ef Mon Sep 17 00:00:00 2001 From: Sakurasan <1173092237@qq.com> Date: Thu, 27 Apr 2023 18:08:05 +0800 Subject: [PATCH 1/6] =?UTF-8?q?=E6=94=AF=E6=8C=81=E7=BB=9F=E8=AE=A1?= =?UTF-8?q?=E7=94=A8=E6=88=B7=E7=9A=84=E4=BD=BF=E7=94=A8=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 4 + go.sum | 10 +- opencat.go | 11 +- router/router.go | 273 +++++++++++++++++++++++++++++++++++++++++++---- store/db.go | 11 ++ store/usage.go | 149 ++++++++++++++++++++++++++ store/userdb.go | 9 ++ 7 files changed, 442 insertions(+), 25 deletions(-) create mode 100644 store/usage.go diff --git a/go.mod b/go.mod index a8e9a12..8e2dc49 100644 --- a/go.mod +++ b/go.mod @@ -4,16 +4,20 @@ go 1.19 require ( github.com/Sakurasan/to v0.0.0-20180919163141-e72657dd7c7d + github.com/duke-git/lancet/v2 v2.1.19 github.com/gin-gonic/gin v1.9.0 github.com/glebarez/sqlite v1.7.0 github.com/google/uuid v1.3.0 github.com/patrickmn/go-cache v2.1.0+incompatible + github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480 + github.com/sashabaranov/go-openai v1.9.0 gorm.io/gorm v1.24.6 ) require ( github.com/bytedance/sonic v1.8.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/dlclark/regexp2 v1.8.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/glebarez/go-sqlite v1.20.3 // indirect diff --git a/go.sum b/go.sum index b4405a5..3aa27ab 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,10 @@ github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583j github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0= +github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/duke-git/lancet/v2 v2.1.19 h1:dbRB1m6wOMV1I0ax/3S6ngop8SYM6I7sr+7D9IXjS2E= +github.com/duke-git/lancet/v2 v2.1.19/go.mod h1:hNcc06mV7qr+crH/0nP+rlC3TB0Q9g5OrVnO8/TGD4c= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -57,12 +61,16 @@ github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaR github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pelletier/go-toml/v2 v2.0.6 h1:nrzqCb7j9cDFj2coyLNLaZuJTLjWjlaz6nvTvIwycIU= github.com/pelletier/go-toml/v2 v2.0.6/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha2N+QD+EUNTek= +github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480 h1:IFhPCcB0/HtnEN+ZoUGDT55YgFCymbFJ15kXqs3nv5w= +github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480/go.mod h1:BijIqAP84FMYC4XbdJgjyMpiSjusU8x0Y0W9K2t0QtU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 h1:VstopitMQi3hZP0fzvnsLmzXZdQGc4bEcgu24cp+d4M= github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= +github.com/sashabaranov/go-openai v1.9.0 h1:NoiO++IISxxJ1pRc0n7uZvMGMake0G+FJ1XPwXtprsA= +github.com/sashabaranov/go-openai v1.9.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -71,8 +79,8 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.9 h1:rmenucSohSTiyL09Y+l2OCk+FrMxGMzho2+tjr5ticU= diff --git a/opencat.go b/opencat.go index 8dbba0b..4cb2d53 100644 --- a/opencat.go +++ b/opencat.go @@ -38,6 +38,8 @@ func main() { // 获取所有用户信息 group.GET("/users", router.HandleUsers) + group.GET("/usages", router.HandleUsage) + // 添加Key group.POST("/keys", router.HandleAddKey) @@ -57,9 +59,12 @@ func main() { // 初始化用户 r.POST("/1/users/init", router.Handleinit) - r.POST("/v1/chat/completions", router.HandleProy) - r.GET("/v1/models", router.HandleProy) - r.GET("/v1/dashboard/billing/subscription", router.HandleProy) + r.Any("/v1/*proxypath", router.HandleProy) + + // r.POST("/v1/chat/completions", router.HandleProy) + // r.GET("/v1/models", router.HandleProy) + // r.GET("/v1/dashboard/billing/subscription", router.HandleProy) + r.GET("/", func(c *gin.Context) { c.Writer.WriteHeader(http.StatusOK) c.Writer.WriteString(`

opencatd-open available

Api-Keys:https://platform.openai.com/account/api-keys`) diff --git a/router/router.go b/router/router.go index a18ccd9..143053e 100644 --- a/router/router.go +++ b/router/router.go @@ -1,8 +1,10 @@ package router import ( + "bufio" "bytes" "crypto/tls" + "encoding/json" "errors" "fmt" "io" @@ -11,17 +13,25 @@ import ( "net/http" "net/http/httputil" "opencatd-open/store" + "strings" + "sync" "time" "github.com/Sakurasan/to" + "github.com/duke-git/lancet/v2/cryptor" "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/pkoukk/tiktoken-go" + "github.com/sashabaranov/go-openai" "gorm.io/gorm" ) var ( - rootToken string - baseUrl = "https://api.openai.com" + rootToken string + baseUrl = "https://api.openai.com" + GPT3Dot5Turbo = "gpt-3.5-turbo" + GPT4 = "gpt-4" + client = getHttpClient() ) type User struct { @@ -41,6 +51,46 @@ type Key struct { CreatedAt string `json:"createdAt,omitempty"` } +type ChatCompletionMessage struct { + Role string `json:"role"` + Content string `json:"content"` + Name string `json:"name,omitempty"` +} + +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + User string `json:"user,omitempty"` +} + +type ChatCompletionChoice struct { + Index int `json:"index"` + Message ChatCompletionMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + func AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { if rootToken == "" { @@ -249,13 +299,7 @@ func GenerateToken() string { return token.String() } -func HandleProy(c *gin.Context) { - var localuser bool - auth := c.Request.Header.Get("Authorization") - if len(auth) > 7 && auth[:7] == "Bearer " { - localuser = store.IsExistAuthCache(auth[7:]) - } - client := http.DefaultClient +func getHttpClient() *http.Client { tr := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ @@ -269,10 +313,43 @@ func HandleProy(c *gin.Context) { ExpectContinueTimeout: 1 * time.Second, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } - client.Transport = tr + return &http.Client{Transport: tr} +} +func HandleProy(c *gin.Context) { + var ( + localuser bool + isStream bool + chatreq = openai.ChatCompletionRequest{} + chatres = openai.ChatCompletionResponse{} + chatlog store.Tokens + pre_prompt string + wg sync.WaitGroup + ) + auth := c.Request.Header.Get("Authorization") + if len(auth) > 7 && auth[:7] == "Bearer " { + localuser = store.IsExistAuthCache(auth[7:]) + } + + if c.Request.URL.Path == "/v1/chat/completions" && localuser { + + if err := c.BindJSON(&chatreq); err != nil { + c.AbortWithError(http.StatusBadRequest, err) + return + } + chatlog.Model = chatreq.Model + for _, m := range chatreq.Messages { + pre_prompt += m.Content + "\n" + } + chatlog.PromptHash = cryptor.Md5String(pre_prompt) + chatlog.PromptCount = NumTokensFromMessages(chatreq.Messages, chatreq.Model) + isStream = chatreq.Stream + chatlog.UserID, _ = store.GetUserID(auth[7:]) + } + var body bytes.Buffer + json.NewEncoder(&body).Encode(chatreq) // 创建 API 请求 - req, err := http.NewRequest(c.Request.Method, baseUrl+c.Request.URL.Path, c.Request.Body) + req, err := http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body) if err != nil { log.Println(err) c.JSON(http.StatusOK, gin.H{"error": err.Error()}) @@ -315,19 +392,55 @@ func HandleProy(c *gin.Context) { resp.Header.Del("content-security-policy-report-only") resp.Header.Del("clear-site-data") - bodyRes, err := io.ReadAll(resp.Body) - if err != nil { - log.Println(err) - c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) + c.Writer.WriteHeader(resp.StatusCode) + writer := bufio.NewWriter(c.Writer) + defer writer.Flush() + + reader := bufio.NewReader(resp.Body) + var resbuf = bytes.NewBuffer(nil) + + if resp.StatusCode == 200 && localuser { + wg.Add(1) + if isStream { + contentCh := fetchResponseContent(resbuf, reader) + var buffer bytes.Buffer + for content := range contentCh { + buffer.WriteString(content) + } + chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model) + chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount + } else { + reader.WriteTo(resbuf) + json.NewDecoder(resbuf).Decode(&chatres) + chatlog.PromptCount = chatres.Usage.PromptTokens + chatlog.CompletionCount = chatres.Usage.CompletionTokens + chatlog.TotalTokens = chatres.Usage.TotalTokens + } + chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) + if err := store.Record(&chatlog); err != nil { + log.Println(err) + } + if err := store.SumDaily(chatlog.UserID); err != nil { + log.Println(err) + } + // 返回 API 响应主体 + if _, err := io.Copy(writer, resbuf); err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) + return + } + go func() { + defer wg.Done() + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) + return + } + }() + wg.Wait() return } - if resp.StatusCode == 200 { - // todo - } - resbody := io.NopCloser(bytes.NewReader(bodyRes)) // 返回 API 响应主体 - c.Writer.WriteHeader(resp.StatusCode) - if _, err := io.Copy(c.Writer, resbody); err != nil { + if _, err := io.Copy(writer, reader); err != nil { log.Println(err) c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) return @@ -381,3 +494,121 @@ func HandleReverseProxy(c *gin.Context) { proxy.ServeHTTP(c.Writer, req) } +func Cost(model string, promptCount, completionCount int) float64 { + var cost, prompt, completion float64 + prompt = float64(promptCount) + completion = float64(completionCount) + + switch model { + case "gpt-3.5-turbo", "gpt-3.5-turbo-0301": + cost = 0.002 * float64((prompt+completion)/1000) + case "gpt-4", "gpt-4-0314": + cost = 0.03*float64(prompt/1000) + 0.06*float64(completion/1000) + case "gpt-4-32k", "gpt-4-32k-0314": + cost = 0.06*float64(prompt/1000) + 0.12*float64(completion/1000) + } + return cost +} + +func HandleUsage(c *gin.Context) { + fromStr := c.Query("from") + toStr := c.Query("to") + + usage, err := store.QueryUsage(fromStr, toStr) + if err != nil { + c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) + return + } + + c.JSON(200, usage) +} + +func fetchResponseContent(buf *bytes.Buffer, responseBody *bufio.Reader) <-chan string { + contentCh := make(chan string) + go func() { + defer close(contentCh) + for { + line, err := responseBody.ReadString('\n') + if err == nil { + buf.WriteString(line) + if line == "\n" { + continue + } + if strings.HasPrefix(line, "data:") { + line = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if strings.HasSuffix(line, "[DONE]") { + break + } + line = strings.TrimSpace(line) + } + + dec := json.NewDecoder(strings.NewReader(line)) + var data map[string]interface{} + if err := dec.Decode(&data); err == io.EOF { + log.Println("EOF:", err) + break + } else if err != nil { + fmt.Println("Error decoding response:", err) + return + } + if choices, ok := data["choices"].([]interface{}); ok { + for _, choice := range choices { + choiceMap := choice.(map[string]interface{}) + if content, ok := choiceMap["delta"].(map[string]interface{})["content"]; ok { + contentCh <- content.(string) + } + } + } + } else { + break + } + } + }() + return contentCh +} + +func NumTokensFromMessages(messages []openai.ChatCompletionMessage, model string) (num_tokens int) { + tkm, err := tiktoken.EncodingForModel(model) + if err != nil { + err = fmt.Errorf("EncodingForModel: %v", err) + fmt.Println(err) + return + } + + var tokens_per_message int + var tokens_per_name int + if model == "gpt-3.5-turbo-0301" || model == "gpt-3.5-turbo" { + tokens_per_message = 4 + tokens_per_name = -1 + } else if model == "gpt-4-0314" || model == "gpt-4" { + tokens_per_message = 3 + tokens_per_name = 1 + } else { + fmt.Println("Warning: model not found. Using cl100k_base encoding.") + tokens_per_message = 3 + tokens_per_name = 1 + } + + for _, message := range messages { + num_tokens += tokens_per_message + num_tokens += len(tkm.Encode(message.Content, nil, nil)) + // num_tokens += len(tkm.Encode(message.Role, nil, nil)) + if message.Name != "" { + num_tokens += tokens_per_name + } + } + num_tokens += 3 + return num_tokens +} + +func NumTokensFromStr(messages string, model string) (num_tokens int) { + tkm, err := tiktoken.EncodingForModel(model) + if err != nil { + err = fmt.Errorf("EncodingForModel: %v", err) + fmt.Println(err) + return + } + + num_tokens += len(tkm.Encode(messages, nil, nil)) + return num_tokens +} diff --git a/store/db.go b/store/db.go index 88b1294..65027bb 100644 --- a/store/db.go +++ b/store/db.go @@ -11,6 +11,8 @@ import ( var db *gorm.DB +var usage *gorm.DB + func init() { if _, err := os.Stat("db"); os.IsNotExist(err) { errDir := os.MkdirAll("db", 0755) @@ -31,4 +33,13 @@ func init() { } LoadKeysCache() LoadAuthCache() + + usage, err = gorm.Open(sqlite.Open("./db/usage.db"), &gorm.Config{}) + if err != nil { + panic(err) + } + err = usage.AutoMigrate(&DailyUsage{}, &Usage{}) + if err != nil { + panic(err) + } } diff --git a/store/usage.go b/store/usage.go new file mode 100644 index 0000000..d0c3bc3 --- /dev/null +++ b/store/usage.go @@ -0,0 +1,149 @@ +package store + +import ( + "errors" + "time" + + "github.com/Sakurasan/to" + "gorm.io/gorm" +) + +type DailyUsage struct { + ID int `gorm:"column:id"` + UserID int `gorm:"column:user_id";primaryKey` + Date time.Time `gorm:"column:date"` + SKU string `gorm:"column:sku"` + PromptUnits int `gorm:"column:prompt_units"` + CompletionUnits int `gorm:"column:completion_units"` + TotalUnit int `gorm:"column:total_unit"` + Cost string `gorm:"column:cost"` +} + +func (DailyUsage) TableName() string { + return "daily_usages" +} + +type Usage struct { + ID int `gorm:"column:id"` + PromptHash string `gorm:"column:prompt_hash"` + UserID int `gorm:"column:user_id"` + SKU string `gorm:"column:sku"` + PromptUnits int `gorm:"column:prompt_units"` + CompletionUnits int `gorm:"column:completion_units"` + TotalUnit int `gorm:"column:total_unit"` + Cost string `gorm:"column:cost"` + Date time.Time `gorm:"column:date"` +} + +func (Usage) TableName() string { + return "usages" +} + +type Summary struct { + UserId int `gorm:"column:user_id"` + SumPromptUnits int `gorm:"column:sum_prompt_units"` + SumCompletionUnits int `gorm:"column:sum_completion_units"` + SumTotalUnit int `gorm:"column:sum_total_unit"` + SumCost float64 `gorm:"column:sum_cost"` +} +type CalcUsage struct { + UserID int `json:"userId,omitempty"` + TotalUnit int `json:"totalUnit,omitempty"` + Cost string `json:"cost,omitempty"` +} + +func QueryUsage(from, to string) ([]CalcUsage, error) { + var results = []CalcUsage{} + err := usage.Model(&DailyUsage{}).Select(`user_id, + --SUM(prompt_units) AS prompt_units, + -- SUM(completion_units) AS completion_units, + SUM(total_unit) AS total_unit, + printf('%.6f', SUM(cost)) AS cost`). + Group("user_id"). + Where("date >= ? AND date < ?", from, to). + Find(&results).Error + if err != nil { + return nil, err + } + return results, nil +} + +type Tokens struct { + UserID int + PromptCount int + CompletionCount int + TotalTokens int + Cost string + Model string + PromptHash string +} + +func Record(chatlog *Tokens) (err error) { + u := &Usage{ + UserID: chatlog.UserID, + SKU: chatlog.Model, + PromptHash: chatlog.PromptHash, + PromptUnits: chatlog.PromptCount, + CompletionUnits: chatlog.CompletionCount, + TotalUnit: chatlog.TotalTokens, + Cost: to.String(chatlog.Cost), + Date: time.Now(), + } + err = usage.Create(u).Error + return + +} + +func SumDaily(userid int) error { + var count int64 + err := usage.Model(&DailyUsage{}).Where("user_id = ? and date = ?", userid, time.Date(time.Now().Year(), time.Now().Month(), time.Now().Day(), 0, 0, 0, 0, time.UTC)).Count(&count).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + if count == 0 { + if err := insertSumDaily(userid); err != nil { + return err + } + } else { + if err := updateSumDaily(userid, time.Date(time.Now().Year(), time.Now().Month(), time.Now().Day(), 0, 0, 0, 0, time.UTC)); err != nil { + return err + } + } + return nil +} + +func insertSumDaily(uid int) error { + nowstr := time.Date(time.Now().Year(), time.Now().Month(), time.Now().Day(), 0, 0, 0, 0, time.UTC) + err := usage.Exec(`INSERT INTO daily_usages + (user_id, date, sku, prompt_units, completion_units, total_unit, cost) + SELECT + user_id, + ?, + sku, + SUM(prompt_units) AS sum_prompt_units, + SUM(completion_units) AS sum_completion_units, + SUM(total_unit) AS sum_total_unit, + SUM(cost) AS sum_cost + FROM usages + WHERE date >= ? + AND user_id = ?`, nowstr, nowstr, uid).Error + if err != nil { + return err + } + return nil +} + +func updateSumDaily(uid int, date time.Time) error { + // var u = Summary{} + err := usage.Model(&Usage{}).Exec(`UPDATE daily_usages + SET + prompt_units = (SELECT SUM(prompt_units) FROM usages WHERE user_id = daily_usages.user_id AND date >= daily_usages.date), + completion_units = (SELECT SUM(completion_units) FROM usages WHERE user_id = daily_usages.user_id AND date >= daily_usages.date), + total_unit = (SELECT SUM(total_unit) FROM usages WHERE user_id = daily_usages.user_id AND date >= daily_usages.date), + cost = (SELECT SUM(cost) FROM usages WHERE user_id = daily_usages.user_id AND date >= daily_usages.date) + WHERE user_id = ? AND date >= ?`, uid, date).Error + if err != nil { + return err + } + return nil +} diff --git a/store/userdb.go b/store/userdb.go index e4577c4..1b25c6d 100644 --- a/store/userdb.go +++ b/store/userdb.go @@ -73,6 +73,15 @@ func GetUserByName(name string) (*User, error) { return &user, nil } +func GetUserID(authkey string) (int, error) { + var user User + result := db.Where(&User{Token: authkey}).First(&user) + if result.Error != nil { + return 0, result.Error + } + return int(user.ID), nil +} + func GetAllUsers() ([]*User, error) { var users []*User result := db.Find(&users) From 8d2ad4e993244ca35b3800b328faa2e34cb49ba4 Mon Sep 17 00:00:00 2001 From: Sakurasan <1173092237@qq.com> Date: Fri, 28 Apr 2023 03:26:02 +0800 Subject: [PATCH 2/6] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=8D=A1=E9=A1=BF&bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- opencat.go | 3 +- router/router.go | 92 ++++++++++++++++++++++++++++-------------------- 2 files changed, 55 insertions(+), 40 deletions(-) diff --git a/opencat.go b/opencat.go index 4cb2d53..0ffebbe 100644 --- a/opencat.go +++ b/opencat.go @@ -59,7 +59,8 @@ func main() { // 初始化用户 r.POST("/1/users/init", router.Handleinit) - r.Any("/v1/*proxypath", router.HandleProy) + // r.Any("/v1/*proxypath", router.HandleProy) + r.Match([]string{http.MethodGet, http.MethodPost}, "/v1/*proxypath", router.HandleProy) // r.POST("/v1/chat/completions", router.HandleProy) // r.GET("/v1/models", router.HandleProy) diff --git a/router/router.go b/router/router.go index 143053e..2fc507b 100644 --- a/router/router.go +++ b/router/router.go @@ -14,7 +14,6 @@ import ( "net/http/httputil" "opencatd-open/store" "strings" - "sync" "time" "github.com/Sakurasan/to" @@ -324,7 +323,9 @@ func HandleProy(c *gin.Context) { chatres = openai.ChatCompletionResponse{} chatlog store.Tokens pre_prompt string - wg sync.WaitGroup + req *http.Request + err error + // wg sync.WaitGroup ) auth := c.Request.Header.Get("Authorization") if len(auth) > 7 && auth[:7] == "Bearer " { @@ -345,20 +346,31 @@ func HandleProy(c *gin.Context) { chatlog.PromptCount = NumTokensFromMessages(chatreq.Messages, chatreq.Model) isStream = chatreq.Stream chatlog.UserID, _ = store.GetUserID(auth[7:]) + + var body bytes.Buffer + json.NewEncoder(&body).Encode(chatreq) + // 创建 API 请求 + req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body) + if err != nil { + log.Println(err) + c.JSON(http.StatusOK, gin.H{"error": err.Error()}) + return + } + } else { + req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, c.Request.Body) + if err != nil { + log.Println(err) + c.JSON(http.StatusOK, gin.H{"error": err.Error()}) + return + } } - var body bytes.Buffer - json.NewEncoder(&body).Encode(chatreq) - // 创建 API 请求 - req, err := http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body) - if err != nil { - log.Println(err) - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) - return - } + req.Header = c.Request.Header if localuser { if store.KeysCache.ItemCount() == 0 { - c.JSON(http.StatusOK, gin.H{"error": "No Api-Key Available"}) + c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{ + "message": "No Api-Key Available", + }}) return } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", store.FromKeyCacheRandomItem())) @@ -397,25 +409,38 @@ func HandleProy(c *gin.Context) { defer writer.Flush() reader := bufio.NewReader(resp.Body) - var resbuf = bytes.NewBuffer(nil) if resp.StatusCode == 200 && localuser { - wg.Add(1) + if isStream { - contentCh := fetchResponseContent(resbuf, reader) + contentCh := fetchResponseContent(writer, reader) var buffer bytes.Buffer for content := range contentCh { buffer.WriteString(content) } chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model) chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount - } else { - reader.WriteTo(resbuf) - json.NewDecoder(resbuf).Decode(&chatres) - chatlog.PromptCount = chatres.Usage.PromptTokens - chatlog.CompletionCount = chatres.Usage.CompletionTokens - chatlog.TotalTokens = chatres.Usage.TotalTokens + chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) + if err := store.Record(&chatlog); err != nil { + log.Println(err) + } + if err := store.SumDaily(chatlog.UserID); err != nil { + log.Println(err) + } + return } + res, err := io.ReadAll(reader) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } + reader = bufio.NewReader(bytes.NewBuffer(res)) + json.NewDecoder(bytes.NewBuffer(res)).Decode(&chatres) + chatlog.PromptCount = chatres.Usage.PromptTokens + chatlog.CompletionCount = chatres.Usage.CompletionTokens + chatlog.TotalTokens = chatres.Usage.TotalTokens chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) if err := store.Record(&chatlog); err != nil { log.Println(err) @@ -423,26 +448,14 @@ func HandleProy(c *gin.Context) { if err := store.SumDaily(chatlog.UserID); err != nil { log.Println(err) } - // 返回 API 响应主体 - if _, err := io.Copy(writer, resbuf); err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) - return - } - go func() { - defer wg.Done() - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) - return - } - }() - wg.Wait() - return + } // 返回 API 响应主体 if _, err := io.Copy(writer, reader); err != nil { log.Println(err) - c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) + c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{ + "message": err.Error(), + }}) return } } @@ -523,14 +536,15 @@ func HandleUsage(c *gin.Context) { c.JSON(200, usage) } -func fetchResponseContent(buf *bytes.Buffer, responseBody *bufio.Reader) <-chan string { +func fetchResponseContent(w *bufio.Writer, responseBody *bufio.Reader) <-chan string { contentCh := make(chan string) go func() { defer close(contentCh) for { line, err := responseBody.ReadString('\n') if err == nil { - buf.WriteString(line) + w.WriteString(line) + w.Flush() if line == "\n" { continue } From ea9d9d532a6f9b7047aaa2875b455303e9e1d582 Mon Sep 17 00:00:00 2001 From: Sakurasan <1173092237@qq.com> Date: Fri, 28 Apr 2023 14:49:04 +0800 Subject: [PATCH 3/6] Squashed commit of the following: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit commit 379586cd1bc6b676473c6d0058e7dd0a9936e623 Author: Sakurasan <1173092237@qq.com> Date: Fri Apr 28 14:46:46 2023 +0800 up commit 5200171dc08a57c29d586d30426532458dbe2508 Author: Sakurasan <1173092237@qq.com> Date: Fri Apr 28 03:05:23 2023 +0800 优化代码 commit d2f07a824e9bb990b995900e8dce84c155c11032 Author: Sakurasan <1173092237@qq.com> Date: Thu Apr 27 22:23:22 2023 +0800 更新错误提示 commit 8a2ebb677879d441835686dc254fdf4d4fe9c557 Author: Sakurasan <1173092237@qq.com> Date: Thu Apr 27 17:27:43 2023 +0800 up commit a202dfadcacb008bba0cd8064d449230e393dbd7 Author: Sakurasan <1173092237@qq.com> Date: Thu Apr 27 02:31:35 2023 +0800 优化代码 commit a100d75c008ac208ea97982d0687503d64b72489 Author: Sakurasan <1173092237@qq.com> Date: Wed Apr 26 03:56:29 2023 +0800 fix bug commit a34e579aa8213aa7d2cf38c26c89d5ea6734a9b6 Author: Sakurasan <1173092237@qq.com> Date: Wed Apr 26 03:18:49 2023 +0800 fix db type commit 3dae04afcdc5f43a3b5e44fe8c4004594e91cec9 Author: Sakurasan <1173092237@qq.com> Date: Wed Apr 26 03:03:43 2023 +0800 fix null data commit 56b56c703d6ba2dc2b89f5b9ed65698e87e00111 Author: Sakurasan <1173092237@qq.com> Date: Tue Apr 25 22:27:19 2023 +0800 up commit 97a13a5f7978cdd870b602d2f7480998f84d1a74 Author: Sakurasan <1173092237@qq.com> Date: Tue Apr 25 21:30:18 2023 +0800 usage commit 5ce546672319c0aa9f2475ab2523ac25ad89126a Author: Sakurasan <1173092237@qq.com> Date: Mon Apr 24 22:44:03 2023 +0800 usage commit 365bc374873ccf5db2ac73d83bfbdb64300f001e Author: Sakurasan <1173092237@qq.com> Date: Sat Apr 22 22:40:47 2023 +0800 up commit 1cd06ea1dfebbc215da3c05d3a9da8bd47b81bc2 Author: Sakurasan <1173092237@qq.com> Date: Wed Apr 19 22:50:39 2023 +0800 up commit 499edbb0fd74a71c40df2e7da4498a195cfa7fac Author: Sakurasan <1173092237@qq.com> Date: Wed Apr 19 02:06:09 2023 +0800 add Completion commit 11dbf4376efb270914ab31233f01d0eb76d3b731 Author: Sakurasan <1173092237@qq.com> Date: Tue Apr 18 22:46:11 2023 +0800 up commit eb22de912a54873b3f80ed2225e596d87851ce48 Author: Sakurasan <1173092237@qq.com> Date: Mon Apr 17 22:48:37 2023 +0800 add usage --- opencat.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/opencat.go b/opencat.go index 0ffebbe..4cb2d53 100644 --- a/opencat.go +++ b/opencat.go @@ -59,8 +59,7 @@ func main() { // 初始化用户 r.POST("/1/users/init", router.Handleinit) - // r.Any("/v1/*proxypath", router.HandleProy) - r.Match([]string{http.MethodGet, http.MethodPost}, "/v1/*proxypath", router.HandleProy) + r.Any("/v1/*proxypath", router.HandleProy) // r.POST("/v1/chat/completions", router.HandleProy) // r.GET("/v1/models", router.HandleProy) From 38cf469852acfb3138c03dac0f55346497dd96e7 Mon Sep 17 00:00:00 2001 From: Sakurasan <1173092237@qq.com> Date: Tue, 2 May 2023 02:05:06 +0800 Subject: [PATCH 4/6] =?UTF-8?q?=E5=A2=9E=E5=8A=A0root=5Ftoken=E6=9F=A5?= =?UTF-8?q?=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 11 +++++++---- docker/Dockerfile | 5 +++-- makefile | 5 +++-- opencat.go | 39 ++++++++++++++++++++++++++++++++------- 4 files changed, 45 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 3d6143a..119e075 100644 --- a/README.md +++ b/README.md @@ -30,10 +30,13 @@ or ``` wget https://github.com/mirrors2/opencatd-open/raw/main/docker/docker-compose.yml ``` -## reset root token -``` -docker exec -it opencatd-open ./opencatd reset_root -``` +## 支持的命令 +>获取 root 的 token + - `docker exec opencatd-open opencatd root_token` + +>重置 root 的 token + - `docker exec opencatd-open opencatd reset_root` + ## Q&A 关于证书? diff --git a/docker/Dockerfile b/docker/Dockerfile index 67013c4..754d71c 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,6 +1,6 @@ FROM golang:1.19.7-alpine as builder LABEL anther="github.com/Sakurasan" -RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.aliyun.com/g' /etc/apk/repositories && apk update && apk --no-cache add openssl make cmake upx +RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.aliyun.com/g' /etc/apk/repositories && apk --no-cache add make cmake upx WORKDIR /build COPY . /build ENV GO111MODULE=on @@ -12,9 +12,10 @@ FROM alpine:latest AS runner # 设置alpine 时间为上海时间 RUN sed -i 's/dl-cdn.alpinelinux.org/mirrors.aliyun.com/g' /etc/apk/repositories && apk update && apk --no-cache add tzdata && cp /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \ && echo "Asia/Shanghai" > /etc/timezone \ - && apk del tzdata && export GIN_MODE=release # RUN apk update && apk --no-cache add openssl libgcc libstdc++ binutils WORKDIR /app COPY --from=builder /build/bin/opencatd /app/opencatd +ENV GIN_MODE=release +ENV PATH=$PATH:/app EXPOSE 80 ENTRYPOINT ["/app/opencatd"] \ No newline at end of file diff --git a/makefile b/makefile index 111700e..565939d 100644 --- a/makefile +++ b/makefile @@ -24,11 +24,12 @@ build: upx -9 bin/opencatd .PHONY:docker +# build docker images docker: docker run --privileged --rm tonistiigi/binfmt --install all - docker buildx create --use --name xbuilder + docker buildx create --use --name xbuilder --driver docker-container docker buildx inspect xbuilder --bootstrap - docker buildx build --platform linux/amd64,linux/arm64 -t mirrors2/opencatd:latest . --push + docker buildx build --platform linux/amd64,linux/arm64 -t mirrors2/opencatd:latest -f docker/Dockerfile . --push .PHONY: clean # clean diff --git a/opencat.go b/opencat.go index 4cb2d53..749b600 100644 --- a/opencat.go +++ b/opencat.go @@ -9,19 +9,44 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" + "gorm.io/gorm" ) func main() { args := os.Args[1:] - if len(args) > 0 && args[0] == "reset_root" { - log.Println("reset root token...") - ntoken := uuid.NewString() - if err := store.UpdateUser(uint(1), ntoken); err != nil { - log.Fatalln(err) + if len(args) > 0 { + switch args[0] { + case "reset_root": + log.Println("reset root token...") + if _, err := store.GetUserByID(uint(1)); err != nil { + if err == gorm.ErrRecordNotFound { + log.Println("请在opencat(或其他APP)客户端完成team初始化") + return + } else { + log.Fatalln(err) + return + } + } + ntoken := uuid.NewString() + if err := store.UpdateUser(uint(1), ntoken); err != nil { + log.Fatalln(err) + return + } + log.Println("new root token:", ntoken) + return + case "root_token": + log.Println("reset root token...") + if user, err := store.GetUserByID(uint(1)); err != nil { + log.Fatalln(err) + return + } else { + log.Println("root token:", user.Token) + return + } + default: return } - log.Println("new root token:", ntoken) - return + } port := os.Getenv("PORT") r := gin.Default() From c379db6a73b900ffc9714e57e85b0be231757488 Mon Sep 17 00:00:00 2001 From: Sakurasan <1173092237@qq.com> Date: Thu, 4 May 2023 19:49:33 +0800 Subject: [PATCH 5/6] add --- assets/logo.svg | 1 + 1 file changed, 1 insertion(+) create mode 100644 assets/logo.svg diff --git a/assets/logo.svg b/assets/logo.svg new file mode 100644 index 0000000..a45bb6b --- /dev/null +++ b/assets/logo.svg @@ -0,0 +1 @@ + \ No newline at end of file From 9339cab328e4bf339a99274cb4377c2b867fb88b Mon Sep 17 00:00:00 2001 From: Sakurasan <1173092237@qq.com> Date: Sun, 7 May 2023 22:40:47 +0800 Subject: [PATCH 6/6] up --- store/userdb.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/store/userdb.go b/store/userdb.go index 1b25c6d..b621e7e 100644 --- a/store/userdb.go +++ b/store/userdb.go @@ -6,7 +6,7 @@ import ( type User struct { IsDelete bool `gorm:"default:false" json:"IsDelete"` - ID uint `gorm:"primarykey" json:"id,omitempty"` + ID uint `gorm:"primaryKey;autoIncrement" json:"id,omitempty"` Name string `gorm:"unique;not null" json:"name,omitempty"` Token string `gorm:"unique;not null" json:"token,omitempty"` CreatedAt time.Time `json:"createdAt,omitempty"`