fix record usage
This commit is contained in:
@@ -10,12 +10,17 @@ import (
|
||||
"opencatd-open/llm/google/v2"
|
||||
"opencatd-open/llm/openai_compatible"
|
||||
"opencatd-open/pkg/tokenizer"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (h *Proxy) ChatHandler(c *gin.Context) {
|
||||
user := c.MustGet("user").(*model.User)
|
||||
if user == nil {
|
||||
dto.WrapErrorAsOpenAI(c, 401, "Unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
var chatreq llm.ChatRequest
|
||||
if err := c.ShouldBindJSON(&chatreq); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -38,10 +43,10 @@ func (h *Proxy) ChatHandler(c *gin.Context) {
|
||||
fallthrough
|
||||
default:
|
||||
llm, err = openai_compatible.NewOpenAICompatible(h.apikey)
|
||||
if err != nil {
|
||||
dto.WrapErrorAsOpenAI(c, 500, fmt.Errorf("create llm client error: %w", err).Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
dto.WrapErrorAsOpenAI(c, 500, fmt.Errorf("create llm client error: %w", err).Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !chatreq.Stream {
|
||||
@@ -62,19 +67,11 @@ func (h *Proxy) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
llmusage := llm.GetTokenUsage()
|
||||
|
||||
llmusage.User = user
|
||||
llmusage.TokenID = c.GetInt64("token_id")
|
||||
cost := tokenizer.Cost(llmusage.Model, llmusage.PromptTokens+llmusage.ToolsTokens, llmusage.CompletionTokens)
|
||||
userid, _ := strconv.ParseInt(c.GetString("user_id"), 10, 64)
|
||||
usage := model.Usage{
|
||||
UserID: userid,
|
||||
Model: llmusage.Model,
|
||||
Stream: chatreq.Stream,
|
||||
PromptTokens: llmusage.PromptTokens + llmusage.ToolsTokens,
|
||||
CompletionTokens: llmusage.CompletionTokens,
|
||||
TotalTokens: llmusage.TotalTokens,
|
||||
Cost: fmt.Sprintf("%f", cost),
|
||||
}
|
||||
h.SendUsage(&usage)
|
||||
|
||||
h.SendUsage(llmusage)
|
||||
defer fmt.Println("cost:", cost, "prompt_tokens:", llmusage.PromptTokens, "completion_tokens:", llmusage.CompletionTokens, "total_tokens:", llmusage.TotalTokens)
|
||||
|
||||
}
|
||||
|
||||
@@ -13,7 +13,9 @@ import (
|
||||
"opencatd-open/internal/dao"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/internal/utils"
|
||||
"opencatd-open/llm"
|
||||
"opencatd-open/pkg/config"
|
||||
"opencatd-open/pkg/tokenizer"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -32,7 +34,7 @@ type Proxy struct {
|
||||
cfg *config.Config
|
||||
db *gorm.DB
|
||||
wg *sync.WaitGroup
|
||||
usageChan chan *model.Usage // 用于异步处理的channel
|
||||
usageChan chan *llm.TokenUsage // 用于异步处理的channel
|
||||
apikey *model.ApiKey
|
||||
httpClient *http.Client
|
||||
cache gcache.Cache
|
||||
@@ -63,7 +65,7 @@ func NewProxy(ctx context.Context, cfg *config.Config, db *gorm.DB, wg *sync.Wai
|
||||
wg: wg,
|
||||
httpClient: client,
|
||||
cache: gcache.New(1).Build(),
|
||||
usageChan: make(chan *model.Usage, cfg.UsageChanSize),
|
||||
usageChan: make(chan *llm.TokenUsage, cfg.UsageChanSize),
|
||||
userDAO: userDAO,
|
||||
apiKeyDao: apiKeyDAO,
|
||||
tokenDAO: tokenDAO,
|
||||
@@ -84,7 +86,7 @@ func (p *Proxy) HandleProxy(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) SendUsage(usage *model.Usage) {
|
||||
func (p *Proxy) SendUsage(usage *llm.TokenUsage) {
|
||||
select {
|
||||
case p.usageChan <- usage:
|
||||
default:
|
||||
@@ -140,46 +142,85 @@ func (p *Proxy) ProcessUsage() {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) Do(usage *model.Usage) error {
|
||||
func (p *Proxy) Do(llmusage *llm.TokenUsage) error {
|
||||
err := p.db.Transaction(func(tx *gorm.DB) error {
|
||||
now := time.Now()
|
||||
cost := tokenizer.Cost(llmusage.Model, llmusage.PromptTokens, llmusage.CompletionTokens)
|
||||
token, err := p.tokenDAO.GetByID(p.ctx, llmusage.TokenID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 1. 记录使用记录
|
||||
if err := tx.WithContext(p.ctx).Create(usage).Error; err != nil {
|
||||
if err := tx.WithContext(p.ctx).Create(&model.Usage{
|
||||
UserID: llmusage.User.ID,
|
||||
TokenID: llmusage.TokenID,
|
||||
Date: now,
|
||||
Model: llmusage.Model,
|
||||
Stream: llmusage.Stream,
|
||||
PromptTokens: llmusage.PromptTokens,
|
||||
CompletionTokens: llmusage.CompletionTokens,
|
||||
TotalTokens: llmusage.TotalTokens,
|
||||
Cost: fmt.Sprintf("%.8f", cost),
|
||||
}).Error; err != nil {
|
||||
return fmt.Errorf("create usage error: %w", err)
|
||||
}
|
||||
|
||||
// 2. 更新每日统计(upsert 操作)
|
||||
dailyUsage := model.DailyUsage{
|
||||
UserID: usage.UserID,
|
||||
TokenID: usage.TokenID,
|
||||
Capability: usage.Capability,
|
||||
Date: time.Date(usage.Date.Year(), usage.Date.Month(), usage.Date.Day(), 0, 0, 0, 0, usage.Date.Location()),
|
||||
Model: usage.Model,
|
||||
Stream: usage.Stream,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
TotalTokens: usage.TotalTokens,
|
||||
Cost: usage.Cost,
|
||||
UserID: llmusage.User.ID,
|
||||
TokenID: llmusage.TokenID,
|
||||
Date: time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()),
|
||||
Model: llmusage.Model,
|
||||
Stream: llmusage.Stream,
|
||||
PromptTokens: llmusage.PromptTokens,
|
||||
CompletionTokens: llmusage.CompletionTokens,
|
||||
TotalTokens: llmusage.TotalTokens,
|
||||
Cost: fmt.Sprintf("%.8f", cost),
|
||||
}
|
||||
|
||||
// 使用 OnConflict 实现 upsert
|
||||
if err := tx.WithContext(p.ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "user_id"}, {Name: "token_id"}, {Name: "capability"}, {Name: "date"}}, // 唯一键
|
||||
Columns: []clause.Column{{Name: "user_id"}, {Name: "date"}}, // 唯一键
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||
"prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens),
|
||||
"completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens),
|
||||
"total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens),
|
||||
"cost": gorm.Expr("cost + ?", usage.Cost),
|
||||
"prompt_tokens": gorm.Expr("prompt_tokens + ?", llmusage.PromptTokens),
|
||||
"completion_tokens": gorm.Expr("completion_tokens + ?", llmusage.CompletionTokens),
|
||||
"total_tokens": gorm.Expr("total_tokens + ?", llmusage.TotalTokens),
|
||||
"cost": gorm.Expr("cost + ?", fmt.Sprintf("%.8f", cost)),
|
||||
}),
|
||||
}).Create(&dailyUsage).Error; err != nil {
|
||||
return fmt.Errorf("upsert daily usage error: %w", err)
|
||||
}
|
||||
|
||||
// 3. 更新用户额度
|
||||
if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", usage.UserID).Updates(map[string]interface{}{
|
||||
"quota": gorm.Expr("quota - ?", usage.Cost),
|
||||
"used_quota": gorm.Expr("used_quota + ?", usage.Cost),
|
||||
}).Error; err != nil {
|
||||
return fmt.Errorf("update user quota and used_quota error: %w", err)
|
||||
if *llmusage.User.UnlimitedQuota {
|
||||
if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", llmusage.User.ID).Updates(map[string]interface{}{
|
||||
"used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)),
|
||||
}).Error; err != nil {
|
||||
return fmt.Errorf("update user quota and used_quota error: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", llmusage.User.ID).Updates(map[string]interface{}{
|
||||
"quota": gorm.Expr("quota - ?", fmt.Sprintf("%.8f", cost)),
|
||||
"used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)),
|
||||
}).Error; err != nil {
|
||||
return fmt.Errorf("update user quota and used_quota error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
//4 . 更新token额度
|
||||
if *token.UnlimitedQuota {
|
||||
if err := tx.WithContext(p.ctx).Model(&model.Token{}).Where("id = ?", llmusage.TokenID).Updates(map[string]interface{}{
|
||||
"used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)),
|
||||
}).Error; err != nil {
|
||||
return fmt.Errorf("update token quota and used_quota error: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := tx.WithContext(p.ctx).Model(&model.Token{}).Where("id = ?", llmusage.TokenID).Updates(map[string]interface{}{
|
||||
"quota": gorm.Expr("quota - ?", fmt.Sprintf("%.8f", cost)),
|
||||
"used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)),
|
||||
}).Error; err != nil {
|
||||
return fmt.Errorf("update token quota and used_quota error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
12
llm/types.go
12
llm/types.go
@@ -2,6 +2,7 @@ package llm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"opencatd-open/internal/model"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
@@ -15,11 +16,14 @@ type StreamChatResponse openai.ChatCompletionStreamResponse
|
||||
type ChatMessage openai.ChatCompletionMessage
|
||||
|
||||
type TokenUsage struct {
|
||||
User *model.User
|
||||
TokenID int64
|
||||
Model string `json:"model"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
ToolsTokens int `json:"tools_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
Stream bool
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
ToolsTokens int `json:"tools_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
|
||||
@@ -95,6 +95,7 @@ func AuthLLM(db *gorm.DB) gin.HandlerFunc {
|
||||
|
||||
c.Set("user", token.User)
|
||||
c.Set("user_id", token.User.ID)
|
||||
c.Set("token_id", token.ID)
|
||||
c.Set("authed", true)
|
||||
// 可以在这里对 token 进行验证并检查权限
|
||||
|
||||
|
||||
Reference in New Issue
Block a user