Files
opencatd-open/internal/service/team/user.go
2025-04-16 18:01:27 +08:00

680 lines
17 KiB
Go

package service
import (
"context"
"errors"
"opencatd-open/internal/consts"
"opencatd-open/internal/dao"
"opencatd-open/internal/model"
"regexp"
"strings"
"time"
"golang.org/x/crypto/bcrypt"
"golang.org/x/exp/rand"
"gorm.io/gorm"
)
var (
ErrUserNotFound = errors.New("user not found")
ErrInvalidUserInput = errors.New("invalid user input")
ErrUserExists = errors.New("user already exists")
ErrInvalidPassword = errors.New("invalid password format")
ErrPermissionDenied = errors.New("permission denied")
ErrInvalidOperation = errors.New("invalid operation")
ErrTransactionFailed = errors.New("transaction failed")
)
// PasswordPolicy 定义密码策略
type PasswordPolicy struct {
MinLength int
MaxLength int
NeedNumber bool
NeedUpper bool
NeedLower bool
NeedSymbol bool
}
// 生成随机数字
func generateNumber() string {
rand.Seed(uint64(time.Now().UnixNano()))
return string('0' + rand.Intn(10))
}
// 生成随机大写字母
func generateUpper() string {
rand.Seed(uint64(time.Now().UnixNano()))
return string('A' + rand.Intn(26))
}
// 生成随机小写字母
func generateLower() string {
rand.Seed(uint64(time.Now().UnixNano()))
return string('a' + rand.Intn(26))
}
// 生成随机特殊符号
func generateSymbol() string {
rand.Seed(uint64(time.Now().UnixNano()))
symbols := "!@#$%^&*"
return string(symbols[rand.Intn(len(symbols))])
}
// GeneratePassword 根据密码策略生成密码
func GeneratePassword(policy PasswordPolicy) string {
rand.Seed(uint64(time.Now().UnixNano()))
// 确保满足所有必须的字符类型
var password string
if policy.NeedNumber {
password += generateNumber()
}
if policy.NeedUpper {
password += generateUpper()
}
if policy.NeedLower {
password += generateLower()
}
if policy.NeedSymbol {
password += generateSymbol()
}
// 计算还需要多少个字符
remainingLength := policy.MinLength - len(password)
if remainingLength < 0 {
remainingLength = 0
}
// 剩余长度随机生成密码字符
for i := 0; i < remainingLength; i++ {
randType := rand.Intn(4) // 0:数字, 1:大写, 2:小写, 3:符号
switch randType {
case 0:
password += generateNumber()
case 1:
password += generateUpper()
case 2:
password += generateLower()
case 3:
password += generateSymbol()
}
}
// 如果密码长度超过最大值,则截断
if len(password) > policy.MaxLength {
password = password[:policy.MaxLength]
}
// 将密码打乱
passwordRune := []rune(password)
rand.Shuffle(len(passwordRune), func(i, j int) {
passwordRune[i], passwordRune[j] = passwordRune[j], passwordRune[i]
})
return string(passwordRune)
}
var _ UserService = (*userService)(nil)
// UserService 定义用户服务的接口
type UserService interface {
CreateUser(ctx context.Context, user *model.User) error
GetUser(ctx context.Context, id int64) (*model.User, error)
GetUserByUsername(ctx context.Context, username string) (*model.User, error)
UpdateUser(ctx context.Context, user *model.User, operatorID int64) error
DeleteUser(ctx context.Context, id int64, operatorID int64) error
ListUsers(ctx context.Context, limit, offset int, active string) ([]model.User, error)
// EnableUser(ctx context.Context, id int64, operatorID int64) error
// DisableUser(ctx context.Context, id int64, operatorID int64) error
BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error
BatchDisableUsers(ctx context.Context, ids []int64, operatorID int64) error
BatchDeleteUsers(ctx context.Context, ids []int64, operatorID int64) error
ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error
ResetPassword(ctx context.Context, userID int64, operatorID int64) error
ValidatePassword(password string) error
CheckPermission(ctx context.Context, requiredRole consts.UserRole) error
}
// userService 实现 UserService 接口
type userService struct {
userRepo dao.UserRepository
db *gorm.DB
pwdPolicy PasswordPolicy
}
// NewUserService 创建 UserService 实例
func NewUserService(db *gorm.DB, userRepo dao.UserRepository) UserService {
return &userService{
userRepo: userRepo,
db: db,
pwdPolicy: PasswordPolicy{
MinLength: 8,
MaxLength: 32,
NeedNumber: true,
NeedUpper: true,
NeedLower: true,
NeedSymbol: true,
},
}
}
// hashPassword 使用 bcrypt 加密密码
func (s *userService) hashPassword(password string) (string, error) {
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hashedBytes), nil
}
// comparePasswords 比较密码
func (s *userService) comparePasswords(hashedPassword, password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
return err == nil
}
// ValidatePassword 验证密码是否符合策略
func (s *userService) ValidatePassword(password string) error {
if len(password) < s.pwdPolicy.MinLength || len(password) > s.pwdPolicy.MaxLength {
return ErrInvalidPassword
}
if s.pwdPolicy.NeedNumber && !regexp.MustCompile(`[0-9]`).MatchString(password) {
return ErrInvalidPassword
}
if s.pwdPolicy.NeedUpper && !regexp.MustCompile(`[A-Z]`).MatchString(password) {
return ErrInvalidPassword
}
if s.pwdPolicy.NeedLower && !regexp.MustCompile(`[a-z]`).MatchString(password) {
return ErrInvalidPassword
}
if s.pwdPolicy.NeedSymbol && !regexp.MustCompile(`[!@#$%^&*]`).MatchString(password) {
return ErrInvalidPassword
}
return nil
}
// CheckPermission 检查用户权限
func (s *userService) CheckPermission(ctx context.Context, requiredRole consts.UserRole) error {
userToken := ctx.Value("Token").(*model.Token)
// 检查用户角色
if *userToken.User.Role < requiredRole {
return ErrPermissionDenied
}
return nil
}
// withTransaction 事务处理封装
func (s *userService) withTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error {
tx := s.db.WithContext(ctx).Begin()
if tx.Error != nil {
return ErrTransactionFailed
}
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
}()
if err := fn(tx); err != nil {
tx.Rollback()
return err
}
if err := tx.Commit().Error; err != nil {
return ErrTransactionFailed
}
return nil
}
// CreateUser 创建用户
func (s *userService) CreateUser(ctx context.Context, user *model.User) error {
if user == nil {
return ErrInvalidUserInput
}
if user.Password == "" {
user.Password = GeneratePassword(s.pwdPolicy)
}
// 使用事务处理
return s.withTransaction(ctx, func(tx *gorm.DB) error {
// 检查用户名是否已存在
// _, err := s.userRepo.GetByID(user.ID)
// if err != nil {
// return err
// }
// 加密密码
hashedPassword, err := s.hashPassword(user.Password)
if err != nil {
return err
}
user.Password = hashedPassword
return s.userRepo.Create(user)
})
}
// GetUser 根据 ID 获取用户
func (s *userService) GetUser(ctx context.Context, id int64) (*model.User, error) {
if id <= 0 {
return nil, ErrInvalidUserInput
}
user, err := s.userRepo.GetByID(id)
if err != nil {
return nil, err // 返回其他数据库错误
}
// 处理返回结果,清除敏感信息
user.Password = "" // 清除密码信息
return user, nil
}
// GetUserByUsername 根据用户名获取用户
func (s *userService) GetUserByUsername(ctx context.Context, username string) (*model.User, error) {
if username == "" {
return nil, ErrInvalidUserInput
}
user, err := s.userRepo.GetByUsername(username)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
return nil, err // 返回其他数据库错误
}
// 处理返回结果,清除敏感信息
user.Password = "" // 清除密码信息
return user, nil
}
// UpdateUser 更新用户信息
func (s *userService) UpdateUser(ctx context.Context, user *model.User, operatorID int64) error {
if user == nil || user.ID <= 0 {
return ErrInvalidUserInput
}
return s.withTransaction(ctx, func(tx *gorm.DB) error {
// 检查用户是否存在
existingUser, err := s.userRepo.GetByID(user.ID)
if err != nil {
return err
}
// 如果修改了用户名,检查新用户名是否已存在
if user.Username != existingUser.Username {
tmpUser, err := s.userRepo.GetByUsername(user.Username)
if err == nil && tmpUser != nil && tmpUser.ID != user.ID {
return ErrUserExists
}
}
// 保持原有密码
user.Password = existingUser.Password
user.UpdatedAt = time.Now().Unix()
return s.userRepo.Update(user)
})
}
// ChangePassword 修改密码
func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error {
// 验证新密码
if err := s.ValidatePassword(newPassword); err != nil {
return err
}
return s.withTransaction(ctx, func(tx *gorm.DB) error {
user, err := s.userRepo.GetByID(userID)
if err != nil {
return ErrUserNotFound
}
// 验证旧密码
if !s.comparePasswords(user.Password, oldPassword) {
return ErrInvalidPassword
}
// 加密新密码
hashedPassword, err := s.hashPassword(newPassword)
if err != nil {
return err
}
user.Password = hashedPassword
user.UpdatedAt = time.Now().Unix()
return s.userRepo.Update(user)
})
}
// ResetPassword 重置密码
func (s *userService) ResetPassword(ctx context.Context, userID int64, operatorID int64) error {
return s.withTransaction(ctx, func(tx *gorm.DB) error {
user, err := s.userRepo.GetByID(userID)
if err != nil {
return ErrUserNotFound
}
// 生成随机密码
newPassword := generateRandomPassword()
hashedPassword, err := s.hashPassword(newPassword)
if err != nil {
return err
}
user.Password = hashedPassword
user.UpdatedAt = time.Now().Unix()
// TODO: 发送新密码给用户邮箱
return s.userRepo.Update(user)
})
}
// ListUsers 获取用户列表(增加过滤功能)
func (s *userService) ListUsers(ctx context.Context, limit, offset int, active string) ([]model.User, error) {
if limit < 0 {
limit = 20
}
if offset < 0 {
offset = 0
}
var users []model.User
var err error
if active != "" {
users, _, err = s.userRepo.List(limit, offset, map[string]interface{}{"active in ?": strings.Split(active, ",")})
} else {
users, _, err = s.userRepo.List(limit, offset, nil)
}
return users, err
}
// generateRandomPassword 生成随机密码
func generateRandomPassword() string {
const (
lowerChars = "abcdefghijklmnopqrstuvwxyz"
upperChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
numberChars = "0123456789"
specialChars = "!@#$%^&*"
)
// rand.NewSource(uint64(time.Now().UnixNano()))
// 确保每种字符都至少出现一次
password := []string{
string(lowerChars[rand.Intn(len(lowerChars))]),
string(upperChars[rand.Intn(len(upperChars))]),
string(numberChars[rand.Intn(len(numberChars))]),
string(specialChars[rand.Intn(len(specialChars))]),
}
// 所有可用字符
allChars := lowerChars + upperChars + numberChars + specialChars
// 生成剩余的12个字符
for i := 0; i < 12; i++ {
password = append(password, string(allChars[rand.Intn(len(allChars))]))
}
// 打乱密码字符顺序
rand.Shuffle(len(password), func(i, j int) {
password[i], password[j] = password[j], password[i]
})
return strings.Join(password, "")
}
// DeleteUser 删除用户
func (s *userService) DeleteUser(ctx context.Context, id int64, operatorID int64) error {
// 检查参数
if id <= 0 {
return ErrInvalidUserInput
}
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
return err
}
// 不允许删除自己
if id == operatorID {
return ErrInvalidOperation
}
return s.withTransaction(ctx, func(tx *gorm.DB) error {
// 检查用户是否存在
user, err := s.userRepo.GetByID(id)
if err != nil {
return err
}
// 检查是否试图删除管理员
if *user.Role == consts.RoleAdmin {
return ErrPermissionDenied
}
return s.userRepo.Delete(id)
})
}
// EnableUser 启用用户
// func (s *userService) EnableUser(ctx context.Context, id int64, operatorID int64) error {
// // 检查参数
// if id <= 0 {
// return ErrInvalidUserInput
// }
// // 检查操作者权限
// if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
// return err
// }
// return s.withTransaction(ctx, func(tx *gorm.DB) error {
// // 检查用户是否存在
// user, err := s.userRepo.GetByID(id)
// if err != nil {
// return ErrUserNotFound
// }
// // 如果用户已经是启用状态,返回成功
// if user.Status == consts.StatusEnabled {
// return nil
// }
// return s.userRepo.Enable(id)
// })
// }
// DisableUser 禁用用户
// func (s *userService) DisableUser(ctx context.Context, id int64, operatorID int64) error {
// // 检查参数
// if id <= 0 {
// return ErrInvalidUserInput
// }
// // 检查操作者权限
// if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
// return err
// }
// // 不允许禁用自己
// if id == operatorID {
// return ErrInvalidOperation
// }
// return s.withTransaction(ctx, func(tx *gorm.DB) error {
// // 检查用户是否存在
// user, err := s.userRepo.GetByID(id)
// if err != nil {
// return ErrUserNotFound
// }
// // 检查是否试图禁用超级管理员
// if user.Role == consts.RoleAdmin {
// return ErrPermissionDenied
// }
// // 如果用户已经是禁用状态,返回成功
// if user.Status == consts.StatusDisabled {
// return nil
// }
// return s.userRepo.Disable(id)
// })
// }
// BatchEnableUsers 批量启用用户
func (s *userService) BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error {
// 检查参数
if len(ids) == 0 {
return ErrInvalidUserInput
}
// 检查操作者权限
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
return err
}
return s.withTransaction(ctx, func(tx *gorm.DB) error {
// 检查所有用户是否存在,并收集当前状态
enabledUsers := make([]int64, 0)
for _, id := range ids {
user, err := s.userRepo.GetByID(id)
if err != nil {
return ErrUserNotFound
}
if *user.Active == true {
enabledUsers = append(enabledUsers, id)
}
}
// 如果所有用户都已经是启用状态,返回成功
if len(enabledUsers) == len(ids) {
return nil
}
// 过滤掉已经启用的用户,只处理需要启用的用户
toEnableIds := make([]int64, 0)
for _, id := range ids {
if !contains(enabledUsers, id) {
toEnableIds = append(toEnableIds, id)
}
}
if len(toEnableIds) > 0 {
return s.userRepo.BatchEnable(toEnableIds, nil)
}
return nil
})
}
// BatchDisableUsers 批量禁用用户
func (s *userService) BatchDisableUsers(ctx context.Context, ids []int64, operatorID int64) error {
// 检查参数
if len(ids) == 0 {
return ErrInvalidUserInput
}
// 检查操作者权限
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
return err
}
// 不允许包含自己
if contains(ids, operatorID) {
return ErrInvalidOperation
}
return s.withTransaction(ctx, func(tx *gorm.DB) error {
// 检查所有用户是否存在
disabledUsers := make([]int64, 0)
for _, id := range ids {
user, err := s.userRepo.GetByID(id)
if err != nil {
return ErrUserNotFound
}
// 不允许禁用管理员
if *user.Role == consts.RoleAdmin {
return ErrPermissionDenied
}
if user.Status == consts.StatusDisabled {
disabledUsers = append(disabledUsers, id)
}
}
// 如果所有用户都已经是禁用状态,返回成功
if len(disabledUsers) == len(ids) {
return nil
}
// 过滤掉已经禁用的用户,只处理需要禁用的用户
toDisableIds := make([]int64, 0)
for _, id := range ids {
if !contains(disabledUsers, id) {
toDisableIds = append(toDisableIds, id)
}
}
if len(toDisableIds) > 0 {
return s.userRepo.BatchDisable(toDisableIds, nil)
}
return nil
})
}
// BatchDeleteUsers 批量删除用户
func (s *userService) BatchDeleteUsers(ctx context.Context, ids []int64, operatorID int64) error {
// 检查参数
if len(ids) == 0 {
return ErrInvalidUserInput
}
// 检查操作者权限
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
return err
}
// 不允许包含自己
if contains(ids, operatorID) {
return ErrInvalidOperation
}
return s.withTransaction(ctx, func(tx *gorm.DB) error {
// 检查所有用户是否存在,并确保不会删除管理员
for _, id := range ids {
user, err := s.userRepo.GetByID(id)
if err != nil {
return ErrUserNotFound
}
if *user.Role == consts.RoleAdmin {
return ErrPermissionDenied
}
}
return s.userRepo.BatchDelete(ids, nil)
})
}
// contains 检查切片中是否包含特定值
func contains[T comparable](slice []T, item T) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}