diff --git a/go.mod b/go.mod index a8e9a12..8e2dc49 100644 --- a/go.mod +++ b/go.mod @@ -4,16 +4,20 @@ go 1.19 require ( github.com/Sakurasan/to v0.0.0-20180919163141-e72657dd7c7d + github.com/duke-git/lancet/v2 v2.1.19 github.com/gin-gonic/gin v1.9.0 github.com/glebarez/sqlite v1.7.0 github.com/google/uuid v1.3.0 github.com/patrickmn/go-cache v2.1.0+incompatible + github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480 + github.com/sashabaranov/go-openai v1.9.0 gorm.io/gorm v1.24.6 ) require ( github.com/bytedance/sonic v1.8.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/dlclark/regexp2 v1.8.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/glebarez/go-sqlite v1.20.3 // indirect diff --git a/go.sum b/go.sum index b4405a5..3aa27ab 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,10 @@ github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583j github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0= +github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/duke-git/lancet/v2 v2.1.19 h1:dbRB1m6wOMV1I0ax/3S6ngop8SYM6I7sr+7D9IXjS2E= +github.com/duke-git/lancet/v2 v2.1.19/go.mod h1:hNcc06mV7qr+crH/0nP+rlC3TB0Q9g5OrVnO8/TGD4c= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -57,12 +61,16 @@ github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaR github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pelletier/go-toml/v2 v2.0.6 h1:nrzqCb7j9cDFj2coyLNLaZuJTLjWjlaz6nvTvIwycIU= github.com/pelletier/go-toml/v2 v2.0.6/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha2N+QD+EUNTek= +github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480 h1:IFhPCcB0/HtnEN+ZoUGDT55YgFCymbFJ15kXqs3nv5w= +github.com/pkoukk/tiktoken-go v0.1.1-0.20230418101013-cae809389480/go.mod h1:BijIqAP84FMYC4XbdJgjyMpiSjusU8x0Y0W9K2t0QtU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578 h1:VstopitMQi3hZP0fzvnsLmzXZdQGc4bEcgu24cp+d4M= github.com/remyoudompheng/bigfft v0.0.0-20230126093431-47fa9a501578/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= +github.com/sashabaranov/go-openai v1.9.0 h1:NoiO++IISxxJ1pRc0n7uZvMGMake0G+FJ1XPwXtprsA= +github.com/sashabaranov/go-openai v1.9.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -71,8 +79,8 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.9 h1:rmenucSohSTiyL09Y+l2OCk+FrMxGMzho2+tjr5ticU= diff --git a/router/router.go b/router/router.go index 9484583..28f1dbf 100644 --- a/router/router.go +++ b/router/router.go @@ -17,10 +17,12 @@ import ( "time" "github.com/Sakurasan/to" + "github.com/duke-git/lancet/v2/cryptor" "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/pkoukk/tiktoken-go" + "github.com/sashabaranov/go-openai" "gorm.io/gorm" - // "github.com/pkoukk/tiktoken-go" ) var ( @@ -295,15 +297,24 @@ func GenerateToken() string { return token.String() } -type _ struct { - model string - promptCount int - completionCount int -} +// type Tokens struct { +// UserID int +// PromptCount int +// CompletionCount int +// TotalTokens int +// Model string +// PromptHash string +// } func HandleProy(c *gin.Context) { - var localuser bool - var isStream bool + var ( + localuser bool + isStream bool + chatreq = openai.ChatCompletionRequest{} + chatres = openai.ChatCompletionResponse{} + chatlog store.Tokens + pre_prompt string + ) auth := c.Request.Header.Get("Authorization") if len(auth) > 7 && auth[:7] == "Bearer " { localuser = store.IsExistAuthCache(auth[7:]) @@ -324,14 +335,20 @@ func HandleProy(c *gin.Context) { } client.Transport = tr - if c.Request.URL.Path == "/v1/chat/completions" { - var chatreq = ChatCompletionRequest{} + if c.Request.URL.Path == "/v1/chat/completions" && localuser { + if err := c.BindJSON(&chatreq); err != nil { return - // c.AbortWithError(http.StatusBadRequest,) + // c.AbortWithError(http.StatusBadRequest, err) } + chatlog.Model = chatreq.Model + for _, m := range chatreq.Messages { + pre_prompt += m.Content + "\n" + } + chatlog.PromptHash = cryptor.Md5String(pre_prompt) + chatlog.PromptCount = NumTokensFromMessages(chatreq.Messages, chatreq.Model) isStream = chatreq.Stream - + chatlog.UserID, _ = store.GetUserID(auth[7:]) } // 创建 API 请求 @@ -378,17 +395,24 @@ func HandleProy(c *gin.Context) { resp.Header.Del("content-security-policy-report-only") resp.Header.Del("clear-site-data") - // bodyRes, err := io.ReadAll(resp.Body) - // if err != nil { - // log.Println(err) - // c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) - // return - // } reader := bufio.NewReader(resp.Body) - // var resbuf = bytes.NewBuffer(nil) + var resbuf = bytes.NewBuffer(nil) + + if resp.StatusCode == 200 && localuser { + if isStream { + chatlog.CompletionCount = NumTokensFromStr(<-fetchResponseContent(resbuf, reader), chatreq.Model) + chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount + } else { + reader.WriteTo(resbuf) + json.NewDecoder(resbuf).Decode(&chatres) + chatlog.PromptCount = chatres.Usage.PromptTokens + chatlog.CompletionCount = chatres.Usage.CompletionTokens + chatlog.TotalTokens = chatres.Usage.TotalTokens + } + chatlog.Cost = Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount) + store.Record(&chatlog) + // todo insert usage && calc daily_usage - if resp.StatusCode == 200 && isStream { - //todo } resbody := io.NopCloser(reader) // 返回 API 响应主体 @@ -450,12 +474,12 @@ func HandleReverseProxy(c *gin.Context) { func Cost(model string, promptCount, completionCount int) float64 { var cost float64 switch model { - case "gpt-3.5": + case "gpt-3.5-turbo", "gpt-3.5-turbo-0301": cost = 0.002 * float64((promptCount+completionCount)/1000) - case "gpt-4-32k": - cost = 0.06*float64(promptCount/1000) + 0.12*float64(completionCount/1000) - case "gpt-4": + case "gpt-4", "gpt-4-0314": cost = 0.03*float64(promptCount/1000) + 0.06*float64(completionCount/1000) + case "gpt-4-32k", "gpt-4-32k-0314": + cost = 0.06*float64(promptCount/1000) + 0.12*float64(completionCount/1000) } return cost } @@ -528,3 +552,49 @@ func fetchResponseContent(buf *bytes.Buffer, responseBody *bufio.Reader) <-chan }() return contentCh } + +func NumTokensFromMessages(messages []openai.ChatCompletionMessage, model string) (num_tokens int) { + tkm, err := tiktoken.EncodingForModel(model) + if err != nil { + err = fmt.Errorf("EncodingForModel: %v", err) + fmt.Println(err) + return + } + + var tokens_per_message int + var tokens_per_name int + if model == "gpt-3.5-turbo-0301" || model == "gpt-3.5-turbo" { + tokens_per_message = 4 + tokens_per_name = -1 + } else if model == "gpt-4-0314" || model == "gpt-4" { + tokens_per_message = 3 + tokens_per_name = 1 + } else { + fmt.Println("Warning: model not found. Using cl100k_base encoding.") + tokens_per_message = 3 + tokens_per_name = 1 + } + + for _, message := range messages { + num_tokens += tokens_per_message + num_tokens += len(tkm.Encode(message.Content, nil, nil)) + // num_tokens += len(tkm.Encode(message.Role, nil, nil)) + if message.Name != "" { + num_tokens += tokens_per_name + } + } + num_tokens += 3 + return num_tokens +} + +func NumTokensFromStr(messages string, model string) (num_tokens int) { + tkm, err := tiktoken.EncodingForModel(model) + if err != nil { + err = fmt.Errorf("EncodingForModel: %v", err) + fmt.Println(err) + return + } + + num_tokens += len(tkm.Encode(messages, nil, nil)) + return num_tokens +} diff --git a/store/usage.go b/store/usage.go index 89810c9..6241011 100644 --- a/store/usage.go +++ b/store/usage.go @@ -1,7 +1,10 @@ package store import ( + "log" "time" + + "github.com/Sakurasan/to" ) type DailyUsage struct { @@ -23,12 +26,12 @@ type Usage struct { ID int `gorm:"column:id"` PromptHash string `gorm:"column:prompt_hash"` UserID int `gorm:"column:user_id"` - Date time.Time `gorm:"column:date"` SKU string `gorm:"column:sku"` PromptUnits int `gorm:"column:prompt_units"` CompletionUnits int `gorm:"column:completion_units"` TotalUnit int `gorm:"column:total_unit"` Cost string `gorm:"column:cost"` + Date time.Time `gorm:"column:date"` } func (Usage) TableName() string { @@ -64,18 +67,49 @@ func QueryUsage(from, to string) ([]CalcUsage, error) { return results, nil } +type Tokens struct { + UserID int + PromptCount int + CompletionCount int + TotalTokens int + Cost float64 + Model string + PromptHash string +} + +func Record(chatlog *Tokens) (err error) { + u := &Usage{ + UserID: chatlog.UserID, + SKU: chatlog.Model, + PromptHash: chatlog.PromptHash, + PromptUnits: chatlog.PromptCount, + CompletionUnits: chatlog.CompletionCount, + TotalUnit: chatlog.TotalTokens, + Cost: to.String(chatlog.Cost), + Date: time.Now(), + } + err = usage.Create(u).Error + return + +} + func SumDaily(userid string) ([]Summary, error) { + var count int64 + err := usage.Model(&DailyUsage{}).Where("user_id = ? and date = ?", userid, time.Date(time.Now().Year(), time.Now().Month(), time.Now().Day(), 0, 0, 0, 0, time.UTC)).Count(&count).Error + if err != nil { + log.Println(err) + } + if count == 0 { + + } else { + + } + return nil, nil } -func SumDailyV2(uid string) error { - - // err := usage.Model(&DailyUsage{}). - // Select("user_id, '2023-04-18' as date, sku, SUM(prompt_units) as sum_prompt_units, SUM(completion_units) as sum_completion_units, SUM(total_unit) as sum_total_unit, SUM(cost) as sum_cost"). - // Where("date >= ?", "2023-04-18"). - // Where("user_id = ?", 2). - // Create(&DailyUsage{}).Error - nowstr := time.Now().Format("2006-01-02") +func insertSumDaily(uid string) error { + nowstr := time.Date(time.Now().Year(), time.Now().Month(), time.Now().Day(), 0, 0, 0, 0, time.UTC) err := usage.Exec(`INSERT INTO daily_usages (user_id, date, sku, prompt_units, completion_units, total_unit, cost) SELECT @@ -94,3 +128,26 @@ func SumDailyV2(uid string) error { } return nil } + +func updateSumDaily(uid string, date time.Time) error { + var u = Usage{} + err := usage.Exec(`SELECT + user_id, + ?, + sku, + SUM(prompt_units) AS prompt_units, + SUM(completion_units) AS completion_units, + SUM(total_unit) AS total_unit, + SUM(cost) AS cost + FROM usages + WHERE date >= ? + AND user_id = ?`, date, date, uid).First(&u).Error + if err != nil { + return err + } + err = usage.Model(&DailyUsage{}).Where("user_id = ? and date = ?", uid, date).Updates(u).Error + if err != nil { + return err + } + return nil +} diff --git a/store/userdb.go b/store/userdb.go index e4577c4..1b25c6d 100644 --- a/store/userdb.go +++ b/store/userdb.go @@ -73,6 +73,15 @@ func GetUserByName(name string) (*User, error) { return &user, nil } +func GetUserID(authkey string) (int, error) { + var user User + result := db.Where(&User{Token: authkey}).First(&user) + if result.Error != nil { + return 0, result.Error + } + return int(user.ID), nil +} + func GetAllUsers() ([]*User, error) { var users []*User result := db.Find(&users)