680 lines
17 KiB
Go
680 lines
17 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"opencatd-open/internal/consts"
|
|
"opencatd-open/internal/dao"
|
|
"opencatd-open/internal/model"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
"golang.org/x/exp/rand"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var (
|
|
ErrUserNotFound = errors.New("user not found")
|
|
ErrInvalidUserInput = errors.New("invalid user input")
|
|
ErrUserExists = errors.New("user already exists")
|
|
ErrInvalidPassword = errors.New("invalid password format")
|
|
ErrPermissionDenied = errors.New("permission denied")
|
|
ErrInvalidOperation = errors.New("invalid operation")
|
|
ErrTransactionFailed = errors.New("transaction failed")
|
|
)
|
|
|
|
// PasswordPolicy 定义密码策略
|
|
type PasswordPolicy struct {
|
|
MinLength int
|
|
MaxLength int
|
|
NeedNumber bool
|
|
NeedUpper bool
|
|
NeedLower bool
|
|
NeedSymbol bool
|
|
}
|
|
|
|
// 生成随机数字
|
|
func generateNumber() string {
|
|
rand.Seed(uint64(time.Now().UnixNano()))
|
|
return string('0' + rand.Intn(10))
|
|
}
|
|
|
|
// 生成随机大写字母
|
|
func generateUpper() string {
|
|
rand.Seed(uint64(time.Now().UnixNano()))
|
|
return string('A' + rand.Intn(26))
|
|
}
|
|
|
|
// 生成随机小写字母
|
|
func generateLower() string {
|
|
rand.Seed(uint64(time.Now().UnixNano()))
|
|
return string('a' + rand.Intn(26))
|
|
}
|
|
|
|
// 生成随机特殊符号
|
|
func generateSymbol() string {
|
|
rand.Seed(uint64(time.Now().UnixNano()))
|
|
symbols := "!@#$%^&*"
|
|
return string(symbols[rand.Intn(len(symbols))])
|
|
}
|
|
|
|
// GeneratePassword 根据密码策略生成密码
|
|
func GeneratePassword(policy PasswordPolicy) string {
|
|
rand.Seed(uint64(time.Now().UnixNano()))
|
|
|
|
// 确保满足所有必须的字符类型
|
|
var password string
|
|
if policy.NeedNumber {
|
|
password += generateNumber()
|
|
}
|
|
if policy.NeedUpper {
|
|
password += generateUpper()
|
|
}
|
|
if policy.NeedLower {
|
|
password += generateLower()
|
|
}
|
|
if policy.NeedSymbol {
|
|
password += generateSymbol()
|
|
}
|
|
|
|
// 计算还需要多少个字符
|
|
remainingLength := policy.MinLength - len(password)
|
|
if remainingLength < 0 {
|
|
remainingLength = 0
|
|
}
|
|
// 剩余长度随机生成密码字符
|
|
for i := 0; i < remainingLength; i++ {
|
|
randType := rand.Intn(4) // 0:数字, 1:大写, 2:小写, 3:符号
|
|
switch randType {
|
|
case 0:
|
|
password += generateNumber()
|
|
case 1:
|
|
password += generateUpper()
|
|
case 2:
|
|
password += generateLower()
|
|
case 3:
|
|
password += generateSymbol()
|
|
}
|
|
}
|
|
|
|
// 如果密码长度超过最大值,则截断
|
|
if len(password) > policy.MaxLength {
|
|
password = password[:policy.MaxLength]
|
|
}
|
|
|
|
// 将密码打乱
|
|
passwordRune := []rune(password)
|
|
rand.Shuffle(len(passwordRune), func(i, j int) {
|
|
passwordRune[i], passwordRune[j] = passwordRune[j], passwordRune[i]
|
|
})
|
|
|
|
return string(passwordRune)
|
|
}
|
|
|
|
var _ UserService = (*userService)(nil)
|
|
|
|
// UserService 定义用户服务的接口
|
|
type UserService interface {
|
|
CreateUser(ctx context.Context, user *model.User) error
|
|
GetUser(ctx context.Context, id int64) (*model.User, error)
|
|
GetUserByUsername(ctx context.Context, username string) (*model.User, error)
|
|
UpdateUser(ctx context.Context, user *model.User, operatorID int64) error
|
|
DeleteUser(ctx context.Context, id int64, operatorID int64) error
|
|
ListUsers(ctx context.Context, limit, offset int, active string) ([]model.User, error)
|
|
// EnableUser(ctx context.Context, id int64, operatorID int64) error
|
|
// DisableUser(ctx context.Context, id int64, operatorID int64) error
|
|
BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error
|
|
BatchDisableUsers(ctx context.Context, ids []int64, operatorID int64) error
|
|
BatchDeleteUsers(ctx context.Context, ids []int64, operatorID int64) error
|
|
ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error
|
|
ResetPassword(ctx context.Context, userID int64, operatorID int64) error
|
|
ValidatePassword(password string) error
|
|
CheckPermission(ctx context.Context, requiredRole consts.UserRole) error
|
|
}
|
|
|
|
// userService 实现 UserService 接口
|
|
type userService struct {
|
|
userRepo dao.UserRepository
|
|
db *gorm.DB
|
|
pwdPolicy PasswordPolicy
|
|
}
|
|
|
|
// NewUserService 创建 UserService 实例
|
|
func NewUserService(db *gorm.DB, userRepo dao.UserRepository) UserService {
|
|
return &userService{
|
|
userRepo: userRepo,
|
|
db: db,
|
|
pwdPolicy: PasswordPolicy{
|
|
MinLength: 8,
|
|
MaxLength: 32,
|
|
NeedNumber: true,
|
|
NeedUpper: true,
|
|
NeedLower: true,
|
|
NeedSymbol: true,
|
|
},
|
|
}
|
|
}
|
|
|
|
// hashPassword 使用 bcrypt 加密密码
|
|
func (s *userService) hashPassword(password string) (string, error) {
|
|
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(hashedBytes), nil
|
|
}
|
|
|
|
// comparePasswords 比较密码
|
|
func (s *userService) comparePasswords(hashedPassword, password string) bool {
|
|
err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
|
return err == nil
|
|
}
|
|
|
|
// ValidatePassword 验证密码是否符合策略
|
|
func (s *userService) ValidatePassword(password string) error {
|
|
if len(password) < s.pwdPolicy.MinLength || len(password) > s.pwdPolicy.MaxLength {
|
|
return ErrInvalidPassword
|
|
}
|
|
|
|
if s.pwdPolicy.NeedNumber && !regexp.MustCompile(`[0-9]`).MatchString(password) {
|
|
return ErrInvalidPassword
|
|
}
|
|
|
|
if s.pwdPolicy.NeedUpper && !regexp.MustCompile(`[A-Z]`).MatchString(password) {
|
|
return ErrInvalidPassword
|
|
}
|
|
|
|
if s.pwdPolicy.NeedLower && !regexp.MustCompile(`[a-z]`).MatchString(password) {
|
|
return ErrInvalidPassword
|
|
}
|
|
|
|
if s.pwdPolicy.NeedSymbol && !regexp.MustCompile(`[!@#$%^&*]`).MatchString(password) {
|
|
return ErrInvalidPassword
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CheckPermission 检查用户权限
|
|
func (s *userService) CheckPermission(ctx context.Context, requiredRole consts.UserRole) error {
|
|
userToken := ctx.Value("Token").(*model.Token)
|
|
|
|
// 检查用户角色
|
|
if *userToken.User.Role < requiredRole {
|
|
return ErrPermissionDenied
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// withTransaction 事务处理封装
|
|
func (s *userService) withTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error {
|
|
tx := s.db.WithContext(ctx).Begin()
|
|
if tx.Error != nil {
|
|
return ErrTransactionFailed
|
|
}
|
|
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
tx.Rollback()
|
|
panic(r)
|
|
}
|
|
}()
|
|
|
|
if err := fn(tx); err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
|
|
if err := tx.Commit().Error; err != nil {
|
|
return ErrTransactionFailed
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CreateUser 创建用户
|
|
func (s *userService) CreateUser(ctx context.Context, user *model.User) error {
|
|
if user == nil {
|
|
return ErrInvalidUserInput
|
|
}
|
|
|
|
if user.Password == "" {
|
|
user.Password = GeneratePassword(s.pwdPolicy)
|
|
}
|
|
|
|
// 使用事务处理
|
|
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
|
// 检查用户名是否已存在
|
|
// _, err := s.userRepo.GetByID(user.ID)
|
|
// if err != nil {
|
|
// return err
|
|
// }
|
|
|
|
// 加密密码
|
|
hashedPassword, err := s.hashPassword(user.Password)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
user.Password = hashedPassword
|
|
|
|
return s.userRepo.Create(user)
|
|
})
|
|
}
|
|
|
|
// GetUser 根据 ID 获取用户
|
|
func (s *userService) GetUser(ctx context.Context, id int64) (*model.User, error) {
|
|
if id <= 0 {
|
|
return nil, ErrInvalidUserInput
|
|
}
|
|
|
|
user, err := s.userRepo.GetByID(id)
|
|
if err != nil {
|
|
return nil, err // 返回其他数据库错误
|
|
}
|
|
// 处理返回结果,清除敏感信息
|
|
user.Password = "" // 清除密码信息
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// GetUserByUsername 根据用户名获取用户
|
|
func (s *userService) GetUserByUsername(ctx context.Context, username string) (*model.User, error) {
|
|
if username == "" {
|
|
return nil, ErrInvalidUserInput
|
|
}
|
|
|
|
user, err := s.userRepo.GetByUsername(username)
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrUserNotFound
|
|
}
|
|
return nil, err // 返回其他数据库错误
|
|
}
|
|
|
|
// 处理返回结果,清除敏感信息
|
|
user.Password = "" // 清除密码信息
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// UpdateUser 更新用户信息
|
|
func (s *userService) UpdateUser(ctx context.Context, user *model.User, operatorID int64) error {
|
|
if user == nil || user.ID <= 0 {
|
|
return ErrInvalidUserInput
|
|
}
|
|
|
|
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
|
// 检查用户是否存在
|
|
existingUser, err := s.userRepo.GetByID(user.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 如果修改了用户名,检查新用户名是否已存在
|
|
if user.Username != existingUser.Username {
|
|
tmpUser, err := s.userRepo.GetByUsername(user.Username)
|
|
if err == nil && tmpUser != nil && tmpUser.ID != user.ID {
|
|
return ErrUserExists
|
|
}
|
|
}
|
|
|
|
// 保持原有密码
|
|
user.Password = existingUser.Password
|
|
user.UpdatedAt = time.Now().Unix()
|
|
|
|
return s.userRepo.Update(user)
|
|
})
|
|
}
|
|
|
|
// ChangePassword 修改密码
|
|
func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error {
|
|
// 验证新密码
|
|
if err := s.ValidatePassword(newPassword); err != nil {
|
|
return err
|
|
}
|
|
|
|
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
|
user, err := s.userRepo.GetByID(userID)
|
|
if err != nil {
|
|
return ErrUserNotFound
|
|
}
|
|
|
|
// 验证旧密码
|
|
if !s.comparePasswords(user.Password, oldPassword) {
|
|
return ErrInvalidPassword
|
|
}
|
|
|
|
// 加密新密码
|
|
hashedPassword, err := s.hashPassword(newPassword)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
user.Password = hashedPassword
|
|
user.UpdatedAt = time.Now().Unix()
|
|
|
|
return s.userRepo.Update(user)
|
|
})
|
|
}
|
|
|
|
// ResetPassword 重置密码
|
|
func (s *userService) ResetPassword(ctx context.Context, userID int64, operatorID int64) error {
|
|
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
|
user, err := s.userRepo.GetByID(userID)
|
|
if err != nil {
|
|
return ErrUserNotFound
|
|
}
|
|
|
|
// 生成随机密码
|
|
newPassword := generateRandomPassword()
|
|
hashedPassword, err := s.hashPassword(newPassword)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
user.Password = hashedPassword
|
|
user.UpdatedAt = time.Now().Unix()
|
|
|
|
// TODO: 发送新密码给用户邮箱
|
|
|
|
return s.userRepo.Update(user)
|
|
})
|
|
}
|
|
|
|
// ListUsers 获取用户列表(增加过滤功能)
|
|
func (s *userService) ListUsers(ctx context.Context, limit, offset int, active string) ([]model.User, error) {
|
|
if limit < 0 {
|
|
limit = 20
|
|
}
|
|
if offset < 0 {
|
|
offset = 0
|
|
}
|
|
var users []model.User
|
|
var err error
|
|
if active != "" {
|
|
users, _, err = s.userRepo.List(limit, offset, map[string]interface{}{"active in ?": strings.Split(active, ",")})
|
|
} else {
|
|
users, _, err = s.userRepo.List(limit, offset, nil)
|
|
}
|
|
|
|
return users, err
|
|
}
|
|
|
|
// generateRandomPassword 生成随机密码
|
|
func generateRandomPassword() string {
|
|
const (
|
|
lowerChars = "abcdefghijklmnopqrstuvwxyz"
|
|
upperChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
|
numberChars = "0123456789"
|
|
specialChars = "!@#$%^&*"
|
|
)
|
|
// rand.NewSource(uint64(time.Now().UnixNano()))
|
|
|
|
// 确保每种字符都至少出现一次
|
|
password := []string{
|
|
string(lowerChars[rand.Intn(len(lowerChars))]),
|
|
string(upperChars[rand.Intn(len(upperChars))]),
|
|
string(numberChars[rand.Intn(len(numberChars))]),
|
|
string(specialChars[rand.Intn(len(specialChars))]),
|
|
}
|
|
|
|
// 所有可用字符
|
|
allChars := lowerChars + upperChars + numberChars + specialChars
|
|
|
|
// 生成剩余的12个字符
|
|
for i := 0; i < 12; i++ {
|
|
password = append(password, string(allChars[rand.Intn(len(allChars))]))
|
|
}
|
|
|
|
// 打乱密码字符顺序
|
|
rand.Shuffle(len(password), func(i, j int) {
|
|
password[i], password[j] = password[j], password[i]
|
|
})
|
|
|
|
return strings.Join(password, "")
|
|
}
|
|
|
|
// DeleteUser 删除用户
|
|
func (s *userService) DeleteUser(ctx context.Context, id int64, operatorID int64) error {
|
|
// 检查参数
|
|
if id <= 0 {
|
|
return ErrInvalidUserInput
|
|
}
|
|
|
|
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
|
return err
|
|
}
|
|
|
|
// 不允许删除自己
|
|
if id == operatorID {
|
|
return ErrInvalidOperation
|
|
}
|
|
|
|
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
|
// 检查用户是否存在
|
|
user, err := s.userRepo.GetByID(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 检查是否试图删除管理员
|
|
if *user.Role == consts.RoleAdmin {
|
|
return ErrPermissionDenied
|
|
}
|
|
|
|
return s.userRepo.Delete(id)
|
|
})
|
|
}
|
|
|
|
// EnableUser 启用用户
|
|
// func (s *userService) EnableUser(ctx context.Context, id int64, operatorID int64) error {
|
|
// // 检查参数
|
|
// if id <= 0 {
|
|
// return ErrInvalidUserInput
|
|
// }
|
|
|
|
// // 检查操作者权限
|
|
// if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
|
// return err
|
|
// }
|
|
|
|
// return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
|
// // 检查用户是否存在
|
|
// user, err := s.userRepo.GetByID(id)
|
|
// if err != nil {
|
|
// return ErrUserNotFound
|
|
// }
|
|
|
|
// // 如果用户已经是启用状态,返回成功
|
|
// if user.Status == consts.StatusEnabled {
|
|
// return nil
|
|
// }
|
|
|
|
// return s.userRepo.Enable(id)
|
|
// })
|
|
// }
|
|
|
|
// DisableUser 禁用用户
|
|
// func (s *userService) DisableUser(ctx context.Context, id int64, operatorID int64) error {
|
|
// // 检查参数
|
|
// if id <= 0 {
|
|
// return ErrInvalidUserInput
|
|
// }
|
|
|
|
// // 检查操作者权限
|
|
// if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
|
// return err
|
|
// }
|
|
|
|
// // 不允许禁用自己
|
|
// if id == operatorID {
|
|
// return ErrInvalidOperation
|
|
// }
|
|
|
|
// return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
|
// // 检查用户是否存在
|
|
// user, err := s.userRepo.GetByID(id)
|
|
// if err != nil {
|
|
// return ErrUserNotFound
|
|
// }
|
|
|
|
// // 检查是否试图禁用超级管理员
|
|
// if user.Role == consts.RoleAdmin {
|
|
// return ErrPermissionDenied
|
|
// }
|
|
|
|
// // 如果用户已经是禁用状态,返回成功
|
|
// if user.Status == consts.StatusDisabled {
|
|
// return nil
|
|
// }
|
|
|
|
// return s.userRepo.Disable(id)
|
|
// })
|
|
// }
|
|
|
|
// BatchEnableUsers 批量启用用户
|
|
func (s *userService) BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error {
|
|
// 检查参数
|
|
if len(ids) == 0 {
|
|
return ErrInvalidUserInput
|
|
}
|
|
|
|
// 检查操作者权限
|
|
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
|
return err
|
|
}
|
|
|
|
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
|
// 检查所有用户是否存在,并收集当前状态
|
|
enabledUsers := make([]int64, 0)
|
|
for _, id := range ids {
|
|
user, err := s.userRepo.GetByID(id)
|
|
if err != nil {
|
|
return ErrUserNotFound
|
|
}
|
|
if *user.Active == true {
|
|
enabledUsers = append(enabledUsers, id)
|
|
}
|
|
}
|
|
|
|
// 如果所有用户都已经是启用状态,返回成功
|
|
if len(enabledUsers) == len(ids) {
|
|
return nil
|
|
}
|
|
|
|
// 过滤掉已经启用的用户,只处理需要启用的用户
|
|
toEnableIds := make([]int64, 0)
|
|
for _, id := range ids {
|
|
if !contains(enabledUsers, id) {
|
|
toEnableIds = append(toEnableIds, id)
|
|
}
|
|
}
|
|
|
|
if len(toEnableIds) > 0 {
|
|
return s.userRepo.BatchEnable(toEnableIds, nil)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// BatchDisableUsers 批量禁用用户
|
|
func (s *userService) BatchDisableUsers(ctx context.Context, ids []int64, operatorID int64) error {
|
|
// 检查参数
|
|
if len(ids) == 0 {
|
|
return ErrInvalidUserInput
|
|
}
|
|
|
|
// 检查操作者权限
|
|
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
|
return err
|
|
}
|
|
|
|
// 不允许包含自己
|
|
if contains(ids, operatorID) {
|
|
return ErrInvalidOperation
|
|
}
|
|
|
|
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
|
// 检查所有用户是否存在
|
|
disabledUsers := make([]int64, 0)
|
|
for _, id := range ids {
|
|
user, err := s.userRepo.GetByID(id)
|
|
if err != nil {
|
|
return ErrUserNotFound
|
|
}
|
|
// 不允许禁用管理员
|
|
if *user.Role == consts.RoleAdmin {
|
|
return ErrPermissionDenied
|
|
}
|
|
if user.Status == consts.StatusDisabled {
|
|
disabledUsers = append(disabledUsers, id)
|
|
}
|
|
}
|
|
|
|
// 如果所有用户都已经是禁用状态,返回成功
|
|
if len(disabledUsers) == len(ids) {
|
|
return nil
|
|
}
|
|
|
|
// 过滤掉已经禁用的用户,只处理需要禁用的用户
|
|
toDisableIds := make([]int64, 0)
|
|
for _, id := range ids {
|
|
if !contains(disabledUsers, id) {
|
|
toDisableIds = append(toDisableIds, id)
|
|
}
|
|
}
|
|
|
|
if len(toDisableIds) > 0 {
|
|
return s.userRepo.BatchDisable(toDisableIds, nil)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// BatchDeleteUsers 批量删除用户
|
|
func (s *userService) BatchDeleteUsers(ctx context.Context, ids []int64, operatorID int64) error {
|
|
// 检查参数
|
|
if len(ids) == 0 {
|
|
return ErrInvalidUserInput
|
|
}
|
|
|
|
// 检查操作者权限
|
|
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
|
return err
|
|
}
|
|
|
|
// 不允许包含自己
|
|
if contains(ids, operatorID) {
|
|
return ErrInvalidOperation
|
|
}
|
|
|
|
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
|
// 检查所有用户是否存在,并确保不会删除管理员
|
|
for _, id := range ids {
|
|
user, err := s.userRepo.GetByID(id)
|
|
if err != nil {
|
|
return ErrUserNotFound
|
|
}
|
|
if *user.Role == consts.RoleAdmin {
|
|
return ErrPermissionDenied
|
|
}
|
|
}
|
|
|
|
return s.userRepo.BatchDelete(ids, nil)
|
|
})
|
|
}
|
|
|
|
// contains 检查切片中是否包含特定值
|
|
func contains[T comparable](slice []T, item T) bool {
|
|
for _, s := range slice {
|
|
if s == item {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|