package dao import ( "context" "errors" "opencatd-open/team/consts" "opencatd-open/team/model" "time" "gorm.io/gorm" ) // 确保 TokenDAO 实现了 TokenRepository 接口 var _ TokenRepository = (*TokenDAO)(nil) type TokenRepository interface { Create(ctx context.Context, token *model.Token) error GetByID(ctx context.Context, id int) (*model.Token, error) GetByKey(ctx context.Context, key string) (*model.Token, error) GetByUserID(ctx context.Context, userID int) (*model.Token, error) Update(ctx context.Context, token *model.Token) error UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error Delete(ctx context.Context, id int) error List(ctx context.Context, offset, limit int) ([]model.Token, error) ListWithFilters(ctx context.Context, offset, limit int, filters map[string]interface{}) ([]model.Token, int64, error) Disable(ctx context.Context, id int) error Enable(ctx context.Context, id int) error BatchDisable(ctx context.Context, ids []int) error BatchEnable(ctx context.Context, ids []int) error BatchDelete(ctx context.Context, ids []int) error } type TokenDAO struct { db *gorm.DB } func NewTokenDAO(db *gorm.DB) *TokenDAO { return &TokenDAO{db: db} } // CreateToken 创建 Token func (dao *TokenDAO) Create(ctx context.Context, token *model.Token) error { if token == nil { return errors.New("token is nil") } return dao.db.WithContext(ctx).Create(token).Error } // 根据 ID 获取 Token func (dao *TokenDAO) GetByID(ctx context.Context, id int) (*model.Token, error) { var token model.Token err := dao.db.WithContext(ctx).First(&token, id).Error if err != nil { return nil, err } return &token, nil } // 根据 Key 获取 Token func (dao *TokenDAO) GetByKey(ctx context.Context, key string) (*model.Token, error) { var token model.Token // err := dao.db.Where("key = ?", key).First(&token).Error err := dao.db.WithContext(ctx).Preload("User").Where("key = ?", key).First(&token).Error if err != nil { return nil, err } return &token, nil } // 根据 UserID 获取 Token func (dao *TokenDAO) GetByUserID(ctx context.Context, userID int) (*model.Token, error) { var token model.Token err := dao.db.WithContext(ctx).Preload("User").Where("user_id = ?", userID).Find(&token).Error if err != nil { return nil, err } return &token, nil } // UpdateToken 更新 Token 信息 func (dao *TokenDAO) Update(ctx context.Context, token *model.Token) error { if token == nil { return errors.New("token is nil") } return dao.db.WithContext(ctx).Save(token).Error } // UpdateTokenWithFilters 更新 Token 信息,支持过滤 func (dao *TokenDAO) UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error { if token == nil { return errors.New("token is nil") } db := dao.db.WithContext(ctx) for key, value := range filters { db = db.Where(key+" = ?", value) } return db.Model(&model.Token{}).Updates(updates).Error } // DeleteToken 删除 Token func (dao *TokenDAO) Delete(ctx context.Context, id int) error { return dao.db.WithContext(ctx).Delete(&model.Token{}, id).Error } // ListTokens 获取 Token 列表 func (dao *TokenDAO) List(ctx context.Context, offset, limit int) ([]model.Token, error) { var tokens []model.Token err := dao.db.WithContext(ctx).Offset(offset).Limit(limit).Find(&tokens).Error if err != nil { return nil, err } return tokens, nil } // ListTokensWithFilters 获取 Token 列表,支持过滤 func (dao *TokenDAO) ListWithFilters(ctx context.Context, offset, limit int, filters map[string]interface{}) ([]model.Token, int64, error) { var tokens []model.Token var count int64 db := dao.db.WithContext(ctx) for key, value := range filters { db = db.Where(key+" = ?", value) } if err := db.Offset(offset).Limit(limit).Find(&tokens).Error; err != nil { return nil, 0, err } if err := db.Model(&model.Token{}).Count(&count).Error; err != nil { return nil, 0, err } return tokens, count, nil } // DisableToken 禁用 Token func (dao *TokenDAO) Disable(ctx context.Context, id int) error { return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id = ?", id).Update("status", false).Error } // EnableToken 启用 Token func (dao *TokenDAO) Enable(ctx context.Context, id int) error { return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id = ?", id).Update("status", true).Error } // BatchDisableTokens 批量禁用 Token func (dao *TokenDAO) BatchDisable(ctx context.Context, ids []int) error { return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id IN ?", ids).Update("status", false).Error } // BatchEnableTokens 批量启用 Token func (dao *TokenDAO) BatchEnable(ctx context.Context, ids []int) error { return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id IN ?", ids).Update("status", true).Error } // BatchDeleteTokens 批量删除 Token func (dao *TokenDAO) BatchDelete(ctx context.Context, ids []int) error { return dao.db.WithContext(ctx).Where("id IN ?", ids).Delete(&model.Token{}).Error } // 检查 token 是否有效 func (dao *TokenDAO) IsValid(ctx context.Context, key string) (bool, error) { var token model.Token err := dao.db.WithContext(ctx).Where("key = ? AND status = ? AND (expired_time = -1 OR expired_time > ?)", key, consts.StatusEnabled, time.Now().Unix()).First(&token).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return false, nil } return false, err } if token.User.Status != consts.StatusEnabled || (token.User.UnlimitedQuota == 1 && token.User.Quota <= 0) { return false, nil } return true, nil }