diff --git a/cmd/openteam/main.go b/cmd/openteam/main.go new file mode 100644 index 0000000..f11a165 --- /dev/null +++ b/cmd/openteam/main.go @@ -0,0 +1,92 @@ +package main + +import ( + "context" + "log" + "net/http" + "opencatd-open/pkg/store" + "opencatd-open/team/dashboard" + "opencatd-open/wire" + "os" + "os/signal" + "syscall" + "time" + + "github.com/gin-gonic/gin" +) + +func main() { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + + _, err := store.InitDB() + if err != nil { + panic(err) + } + + team, err := wire.InitTeamHandler(ctx, store.DB) + if err != nil { + panic(err) + } + + r := gin.Default() + teamGroup := r.Group("/1") + teamGroup.Use(team.AuthMiddleware()) + { + teamGroup.POST("/users/init", team.InitAdmin) + // 获取当前用户信息 + teamGroup.GET("/me", team.Me) + //// team.GET("/me/usages", team.HandleMeUsage) + + teamGroup.POST("/keys", team.CreateKey) + teamGroup.GET("/keys", team.ListKeys) + teamGroup.POST("/keys/:id", team.UpdateKey) + teamGroup.DELETE("/keys/:id", team.DeleteKey) + + teamGroup.POST("/users", team.CreateUser) + teamGroup.GET("/users", team.ListUsers) + teamGroup.POST("/users/:id/reset", team.ResetUserToken) + teamGroup.DELETE("/users/:id", team.DeleteUser) + + teamGroup.GET("/1/usages", team.ListUsages) + } + + api := r.Group("/api") + { + api.POST("/login", dashboard.HandleLogin) + } + + srv := &http.Server{ + Addr: ":8080", + Handler: r, + } + + go func() { + // 服务启动 + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("listen: %s\n", err) + } + }() + + // 等待中断信号来优雅地关闭服务器 + quit := make(chan os.Signal, 1) + // kill (no param) default send syscall.SIGTERM + // kill -2 is syscall.SIGINT + // kill -9 is syscall.SIGKILL but can't be catch + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + log.Println("Shutdown Server ...") + + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + log.Fatal("Server Shutdown:", err) + } + db, _ := store.DB.DB() + db.Close() + // catching ctx.Done(). timeout of 1 seconds. + select { + case <-ctx.Done(): + log.Println("timeout of 5 seconds.") + } + log.Println("Server exiting") + +} diff --git a/internal/utils/pointer.go b/internal/utils/pointer.go new file mode 100644 index 0000000..7847c92 --- /dev/null +++ b/internal/utils/pointer.go @@ -0,0 +1,5 @@ +package utils + +func ToPtr[T any](v T) *T { + return &v +} diff --git a/opencat.go b/opencat.go index c102cab..aa49857 100644 --- a/opencat.go +++ b/opencat.go @@ -8,9 +8,10 @@ import ( "io/fs" "log" "net/http" - "opencatd-open/pkg/team" "opencatd-open/router" "opencatd-open/store" + "opencatd-open/team" + "opencatd-open/team/dashboard" "os" "github.com/duke-git/lancet/v2/fileutil" @@ -35,7 +36,7 @@ func main() { args := os.Args[1:] if len(args) > 0 { type user struct { - ID uint + ID int64 Name string Token string } @@ -155,20 +156,25 @@ func main() { group.POST("/keys", team.HandleAddKey) // 添加Key group.DELETE("/keys/:id", team.HandleDelKey) // 删除Key - group.GET("/users", team.HandleUsers) // 获取所有用户信息 - group.POST("/users", team.HandleAddUser) // 添加用户 - group.DELETE("/users/:id", team.HandleDelUser) // 删除用户 + // group.GET("/users", team.HandleUsers) // 获取所有用户信息 + // group.POST("/users", team.HandleAddUser) // 添加用户 + // group.DELETE("/users/:id", team.HandleDelUser) // 删除用户 - group.GET("/usages", team.HandleUsage) + // group.GET("/usages", team.HandleUsage) - // 重置用户Token - group.POST("/users/:id/reset", team.HandleResetUserToken) + // // 重置用户Token + // group.POST("/users/:id/reset", team.HandleResetUserToken) } // 初始化用户 r.POST("/1/users/init", team.Handleinit) r.Any("/v1/*proxypath", router.HandleProxy) + api := r.Group("/api") + { + api.POST("/login", dashboard.HandleLogin) + } + // r.POST("/v1/chat/completions", router.HandleProy) // r.GET("/v1/models", router.HandleProy) // r.GET("/v1/dashboard/billing/subscription", router.HandleProy) diff --git a/pkg/store/db.go b/pkg/store/db.go index 862b90f..dcd337e 100644 --- a/pkg/store/db.go +++ b/pkg/store/db.go @@ -3,6 +3,7 @@ package store import ( "fmt" "log" + "opencatd-open/team/consts" "opencatd-open/team/model" "os" "strings" @@ -10,13 +11,23 @@ import ( // "gocloud.dev/mysql" // "gocloud.dev/postgres" "github.com/glebarez/sqlite" + "github.com/google/wire" "gorm.io/driver/mysql" "gorm.io/driver/postgres" + + // "gorm.io/driver/sqlite" "gorm.io/gorm" ) +var DB *gorm.DB + +var DBType consts.DBType var IsPostgres bool +var DBSet = wire.NewSet( + InitDB, +) + // InitDB 初始化数据库连接 func InitDB() (*gorm.DB, error) { var db *gorm.DB @@ -32,21 +43,35 @@ func InitDB() (*gorm.DB, error) { // 解析DSN来确定数据库类型 if strings.HasPrefix(dsn, "postgres://") { IsPostgres = true + DBType = consts.DBTypePostgreSQL db, err = initPostgres(dsn) } else if strings.HasPrefix(dsn, "mysql://") { + DBType = consts.DBTypeMySQL db, err = initMySQL(dsn) + } else { + if dsn != "" { + return nil, fmt.Errorf("unsupported database type in DSN: %s", dsn) + } } if err != nil { return nil, err } + + DB = db + if IsPostgres { - err = db.AutoMigrate(&model.User{}, &model.ApiKey_PG{}, &model.Token{}, &model.Session{}, &model.Usage{}, &model.DailyUsage{}) + err = db.AutoMigrate(&model.User{}, &model.Token{}, &model.ApiKey_PG{}, &model.Usage{}, &model.DailyUsage{}) + if err != nil { + return nil, err + } + } else { + err = db.AutoMigrate(&model.User{}, &model.Token{}, &model.ApiKey{}, &model.Usage{}, &model.DailyUsage{}) if err != nil { return nil, err } } - return nil, fmt.Errorf("unsupported database type in DSN: %s", dsn) + return db, nil } // initSQLite 初始化 SQLite 数据库 @@ -55,6 +80,7 @@ func initSQLite() (*gorm.DB, error) { if err != nil { return nil, fmt.Errorf("failed to connect to SQLite: %v", err) } + // db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) return db, nil } diff --git a/pkg/team/me.go b/pkg/team/me.go index 146ce7c..20e742a 100644 --- a/pkg/team/me.go +++ b/pkg/team/me.go @@ -42,7 +42,7 @@ func Handleinit(c *gin.Context) { }) return } - if user.ID == uint(1) { + if user.ID == 1 { c.JSON(http.StatusForbidden, gin.H{ "error": "super user already exists, use cli to reset password", }) diff --git a/pkg/team/user.go b/pkg/team/user.go index b848be1..c425f30 100644 --- a/pkg/team/user.go +++ b/pkg/team/user.go @@ -82,7 +82,7 @@ func HandleResetUserToken(c *gin.Context) { c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) return } - if u.ID == uint(1) { + if u.ID == 1 { rootToken = u.Token } c.JSON(http.StatusOK, u) diff --git a/pkg/tokenizer/tokenizer.go b/pkg/tokenizer/tokenizer.go index 9579948..f61c0f2 100644 --- a/pkg/tokenizer/tokenizer.go +++ b/pkg/tokenizer/tokenizer.go @@ -183,7 +183,7 @@ func Cost(model string, promptCount, completionCount int) float64 { cost = (0.00035/1000)*float64(prompt) + (0.00053/1000)*float64(completion) case "gemini-2.0-flash-exp": cost = (0.00035/1000)*float64(prompt) + (0.00053/1000)*float64(completion) - case "gemini-2.0-flash-thinking-exp-1219": + case "gemini-2.0-flash-thinking-exp-1219", "gemini-2.0-flash-thinking-exp-01-21": cost = (0.00035/1000)*float64(prompt) + (0.00053/1000)*float64(completion) case "learnlm-1.5-pro-experimental", " gemini-exp-1114", "gemini-exp-1121", "gemini-exp-1206": cost = (0.00035/1000)*float64(prompt) + (0.00053/1000)*float64(completion) diff --git a/store/cache.go b/store/cache.go index 37aab6c..5f8376b 100644 --- a/store/cache.go +++ b/store/cache.go @@ -82,7 +82,7 @@ func SelectKeyCacheByModel(model string) (Key, error) { } items := KeysCache.Items() for _, item := range items { - if strings.Contains(model, "realtime") { + if strings.Contains(model, "realtime") || strings.HasPrefix(model, "o1-") { if item.Object.(Key).ApiType == "openai" { keys = append(keys, item.Object.(Key)) } @@ -101,7 +101,7 @@ func SelectKeyCacheByModel(model string) (Key, error) { keys = append(keys, item.Object.(Key)) } } - if strings.HasPrefix(model, "o1-") || strings.HasPrefix(model, "chatgpt-") { + if strings.HasPrefix(model, "chatgpt-") { if item.Object.(Key).ApiType == "openai" { keys = append(keys, item.Object.(Key)) } diff --git a/team/consts/consts.go b/team/consts/consts.go index d74ffd9..27a0671 100644 --- a/team/consts/consts.go +++ b/team/consts/consts.go @@ -1,17 +1,44 @@ package consts +import "gorm.io/gorm" + +type UserRole int + const ( - RoleGuest = iota * 10 - RoleUser + RoleUser UserRole = iota * 10 RoleAdmin RoleSuperAdmin ) const ( - StatusEnabled = iota - StatusDisabled + StatusDisabled = iota + StatusEnabled StatusExpired // 过期 StatusExhausted // 耗尽 - StatusDeleted + + StatusDeleted = -1 +) +const ( + Limited = iota + Unlimited + UnlimitedQuota = 999999 +) + +var ( + ErrUserNotFound = gorm.ErrRecordNotFound +) + +func OpenOrClose(status bool) int { + if status { + return StatusEnabled + } + return StatusDisabled +} + +type DBType int + +const ( + DBTypeMySQL DBType = iota + DBTypePostgreSQL + DBTypeSQLite ) -const UnlimitedQuota = -1 diff --git a/team/dao/apikey.go b/team/dao/apikey.go index fbf2d0a..88b0ccf 100644 --- a/team/dao/apikey.go +++ b/team/dao/apikey.go @@ -18,6 +18,7 @@ type ApiKeyRepository interface { Update(apiKey *model.ApiKey) error Delete(id int64) error List(offset, limit int, status *int) ([]model.ApiKey, error) + ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.ApiKey, int64, error) Enable(id int64) error Disable(id int64) error BatchEnable(ids []int64) error @@ -77,7 +78,7 @@ func (dao *ApiKeyDAO) Update(apiKey *model.ApiKey) error { if apiKey == nil { return errors.New("apiKey is nil") } - apiKey.UpdatedAt = time.Now() + apiKey.UpdatedAt = time.Now().Unix() return dao.db.Save(apiKey).Error } @@ -100,6 +101,21 @@ func (dao *ApiKeyDAO) List(offset, limit int, status *int) ([]model.ApiKey, erro return apiKeys, nil } +// ListApiKeysWithFilters 根据条件获取ApiKey列表 +func (dao *ApiKeyDAO) ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.ApiKey, int64, error) { + var apiKeys []model.ApiKey + db := dao.db.Offset(offset).Limit(limit) + for key, value := range filters { + db = db.Where(key+" = ?", value) + } + var count int64 + err := db.Find(&apiKeys).Count(&count).Error + if err != nil { + return nil, 0, err + } + return apiKeys, count, nil +} + // EnableApiKey 启用ApiKey func (dao *ApiKeyDAO) Enable(id int64) error { return dao.db.Model(&model.ApiKey{}).Where("id = ?", id).Update("status", 0).Error diff --git a/team/dao/token.go b/team/dao/token.go index ec10335..2f1deb8 100644 --- a/team/dao/token.go +++ b/team/dao/token.go @@ -1,28 +1,33 @@ package dao import ( + "context" "errors" + "opencatd-open/team/consts" "opencatd-open/team/model" + "time" "gorm.io/gorm" ) -// 确保 TokenDAO 实现了 TokenDAOInterface 接口 -var _ TokenDAOInterface = (*TokenDAO)(nil) +// 确保 TokenDAO 实现了 TokenRepository 接口 +var _ TokenRepository = (*TokenDAO)(nil) -type TokenDAOInterface interface { - Create(token *model.Token) error - GetByID(id int) (*model.Token, error) - GetByKey(key string) (*model.Token, error) - GetByUserID(userID int) ([]model.Token, error) - Update(token *model.Token) error - Delete(id int) error - List(offset, limit int) ([]model.Token, error) - Disable(id int) error - Enable(id int) error - BatchDisable(ids []int) error - BatchEnable(ids []int) error - BatchDelete(ids []int) error +type TokenRepository interface { + Create(ctx context.Context, token *model.Token) error + GetByID(ctx context.Context, id int) (*model.Token, error) + GetByKey(ctx context.Context, key string) (*model.Token, error) + GetByUserID(ctx context.Context, userID int) (*model.Token, error) + Update(ctx context.Context, token *model.Token) error + UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error + Delete(ctx context.Context, id int) error + List(ctx context.Context, offset, limit int) ([]model.Token, error) + ListWithFilters(ctx context.Context, offset, limit int, filters map[string]interface{}) ([]model.Token, int64, error) + Disable(ctx context.Context, id int) error + Enable(ctx context.Context, id int) error + BatchDisable(ctx context.Context, ids []int) error + BatchEnable(ctx context.Context, ids []int) error + BatchDelete(ctx context.Context, ids []int) error } type TokenDAO struct { @@ -34,87 +39,140 @@ func NewTokenDAO(db *gorm.DB) *TokenDAO { } // CreateToken 创建 Token -func (dao *TokenDAO) Create(token *model.Token) error { +func (dao *TokenDAO) Create(ctx context.Context, token *model.Token) error { if token == nil { return errors.New("token is nil") } - return dao.db.Create(token).Error + return dao.db.WithContext(ctx).Create(token).Error } -// GetTokenByID 根据 ID 获取 Token -func (dao *TokenDAO) GetByID(id int) (*model.Token, error) { +// 根据 ID 获取 Token +func (dao *TokenDAO) GetByID(ctx context.Context, id int) (*model.Token, error) { var token model.Token - err := dao.db.First(&token, id).Error + err := dao.db.WithContext(ctx).First(&token, id).Error if err != nil { return nil, err } return &token, nil } -// GetTokenByKey 根据 Key 获取 Token -func (dao *TokenDAO) GetByKey(key string) (*model.Token, error) { +// 根据 Key 获取 Token +func (dao *TokenDAO) GetByKey(ctx context.Context, key string) (*model.Token, error) { var token model.Token - err := dao.db.Where("key = ?", key).First(&token).Error + // err := dao.db.Where("key = ?", key).First(&token).Error + err := dao.db.WithContext(ctx).Preload("User").Where("key = ?", key).First(&token).Error if err != nil { return nil, err } return &token, nil } -// GetTokensByUserID 根据 UserID 获取 Token 列表 -func (dao *TokenDAO) GetByUserID(userID int) ([]model.Token, error) { - var tokens []model.Token - err := dao.db.Where("user_id = ?", userID).Find(&tokens).Error +// 根据 UserID 获取 Token +func (dao *TokenDAO) GetByUserID(ctx context.Context, userID int) (*model.Token, error) { + var token model.Token + err := dao.db.WithContext(ctx).Preload("User").Where("user_id = ?", userID).Find(&token).Error if err != nil { return nil, err } - return tokens, nil + return &token, nil } // UpdateToken 更新 Token 信息 -func (dao *TokenDAO) Update(token *model.Token) error { +func (dao *TokenDAO) Update(ctx context.Context, token *model.Token) error { if token == nil { return errors.New("token is nil") } - return dao.db.Save(token).Error + return dao.db.WithContext(ctx).Save(token).Error +} + +// UpdateTokenWithFilters 更新 Token 信息,支持过滤 +func (dao *TokenDAO) UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error { + if token == nil { + return errors.New("token is nil") + } + db := dao.db.WithContext(ctx) + for key, value := range filters { + db = db.Where(key+" = ?", value) + } + return db.Model(&model.Token{}).Updates(updates).Error } // DeleteToken 删除 Token -func (dao *TokenDAO) Delete(id int) error { - return dao.db.Delete(&model.Token{}, id).Error +func (dao *TokenDAO) Delete(ctx context.Context, id int) error { + return dao.db.WithContext(ctx).Delete(&model.Token{}, id).Error } // ListTokens 获取 Token 列表 -func (dao *TokenDAO) List(offset, limit int) ([]model.Token, error) { +func (dao *TokenDAO) List(ctx context.Context, offset, limit int) ([]model.Token, error) { var tokens []model.Token - err := dao.db.Offset(offset).Limit(limit).Find(&tokens).Error + err := dao.db.WithContext(ctx).Offset(offset).Limit(limit).Find(&tokens).Error if err != nil { return nil, err } return tokens, nil } +// ListTokensWithFilters 获取 Token 列表,支持过滤 +func (dao *TokenDAO) ListWithFilters(ctx context.Context, offset, limit int, filters map[string]interface{}) ([]model.Token, int64, error) { + var tokens []model.Token + var count int64 + + db := dao.db.WithContext(ctx) + for key, value := range filters { + db = db.Where(key+" = ?", value) + } + + if err := db.Offset(offset).Limit(limit).Find(&tokens).Error; err != nil { + return nil, 0, err + } + + if err := db.Model(&model.Token{}).Count(&count).Error; err != nil { + return nil, 0, err + } + + return tokens, count, nil +} + // DisableToken 禁用 Token -func (dao *TokenDAO) Disable(id int) error { - return dao.db.Model(&model.Token{}).Where("id = ?", id).Update("status", false).Error +func (dao *TokenDAO) Disable(ctx context.Context, id int) error { + return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id = ?", id).Update("status", false).Error } // EnableToken 启用 Token -func (dao *TokenDAO) Enable(id int) error { - return dao.db.Model(&model.Token{}).Where("id = ?", id).Update("status", true).Error +func (dao *TokenDAO) Enable(ctx context.Context, id int) error { + return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id = ?", id).Update("status", true).Error } // BatchDisableTokens 批量禁用 Token -func (dao *TokenDAO) BatchDisable(ids []int) error { - return dao.db.Model(&model.Token{}).Where("id IN ?", ids).Update("status", false).Error +func (dao *TokenDAO) BatchDisable(ctx context.Context, ids []int) error { + return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id IN ?", ids).Update("status", false).Error } // BatchEnableTokens 批量启用 Token -func (dao *TokenDAO) BatchEnable(ids []int) error { - return dao.db.Model(&model.Token{}).Where("id IN ?", ids).Update("status", true).Error +func (dao *TokenDAO) BatchEnable(ctx context.Context, ids []int) error { + return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id IN ?", ids).Update("status", true).Error } // BatchDeleteTokens 批量删除 Token -func (dao *TokenDAO) BatchDelete(ids []int) error { - return dao.db.Where("id IN ?", ids).Delete(&model.Token{}).Error +func (dao *TokenDAO) BatchDelete(ctx context.Context, ids []int) error { + return dao.db.WithContext(ctx).Where("id IN ?", ids).Delete(&model.Token{}).Error +} + +// 检查 token 是否有效 +func (dao *TokenDAO) IsValid(ctx context.Context, key string) (bool, error) { + var token model.Token + err := dao.db.WithContext(ctx).Where("key = ? AND status = ? AND (expired_time = -1 OR expired_time > ?)", + key, consts.StatusEnabled, time.Now().Unix()).First(&token).Error + + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return false, nil + } + return false, err + } + if token.User.Status != consts.StatusEnabled || (token.User.UnlimitedQuota == 1 && token.User.Quota <= 0) { + return false, nil + } + + return true, nil } diff --git a/team/dao/usage.go b/team/dao/usage.go new file mode 100644 index 0000000..9c5ea14 --- /dev/null +++ b/team/dao/usage.go @@ -0,0 +1,297 @@ +package dao + +import ( + "context" + "fmt" + "opencatd-open/pkg/store" + "opencatd-open/team/consts" + dto "opencatd-open/team/dto/team" + "opencatd-open/team/model" + "time" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +var _ UsageRepository = (*UsageDAO)(nil) +var _ DailyUsageRepository = (*DailyUsageDAO)(nil) + +type UsageRepository interface { + // Create + Create(ctx context.Context, usage *model.Usage) error + BatchCreate(ctx context.Context, usages []*model.Usage) error + + // Read + ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error) + ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.Usage, error) + ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.Usage, error) + ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) + + // Delete + Delete(ctx context.Context, id int64) error + + // Statistics + CountByUserID(ctx context.Context, userID int64) (int64, error) +} + +type DailyUsageRepository interface { + // Create + Create(ctx context.Context, usage *model.DailyUsage) error + BatchCreate(ctx context.Context, usages []*model.DailyUsage) error + + // Read + ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.DailyUsage, error) + ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.DailyUsage, error) + ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.DailyUsage, error) + GetByDate(ctx context.Context, userID int64, date time.Time) (*model.DailyUsage, error) + + // Delete + Delete(ctx context.Context, id int64) error + + // Statistics + CountByUserID(ctx context.Context, userID int64) (int64, error) + StatUserUsages(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) +} + +type UsageDAO struct { + db *gorm.DB +} + +type DailyUsageDAO struct { + db *gorm.DB +} + +func NewUsageDAO(db *gorm.DB) *UsageDAO { + return &UsageDAO{db: db} +} + +func NewDailyUsageDAO(db *gorm.DB) *DailyUsageDAO { + return &DailyUsageDAO{db: db} +} + +// Usage DAO implementations +func (d *UsageDAO) Create(ctx context.Context, usage *model.Usage) error { + return d.db.WithContext(ctx).Create(usage).Error +} + +func (d *UsageDAO) BatchCreate(ctx context.Context, usages []*model.Usage) error { + return d.db.WithContext(ctx).Create(usages).Error +} + +func (d *UsageDAO) GetByID(ctx context.Context, id int64) (*model.Usage, error) { + var usage model.Usage + err := d.db.WithContext(ctx).First(&usage, id).Error + if err != nil { + return nil, err + } + return &usage, nil +} + +func (d *UsageDAO) ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error) { + var usages []*model.Usage + err := d.db.WithContext(ctx). + Where("user_id = ?", userID). + Limit(limit). + Offset(offset). + Find(&usages).Error + return usages, err +} + +func (d *UsageDAO) ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.Usage, error) { + var usages []*model.Usage + err := d.db.WithContext(ctx). + Where("token_id = ?", tokenID). + Limit(limit). + Offset(offset). + Find(&usages).Error + return usages, err +} + +func (d *UsageDAO) ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.Usage, error) { + var usages []*model.Usage + err := d.db.WithContext(ctx). + Where("date BETWEEN ? AND ?", start, end). + Find(&usages).Error + return usages, err +} + +func (d *UsageDAO) ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) { + var usages []*model.Usage + err := d.db.WithContext(ctx). + Where("capability = ?", capability). + Limit(limit). + Offset(offset). + Find(&usages).Error + return usages, err +} + +func (d *UsageDAO) Delete(ctx context.Context, id int64) error { + return d.db.WithContext(ctx).Delete(&model.Usage{}, id).Error +} + +func (d *UsageDAO) CountByUserID(ctx context.Context, userID int64) (int64, error) { + var count int64 + err := d.db.WithContext(ctx).Model(&model.Usage{}).Where("user_id = ?", userID).Count(&count).Error + return count, err +} + +// DailyUsage DAO implementations +func (d *DailyUsageDAO) Create(ctx context.Context, usage *model.DailyUsage) error { + return d.db.WithContext(ctx).Create(usage).Error +} + +func (d *DailyUsageDAO) BatchCreate(ctx context.Context, usages []*model.DailyUsage) error { + return d.db.WithContext(ctx).Create(usages).Error +} + +func (d *DailyUsageDAO) ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.DailyUsage, error) { + var usages []*model.DailyUsage + err := d.db.WithContext(ctx). + Where("user_id = ?", userID). + Limit(limit). + Offset(offset). + Find(&usages).Error + return usages, err +} + +func (d *DailyUsageDAO) ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.DailyUsage, error) { + var usages []*model.DailyUsage + err := d.db.WithContext(ctx). + Where("token_id = ?", tokenID). + Limit(limit). + Offset(offset). + Find(&usages).Error + return usages, err +} + +func (d *DailyUsageDAO) ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.DailyUsage, error) { + var usages []*model.DailyUsage + err := d.db.WithContext(ctx). + Where("date BETWEEN ? AND ?", start, end). + Find(&usages).Error + return usages, err +} + +func (d *DailyUsageDAO) GetByDate(ctx context.Context, userID int64, date time.Time) (*model.DailyUsage, error) { + var usage model.DailyUsage + err := d.db.WithContext(ctx). + Where("user_id = ? AND date = ?", userID, date). + First(&usage).Error + if err != nil { + return nil, err + } + return &usage, nil +} + +// UpsertDailyUsage 根据不同数据库类型执行 Upsert +func (d *DailyUsageDAO) UpsertDailyUsage(ctx context.Context, usage *model.Usage) error { + date := usage.Date.Truncate(24 * time.Hour) + dailyUsage := &model.DailyUsage{ + UserID: usage.UserID, + TokenID: usage.TokenID, + Capability: usage.Capability, + Model: usage.Model, + Stream: usage.Stream, + PromptTokens: usage.PromptTokens, + CompletionTokens: usage.CompletionTokens, + TotalTokens: usage.TotalTokens, + Cost: usage.Cost, + } + + updateColumns := map[string]interface{}{ + "prompt_tokens": gorm.Expr("prompt_tokens + VALUES(prompt_tokens)"), + "completion_tokens": gorm.Expr("completion_tokens + VALUES(completion_tokens)"), + "total_tokens": gorm.Expr("total_tokens + VALUES(total_tokens)"), + } + + db := d.db.WithContext(ctx) + + switch store.DBType { + case consts.DBTypeMySQL: + // MySQL: INSERT ... ON DUPLICATE KEY UPDATE + return db.Clauses(clause.OnConflict{ + Columns: []clause.Column{ + {Name: "user_id"}, + {Name: "token_id"}, + {Name: "capability"}, + {Name: "date"}, + {Name: "model"}, + {Name: "stream"}, + }, + DoUpdates: clause.Assignments(updateColumns), + }).Create(dailyUsage).Error + + case consts.DBTypePostgreSQL: + // PostgreSQL: INSERT ... ON CONFLICT DO UPDATE + updateColumns := map[string]interface{}{ + "prompt_tokens": gorm.Expr("daily_usages.prompt_tokens + EXCLUDED.prompt_tokens"), + "completion_tokens": gorm.Expr("daily_usages.completion_tokens + EXCLUDED.completion_tokens"), + "total_tokens": gorm.Expr("daily_usages.total_tokens + EXCLUDED.total_tokens"), + } + return db.Clauses(clause.OnConflict{ + Columns: []clause.Column{ + {Name: "user_id"}, + {Name: "token_id"}, + {Name: "capability"}, + {Name: "date"}, + {Name: "model"}, + {Name: "stream"}, + }, + DoUpdates: clause.Assignments(updateColumns), + }).Create(dailyUsage).Error + case consts.DBTypeSQLite: + // SQLite: 需要使用事务来模拟 upsert + return db.Transaction(func(tx *gorm.DB) error { + var existing model.DailyUsage + err := tx.Where("user_id = ? AND token_id = ? AND capability = ? AND date = ? AND model = ? AND stream = ?", + usage.UserID, usage.TokenID, usage.Capability, date, usage.Model, usage.Stream). + First(&existing).Error + + if err == gorm.ErrRecordNotFound { + // 记录不存在,创建新记录 + return tx.Create(dailyUsage).Error + } else if err != nil { + return err // 返回其他错误 + } + + // 记录存在,更新 + return tx.Model(&existing).Updates(map[string]interface{}{ + "prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens), + "completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens), + "total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens), + }).Error + }) + + default: + return fmt.Errorf("不支持的数据库类型: %s", store.DBType) + } +} + +func (d *DailyUsageDAO) Delete(ctx context.Context, id int64) error { + return d.db.WithContext(ctx).Delete(&model.DailyUsage{}, id).Error +} + +func (d *DailyUsageDAO) CountByUserID(ctx context.Context, userID int64) (int64, error) { + var count int64 + err := d.db.WithContext(ctx).Model(&model.DailyUsage{}).Where("user_id = ?", userID).Count(&count).Error + return count, err +} + +func (d *DailyUsageDAO) StatUserUsages(ctx context.Context, from, to time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) { + var usages []*dto.UsageInfo + + query := d.db.WithContext(ctx). + Model(&model.DailyUsage{}). + Select("user_id as userId, sum(total_tokens) as totalUnit, sum(cast(cost as decimal(20,6))) as cost") + for key, value := range filters { + query = query.Where(fmt.Sprintf("%s = ?", key), value) + } + query = query.Group("user_id").Where("date >= ? AND date <= ?", from, to) + + err := query.Group("user_id").Find(&usages).Error + if err != nil { + return nil, fmt.Errorf("failed to list usages: %w", err) + } + + return usages, nil +} diff --git a/team/dao/user.go b/team/dao/user.go index 93d3373..1d00060 100644 --- a/team/dao/user.go +++ b/team/dao/user.go @@ -2,6 +2,8 @@ package dao import ( "errors" + "fmt" + "opencatd-open/team/consts" "opencatd-open/team/model" "time" @@ -19,6 +21,7 @@ type UserRepository interface { Update(user *model.User) error Delete(id int64) error List(offset, limit int) ([]model.User, error) + ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.User, int64, error) Enable(id int64) error Disable(id int64) error BatchEnable(ids []int64) error @@ -34,70 +37,107 @@ func NewUserDAO(db *gorm.DB) *UserDAO { return &UserDAO{db: db} } -// CreateUser 创建用户 +// 创建用户 func (dao *UserDAO) Create(user *model.User) error { if user == nil { return errors.New("user is nil") } - return dao.db.Create(user).Error + + return dao.db.Transaction(func(tx *gorm.DB) error { + // 创建用户 + if err := tx.Create(user).Error; err != nil { + return fmt.Errorf("failed to create user: %w", err) + } + + return nil + }) } -// GetUserByID 根据ID获取用户 +// 根据ID获取用户 func (dao *UserDAO) GetByID(id int64) (*model.User, error) { var user model.User - err := dao.db.First(&user, id).Error + // err := dao.db.First(&user, id).Error + err := dao.db.Preload("Tokens").First(&user, id).Error if err != nil { return nil, err } return &user, nil } -// GetUserByUsername 根据用户名获取用户 +// 根据用户名获取用户 func (dao *UserDAO) GetByUsername(username string) (*model.User, error) { var user model.User - err := dao.db.Where("user_name = ?", username).First(&user).Error + // err := dao.db.Where("user_name = ?", username).First(&user).Error + err := dao.db.Preload("Tokens").Where("user_name = ?", username).First(&user).Error if err != nil { return nil, err } return &user, nil } -// UpdateUser 更新用户信息 +// 更新用户信息 func (dao *UserDAO) Update(user *model.User) error { if user == nil { return errors.New("user is nil") } - user.UpdatedAt = time.Now() + user.UpdatedAt = time.Now().Unix() return dao.db.Save(user).Error } -// DeleteUser 删除用户 +// 删除用户 func (dao *UserDAO) Delete(id int64) error { return dao.db.Delete(&model.User{}, id).Error // return dao.db.Model(&model.User{}).Where("id = ?", id).Update("status", 2).Error } -// ListUsers 获取用户列表 +// 获取用户列表 func (dao *UserDAO) List(offset, limit int) ([]model.User, error) { var users []model.User - err := dao.db.Offset(offset).Limit(limit).Find(&users).Error + err := dao.db.Preload("Tokens").Offset(offset).Limit(limit).Find(&users).Error if err != nil { return nil, err } return users, nil } -// EnableUser 启用User +// 获取用户列表,带过滤条件 +func (dao *UserDAO) ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.User, int64, error) { + var users []model.User + var total int64 + + // 构建查询 + query := dao.db.Model(&model.User{}) + + // 添加过滤条件 + for key, value := range filters { + query = query.Where(key+" = ?", value) + } + + // 查询总数 + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + // 分页查询 + err := query.Offset(offset).Limit(limit).Find(&users).Error + if err != nil { + return nil, 0, err + } + + return users, total, nil +} + +// 启用User func (dao *UserDAO) Enable(id int64) error { - return dao.db.Model(&model.User{}).Where("id = ?", id).Update("status", 0).Error + return dao.db.Model(&model.User{}).Where("id = ?", id).Update("status", consts.StatusEnabled).Error } -// DisableUser 禁用User +// 禁用User func (dao *UserDAO) Disable(id int64) error { - return dao.db.Model(&model.User{}).Where("id = ?", id).Update("status", 1).Error + return dao.db.Model(&model.User{}).Where("id = ?", id).Update("status", consts.StatusDisabled).Error } -// BatchEnableUsers 批量启用User +// 批量启用User func (dao *UserDAO) BatchEnable(ids []int64) error { if len(ids) == 0 { return errors.New("ids is empty") @@ -105,7 +145,7 @@ func (dao *UserDAO) BatchEnable(ids []int64) error { return dao.db.Model(&model.User{}).Where("id IN ?", ids).Update("status", 0).Error } -// BatchDisableUsers 批量禁用User +// 批量禁用User func (dao *UserDAO) BatchDisable(ids []int64) error { if len(ids) == 0 { return errors.New("ids is empty") @@ -113,7 +153,7 @@ func (dao *UserDAO) BatchDisable(ids []int64) error { return dao.db.Model(&model.User{}).Where("id IN ?", ids).Update("status", 1).Error } -// BatchDeleteUser 批量删除用户 +// 批量删除用户 func (dao *UserDAO) BatchDelete(ids []int64) error { if len(ids) == 0 { return errors.New("ids is empty") diff --git a/team/dashboard/dashboard.go b/team/dashboard/dashboard.go new file mode 100644 index 0000000..82b769a --- /dev/null +++ b/team/dashboard/dashboard.go @@ -0,0 +1,16 @@ +package dashboard + +import "github.com/gin-gonic/gin" + +func HandleTeam(c *gin.Context) { + c.JSON(200, gin.H{ + "code": 200, + "data": gin.H{ + "team": gin.H{ + "total_users": 10, + "total_keys": 20, + "total_projects": 30, + }, + }, + }) +} diff --git a/team/dashboard/login.go b/team/dashboard/login.go new file mode 100644 index 0000000..83a9c5b --- /dev/null +++ b/team/dashboard/login.go @@ -0,0 +1,20 @@ +package dashboard + +import ( + "fmt" + + "github.com/gin-gonic/gin" +) + +func HandleLogin(c *gin.Context) { + var user map[string]string + c.ShouldBind(&user) + fmt.Sprintf("%v", user) + c.JSON(200, gin.H{ + "code": 200, + "msg": "success", + "data": gin.H{ + "token": "token", + }, + }) +} diff --git a/team/dto/openai/err_resp.go b/team/dto/openai/err_resp.go new file mode 100644 index 0000000..3b82c23 --- /dev/null +++ b/team/dto/openai/err_resp.go @@ -0,0 +1,22 @@ +package dto + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +type Error struct { + Message string `json:"message,omitempty"` + Code string `json:"code,omitempty"` +} + +func WarpErrAsOpenAI(c *gin.Context, msg string, code string) { + c.JSON(http.StatusForbidden, gin.H{ + "error": Error{ + Message: msg, + Code: code, + }, + }) + return +} diff --git a/team/dto/team/team.go b/team/dto/team/team.go new file mode 100644 index 0000000..c9994ab --- /dev/null +++ b/team/dto/team/team.go @@ -0,0 +1,83 @@ +package dto + +import ( + "opencatd-open/team/consts" + "opencatd-open/team/model" +) + +type UserInfo struct { + ID int `json:"id"` + Name string `json:"name"` + Token string `json:"token"` + Status *bool `json:"status,omitempty"` +} + +func (u UserInfo) HasNameUpdate() bool { + return u.Name != "" +} + +func (u UserInfo) HasTokenUpdate() bool { + return u.Token != "" +} + +func (u UserInfo) HasStatusUpdate() bool { + return u.Status != nil +} + +type KeyInfo struct { + ID int `json:"id,omitempty"` + Key string `json:"key,omitempty"` + Name string `json:"name,omitempty"` + ApiType string `json:"api_type,omitempty"` + Endpoint string `json:"endpoint,omitempty"` + Status *bool `json:"status,omitempty"` +} + +// 添加辅助方法判断字段是否需要更新 +func (k KeyInfo) HasNameUpdate() bool { + return k.Name != "" +} + +func (k KeyInfo) HasKeyUpdate() bool { + return k.Key != "" +} + +func (k KeyInfo) HasStatusUpdate() bool { + return k.Status != nil +} + +func (k KeyInfo) HasApiTypeUpdate() bool { + return k.ApiType != "" +} + +// 辅助函数:统一处理字段更新 +func (update *KeyInfo) UpdateFields(existing *model.ApiKey) *model.ApiKey { + result := &model.ApiKey{ + ID: existing.ID, + Name: existing.Name, // 默认保持原值 + ApiType: existing.ApiType, // 默认保持原值 + ApiKey: existing.ApiKey, // 默认保持原值 + Status: existing.Status, // 默认保持原值 + } + + if update.HasNameUpdate() { + result.Name = update.Name + } + if update.HasKeyUpdate() { + result.ApiKey = update.Key + } + if update.HasStatusUpdate() { + result.Status = consts.OpenOrClose(*update.Status) + } + if update.HasApiTypeUpdate() { + result.ApiType = update.ApiType + } + + return result +} + +type UsageInfo struct { + UserId int `json:"userId"` + TotalUnit int `json:"totalUnit"` + Cost string `json:"cost"` +} diff --git a/team/handler/team/middleware.go b/team/handler/team/middleware.go new file mode 100644 index 0000000..f6588c9 --- /dev/null +++ b/team/handler/team/middleware.go @@ -0,0 +1,53 @@ +package handler + +import ( + "fmt" + "net/http" + "opencatd-open/team/consts" + + "github.com/gin-contrib/cors" + "github.com/gin-gonic/gin" +) + +func (h *TeamHandler) AuthMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + if c.Request.URL.Path == "/1/users/init" { + c.Next() + return + } + authtoken := c.GetHeader("Authorization") + if authtoken == "" || len(authtoken) <= 7 || authtoken[:7] != "Bearer " { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + c.Abort() + return + } + authtoken = authtoken[7:] + token, err := h.tokenService.GetByKey(c, authtoken) + if err != nil { + fmt.Println(err) + } + if token.Name != "default" { + c.JSON(http.StatusForbidden, gin.H{"error": "only default token can access"}) + c.Abort() + } + if token.User.Status != consts.StatusEnabled { + c.JSON(http.StatusForbidden, gin.H{"error": "user is disabled"}) + c.Abort() + } + c.Set("local_user", true) + c.Set("token", token) + + // 可以在这里对 token 进行验证并检查权限 + + c.Next() + } +} + +func CORS() gin.HandlerFunc { + config := cors.DefaultConfig() + config.AllowAllOrigins = true + config.AllowCredentials = true + config.AllowMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"} + config.AllowHeaders = []string{"*"} + return cors.New(config) +} diff --git a/team/handler/team/team.go b/team/handler/team/team.go new file mode 100644 index 0000000..c87aab1 --- /dev/null +++ b/team/handler/team/team.go @@ -0,0 +1,567 @@ +package handler + +import ( + "errors" + "net/http" + "strconv" + "strings" + "time" + + "opencatd-open/internal/utils" + "opencatd-open/team/consts" + dto "opencatd-open/team/dto/team" + "opencatd-open/team/model" + "opencatd-open/team/service" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "gorm.io/gorm" +) + +type TeamHandler struct { + db *gorm.DB + userService service.UserService + tokenService service.TokenService + keyService service.ApiKeyService + usageService service.UsageService +} + +func NewTeamHandler(userService service.UserService, tokenService service.TokenService, keyService service.ApiKeyService, usageService service.UsageService) *TeamHandler { + return &TeamHandler{ + userService: userService, + tokenService: tokenService, + keyService: keyService, + usageService: usageService, + } +} + +// initadmin +func (h *TeamHandler) InitAdmin(c *gin.Context) { + admin, err := h.userService.GetUser(c, 1) + + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + user := &model.User{ + Name: "root", + Username: "root", + Password: "openteam", + Role: int(consts.RoleSuperAdmin), + Tokens: []model.Token{ + { + Name: "default", + Key: "team-" + strings.ReplaceAll(uuid.New().String(), "-", ""), + UnlimitedQuota: true, + }, + }, + } + if err := h.userService.CreateUser(c, user); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + var result = dto.UserInfo{ + ID: int(user.ID), + Name: user.Name, + Token: user.Tokens[0].Key, + Status: utils.ToPtr(user.Status == consts.StatusEnabled), + } + + c.JSON(http.StatusOK, result) + return + } else { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + if admin != nil { + c.JSON(http.StatusForbidden, gin.H{ + "error": "super user already exists, use cli to reset password", + }) + return + } +} + +func (h *TeamHandler) Me(c *gin.Context) { + // token := c.GetHeader("Authorization") + // token = strings.TrimPrefix(token, "Bearer ") + // userToken, err := h.tokenService.GetTokenByKey(token) + // if err != nil { + // c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) + // return + // } + // if userToken.ID != 1 { + // c.JSON(http.StatusForbidden, gin.H{"error": "only first user token can access"}) + // return + // } + token, exists := c.Get("token") + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "token not found"}) + return + } + userToken := token.(*model.Token) + + c.JSON(http.StatusOK, dto.UserInfo{ + ID: int(userToken.UserID), + Name: userToken.User.Name, + Token: userToken.Key, + Status: utils.ToPtr(userToken.User.Status == consts.StatusEnabled), + }) + +} + +// CreateUser 创建用户 +func (h *TeamHandler) CreateUser(c *gin.Context) { + var userReq dto.UserInfo + if err := c.ShouldBindJSON(&userReq); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid input"}) + return + } + + token, exists := c.Get("token") + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"}) + return + } + userToken := token.(*model.Token) + if userToken.User.Role < int(consts.RoleAdmin) { + create := &model.Token{ + Name: userReq.Name, + Key: "team-" + strings.ReplaceAll(uuid.New().String(), "-", ""), + } + if err := h.tokenService.Create(c.Request.Context(), create); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + } else { + user := &model.User{ + Name: userReq.Name, + Username: userReq.Name, + Role: int(consts.RoleUser), + Tokens: []model.Token{ + { + Name: "default", + Key: "team-" + strings.ReplaceAll(uuid.New().String(), "-", ""), + }, + }, + } + + // 默认角色为普通用户 + if err := h.userService.CreateUser(c.Request.Context(), user); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + + c.JSON(http.StatusOK, gin.H{"message": "ok"}) +} + +// GetUser 获取用户信息 +func (h *TeamHandler) GetUser(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"}) + return + } + + user, err := h.userService.GetUser(c.Request.Context(), id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, user) +} + +// UpdateUser 更新用户信息 +func (h *TeamHandler) UpdateUser(c *gin.Context) { + var user model.User + if err := c.ShouldBindJSON(&user); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid input"}) + return + } + token, exists := c.Get("token") + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"}) + return + } + userToken := token.(*model.Token) + + operatorID := userToken.UserID // 假设从上下文中获取操作者ID + if err := h.userService.UpdateUser(c.Request.Context(), &user, operatorID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "ok"}) +} + +// DeleteUser 删除用户 +func (h *TeamHandler) DeleteUser(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"}) + return + } + + token, exists := c.Get("token") + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"}) + return + } + userToken := token.(*model.Token) + + if userToken.User.Role < int(consts.RoleAdmin) { // 用户只能删除自己的token + err := h.tokenService.Delete(c.Request.Context(), int(id)) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } else { + if err := h.userService.DeleteUser(c.Request.Context(), id, userToken.UserID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + + c.JSON(http.StatusOK, gin.H{"message": "ok"}) +} + +func (h *TeamHandler) ListUsages(c *gin.Context) { + fromStr := c.Query("from") + toStr := c.Query("to") + + var from, to time.Time + loc, _ := time.LoadLocation("Local") + + var listUsage []*dto.UsageInfo + var err error + + if fromStr != "" && toStr != "" { + + from, err = time.Parse("2006-01-02", fromStr) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid from date"}) + return + } + to, err = time.Parse("2006-01-02", toStr) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid to date"}) + return + } + } else { + year, month, _ := time.Now().In(loc).Date() + from = time.Date(year, month, 1, 0, 0, 0, 0, loc) + to = from.AddDate(0, 1, 0) + } + + token, _ := c.Get("token") + userToken := token.(*model.Token) + if userToken.User.Role < int(consts.RoleAdmin) { + listUsage, err = h.usageService.ListByDateRange(c.Request.Context(), from, to, map[string]interface{}{"user_id": userToken.UserID}) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } else { + listUsage, err = h.usageService.ListByDateRange(c.Request.Context(), from, to, nil) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + + c.JSON(http.StatusOK, listUsage) + +} + +// ListUsers 获取用户列表 +func (h *TeamHandler) ListUsers(c *gin.Context) { + pageStr := c.DefaultQuery("page", "1") + pageSizeStr := c.DefaultQuery("pageSize", "100") + + page, err := strconv.Atoi(pageStr) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid page number"}) + return + } + + pageSize, err := strconv.Atoi(pageSizeStr) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid page size"}) + return + } + token, exists := c.Get("token") + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"}) + return + } + userToken := token.(*model.Token) + if userToken.User.Role < int(consts.RoleAdmin) { + tokens, _, err := h.tokenService.ListsWithFilters(c, 0, 100, map[string]interface{}{"user_id": userToken.UserID}) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + var userDTOs []dto.UserInfo + for _, token := range tokens { + userDTOs = append(userDTOs, dto.UserInfo{ + ID: int(token.User.ID), + Name: token.User.Name, + Token: token.Key, + Status: utils.ToPtr(token.User.Status == consts.StatusEnabled), + }) + } + c.JSON(http.StatusOK, userDTOs) + return + } + + users, _, err := h.userService.ListUsers(c.Request.Context(), page, pageSize) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + var userDTOs []dto.UserInfo + for _, user := range users { + useres := dto.UserInfo{ + ID: int(user.ID), + Name: user.Name, + + Status: utils.ToPtr(user.Status == consts.StatusEnabled), + } + if len(user.Tokens) > 0 { + useres.Token = user.Tokens[0].Key + } + userDTOs = append(userDTOs, useres) + } + + c.JSON(http.StatusOK, userDTOs) +} + +func (h *TeamHandler) ResetUserToken(c *gin.Context) { + idstr := c.Param("id") + id, err := strconv.Atoi(idstr) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"}) + return + } + token, exists := c.Get("token") + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"}) + return + } + userToken := token.(*model.Token) + + findtoken, err := h.tokenService.GetByUserID(c, id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + findtoken.Key = "team-" + strings.ReplaceAll(uuid.New().String(), "-", "") + + if userToken.User.Role < int(consts.RoleAdmin) { // 非管理员只能修改自己的token + if userToken.User.Role <= findtoken.User.Role || userToken.UserID != findtoken.UserID { + c.JSON(http.StatusForbidden, gin.H{"error": "forbidden"}) + return + } + err := h.tokenService.UpdateWithCondition(c, findtoken, map[string]interface{}{"user_id": userToken.UserID}, nil) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } else { + if err := h.tokenService.Update(c, findtoken); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } + + c.JSON(http.StatusOK, dto.UserInfo{ + ID: int(findtoken.User.ID), + Name: findtoken.User.Name, + Token: findtoken.Key, + }) +} + +func (h *TeamHandler) CreateKey(c *gin.Context) { + token, exists := c.Get("token") + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "token not found"}) + return + } + userToken := token.(*model.Token) + if userToken.User.Role < int(consts.RoleAdmin) { + c.JSON(http.StatusForbidden, gin.H{"error": "forbidden"}) + return + } + + var key dto.KeyInfo + if err := c.ShouldBindJSON(&key); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + err := h.keyService.Create(&model.ApiKey{ + Name: key.Name, + ApiType: key.ApiType, + ApiKey: key.Key, + Endpoint: key.Endpoint, + }) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, key) +} + +func (h *TeamHandler) ListKeys(c *gin.Context) { + keys, err := h.keyService.List(0, 100, nil) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + } + + var keysDTO []dto.KeyInfo + for _, key := range keys { + keylength := len(key.ApiKey) / 3 + if keylength < 1 { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid key length"}) + return + } + keysDTO = append(keysDTO, dto.KeyInfo{ + ID: int(key.ID), + Name: key.Name, + ApiType: key.ApiType, + Endpoint: key.Endpoint, + Key: key.ApiKey[:keylength] + "****" + key.ApiKey[len(key.ApiKey)-keylength:], + }) + } + c.JSON(http.StatusOK, keysDTO) +} + +func (h *TeamHandler) UpdateKey(c *gin.Context) { + // 1. 获取并验证ID + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid key id"}) + return + } + + // 2. 解析请求体 + var updateKey dto.KeyInfo // 更明确的命名 + if err := c.ShouldBindJSON(&updateKey); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // 3. 获取现有记录 + existingKey, err := h.keyService.GetByID(id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // 4. 使用 UpdateFields 方法统一处理字段更新 + updatedKey := updateKey.UpdateFields(existingKey) + + // 5. 保存更新 + if err := h.keyService.Update(updatedKey); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, updatedKey) +} + +func (h *TeamHandler) DeleteKey(c *gin.Context) { + // 1. 获取并验证ID + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid key id"}) + return + } + + // 2. 删除记录 + if err := h.keyService.Delete(id); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "ok"}) +} + +// ChangePassword 修改密码 +func (h *TeamHandler) ChangePassword(c *gin.Context) { + userID := c.GetInt64("userID") // 假设从上下文中获取用户ID + + var req struct { + OldPassword string `json:"oldPassword"` + NewPassword string `json:"newPassword"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid input"}) + return + } + + if err := h.userService.ChangePassword(c.Request.Context(), userID, req.OldPassword, req.NewPassword); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "ok"}) +} + +// ResetPassword 重置密码 +func (h *TeamHandler) ResetPassword(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"}) + return + } + + operatorID := c.GetInt64("userID") // 假设从上下文中获取操作者ID + if err := h.userService.ResetPassword(c.Request.Context(), id, operatorID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "password reset successfully"}) +} + +// EnableUser 启用用户 +func (h *TeamHandler) EnableUser(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"}) + return + } + + operatorID := c.GetInt64("userID") // 假设从上下文中获取操作者ID + if err := h.userService.EnableUser(c.Request.Context(), id, operatorID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "user enabled successfully"}) +} + +// DisableUser 禁用用户 +func (h *TeamHandler) DisableUser(c *gin.Context) { + idStr := c.Param("id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"}) + return + } + + operatorID := c.GetInt64("userID") // 假设从上下文中获取操作者ID + if err := h.userService.DisableUser(c.Request.Context(), id, operatorID); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "user disabled successfully"}) +} diff --git a/team/model/Token.go b/team/model/Token.go deleted file mode 100644 index 7423a12..0000000 --- a/team/model/Token.go +++ /dev/null @@ -1,19 +0,0 @@ -package model - -// 用户的token -type Token struct { - Id int64 `gorm:"column:id;primaryKey;autoIncrement"` - UserId int64 `gorm:"column:user_id;not null;index:idx_token_user_id"` - Name string `gorm:"column:name;index:idx_token_name"` - Key string `gorm:"column:key;type:char(48);uniqueIndex:idx_token_key"` - Status bool `gorm:"column:status;default:true"` // enabled 0, disabled 1 - Quota int64 `gorm:"column:quota;type:bigint;default:0"` // -1 means unlimited - UnlimitedQuota bool `gorm:"column:unlimited_quota;default:true"` - UsedQuota int64 `gorm:"column:used_quota;type:bigint;default:0"` - CreatedTime int64 `gorm:"column:created_time;type:bigint"` - ExpiredTime int64 `gorm:"column:expired_time;type:bigint;default:-1"` // -1 means never expired -} - -func (Token) TableName() string { - return "token" -} diff --git a/team/model/apikey.go b/team/model/apikey.go index 503d1f6..40e84bd 100644 --- a/team/model/apikey.go +++ b/team/model/apikey.go @@ -1,26 +1,22 @@ package model -import ( - "time" - - "github.com/lib/pq" -) +import "github.com/lib/pq" //pq.StringArray type ApiKey_PG struct { ID int64 `gorm:"column:id;primaryKey;autoIncrement"` Name string `gorm:"column:name;not null;unique;index:idx_apikey_name"` ApiType string `gorm:"column:apitype;not null;unique;index:idx_apikey_apitype"` - ApiKey string `gorm:"column:apikey;not null;unique;index:idx_apikey_apikey"` - Status int `gorm:"type:int;default:0"` // enabled 0, disabled 1 - Endpoint string `gorm:"column:endpoint"` - ResourceNmae string `gorm:"column:resource_name"` - DeploymentName string `gorm:"column:deployment_name"` + ApiKey string `gorm:"column:apikey;not null;unique;uniqueIndex:idx_apikey"` + Status int `gorm:"type:int;default:1"` // enabled 1, disabled 0 + Endpoint string `gorm:"column:endpoint;comment:接入点"` + ResourceNmae string `gorm:"column:resource_name;comment:azure资源名称"` + DeploymentName string `gorm:"column:deployment_name;comment:azure部署名称"` ApiSecret string `gorm:"column:api_secret"` - ModelPrefix string `gorm:"column:model_prefix"` - ModelAlias string `gorm:"column:model_alias"` - SupportModels pq.StringArray `gorm:"type:text[]"` - UpdatedAt time.Time `json:"updatedAt,omitempty"` - CreatedAt time.Time `json:"createdAt,omitempty"` + ModelPrefix string `gorm:"column:model_prefix;comment:模型前缀"` + ModelAlias string `gorm:"column:model_alias;comment:模型别名"` + SupportModels pq.StringArray `gorm:"column:support_models;type:text[]"` + CreatedAt int64 `gorm:"column:created_at;autoUpdateTime" json:"created_at,omitempty"` + UpdatedAt int64 `gorm:"column:updated_at;autoCreateTime" json:"updated_at,omitempty"` } func (ApiKey_PG) TableName() string { @@ -28,20 +24,20 @@ func (ApiKey_PG) TableName() string { } type ApiKey struct { - ID int64 `gorm:"column:id;primaryKey;autoIncrement"` - Name string `gorm:"column:name;not null;unique;index:idx_apikey_name"` - ApiType string `gorm:"column:apitype;not null;unique;index:idx_apikey_apitype"` - ApiKey string `gorm:"column:apikey;not null;unique;index:idx_apikey_apikey"` - Status int `json:"status" gorm:"type:int;default:0"` // enabled 0, disabled 1 - Endpoint string `gorm:"column:endpoint"` - ResourceNmae string `gorm:"column:resource_name"` - DeploymentName string `gorm:"column:deployment_name"` - ApiSecret string `gorm:"column:api_secret"` - ModelPrefix string `gorm:"column:model_prefix"` - ModelAlias string `gorm:"column:model_alias"` - SupportModels []string `gorm:"type:json"` - CreatedAt time.Time `json:"created_at,omitempty" gorm:"autoUpdateTime"` - UpdatedAt time.Time `json:"updated_at,omitempty" gorm:"autoCreateTime"` + ID int64 `gorm:"column:id;primaryKey;autoIncrement"` + Name string `gorm:"column:name;not null;unique;index:idx_apikey_name"` + ApiType string `gorm:"column:apitype;not null;unique;index:idx_apikey_apitype"` + ApiKey string `gorm:"column:apikey;not null;unique;index:idx_apikey_apikey"` + Status int `gorm:"type:int;default:1"` // enabled 1, disabled 0 + Endpoint string `gorm:"column:endpoint"` + ResourceNmae string `gorm:"column:resource_name"` + DeploymentName string `gorm:"column:deployment_name"` + ApiSecret string `gorm:"column:api_secret"` + ModelPrefix string `gorm:"column:model_prefix"` + ModelAlias string `gorm:"column:model_alias"` + SupportModels []string `gorm:"column:support_models;type:json"` + CreatedAt int64 `gorm:"column:created_at;autoUpdateTime" json:"created_at,omitempty"` + UpdatedAt int64 `gorm:"column:updated_at;autoCreateTime" json:"updated_at,omitempty"` } func (ApiKey) TableName() string { diff --git a/team/model/token.go b/team/model/token.go new file mode 100644 index 0000000..6a021ff --- /dev/null +++ b/team/model/token.go @@ -0,0 +1,20 @@ +package model + +// 用户的token +type Token struct { + ID int64 `gorm:"column:id;primaryKey;autoIncrement"` + UserID int64 `gorm:"column:user_id;not null;index:idx_token_user_id"` + Name string `gorm:"column:name;index:idx_token_name"` + Key string `gorm:"column:key;not null;uniqueIndex:idx_token_key;comment:token key"` + Status int64 `gorm:"column:status;default:1;check:status IN (0,1)"` // enabled 1, disabled 0 + Quota int64 `gorm:"column:quota;type:bigint;default:0"` // default 0 + UnlimitedQuota bool `gorm:"column:unlimited_quota;default:true"` // set Quota 1 unlimited + UsedQuota int64 `gorm:"column:used_quota;type:bigint;default:0"` + CreatedAt int64 `gorm:"column:created_at;type:bigint;autoCreateTime"` + ExpiredAt int64 `gorm:"column:expired_at;type:bigint;default:-1"` // -1 means never expired + User User `gorm:"foreignKey:UserID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE" json:"user"` +} + +func (Token) TableName() string { + return "tokens" +} diff --git a/team/model/usage.go b/team/model/usage.go index dbd5258..f7204ca 100644 --- a/team/model/usage.go +++ b/team/model/usage.go @@ -8,40 +8,42 @@ import ( "github.com/gin-gonic/gin" ) -type DailyUsage struct { +type Usage struct { ID int64 `gorm:"column:id;primaryKey;autoIncrement"` UserID int64 `gorm:"column:user_id;index:idx_user_id"` - TokenId int64 `gorm:"column:token_id;index:idx_token_id"` - Capability string `gorm:"column:capability;index:idx_capability;comment:模型能力"` + TokenID int64 `gorm:"column:token_id;index:idx_token_id"` + Capability string `gorm:"column:capability;index:idx_usage_capability;comment:模型能力"` Date time.Time `gorm:"column:date;autoCreateTime;index:idx_date"` Model string `gorm:"column:model"` Stream bool `gorm:"column:stream"` - PromptTokens int `gorm:"column:prompt_tokens"` - CompletionTokens int `gorm:"column:completion_tokens"` - TotalTokens int `gorm:"column:total_tokens"` + PromptTokens float64 `gorm:"column:prompt_tokens"` + CompletionTokens float64 `gorm:"column:completion_tokens"` + TotalTokens float64 `gorm:"column:total_tokens"` Cost string `gorm:"column:cost"` - CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"` -} - -func (DailyUsage) TableName() string { - return "daily_usages" -} - -type Usage struct { - ID int `gorm:"column:id"` - UserID int `gorm:"column:user_id"` - SKU string `gorm:"column:sku"` - PromptUnits int `gorm:"column:prompt_units"` - CompletionUnits int `gorm:"column:completion_units"` - TotalUnit int `gorm:"column:total_unit"` - Cost string `gorm:"column:cost"` - Date time.Time `gorm:"column:date"` } func (Usage) TableName() string { return "usages" } +type DailyUsage struct { + ID int64 `gorm:"column:id;primaryKey;autoIncrement"` + UserID int64 `gorm:"column:user_id;uniqueIndex:idx_daily_unique,priority:1"` + TokenID int64 `gorm:"column:token_id;index:idx_daily_token_id"` + Capability string `gorm:"column:capability;uniqueIndex:idx_daily_unique,priority:2;comment:模型能力"` + Date time.Time `gorm:"column:date;autoCreateTime;uniqueIndex:idx_daily_unique,priority:3"` + Model string `gorm:"column:model"` + Stream bool `gorm:"column:stream"` + PromptTokens float64 `gorm:"column:prompt_tokens"` + CompletionTokens float64 `gorm:"column:completion_tokens"` + TotalTokens float64 `gorm:"column:total_tokens"` + Cost string `gorm:"column:cost"` +} + +func (DailyUsage) TableName() string { + return "daily_usages" +} + func HandleUsage(c *gin.Context) { fromStr := c.Query("from") toStr := c.Query("to") diff --git a/team/model/user.go b/team/model/user.go index 170b3b5..3bf44d6 100644 --- a/team/model/user.go +++ b/team/model/user.go @@ -1,99 +1,34 @@ package model import ( - "net/http" - "opencatd-open/store" "time" - - "github.com/Sakurasan/to" - "github.com/gin-gonic/gin" - "github.com/google/uuid" ) type User struct { - ID int64 `json:"id" gorm:"primaryKey;autoIncrement"` - Name string `json:"name" gorm:"unique;index" validate:"max=12"` - UserName string `json:"username" gorm:"unique;index" validate:"max=12"` - Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"` - Role int `json:"role" gorm:"type:int;default:1"` // default user - Status int `json:"status" gorm:"type:int;default:0"` // enabled 0, disabled 1 deleted 2 - Nickname string `json:"nickname" gorm:"type:varchar(50)"` - AvatarURL string `json:"avatar_url" gorm:"type:varchar(255)"` - Email string `json:"email" gorm:"type:varchar(255);unique;index"` - Quota int64 `json:"quota" gorm:"bigint;default:-1"` // default unlimited - Token string `json:"token,omitempty"` - Timezone string `json:"timezone" gorm:"type:varchar(50)"` - Language string `json:"language" gorm:"type:varchar(50)"` + ID int64 `json:"id" gorm:"column:id;primaryKey;autoIncrement"` + Name string `json:"name" gorm:"column:name;not null;unique;index"` + Username string `json:"username" gorm:"column:username;unique;index"` + Password string `json:"password" gorm:"column:password;"` + Role int `json:"role" gorm:"column:role;type:int;default:0"` // default user 0-10-20 + Status int `json:"status" gorm:"column:status;type:int;default:1"` // disabled 0, enabled 1, deleted 2 + Nickname string `json:"nickname" gorm:"column:nickname;type:varchar(50)"` + AvatarURL string `json:"avatar_url" gorm:"column:avatar_url;type:varchar(255)"` + Email string `json:"email" gorm:"column:email;type:varchar(255);index"` + Quota int64 `json:"quota" gorm:"column:quota;bigint;default:0"` // default unlimited + UnlimitedQuota int `json:"unlimited_quota" gorm:"column:unlimited_quota;default:1;check:(unlimited_quota IN (0,1))"` // 0 limited , 1 unlimited + Timezone string `json:"timezone" gorm:"column:timezone;type:varchar(50)"` + Language string `json:"language" gorm:"column:language;type:varchar(50)"` - CreatedAt time.Time `json:"created_at,omitempty" gorm:"autoCreateTime"` - UpdatedAt time.Time `json:"updated_at,omitempty" gorm:"autoUpdateTime"` + // 添加一对多关系 + // Token string `json:"-" gorm:"column:token;type:varchar(64);unique;index"` + Tokens []Token `json:"-" gorm:"foreignKey:UserID;references:ID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` + + CreatedAt int64 `json:"created_at,omitempty" gorm:"autoCreateTime"` + UpdatedAt int64 `json:"updated_at,omitempty" gorm:"autoUpdateTime"` } -func HandleUsers(c *gin.Context) { - users, err := store.GetAllUsers() - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "error": err.Error(), - }) - } - - c.JSON(http.StatusOK, users) -} - -func HandleAddUser(c *gin.Context) { - var body User - if err := c.BindJSON(&body); err != nil { - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) - return - } - if len(body.Name) == 0 { - c.JSON(http.StatusOK, gin.H{"error": "invalid user name"}) - return - } - - if err := store.AddUser(body.Name, uuid.NewString()); err != nil { - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) - return - } - u, err := store.GetUserByName(body.Name) - if err != nil { - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusOK, u) -} - -func HandleDelUser(c *gin.Context) { - id := to.Int(c.Param("id")) - if id <= 1 { - c.JSON(http.StatusOK, gin.H{"error": "invalid user id"}) - return - } - if err := store.DeleteUser(uint(id)); err != nil { - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"message": "ok"}) -} - -func HandleResetUserToken(c *gin.Context) { - id := to.Int(c.Param("id")) - newtoken := c.Query("token") - if newtoken == "" { - newtoken = uuid.NewString() - } - - if err := store.UpdateUser(uint(id), newtoken); err != nil { - c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) - return - } - u, err := store.GetUserByID(uint(id)) - if err != nil { - c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusOK, u) +func (User) TableName() string { + return "users" } type Session struct { @@ -104,8 +39,9 @@ type Session struct { DeviceName string `json:"device_name" gorm:"type:varchar(100);default:''"` LastActiveAt time.Time `json:"last_active_at" gorm:"type:timestamp;default:CURRENT_TIMESTAMP"` LogoutAt time.Time `json:"logout_at" gorm:"type:timestamp;null"` - CreatedAt time.Time `json:"created_at" gorm:"type:timestamp;not null;default:CURRENT_TIMESTAMP"` - UpdatedAt time.Time `json:"updated_at" gorm:"type:timestamp;not null;default:CURRENT_TIMESTAMP;update:CURRENT_TIMESTAMP"` + + CreatedAt time.Time `json:"created_at" gorm:"type:timestamp;not null;default:CURRENT_TIMESTAMP"` + UpdatedAt time.Time `json:"updated_at" gorm:"type:timestamp;not null;default:CURRENT_TIMESTAMP;update:CURRENT_TIMESTAMP"` } func (Session) TableName() string { diff --git a/team/service/apikey.go b/team/service/apikey.go new file mode 100644 index 0000000..543b0f7 --- /dev/null +++ b/team/service/apikey.go @@ -0,0 +1,152 @@ +package service + +import ( + "errors" + "opencatd-open/team/dao" + "opencatd-open/team/model" + "time" + + "gorm.io/gorm" +) + +var _ ApiKeyService = (*ApiKeyServiceImpl)(nil) + +type ApiKeyService interface { + Create(apiKey *model.ApiKey) error + GetByID(id int64) (*model.ApiKey, error) + GetByName(name string) (*model.ApiKey, error) + GetByApiKey(apiKeyValue string) (*model.ApiKey, error) + Update(apiKey *model.ApiKey) error + Delete(id int64) error + List(offset, limit int, status *int) ([]model.ApiKey, error) + ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.ApiKey, int64, error) + Enable(id int64) error + Disable(id int64) error + BatchEnable(ids []int64) error + BatchDisable(ids []int64) error + BatchDelete(ids []int64) error + Count() (int64, error) +} + +type ApiKeyServiceImpl struct { + apiKeyRepo dao.ApiKeyRepository + db *gorm.DB +} + +func NewApiKeyService(apiKeyDao dao.ApiKeyRepository, db *gorm.DB) ApiKeyService { + return &ApiKeyServiceImpl{apiKeyRepo: apiKeyDao, db: db} +} + +func (s *ApiKeyServiceImpl) Create(apiKey *model.ApiKey) error { + if apiKey == nil { + return errors.New("apiKey不能为空") + } + if apiKey.Name == "" { + return errors.New("apiKey名称不能为空") + } + if apiKey.ApiKey == "" { + return errors.New("apiKey值不能为空") + } + apiKey.CreatedAt = time.Now().Unix() + apiKey.UpdatedAt = time.Now().Unix() + + return s.apiKeyRepo.Create(apiKey) +} + +func (s *ApiKeyServiceImpl) GetByID(id int64) (*model.ApiKey, error) { + if id <= 0 { + return nil, errors.New("id 必须大于 0") + } + return s.apiKeyRepo.GetByID(id) +} + +func (s *ApiKeyServiceImpl) GetByName(name string) (*model.ApiKey, error) { + if name == "" { + return nil, errors.New("name 不能为空") + } + return s.apiKeyRepo.GetByName(name) +} + +func (s *ApiKeyServiceImpl) GetByApiKey(apiKeyValue string) (*model.ApiKey, error) { + if apiKeyValue == "" { + return nil, errors.New("apiKeyValue 不能为空") + } + return s.apiKeyRepo.GetByApiKey(apiKeyValue) +} + +func (s *ApiKeyServiceImpl) Update(apiKey *model.ApiKey) error { + if apiKey == nil { + return errors.New("apiKey不能为空") + } + if apiKey.ID <= 0 { + return errors.New("apiKey ID 必须大于 0") + } + return s.apiKeyRepo.Update(apiKey) +} + +func (s *ApiKeyServiceImpl) Delete(id int64) error { + if id <= 0 { + return errors.New("id 必须大于 0") + } + return s.apiKeyRepo.Delete(id) +} + +func (s *ApiKeyServiceImpl) List(offset, limit int, status *int) ([]model.ApiKey, error) { + if offset < 0 { + offset = 0 + } + if limit <= 0 { + limit = 10 // 设置默认值 + } + return s.apiKeyRepo.List(offset, limit, status) +} + +func (s *ApiKeyServiceImpl) ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.ApiKey, int64, error) { + if offset < 0 { + offset = 0 + } + if limit <= 0 { + limit = 10 // 设置默认值 + } + + return s.apiKeyRepo.ListWithFilters(offset, limit, filters) +} + +func (s *ApiKeyServiceImpl) Enable(id int64) error { + if id <= 0 { + return errors.New("id 必须大于 0") + } + return s.apiKeyRepo.Enable(id) +} + +func (s *ApiKeyServiceImpl) Disable(id int64) error { + if id <= 0 { + return errors.New("id 必须大于 0") + } + return s.apiKeyRepo.Disable(id) +} + +func (s *ApiKeyServiceImpl) BatchEnable(ids []int64) error { + if len(ids) == 0 { + return errors.New("ids 不能为空") + } + return s.apiKeyRepo.BatchEnable(ids) +} + +func (s *ApiKeyServiceImpl) BatchDisable(ids []int64) error { + if len(ids) == 0 { + return errors.New("ids 不能为空") + } + return s.apiKeyRepo.BatchDisable(ids) +} + +func (s *ApiKeyServiceImpl) BatchDelete(ids []int64) error { + if len(ids) == 0 { + return errors.New("ids 不能为空") + } + return s.apiKeyRepo.BatchDelete(ids) +} + +func (s *ApiKeyServiceImpl) Count() (int64, error) { + return s.apiKeyRepo.Count() +} diff --git a/team/service/token.go b/team/service/token.go new file mode 100644 index 0000000..8a4491b --- /dev/null +++ b/team/service/token.go @@ -0,0 +1,97 @@ +package service + +import ( + "context" + "opencatd-open/team/dao" + "opencatd-open/team/model" + "strings" + + "github.com/google/uuid" +) + +// 确保 TokenService 实现了 TokenServiceInterface 接口 +var _ TokenService = (*TokenServiceImpl)(nil) + +type TokenService interface { + Create(ctx context.Context, token *model.Token) error + GetByID(ctx context.Context, id int) (*model.Token, error) + GetByKey(ctx context.Context, key string) (*model.Token, error) + GetByUserID(ctx context.Context, userID int) (*model.Token, error) + Update(ctx context.Context, token *model.Token) error + UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error + Delete(ctx context.Context, id int) error + Lists(ctx context.Context, offset, limit int) ([]model.Token, error) + ListsWithFilters(ctx context.Context, offset, limit int, filters map[string]interface{}) ([]model.Token, int64, error) + Disable(ctx context.Context, id int) error + Enable(ctx context.Context, id int) error + BatchDisable(ctx context.Context, ids []int) error + BatchEnable(ctx context.Context, ids []int) error + BatchDelete(ctx context.Context, ids []int) error +} + +type TokenServiceImpl struct { + tokenRepo dao.TokenRepository +} + +func NewTokenService(tokenRepo dao.TokenRepository) TokenService { + return &TokenServiceImpl{tokenRepo: tokenRepo} +} + +func (s *TokenServiceImpl) Create(ctx context.Context, token *model.Token) error { + if token.Key == "" { + token.Key = "team-" + strings.ReplaceAll(uuid.New().String(), "-", "") + } + return s.tokenRepo.Create(ctx, token) +} + +func (s *TokenServiceImpl) GetByID(ctx context.Context, id int) (*model.Token, error) { + return s.tokenRepo.GetByID(ctx, id) +} + +func (s *TokenServiceImpl) GetByKey(ctx context.Context, key string) (*model.Token, error) { + return s.tokenRepo.GetByKey(ctx, key) +} + +func (s *TokenServiceImpl) GetByUserID(ctx context.Context, userID int) (*model.Token, error) { + return s.tokenRepo.GetByUserID(ctx, userID) +} + +func (s *TokenServiceImpl) Update(ctx context.Context, token *model.Token) error { + return s.tokenRepo.Update(ctx, token) +} + +func (s *TokenServiceImpl) UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error { + return s.tokenRepo.UpdateWithCondition(ctx, token, filters, updates) +} + +func (s *TokenServiceImpl) Delete(ctx context.Context, id int) error { + return s.tokenRepo.Delete(ctx, id) +} + +func (s *TokenServiceImpl) Lists(ctx context.Context, offset, limit int) ([]model.Token, error) { + return s.tokenRepo.List(ctx, offset, limit) +} + +func (s *TokenServiceImpl) ListsWithFilters(ctx context.Context, offset, limit int, filters map[string]interface{}) ([]model.Token, int64, error) { + return s.tokenRepo.ListWithFilters(ctx, offset, limit, filters) +} + +func (s *TokenServiceImpl) Disable(ctx context.Context, id int) error { + return s.tokenRepo.Disable(ctx, id) +} + +func (s *TokenServiceImpl) Enable(ctx context.Context, id int) error { + return s.tokenRepo.Enable(ctx, id) +} + +func (s *TokenServiceImpl) BatchDisable(ctx context.Context, ids []int) error { + return s.tokenRepo.BatchDisable(ctx, ids) +} + +func (s *TokenServiceImpl) BatchEnable(ctx context.Context, ids []int) error { + return s.tokenRepo.BatchEnable(ctx, ids) +} + +func (s *TokenServiceImpl) BatchDelete(ctx context.Context, ids []int) error { + return s.tokenRepo.BatchDelete(ctx, ids) +} diff --git a/team/service/usage.go b/team/service/usage.go new file mode 100644 index 0000000..5e131f6 --- /dev/null +++ b/team/service/usage.go @@ -0,0 +1,137 @@ +package service + +import ( + "context" + "fmt" + "opencatd-open/team/dao" + dto "opencatd-open/team/dto/team" + "opencatd-open/team/model" + "time" + + "log" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +var _ UsageService = (*usageService)(nil) + +type UsageService interface { + // AsyncProcessUsage 异步处理使用记录 + AsyncProcessUsage(usage *model.Usage) + + ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error) + ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) + ListByDateRange(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) + + Delete(ctx context.Context, id int64) error +} + +type usageService struct { + db *gorm.DB + usageDAO dao.UsageRepository + dailyUsageDAO dao.DailyUsageRepository + usageChan chan *model.Usage // 用于异步处理的channel + ctx context.Context +} + +func NewUsageService(ctx context.Context, db *gorm.DB, usageRepo dao.UsageRepository, dailyUsageRepo dao.DailyUsageRepository) UsageService { + srv := &usageService{ + db: db, + usageDAO: usageRepo, + dailyUsageDAO: dailyUsageRepo, + usageChan: make(chan *model.Usage, 1000), // 设置合适的缓冲区大小 + ctx: ctx, + } + + // 启动异步处理goroutine + go srv.processUsageWorker() + + return srv +} + +func (s *usageService) AsyncProcessUsage(usage *model.Usage) { + select { + case s.usageChan <- usage: + // 成功发送到channel + default: + // channel已满,记录错误日志 + log.Println("usage channel is full, skip processing") + } +} + +func (s *usageService) processUsageWorker() { + for { + select { + case usage := <-s.usageChan: + err := s.processUsage(usage) + if err != nil { + log.Println("processUsage error:", err) + } + case <-s.ctx.Done(): + log.Println("processUsageWorker is exiting") + return + } + } +} + +// processUsageWorker 异步处理worker +func (s *usageService) processUsage(usage *model.Usage) error { + err := s.db.Transaction(func(tx *gorm.DB) error { + // 1. 记录使用记录 + if err := tx.WithContext(s.ctx).Create(usage).Error; err != nil { + return fmt.Errorf("create usage error: %w", err) + } + + // 2. 更新每日统计(upsert 操作) + dailyUsage := model.DailyUsage{ + UserID: usage.UserID, + TokenID: usage.TokenID, + Capability: usage.Capability, + Date: time.Date(usage.Date.Year(), usage.Date.Month(), usage.Date.Day(), 0, 0, 0, 0, usage.Date.Location()), + Model: usage.Model, + Stream: usage.Stream, + PromptTokens: usage.PromptTokens, + CompletionTokens: usage.CompletionTokens, + TotalTokens: usage.TotalTokens, + Cost: usage.Cost, + } + + // 使用 OnConflict 实现 upsert + if err := tx.WithContext(s.ctx).Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "user_id"}, {Name: "token_id"}, {Name: "capability"}, {Name: "date"}}, // 唯一键 + DoUpdates: clause.Assignments(map[string]interface{}{ + "prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens), + "completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens), + "total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens), + "cost": gorm.Expr("cost + ?", usage.Cost), + }), + }).Create(&dailyUsage).Error; err != nil { + return fmt.Errorf("upsert daily usage error: %w", err) + } + + // 3. 更新用户额度 + if err := tx.WithContext(s.ctx).Model(&model.User{}).Where("id = ?", usage.UserID).Update("quota", gorm.Expr("quota - ?", usage.Cost)).Error; err != nil { + return fmt.Errorf("update user quota error: %w", err) + } + + return nil + }) + return err +} + +func (s *usageService) ListByUserID(ctx context.Context, userID int64, limit int, offset int) ([]*model.Usage, error) { + return s.usageDAO.ListByUserID(ctx, userID, limit, offset) +} + +func (s *usageService) ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) { + return s.usageDAO.ListByCapability(ctx, capability, limit, offset) +} + +func (s *usageService) ListByDateRange(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) { + return s.dailyUsageDAO.StatUserUsages(ctx, start, end, filters) +} + +func (s *usageService) Delete(ctx context.Context, id int64) error { + return s.usageDAO.Delete(ctx, id) +} diff --git a/team/service/user.go b/team/service/user.go new file mode 100644 index 0000000..9b809e7 --- /dev/null +++ b/team/service/user.go @@ -0,0 +1,702 @@ +package service + +import ( + "context" + "errors" + "opencatd-open/team/consts" + "opencatd-open/team/dao" + "opencatd-open/team/model" + "regexp" + "strings" + "time" + + "golang.org/x/crypto/bcrypt" + "golang.org/x/exp/rand" + + "gorm.io/gorm" +) + +var ( + ErrUserNotFound = errors.New("user not found") + ErrInvalidUserInput = errors.New("invalid user input") + ErrUserExists = errors.New("user already exists") + ErrInvalidPassword = errors.New("invalid password format") + ErrPermissionDenied = errors.New("permission denied") + ErrInvalidOperation = errors.New("invalid operation") + ErrTransactionFailed = errors.New("transaction failed") +) + +// PasswordPolicy 定义密码策略 +type PasswordPolicy struct { + MinLength int + MaxLength int + NeedNumber bool + NeedUpper bool + NeedLower bool + NeedSymbol bool +} + +// 生成随机数字 +func generateNumber() string { + rand.Seed(uint64(time.Now().UnixNano())) + return string('0' + rand.Intn(10)) +} + +// 生成随机大写字母 +func generateUpper() string { + rand.Seed(uint64(time.Now().UnixNano())) + return string('A' + rand.Intn(26)) +} + +// 生成随机小写字母 +func generateLower() string { + rand.Seed(uint64(time.Now().UnixNano())) + return string('a' + rand.Intn(26)) +} + +// 生成随机特殊符号 +func generateSymbol() string { + rand.Seed(uint64(time.Now().UnixNano())) + symbols := "!@#$%^&*" + return string(symbols[rand.Intn(len(symbols))]) +} + +// GeneratePassword 根据密码策略生成密码 +func GeneratePassword(policy PasswordPolicy) string { + rand.Seed(uint64(time.Now().UnixNano())) + + // 确保满足所有必须的字符类型 + var password string + if policy.NeedNumber { + password += generateNumber() + } + if policy.NeedUpper { + password += generateUpper() + } + if policy.NeedLower { + password += generateLower() + } + if policy.NeedSymbol { + password += generateSymbol() + } + + // 计算还需要多少个字符 + remainingLength := policy.MinLength - len(password) + if remainingLength < 0 { + remainingLength = 0 + } + // 剩余长度随机生成密码字符 + for i := 0; i < remainingLength; i++ { + randType := rand.Intn(4) // 0:数字, 1:大写, 2:小写, 3:符号 + switch randType { + case 0: + password += generateNumber() + case 1: + password += generateUpper() + case 2: + password += generateLower() + case 3: + password += generateSymbol() + } + } + + // 如果密码长度超过最大值,则截断 + if len(password) > policy.MaxLength { + password = password[:policy.MaxLength] + } + + // 将密码打乱 + passwordRune := []rune(password) + rand.Shuffle(len(passwordRune), func(i, j int) { + passwordRune[i], passwordRune[j] = passwordRune[j], passwordRune[i] + }) + + return string(passwordRune) +} + +var _ UserService = (*userService)(nil) + +// UserService 定义用户服务的接口 +type UserService interface { + CreateUser(ctx context.Context, user *model.User) error + GetUser(ctx context.Context, id int64) (*model.User, error) + GetUserByUsername(ctx context.Context, username string) (*model.User, error) + UpdateUser(ctx context.Context, user *model.User, operatorID int64) error + DeleteUser(ctx context.Context, id int64, operatorID int64) error + ListUsers(ctx context.Context, page, pageSize int) ([]model.User, int64, error) + ListUsersWithFilters(ctx context.Context, page, pageSize int, filters map[string]interface{}) ([]model.User, int64, error) + EnableUser(ctx context.Context, id int64, operatorID int64) error + DisableUser(ctx context.Context, id int64, operatorID int64) error + BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error + BatchDisableUsers(ctx context.Context, ids []int64, operatorID int64) error + BatchDeleteUsers(ctx context.Context, ids []int64, operatorID int64) error + ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error + ResetPassword(ctx context.Context, userID int64, operatorID int64) error + ValidatePassword(password string) error + CheckPermission(ctx context.Context, requiredRole consts.UserRole) error +} + +// userService 实现 UserService 接口 +type userService struct { + userRepo dao.UserRepository + db *gorm.DB + pwdPolicy PasswordPolicy +} + +// NewUserService 创建 UserService 实例 +func NewUserService(userRepo dao.UserRepository, db *gorm.DB) UserService { + return &userService{ + userRepo: userRepo, + db: db, + pwdPolicy: PasswordPolicy{ + MinLength: 8, + MaxLength: 32, + NeedNumber: true, + NeedUpper: true, + NeedLower: true, + NeedSymbol: true, + }, + } +} + +// hashPassword 使用 bcrypt 加密密码 +func (s *userService) hashPassword(password string) (string, error) { + hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return "", err + } + return string(hashedBytes), nil +} + +// comparePasswords 比较密码 +func (s *userService) comparePasswords(hashedPassword, password string) bool { + err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) + return err == nil +} + +// ValidatePassword 验证密码是否符合策略 +func (s *userService) ValidatePassword(password string) error { + if len(password) < s.pwdPolicy.MinLength || len(password) > s.pwdPolicy.MaxLength { + return ErrInvalidPassword + } + + if s.pwdPolicy.NeedNumber && !regexp.MustCompile(`[0-9]`).MatchString(password) { + return ErrInvalidPassword + } + + if s.pwdPolicy.NeedUpper && !regexp.MustCompile(`[A-Z]`).MatchString(password) { + return ErrInvalidPassword + } + + if s.pwdPolicy.NeedLower && !regexp.MustCompile(`[a-z]`).MatchString(password) { + return ErrInvalidPassword + } + + if s.pwdPolicy.NeedSymbol && !regexp.MustCompile(`[!@#$%^&*]`).MatchString(password) { + return ErrInvalidPassword + } + + return nil +} + +// CheckPermission 检查用户权限 +func (s *userService) CheckPermission(ctx context.Context, requiredRole consts.UserRole) error { + userToken := ctx.Value("Token").(*model.Token) + + // 检查用户角色 + if userToken.User.Role < int(requiredRole) { + return ErrPermissionDenied + } + + return nil +} + +// withTransaction 事务处理封装 +func (s *userService) withTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error { + tx := s.db.WithContext(ctx).Begin() + if tx.Error != nil { + return ErrTransactionFailed + } + + defer func() { + if r := recover(); r != nil { + tx.Rollback() + panic(r) + } + }() + + if err := fn(tx); err != nil { + tx.Rollback() + return err + } + + if err := tx.Commit().Error; err != nil { + return ErrTransactionFailed + } + + return nil +} + +// CreateUser 创建用户 +func (s *userService) CreateUser(ctx context.Context, user *model.User) error { + if user == nil { + return ErrInvalidUserInput + } + + if user.Password == "" { + user.Password = GeneratePassword(s.pwdPolicy) + } + + // 使用事务处理 + return s.withTransaction(ctx, func(tx *gorm.DB) error { + // 检查用户名是否已存在 + // _, err := s.userRepo.GetByID(user.ID) + // if err != nil { + // return err + // } + + // 加密密码 + hashedPassword, err := s.hashPassword(user.Password) + if err != nil { + return err + } + user.Password = hashedPassword + + return s.userRepo.Create(user) + }) +} + +// GetUser 根据 ID 获取用户 +func (s *userService) GetUser(ctx context.Context, id int64) (*model.User, error) { + if id <= 0 { + return nil, ErrInvalidUserInput + } + + user, err := s.userRepo.GetByID(id) + if err != nil { + return nil, err // 返回其他数据库错误 + } + // 处理返回结果,清除敏感信息 + user.Password = "" // 清除密码信息 + + return user, nil +} + +// GetUserByUsername 根据用户名获取用户 +func (s *userService) GetUserByUsername(ctx context.Context, username string) (*model.User, error) { + if username == "" { + return nil, ErrInvalidUserInput + } + + user, err := s.userRepo.GetByUsername(username) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrUserNotFound + } + return nil, err // 返回其他数据库错误 + } + + // 处理返回结果,清除敏感信息 + user.Password = "" // 清除密码信息 + + return user, nil +} + +// UpdateUser 更新用户信息 +func (s *userService) UpdateUser(ctx context.Context, user *model.User, operatorID int64) error { + if user == nil || user.ID <= 0 { + return ErrInvalidUserInput + } + + return s.withTransaction(ctx, func(tx *gorm.DB) error { + // 检查用户是否存在 + existingUser, err := s.userRepo.GetByID(user.ID) + if err != nil { + return err + } + + // 如果修改了用户名,检查新用户名是否已存在 + if user.Username != existingUser.Username { + tmpUser, err := s.userRepo.GetByUsername(user.Username) + if err == nil && tmpUser != nil && tmpUser.ID != user.ID { + return ErrUserExists + } + } + + // 保持原有密码 + user.Password = existingUser.Password + user.UpdatedAt = time.Now().Unix() + + return s.userRepo.Update(user) + }) +} + +// ChangePassword 修改密码 +func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error { + // 验证新密码 + if err := s.ValidatePassword(newPassword); err != nil { + return err + } + + return s.withTransaction(ctx, func(tx *gorm.DB) error { + user, err := s.userRepo.GetByID(userID) + if err != nil { + return ErrUserNotFound + } + + // 验证旧密码 + if !s.comparePasswords(user.Password, oldPassword) { + return ErrInvalidPassword + } + + // 加密新密码 + hashedPassword, err := s.hashPassword(newPassword) + if err != nil { + return err + } + + user.Password = hashedPassword + user.UpdatedAt = time.Now().Unix() + + return s.userRepo.Update(user) + }) +} + +// ResetPassword 重置密码 +func (s *userService) ResetPassword(ctx context.Context, userID int64, operatorID int64) error { + return s.withTransaction(ctx, func(tx *gorm.DB) error { + user, err := s.userRepo.GetByID(userID) + if err != nil { + return ErrUserNotFound + } + + // 生成随机密码 + newPassword := generateRandomPassword() + hashedPassword, err := s.hashPassword(newPassword) + if err != nil { + return err + } + + user.Password = hashedPassword + user.UpdatedAt = time.Now().Unix() + + // TODO: 发送新密码给用户邮箱 + + return s.userRepo.Update(user) + }) +} + +// ListUsers 获取用户列表(增加过滤功能) +func (s *userService) ListUsers(ctx context.Context, page, pageSize int) ([]model.User, int64, error) { + if page < 1 { + page = 1 + } + if pageSize < 1 { + pageSize = 10 + } + + offset := (page - 1) * pageSize + + users, err := s.userRepo.List(offset, pageSize) + if err != nil { + return nil, 0, err + } + + var total int64 = 0 + + return users, total, nil +} + +// ListUsers 获取用户列表(增加过滤功能) +func (s *userService) ListUsersWithFilters(ctx context.Context, page, pageSize int, filters map[string]interface{}) ([]model.User, int64, error) { + if page < 1 { + page = 1 + } + if pageSize < 1 { + pageSize = 10 + } + + offset := (page - 1) * pageSize + + // 使用新的 ListWithFilters 方法 + users, total, err := s.userRepo.ListWithFilters(offset, pageSize, filters) + if err != nil { + return nil, 0, err + } + + return users, total, nil +} + +// generateRandomPassword 生成随机密码 +func generateRandomPassword() string { + const ( + lowerChars = "abcdefghijklmnopqrstuvwxyz" + upperChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + numberChars = "0123456789" + specialChars = "!@#$%^&*" + ) + // rand.NewSource(uint64(time.Now().UnixNano())) + + // 确保每种字符都至少出现一次 + password := []string{ + string(lowerChars[rand.Intn(len(lowerChars))]), + string(upperChars[rand.Intn(len(upperChars))]), + string(numberChars[rand.Intn(len(numberChars))]), + string(specialChars[rand.Intn(len(specialChars))]), + } + + // 所有可用字符 + allChars := lowerChars + upperChars + numberChars + specialChars + + // 生成剩余的12个字符 + for i := 0; i < 12; i++ { + password = append(password, string(allChars[rand.Intn(len(allChars))])) + } + + // 打乱密码字符顺序 + rand.Shuffle(len(password), func(i, j int) { + password[i], password[j] = password[j], password[i] + }) + + return strings.Join(password, "") +} + +// DeleteUser 删除用户 +func (s *userService) DeleteUser(ctx context.Context, id int64, operatorID int64) error { + // 检查参数 + if id <= 0 { + return ErrInvalidUserInput + } + + if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil { + return err + } + + // 不允许删除自己 + if id == operatorID { + return ErrInvalidOperation + } + + return s.withTransaction(ctx, func(tx *gorm.DB) error { + // 检查用户是否存在 + user, err := s.userRepo.GetByID(id) + if err != nil { + return err + } + + // 检查是否试图删除管理员 + if user.Role == int(consts.RoleAdmin) { + return ErrPermissionDenied + } + + return s.userRepo.Delete(id) + }) +} + +// EnableUser 启用用户 +func (s *userService) EnableUser(ctx context.Context, id int64, operatorID int64) error { + // 检查参数 + if id <= 0 { + return ErrInvalidUserInput + } + + // 检查操作者权限 + if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil { + return err + } + + return s.withTransaction(ctx, func(tx *gorm.DB) error { + // 检查用户是否存在 + user, err := s.userRepo.GetByID(id) + if err != nil { + return ErrUserNotFound + } + + // 如果用户已经是启用状态,返回成功 + if user.Status == consts.StatusEnabled { + return nil + } + + return s.userRepo.Enable(id) + }) +} + +// DisableUser 禁用用户 +func (s *userService) DisableUser(ctx context.Context, id int64, operatorID int64) error { + // 检查参数 + if id <= 0 { + return ErrInvalidUserInput + } + + // 检查操作者权限 + if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil { + return err + } + + // 不允许禁用自己 + if id == operatorID { + return ErrInvalidOperation + } + + return s.withTransaction(ctx, func(tx *gorm.DB) error { + // 检查用户是否存在 + user, err := s.userRepo.GetByID(id) + if err != nil { + return ErrUserNotFound + } + + // 检查是否试图禁用超级管理员 + if user.Role == int(consts.RoleAdmin) { + return ErrPermissionDenied + } + + // 如果用户已经是禁用状态,返回成功 + if user.Status == consts.StatusDisabled { + return nil + } + + return s.userRepo.Disable(id) + }) +} + +// BatchEnableUsers 批量启用用户 +func (s *userService) BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error { + // 检查参数 + if len(ids) == 0 { + return ErrInvalidUserInput + } + + // 检查操作者权限 + if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil { + return err + } + + return s.withTransaction(ctx, func(tx *gorm.DB) error { + // 检查所有用户是否存在,并收集当前状态 + enabledUsers := make([]int64, 0) + for _, id := range ids { + user, err := s.userRepo.GetByID(id) + if err != nil { + return ErrUserNotFound + } + if user.Status == consts.StatusEnabled { + enabledUsers = append(enabledUsers, id) + } + } + + // 如果所有用户都已经是启用状态,返回成功 + if len(enabledUsers) == len(ids) { + return nil + } + + // 过滤掉已经启用的用户,只处理需要启用的用户 + toEnableIds := make([]int64, 0) + for _, id := range ids { + if !contains(enabledUsers, id) { + toEnableIds = append(toEnableIds, id) + } + } + + if len(toEnableIds) > 0 { + return s.userRepo.BatchEnable(toEnableIds) + } + return nil + }) +} + +// BatchDisableUsers 批量禁用用户 +func (s *userService) BatchDisableUsers(ctx context.Context, ids []int64, operatorID int64) error { + // 检查参数 + if len(ids) == 0 { + return ErrInvalidUserInput + } + + // 检查操作者权限 + if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil { + return err + } + + // 不允许包含自己 + if contains(ids, operatorID) { + return ErrInvalidOperation + } + + return s.withTransaction(ctx, func(tx *gorm.DB) error { + // 检查所有用户是否存在 + disabledUsers := make([]int64, 0) + for _, id := range ids { + user, err := s.userRepo.GetByID(id) + if err != nil { + return ErrUserNotFound + } + // 不允许禁用管理员 + if user.Role == int(consts.RoleAdmin) { + return ErrPermissionDenied + } + if user.Status == consts.StatusDisabled { + disabledUsers = append(disabledUsers, id) + } + } + + // 如果所有用户都已经是禁用状态,返回成功 + if len(disabledUsers) == len(ids) { + return nil + } + + // 过滤掉已经禁用的用户,只处理需要禁用的用户 + toDisableIds := make([]int64, 0) + for _, id := range ids { + if !contains(disabledUsers, id) { + toDisableIds = append(toDisableIds, id) + } + } + + if len(toDisableIds) > 0 { + return s.userRepo.BatchDisable(toDisableIds) + } + return nil + }) +} + +// BatchDeleteUsers 批量删除用户 +func (s *userService) BatchDeleteUsers(ctx context.Context, ids []int64, operatorID int64) error { + // 检查参数 + if len(ids) == 0 { + return ErrInvalidUserInput + } + + // 检查操作者权限 + if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil { + return err + } + + // 不允许包含自己 + if contains(ids, operatorID) { + return ErrInvalidOperation + } + + return s.withTransaction(ctx, func(tx *gorm.DB) error { + // 检查所有用户是否存在,并确保不会删除管理员 + for _, id := range ids { + user, err := s.userRepo.GetByID(id) + if err != nil { + return ErrUserNotFound + } + if user.Role == int(consts.RoleAdmin) { + return ErrPermissionDenied + } + } + + return s.userRepo.BatchDelete(ids) + }) +} + +// contains 检查切片中是否包含特定值 +func contains(slice []int64, item int64) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} diff --git a/wire/wire.go b/wire/wire.go new file mode 100644 index 0000000..c2f9256 --- /dev/null +++ b/wire/wire.go @@ -0,0 +1,47 @@ +//go:build wireinject +// +build wireinject + +package wire + +import ( + "context" + "opencatd-open/team/dao" + handler "opencatd-open/team/handler/team" + "opencatd-open/team/service" + + "github.com/google/wire" + "gorm.io/gorm" +) + +// 定义 provider set +var userSet = wire.NewSet( + dao.NewUserDAO, + wire.Bind(new(dao.UserRepository), new(*dao.UserDAO)), + service.NewUserService, +) + +var keySet = wire.NewSet( + dao.NewApiKeyDAO, + wire.Bind(new(dao.ApiKeyRepository), new(*dao.ApiKeyDAO)), + service.NewApiKeyService, +) + +var tokenSet = wire.NewSet( + dao.NewTokenDAO, + wire.Bind(new(dao.TokenRepository), new(*dao.TokenDAO)), + service.NewTokenService, +) + +var usageSet = wire.NewSet( + dao.NewUsageDAO, + wire.Bind(new(dao.UsageRepository), new(*dao.UsageDAO)), + dao.NewDailyUsageDAO, + wire.Bind(new(dao.DailyUsageRepository), new(*dao.DailyUsageDAO)), + service.NewUsageService, +) + +// 初始化 TeamHandler +func InitTeamHandler(ctx context.Context, db *gorm.DB) (*handler.TeamHandler, error) { + wire.Build(userSet, keySet, tokenSet, usageSet, handler.NewTeamHandler) + return nil, nil +} diff --git a/wire/wire_gen.go b/wire/wire_gen.go new file mode 100644 index 0000000..93b6b5e --- /dev/null +++ b/wire/wire_gen.go @@ -0,0 +1,43 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:generate go run github.com/google/wire/cmd/wire +//+build !wireinject + +package wire + +import ( + "context" + "github.com/google/wire" + "gorm.io/gorm" + "opencatd-open/team/dao" + "opencatd-open/team/handler/team" + "opencatd-open/team/service" +) + +// Injectors from wire.go: + +// 初始化 TeamHandler +func InitTeamHandler(ctx context.Context, db *gorm.DB) (*handler.TeamHandler, error) { + userDAO := dao.NewUserDAO(db) + userService := service.NewUserService(userDAO, db) + tokenDAO := dao.NewTokenDAO(db) + tokenService := service.NewTokenService(tokenDAO) + apiKeyDAO := dao.NewApiKeyDAO(db) + apiKeyService := service.NewApiKeyService(apiKeyDAO, db) + usageDAO := dao.NewUsageDAO(db) + dailyUsageDAO := dao.NewDailyUsageDAO(db) + usageService := service.NewUsageService(ctx, db, usageDAO, dailyUsageDAO) + teamHandler := handler.NewTeamHandler(userService, tokenService, apiKeyService, usageService) + return teamHandler, nil +} + +// wire.go: + +// 定义 provider set +var userSet = wire.NewSet(dao.NewUserDAO, wire.Bind(new(dao.UserRepository), new(*dao.UserDAO)), service.NewUserService) + +var keySet = wire.NewSet(dao.NewApiKeyDAO, wire.Bind(new(dao.ApiKeyRepository), new(*dao.ApiKeyDAO)), service.NewApiKeyService) + +var tokenSet = wire.NewSet(dao.NewTokenDAO, wire.Bind(new(dao.TokenRepository), new(*dao.TokenDAO)), service.NewTokenService) + +var usageSet = wire.NewSet(dao.NewUsageDAO, wire.Bind(new(dao.UsageRepository), new(*dao.UsageDAO)), dao.NewDailyUsageDAO, wire.Bind(new(dao.DailyUsageRepository), new(*dao.DailyUsageDAO)), service.NewUsageService)