288 lines
8.9 KiB
Go
288 lines
8.9 KiB
Go
package dao
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
dto "opencatd-open/internal/dto/team"
|
|
"opencatd-open/internal/model"
|
|
"opencatd-open/pkg/config"
|
|
"time"
|
|
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
)
|
|
|
|
var _ UsageRepository = (*UsageDAO)(nil)
|
|
var _ DailyUsageRepository = (*DailyUsageDAO)(nil)
|
|
|
|
type UsageRepository interface {
|
|
// Create
|
|
Create(ctx context.Context, usage *model.Usage) error
|
|
BatchCreate(ctx context.Context, usages []*model.Usage) error
|
|
|
|
// Read
|
|
ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error)
|
|
ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.Usage, error)
|
|
ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.Usage, error)
|
|
ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error)
|
|
|
|
// Delete
|
|
Delete(ctx context.Context, id int64) error
|
|
|
|
// Statistics
|
|
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
|
}
|
|
|
|
type DailyUsageRepository interface {
|
|
// Create
|
|
Create(ctx context.Context, usage *model.DailyUsage) error
|
|
BatchCreate(ctx context.Context, usages []*model.DailyUsage) error
|
|
|
|
// Read
|
|
ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.DailyUsage, error)
|
|
ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.DailyUsage, error)
|
|
ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.DailyUsage, error)
|
|
GetByDate(ctx context.Context, userID int64, date time.Time) (*model.DailyUsage, error)
|
|
|
|
// Delete
|
|
Delete(ctx context.Context, id int64) error
|
|
|
|
// Statistics
|
|
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
|
StatUserUsages(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error)
|
|
}
|
|
|
|
type UsageDAO struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
type DailyUsageDAO struct {
|
|
cfg *config.Config
|
|
db *gorm.DB
|
|
}
|
|
|
|
func NewUsageDAO(cfg *config.Config, db *gorm.DB) *UsageDAO {
|
|
return &UsageDAO{db: db}
|
|
}
|
|
|
|
func NewDailyUsageDAO(cfg *config.Config, db *gorm.DB) *DailyUsageDAO {
|
|
return &DailyUsageDAO{db: db}
|
|
}
|
|
|
|
// Usage DAO implementations
|
|
func (d *UsageDAO) Create(ctx context.Context, usage *model.Usage) error {
|
|
return d.db.WithContext(ctx).Create(usage).Error
|
|
}
|
|
|
|
func (d *UsageDAO) BatchCreate(ctx context.Context, usages []*model.Usage) error {
|
|
return d.db.WithContext(ctx).Create(usages).Error
|
|
}
|
|
|
|
func (d *UsageDAO) GetByID(ctx context.Context, id int64) (*model.Usage, error) {
|
|
var usage model.Usage
|
|
err := d.db.WithContext(ctx).First(&usage, id).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &usage, nil
|
|
}
|
|
|
|
func (d *UsageDAO) ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error) {
|
|
var usages []*model.Usage
|
|
err := d.db.WithContext(ctx).
|
|
Where("user_id = ?", userID).
|
|
Limit(limit).
|
|
Offset(offset).
|
|
Find(&usages).Error
|
|
return usages, err
|
|
}
|
|
|
|
func (d *UsageDAO) ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.Usage, error) {
|
|
var usages []*model.Usage
|
|
err := d.db.WithContext(ctx).
|
|
Where("token_id = ?", tokenID).
|
|
Limit(limit).
|
|
Offset(offset).
|
|
Find(&usages).Error
|
|
return usages, err
|
|
}
|
|
|
|
func (d *UsageDAO) ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.Usage, error) {
|
|
var usages []*model.Usage
|
|
err := d.db.WithContext(ctx).
|
|
Where("date BETWEEN ? AND ?", start, end).
|
|
Find(&usages).Error
|
|
return usages, err
|
|
}
|
|
|
|
func (d *UsageDAO) ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) {
|
|
var usages []*model.Usage
|
|
err := d.db.WithContext(ctx).
|
|
Where("capability = ?", capability).
|
|
Limit(limit).
|
|
Offset(offset).
|
|
Find(&usages).Error
|
|
return usages, err
|
|
}
|
|
|
|
func (d *UsageDAO) Delete(ctx context.Context, id int64) error {
|
|
return d.db.WithContext(ctx).Delete(&model.Usage{}, id).Error
|
|
}
|
|
|
|
func (d *UsageDAO) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
|
var count int64
|
|
err := d.db.WithContext(ctx).Model(&model.Usage{}).Where("user_id = ?", userID).Count(&count).Error
|
|
return count, err
|
|
}
|
|
|
|
// DailyUsage DAO implementations
|
|
func (d *DailyUsageDAO) Create(ctx context.Context, usage *model.DailyUsage) error {
|
|
return d.db.WithContext(ctx).Create(usage).Error
|
|
}
|
|
|
|
func (d *DailyUsageDAO) BatchCreate(ctx context.Context, usages []*model.DailyUsage) error {
|
|
return d.db.WithContext(ctx).Create(usages).Error
|
|
}
|
|
|
|
func (d *DailyUsageDAO) ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.DailyUsage, error) {
|
|
var usages []*model.DailyUsage
|
|
err := d.db.WithContext(ctx).
|
|
Where("user_id = ?", userID).
|
|
Limit(limit).
|
|
Offset(offset).
|
|
Find(&usages).Error
|
|
return usages, err
|
|
}
|
|
|
|
func (d *DailyUsageDAO) ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.DailyUsage, error) {
|
|
var usages []*model.DailyUsage
|
|
err := d.db.WithContext(ctx).
|
|
Where("token_id = ?", tokenID).
|
|
Limit(limit).
|
|
Offset(offset).
|
|
Find(&usages).Error
|
|
return usages, err
|
|
}
|
|
|
|
func (d *DailyUsageDAO) ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.DailyUsage, error) {
|
|
var usages []*model.DailyUsage
|
|
err := d.db.WithContext(ctx).
|
|
Where("date BETWEEN ? AND ?", start, end).
|
|
Find(&usages).Error
|
|
return usages, err
|
|
}
|
|
|
|
func (d *DailyUsageDAO) GetByDate(ctx context.Context, userID int64, date time.Time) (*model.DailyUsage, error) {
|
|
var usage model.DailyUsage
|
|
err := d.db.WithContext(ctx).
|
|
Where("user_id = ? AND date = ?", userID, date).
|
|
First(&usage).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &usage, nil
|
|
}
|
|
|
|
// UpsertDailyUsage 根据不同数据库类型执行 Upsert
|
|
func (d *DailyUsageDAO) UpsertDailyUsage(ctx context.Context, usage *model.Usage) error {
|
|
date := usage.Date.Truncate(24 * time.Hour)
|
|
dailyUsage := &model.DailyUsage{
|
|
UserID: usage.UserID,
|
|
TokenID: usage.TokenID,
|
|
Capability: usage.Capability,
|
|
Model: usage.Model,
|
|
Stream: usage.Stream,
|
|
PromptTokens: usage.PromptTokens,
|
|
CompletionTokens: usage.CompletionTokens,
|
|
TotalTokens: usage.TotalTokens,
|
|
Cost: usage.Cost,
|
|
}
|
|
|
|
updateColumns := map[string]interface{}{
|
|
"prompt_tokens": gorm.Expr("prompt_tokens + VALUES(prompt_tokens)"),
|
|
"completion_tokens": gorm.Expr("completion_tokens + VALUES(completion_tokens)"),
|
|
"total_tokens": gorm.Expr("total_tokens + VALUES(total_tokens)"),
|
|
}
|
|
|
|
db := d.db.WithContext(ctx)
|
|
|
|
switch d.cfg.DB_Type {
|
|
case "mysql":
|
|
// MySQL: INSERT ... ON DUPLICATE KEY UPDATE
|
|
return db.Clauses(clause.OnConflict{
|
|
Columns: []clause.Column{
|
|
{Name: "user_id"},
|
|
{Name: "date"},
|
|
},
|
|
DoUpdates: clause.Assignments(updateColumns),
|
|
}).Create(dailyUsage).Error
|
|
|
|
case "postgres":
|
|
// PostgreSQL: INSERT ... ON CONFLICT DO UPDATE
|
|
updateColumns := map[string]interface{}{
|
|
"prompt_tokens": gorm.Expr("daily_usages.prompt_tokens + EXCLUDED.prompt_tokens"),
|
|
"completion_tokens": gorm.Expr("daily_usages.completion_tokens + EXCLUDED.completion_tokens"),
|
|
"total_tokens": gorm.Expr("daily_usages.total_tokens + EXCLUDED.total_tokens"),
|
|
}
|
|
return db.Clauses(clause.OnConflict{
|
|
Columns: []clause.Column{
|
|
{Name: "user_id"},
|
|
{Name: "date"},
|
|
},
|
|
DoUpdates: clause.Assignments(updateColumns),
|
|
}).Create(dailyUsage).Error
|
|
case "sqlite":
|
|
fallthrough
|
|
default:
|
|
return db.Transaction(func(tx *gorm.DB) error {
|
|
var existing model.DailyUsage
|
|
err := tx.Where("user_id = ? AND date = ?",
|
|
usage.UserID, date).
|
|
First(&existing).Error
|
|
|
|
if err == gorm.ErrRecordNotFound {
|
|
// 记录不存在,创建新记录
|
|
return tx.Create(dailyUsage).Error
|
|
} else if err != nil {
|
|
return err // 返回其他错误
|
|
}
|
|
|
|
// 记录存在,更新
|
|
return tx.Model(&existing).Updates(map[string]interface{}{
|
|
"prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens),
|
|
"completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens),
|
|
"total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens),
|
|
}).Error
|
|
})
|
|
}
|
|
}
|
|
|
|
func (d *DailyUsageDAO) Delete(ctx context.Context, id int64) error {
|
|
return d.db.WithContext(ctx).Delete(&model.DailyUsage{}, id).Error
|
|
}
|
|
|
|
func (d *DailyUsageDAO) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
|
var count int64
|
|
err := d.db.WithContext(ctx).Model(&model.DailyUsage{}).Where("user_id = ?", userID).Count(&count).Error
|
|
return count, err
|
|
}
|
|
|
|
func (d *DailyUsageDAO) StatUserUsages(ctx context.Context, from, to time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) {
|
|
var usages []*dto.UsageInfo
|
|
|
|
query := d.db.WithContext(ctx).
|
|
Model(&model.DailyUsage{}).
|
|
Select("user_id as userId, sum(total_tokens) as totalUnit, sum(cast(cost as decimal(20,6))) as cost")
|
|
for key, value := range filters {
|
|
query = query.Where(fmt.Sprintf("%s = ?", key), value)
|
|
}
|
|
query = query.Group("user_id").Where("date >= ? AND date <= ?", from, to)
|
|
|
|
err := query.Group("user_id").Find(&usages).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list usages: %w", err)
|
|
}
|
|
|
|
return usages, nil
|
|
}
|