package dao import ( "errors" "opencatd-open/internal/model" "opencatd-open/internal/utils" "opencatd-open/pkg/config" "gorm.io/gorm" ) var _ ApiKeyRepository = (*ApiKeyDAO)(nil) type ApiKeyRepository interface { Create(apiKey *model.ApiKey) error GetByID(id int64) (*model.ApiKey, error) GetByName(name string) (*model.ApiKey, error) GetByApiKey(apiKeyValue string) (*model.ApiKey, error) Update(apiKey *model.ApiKey) error List(limit, offset int, status string) ([]*model.ApiKey, error) ListWithFilters(limit, offset int, filters map[string]interface{}) ([]*model.ApiKey, int64, error) BatchEnable(ids []int64) error BatchDisable(ids []int64) error BatchDelete(ids []int64) error Count() (int64, error) } type ApiKeyDAO struct { cfg *config.Config db *gorm.DB } func NewApiKeyDAO(cfg *config.Config, db *gorm.DB) *ApiKeyDAO { return &ApiKeyDAO{cfg: cfg, db: db} } // CreateApiKey 创建ApiKey func (dao *ApiKeyDAO) Create(apiKey *model.ApiKey) error { if apiKey == nil { return errors.New("apiKey is nil") } if len(*apiKey.SupportModels) < 2 { apiKey.SupportModels = utils.ToPtr("[]") } return dao.db.Create(apiKey).Error } // GetApiKeyByID 根据ID获取ApiKey func (dao *ApiKeyDAO) GetByID(id int64) (*model.ApiKey, error) { var apiKey model.ApiKey err := dao.db.First(&apiKey, id).Error if err != nil { return nil, err } return &apiKey, nil } // GetApiKeyByName 根据名称获取ApiKey func (dao *ApiKeyDAO) GetByName(name string) (*model.ApiKey, error) { var apiKey model.ApiKey err := dao.db.Where("name = ?", name).First(&apiKey).Error if err != nil { return nil, err } return &apiKey, nil } // GetApiKeyByApiKey 根据ApiKey值获取ApiKey func (dao *ApiKeyDAO) GetByApiKey(apiKeyValue string) (*model.ApiKey, error) { var apiKey model.ApiKey err := dao.db.Where("api_key = ?", apiKeyValue).First(&apiKey).Error if err != nil { return nil, err } return &apiKey, nil } func (dao *ApiKeyDAO) FindKeys(condition map[string]any) ([]model.ApiKey, error) { var apiKeys []model.ApiKey query := dao.db.Model(&model.ApiKey{}) for k, v := range condition { query = query.Where(k, v) } err := query.Find(&apiKeys).Error return apiKeys, err } func (dao *ApiKeyDAO) FindApiKeysBySupportModel(db *gorm.DB, modelName string) ([]model.ApiKey, error) { var apiKeys []model.ApiKey switch dao.cfg.DB_Type { case "mysql": err := db.Raw(` SELECT * FROM apikeys WHERE active = true AND JSON_CONTAINS(support_models, ?, '$')`, modelName). Scan(&apiKeys).Error return apiKeys, err case "postgres": return nil, errors.New("not support") } err := db.Raw(` SELECT a.* FROM apikeys a JOIN json_each(a.support_models) AS je ON je.value = ? WHERE a.active = true`, modelName).Scan(&apiKeys).Error return apiKeys, err } // UpdateApiKey 更新ApiKey信息 func (dao *ApiKeyDAO) Update(apiKey *model.ApiKey) error { if apiKey == nil { return errors.New("apiKey is nil") } // return dao.db.Model(&model.ApiKey{}). // Select("name", "apitype", "apikey", "status", "endpoint", "resource_name", "deployment_name").Updates(apiKey).Error return dao.db.Save(apiKey).Error } // DeleteApiKey 删除ApiKey func (dao *ApiKeyDAO) Delete(id int64) error { return dao.db.Unscoped().Delete(&model.ApiKey{}, id).Error } // ListApiKeys 获取ApiKey列表 func (dao *ApiKeyDAO) List(limit, offset int, status string) ([]*model.ApiKey, error) { var apiKeys []*model.ApiKey db := dao.db.Limit(limit).Offset(offset) if status != "" { db = db.Where("status = ?", status) } err := db.Find(&apiKeys).Error if err != nil { return nil, err } return apiKeys, nil } // ListApiKeysWithFilters 根据条件获取ApiKey列表 func (dao *ApiKeyDAO) ListWithFilters(limit, offset int, filters map[string]interface{}) ([]*model.ApiKey, int64, error) { var apiKeys []*model.ApiKey db := dao.db.Limit(limit).Offset(offset) for k, v := range filters { db = db.Where(k, v) } err := db.Find(&apiKeys).Error if err != nil { return nil, 0, err } var total int64 db.Model(&model.ApiKey{}).Count(&total) return apiKeys, total, nil } // BatchEnableApiKeys 批量启用ApiKey func (dao *ApiKeyDAO) BatchEnable(ids []int64) error { if len(ids) == 0 { return errors.New("ids is empty") } return dao.db.Model(&model.ApiKey{}).Where("id IN ?", ids).Update("active", true).Error } // BatchDisableApiKeys 批量禁用ApiKey func (dao *ApiKeyDAO) BatchDisable(ids []int64) error { if len(ids) == 0 { return errors.New("ids is empty") } return dao.db.Model(&model.ApiKey{}).Where("id IN ?", ids).Update("active", false).Error } // BatchDeleteApiKey 批量删除ApiKey func (dao *ApiKeyDAO) BatchDelete(ids []int64) error { if len(ids) == 0 { return errors.New("ids is empty") } return dao.db.Unscoped().Delete(&model.ApiKey{}, ids).Error } // CountApiKeys 获取ApiKey总数 func (dao *ApiKeyDAO) Count() (int64, error) { var count int64 err := dao.db.Model(&model.ApiKey{}).Count(&count).Error if err != nil { return 0, err } return count, nil }