From 6662ea5e04dedd1cf9e4760772c33b567cc02a11 Mon Sep 17 00:00:00 2001 From: Sakurasan <26715255+Sakurasan@users.noreply.github.com> Date: Tue, 22 Apr 2025 00:48:24 +0800 Subject: [PATCH] fix record usage --- internal/controller/proxy/chat_proxy.go | 31 ++++----- internal/controller/proxy/proxy.go | 91 ++++++++++++++++++------- llm/types.go | 12 ++-- middleware/auth_team.go | 1 + 4 files changed, 89 insertions(+), 46 deletions(-) diff --git a/internal/controller/proxy/chat_proxy.go b/internal/controller/proxy/chat_proxy.go index 2e9220c..fbc567b 100644 --- a/internal/controller/proxy/chat_proxy.go +++ b/internal/controller/proxy/chat_proxy.go @@ -10,12 +10,17 @@ import ( "opencatd-open/llm/google/v2" "opencatd-open/llm/openai_compatible" "opencatd-open/pkg/tokenizer" - "strconv" "github.com/gin-gonic/gin" ) func (h *Proxy) ChatHandler(c *gin.Context) { + user := c.MustGet("user").(*model.User) + if user == nil { + dto.WrapErrorAsOpenAI(c, 401, "Unauthorized") + return + } + var chatreq llm.ChatRequest if err := c.ShouldBindJSON(&chatreq); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) @@ -38,10 +43,10 @@ func (h *Proxy) ChatHandler(c *gin.Context) { fallthrough default: llm, err = openai_compatible.NewOpenAICompatible(h.apikey) - if err != nil { - dto.WrapErrorAsOpenAI(c, 500, fmt.Errorf("create llm client error: %w", err).Error()) - return - } + } + if err != nil { + dto.WrapErrorAsOpenAI(c, 500, fmt.Errorf("create llm client error: %w", err).Error()) + return } if !chatreq.Stream { @@ -62,19 +67,11 @@ func (h *Proxy) ChatHandler(c *gin.Context) { } llmusage := llm.GetTokenUsage() - + llmusage.User = user + llmusage.TokenID = c.GetInt64("token_id") cost := tokenizer.Cost(llmusage.Model, llmusage.PromptTokens+llmusage.ToolsTokens, llmusage.CompletionTokens) - userid, _ := strconv.ParseInt(c.GetString("user_id"), 10, 64) - usage := model.Usage{ - UserID: userid, - Model: llmusage.Model, - Stream: chatreq.Stream, - PromptTokens: llmusage.PromptTokens + llmusage.ToolsTokens, - CompletionTokens: llmusage.CompletionTokens, - TotalTokens: llmusage.TotalTokens, - Cost: fmt.Sprintf("%f", cost), - } - h.SendUsage(&usage) + + h.SendUsage(llmusage) defer fmt.Println("cost:", cost, "prompt_tokens:", llmusage.PromptTokens, "completion_tokens:", llmusage.CompletionTokens, "total_tokens:", llmusage.TotalTokens) } diff --git a/internal/controller/proxy/proxy.go b/internal/controller/proxy/proxy.go index b11f81c..3c18750 100644 --- a/internal/controller/proxy/proxy.go +++ b/internal/controller/proxy/proxy.go @@ -13,7 +13,9 @@ import ( "opencatd-open/internal/dao" "opencatd-open/internal/model" "opencatd-open/internal/utils" + "opencatd-open/llm" "opencatd-open/pkg/config" + "opencatd-open/pkg/tokenizer" "os" "strings" "sync" @@ -32,7 +34,7 @@ type Proxy struct { cfg *config.Config db *gorm.DB wg *sync.WaitGroup - usageChan chan *model.Usage // 用于异步处理的channel + usageChan chan *llm.TokenUsage // 用于异步处理的channel apikey *model.ApiKey httpClient *http.Client cache gcache.Cache @@ -63,7 +65,7 @@ func NewProxy(ctx context.Context, cfg *config.Config, db *gorm.DB, wg *sync.Wai wg: wg, httpClient: client, cache: gcache.New(1).Build(), - usageChan: make(chan *model.Usage, cfg.UsageChanSize), + usageChan: make(chan *llm.TokenUsage, cfg.UsageChanSize), userDAO: userDAO, apiKeyDao: apiKeyDAO, tokenDAO: tokenDAO, @@ -84,7 +86,7 @@ func (p *Proxy) HandleProxy(c *gin.Context) { } } -func (p *Proxy) SendUsage(usage *model.Usage) { +func (p *Proxy) SendUsage(usage *llm.TokenUsage) { select { case p.usageChan <- usage: default: @@ -140,46 +142,85 @@ func (p *Proxy) ProcessUsage() { } } -func (p *Proxy) Do(usage *model.Usage) error { +func (p *Proxy) Do(llmusage *llm.TokenUsage) error { err := p.db.Transaction(func(tx *gorm.DB) error { + now := time.Now() + cost := tokenizer.Cost(llmusage.Model, llmusage.PromptTokens, llmusage.CompletionTokens) + token, err := p.tokenDAO.GetByID(p.ctx, llmusage.TokenID) + if err != nil { + return err + } // 1. 记录使用记录 - if err := tx.WithContext(p.ctx).Create(usage).Error; err != nil { + if err := tx.WithContext(p.ctx).Create(&model.Usage{ + UserID: llmusage.User.ID, + TokenID: llmusage.TokenID, + Date: now, + Model: llmusage.Model, + Stream: llmusage.Stream, + PromptTokens: llmusage.PromptTokens, + CompletionTokens: llmusage.CompletionTokens, + TotalTokens: llmusage.TotalTokens, + Cost: fmt.Sprintf("%.8f", cost), + }).Error; err != nil { return fmt.Errorf("create usage error: %w", err) } // 2. 更新每日统计(upsert 操作) dailyUsage := model.DailyUsage{ - UserID: usage.UserID, - TokenID: usage.TokenID, - Capability: usage.Capability, - Date: time.Date(usage.Date.Year(), usage.Date.Month(), usage.Date.Day(), 0, 0, 0, 0, usage.Date.Location()), - Model: usage.Model, - Stream: usage.Stream, - PromptTokens: usage.PromptTokens, - CompletionTokens: usage.CompletionTokens, - TotalTokens: usage.TotalTokens, - Cost: usage.Cost, + UserID: llmusage.User.ID, + TokenID: llmusage.TokenID, + Date: time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()), + Model: llmusage.Model, + Stream: llmusage.Stream, + PromptTokens: llmusage.PromptTokens, + CompletionTokens: llmusage.CompletionTokens, + TotalTokens: llmusage.TotalTokens, + Cost: fmt.Sprintf("%.8f", cost), } // 使用 OnConflict 实现 upsert if err := tx.WithContext(p.ctx).Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "user_id"}, {Name: "token_id"}, {Name: "capability"}, {Name: "date"}}, // 唯一键 + Columns: []clause.Column{{Name: "user_id"}, {Name: "date"}}, // 唯一键 DoUpdates: clause.Assignments(map[string]interface{}{ - "prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens), - "completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens), - "total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens), - "cost": gorm.Expr("cost + ?", usage.Cost), + "prompt_tokens": gorm.Expr("prompt_tokens + ?", llmusage.PromptTokens), + "completion_tokens": gorm.Expr("completion_tokens + ?", llmusage.CompletionTokens), + "total_tokens": gorm.Expr("total_tokens + ?", llmusage.TotalTokens), + "cost": gorm.Expr("cost + ?", fmt.Sprintf("%.8f", cost)), }), }).Create(&dailyUsage).Error; err != nil { return fmt.Errorf("upsert daily usage error: %w", err) } // 3. 更新用户额度 - if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", usage.UserID).Updates(map[string]interface{}{ - "quota": gorm.Expr("quota - ?", usage.Cost), - "used_quota": gorm.Expr("used_quota + ?", usage.Cost), - }).Error; err != nil { - return fmt.Errorf("update user quota and used_quota error: %w", err) + if *llmusage.User.UnlimitedQuota { + if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", llmusage.User.ID).Updates(map[string]interface{}{ + "used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)), + }).Error; err != nil { + return fmt.Errorf("update user quota and used_quota error: %w", err) + } + } else { + if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", llmusage.User.ID).Updates(map[string]interface{}{ + "quota": gorm.Expr("quota - ?", fmt.Sprintf("%.8f", cost)), + "used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)), + }).Error; err != nil { + return fmt.Errorf("update user quota and used_quota error: %w", err) + } + } + + //4 . 更新token额度 + if *token.UnlimitedQuota { + if err := tx.WithContext(p.ctx).Model(&model.Token{}).Where("id = ?", llmusage.TokenID).Updates(map[string]interface{}{ + "used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)), + }).Error; err != nil { + return fmt.Errorf("update token quota and used_quota error: %w", err) + } + } else { + if err := tx.WithContext(p.ctx).Model(&model.Token{}).Where("id = ?", llmusage.TokenID).Updates(map[string]interface{}{ + "quota": gorm.Expr("quota - ?", fmt.Sprintf("%.8f", cost)), + "used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)), + }).Error; err != nil { + return fmt.Errorf("update token quota and used_quota error: %w", err) + } } return nil diff --git a/llm/types.go b/llm/types.go index 0b8bbe9..25278aa 100644 --- a/llm/types.go +++ b/llm/types.go @@ -2,6 +2,7 @@ package llm import ( "fmt" + "opencatd-open/internal/model" "github.com/sashabaranov/go-openai" ) @@ -15,11 +16,14 @@ type StreamChatResponse openai.ChatCompletionStreamResponse type ChatMessage openai.ChatCompletionMessage type TokenUsage struct { + User *model.User + TokenID int64 Model string `json:"model"` - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - ToolsTokens int `json:"tools_tokens"` - TotalTokens int `json:"total_tokens"` + Stream bool + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + ToolsTokens int `json:"tools_tokens"` + TotalTokens int `json:"total_tokens"` } type ErrorResponse struct { diff --git a/middleware/auth_team.go b/middleware/auth_team.go index c87253c..69d9c58 100644 --- a/middleware/auth_team.go +++ b/middleware/auth_team.go @@ -95,6 +95,7 @@ func AuthLLM(db *gorm.DB) gin.HandlerFunc { c.Set("user", token.User) c.Set("user_id", token.User.ID) + c.Set("token_id", token.ID) c.Set("authed", true) // 可以在这里对 token 进行验证并检查权限