package service import ( "context" "errors" "opencatd-open/internal/consts" "opencatd-open/internal/dao" "opencatd-open/internal/model" "regexp" "strings" "time" "golang.org/x/crypto/bcrypt" "golang.org/x/exp/rand" "gorm.io/gorm" ) var ( ErrUserNotFound = errors.New("user not found") ErrInvalidUserInput = errors.New("invalid user input") ErrUserExists = errors.New("user already exists") ErrInvalidPassword = errors.New("invalid password format") ErrPermissionDenied = errors.New("permission denied") ErrInvalidOperation = errors.New("invalid operation") ErrTransactionFailed = errors.New("transaction failed") ) // PasswordPolicy 定义密码策略 type PasswordPolicy struct { MinLength int MaxLength int NeedNumber bool NeedUpper bool NeedLower bool NeedSymbol bool } // 生成随机数字 func generateNumber() string { rand.Seed(uint64(time.Now().UnixNano())) return string('0' + rand.Intn(10)) } // 生成随机大写字母 func generateUpper() string { rand.Seed(uint64(time.Now().UnixNano())) return string('A' + rand.Intn(26)) } // 生成随机小写字母 func generateLower() string { rand.Seed(uint64(time.Now().UnixNano())) return string('a' + rand.Intn(26)) } // 生成随机特殊符号 func generateSymbol() string { rand.Seed(uint64(time.Now().UnixNano())) symbols := "!@#$%^&*" return string(symbols[rand.Intn(len(symbols))]) } // GeneratePassword 根据密码策略生成密码 func GeneratePassword(policy PasswordPolicy) string { rand.Seed(uint64(time.Now().UnixNano())) // 确保满足所有必须的字符类型 var password string if policy.NeedNumber { password += generateNumber() } if policy.NeedUpper { password += generateUpper() } if policy.NeedLower { password += generateLower() } if policy.NeedSymbol { password += generateSymbol() } // 计算还需要多少个字符 remainingLength := policy.MinLength - len(password) if remainingLength < 0 { remainingLength = 0 } // 剩余长度随机生成密码字符 for i := 0; i < remainingLength; i++ { randType := rand.Intn(4) // 0:数字, 1:大写, 2:小写, 3:符号 switch randType { case 0: password += generateNumber() case 1: password += generateUpper() case 2: password += generateLower() case 3: password += generateSymbol() } } // 如果密码长度超过最大值,则截断 if len(password) > policy.MaxLength { password = password[:policy.MaxLength] } // 将密码打乱 passwordRune := []rune(password) rand.Shuffle(len(passwordRune), func(i, j int) { passwordRune[i], passwordRune[j] = passwordRune[j], passwordRune[i] }) return string(passwordRune) } var _ UserService = (*userService)(nil) // UserService 定义用户服务的接口 type UserService interface { CreateUser(ctx context.Context, user *model.User) error GetUser(ctx context.Context, id int64) (*model.User, error) GetUserByUsername(ctx context.Context, username string) (*model.User, error) UpdateUser(ctx context.Context, user *model.User, operatorID int64) error DeleteUser(ctx context.Context, id int64, operatorID int64) error ListUsers(ctx context.Context, limit, offset int, active string) ([]model.User, error) // EnableUser(ctx context.Context, id int64, operatorID int64) error // DisableUser(ctx context.Context, id int64, operatorID int64) error BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error BatchDisableUsers(ctx context.Context, ids []int64, operatorID int64) error BatchDeleteUsers(ctx context.Context, ids []int64, operatorID int64) error ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error ResetPassword(ctx context.Context, userID int64, operatorID int64) error ValidatePassword(password string) error CheckPermission(ctx context.Context, requiredRole consts.UserRole) error } // userService 实现 UserService 接口 type userService struct { userRepo dao.UserRepository db *gorm.DB pwdPolicy PasswordPolicy } // NewUserService 创建 UserService 实例 func NewUserService(db *gorm.DB, userRepo dao.UserRepository) UserService { return &userService{ userRepo: userRepo, db: db, pwdPolicy: PasswordPolicy{ MinLength: 8, MaxLength: 32, NeedNumber: true, NeedUpper: true, NeedLower: true, NeedSymbol: true, }, } } // hashPassword 使用 bcrypt 加密密码 func (s *userService) hashPassword(password string) (string, error) { hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return "", err } return string(hashedBytes), nil } // comparePasswords 比较密码 func (s *userService) comparePasswords(hashedPassword, password string) bool { err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) return err == nil } // ValidatePassword 验证密码是否符合策略 func (s *userService) ValidatePassword(password string) error { if len(password) < s.pwdPolicy.MinLength || len(password) > s.pwdPolicy.MaxLength { return ErrInvalidPassword } if s.pwdPolicy.NeedNumber && !regexp.MustCompile(`[0-9]`).MatchString(password) { return ErrInvalidPassword } if s.pwdPolicy.NeedUpper && !regexp.MustCompile(`[A-Z]`).MatchString(password) { return ErrInvalidPassword } if s.pwdPolicy.NeedLower && !regexp.MustCompile(`[a-z]`).MatchString(password) { return ErrInvalidPassword } if s.pwdPolicy.NeedSymbol && !regexp.MustCompile(`[!@#$%^&*]`).MatchString(password) { return ErrInvalidPassword } return nil } // CheckPermission 检查用户权限 func (s *userService) CheckPermission(ctx context.Context, requiredRole consts.UserRole) error { userToken := ctx.Value("Token").(*model.Token) // 检查用户角色 if *userToken.User.Role < requiredRole { return ErrPermissionDenied } return nil } // withTransaction 事务处理封装 func (s *userService) withTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error { tx := s.db.WithContext(ctx).Begin() if tx.Error != nil { return ErrTransactionFailed } defer func() { if r := recover(); r != nil { tx.Rollback() panic(r) } }() if err := fn(tx); err != nil { tx.Rollback() return err } if err := tx.Commit().Error; err != nil { return ErrTransactionFailed } return nil } // CreateUser 创建用户 func (s *userService) CreateUser(ctx context.Context, user *model.User) error { if user == nil { return ErrInvalidUserInput } if user.Password == "" { user.Password = GeneratePassword(s.pwdPolicy) } // 使用事务处理 return s.withTransaction(ctx, func(tx *gorm.DB) error { // 检查用户名是否已存在 // _, err := s.userRepo.GetByID(user.ID) // if err != nil { // return err // } // 加密密码 hashedPassword, err := s.hashPassword(user.Password) if err != nil { return err } user.Password = hashedPassword return s.userRepo.Create(user) }) } // GetUser 根据 ID 获取用户 func (s *userService) GetUser(ctx context.Context, id int64) (*model.User, error) { if id <= 0 { return nil, ErrInvalidUserInput } user, err := s.userRepo.GetByID(id) if err != nil { return nil, err // 返回其他数据库错误 } // 处理返回结果,清除敏感信息 user.Password = "" // 清除密码信息 return user, nil } // GetUserByUsername 根据用户名获取用户 func (s *userService) GetUserByUsername(ctx context.Context, username string) (*model.User, error) { if username == "" { return nil, ErrInvalidUserInput } user, err := s.userRepo.GetByUsername(username) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrUserNotFound } return nil, err // 返回其他数据库错误 } // 处理返回结果,清除敏感信息 user.Password = "" // 清除密码信息 return user, nil } // UpdateUser 更新用户信息 func (s *userService) UpdateUser(ctx context.Context, user *model.User, operatorID int64) error { if user == nil || user.ID <= 0 { return ErrInvalidUserInput } return s.withTransaction(ctx, func(tx *gorm.DB) error { // 检查用户是否存在 existingUser, err := s.userRepo.GetByID(user.ID) if err != nil { return err } // 如果修改了用户名,检查新用户名是否已存在 if user.Username != existingUser.Username { tmpUser, err := s.userRepo.GetByUsername(user.Username) if err == nil && tmpUser != nil && tmpUser.ID != user.ID { return ErrUserExists } } // 保持原有密码 user.Password = existingUser.Password user.UpdatedAt = time.Now().Unix() return s.userRepo.Update(user) }) } // ChangePassword 修改密码 func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error { // 验证新密码 if err := s.ValidatePassword(newPassword); err != nil { return err } return s.withTransaction(ctx, func(tx *gorm.DB) error { user, err := s.userRepo.GetByID(userID) if err != nil { return ErrUserNotFound } // 验证旧密码 if !s.comparePasswords(user.Password, oldPassword) { return ErrInvalidPassword } // 加密新密码 hashedPassword, err := s.hashPassword(newPassword) if err != nil { return err } user.Password = hashedPassword user.UpdatedAt = time.Now().Unix() return s.userRepo.Update(user) }) } // ResetPassword 重置密码 func (s *userService) ResetPassword(ctx context.Context, userID int64, operatorID int64) error { return s.withTransaction(ctx, func(tx *gorm.DB) error { user, err := s.userRepo.GetByID(userID) if err != nil { return ErrUserNotFound } // 生成随机密码 newPassword := generateRandomPassword() hashedPassword, err := s.hashPassword(newPassword) if err != nil { return err } user.Password = hashedPassword user.UpdatedAt = time.Now().Unix() // TODO: 发送新密码给用户邮箱 return s.userRepo.Update(user) }) } // ListUsers 获取用户列表(增加过滤功能) func (s *userService) ListUsers(ctx context.Context, limit, offset int, active string) ([]model.User, error) { if limit < 0 { limit = 20 } if offset < 0 { offset = 0 } var users []model.User var err error if active != "" { users, _, err = s.userRepo.List(limit, offset, map[string]interface{}{"active in ?": strings.Split(active, ",")}) } else { users, _, err = s.userRepo.List(limit, offset, nil) } return users, err } // generateRandomPassword 生成随机密码 func generateRandomPassword() string { const ( lowerChars = "abcdefghijklmnopqrstuvwxyz" upperChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" numberChars = "0123456789" specialChars = "!@#$%^&*" ) // rand.NewSource(uint64(time.Now().UnixNano())) // 确保每种字符都至少出现一次 password := []string{ string(lowerChars[rand.Intn(len(lowerChars))]), string(upperChars[rand.Intn(len(upperChars))]), string(numberChars[rand.Intn(len(numberChars))]), string(specialChars[rand.Intn(len(specialChars))]), } // 所有可用字符 allChars := lowerChars + upperChars + numberChars + specialChars // 生成剩余的12个字符 for i := 0; i < 12; i++ { password = append(password, string(allChars[rand.Intn(len(allChars))])) } // 打乱密码字符顺序 rand.Shuffle(len(password), func(i, j int) { password[i], password[j] = password[j], password[i] }) return strings.Join(password, "") } // DeleteUser 删除用户 func (s *userService) DeleteUser(ctx context.Context, id int64, operatorID int64) error { // 检查参数 if id <= 0 { return ErrInvalidUserInput } if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil { return err } // 不允许删除自己 if id == operatorID { return ErrInvalidOperation } return s.withTransaction(ctx, func(tx *gorm.DB) error { // 检查用户是否存在 user, err := s.userRepo.GetByID(id) if err != nil { return err } // 检查是否试图删除管理员 if *user.Role == consts.RoleAdmin { return ErrPermissionDenied } return s.userRepo.Delete(id) }) } // EnableUser 启用用户 // func (s *userService) EnableUser(ctx context.Context, id int64, operatorID int64) error { // // 检查参数 // if id <= 0 { // return ErrInvalidUserInput // } // // 检查操作者权限 // if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil { // return err // } // return s.withTransaction(ctx, func(tx *gorm.DB) error { // // 检查用户是否存在 // user, err := s.userRepo.GetByID(id) // if err != nil { // return ErrUserNotFound // } // // 如果用户已经是启用状态,返回成功 // if user.Status == consts.StatusEnabled { // return nil // } // return s.userRepo.Enable(id) // }) // } // DisableUser 禁用用户 // func (s *userService) DisableUser(ctx context.Context, id int64, operatorID int64) error { // // 检查参数 // if id <= 0 { // return ErrInvalidUserInput // } // // 检查操作者权限 // if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil { // return err // } // // 不允许禁用自己 // if id == operatorID { // return ErrInvalidOperation // } // return s.withTransaction(ctx, func(tx *gorm.DB) error { // // 检查用户是否存在 // user, err := s.userRepo.GetByID(id) // if err != nil { // return ErrUserNotFound // } // // 检查是否试图禁用超级管理员 // if user.Role == consts.RoleAdmin { // return ErrPermissionDenied // } // // 如果用户已经是禁用状态,返回成功 // if user.Status == consts.StatusDisabled { // return nil // } // return s.userRepo.Disable(id) // }) // } // BatchEnableUsers 批量启用用户 func (s *userService) BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error { // 检查参数 if len(ids) == 0 { return ErrInvalidUserInput } // 检查操作者权限 if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil { return err } return s.withTransaction(ctx, func(tx *gorm.DB) error { // 检查所有用户是否存在,并收集当前状态 enabledUsers := make([]int64, 0) for _, id := range ids { user, err := s.userRepo.GetByID(id) if err != nil { return ErrUserNotFound } if *user.Active == true { enabledUsers = append(enabledUsers, id) } } // 如果所有用户都已经是启用状态,返回成功 if len(enabledUsers) == len(ids) { return nil } // 过滤掉已经启用的用户,只处理需要启用的用户 toEnableIds := make([]int64, 0) for _, id := range ids { if !contains(enabledUsers, id) { toEnableIds = append(toEnableIds, id) } } if len(toEnableIds) > 0 { return s.userRepo.BatchEnable(toEnableIds, nil) } return nil }) } // BatchDisableUsers 批量禁用用户 func (s *userService) BatchDisableUsers(ctx context.Context, ids []int64, operatorID int64) error { // 检查参数 if len(ids) == 0 { return ErrInvalidUserInput } // 检查操作者权限 if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil { return err } // 不允许包含自己 if contains(ids, operatorID) { return ErrInvalidOperation } return s.withTransaction(ctx, func(tx *gorm.DB) error { // 检查所有用户是否存在 disabledUsers := make([]int64, 0) for _, id := range ids { user, err := s.userRepo.GetByID(id) if err != nil { return ErrUserNotFound } // 不允许禁用管理员 if *user.Role == consts.RoleAdmin { return ErrPermissionDenied } if user.Status == consts.StatusDisabled { disabledUsers = append(disabledUsers, id) } } // 如果所有用户都已经是禁用状态,返回成功 if len(disabledUsers) == len(ids) { return nil } // 过滤掉已经禁用的用户,只处理需要禁用的用户 toDisableIds := make([]int64, 0) for _, id := range ids { if !contains(disabledUsers, id) { toDisableIds = append(toDisableIds, id) } } if len(toDisableIds) > 0 { return s.userRepo.BatchDisable(toDisableIds, nil) } return nil }) } // BatchDeleteUsers 批量删除用户 func (s *userService) BatchDeleteUsers(ctx context.Context, ids []int64, operatorID int64) error { // 检查参数 if len(ids) == 0 { return ErrInvalidUserInput } // 检查操作者权限 if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil { return err } // 不允许包含自己 if contains(ids, operatorID) { return ErrInvalidOperation } return s.withTransaction(ctx, func(tx *gorm.DB) error { // 检查所有用户是否存在,并确保不会删除管理员 for _, id := range ids { user, err := s.userRepo.GetByID(id) if err != nil { return ErrUserNotFound } if *user.Role == consts.RoleAdmin { return ErrPermissionDenied } } return s.userRepo.BatchDelete(ids, nil) }) } // contains 检查切片中是否包含特定值 func contains[T comparable](slice []T, item T) bool { for _, s := range slice { if s == item { return true } } return false }