reface to openteam
This commit is contained in:
179
internal/dao/apikey.go
Normal file
179
internal/dao/apikey.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/pkg/config"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ ApiKeyRepository = (*ApiKeyDAO)(nil)
|
||||
|
||||
type ApiKeyRepository interface {
|
||||
Create(apiKey *model.ApiKey) error
|
||||
GetByID(id int64) (*model.ApiKey, error)
|
||||
GetByName(name string) (*model.ApiKey, error)
|
||||
GetByApiKey(apiKeyValue string) (*model.ApiKey, error)
|
||||
Update(apiKey *model.ApiKey) error
|
||||
List(limit, offset int, status string) ([]*model.ApiKey, error)
|
||||
ListWithFilters(limit, offset int, filters map[string]interface{}) ([]*model.ApiKey, int64, error)
|
||||
BatchEnable(ids []int64) error
|
||||
BatchDisable(ids []int64) error
|
||||
BatchDelete(ids []int64) error
|
||||
Count() (int64, error)
|
||||
}
|
||||
|
||||
type ApiKeyDAO struct {
|
||||
cfg *config.Config
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewApiKeyDAO(db *gorm.DB) *ApiKeyDAO {
|
||||
return &ApiKeyDAO{db: db}
|
||||
}
|
||||
|
||||
// CreateApiKey 创建ApiKey
|
||||
func (dao *ApiKeyDAO) Create(apiKey *model.ApiKey) error {
|
||||
if apiKey == nil {
|
||||
return errors.New("apiKey is nil")
|
||||
}
|
||||
return dao.db.Create(apiKey).Error
|
||||
}
|
||||
|
||||
// GetApiKeyByID 根据ID获取ApiKey
|
||||
func (dao *ApiKeyDAO) GetByID(id int64) (*model.ApiKey, error) {
|
||||
var apiKey model.ApiKey
|
||||
err := dao.db.First(&apiKey, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &apiKey, nil
|
||||
}
|
||||
|
||||
// GetApiKeyByName 根据名称获取ApiKey
|
||||
func (dao *ApiKeyDAO) GetByName(name string) (*model.ApiKey, error) {
|
||||
var apiKey model.ApiKey
|
||||
err := dao.db.Where("name = ?", name).First(&apiKey).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &apiKey, nil
|
||||
}
|
||||
|
||||
// GetApiKeyByApiKey 根据ApiKey值获取ApiKey
|
||||
func (dao *ApiKeyDAO) GetByApiKey(apiKeyValue string) (*model.ApiKey, error) {
|
||||
var apiKey model.ApiKey
|
||||
err := dao.db.Where("api_key = ?", apiKeyValue).First(&apiKey).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &apiKey, nil
|
||||
}
|
||||
|
||||
func (dao *ApiKeyDAO) FindKeys(condition map[string]any) ([]model.ApiKey, error) {
|
||||
var apiKeys []model.ApiKey
|
||||
|
||||
query := dao.db.Model(&model.ApiKey{})
|
||||
for k, v := range condition {
|
||||
query = query.Where(k, v)
|
||||
}
|
||||
err := query.Find(&apiKeys).Error
|
||||
|
||||
return apiKeys, err
|
||||
}
|
||||
|
||||
func (dao *ApiKeyDAO) FindApiKeysBySupportModel(db *gorm.DB, modelName string) ([]model.ApiKey, error) {
|
||||
var apiKeys []model.ApiKey
|
||||
switch dao.cfg.DB_Type {
|
||||
case "mysql":
|
||||
return nil, errors.New("not support")
|
||||
case "postgres":
|
||||
return nil, errors.New("not support")
|
||||
}
|
||||
err := db.Model(&model.ApiKey{}).
|
||||
Joins("CROSS JOIN JSON_EACH(apikeys.support_models)").
|
||||
Where("value = ?", modelName).
|
||||
Find(&apiKeys).Error
|
||||
return apiKeys, err
|
||||
}
|
||||
|
||||
// UpdateApiKey 更新ApiKey信息
|
||||
func (dao *ApiKeyDAO) Update(apiKey *model.ApiKey) error {
|
||||
if apiKey == nil {
|
||||
return errors.New("apiKey is nil")
|
||||
}
|
||||
// return dao.db.Model(&model.ApiKey{}).
|
||||
// Select("name", "apitype", "apikey", "status", "endpoint", "resource_name", "deployment_name").Updates(apiKey).Error
|
||||
return dao.db.Save(apiKey).Error
|
||||
}
|
||||
|
||||
// DeleteApiKey 删除ApiKey
|
||||
func (dao *ApiKeyDAO) Delete(id int64) error {
|
||||
return dao.db.Unscoped().Delete(&model.ApiKey{}, id).Error
|
||||
}
|
||||
|
||||
// ListApiKeys 获取ApiKey列表
|
||||
func (dao *ApiKeyDAO) List(limit, offset int, status string) ([]*model.ApiKey, error) {
|
||||
var apiKeys []*model.ApiKey
|
||||
db := dao.db.Limit(limit).Offset(offset)
|
||||
if status != "" {
|
||||
db = db.Where("status = ?", status)
|
||||
}
|
||||
err := db.Find(&apiKeys).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return apiKeys, nil
|
||||
}
|
||||
|
||||
// ListApiKeysWithFilters 根据条件获取ApiKey列表
|
||||
func (dao *ApiKeyDAO) ListWithFilters(limit, offset int, filters map[string]interface{}) ([]*model.ApiKey, int64, error) {
|
||||
var apiKeys []*model.ApiKey
|
||||
db := dao.db.Limit(limit).Offset(offset)
|
||||
for k, v := range filters {
|
||||
db = db.Where(k, v)
|
||||
}
|
||||
err := db.Find(&apiKeys).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
var total int64
|
||||
db.Model(&model.ApiKey{}).Count(&total)
|
||||
|
||||
return apiKeys, total, nil
|
||||
}
|
||||
|
||||
// BatchEnableApiKeys 批量启用ApiKey
|
||||
func (dao *ApiKeyDAO) BatchEnable(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids is empty")
|
||||
}
|
||||
return dao.db.Model(&model.ApiKey{}).Where("id IN ?", ids).Update("active", true).Error
|
||||
}
|
||||
|
||||
// BatchDisableApiKeys 批量禁用ApiKey
|
||||
func (dao *ApiKeyDAO) BatchDisable(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids is empty")
|
||||
}
|
||||
return dao.db.Model(&model.ApiKey{}).Where("id IN ?", ids).Update("active", false).Error
|
||||
}
|
||||
|
||||
// BatchDeleteApiKey 批量删除ApiKey
|
||||
func (dao *ApiKeyDAO) BatchDelete(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids is empty")
|
||||
}
|
||||
return dao.db.Unscoped().Delete(&model.ApiKey{}, ids).Error
|
||||
}
|
||||
|
||||
// CountApiKeys 获取ApiKey总数
|
||||
func (dao *ApiKeyDAO) Count() (int64, error) {
|
||||
var count int64
|
||||
err := dao.db.Model(&model.ApiKey{}).Count(&count).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
198
internal/dao/token.go
Normal file
198
internal/dao/token.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"opencatd-open/internal/consts"
|
||||
"opencatd-open/internal/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 确保 TokenDAO 实现了 TokenRepository 接口
|
||||
var _ TokenRepository = (*TokenDAO)(nil)
|
||||
|
||||
type TokenRepository interface {
|
||||
Create(ctx context.Context, token *model.Token) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Token, error)
|
||||
GetByKey(ctx context.Context, key string) (*model.Token, error)
|
||||
GetByUserID(ctx context.Context, userID int64) (*model.Token, error)
|
||||
Update(ctx context.Context, token *model.Token) error
|
||||
UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error
|
||||
Delete(ctx context.Context, id int64, condition map[string]interface{}) error
|
||||
List(ctx context.Context, limit, offset int) ([]*model.Token, error)
|
||||
ListWithFilters(ctx context.Context, limit, offset int, filters map[string]interface{}) ([]*model.Token, int64, error)
|
||||
Disable(ctx context.Context, id int) error
|
||||
Enable(ctx context.Context, id int) error
|
||||
BatchDisable(ctx context.Context, ids []int64, filters map[string]interface{}) error
|
||||
BatchEnable(ctx context.Context, ids []int64, filters map[string]interface{}) error
|
||||
BatchDelete(ctx context.Context, ids []int64, filters map[string]interface{}) error
|
||||
}
|
||||
|
||||
type TokenDAO struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewTokenDAO(db *gorm.DB) *TokenDAO {
|
||||
return &TokenDAO{db: db}
|
||||
}
|
||||
|
||||
// CreateToken 创建 Token
|
||||
func (dao *TokenDAO) Create(ctx context.Context, token *model.Token) error {
|
||||
if token == nil {
|
||||
return errors.New("token is nil")
|
||||
}
|
||||
return dao.db.WithContext(ctx).Create(token).Error
|
||||
}
|
||||
|
||||
// 根据 ID 获取 Token
|
||||
func (dao *TokenDAO) GetByID(ctx context.Context, id int64) (*model.Token, error) {
|
||||
var token model.Token
|
||||
err := dao.db.WithContext(ctx).Preload("User").First(&token, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// 根据 Key 获取 Token
|
||||
func (dao *TokenDAO) GetByKey(ctx context.Context, key string) (*model.Token, error) {
|
||||
var token model.Token
|
||||
// err := dao.db.Where("key = ?", key).First(&token).Error
|
||||
err := dao.db.WithContext(ctx).Preload("User").Where("key = ?", key).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// 根据 UserID 获取 Token
|
||||
func (dao *TokenDAO) GetByUserID(ctx context.Context, userID int64) (*model.Token, error) {
|
||||
var token model.Token
|
||||
err := dao.db.WithContext(ctx).Preload("User").Where("user_id = ?", userID).Find(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// UpdateToken 更新 Token 信息
|
||||
func (dao *TokenDAO) Update(ctx context.Context, token *model.Token) error {
|
||||
if token == nil {
|
||||
return errors.New("token is nil")
|
||||
}
|
||||
return dao.db.WithContext(ctx).Save(token).Error
|
||||
}
|
||||
|
||||
// UpdateTokenWithFilters 更新 Token 信息,支持过滤
|
||||
func (dao *TokenDAO) UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error {
|
||||
if token == nil {
|
||||
return errors.New("token is nil")
|
||||
}
|
||||
db := dao.db.WithContext(ctx)
|
||||
for key, value := range filters {
|
||||
db = db.Where(key+" = ?", value)
|
||||
}
|
||||
return db.Model(&model.Token{}).Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteToken 删除 Token
|
||||
func (dao *TokenDAO) Delete(ctx context.Context, id int64, condition map[string]interface{}) error {
|
||||
if id <= 0 {
|
||||
return errors.New("id is invalid")
|
||||
}
|
||||
query := dao.db.WithContext(ctx).Where("id = ?", id)
|
||||
for key, value := range condition {
|
||||
query = query.Where(key, value)
|
||||
}
|
||||
return query.Unscoped().Delete(&model.Token{}).Error
|
||||
}
|
||||
|
||||
// ListTokens 获取 Token 列表
|
||||
func (dao *TokenDAO) List(ctx context.Context, limit, offset int) ([]*model.Token, error) {
|
||||
var tokens []*model.Token
|
||||
err := dao.db.WithContext(ctx).Limit(limit).Offset(offset).Find(&tokens).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// ListTokensWithFilters 获取 Token 列表,支持过滤
|
||||
func (dao *TokenDAO) ListWithFilters(ctx context.Context, limit, offset int, filters map[string]interface{}) ([]*model.Token, int64, error) {
|
||||
var tokens []*model.Token
|
||||
var count int64
|
||||
|
||||
db := dao.db.WithContext(ctx)
|
||||
if filters != nil {
|
||||
for k, v := range filters {
|
||||
db = db.Where(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
if err := db.Limit(limit).Offset(offset).Find(&tokens).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
db.Model(&model.Token{}).Count(&count)
|
||||
|
||||
return tokens, count, nil
|
||||
}
|
||||
|
||||
// DisableToken 禁用 Token
|
||||
func (dao *TokenDAO) Disable(ctx context.Context, id int) error {
|
||||
return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id = ?", id).Update("status", false).Error
|
||||
}
|
||||
|
||||
// EnableToken 启用 Token
|
||||
func (dao *TokenDAO) Enable(ctx context.Context, id int) error {
|
||||
return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id = ?", id).Update("status", true).Error
|
||||
}
|
||||
|
||||
// BatchDisableTokens 批量禁用 Token
|
||||
func (dao *TokenDAO) BatchDisable(ctx context.Context, ids []int64, filters map[string]interface{}) error {
|
||||
query := dao.db.WithContext(ctx).Model(&model.Token{}).Where("id IN ?", ids)
|
||||
for key, value := range filters {
|
||||
query = query.Where(key, value)
|
||||
}
|
||||
return query.Update("active", false).Error
|
||||
}
|
||||
|
||||
// BatchEnableTokens 批量启用 Token
|
||||
func (dao *TokenDAO) BatchEnable(ctx context.Context, ids []int64, filters map[string]interface{}) error {
|
||||
query := dao.db.WithContext(ctx).Model(&model.Token{}).Where("id IN ?", ids)
|
||||
for key, value := range filters {
|
||||
query = query.Where(key, value)
|
||||
}
|
||||
return query.Update("active", true).Error
|
||||
}
|
||||
|
||||
// BatchDeleteTokens 批量删除 Token
|
||||
func (dao *TokenDAO) BatchDelete(ctx context.Context, ids []int64, filters map[string]interface{}) error {
|
||||
query := dao.db.Unscoped().WithContext(ctx).Where("id IN ?", ids)
|
||||
for key, value := range filters {
|
||||
query = query.Where(key, value)
|
||||
}
|
||||
return query.Delete(&model.Token{}).Error
|
||||
// return dao.db.WithContext(ctx).Where("name != 'default' AND id IN ?", ids).Delete(&model.Token{}).Error
|
||||
}
|
||||
|
||||
// 检查 token 是否有效
|
||||
func (dao *TokenDAO) IsValid(ctx context.Context, key string) (bool, error) {
|
||||
var token model.Token
|
||||
err := dao.db.WithContext(ctx).Where("key = ? AND status = ? AND (expired_time = -1 OR expired_time > ?)",
|
||||
key, consts.StatusEnabled, time.Now().Unix()).First(&token).Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
if token.User.Status != consts.StatusEnabled || (*token.User.UnlimitedQuota && *token.User.Quota <= 0) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
295
internal/dao/usage.go
Normal file
295
internal/dao/usage.go
Normal file
@@ -0,0 +1,295 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
dto "opencatd-open/internal/dto/team"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/pkg/config"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
var _ UsageRepository = (*UsageDAO)(nil)
|
||||
var _ DailyUsageRepository = (*DailyUsageDAO)(nil)
|
||||
|
||||
type UsageRepository interface {
|
||||
// Create
|
||||
Create(ctx context.Context, usage *model.Usage) error
|
||||
BatchCreate(ctx context.Context, usages []*model.Usage) error
|
||||
|
||||
// Read
|
||||
ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error)
|
||||
ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.Usage, error)
|
||||
ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.Usage, error)
|
||||
ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error)
|
||||
|
||||
// Delete
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
// Statistics
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
}
|
||||
|
||||
type DailyUsageRepository interface {
|
||||
// Create
|
||||
Create(ctx context.Context, usage *model.DailyUsage) error
|
||||
BatchCreate(ctx context.Context, usages []*model.DailyUsage) error
|
||||
|
||||
// Read
|
||||
ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.DailyUsage, error)
|
||||
ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.DailyUsage, error)
|
||||
ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.DailyUsage, error)
|
||||
GetByDate(ctx context.Context, userID int64, date time.Time) (*model.DailyUsage, error)
|
||||
|
||||
// Delete
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
// Statistics
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
StatUserUsages(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error)
|
||||
}
|
||||
|
||||
type UsageDAO struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
type DailyUsageDAO struct {
|
||||
cfg *config.Config
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUsageDAO(cfg *config.Config, db *gorm.DB) *UsageDAO {
|
||||
return &UsageDAO{db: db}
|
||||
}
|
||||
|
||||
func NewDailyUsageDAO(cfg *config.Config, db *gorm.DB) *DailyUsageDAO {
|
||||
return &DailyUsageDAO{db: db}
|
||||
}
|
||||
|
||||
// Usage DAO implementations
|
||||
func (d *UsageDAO) Create(ctx context.Context, usage *model.Usage) error {
|
||||
return d.db.WithContext(ctx).Create(usage).Error
|
||||
}
|
||||
|
||||
func (d *UsageDAO) BatchCreate(ctx context.Context, usages []*model.Usage) error {
|
||||
return d.db.WithContext(ctx).Create(usages).Error
|
||||
}
|
||||
|
||||
func (d *UsageDAO) GetByID(ctx context.Context, id int64) (*model.Usage, error) {
|
||||
var usage model.Usage
|
||||
err := d.db.WithContext(ctx).First(&usage, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func (d *UsageDAO) ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error) {
|
||||
var usages []*model.Usage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("user_id = ?", userID).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *UsageDAO) ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.Usage, error) {
|
||||
var usages []*model.Usage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("token_id = ?", tokenID).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *UsageDAO) ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.Usage, error) {
|
||||
var usages []*model.Usage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("date BETWEEN ? AND ?", start, end).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *UsageDAO) ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) {
|
||||
var usages []*model.Usage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("capability = ?", capability).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *UsageDAO) Delete(ctx context.Context, id int64) error {
|
||||
return d.db.WithContext(ctx).Delete(&model.Usage{}, id).Error
|
||||
}
|
||||
|
||||
func (d *UsageDAO) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := d.db.WithContext(ctx).Model(&model.Usage{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// DailyUsage DAO implementations
|
||||
func (d *DailyUsageDAO) Create(ctx context.Context, usage *model.DailyUsage) error {
|
||||
return d.db.WithContext(ctx).Create(usage).Error
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) BatchCreate(ctx context.Context, usages []*model.DailyUsage) error {
|
||||
return d.db.WithContext(ctx).Create(usages).Error
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.DailyUsage, error) {
|
||||
var usages []*model.DailyUsage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("user_id = ?", userID).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.DailyUsage, error) {
|
||||
var usages []*model.DailyUsage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("token_id = ?", tokenID).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.DailyUsage, error) {
|
||||
var usages []*model.DailyUsage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("date BETWEEN ? AND ?", start, end).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) GetByDate(ctx context.Context, userID int64, date time.Time) (*model.DailyUsage, error) {
|
||||
var usage model.DailyUsage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("user_id = ? AND date = ?", userID, date).
|
||||
First(&usage).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
// UpsertDailyUsage 根据不同数据库类型执行 Upsert
|
||||
func (d *DailyUsageDAO) UpsertDailyUsage(ctx context.Context, usage *model.Usage) error {
|
||||
date := usage.Date.Truncate(24 * time.Hour)
|
||||
dailyUsage := &model.DailyUsage{
|
||||
UserID: usage.UserID,
|
||||
TokenID: usage.TokenID,
|
||||
Capability: usage.Capability,
|
||||
Model: usage.Model,
|
||||
Stream: usage.Stream,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
TotalTokens: usage.TotalTokens,
|
||||
Cost: usage.Cost,
|
||||
}
|
||||
|
||||
updateColumns := map[string]interface{}{
|
||||
"prompt_tokens": gorm.Expr("prompt_tokens + VALUES(prompt_tokens)"),
|
||||
"completion_tokens": gorm.Expr("completion_tokens + VALUES(completion_tokens)"),
|
||||
"total_tokens": gorm.Expr("total_tokens + VALUES(total_tokens)"),
|
||||
}
|
||||
|
||||
db := d.db.WithContext(ctx)
|
||||
|
||||
switch d.cfg.DB_Type {
|
||||
case "mysql":
|
||||
// MySQL: INSERT ... ON DUPLICATE KEY UPDATE
|
||||
return db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "user_id"},
|
||||
{Name: "token_id"},
|
||||
{Name: "capability"},
|
||||
{Name: "date"},
|
||||
{Name: "model"},
|
||||
{Name: "stream"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(updateColumns),
|
||||
}).Create(dailyUsage).Error
|
||||
|
||||
case "postgres":
|
||||
// PostgreSQL: INSERT ... ON CONFLICT DO UPDATE
|
||||
updateColumns := map[string]interface{}{
|
||||
"prompt_tokens": gorm.Expr("daily_usages.prompt_tokens + EXCLUDED.prompt_tokens"),
|
||||
"completion_tokens": gorm.Expr("daily_usages.completion_tokens + EXCLUDED.completion_tokens"),
|
||||
"total_tokens": gorm.Expr("daily_usages.total_tokens + EXCLUDED.total_tokens"),
|
||||
}
|
||||
return db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "user_id"},
|
||||
{Name: "token_id"},
|
||||
{Name: "capability"},
|
||||
{Name: "date"},
|
||||
{Name: "model"},
|
||||
{Name: "stream"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(updateColumns),
|
||||
}).Create(dailyUsage).Error
|
||||
case "sqlite":
|
||||
fallthrough
|
||||
default:
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
var existing model.DailyUsage
|
||||
err := tx.Where("user_id = ? AND token_id = ? AND capability = ? AND date = ? AND model = ? AND stream = ?",
|
||||
usage.UserID, usage.TokenID, usage.Capability, date, usage.Model, usage.Stream).
|
||||
First(&existing).Error
|
||||
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
// 记录不存在,创建新记录
|
||||
return tx.Create(dailyUsage).Error
|
||||
} else if err != nil {
|
||||
return err // 返回其他错误
|
||||
}
|
||||
|
||||
// 记录存在,更新
|
||||
return tx.Model(&existing).Updates(map[string]interface{}{
|
||||
"prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens),
|
||||
"completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens),
|
||||
"total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens),
|
||||
}).Error
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) Delete(ctx context.Context, id int64) error {
|
||||
return d.db.WithContext(ctx).Delete(&model.DailyUsage{}, id).Error
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := d.db.WithContext(ctx).Model(&model.DailyUsage{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) StatUserUsages(ctx context.Context, from, to time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) {
|
||||
var usages []*dto.UsageInfo
|
||||
|
||||
query := d.db.WithContext(ctx).
|
||||
Model(&model.DailyUsage{}).
|
||||
Select("user_id as userId, sum(total_tokens) as totalUnit, sum(cast(cost as decimal(20,6))) as cost")
|
||||
for key, value := range filters {
|
||||
query = query.Where(fmt.Sprintf("%s = ?", key), value)
|
||||
}
|
||||
query = query.Group("user_id").Where("date >= ? AND date <= ?", from, to)
|
||||
|
||||
err := query.Group("user_id").Find(&usages).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list usages: %w", err)
|
||||
}
|
||||
|
||||
return usages, nil
|
||||
}
|
||||
163
internal/dao/user.go
Normal file
163
internal/dao/user.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"opencatd-open/internal/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 确保 UserDAO 实现了 UserRepository 接口
|
||||
var _ UserRepository = (*UserDAO)(nil)
|
||||
|
||||
// UserRepository 定义用户数据访问操作的接口
|
||||
type UserRepository interface {
|
||||
Create(user *model.User) error
|
||||
GetByID(id int64) (*model.User, error)
|
||||
GetByUsername(username string) (*model.User, error)
|
||||
Update(user *model.User) error
|
||||
Delete(id int64) error
|
||||
List(limit, offset int, condition map[string]interface{}) ([]model.User, int64, error)
|
||||
// Enable(id int64) error
|
||||
// Disable(id int64) error
|
||||
BatchEnable(ids []int64, condition []string) error
|
||||
BatchDisable(ids []int64, condition []string) error
|
||||
BatchDelete(ids []int64, condition []string) error
|
||||
}
|
||||
|
||||
type UserDAO struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUserDAO(db *gorm.DB) *UserDAO {
|
||||
return &UserDAO{db: db}
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
func (dao *UserDAO) Create(user *model.User) error {
|
||||
if user == nil {
|
||||
return errors.New("user is nil")
|
||||
}
|
||||
fmt.Println(*user)
|
||||
|
||||
return dao.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 创建用户
|
||||
if err := tx.Create(user).Error; err != nil {
|
||||
return fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// 根据ID获取用户
|
||||
func (dao *UserDAO) GetByID(id int64) (*model.User, error) {
|
||||
var user model.User
|
||||
// err := dao.db.First(&user, id).Error
|
||||
err := dao.db.Preload("Tokens", "user_id = ?", id).First(&user, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// 根据用户名获取用户
|
||||
func (dao *UserDAO) GetByUsername(username string) (*model.User, error) {
|
||||
var user model.User
|
||||
// err := dao.db.Where("user_name = ?", username).First(&user).Error
|
||||
err := dao.db.Preload("Tokens").Where("user_name = ?", username).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// 更新用户信息
|
||||
func (dao *UserDAO) Update(user *model.User) error {
|
||||
if user == nil {
|
||||
return errors.New("user is nil")
|
||||
}
|
||||
|
||||
user.UpdatedAt = time.Now().Unix()
|
||||
return dao.db.Save(user).Error
|
||||
}
|
||||
|
||||
// 删除用户
|
||||
func (dao *UserDAO) Delete(id int64) error {
|
||||
return dao.db.Unscoped().Delete(&model.User{}, id).Error
|
||||
// return dao.db.Model(&model.User{}).Where("id = ?", id).Update("status", 2).Error
|
||||
}
|
||||
|
||||
// 获取用户列表
|
||||
func (dao *UserDAO) List(limit, offset int, condition map[string]interface{}) ([]model.User, int64, error) {
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
var users []model.User
|
||||
var total int64
|
||||
|
||||
query := dao.db.Preload("Tokens").Model(&model.User{})
|
||||
|
||||
for k, v := range condition {
|
||||
query = query.Where(k, v)
|
||||
}
|
||||
err := query.Limit(limit).Offset(offset).Find(&users).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
query = dao.db.Model(&model.User{})
|
||||
for k, v := range condition {
|
||||
query = query.Where(k, v)
|
||||
}
|
||||
query.Count(&total)
|
||||
|
||||
return users, total, nil
|
||||
}
|
||||
|
||||
// 启用User
|
||||
func (dao *UserDAO) Enable(id uint) error {
|
||||
return dao.db.Model(&model.User{}).Where("id = ?", id).Update("active", true).Error
|
||||
}
|
||||
|
||||
// 禁用User
|
||||
func (dao *UserDAO) Disable(id uint) error {
|
||||
return dao.db.Model(&model.User{}).Where("id = ?", id).Update("active", false).Error
|
||||
}
|
||||
|
||||
// 批量启用User
|
||||
func (dao *UserDAO) BatchEnable(ids []int64, condition []string) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids is empty")
|
||||
}
|
||||
query := dao.db.Model(&model.User{}).Where("id IN ?", ids)
|
||||
for _, value := range condition {
|
||||
query = query.Where(value)
|
||||
}
|
||||
return query.Update("active", true).Error
|
||||
}
|
||||
|
||||
// 批量禁用User
|
||||
func (dao *UserDAO) BatchDisable(ids []int64, condition []string) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids is empty")
|
||||
}
|
||||
query := dao.db.Model(&model.User{}).Where("id IN ?", ids)
|
||||
for _, value := range condition {
|
||||
query = query.Where(value)
|
||||
}
|
||||
return query.Update("active", false).Error
|
||||
}
|
||||
|
||||
// 批量删除用户
|
||||
func (dao *UserDAO) BatchDelete(ids []int64, condition []string) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids is empty")
|
||||
}
|
||||
query := dao.db.Unscoped().Where("id IN ?", ids)
|
||||
for _, value := range condition {
|
||||
query = query.Where(value)
|
||||
}
|
||||
return query.Delete(&model.User{}).Error
|
||||
}
|
||||
Reference in New Issue
Block a user