package service import ( "context" "fmt" "opencatd-open/internal/auth" "opencatd-open/internal/consts" "opencatd-open/internal/dao" "opencatd-open/internal/dto" "opencatd-open/internal/model" "opencatd-open/internal/utils" "opencatd-open/pkg/config" "strings" "time" "github.com/google/uuid" "gorm.io/gorm" ) type UserServiceImpl struct { cfg *config.Config db *gorm.DB userRepo dao.UserRepository } func NewUserService(cfg *config.Config, db *gorm.DB, userRepo dao.UserRepository) *UserServiceImpl { return &UserServiceImpl{ cfg: cfg, db: db, userRepo: userRepo, } } func (s *UserServiceImpl) Register(ctx context.Context, req *model.User) error { var _user model.User var count int64 err := s.db.Model(&model.User{}).Count(&count).Error if err != nil { return fmt.Errorf("username or email already exists") } if count == 0 { _user.Name = "root" _user.Role = utils.ToPtr(consts.RoleRoot) _user.Active = utils.ToPtr(true) _user.UnlimitedQuota = utils.ToPtr(true) } else { if !s.cfg.AllowRegister { return fmt.Errorf("register is not allowed") } } _user.Password, err = utils.HashPassword(req.Password) if err != nil { return err } _user.Active = &s.cfg.DefaultActive _user.UnlimitedQuota = &s.cfg.UnlimitedQuota _user.Username = req.Username _user.Email = req.Email _user.Tokens = []model.Token{ { Name: "default", Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""), }, } return s.userRepo.Create(&_user) } func (s *UserServiceImpl) Login(ctx context.Context, req *dto.User) (*dto.Auth, error) { var _user model.User if err := s.db.Model(&model.User{}).Where("username = ?", req.Username).First(&_user).Error; err != nil { if err := s.db.Model(&model.User{}).Where("email = ?", req.Username).First(&_user).Error; err != nil { return nil, err } } if utils.CheckPassword(_user.Password, req.Password) { day := 86400 at, err := auth.GenerateTokenPair(&_user, consts.SecretKey, time.Duration(day)*time.Second, time.Duration(day*7)*time.Second) if err != nil { return nil, err } return &dto.Auth{ Token: at.AccessToken, ExpiresIn: time.Now().Add(time.Duration(day) * time.Second).Unix(), }, nil } return nil, fmt.Errorf("密码错误") } func (s *UserServiceImpl) Profile(ctx context.Context) (*model.User, error) { id := ctx.Value("user_id").(int64) return s.userRepo.GetByID(id) } func (s *UserServiceImpl) List(ctx context.Context, limit, offset int, active []string) ([]model.User, int64, error) { userRoleValue := ctx.Value("user_role") if userRoleValue == nil { return nil, 0, fmt.Errorf("user role not found in context") } role, ok := userRoleValue.(*consts.UserRole) if !ok { return nil, 0, fmt.Errorf("user role in context is not an integer") } if *role < consts.RoleAdmin { return nil, 0, fmt.Errorf("Unauthorized") } else if *role < consts.RoleRoot { // 管理员只能查看普通用户 var condition = map[string]interface{}{"role = ?": consts.RoleUser} if len(active) > 0 { boolCondition := utils.StringToBool(active) condition["active IN ?"] = boolCondition } return s.userRepo.List(limit, offset, condition) } else { var condition = make(map[string]interface{}) if len(active) > 0 { boolCondition := utils.StringToBool(active) condition["active IN ?"] = boolCondition } return s.userRepo.List(limit, offset, condition) } } func (s *UserServiceImpl) Create(ctx context.Context, req *model.User) error { userRoleValue := ctx.Value("user_role") if userRoleValue == nil { return fmt.Errorf("user role not found in context") } role, ok := userRoleValue.(*consts.UserRole) if !ok { return fmt.Errorf("user role in context is not an integer") } var _user model.User if *role < consts.RoleAdmin { return fmt.Errorf("Forbidden") } else if *role < consts.RoleRoot { _user.Role = utils.ToPtr(consts.RoleRoot) } else { _user.Role = req.Role } _user.Username = req.Username _user.Name = req.Name _user.Email = req.Email _user.Active = req.Active _user.Quota = req.Quota _user.UnlimitedQuota = req.UnlimitedQuota _user.Language = req.Language if hashpass, err := utils.HashPassword(req.Password); err != nil { return err } else { _user.Password = hashpass } _user.Tokens = []model.Token{ { Name: "default", Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""), }, } return s.userRepo.Create(&_user) } func (s *UserServiceImpl) GetByID(ctx context.Context, id int64) (*model.User, error) { return s.userRepo.GetByID(id) } func (s *UserServiceImpl) Update(ctx context.Context, user *model.User) error { _user := ctx.Value("user").(*model.User) // 被更新的用户 if _user == nil { return fmt.Errorf("user not found in context") } userid := ctx.Value("user_id").(int64) // 操作者 userRoleValue := ctx.Value("user_role") if userRoleValue == nil { return fmt.Errorf("user role not found in context") } role, ok := userRoleValue.(*consts.UserRole) // 操作者角色 if !ok { return fmt.Errorf("user role in context is not an integer") } switch { case *role < consts.RoleAdmin: if user.ID != userid { return fmt.Errorf("Permission denied") } case *role == consts.RoleAdmin: if *user.Role > *role { // 更新的用户角色不能高于操作者角色 return fmt.Errorf("Permission denied") } if *_user.Role >= *role { // 管理员之间不能被修改 return fmt.Errorf("Permission denied") } case *role > consts.RoleAdmin: // 根不能被修改 if user.ID == userid { user.Role = role // root不能修改自己的角色 } else { if user.Role != nil && user.Role == utils.ToPtr(consts.RoleRoot) { return fmt.Errorf("Root user Only one can exist") } } } if user.Name != "" { _user.Name = user.Name } if user.Username != "" { _user.Username = user.Username } if user.Email != "" { _user.Email = user.Email _user.EmailVerified = utils.ToPtr(false) } if user.Active != nil { _user.Active = user.Active } if user.Role != nil { _user.Role = user.Role } if user.Active != nil { _user.Active = user.Active } if user.Quota != nil { _user.Quota = user.Quota } if user.UsedQuota != nil { _user.UsedQuota = user.UsedQuota } if user.UnlimitedQuota != nil { _user.UnlimitedQuota = user.UnlimitedQuota } if user.Timezone != "" { _user.Timezone = user.Timezone } if user.Language != "" { _user.Language = user.Language } return s.userRepo.Update(_user) } func (s *UserServiceImpl) Delete(ctx context.Context, id int64) error { _user, err := s.userRepo.GetByID(id) // 被更新的用户 if err != nil { return err } userid := ctx.Value("user_id").(int64) userRoleValue := ctx.Value("user_role") if userRoleValue == nil { return fmt.Errorf("user role not found in context") } role, ok := userRoleValue.(*consts.UserRole) // 操作者 if !ok { return fmt.Errorf("user role in context is not an integer") } switch { case *role < consts.RoleAdmin: if _user.ID != userid { return fmt.Errorf("Permission denied") } case *role == consts.RoleAdmin: if *_user.Role >= *role { // 管理员之间不能被修改 return fmt.Errorf("Permission denied") } case *_user.Role == consts.RoleRoot: // 根不能被修改 return fmt.Errorf("Root user can not be modified") } return s.userRepo.Delete(id) } func (s *UserServiceImpl) BatchDelete(ctx context.Context, ids []int64) error { userRoleValue := ctx.Value("user_role") if userRoleValue == nil { return fmt.Errorf("user role not found in context") } role, ok := userRoleValue.(*consts.UserRole) if !ok { return fmt.Errorf("user role in context is not an integer") } switch { case *role < consts.RoleAdmin: return fmt.Errorf("Unauthorized") case *role == consts.RoleAdmin: return s.userRepo.BatchDelete(ids, []string{fmt.Sprintf("role < %d", role)}) } return s.userRepo.BatchDelete(ids, []string{fmt.Sprintf("role < %d", consts.RoleRoot)}) } func (s *UserServiceImpl) BatchEnable(ctx context.Context, ids []int64) error { userRoleValue := ctx.Value("user_role") if userRoleValue == nil { return fmt.Errorf("user role not found in context") } role, ok := userRoleValue.(*consts.UserRole) if !ok { return fmt.Errorf("user role in context is not an integer") } switch { case *role < consts.RoleAdmin: return fmt.Errorf("Unauthorized") case *role == consts.RoleAdmin: return s.userRepo.BatchEnable(ids, []string{fmt.Sprintf("role < %d", role)}) } return s.userRepo.BatchEnable(ids, nil) } func (s *UserServiceImpl) BatchDisable(ctx context.Context, ids []int64) error { userRoleValue := ctx.Value("user_role") if userRoleValue == nil { return fmt.Errorf("user role not found in context") } role, ok := userRoleValue.(*consts.UserRole) if !ok { return fmt.Errorf("user role in context is not an integer") } switch { case *role < consts.RoleAdmin: return fmt.Errorf("Unauthorized") case *role == consts.RoleAdmin: return s.userRepo.BatchDisable(ids, []string{fmt.Sprintf("role < %d", role)}) } return s.userRepo.BatchDisable(ids, nil) }