package dao import ( "context" "errors" "opencatd-open/internal/consts" "opencatd-open/internal/model" "time" "gorm.io/gorm" ) // 确保 TokenDAO 实现了 TokenRepository 接口 var _ TokenRepository = (*TokenDAO)(nil) type TokenRepository interface { Create(ctx context.Context, token *model.Token) error GetByID(ctx context.Context, id int64) (*model.Token, error) GetByKey(ctx context.Context, key string) (*model.Token, error) GetByUserID(ctx context.Context, userID int64) (*model.Token, error) Update(ctx context.Context, token *model.Token) error UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error Delete(ctx context.Context, id int64, condition map[string]interface{}) error List(ctx context.Context, limit, offset int) ([]*model.Token, error) ListWithFilters(ctx context.Context, limit, offset int, filters map[string]interface{}) ([]*model.Token, int64, error) Disable(ctx context.Context, id int) error Enable(ctx context.Context, id int) error BatchDisable(ctx context.Context, ids []int64, filters map[string]interface{}) error BatchEnable(ctx context.Context, ids []int64, filters map[string]interface{}) error BatchDelete(ctx context.Context, ids []int64, filters map[string]interface{}) error } type TokenDAO struct { db *gorm.DB } func NewTokenDAO(db *gorm.DB) *TokenDAO { return &TokenDAO{db: db} } // CreateToken 创建 Token func (dao *TokenDAO) Create(ctx context.Context, token *model.Token) error { if token == nil { return errors.New("token is nil") } return dao.db.WithContext(ctx).Create(token).Error } // 根据 ID 获取 Token func (dao *TokenDAO) GetByID(ctx context.Context, id int64) (*model.Token, error) { var token model.Token err := dao.db.WithContext(ctx).Preload("User").First(&token, id).Error if err != nil { return nil, err } return &token, nil } // 根据 Key 获取 Token func (dao *TokenDAO) GetByKey(ctx context.Context, key string) (*model.Token, error) { var token model.Token // err := dao.db.Where("key = ?", key).First(&token).Error err := dao.db.WithContext(ctx).Preload("User").Where("key = ?", key).First(&token).Error if err != nil { return nil, err } return &token, nil } // 根据 UserID 获取 Token func (dao *TokenDAO) GetByUserID(ctx context.Context, userID int64) (*model.Token, error) { var token model.Token err := dao.db.WithContext(ctx).Preload("User").Where("user_id = ?", userID).Find(&token).Error if err != nil { return nil, err } return &token, nil } // UpdateToken 更新 Token 信息 func (dao *TokenDAO) Update(ctx context.Context, token *model.Token) error { if token == nil { return errors.New("token is nil") } return dao.db.WithContext(ctx).Save(token).Error } // UpdateTokenWithFilters 更新 Token 信息,支持过滤 func (dao *TokenDAO) UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error { if token == nil { return errors.New("token is nil") } db := dao.db.WithContext(ctx) for key, value := range filters { db = db.Where(key+" = ?", value) } return db.Model(&model.Token{}).Updates(updates).Error } // DeleteToken 删除 Token func (dao *TokenDAO) Delete(ctx context.Context, id int64, condition map[string]interface{}) error { if id <= 0 { return errors.New("id is invalid") } query := dao.db.WithContext(ctx).Where("id = ?", id) for key, value := range condition { query = query.Where(key, value) } return query.Unscoped().Delete(&model.Token{}).Error } // ListTokens 获取 Token 列表 func (dao *TokenDAO) List(ctx context.Context, limit, offset int) ([]*model.Token, error) { var tokens []*model.Token err := dao.db.WithContext(ctx).Limit(limit).Offset(offset).Find(&tokens).Error if err != nil { return nil, err } return tokens, nil } // ListTokensWithFilters 获取 Token 列表,支持过滤 func (dao *TokenDAO) ListWithFilters(ctx context.Context, limit, offset int, filters map[string]interface{}) ([]*model.Token, int64, error) { var tokens []*model.Token var count int64 db := dao.db.WithContext(ctx) if filters != nil { for k, v := range filters { db = db.Where(k, v) } } if err := db.Limit(limit).Offset(offset).Find(&tokens).Error; err != nil { return nil, 0, err } db.Model(&model.Token{}).Count(&count) return tokens, count, nil } // DisableToken 禁用 Token func (dao *TokenDAO) Disable(ctx context.Context, id int) error { return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id = ?", id).Update("status", false).Error } // EnableToken 启用 Token func (dao *TokenDAO) Enable(ctx context.Context, id int) error { return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id = ?", id).Update("status", true).Error } // BatchDisableTokens 批量禁用 Token func (dao *TokenDAO) BatchDisable(ctx context.Context, ids []int64, filters map[string]interface{}) error { query := dao.db.WithContext(ctx).Model(&model.Token{}).Where("id IN ?", ids) for key, value := range filters { query = query.Where(key, value) } return query.Update("active", false).Error } // BatchEnableTokens 批量启用 Token func (dao *TokenDAO) BatchEnable(ctx context.Context, ids []int64, filters map[string]interface{}) error { query := dao.db.WithContext(ctx).Model(&model.Token{}).Where("id IN ?", ids) for key, value := range filters { query = query.Where(key, value) } return query.Update("active", true).Error } // BatchDeleteTokens 批量删除 Token func (dao *TokenDAO) BatchDelete(ctx context.Context, ids []int64, filters map[string]interface{}) error { query := dao.db.Unscoped().WithContext(ctx).Where("id IN ?", ids) for key, value := range filters { query = query.Where(key, value) } return query.Delete(&model.Token{}).Error // return dao.db.WithContext(ctx).Where("name != 'default' AND id IN ?", ids).Delete(&model.Token{}).Error } // 检查 token 是否有效 func (dao *TokenDAO) IsValid(ctx context.Context, key string) (bool, error) { var token model.Token err := dao.db.WithContext(ctx).Where("key = ? AND status = ? AND (expired_time = -1 OR expired_time > ?)", key, consts.StatusEnabled, time.Now().Unix()).First(&token).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return false, nil } return false, err } if token.User.Status != consts.StatusEnabled || (*token.User.UnlimitedQuota && *token.User.Quota <= 0) { return false, nil } return true, nil }