fix record usage

This commit is contained in:
Sakurasan
2025-04-22 00:48:24 +08:00
parent 5789d50e9e
commit 6662ea5e04
4 changed files with 89 additions and 46 deletions

View File

@@ -10,12 +10,17 @@ import (
"opencatd-open/llm/google/v2"
"opencatd-open/llm/openai_compatible"
"opencatd-open/pkg/tokenizer"
"strconv"
"github.com/gin-gonic/gin"
)
func (h *Proxy) ChatHandler(c *gin.Context) {
user := c.MustGet("user").(*model.User)
if user == nil {
dto.WrapErrorAsOpenAI(c, 401, "Unauthorized")
return
}
var chatreq llm.ChatRequest
if err := c.ShouldBindJSON(&chatreq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -38,10 +43,10 @@ func (h *Proxy) ChatHandler(c *gin.Context) {
fallthrough
default:
llm, err = openai_compatible.NewOpenAICompatible(h.apikey)
if err != nil {
dto.WrapErrorAsOpenAI(c, 500, fmt.Errorf("create llm client error: %w", err).Error())
return
}
}
if err != nil {
dto.WrapErrorAsOpenAI(c, 500, fmt.Errorf("create llm client error: %w", err).Error())
return
}
if !chatreq.Stream {
@@ -62,19 +67,11 @@ func (h *Proxy) ChatHandler(c *gin.Context) {
}
llmusage := llm.GetTokenUsage()
llmusage.User = user
llmusage.TokenID = c.GetInt64("token_id")
cost := tokenizer.Cost(llmusage.Model, llmusage.PromptTokens+llmusage.ToolsTokens, llmusage.CompletionTokens)
userid, _ := strconv.ParseInt(c.GetString("user_id"), 10, 64)
usage := model.Usage{
UserID: userid,
Model: llmusage.Model,
Stream: chatreq.Stream,
PromptTokens: llmusage.PromptTokens + llmusage.ToolsTokens,
CompletionTokens: llmusage.CompletionTokens,
TotalTokens: llmusage.TotalTokens,
Cost: fmt.Sprintf("%f", cost),
}
h.SendUsage(&usage)
h.SendUsage(llmusage)
defer fmt.Println("cost:", cost, "prompt_tokens:", llmusage.PromptTokens, "completion_tokens:", llmusage.CompletionTokens, "total_tokens:", llmusage.TotalTokens)
}

View File

@@ -13,7 +13,9 @@ import (
"opencatd-open/internal/dao"
"opencatd-open/internal/model"
"opencatd-open/internal/utils"
"opencatd-open/llm"
"opencatd-open/pkg/config"
"opencatd-open/pkg/tokenizer"
"os"
"strings"
"sync"
@@ -32,7 +34,7 @@ type Proxy struct {
cfg *config.Config
db *gorm.DB
wg *sync.WaitGroup
usageChan chan *model.Usage // 用于异步处理的channel
usageChan chan *llm.TokenUsage // 用于异步处理的channel
apikey *model.ApiKey
httpClient *http.Client
cache gcache.Cache
@@ -63,7 +65,7 @@ func NewProxy(ctx context.Context, cfg *config.Config, db *gorm.DB, wg *sync.Wai
wg: wg,
httpClient: client,
cache: gcache.New(1).Build(),
usageChan: make(chan *model.Usage, cfg.UsageChanSize),
usageChan: make(chan *llm.TokenUsage, cfg.UsageChanSize),
userDAO: userDAO,
apiKeyDao: apiKeyDAO,
tokenDAO: tokenDAO,
@@ -84,7 +86,7 @@ func (p *Proxy) HandleProxy(c *gin.Context) {
}
}
func (p *Proxy) SendUsage(usage *model.Usage) {
func (p *Proxy) SendUsage(usage *llm.TokenUsage) {
select {
case p.usageChan <- usage:
default:
@@ -140,46 +142,85 @@ func (p *Proxy) ProcessUsage() {
}
}
func (p *Proxy) Do(usage *model.Usage) error {
func (p *Proxy) Do(llmusage *llm.TokenUsage) error {
err := p.db.Transaction(func(tx *gorm.DB) error {
now := time.Now()
cost := tokenizer.Cost(llmusage.Model, llmusage.PromptTokens, llmusage.CompletionTokens)
token, err := p.tokenDAO.GetByID(p.ctx, llmusage.TokenID)
if err != nil {
return err
}
// 1. 记录使用记录
if err := tx.WithContext(p.ctx).Create(usage).Error; err != nil {
if err := tx.WithContext(p.ctx).Create(&model.Usage{
UserID: llmusage.User.ID,
TokenID: llmusage.TokenID,
Date: now,
Model: llmusage.Model,
Stream: llmusage.Stream,
PromptTokens: llmusage.PromptTokens,
CompletionTokens: llmusage.CompletionTokens,
TotalTokens: llmusage.TotalTokens,
Cost: fmt.Sprintf("%.8f", cost),
}).Error; err != nil {
return fmt.Errorf("create usage error: %w", err)
}
// 2. 更新每日统计upsert 操作)
dailyUsage := model.DailyUsage{
UserID: usage.UserID,
TokenID: usage.TokenID,
Capability: usage.Capability,
Date: time.Date(usage.Date.Year(), usage.Date.Month(), usage.Date.Day(), 0, 0, 0, 0, usage.Date.Location()),
Model: usage.Model,
Stream: usage.Stream,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
TotalTokens: usage.TotalTokens,
Cost: usage.Cost,
UserID: llmusage.User.ID,
TokenID: llmusage.TokenID,
Date: time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()),
Model: llmusage.Model,
Stream: llmusage.Stream,
PromptTokens: llmusage.PromptTokens,
CompletionTokens: llmusage.CompletionTokens,
TotalTokens: llmusage.TotalTokens,
Cost: fmt.Sprintf("%.8f", cost),
}
// 使用 OnConflict 实现 upsert
if err := tx.WithContext(p.ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "user_id"}, {Name: "token_id"}, {Name: "capability"}, {Name: "date"}}, // 唯一键
Columns: []clause.Column{{Name: "user_id"}, {Name: "date"}}, // 唯一键
DoUpdates: clause.Assignments(map[string]interface{}{
"prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens),
"completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens),
"total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens),
"cost": gorm.Expr("cost + ?", usage.Cost),
"prompt_tokens": gorm.Expr("prompt_tokens + ?", llmusage.PromptTokens),
"completion_tokens": gorm.Expr("completion_tokens + ?", llmusage.CompletionTokens),
"total_tokens": gorm.Expr("total_tokens + ?", llmusage.TotalTokens),
"cost": gorm.Expr("cost + ?", fmt.Sprintf("%.8f", cost)),
}),
}).Create(&dailyUsage).Error; err != nil {
return fmt.Errorf("upsert daily usage error: %w", err)
}
// 3. 更新用户额度
if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", usage.UserID).Updates(map[string]interface{}{
"quota": gorm.Expr("quota - ?", usage.Cost),
"used_quota": gorm.Expr("used_quota + ?", usage.Cost),
}).Error; err != nil {
return fmt.Errorf("update user quota and used_quota error: %w", err)
if *llmusage.User.UnlimitedQuota {
if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", llmusage.User.ID).Updates(map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)),
}).Error; err != nil {
return fmt.Errorf("update user quota and used_quota error: %w", err)
}
} else {
if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", llmusage.User.ID).Updates(map[string]interface{}{
"quota": gorm.Expr("quota - ?", fmt.Sprintf("%.8f", cost)),
"used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)),
}).Error; err != nil {
return fmt.Errorf("update user quota and used_quota error: %w", err)
}
}
//4 . 更新token额度
if *token.UnlimitedQuota {
if err := tx.WithContext(p.ctx).Model(&model.Token{}).Where("id = ?", llmusage.TokenID).Updates(map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)),
}).Error; err != nil {
return fmt.Errorf("update token quota and used_quota error: %w", err)
}
} else {
if err := tx.WithContext(p.ctx).Model(&model.Token{}).Where("id = ?", llmusage.TokenID).Updates(map[string]interface{}{
"quota": gorm.Expr("quota - ?", fmt.Sprintf("%.8f", cost)),
"used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)),
}).Error; err != nil {
return fmt.Errorf("update token quota and used_quota error: %w", err)
}
}
return nil

View File

@@ -2,6 +2,7 @@ package llm
import (
"fmt"
"opencatd-open/internal/model"
"github.com/sashabaranov/go-openai"
)
@@ -15,11 +16,14 @@ type StreamChatResponse openai.ChatCompletionStreamResponse
type ChatMessage openai.ChatCompletionMessage
type TokenUsage struct {
User *model.User
TokenID int64
Model string `json:"model"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
ToolsTokens int `json:"tools_tokens"`
TotalTokens int `json:"total_tokens"`
Stream bool
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
ToolsTokens int `json:"tools_tokens"`
TotalTokens int `json:"total_tokens"`
}
type ErrorResponse struct {

View File

@@ -95,6 +95,7 @@ func AuthLLM(db *gorm.DB) gin.HandlerFunc {
c.Set("user", token.User)
c.Set("user_id", token.User.ID)
c.Set("token_id", token.ID)
c.Set("authed", true)
// 可以在这里对 token 进行验证并检查权限