package dao import ( "context" "fmt" dto "opencatd-open/internal/dto/team" "opencatd-open/internal/model" "opencatd-open/pkg/config" "time" "gorm.io/gorm" "gorm.io/gorm/clause" ) var _ UsageRepository = (*UsageDAO)(nil) var _ DailyUsageRepository = (*DailyUsageDAO)(nil) type UsageRepository interface { // Create Create(ctx context.Context, usage *model.Usage) error BatchCreate(ctx context.Context, usages []*model.Usage) error // Read ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error) ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.Usage, error) ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.Usage, error) ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) // Delete Delete(ctx context.Context, id int64) error // Statistics CountByUserID(ctx context.Context, userID int64) (int64, error) } type DailyUsageRepository interface { // Create Create(ctx context.Context, usage *model.DailyUsage) error BatchCreate(ctx context.Context, usages []*model.DailyUsage) error // Read ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.DailyUsage, error) ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.DailyUsage, error) ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.DailyUsage, error) GetByDate(ctx context.Context, userID int64, date time.Time) (*model.DailyUsage, error) // Delete Delete(ctx context.Context, id int64) error // Statistics CountByUserID(ctx context.Context, userID int64) (int64, error) StatUserUsages(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) } type UsageDAO struct { db *gorm.DB } type DailyUsageDAO struct { cfg *config.Config db *gorm.DB } func NewUsageDAO(cfg *config.Config, db *gorm.DB) *UsageDAO { return &UsageDAO{db: db} } func NewDailyUsageDAO(cfg *config.Config, db *gorm.DB) *DailyUsageDAO { return &DailyUsageDAO{db: db} } // Usage DAO implementations func (d *UsageDAO) Create(ctx context.Context, usage *model.Usage) error { return d.db.WithContext(ctx).Create(usage).Error } func (d *UsageDAO) BatchCreate(ctx context.Context, usages []*model.Usage) error { return d.db.WithContext(ctx).Create(usages).Error } func (d *UsageDAO) GetByID(ctx context.Context, id int64) (*model.Usage, error) { var usage model.Usage err := d.db.WithContext(ctx).First(&usage, id).Error if err != nil { return nil, err } return &usage, nil } func (d *UsageDAO) ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error) { var usages []*model.Usage err := d.db.WithContext(ctx). Where("user_id = ?", userID). Limit(limit). Offset(offset). Find(&usages).Error return usages, err } func (d *UsageDAO) ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.Usage, error) { var usages []*model.Usage err := d.db.WithContext(ctx). Where("token_id = ?", tokenID). Limit(limit). Offset(offset). Find(&usages).Error return usages, err } func (d *UsageDAO) ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.Usage, error) { var usages []*model.Usage err := d.db.WithContext(ctx). Where("date BETWEEN ? AND ?", start, end). Find(&usages).Error return usages, err } func (d *UsageDAO) ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) { var usages []*model.Usage err := d.db.WithContext(ctx). Where("capability = ?", capability). Limit(limit). Offset(offset). Find(&usages).Error return usages, err } func (d *UsageDAO) Delete(ctx context.Context, id int64) error { return d.db.WithContext(ctx).Delete(&model.Usage{}, id).Error } func (d *UsageDAO) CountByUserID(ctx context.Context, userID int64) (int64, error) { var count int64 err := d.db.WithContext(ctx).Model(&model.Usage{}).Where("user_id = ?", userID).Count(&count).Error return count, err } // DailyUsage DAO implementations func (d *DailyUsageDAO) Create(ctx context.Context, usage *model.DailyUsage) error { return d.db.WithContext(ctx).Create(usage).Error } func (d *DailyUsageDAO) BatchCreate(ctx context.Context, usages []*model.DailyUsage) error { return d.db.WithContext(ctx).Create(usages).Error } func (d *DailyUsageDAO) ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.DailyUsage, error) { var usages []*model.DailyUsage err := d.db.WithContext(ctx). Where("user_id = ?", userID). Limit(limit). Offset(offset). Find(&usages).Error return usages, err } func (d *DailyUsageDAO) ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.DailyUsage, error) { var usages []*model.DailyUsage err := d.db.WithContext(ctx). Where("token_id = ?", tokenID). Limit(limit). Offset(offset). Find(&usages).Error return usages, err } func (d *DailyUsageDAO) ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.DailyUsage, error) { var usages []*model.DailyUsage err := d.db.WithContext(ctx). Where("date BETWEEN ? AND ?", start, end). Find(&usages).Error return usages, err } func (d *DailyUsageDAO) GetByDate(ctx context.Context, userID int64, date time.Time) (*model.DailyUsage, error) { var usage model.DailyUsage err := d.db.WithContext(ctx). Where("user_id = ? AND date = ?", userID, date). First(&usage).Error if err != nil { return nil, err } return &usage, nil } // UpsertDailyUsage 根据不同数据库类型执行 Upsert func (d *DailyUsageDAO) UpsertDailyUsage(ctx context.Context, usage *model.Usage) error { date := usage.Date.Truncate(24 * time.Hour) dailyUsage := &model.DailyUsage{ UserID: usage.UserID, TokenID: usage.TokenID, Capability: usage.Capability, Model: usage.Model, Stream: usage.Stream, PromptTokens: usage.PromptTokens, CompletionTokens: usage.CompletionTokens, TotalTokens: usage.TotalTokens, Cost: usage.Cost, } updateColumns := map[string]interface{}{ "prompt_tokens": gorm.Expr("prompt_tokens + VALUES(prompt_tokens)"), "completion_tokens": gorm.Expr("completion_tokens + VALUES(completion_tokens)"), "total_tokens": gorm.Expr("total_tokens + VALUES(total_tokens)"), } db := d.db.WithContext(ctx) switch d.cfg.DB_Type { case "mysql": // MySQL: INSERT ... ON DUPLICATE KEY UPDATE return db.Clauses(clause.OnConflict{ Columns: []clause.Column{ {Name: "user_id"}, {Name: "date"}, }, DoUpdates: clause.Assignments(updateColumns), }).Create(dailyUsage).Error case "postgres": // PostgreSQL: INSERT ... ON CONFLICT DO UPDATE updateColumns := map[string]interface{}{ "prompt_tokens": gorm.Expr("daily_usages.prompt_tokens + EXCLUDED.prompt_tokens"), "completion_tokens": gorm.Expr("daily_usages.completion_tokens + EXCLUDED.completion_tokens"), "total_tokens": gorm.Expr("daily_usages.total_tokens + EXCLUDED.total_tokens"), } return db.Clauses(clause.OnConflict{ Columns: []clause.Column{ {Name: "user_id"}, {Name: "date"}, }, DoUpdates: clause.Assignments(updateColumns), }).Create(dailyUsage).Error case "sqlite": fallthrough default: return db.Transaction(func(tx *gorm.DB) error { var existing model.DailyUsage err := tx.Where("user_id = ? AND date = ?", usage.UserID, date). First(&existing).Error if err == gorm.ErrRecordNotFound { // 记录不存在,创建新记录 return tx.Create(dailyUsage).Error } else if err != nil { return err // 返回其他错误 } // 记录存在,更新 return tx.Model(&existing).Updates(map[string]interface{}{ "prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens), "completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens), "total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens), }).Error }) } } func (d *DailyUsageDAO) Delete(ctx context.Context, id int64) error { return d.db.WithContext(ctx).Delete(&model.DailyUsage{}, id).Error } func (d *DailyUsageDAO) CountByUserID(ctx context.Context, userID int64) (int64, error) { var count int64 err := d.db.WithContext(ctx).Model(&model.DailyUsage{}).Where("user_id = ?", userID).Count(&count).Error return count, err } func (d *DailyUsageDAO) StatUserUsages(ctx context.Context, from, to time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) { var usages []*dto.UsageInfo query := d.db.WithContext(ctx). Model(&model.DailyUsage{}). Select("user_id as userId, sum(total_tokens) as totalUnit, sum(cast(cost as decimal(20,6))) as cost") for key, value := range filters { query = query.Where(fmt.Sprintf("%s = ?", key), value) } query = query.Group("user_id").Where("date >= ? AND date <= ?", from, to) err := query.Group("user_id").Find(&usages).Error if err != nil { return nil, fmt.Errorf("failed to list usages: %w", err) } return usages, nil }