This commit is contained in:
Sakurasan
2025-02-01 23:52:55 +08:00
parent 65d6d12972
commit bc223d6530
30 changed files with 2683 additions and 242 deletions

152
team/service/apikey.go Normal file
View File

@@ -0,0 +1,152 @@
package service
import (
"errors"
"opencatd-open/team/dao"
"opencatd-open/team/model"
"time"
"gorm.io/gorm"
)
var _ ApiKeyService = (*ApiKeyServiceImpl)(nil)
type ApiKeyService interface {
Create(apiKey *model.ApiKey) error
GetByID(id int64) (*model.ApiKey, error)
GetByName(name string) (*model.ApiKey, error)
GetByApiKey(apiKeyValue string) (*model.ApiKey, error)
Update(apiKey *model.ApiKey) error
Delete(id int64) error
List(offset, limit int, status *int) ([]model.ApiKey, error)
ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.ApiKey, int64, error)
Enable(id int64) error
Disable(id int64) error
BatchEnable(ids []int64) error
BatchDisable(ids []int64) error
BatchDelete(ids []int64) error
Count() (int64, error)
}
type ApiKeyServiceImpl struct {
apiKeyRepo dao.ApiKeyRepository
db *gorm.DB
}
func NewApiKeyService(apiKeyDao dao.ApiKeyRepository, db *gorm.DB) ApiKeyService {
return &ApiKeyServiceImpl{apiKeyRepo: apiKeyDao, db: db}
}
func (s *ApiKeyServiceImpl) Create(apiKey *model.ApiKey) error {
if apiKey == nil {
return errors.New("apiKey不能为空")
}
if apiKey.Name == "" {
return errors.New("apiKey名称不能为空")
}
if apiKey.ApiKey == "" {
return errors.New("apiKey值不能为空")
}
apiKey.CreatedAt = time.Now().Unix()
apiKey.UpdatedAt = time.Now().Unix()
return s.apiKeyRepo.Create(apiKey)
}
func (s *ApiKeyServiceImpl) GetByID(id int64) (*model.ApiKey, error) {
if id <= 0 {
return nil, errors.New("id 必须大于 0")
}
return s.apiKeyRepo.GetByID(id)
}
func (s *ApiKeyServiceImpl) GetByName(name string) (*model.ApiKey, error) {
if name == "" {
return nil, errors.New("name 不能为空")
}
return s.apiKeyRepo.GetByName(name)
}
func (s *ApiKeyServiceImpl) GetByApiKey(apiKeyValue string) (*model.ApiKey, error) {
if apiKeyValue == "" {
return nil, errors.New("apiKeyValue 不能为空")
}
return s.apiKeyRepo.GetByApiKey(apiKeyValue)
}
func (s *ApiKeyServiceImpl) Update(apiKey *model.ApiKey) error {
if apiKey == nil {
return errors.New("apiKey不能为空")
}
if apiKey.ID <= 0 {
return errors.New("apiKey ID 必须大于 0")
}
return s.apiKeyRepo.Update(apiKey)
}
func (s *ApiKeyServiceImpl) Delete(id int64) error {
if id <= 0 {
return errors.New("id 必须大于 0")
}
return s.apiKeyRepo.Delete(id)
}
func (s *ApiKeyServiceImpl) List(offset, limit int, status *int) ([]model.ApiKey, error) {
if offset < 0 {
offset = 0
}
if limit <= 0 {
limit = 10 // 设置默认值
}
return s.apiKeyRepo.List(offset, limit, status)
}
func (s *ApiKeyServiceImpl) ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.ApiKey, int64, error) {
if offset < 0 {
offset = 0
}
if limit <= 0 {
limit = 10 // 设置默认值
}
return s.apiKeyRepo.ListWithFilters(offset, limit, filters)
}
func (s *ApiKeyServiceImpl) Enable(id int64) error {
if id <= 0 {
return errors.New("id 必须大于 0")
}
return s.apiKeyRepo.Enable(id)
}
func (s *ApiKeyServiceImpl) Disable(id int64) error {
if id <= 0 {
return errors.New("id 必须大于 0")
}
return s.apiKeyRepo.Disable(id)
}
func (s *ApiKeyServiceImpl) BatchEnable(ids []int64) error {
if len(ids) == 0 {
return errors.New("ids 不能为空")
}
return s.apiKeyRepo.BatchEnable(ids)
}
func (s *ApiKeyServiceImpl) BatchDisable(ids []int64) error {
if len(ids) == 0 {
return errors.New("ids 不能为空")
}
return s.apiKeyRepo.BatchDisable(ids)
}
func (s *ApiKeyServiceImpl) BatchDelete(ids []int64) error {
if len(ids) == 0 {
return errors.New("ids 不能为空")
}
return s.apiKeyRepo.BatchDelete(ids)
}
func (s *ApiKeyServiceImpl) Count() (int64, error) {
return s.apiKeyRepo.Count()
}

97
team/service/token.go Normal file
View File

@@ -0,0 +1,97 @@
package service
import (
"context"
"opencatd-open/team/dao"
"opencatd-open/team/model"
"strings"
"github.com/google/uuid"
)
// 确保 TokenService 实现了 TokenServiceInterface 接口
var _ TokenService = (*TokenServiceImpl)(nil)
type TokenService interface {
Create(ctx context.Context, token *model.Token) error
GetByID(ctx context.Context, id int) (*model.Token, error)
GetByKey(ctx context.Context, key string) (*model.Token, error)
GetByUserID(ctx context.Context, userID int) (*model.Token, error)
Update(ctx context.Context, token *model.Token) error
UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error
Delete(ctx context.Context, id int) error
Lists(ctx context.Context, offset, limit int) ([]model.Token, error)
ListsWithFilters(ctx context.Context, offset, limit int, filters map[string]interface{}) ([]model.Token, int64, error)
Disable(ctx context.Context, id int) error
Enable(ctx context.Context, id int) error
BatchDisable(ctx context.Context, ids []int) error
BatchEnable(ctx context.Context, ids []int) error
BatchDelete(ctx context.Context, ids []int) error
}
type TokenServiceImpl struct {
tokenRepo dao.TokenRepository
}
func NewTokenService(tokenRepo dao.TokenRepository) TokenService {
return &TokenServiceImpl{tokenRepo: tokenRepo}
}
func (s *TokenServiceImpl) Create(ctx context.Context, token *model.Token) error {
if token.Key == "" {
token.Key = "team-" + strings.ReplaceAll(uuid.New().String(), "-", "")
}
return s.tokenRepo.Create(ctx, token)
}
func (s *TokenServiceImpl) GetByID(ctx context.Context, id int) (*model.Token, error) {
return s.tokenRepo.GetByID(ctx, id)
}
func (s *TokenServiceImpl) GetByKey(ctx context.Context, key string) (*model.Token, error) {
return s.tokenRepo.GetByKey(ctx, key)
}
func (s *TokenServiceImpl) GetByUserID(ctx context.Context, userID int) (*model.Token, error) {
return s.tokenRepo.GetByUserID(ctx, userID)
}
func (s *TokenServiceImpl) Update(ctx context.Context, token *model.Token) error {
return s.tokenRepo.Update(ctx, token)
}
func (s *TokenServiceImpl) UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error {
return s.tokenRepo.UpdateWithCondition(ctx, token, filters, updates)
}
func (s *TokenServiceImpl) Delete(ctx context.Context, id int) error {
return s.tokenRepo.Delete(ctx, id)
}
func (s *TokenServiceImpl) Lists(ctx context.Context, offset, limit int) ([]model.Token, error) {
return s.tokenRepo.List(ctx, offset, limit)
}
func (s *TokenServiceImpl) ListsWithFilters(ctx context.Context, offset, limit int, filters map[string]interface{}) ([]model.Token, int64, error) {
return s.tokenRepo.ListWithFilters(ctx, offset, limit, filters)
}
func (s *TokenServiceImpl) Disable(ctx context.Context, id int) error {
return s.tokenRepo.Disable(ctx, id)
}
func (s *TokenServiceImpl) Enable(ctx context.Context, id int) error {
return s.tokenRepo.Enable(ctx, id)
}
func (s *TokenServiceImpl) BatchDisable(ctx context.Context, ids []int) error {
return s.tokenRepo.BatchDisable(ctx, ids)
}
func (s *TokenServiceImpl) BatchEnable(ctx context.Context, ids []int) error {
return s.tokenRepo.BatchEnable(ctx, ids)
}
func (s *TokenServiceImpl) BatchDelete(ctx context.Context, ids []int) error {
return s.tokenRepo.BatchDelete(ctx, ids)
}

137
team/service/usage.go Normal file
View File

@@ -0,0 +1,137 @@
package service
import (
"context"
"fmt"
"opencatd-open/team/dao"
dto "opencatd-open/team/dto/team"
"opencatd-open/team/model"
"time"
"log"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
var _ UsageService = (*usageService)(nil)
type UsageService interface {
// AsyncProcessUsage 异步处理使用记录
AsyncProcessUsage(usage *model.Usage)
ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error)
ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error)
ListByDateRange(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error)
Delete(ctx context.Context, id int64) error
}
type usageService struct {
db *gorm.DB
usageDAO dao.UsageRepository
dailyUsageDAO dao.DailyUsageRepository
usageChan chan *model.Usage // 用于异步处理的channel
ctx context.Context
}
func NewUsageService(ctx context.Context, db *gorm.DB, usageRepo dao.UsageRepository, dailyUsageRepo dao.DailyUsageRepository) UsageService {
srv := &usageService{
db: db,
usageDAO: usageRepo,
dailyUsageDAO: dailyUsageRepo,
usageChan: make(chan *model.Usage, 1000), // 设置合适的缓冲区大小
ctx: ctx,
}
// 启动异步处理goroutine
go srv.processUsageWorker()
return srv
}
func (s *usageService) AsyncProcessUsage(usage *model.Usage) {
select {
case s.usageChan <- usage:
// 成功发送到channel
default:
// channel已满记录错误日志
log.Println("usage channel is full, skip processing")
}
}
func (s *usageService) processUsageWorker() {
for {
select {
case usage := <-s.usageChan:
err := s.processUsage(usage)
if err != nil {
log.Println("processUsage error:", err)
}
case <-s.ctx.Done():
log.Println("processUsageWorker is exiting")
return
}
}
}
// processUsageWorker 异步处理worker
func (s *usageService) processUsage(usage *model.Usage) error {
err := s.db.Transaction(func(tx *gorm.DB) error {
// 1. 记录使用记录
if err := tx.WithContext(s.ctx).Create(usage).Error; err != nil {
return fmt.Errorf("create usage error: %w", err)
}
// 2. 更新每日统计upsert 操作)
dailyUsage := model.DailyUsage{
UserID: usage.UserID,
TokenID: usage.TokenID,
Capability: usage.Capability,
Date: time.Date(usage.Date.Year(), usage.Date.Month(), usage.Date.Day(), 0, 0, 0, 0, usage.Date.Location()),
Model: usage.Model,
Stream: usage.Stream,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
TotalTokens: usage.TotalTokens,
Cost: usage.Cost,
}
// 使用 OnConflict 实现 upsert
if err := tx.WithContext(s.ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "user_id"}, {Name: "token_id"}, {Name: "capability"}, {Name: "date"}}, // 唯一键
DoUpdates: clause.Assignments(map[string]interface{}{
"prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens),
"completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens),
"total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens),
"cost": gorm.Expr("cost + ?", usage.Cost),
}),
}).Create(&dailyUsage).Error; err != nil {
return fmt.Errorf("upsert daily usage error: %w", err)
}
// 3. 更新用户额度
if err := tx.WithContext(s.ctx).Model(&model.User{}).Where("id = ?", usage.UserID).Update("quota", gorm.Expr("quota - ?", usage.Cost)).Error; err != nil {
return fmt.Errorf("update user quota error: %w", err)
}
return nil
})
return err
}
func (s *usageService) ListByUserID(ctx context.Context, userID int64, limit int, offset int) ([]*model.Usage, error) {
return s.usageDAO.ListByUserID(ctx, userID, limit, offset)
}
func (s *usageService) ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) {
return s.usageDAO.ListByCapability(ctx, capability, limit, offset)
}
func (s *usageService) ListByDateRange(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) {
return s.dailyUsageDAO.StatUserUsages(ctx, start, end, filters)
}
func (s *usageService) Delete(ctx context.Context, id int64) error {
return s.usageDAO.Delete(ctx, id)
}

702
team/service/user.go Normal file
View File

@@ -0,0 +1,702 @@
package service
import (
"context"
"errors"
"opencatd-open/team/consts"
"opencatd-open/team/dao"
"opencatd-open/team/model"
"regexp"
"strings"
"time"
"golang.org/x/crypto/bcrypt"
"golang.org/x/exp/rand"
"gorm.io/gorm"
)
var (
ErrUserNotFound = errors.New("user not found")
ErrInvalidUserInput = errors.New("invalid user input")
ErrUserExists = errors.New("user already exists")
ErrInvalidPassword = errors.New("invalid password format")
ErrPermissionDenied = errors.New("permission denied")
ErrInvalidOperation = errors.New("invalid operation")
ErrTransactionFailed = errors.New("transaction failed")
)
// PasswordPolicy 定义密码策略
type PasswordPolicy struct {
MinLength int
MaxLength int
NeedNumber bool
NeedUpper bool
NeedLower bool
NeedSymbol bool
}
// 生成随机数字
func generateNumber() string {
rand.Seed(uint64(time.Now().UnixNano()))
return string('0' + rand.Intn(10))
}
// 生成随机大写字母
func generateUpper() string {
rand.Seed(uint64(time.Now().UnixNano()))
return string('A' + rand.Intn(26))
}
// 生成随机小写字母
func generateLower() string {
rand.Seed(uint64(time.Now().UnixNano()))
return string('a' + rand.Intn(26))
}
// 生成随机特殊符号
func generateSymbol() string {
rand.Seed(uint64(time.Now().UnixNano()))
symbols := "!@#$%^&*"
return string(symbols[rand.Intn(len(symbols))])
}
// GeneratePassword 根据密码策略生成密码
func GeneratePassword(policy PasswordPolicy) string {
rand.Seed(uint64(time.Now().UnixNano()))
// 确保满足所有必须的字符类型
var password string
if policy.NeedNumber {
password += generateNumber()
}
if policy.NeedUpper {
password += generateUpper()
}
if policy.NeedLower {
password += generateLower()
}
if policy.NeedSymbol {
password += generateSymbol()
}
// 计算还需要多少个字符
remainingLength := policy.MinLength - len(password)
if remainingLength < 0 {
remainingLength = 0
}
// 剩余长度随机生成密码字符
for i := 0; i < remainingLength; i++ {
randType := rand.Intn(4) // 0:数字, 1:大写, 2:小写, 3:符号
switch randType {
case 0:
password += generateNumber()
case 1:
password += generateUpper()
case 2:
password += generateLower()
case 3:
password += generateSymbol()
}
}
// 如果密码长度超过最大值,则截断
if len(password) > policy.MaxLength {
password = password[:policy.MaxLength]
}
// 将密码打乱
passwordRune := []rune(password)
rand.Shuffle(len(passwordRune), func(i, j int) {
passwordRune[i], passwordRune[j] = passwordRune[j], passwordRune[i]
})
return string(passwordRune)
}
var _ UserService = (*userService)(nil)
// UserService 定义用户服务的接口
type UserService interface {
CreateUser(ctx context.Context, user *model.User) error
GetUser(ctx context.Context, id int64) (*model.User, error)
GetUserByUsername(ctx context.Context, username string) (*model.User, error)
UpdateUser(ctx context.Context, user *model.User, operatorID int64) error
DeleteUser(ctx context.Context, id int64, operatorID int64) error
ListUsers(ctx context.Context, page, pageSize int) ([]model.User, int64, error)
ListUsersWithFilters(ctx context.Context, page, pageSize int, filters map[string]interface{}) ([]model.User, int64, error)
EnableUser(ctx context.Context, id int64, operatorID int64) error
DisableUser(ctx context.Context, id int64, operatorID int64) error
BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error
BatchDisableUsers(ctx context.Context, ids []int64, operatorID int64) error
BatchDeleteUsers(ctx context.Context, ids []int64, operatorID int64) error
ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error
ResetPassword(ctx context.Context, userID int64, operatorID int64) error
ValidatePassword(password string) error
CheckPermission(ctx context.Context, requiredRole consts.UserRole) error
}
// userService 实现 UserService 接口
type userService struct {
userRepo dao.UserRepository
db *gorm.DB
pwdPolicy PasswordPolicy
}
// NewUserService 创建 UserService 实例
func NewUserService(userRepo dao.UserRepository, db *gorm.DB) UserService {
return &userService{
userRepo: userRepo,
db: db,
pwdPolicy: PasswordPolicy{
MinLength: 8,
MaxLength: 32,
NeedNumber: true,
NeedUpper: true,
NeedLower: true,
NeedSymbol: true,
},
}
}
// hashPassword 使用 bcrypt 加密密码
func (s *userService) hashPassword(password string) (string, error) {
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hashedBytes), nil
}
// comparePasswords 比较密码
func (s *userService) comparePasswords(hashedPassword, password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
return err == nil
}
// ValidatePassword 验证密码是否符合策略
func (s *userService) ValidatePassword(password string) error {
if len(password) < s.pwdPolicy.MinLength || len(password) > s.pwdPolicy.MaxLength {
return ErrInvalidPassword
}
if s.pwdPolicy.NeedNumber && !regexp.MustCompile(`[0-9]`).MatchString(password) {
return ErrInvalidPassword
}
if s.pwdPolicy.NeedUpper && !regexp.MustCompile(`[A-Z]`).MatchString(password) {
return ErrInvalidPassword
}
if s.pwdPolicy.NeedLower && !regexp.MustCompile(`[a-z]`).MatchString(password) {
return ErrInvalidPassword
}
if s.pwdPolicy.NeedSymbol && !regexp.MustCompile(`[!@#$%^&*]`).MatchString(password) {
return ErrInvalidPassword
}
return nil
}
// CheckPermission 检查用户权限
func (s *userService) CheckPermission(ctx context.Context, requiredRole consts.UserRole) error {
userToken := ctx.Value("Token").(*model.Token)
// 检查用户角色
if userToken.User.Role < int(requiredRole) {
return ErrPermissionDenied
}
return nil
}
// withTransaction 事务处理封装
func (s *userService) withTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error {
tx := s.db.WithContext(ctx).Begin()
if tx.Error != nil {
return ErrTransactionFailed
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
}()
if err := fn(tx); err != nil {
tx.Rollback()
return err
}
if err := tx.Commit().Error; err != nil {
return ErrTransactionFailed
}
return nil
}
// CreateUser 创建用户
func (s *userService) CreateUser(ctx context.Context, user *model.User) error {
if user == nil {
return ErrInvalidUserInput
}
if user.Password == "" {
user.Password = GeneratePassword(s.pwdPolicy)
}
// 使用事务处理
return s.withTransaction(ctx, func(tx *gorm.DB) error {
// 检查用户名是否已存在
// _, err := s.userRepo.GetByID(user.ID)
// if err != nil {
// return err
// }
// 加密密码
hashedPassword, err := s.hashPassword(user.Password)
if err != nil {
return err
}
user.Password = hashedPassword
return s.userRepo.Create(user)
})
}
// GetUser 根据 ID 获取用户
func (s *userService) GetUser(ctx context.Context, id int64) (*model.User, error) {
if id <= 0 {
return nil, ErrInvalidUserInput
}
user, err := s.userRepo.GetByID(id)
if err != nil {
return nil, err // 返回其他数据库错误
}
// 处理返回结果,清除敏感信息
user.Password = "" // 清除密码信息
return user, nil
}
// GetUserByUsername 根据用户名获取用户
func (s *userService) GetUserByUsername(ctx context.Context, username string) (*model.User, error) {
if username == "" {
return nil, ErrInvalidUserInput
}
user, err := s.userRepo.GetByUsername(username)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, err // 返回其他数据库错误
}
// 处理返回结果,清除敏感信息
user.Password = "" // 清除密码信息
return user, nil
}
// UpdateUser 更新用户信息
func (s *userService) UpdateUser(ctx context.Context, user *model.User, operatorID int64) error {
if user == nil || user.ID <= 0 {
return ErrInvalidUserInput
}
return s.withTransaction(ctx, func(tx *gorm.DB) error {
// 检查用户是否存在
existingUser, err := s.userRepo.GetByID(user.ID)
if err != nil {
return err
}
// 如果修改了用户名,检查新用户名是否已存在
if user.Username != existingUser.Username {
tmpUser, err := s.userRepo.GetByUsername(user.Username)
if err == nil && tmpUser != nil && tmpUser.ID != user.ID {
return ErrUserExists
}
}
// 保持原有密码
user.Password = existingUser.Password
user.UpdatedAt = time.Now().Unix()
return s.userRepo.Update(user)
})
}
// ChangePassword 修改密码
func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error {
// 验证新密码
if err := s.ValidatePassword(newPassword); err != nil {
return err
}
return s.withTransaction(ctx, func(tx *gorm.DB) error {
user, err := s.userRepo.GetByID(userID)
if err != nil {
return ErrUserNotFound
}
// 验证旧密码
if !s.comparePasswords(user.Password, oldPassword) {
return ErrInvalidPassword
}
// 加密新密码
hashedPassword, err := s.hashPassword(newPassword)
if err != nil {
return err
}
user.Password = hashedPassword
user.UpdatedAt = time.Now().Unix()
return s.userRepo.Update(user)
})
}
// ResetPassword 重置密码
func (s *userService) ResetPassword(ctx context.Context, userID int64, operatorID int64) error {
return s.withTransaction(ctx, func(tx *gorm.DB) error {
user, err := s.userRepo.GetByID(userID)
if err != nil {
return ErrUserNotFound
}
// 生成随机密码
newPassword := generateRandomPassword()
hashedPassword, err := s.hashPassword(newPassword)
if err != nil {
return err
}
user.Password = hashedPassword
user.UpdatedAt = time.Now().Unix()
// TODO: 发送新密码给用户邮箱
return s.userRepo.Update(user)
})
}
// ListUsers 获取用户列表(增加过滤功能)
func (s *userService) ListUsers(ctx context.Context, page, pageSize int) ([]model.User, int64, error) {
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 10
}
offset := (page - 1) * pageSize
users, err := s.userRepo.List(offset, pageSize)
if err != nil {
return nil, 0, err
}
var total int64 = 0
return users, total, nil
}
// ListUsers 获取用户列表(增加过滤功能)
func (s *userService) ListUsersWithFilters(ctx context.Context, page, pageSize int, filters map[string]interface{}) ([]model.User, int64, error) {
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 10
}
offset := (page - 1) * pageSize
// 使用新的 ListWithFilters 方法
users, total, err := s.userRepo.ListWithFilters(offset, pageSize, filters)
if err != nil {
return nil, 0, err
}
return users, total, nil
}
// generateRandomPassword 生成随机密码
func generateRandomPassword() string {
const (
lowerChars = "abcdefghijklmnopqrstuvwxyz"
upperChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
numberChars = "0123456789"
specialChars = "!@#$%^&*"
)
// rand.NewSource(uint64(time.Now().UnixNano()))
// 确保每种字符都至少出现一次
password := []string{
string(lowerChars[rand.Intn(len(lowerChars))]),
string(upperChars[rand.Intn(len(upperChars))]),
string(numberChars[rand.Intn(len(numberChars))]),
string(specialChars[rand.Intn(len(specialChars))]),
}
// 所有可用字符
allChars := lowerChars + upperChars + numberChars + specialChars
// 生成剩余的12个字符
for i := 0; i < 12; i++ {
password = append(password, string(allChars[rand.Intn(len(allChars))]))
}
// 打乱密码字符顺序
rand.Shuffle(len(password), func(i, j int) {
password[i], password[j] = password[j], password[i]
})
return strings.Join(password, "")
}
// DeleteUser 删除用户
func (s *userService) DeleteUser(ctx context.Context, id int64, operatorID int64) error {
// 检查参数
if id <= 0 {
return ErrInvalidUserInput
}
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
return err
}
// 不允许删除自己
if id == operatorID {
return ErrInvalidOperation
}
return s.withTransaction(ctx, func(tx *gorm.DB) error {
// 检查用户是否存在
user, err := s.userRepo.GetByID(id)
if err != nil {
return err
}
// 检查是否试图删除管理员
if user.Role == int(consts.RoleAdmin) {
return ErrPermissionDenied
}
return s.userRepo.Delete(id)
})
}
// EnableUser 启用用户
func (s *userService) EnableUser(ctx context.Context, id int64, operatorID int64) error {
// 检查参数
if id <= 0 {
return ErrInvalidUserInput
}
// 检查操作者权限
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
return err
}
return s.withTransaction(ctx, func(tx *gorm.DB) error {
// 检查用户是否存在
user, err := s.userRepo.GetByID(id)
if err != nil {
return ErrUserNotFound
}
// 如果用户已经是启用状态,返回成功
if user.Status == consts.StatusEnabled {
return nil
}
return s.userRepo.Enable(id)
})
}
// DisableUser 禁用用户
func (s *userService) DisableUser(ctx context.Context, id int64, operatorID int64) error {
// 检查参数
if id <= 0 {
return ErrInvalidUserInput
}
// 检查操作者权限
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
return err
}
// 不允许禁用自己
if id == operatorID {
return ErrInvalidOperation
}
return s.withTransaction(ctx, func(tx *gorm.DB) error {
// 检查用户是否存在
user, err := s.userRepo.GetByID(id)
if err != nil {
return ErrUserNotFound
}
// 检查是否试图禁用超级管理员
if user.Role == int(consts.RoleAdmin) {
return ErrPermissionDenied
}
// 如果用户已经是禁用状态,返回成功
if user.Status == consts.StatusDisabled {
return nil
}
return s.userRepo.Disable(id)
})
}
// BatchEnableUsers 批量启用用户
func (s *userService) BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error {
// 检查参数
if len(ids) == 0 {
return ErrInvalidUserInput
}
// 检查操作者权限
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
return err
}
return s.withTransaction(ctx, func(tx *gorm.DB) error {
// 检查所有用户是否存在,并收集当前状态
enabledUsers := make([]int64, 0)
for _, id := range ids {
user, err := s.userRepo.GetByID(id)
if err != nil {
return ErrUserNotFound
}
if user.Status == consts.StatusEnabled {
enabledUsers = append(enabledUsers, id)
}
}
// 如果所有用户都已经是启用状态,返回成功
if len(enabledUsers) == len(ids) {
return nil
}
// 过滤掉已经启用的用户,只处理需要启用的用户
toEnableIds := make([]int64, 0)
for _, id := range ids {
if !contains(enabledUsers, id) {
toEnableIds = append(toEnableIds, id)
}
}
if len(toEnableIds) > 0 {
return s.userRepo.BatchEnable(toEnableIds)
}
return nil
})
}
// BatchDisableUsers 批量禁用用户
func (s *userService) BatchDisableUsers(ctx context.Context, ids []int64, operatorID int64) error {
// 检查参数
if len(ids) == 0 {
return ErrInvalidUserInput
}
// 检查操作者权限
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
return err
}
// 不允许包含自己
if contains(ids, operatorID) {
return ErrInvalidOperation
}
return s.withTransaction(ctx, func(tx *gorm.DB) error {
// 检查所有用户是否存在
disabledUsers := make([]int64, 0)
for _, id := range ids {
user, err := s.userRepo.GetByID(id)
if err != nil {
return ErrUserNotFound
}
// 不允许禁用管理员
if user.Role == int(consts.RoleAdmin) {
return ErrPermissionDenied
}
if user.Status == consts.StatusDisabled {
disabledUsers = append(disabledUsers, id)
}
}
// 如果所有用户都已经是禁用状态,返回成功
if len(disabledUsers) == len(ids) {
return nil
}
// 过滤掉已经禁用的用户,只处理需要禁用的用户
toDisableIds := make([]int64, 0)
for _, id := range ids {
if !contains(disabledUsers, id) {
toDisableIds = append(toDisableIds, id)
}
}
if len(toDisableIds) > 0 {
return s.userRepo.BatchDisable(toDisableIds)
}
return nil
})
}
// BatchDeleteUsers 批量删除用户
func (s *userService) BatchDeleteUsers(ctx context.Context, ids []int64, operatorID int64) error {
// 检查参数
if len(ids) == 0 {
return ErrInvalidUserInput
}
// 检查操作者权限
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
return err
}
// 不允许包含自己
if contains(ids, operatorID) {
return ErrInvalidOperation
}
return s.withTransaction(ctx, func(tx *gorm.DB) error {
// 检查所有用户是否存在,并确保不会删除管理员
for _, id := range ids {
user, err := s.userRepo.GetByID(id)
if err != nil {
return ErrUserNotFound
}
if user.Role == int(consts.RoleAdmin) {
return ErrPermissionDenied
}
}
return s.userRepo.BatchDelete(ids)
})
}
// contains 检查切片中是否包含特定值
func contains(slice []int64, item int64) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}