fix record usage

This commit is contained in:
Sakurasan
2025-04-22 00:48:24 +08:00
parent 5789d50e9e
commit 6662ea5e04
4 changed files with 89 additions and 46 deletions

View File

@@ -10,12 +10,17 @@ import (
"opencatd-open/llm/google/v2" "opencatd-open/llm/google/v2"
"opencatd-open/llm/openai_compatible" "opencatd-open/llm/openai_compatible"
"opencatd-open/pkg/tokenizer" "opencatd-open/pkg/tokenizer"
"strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func (h *Proxy) ChatHandler(c *gin.Context) { func (h *Proxy) ChatHandler(c *gin.Context) {
user := c.MustGet("user").(*model.User)
if user == nil {
dto.WrapErrorAsOpenAI(c, 401, "Unauthorized")
return
}
var chatreq llm.ChatRequest var chatreq llm.ChatRequest
if err := c.ShouldBindJSON(&chatreq); err != nil { if err := c.ShouldBindJSON(&chatreq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -38,10 +43,10 @@ func (h *Proxy) ChatHandler(c *gin.Context) {
fallthrough fallthrough
default: default:
llm, err = openai_compatible.NewOpenAICompatible(h.apikey) llm, err = openai_compatible.NewOpenAICompatible(h.apikey)
if err != nil { }
dto.WrapErrorAsOpenAI(c, 500, fmt.Errorf("create llm client error: %w", err).Error()) if err != nil {
return dto.WrapErrorAsOpenAI(c, 500, fmt.Errorf("create llm client error: %w", err).Error())
} return
} }
if !chatreq.Stream { if !chatreq.Stream {
@@ -62,19 +67,11 @@ func (h *Proxy) ChatHandler(c *gin.Context) {
} }
llmusage := llm.GetTokenUsage() llmusage := llm.GetTokenUsage()
llmusage.User = user
llmusage.TokenID = c.GetInt64("token_id")
cost := tokenizer.Cost(llmusage.Model, llmusage.PromptTokens+llmusage.ToolsTokens, llmusage.CompletionTokens) cost := tokenizer.Cost(llmusage.Model, llmusage.PromptTokens+llmusage.ToolsTokens, llmusage.CompletionTokens)
userid, _ := strconv.ParseInt(c.GetString("user_id"), 10, 64)
usage := model.Usage{ h.SendUsage(llmusage)
UserID: userid,
Model: llmusage.Model,
Stream: chatreq.Stream,
PromptTokens: llmusage.PromptTokens + llmusage.ToolsTokens,
CompletionTokens: llmusage.CompletionTokens,
TotalTokens: llmusage.TotalTokens,
Cost: fmt.Sprintf("%f", cost),
}
h.SendUsage(&usage)
defer fmt.Println("cost:", cost, "prompt_tokens:", llmusage.PromptTokens, "completion_tokens:", llmusage.CompletionTokens, "total_tokens:", llmusage.TotalTokens) defer fmt.Println("cost:", cost, "prompt_tokens:", llmusage.PromptTokens, "completion_tokens:", llmusage.CompletionTokens, "total_tokens:", llmusage.TotalTokens)
} }

View File

@@ -13,7 +13,9 @@ import (
"opencatd-open/internal/dao" "opencatd-open/internal/dao"
"opencatd-open/internal/model" "opencatd-open/internal/model"
"opencatd-open/internal/utils" "opencatd-open/internal/utils"
"opencatd-open/llm"
"opencatd-open/pkg/config" "opencatd-open/pkg/config"
"opencatd-open/pkg/tokenizer"
"os" "os"
"strings" "strings"
"sync" "sync"
@@ -32,7 +34,7 @@ type Proxy struct {
cfg *config.Config cfg *config.Config
db *gorm.DB db *gorm.DB
wg *sync.WaitGroup wg *sync.WaitGroup
usageChan chan *model.Usage // 用于异步处理的channel usageChan chan *llm.TokenUsage // 用于异步处理的channel
apikey *model.ApiKey apikey *model.ApiKey
httpClient *http.Client httpClient *http.Client
cache gcache.Cache cache gcache.Cache
@@ -63,7 +65,7 @@ func NewProxy(ctx context.Context, cfg *config.Config, db *gorm.DB, wg *sync.Wai
wg: wg, wg: wg,
httpClient: client, httpClient: client,
cache: gcache.New(1).Build(), cache: gcache.New(1).Build(),
usageChan: make(chan *model.Usage, cfg.UsageChanSize), usageChan: make(chan *llm.TokenUsage, cfg.UsageChanSize),
userDAO: userDAO, userDAO: userDAO,
apiKeyDao: apiKeyDAO, apiKeyDao: apiKeyDAO,
tokenDAO: tokenDAO, tokenDAO: tokenDAO,
@@ -84,7 +86,7 @@ func (p *Proxy) HandleProxy(c *gin.Context) {
} }
} }
func (p *Proxy) SendUsage(usage *model.Usage) { func (p *Proxy) SendUsage(usage *llm.TokenUsage) {
select { select {
case p.usageChan <- usage: case p.usageChan <- usage:
default: default:
@@ -140,46 +142,85 @@ func (p *Proxy) ProcessUsage() {
} }
} }
func (p *Proxy) Do(usage *model.Usage) error { func (p *Proxy) Do(llmusage *llm.TokenUsage) error {
err := p.db.Transaction(func(tx *gorm.DB) error { err := p.db.Transaction(func(tx *gorm.DB) error {
now := time.Now()
cost := tokenizer.Cost(llmusage.Model, llmusage.PromptTokens, llmusage.CompletionTokens)
token, err := p.tokenDAO.GetByID(p.ctx, llmusage.TokenID)
if err != nil {
return err
}
// 1. 记录使用记录 // 1. 记录使用记录
if err := tx.WithContext(p.ctx).Create(usage).Error; err != nil { if err := tx.WithContext(p.ctx).Create(&model.Usage{
UserID: llmusage.User.ID,
TokenID: llmusage.TokenID,
Date: now,
Model: llmusage.Model,
Stream: llmusage.Stream,
PromptTokens: llmusage.PromptTokens,
CompletionTokens: llmusage.CompletionTokens,
TotalTokens: llmusage.TotalTokens,
Cost: fmt.Sprintf("%.8f", cost),
}).Error; err != nil {
return fmt.Errorf("create usage error: %w", err) return fmt.Errorf("create usage error: %w", err)
} }
// 2. 更新每日统计upsert 操作) // 2. 更新每日统计upsert 操作)
dailyUsage := model.DailyUsage{ dailyUsage := model.DailyUsage{
UserID: usage.UserID, UserID: llmusage.User.ID,
TokenID: usage.TokenID, TokenID: llmusage.TokenID,
Capability: usage.Capability, Date: time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()),
Date: time.Date(usage.Date.Year(), usage.Date.Month(), usage.Date.Day(), 0, 0, 0, 0, usage.Date.Location()), Model: llmusage.Model,
Model: usage.Model, Stream: llmusage.Stream,
Stream: usage.Stream, PromptTokens: llmusage.PromptTokens,
PromptTokens: usage.PromptTokens, CompletionTokens: llmusage.CompletionTokens,
CompletionTokens: usage.CompletionTokens, TotalTokens: llmusage.TotalTokens,
TotalTokens: usage.TotalTokens, Cost: fmt.Sprintf("%.8f", cost),
Cost: usage.Cost,
} }
// 使用 OnConflict 实现 upsert // 使用 OnConflict 实现 upsert
if err := tx.WithContext(p.ctx).Clauses(clause.OnConflict{ if err := tx.WithContext(p.ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "user_id"}, {Name: "token_id"}, {Name: "capability"}, {Name: "date"}}, // 唯一键 Columns: []clause.Column{{Name: "user_id"}, {Name: "date"}}, // 唯一键
DoUpdates: clause.Assignments(map[string]interface{}{ DoUpdates: clause.Assignments(map[string]interface{}{
"prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens), "prompt_tokens": gorm.Expr("prompt_tokens + ?", llmusage.PromptTokens),
"completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens), "completion_tokens": gorm.Expr("completion_tokens + ?", llmusage.CompletionTokens),
"total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens), "total_tokens": gorm.Expr("total_tokens + ?", llmusage.TotalTokens),
"cost": gorm.Expr("cost + ?", usage.Cost), "cost": gorm.Expr("cost + ?", fmt.Sprintf("%.8f", cost)),
}), }),
}).Create(&dailyUsage).Error; err != nil { }).Create(&dailyUsage).Error; err != nil {
return fmt.Errorf("upsert daily usage error: %w", err) return fmt.Errorf("upsert daily usage error: %w", err)
} }
// 3. 更新用户额度 // 3. 更新用户额度
if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", usage.UserID).Updates(map[string]interface{}{ if *llmusage.User.UnlimitedQuota {
"quota": gorm.Expr("quota - ?", usage.Cost), if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", llmusage.User.ID).Updates(map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", usage.Cost), "used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)),
}).Error; err != nil { }).Error; err != nil {
return fmt.Errorf("update user quota and used_quota error: %w", err) return fmt.Errorf("update user quota and used_quota error: %w", err)
}
} else {
if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", llmusage.User.ID).Updates(map[string]interface{}{
"quota": gorm.Expr("quota - ?", fmt.Sprintf("%.8f", cost)),
"used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)),
}).Error; err != nil {
return fmt.Errorf("update user quota and used_quota error: %w", err)
}
}
//4 . 更新token额度
if *token.UnlimitedQuota {
if err := tx.WithContext(p.ctx).Model(&model.Token{}).Where("id = ?", llmusage.TokenID).Updates(map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)),
}).Error; err != nil {
return fmt.Errorf("update token quota and used_quota error: %w", err)
}
} else {
if err := tx.WithContext(p.ctx).Model(&model.Token{}).Where("id = ?", llmusage.TokenID).Updates(map[string]interface{}{
"quota": gorm.Expr("quota - ?", fmt.Sprintf("%.8f", cost)),
"used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)),
}).Error; err != nil {
return fmt.Errorf("update token quota and used_quota error: %w", err)
}
} }
return nil return nil

View File

@@ -2,6 +2,7 @@ package llm
import ( import (
"fmt" "fmt"
"opencatd-open/internal/model"
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
) )
@@ -15,11 +16,14 @@ type StreamChatResponse openai.ChatCompletionStreamResponse
type ChatMessage openai.ChatCompletionMessage type ChatMessage openai.ChatCompletionMessage
type TokenUsage struct { type TokenUsage struct {
User *model.User
TokenID int64
Model string `json:"model"` Model string `json:"model"`
PromptTokens int `json:"prompt_tokens"` Stream bool
CompletionTokens int `json:"completion_tokens"` PromptTokens int `json:"prompt_tokens"`
ToolsTokens int `json:"tools_tokens"` CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"` ToolsTokens int `json:"tools_tokens"`
TotalTokens int `json:"total_tokens"`
} }
type ErrorResponse struct { type ErrorResponse struct {

View File

@@ -95,6 +95,7 @@ func AuthLLM(db *gorm.DB) gin.HandlerFunc {
c.Set("user", token.User) c.Set("user", token.User)
c.Set("user_id", token.User.ID) c.Set("user_id", token.User.ID)
c.Set("token_id", token.ID)
c.Set("authed", true) c.Set("authed", true)
// 可以在这里对 token 进行验证并检查权限 // 可以在这里对 token 进行验证并检查权限