reface to openteam
This commit is contained in:
92
internal/service/apikey.go
Normal file
92
internal/service/apikey.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"opencatd-open/internal/dao"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/internal/utils"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ApiKeyServiceImpl struct {
|
||||
db *gorm.DB
|
||||
apiKeyRepo dao.ApiKeyRepository
|
||||
}
|
||||
|
||||
func NewApiKeyService(db *gorm.DB, apiKeyDao dao.ApiKeyRepository) *ApiKeyServiceImpl {
|
||||
return &ApiKeyServiceImpl{db: db, apiKeyRepo: apiKeyDao}
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) CreateApiKey(ctx context.Context, apikey *model.ApiKey) error {
|
||||
return s.apiKeyRepo.Create(apikey)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) GetApiKey(ctx context.Context, id int64) (*model.ApiKey, error) {
|
||||
return s.apiKeyRepo.GetByID(id)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) ListApiKey(ctx context.Context, limit, offset int, active []string) ([]*model.ApiKey, int64, error) {
|
||||
var conditions = make(map[string]interface{})
|
||||
if len(active) > 0 {
|
||||
conditions["active IN ?"] = utils.StringToBool(active)
|
||||
}
|
||||
return s.apiKeyRepo.ListWithFilters(limit, offset, conditions)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) UpdateApiKey(ctx context.Context, apikey *model.ApiKey) error {
|
||||
_key, err := s.apiKeyRepo.GetByID(apikey.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get apikey failed: %v", err)
|
||||
}
|
||||
if apikey.ApiKey != nil {
|
||||
_key.ApiKey = apikey.ApiKey
|
||||
}
|
||||
if apikey.Active != nil {
|
||||
_key.Active = apikey.Active
|
||||
}
|
||||
if apikey.Endpoint != nil {
|
||||
_key.Endpoint = apikey.Endpoint
|
||||
}
|
||||
if apikey.ResourceNmae != nil {
|
||||
_key.ResourceNmae = apikey.ResourceNmae
|
||||
}
|
||||
if apikey.DeploymentName != nil {
|
||||
_key.DeploymentName = apikey.DeploymentName
|
||||
}
|
||||
if apikey.AccessKey != nil {
|
||||
_key.AccessKey = apikey.AccessKey
|
||||
}
|
||||
if apikey.SecretKey != nil {
|
||||
_key.SecretKey = apikey.SecretKey
|
||||
}
|
||||
if apikey.ModelAlias != nil {
|
||||
_key.ModelAlias = apikey.ModelAlias
|
||||
}
|
||||
if apikey.ModelPrefix != nil {
|
||||
_key.ModelPrefix = apikey.ModelPrefix
|
||||
}
|
||||
if apikey.Parameters != nil {
|
||||
_key.Parameters = apikey.Parameters
|
||||
}
|
||||
if apikey.SupportModels != nil {
|
||||
_key.SupportModels = apikey.SupportModels
|
||||
}
|
||||
if apikey.SupportModelsArray != nil {
|
||||
_key.SupportModelsArray = apikey.SupportModelsArray
|
||||
}
|
||||
|
||||
return s.apiKeyRepo.Update(apikey)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) DeleteApiKey(ctx context.Context, ids []int64) error {
|
||||
return s.apiKeyRepo.BatchDelete(ids)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) EnableApiKey(ctx context.Context, ids []int64) error {
|
||||
return s.apiKeyRepo.BatchEnable(ids)
|
||||
}
|
||||
func (s *ApiKeyServiceImpl) DisableApiKey(ctx context.Context, ids []int64) error {
|
||||
return s.apiKeyRepo.BatchDisable(ids)
|
||||
}
|
||||
150
internal/service/team/apikey.go
Normal file
150
internal/service/team/apikey.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"opencatd-open/internal/dao"
|
||||
"opencatd-open/internal/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ ApiKeyService = (*ApiKeyServiceImpl)(nil)
|
||||
|
||||
type ApiKeyService interface {
|
||||
Create(apiKey *model.ApiKey) error
|
||||
GetByID(id int64) (*model.ApiKey, error)
|
||||
GetByName(name string) (*model.ApiKey, error)
|
||||
GetByApiKey(apiKeyValue string) (*model.ApiKey, error)
|
||||
Update(apiKey *model.ApiKey) error
|
||||
Delete(id int64) error
|
||||
List(limit, offset int, status string) ([]*model.ApiKey, error)
|
||||
ListWithFilters(limit, offset int, filters map[string]interface{}) ([]*model.ApiKey, int64, error)
|
||||
BatchEnable(ids []int64) error
|
||||
BatchDisable(ids []int64) error
|
||||
BatchDelete(ids []int64) error
|
||||
Count() (int64, error)
|
||||
}
|
||||
|
||||
type ApiKeyServiceImpl struct {
|
||||
db *gorm.DB
|
||||
apiKeyRepo dao.ApiKeyRepository
|
||||
}
|
||||
|
||||
func NewApiKeyService(db *gorm.DB, apiKeyDao dao.ApiKeyRepository) ApiKeyService {
|
||||
return &ApiKeyServiceImpl{apiKeyRepo: apiKeyDao, db: db}
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Create(apiKey *model.ApiKey) error {
|
||||
if apiKey == nil {
|
||||
return errors.New("apiKey不能为空")
|
||||
}
|
||||
if apiKey.Name == nil {
|
||||
return errors.New("apiKey名称不能为空")
|
||||
}
|
||||
if apiKey.ApiKey == nil {
|
||||
return errors.New("apiKey值不能为空")
|
||||
}
|
||||
apiKey.CreatedAt = time.Now().Unix()
|
||||
apiKey.UpdatedAt = time.Now().Unix()
|
||||
|
||||
return s.apiKeyRepo.Create(apiKey)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) GetByID(id int64) (*model.ApiKey, error) {
|
||||
if id <= 0 {
|
||||
return nil, errors.New("id 必须大于 0")
|
||||
}
|
||||
return s.apiKeyRepo.GetByID(id)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) GetByName(name string) (*model.ApiKey, error) {
|
||||
if name == "" {
|
||||
return nil, errors.New("name 不能为空")
|
||||
}
|
||||
return s.apiKeyRepo.GetByName(name)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) GetByApiKey(apiKeyValue string) (*model.ApiKey, error) {
|
||||
if apiKeyValue == "" {
|
||||
return nil, errors.New("apiKeyValue 不能为空")
|
||||
}
|
||||
return s.apiKeyRepo.GetByApiKey(apiKeyValue)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Update(apiKey *model.ApiKey) error {
|
||||
if apiKey == nil {
|
||||
return errors.New("apiKey不能为空")
|
||||
}
|
||||
if apiKey.ID <= 0 {
|
||||
return errors.New("apiKey ID 必须大于 0")
|
||||
}
|
||||
return s.apiKeyRepo.Update(apiKey)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Delete(id int64) error {
|
||||
if id <= 0 {
|
||||
return errors.New("id 必须大于 0")
|
||||
}
|
||||
return s.apiKeyRepo.BatchDelete([]int64{id})
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) List(offset, limit int, status string) ([]*model.ApiKey, error) {
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = 20 // 设置默认值
|
||||
}
|
||||
return s.apiKeyRepo.List(offset, limit, status)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) ListWithFilters(offset, limit int, filters map[string]interface{}) ([]*model.ApiKey, int64, error) {
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = 20 // 设置默认值
|
||||
}
|
||||
|
||||
return s.apiKeyRepo.ListWithFilters(offset, limit, filters)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Enable(id int64) error {
|
||||
if id <= 0 {
|
||||
return errors.New("id 必须大于 0")
|
||||
}
|
||||
return s.apiKeyRepo.BatchEnable([]int64{id})
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Disable(id int64) error {
|
||||
if id <= 0 {
|
||||
return errors.New("id 必须大于 0")
|
||||
}
|
||||
return s.apiKeyRepo.BatchDisable([]int64{id})
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) BatchEnable(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids 不能为空")
|
||||
}
|
||||
return s.apiKeyRepo.BatchEnable(ids)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) BatchDisable(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids 不能为空")
|
||||
}
|
||||
return s.apiKeyRepo.BatchDisable(ids)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) BatchDelete(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids 不能为空")
|
||||
}
|
||||
return s.apiKeyRepo.BatchDelete(ids)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Count() (int64, error) {
|
||||
return s.apiKeyRepo.Count()
|
||||
}
|
||||
77
internal/service/team/token.go
Normal file
77
internal/service/team/token.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"opencatd-open/internal/dao"
|
||||
"opencatd-open/internal/model"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// 确保 TokenService 实现了 TokenServiceInterface 接口
|
||||
var _ TokenService = (*TokenServiceImpl)(nil)
|
||||
|
||||
type TokenService interface {
|
||||
Create(ctx context.Context, token *model.Token) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Token, error)
|
||||
GetByKey(ctx context.Context, key string) (*model.Token, error)
|
||||
GetByUserID(ctx context.Context, userID int64) (*model.Token, error)
|
||||
Update(ctx context.Context, token *model.Token) error
|
||||
UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
Lists(ctx context.Context, limit, offset int) ([]*model.Token, int64, error)
|
||||
Disable(ctx context.Context, id int) error
|
||||
Enable(ctx context.Context, id int) error
|
||||
}
|
||||
|
||||
type TokenServiceImpl struct {
|
||||
tokenRepo dao.TokenRepository
|
||||
}
|
||||
|
||||
func NewTokenService(tokenRepo dao.TokenRepository) TokenService {
|
||||
return &TokenServiceImpl{tokenRepo: tokenRepo}
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Create(ctx context.Context, token *model.Token) error {
|
||||
if token.Key == "" {
|
||||
token.Key = "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
}
|
||||
return s.tokenRepo.Create(ctx, token)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) GetByID(ctx context.Context, id int64) (*model.Token, error) {
|
||||
return s.tokenRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) GetByKey(ctx context.Context, key string) (*model.Token, error) {
|
||||
return s.tokenRepo.GetByKey(ctx, key)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) GetByUserID(ctx context.Context, userID int64) (*model.Token, error) {
|
||||
return s.tokenRepo.GetByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Update(ctx context.Context, token *model.Token) error {
|
||||
return s.tokenRepo.Update(ctx, token)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error {
|
||||
return s.tokenRepo.UpdateWithCondition(ctx, token, filters, updates)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Delete(ctx context.Context, id int64) error {
|
||||
return s.tokenRepo.Delete(ctx, id, nil)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Lists(ctx context.Context, limit, offset int) ([]*model.Token, int64, error) {
|
||||
return s.tokenRepo.ListWithFilters(ctx, limit, offset, nil)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Disable(ctx context.Context, id int) error {
|
||||
return s.tokenRepo.Disable(ctx, id)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Enable(ctx context.Context, id int) error {
|
||||
return s.tokenRepo.Enable(ctx, id)
|
||||
}
|
||||
62
internal/service/team/usage.go
Normal file
62
internal/service/team/usage.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"opencatd-open/internal/dao"
|
||||
dto "opencatd-open/internal/dto/team"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/pkg/config"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ UsageService = (*usageService)(nil)
|
||||
|
||||
type UsageService interface {
|
||||
ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error)
|
||||
ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error)
|
||||
ListByDateRange(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error)
|
||||
|
||||
Delete(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
type usageService struct {
|
||||
ctx context.Context
|
||||
cfg *config.Config
|
||||
db *gorm.DB
|
||||
|
||||
usageDAO dao.UsageRepository
|
||||
dailyUsageDAO dao.DailyUsageRepository
|
||||
}
|
||||
|
||||
func NewUsageService(ctx context.Context, cfg *config.Config, db *gorm.DB, usageRepo dao.UsageRepository, dailyUsageRepo dao.DailyUsageRepository) UsageService {
|
||||
srv := &usageService{
|
||||
ctx: ctx,
|
||||
cfg: cfg,
|
||||
db: db,
|
||||
|
||||
usageDAO: usageRepo,
|
||||
dailyUsageDAO: dailyUsageRepo,
|
||||
}
|
||||
|
||||
// 启动异步处理goroutine
|
||||
|
||||
return srv
|
||||
}
|
||||
|
||||
func (s *usageService) ListByUserID(ctx context.Context, userID int64, limit int, offset int) ([]*model.Usage, error) {
|
||||
return s.usageDAO.ListByUserID(ctx, userID, limit, offset)
|
||||
}
|
||||
|
||||
func (s *usageService) ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) {
|
||||
return s.usageDAO.ListByCapability(ctx, capability, limit, offset)
|
||||
}
|
||||
|
||||
func (s *usageService) ListByDateRange(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) {
|
||||
return s.dailyUsageDAO.StatUserUsages(ctx, start, end, filters)
|
||||
}
|
||||
|
||||
func (s *usageService) Delete(ctx context.Context, id int64) error {
|
||||
return s.usageDAO.Delete(ctx, id)
|
||||
}
|
||||
679
internal/service/team/user.go
Normal file
679
internal/service/team/user.go
Normal file
@@ -0,0 +1,679 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"opencatd-open/internal/consts"
|
||||
"opencatd-open/internal/dao"
|
||||
"opencatd-open/internal/model"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/exp/rand"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrInvalidUserInput = errors.New("invalid user input")
|
||||
ErrUserExists = errors.New("user already exists")
|
||||
ErrInvalidPassword = errors.New("invalid password format")
|
||||
ErrPermissionDenied = errors.New("permission denied")
|
||||
ErrInvalidOperation = errors.New("invalid operation")
|
||||
ErrTransactionFailed = errors.New("transaction failed")
|
||||
)
|
||||
|
||||
// PasswordPolicy 定义密码策略
|
||||
type PasswordPolicy struct {
|
||||
MinLength int
|
||||
MaxLength int
|
||||
NeedNumber bool
|
||||
NeedUpper bool
|
||||
NeedLower bool
|
||||
NeedSymbol bool
|
||||
}
|
||||
|
||||
// 生成随机数字
|
||||
func generateNumber() string {
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
return string('0' + rand.Intn(10))
|
||||
}
|
||||
|
||||
// 生成随机大写字母
|
||||
func generateUpper() string {
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
return string('A' + rand.Intn(26))
|
||||
}
|
||||
|
||||
// 生成随机小写字母
|
||||
func generateLower() string {
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
return string('a' + rand.Intn(26))
|
||||
}
|
||||
|
||||
// 生成随机特殊符号
|
||||
func generateSymbol() string {
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
symbols := "!@#$%^&*"
|
||||
return string(symbols[rand.Intn(len(symbols))])
|
||||
}
|
||||
|
||||
// GeneratePassword 根据密码策略生成密码
|
||||
func GeneratePassword(policy PasswordPolicy) string {
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
|
||||
// 确保满足所有必须的字符类型
|
||||
var password string
|
||||
if policy.NeedNumber {
|
||||
password += generateNumber()
|
||||
}
|
||||
if policy.NeedUpper {
|
||||
password += generateUpper()
|
||||
}
|
||||
if policy.NeedLower {
|
||||
password += generateLower()
|
||||
}
|
||||
if policy.NeedSymbol {
|
||||
password += generateSymbol()
|
||||
}
|
||||
|
||||
// 计算还需要多少个字符
|
||||
remainingLength := policy.MinLength - len(password)
|
||||
if remainingLength < 0 {
|
||||
remainingLength = 0
|
||||
}
|
||||
// 剩余长度随机生成密码字符
|
||||
for i := 0; i < remainingLength; i++ {
|
||||
randType := rand.Intn(4) // 0:数字, 1:大写, 2:小写, 3:符号
|
||||
switch randType {
|
||||
case 0:
|
||||
password += generateNumber()
|
||||
case 1:
|
||||
password += generateUpper()
|
||||
case 2:
|
||||
password += generateLower()
|
||||
case 3:
|
||||
password += generateSymbol()
|
||||
}
|
||||
}
|
||||
|
||||
// 如果密码长度超过最大值,则截断
|
||||
if len(password) > policy.MaxLength {
|
||||
password = password[:policy.MaxLength]
|
||||
}
|
||||
|
||||
// 将密码打乱
|
||||
passwordRune := []rune(password)
|
||||
rand.Shuffle(len(passwordRune), func(i, j int) {
|
||||
passwordRune[i], passwordRune[j] = passwordRune[j], passwordRune[i]
|
||||
})
|
||||
|
||||
return string(passwordRune)
|
||||
}
|
||||
|
||||
var _ UserService = (*userService)(nil)
|
||||
|
||||
// UserService 定义用户服务的接口
|
||||
type UserService interface {
|
||||
CreateUser(ctx context.Context, user *model.User) error
|
||||
GetUser(ctx context.Context, id int64) (*model.User, error)
|
||||
GetUserByUsername(ctx context.Context, username string) (*model.User, error)
|
||||
UpdateUser(ctx context.Context, user *model.User, operatorID int64) error
|
||||
DeleteUser(ctx context.Context, id int64, operatorID int64) error
|
||||
ListUsers(ctx context.Context, limit, offset int, active string) ([]model.User, error)
|
||||
// EnableUser(ctx context.Context, id int64, operatorID int64) error
|
||||
// DisableUser(ctx context.Context, id int64, operatorID int64) error
|
||||
BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error
|
||||
BatchDisableUsers(ctx context.Context, ids []int64, operatorID int64) error
|
||||
BatchDeleteUsers(ctx context.Context, ids []int64, operatorID int64) error
|
||||
ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error
|
||||
ResetPassword(ctx context.Context, userID int64, operatorID int64) error
|
||||
ValidatePassword(password string) error
|
||||
CheckPermission(ctx context.Context, requiredRole consts.UserRole) error
|
||||
}
|
||||
|
||||
// userService 实现 UserService 接口
|
||||
type userService struct {
|
||||
userRepo dao.UserRepository
|
||||
db *gorm.DB
|
||||
pwdPolicy PasswordPolicy
|
||||
}
|
||||
|
||||
// NewUserService 创建 UserService 实例
|
||||
func NewUserService(db *gorm.DB, userRepo dao.UserRepository) UserService {
|
||||
return &userService{
|
||||
userRepo: userRepo,
|
||||
db: db,
|
||||
pwdPolicy: PasswordPolicy{
|
||||
MinLength: 8,
|
||||
MaxLength: 32,
|
||||
NeedNumber: true,
|
||||
NeedUpper: true,
|
||||
NeedLower: true,
|
||||
NeedSymbol: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// hashPassword 使用 bcrypt 加密密码
|
||||
func (s *userService) hashPassword(password string) (string, error) {
|
||||
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(hashedBytes), nil
|
||||
}
|
||||
|
||||
// comparePasswords 比较密码
|
||||
func (s *userService) comparePasswords(hashedPassword, password string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ValidatePassword 验证密码是否符合策略
|
||||
func (s *userService) ValidatePassword(password string) error {
|
||||
if len(password) < s.pwdPolicy.MinLength || len(password) > s.pwdPolicy.MaxLength {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
if s.pwdPolicy.NeedNumber && !regexp.MustCompile(`[0-9]`).MatchString(password) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
if s.pwdPolicy.NeedUpper && !regexp.MustCompile(`[A-Z]`).MatchString(password) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
if s.pwdPolicy.NeedLower && !regexp.MustCompile(`[a-z]`).MatchString(password) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
if s.pwdPolicy.NeedSymbol && !regexp.MustCompile(`[!@#$%^&*]`).MatchString(password) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPermission 检查用户权限
|
||||
func (s *userService) CheckPermission(ctx context.Context, requiredRole consts.UserRole) error {
|
||||
userToken := ctx.Value("Token").(*model.Token)
|
||||
|
||||
// 检查用户角色
|
||||
if *userToken.User.Role < requiredRole {
|
||||
return ErrPermissionDenied
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// withTransaction 事务处理封装
|
||||
func (s *userService) withTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error {
|
||||
tx := s.db.WithContext(ctx).Begin()
|
||||
if tx.Error != nil {
|
||||
return ErrTransactionFailed
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return ErrTransactionFailed
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateUser 创建用户
|
||||
func (s *userService) CreateUser(ctx context.Context, user *model.User) error {
|
||||
if user == nil {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
if user.Password == "" {
|
||||
user.Password = GeneratePassword(s.pwdPolicy)
|
||||
}
|
||||
|
||||
// 使用事务处理
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查用户名是否已存在
|
||||
// _, err := s.userRepo.GetByID(user.ID)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// 加密密码
|
||||
hashedPassword, err := s.hashPassword(user.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user.Password = hashedPassword
|
||||
|
||||
return s.userRepo.Create(user)
|
||||
})
|
||||
}
|
||||
|
||||
// GetUser 根据 ID 获取用户
|
||||
func (s *userService) GetUser(ctx context.Context, id int64) (*model.User, error) {
|
||||
if id <= 0 {
|
||||
return nil, ErrInvalidUserInput
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return nil, err // 返回其他数据库错误
|
||||
}
|
||||
// 处理返回结果,清除敏感信息
|
||||
user.Password = "" // 清除密码信息
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetUserByUsername 根据用户名获取用户
|
||||
func (s *userService) GetUserByUsername(ctx context.Context, username string) (*model.User, error) {
|
||||
if username == "" {
|
||||
return nil, ErrInvalidUserInput
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByUsername(username)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, err // 返回其他数据库错误
|
||||
}
|
||||
|
||||
// 处理返回结果,清除敏感信息
|
||||
user.Password = "" // 清除密码信息
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// UpdateUser 更新用户信息
|
||||
func (s *userService) UpdateUser(ctx context.Context, user *model.User, operatorID int64) error {
|
||||
if user == nil || user.ID <= 0 {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查用户是否存在
|
||||
existingUser, err := s.userRepo.GetByID(user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果修改了用户名,检查新用户名是否已存在
|
||||
if user.Username != existingUser.Username {
|
||||
tmpUser, err := s.userRepo.GetByUsername(user.Username)
|
||||
if err == nil && tmpUser != nil && tmpUser.ID != user.ID {
|
||||
return ErrUserExists
|
||||
}
|
||||
}
|
||||
|
||||
// 保持原有密码
|
||||
user.Password = existingUser.Password
|
||||
user.UpdatedAt = time.Now().Unix()
|
||||
|
||||
return s.userRepo.Update(user)
|
||||
})
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error {
|
||||
// 验证新密码
|
||||
if err := s.ValidatePassword(newPassword); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
// 验证旧密码
|
||||
if !s.comparePasswords(user.Password, oldPassword) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
// 加密新密码
|
||||
hashedPassword, err := s.hashPassword(newPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.Password = hashedPassword
|
||||
user.UpdatedAt = time.Now().Unix()
|
||||
|
||||
return s.userRepo.Update(user)
|
||||
})
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
func (s *userService) ResetPassword(ctx context.Context, userID int64, operatorID int64) error {
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
// 生成随机密码
|
||||
newPassword := generateRandomPassword()
|
||||
hashedPassword, err := s.hashPassword(newPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.Password = hashedPassword
|
||||
user.UpdatedAt = time.Now().Unix()
|
||||
|
||||
// TODO: 发送新密码给用户邮箱
|
||||
|
||||
return s.userRepo.Update(user)
|
||||
})
|
||||
}
|
||||
|
||||
// ListUsers 获取用户列表(增加过滤功能)
|
||||
func (s *userService) ListUsers(ctx context.Context, limit, offset int, active string) ([]model.User, error) {
|
||||
if limit < 0 {
|
||||
limit = 20
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
var users []model.User
|
||||
var err error
|
||||
if active != "" {
|
||||
users, _, err = s.userRepo.List(limit, offset, map[string]interface{}{"active in ?": strings.Split(active, ",")})
|
||||
} else {
|
||||
users, _, err = s.userRepo.List(limit, offset, nil)
|
||||
}
|
||||
|
||||
return users, err
|
||||
}
|
||||
|
||||
// generateRandomPassword 生成随机密码
|
||||
func generateRandomPassword() string {
|
||||
const (
|
||||
lowerChars = "abcdefghijklmnopqrstuvwxyz"
|
||||
upperChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
numberChars = "0123456789"
|
||||
specialChars = "!@#$%^&*"
|
||||
)
|
||||
// rand.NewSource(uint64(time.Now().UnixNano()))
|
||||
|
||||
// 确保每种字符都至少出现一次
|
||||
password := []string{
|
||||
string(lowerChars[rand.Intn(len(lowerChars))]),
|
||||
string(upperChars[rand.Intn(len(upperChars))]),
|
||||
string(numberChars[rand.Intn(len(numberChars))]),
|
||||
string(specialChars[rand.Intn(len(specialChars))]),
|
||||
}
|
||||
|
||||
// 所有可用字符
|
||||
allChars := lowerChars + upperChars + numberChars + specialChars
|
||||
|
||||
// 生成剩余的12个字符
|
||||
for i := 0; i < 12; i++ {
|
||||
password = append(password, string(allChars[rand.Intn(len(allChars))]))
|
||||
}
|
||||
|
||||
// 打乱密码字符顺序
|
||||
rand.Shuffle(len(password), func(i, j int) {
|
||||
password[i], password[j] = password[j], password[i]
|
||||
})
|
||||
|
||||
return strings.Join(password, "")
|
||||
}
|
||||
|
||||
// DeleteUser 删除用户
|
||||
func (s *userService) DeleteUser(ctx context.Context, id int64, operatorID int64) error {
|
||||
// 检查参数
|
||||
if id <= 0 {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 不允许删除自己
|
||||
if id == operatorID {
|
||||
return ErrInvalidOperation
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查用户是否存在
|
||||
user, err := s.userRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查是否试图删除管理员
|
||||
if *user.Role == consts.RoleAdmin {
|
||||
return ErrPermissionDenied
|
||||
}
|
||||
|
||||
return s.userRepo.Delete(id)
|
||||
})
|
||||
}
|
||||
|
||||
// EnableUser 启用用户
|
||||
// func (s *userService) EnableUser(ctx context.Context, id int64, operatorID int64) error {
|
||||
// // 检查参数
|
||||
// if id <= 0 {
|
||||
// return ErrInvalidUserInput
|
||||
// }
|
||||
|
||||
// // 检查操作者权限
|
||||
// if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// // 检查用户是否存在
|
||||
// user, err := s.userRepo.GetByID(id)
|
||||
// if err != nil {
|
||||
// return ErrUserNotFound
|
||||
// }
|
||||
|
||||
// // 如果用户已经是启用状态,返回成功
|
||||
// if user.Status == consts.StatusEnabled {
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// return s.userRepo.Enable(id)
|
||||
// })
|
||||
// }
|
||||
|
||||
// DisableUser 禁用用户
|
||||
// func (s *userService) DisableUser(ctx context.Context, id int64, operatorID int64) error {
|
||||
// // 检查参数
|
||||
// if id <= 0 {
|
||||
// return ErrInvalidUserInput
|
||||
// }
|
||||
|
||||
// // 检查操作者权限
|
||||
// if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// // 不允许禁用自己
|
||||
// if id == operatorID {
|
||||
// return ErrInvalidOperation
|
||||
// }
|
||||
|
||||
// return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// // 检查用户是否存在
|
||||
// user, err := s.userRepo.GetByID(id)
|
||||
// if err != nil {
|
||||
// return ErrUserNotFound
|
||||
// }
|
||||
|
||||
// // 检查是否试图禁用超级管理员
|
||||
// if user.Role == consts.RoleAdmin {
|
||||
// return ErrPermissionDenied
|
||||
// }
|
||||
|
||||
// // 如果用户已经是禁用状态,返回成功
|
||||
// if user.Status == consts.StatusDisabled {
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// return s.userRepo.Disable(id)
|
||||
// })
|
||||
// }
|
||||
|
||||
// BatchEnableUsers 批量启用用户
|
||||
func (s *userService) BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error {
|
||||
// 检查参数
|
||||
if len(ids) == 0 {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
// 检查操作者权限
|
||||
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查所有用户是否存在,并收集当前状态
|
||||
enabledUsers := make([]int64, 0)
|
||||
for _, id := range ids {
|
||||
user, err := s.userRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
if *user.Active == true {
|
||||
enabledUsers = append(enabledUsers, id)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果所有用户都已经是启用状态,返回成功
|
||||
if len(enabledUsers) == len(ids) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 过滤掉已经启用的用户,只处理需要启用的用户
|
||||
toEnableIds := make([]int64, 0)
|
||||
for _, id := range ids {
|
||||
if !contains(enabledUsers, id) {
|
||||
toEnableIds = append(toEnableIds, id)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toEnableIds) > 0 {
|
||||
return s.userRepo.BatchEnable(toEnableIds, nil)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// BatchDisableUsers 批量禁用用户
|
||||
func (s *userService) BatchDisableUsers(ctx context.Context, ids []int64, operatorID int64) error {
|
||||
// 检查参数
|
||||
if len(ids) == 0 {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
// 检查操作者权限
|
||||
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 不允许包含自己
|
||||
if contains(ids, operatorID) {
|
||||
return ErrInvalidOperation
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查所有用户是否存在
|
||||
disabledUsers := make([]int64, 0)
|
||||
for _, id := range ids {
|
||||
user, err := s.userRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
// 不允许禁用管理员
|
||||
if *user.Role == consts.RoleAdmin {
|
||||
return ErrPermissionDenied
|
||||
}
|
||||
if user.Status == consts.StatusDisabled {
|
||||
disabledUsers = append(disabledUsers, id)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果所有用户都已经是禁用状态,返回成功
|
||||
if len(disabledUsers) == len(ids) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 过滤掉已经禁用的用户,只处理需要禁用的用户
|
||||
toDisableIds := make([]int64, 0)
|
||||
for _, id := range ids {
|
||||
if !contains(disabledUsers, id) {
|
||||
toDisableIds = append(toDisableIds, id)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toDisableIds) > 0 {
|
||||
return s.userRepo.BatchDisable(toDisableIds, nil)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// BatchDeleteUsers 批量删除用户
|
||||
func (s *userService) BatchDeleteUsers(ctx context.Context, ids []int64, operatorID int64) error {
|
||||
// 检查参数
|
||||
if len(ids) == 0 {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
// 检查操作者权限
|
||||
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 不允许包含自己
|
||||
if contains(ids, operatorID) {
|
||||
return ErrInvalidOperation
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查所有用户是否存在,并确保不会删除管理员
|
||||
for _, id := range ids {
|
||||
user, err := s.userRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
if *user.Role == consts.RoleAdmin {
|
||||
return ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
|
||||
return s.userRepo.BatchDelete(ids, nil)
|
||||
})
|
||||
}
|
||||
|
||||
// contains 检查切片中是否包含特定值
|
||||
func contains[T comparable](slice []T, item T) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
251
internal/service/token.go
Normal file
251
internal/service/token.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"opencatd-open/internal/consts"
|
||||
"opencatd-open/internal/dao"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/internal/utils"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// var _ TokenService = (*TokenServiceImpl)(nil)
|
||||
|
||||
// type TokenService interface {
|
||||
// }
|
||||
|
||||
type TokenServiceImpl struct {
|
||||
db *gorm.DB
|
||||
tokenRepo dao.TokenRepository
|
||||
}
|
||||
|
||||
func NewTokenService(db *gorm.DB, tokenRepo dao.TokenRepository) *TokenServiceImpl {
|
||||
return &TokenServiceImpl{
|
||||
db: db,
|
||||
tokenRepo: tokenRepo,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TokenServiceImpl) CreateToken(ctx context.Context, token *model.Token) error {
|
||||
if token.UserID == 0 {
|
||||
token.UserID = ctx.Value("user_id").(int64)
|
||||
}
|
||||
if token.Active == nil {
|
||||
token.Active = utils.ToPtr(true)
|
||||
}
|
||||
if token.UnlimitedQuota == nil {
|
||||
token.UnlimitedQuota = utils.ToPtr(true)
|
||||
}
|
||||
if token.ExpiredAt == nil {
|
||||
token.ExpiredAt = utils.ToPtr(int64(-1))
|
||||
}
|
||||
|
||||
if token.Key == "" {
|
||||
token.Key = "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
}
|
||||
if !strings.HasPrefix(token.Key, "sk-team-") {
|
||||
token.Key = "sk-team-" + strings.ReplaceAll(token.Key, " ", "")
|
||||
}
|
||||
return t.tokenRepo.Create(ctx, token)
|
||||
}
|
||||
|
||||
func (t *TokenServiceImpl) GetToken(ctx context.Context, id int64) (*model.Token, error) {
|
||||
userid := ctx.Value("user_id").(int64)
|
||||
tk := &model.Token{}
|
||||
return tk, t.db.Model(&model.Token{}).Where("user_id = ?", userid).Where("id = ?", id).First(tk).Error
|
||||
}
|
||||
|
||||
func (t *TokenServiceImpl) ListToken(ctx context.Context, limit, offset int, active []string) ([]*model.Token, int64, error) {
|
||||
userid := ctx.Value("user_id").(int64)
|
||||
condition := make(map[string]interface{})
|
||||
condition["user_id = ?"] = userid
|
||||
if len(active) > 0 {
|
||||
condition["active IN ?"] = utils.StringToBool(active)
|
||||
return t.tokenRepo.ListWithFilters(ctx, limit, offset, condition)
|
||||
}
|
||||
return t.tokenRepo.ListWithFilters(ctx, limit, offset, condition)
|
||||
}
|
||||
|
||||
func (t *TokenServiceImpl) UpdateToken(ctx context.Context, token *model.Token) error {
|
||||
userid := ctx.Value("user_id").(int64) // 操作者
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
|
||||
role, ok := userRoleValue.(*consts.UserRole) // 操作角色
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
switch {
|
||||
case *role < consts.RoleAdmin:
|
||||
if userid != token.UserID {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
case *role == consts.RoleAdmin:
|
||||
if *role <= *token.User.Role {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
}
|
||||
|
||||
return t.db.Model(&model.Token{}).Where("id = ?", token.ID).Updates(token).Error
|
||||
}
|
||||
|
||||
func (t *TokenServiceImpl) ResetToken(ctx context.Context, id int64) error {
|
||||
userid := ctx.Value("user_id").(int64) // 操作者
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
|
||||
role, ok := userRoleValue.(*consts.UserRole) // 操作角色
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
switch {
|
||||
case *role < consts.RoleAdmin:
|
||||
if userid != id {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
case *role == consts.RoleAdmin:
|
||||
var user = &model.User{}
|
||||
if err := t.db.Model(&model.User{}).Where("id = ?", id).First(user).Error; err != nil {
|
||||
return fmt.Errorf("User not found")
|
||||
}
|
||||
if *role <= *user.Role {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
}
|
||||
|
||||
token := "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
return t.db.Model(&model.Token{}).Where("user_id = ?", userid).Where("id = ?", id).Update("token", token).Error
|
||||
}
|
||||
func (t *TokenServiceImpl) DeleteToken(ctx context.Context, id int64) error {
|
||||
token, err := t.tokenRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Token not found")
|
||||
}
|
||||
if token.User == nil {
|
||||
return fmt.Errorf("Token user not found")
|
||||
}
|
||||
|
||||
role := ctx.Value("user_role").(*consts.UserRole) // 操作角色
|
||||
userid := ctx.Value("user_id").(int64) // 操作者
|
||||
|
||||
switch {
|
||||
case *role < consts.RoleAdmin:
|
||||
if userid != token.UserID {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
case *role == consts.RoleAdmin:
|
||||
if *role <= *token.User.Role {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
}
|
||||
|
||||
return t.db.Model(&model.Token{}).Where("id = ?", id).Delete(&model.Token{}).Error
|
||||
}
|
||||
|
||||
func (t *TokenServiceImpl) DeleteTokens(ctx context.Context, userid int64, ids []int64) error {
|
||||
operator_id := ctx.Value("user_id").(int64)
|
||||
|
||||
roleValue := ctx.Value("user_role")
|
||||
if roleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
operator_role, ok := roleValue.(*consts.UserRole) // 操作角色
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
switch {
|
||||
case *operator_role < consts.RoleAdmin:
|
||||
if operator_id != userid {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
return t.tokenRepo.BatchDelete(ctx, ids, map[string]interface{}{"name != ?": "default", "user_id = ?": userid})
|
||||
case *operator_role == consts.RoleAdmin:
|
||||
var user = &model.User{}
|
||||
if err := t.db.Model(&model.User{}).Where("id = ?", userid).First(user).Error; err != nil {
|
||||
return fmt.Errorf("User not found")
|
||||
}
|
||||
if *operator_role <= *user.Role {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
return t.tokenRepo.BatchDelete(ctx, ids, map[string]interface{}{"name != ?": "default", "user_id = ?": userid})
|
||||
default:
|
||||
return t.tokenRepo.BatchDelete(ctx, ids, map[string]interface{}{"name != ?": "default"})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (t *TokenServiceImpl) EnableTokens(ctx context.Context, userid int64, ids []int64) error {
|
||||
operator_id := ctx.Value("user_id").(int64)
|
||||
|
||||
roleValue := ctx.Value("user_role")
|
||||
if roleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
operator_role, ok := roleValue.(*consts.UserRole) // 操作角色
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
switch {
|
||||
case *operator_role < consts.RoleAdmin:
|
||||
if operator_id != userid {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
return t.tokenRepo.BatchEnable(ctx, ids, map[string]interface{}{"user_id = ?": userid})
|
||||
case *operator_role == consts.RoleAdmin:
|
||||
var user = &model.User{}
|
||||
if err := t.db.Model(&model.User{}).Where("id = ?", userid).First(user).Error; err != nil {
|
||||
return fmt.Errorf("User not found")
|
||||
}
|
||||
if *operator_role <= *user.Role {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
return t.tokenRepo.BatchEnable(ctx, ids, map[string]interface{}{"user_id = ?": userid})
|
||||
default:
|
||||
return t.tokenRepo.BatchEnable(ctx, ids, nil)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (t *TokenServiceImpl) DisableTokens(ctx context.Context, userid int64, ids []int64) error {
|
||||
operator_id := ctx.Value("user_id").(int64)
|
||||
|
||||
roleValue := ctx.Value("user_role")
|
||||
if roleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
operator_role, ok := roleValue.(*consts.UserRole) // 操作角色
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
switch {
|
||||
case *operator_role < consts.RoleAdmin:
|
||||
if operator_id != userid {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
return t.tokenRepo.BatchDisable(ctx, ids, map[string]interface{}{"user_id =": userid})
|
||||
case *operator_role == consts.RoleAdmin:
|
||||
var user = &model.User{}
|
||||
if err := t.db.Model(&model.User{}).Where("id = ?", userid).First(user).Error; err != nil {
|
||||
return fmt.Errorf("User not found")
|
||||
}
|
||||
if *operator_role <= *user.Role {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
return t.tokenRepo.BatchDisable(ctx, ids, map[string]interface{}{"user_id =": userid})
|
||||
default:
|
||||
return t.tokenRepo.BatchDisable(ctx, ids, nil)
|
||||
}
|
||||
|
||||
}
|
||||
22
internal/service/usage.go
Normal file
22
internal/service/usage.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"opencatd-open/pkg/config"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UsageService struct {
|
||||
Ctx context.Context
|
||||
Cfg *config.Config
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
func NewUsageService(ctx context.Context, cfg *config.Config, db *gorm.DB) *UsageService {
|
||||
return &UsageService{
|
||||
Ctx: ctx,
|
||||
Cfg: cfg,
|
||||
DB: db,
|
||||
}
|
||||
}
|
||||
320
internal/service/user.go
Normal file
320
internal/service/user.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"opencatd-open/internal/auth"
|
||||
"opencatd-open/internal/consts"
|
||||
"opencatd-open/internal/dao"
|
||||
"opencatd-open/internal/dto"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/internal/utils"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserServiceImpl struct {
|
||||
db *gorm.DB
|
||||
userRepo dao.UserRepository
|
||||
}
|
||||
|
||||
func NewUserService(db *gorm.DB, userRepo dao.UserRepository) *UserServiceImpl {
|
||||
return &UserServiceImpl{
|
||||
db: db,
|
||||
userRepo: userRepo,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) Register(ctx context.Context, req *model.User) error {
|
||||
var _user model.User
|
||||
var count int64
|
||||
err := s.db.Model(&model.User{}).Count(&count).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("username or email already exists")
|
||||
}
|
||||
if count == 0 {
|
||||
_user.Name = "root"
|
||||
_user.Role = utils.ToPtr(consts.RoleRoot)
|
||||
_user.Active = utils.ToPtr(true)
|
||||
_user.UnlimitedQuota = utils.ToPtr(true)
|
||||
}
|
||||
_user.Password, err = utils.HashPassword(req.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_user.Username = req.Username
|
||||
_user.Email = req.Email
|
||||
_user.Tokens = []model.Token{
|
||||
{
|
||||
Name: "default",
|
||||
Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
},
|
||||
}
|
||||
|
||||
return s.userRepo.Create(&_user)
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) Login(ctx context.Context, req *dto.User) (*dto.Auth, error) {
|
||||
var _user model.User
|
||||
if err := s.db.Model(&model.User{}).Where("username = ?", req.Username).First(&_user).Error; err != nil {
|
||||
if err := s.db.Model(&model.User{}).Where("email = ?", req.Username).First(&_user).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if utils.CheckPassword(_user.Password, req.Password) {
|
||||
day := 86400
|
||||
at, err := auth.GenerateTokenPair(&_user, consts.SecretKey, time.Duration(day)*time.Second, time.Duration(day*7)*time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.Auth{
|
||||
Token: at.AccessToken,
|
||||
ExpiresIn: time.Now().Add(time.Duration(day) * time.Second).Unix(),
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("密码错误")
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) Profile(ctx context.Context) (*model.User, error) {
|
||||
id := ctx.Value("user_id").(int64)
|
||||
return s.userRepo.GetByID(id)
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) List(ctx context.Context, limit, offset int, active []string) ([]model.User, int64, error) {
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return nil, 0, fmt.Errorf("user role not found in context")
|
||||
}
|
||||
|
||||
role, ok := userRoleValue.(*consts.UserRole)
|
||||
if !ok {
|
||||
return nil, 0, fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
if *role < consts.RoleAdmin {
|
||||
return nil, 0, fmt.Errorf("Unauthorized")
|
||||
} else if *role < consts.RoleRoot { // 管理员只能查看普通用户
|
||||
var condition = map[string]interface{}{"role = ?": consts.RoleUser}
|
||||
if len(active) > 0 {
|
||||
boolCondition := utils.StringToBool(active)
|
||||
condition["active IN ?"] = boolCondition
|
||||
}
|
||||
return s.userRepo.List(limit, offset, condition)
|
||||
} else {
|
||||
var condition = make(map[string]interface{})
|
||||
if len(active) > 0 {
|
||||
boolCondition := utils.StringToBool(active)
|
||||
condition["active IN ?"] = boolCondition
|
||||
}
|
||||
return s.userRepo.List(limit, offset, condition)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) Create(ctx context.Context, req *model.User) error {
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
|
||||
role, ok := userRoleValue.(*consts.UserRole)
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
var _user model.User
|
||||
|
||||
if *role < consts.RoleAdmin {
|
||||
return fmt.Errorf("Forbidden")
|
||||
} else if *role < consts.RoleRoot {
|
||||
_user.Role = utils.ToPtr(consts.RoleRoot)
|
||||
} else {
|
||||
_user.Role = req.Role
|
||||
}
|
||||
_user.Username = req.Username
|
||||
_user.Name = req.Name
|
||||
_user.Email = req.Email
|
||||
_user.Active = req.Active
|
||||
_user.Quota = req.Quota
|
||||
_user.UnlimitedQuota = req.UnlimitedQuota
|
||||
_user.Language = req.Language
|
||||
if hashpass, err := utils.HashPassword(req.Password); err != nil {
|
||||
return err
|
||||
} else {
|
||||
_user.Password = hashpass
|
||||
}
|
||||
_user.Tokens = []model.Token{
|
||||
{
|
||||
Name: "default",
|
||||
Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
},
|
||||
}
|
||||
|
||||
return s.userRepo.Create(&_user)
|
||||
}
|
||||
func (s *UserServiceImpl) GetByID(ctx context.Context, id int64) (*model.User, error) {
|
||||
return s.userRepo.GetByID(id)
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) Update(ctx context.Context, user *model.User) error {
|
||||
_user := ctx.Value("user").(*model.User) // 被更新的用户
|
||||
if _user == nil {
|
||||
return fmt.Errorf("user not found in context")
|
||||
}
|
||||
userid := ctx.Value("user_id").(int64) // 操作者
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
|
||||
role, ok := userRoleValue.(*consts.UserRole) // 操作者角色
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
switch {
|
||||
case *role < consts.RoleAdmin:
|
||||
if user.ID != userid {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
case *role == consts.RoleAdmin:
|
||||
if *user.Role > *role { // 更新的用户角色不能高于操作者角色
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
if *_user.Role >= *role { // 管理员之间不能被修改
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
case *role > consts.RoleAdmin: // 根不能被修改
|
||||
if user.ID == userid {
|
||||
user.Role = role // root不能修改自己的角色
|
||||
} else {
|
||||
if user.Role != nil && user.Role == utils.ToPtr(consts.RoleRoot) {
|
||||
return fmt.Errorf("Root user Only one can exist")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if user.Name != "" {
|
||||
_user.Name = user.Name
|
||||
}
|
||||
if user.Username != "" {
|
||||
_user.Username = user.Username
|
||||
}
|
||||
if user.Email != "" {
|
||||
_user.Email = user.Email
|
||||
_user.EmailVerified = utils.ToPtr(false)
|
||||
}
|
||||
if user.Active != nil {
|
||||
_user.Active = user.Active
|
||||
}
|
||||
if user.Role != nil {
|
||||
_user.Role = user.Role
|
||||
}
|
||||
if user.Active != nil {
|
||||
_user.Active = user.Active
|
||||
}
|
||||
if user.Quota != nil {
|
||||
_user.Quota = user.Quota
|
||||
}
|
||||
if user.UsedQuota != nil {
|
||||
_user.UsedQuota = user.UsedQuota
|
||||
}
|
||||
if user.UnlimitedQuota != nil {
|
||||
_user.UnlimitedQuota = user.UnlimitedQuota
|
||||
}
|
||||
if user.Timezone != "" {
|
||||
_user.Timezone = user.Timezone
|
||||
}
|
||||
if user.Language != "" {
|
||||
_user.Language = user.Language
|
||||
}
|
||||
return s.userRepo.Update(_user)
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) Delete(ctx context.Context, id int64) error {
|
||||
_user, err := s.userRepo.GetByID(id) // 被更新的用户
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userid := ctx.Value("user_id").(int64)
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
role, ok := userRoleValue.(*consts.UserRole) // 操作者
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
switch {
|
||||
case *role < consts.RoleAdmin:
|
||||
if _user.ID != userid {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
case *role == consts.RoleAdmin:
|
||||
if *_user.Role >= *role { // 管理员之间不能被修改
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
case *_user.Role == consts.RoleRoot: // 根不能被修改
|
||||
return fmt.Errorf("Root user can not be modified")
|
||||
}
|
||||
|
||||
return s.userRepo.Delete(id)
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) BatchDelete(ctx context.Context, ids []int64) error {
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
role, ok := userRoleValue.(*consts.UserRole)
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
switch {
|
||||
case *role < consts.RoleAdmin:
|
||||
return fmt.Errorf("Unauthorized")
|
||||
case *role == consts.RoleAdmin:
|
||||
return s.userRepo.BatchDelete(ids, []string{fmt.Sprintf("role < %d", role)})
|
||||
}
|
||||
return s.userRepo.BatchDelete(ids, []string{fmt.Sprintf("role < %d", consts.RoleRoot)})
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) BatchEnable(ctx context.Context, ids []int64) error {
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
role, ok := userRoleValue.(*consts.UserRole)
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
switch {
|
||||
case *role < consts.RoleAdmin:
|
||||
return fmt.Errorf("Unauthorized")
|
||||
case *role == consts.RoleAdmin:
|
||||
return s.userRepo.BatchEnable(ids, []string{fmt.Sprintf("role < %d", role)})
|
||||
}
|
||||
return s.userRepo.BatchEnable(ids, nil)
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) BatchDisable(ctx context.Context, ids []int64) error {
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
role, ok := userRoleValue.(*consts.UserRole)
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
switch {
|
||||
case *role < consts.RoleAdmin:
|
||||
return fmt.Errorf("Unauthorized")
|
||||
case *role == consts.RoleAdmin:
|
||||
return s.userRepo.BatchDisable(ids, []string{fmt.Sprintf("role < %d", role)})
|
||||
}
|
||||
return s.userRepo.BatchDisable(ids, nil)
|
||||
}
|
||||
304
internal/service/webauth.go
Normal file
304
internal/service/webauth.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/pkg/config"
|
||||
"opencatd-open/pkg/store"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
"github.com/mileusna/useragent"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ webauthn.User = (*WebAuthnUser)(nil)
|
||||
|
||||
// WebAuthnUser 实现webauthn.User接口的结构体
|
||||
type WebAuthnUser struct {
|
||||
User *model.User
|
||||
// ID int64
|
||||
// Name string
|
||||
// DisplayName string
|
||||
Credentials []webauthn.Credential
|
||||
}
|
||||
|
||||
// WebAuthnID 返回用户ID
|
||||
func (u *WebAuthnUser) WebAuthnID() []byte {
|
||||
return []byte(strconv.Itoa(int(u.User.ID)))
|
||||
}
|
||||
|
||||
// WebAuthnName 返回用户名
|
||||
func (u *WebAuthnUser) WebAuthnName() string {
|
||||
return u.User.Username
|
||||
}
|
||||
|
||||
// WebAuthnDisplayName 返回用户显示名
|
||||
func (u *WebAuthnUser) WebAuthnDisplayName() string {
|
||||
return u.User.Name
|
||||
}
|
||||
|
||||
// WebAuthnCredentials 返回用户所有凭证
|
||||
func (u *WebAuthnUser) WebAuthnCredentials() []webauthn.Credential {
|
||||
return u.Credentials
|
||||
}
|
||||
|
||||
func (u *WebAuthnUser) WebAuthnCredentialDescriptors() (descriptors []protocol.CredentialDescriptor) {
|
||||
credentials := u.WebAuthnCredentials()
|
||||
|
||||
descriptors = make([]protocol.CredentialDescriptor, len(credentials))
|
||||
|
||||
for i, credential := range credentials {
|
||||
descriptors[i] = credential.Descriptor()
|
||||
}
|
||||
|
||||
return descriptors
|
||||
}
|
||||
|
||||
// WebAuthnService 提供WebAuthn相关功能
|
||||
type WebAuthnService struct {
|
||||
DB *gorm.DB
|
||||
WebAuthn *webauthn.WebAuthn
|
||||
// Sessions map[string]webauthn.SessionData // 用于存储注册和认证过程中的会话数据
|
||||
Sessions *store.WebAuthnSessionStore
|
||||
}
|
||||
|
||||
// NewWebAuthnService 创建新的WebAuthn服务
|
||||
func NewWebAuthnService(db *gorm.DB, cfg *config.Config) (*WebAuthnService, error) {
|
||||
// 创建WebAuthn配置
|
||||
wconfig := &webauthn.Config{
|
||||
RPDisplayName: config.Cfg.AppName, // 依赖方(Relying Party)显示名称
|
||||
RPID: config.Cfg.Domain, // 依赖方ID(通常为域名)
|
||||
RPOrigins: []string{config.Cfg.AppURL}, // 依赖方源(URL)
|
||||
AuthenticatorSelection: protocol.AuthenticatorSelection{
|
||||
RequireResidentKey: protocol.ResidentKeyRequired(), // 要求认证器存储用户 ID (resident key)
|
||||
ResidentKey: protocol.ResidentKeyRequirementRequired, // 使用 Discoverable 模式
|
||||
UserVerification: protocol.VerificationPreferred, // 推荐用户验证
|
||||
AuthenticatorAttachment: "", // 允许任何认证器 (平台或跨平台)
|
||||
},
|
||||
// EncodeUserIDAsString: true, // 将用户ID编码为字符串
|
||||
}
|
||||
|
||||
wa, err := webauthn.New(wconfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &WebAuthnService{
|
||||
DB: db,
|
||||
WebAuthn: wa,
|
||||
// Sessions: make(map[string]webauthn.SessionData),
|
||||
Sessions: store.NewWebAuthnSessionStore(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUserWithCredentials 获取用户及其凭证
|
||||
func (s *WebAuthnService) GetUserWithCredentials(userID int64) (*WebAuthnUser, error) {
|
||||
var user model.User
|
||||
if err := s.DB.Model(&model.User{}).Preload("Passkeys").First(&user, userID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取用户的所有Passkey
|
||||
passkeys := user.Passkeys
|
||||
|
||||
// 将Passkey转换为webauthn.Credential
|
||||
credentials := make([]webauthn.Credential, len(passkeys))
|
||||
for i, pk := range passkeys {
|
||||
credentialIDBytes, err := base64.StdEncoding.DecodeString(pk.CredentialID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode CredentialID: %w", err)
|
||||
}
|
||||
publicKeyBytes, err := base64.StdEncoding.DecodeString(pk.PublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode PublicKey: %w", err)
|
||||
}
|
||||
aaguidBytes, err := base64.StdEncoding.DecodeString(pk.AAGUID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode AAGUID: %w", err)
|
||||
}
|
||||
|
||||
var transport []protocol.AuthenticatorTransport
|
||||
if pk.Transport != "" {
|
||||
transport = []protocol.AuthenticatorTransport{protocol.AuthenticatorTransport(pk.Transport)}
|
||||
}
|
||||
|
||||
credentials[i] = webauthn.Credential{
|
||||
ID: credentialIDBytes,
|
||||
PublicKey: publicKeyBytes,
|
||||
AttestationType: pk.AttestationType,
|
||||
Transport: transport,
|
||||
Flags: webauthn.CredentialFlags{
|
||||
UserPresent: true,
|
||||
UserVerified: true,
|
||||
BackupEligible: pk.BackupEligible,
|
||||
BackupState: pk.BackupState,
|
||||
},
|
||||
Authenticator: webauthn.Authenticator{
|
||||
AAGUID: aaguidBytes,
|
||||
SignCount: pk.SignCount,
|
||||
CloneWarning: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 创建WebAuthnUser
|
||||
return &WebAuthnUser{
|
||||
User: &user,
|
||||
Credentials: credentials,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BeginRegistration 开始注册过程
|
||||
func (s *WebAuthnService) BeginRegistration(userID int64) (*protocol.CredentialCreation, error) {
|
||||
user, err := s.GetUserWithCredentials(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取注册选项
|
||||
options, sessionData, err := s.WebAuthn.BeginRegistration(user)
|
||||
// webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired),
|
||||
// webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()), // 排除已存在的凭证
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保存会话数据
|
||||
userid := strconv.Itoa(int(userID))
|
||||
s.Sessions.SaveWebauthnSession(userid, sessionData)
|
||||
|
||||
return options, nil
|
||||
}
|
||||
|
||||
// FinishRegistration 完成注册过程
|
||||
func (s *WebAuthnService) FinishRegistration(userID int64, response *http.Request, deviceName string) (*model.Passkey, error) {
|
||||
user, err := s.GetUserWithCredentials(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userid := strconv.Itoa(int(userID))
|
||||
// 获取并清除会话数据
|
||||
sessionData, err := s.Sessions.GetWebauthnSession(userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.Sessions.DeleteWebauthnSession(userid)
|
||||
|
||||
// 完成注册
|
||||
credential, err := s.WebAuthn.FinishRegistration(user, *sessionData, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ua := useragent.Parse(response.UserAgent())
|
||||
|
||||
var transport string
|
||||
if len(credential.Transport) > 0 {
|
||||
transport = string(credential.Transport[0]) // 通常只取第一个传输方式
|
||||
}
|
||||
// 创建Passkey记录
|
||||
passkey := &model.Passkey{
|
||||
UserID: userID,
|
||||
CredentialID: base64.StdEncoding.EncodeToString(credential.ID),
|
||||
PublicKey: base64.StdEncoding.EncodeToString(credential.PublicKey),
|
||||
AttestationType: string(credential.AttestationType),
|
||||
AAGUID: base64.StdEncoding.EncodeToString(credential.Authenticator.AAGUID),
|
||||
SignCount: credential.Authenticator.SignCount,
|
||||
Name: deviceName,
|
||||
DeviceType: strings.TrimSpace(fmt.Sprintf("%s %s %s %s %s", ua.Device, ua.OS, ua.OSVersionNoFull(), ua.Name, ua.VersionNoFull())),
|
||||
LastUsedAt: time.Now().Unix(),
|
||||
BackupEligible: credential.Flags.BackupEligible,
|
||||
BackupState: credential.Flags.BackupState,
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
// 保存Passkey
|
||||
if err := s.DB.Create(passkey).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return passkey, nil
|
||||
}
|
||||
|
||||
// BeginLogin 开始登录过程 (无需用户ID,针对未认证用户)
|
||||
func (s *WebAuthnService) BeginLogin() (*protocol.CredentialAssertion, error) {
|
||||
// 不指定用户ID,让客户端决定使用哪个凭证
|
||||
options, session, err := s.WebAuthn.BeginDiscoverableLogin(
|
||||
webauthn.WithUserVerification(protocol.VerificationPreferred), // 推荐用户验证
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.Sessions.SaveWebauthnSession(session.Challenge, session)
|
||||
|
||||
return options, nil
|
||||
}
|
||||
|
||||
// FinishLogin 完成登录过程
|
||||
func (s *WebAuthnService) FinishLogin(challenge string, response *http.Request) (*WebAuthnUser, error) {
|
||||
// 获取并清除会话数据
|
||||
sessionData, err := s.Sessions.GetWebauthnSession(challenge)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.Sessions.DeleteWebauthnSession(challenge)
|
||||
|
||||
// 获取相应的用户
|
||||
// var user model.User
|
||||
// if err := s.DB.First(&user, passkey.UserID).Error; err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// 创建WebAuthnUser
|
||||
// webAuthnUser, err := s.GetUserWithCredentials(user.ID)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// 完成登录
|
||||
// _, err = s.WebAuthn.FinishLogin(webAuthnUser, sessionData, response)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
var user *WebAuthnUser
|
||||
_, err = s.WebAuthn.FinishDiscoverableLogin(s.GetWebAuthnUser(&user), *sessionData, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 更新Passkey的LastUsedAt
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) GetWebAuthnUser(wau **WebAuthnUser) webauthn.DiscoverableUserHandler {
|
||||
return func(rawID, userHandle []byte) (webauthn.User, error) {
|
||||
userid, err := strconv.ParseInt(string(userHandle), 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
*wau, err = s.GetUserWithCredentials(userid)
|
||||
return *wau, err
|
||||
}
|
||||
}
|
||||
|
||||
// ListPasskeys 列出用户所有Passkey
|
||||
func (s *WebAuthnService) ListPasskeys(userID int64) ([]model.Passkey, error) {
|
||||
var passkeys []model.Passkey
|
||||
if err := s.DB.Where("user_id = ?", userID).Find(&passkeys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return passkeys, nil
|
||||
}
|
||||
|
||||
// DeletePasskey 删除用户Passkey
|
||||
func (s *WebAuthnService) DeletePasskey(userID int64, passkeyID int64) error {
|
||||
return s.DB.Where("id = ? AND user_id = ?", passkeyID, userID).Delete(&model.Passkey{}).Error
|
||||
}
|
||||
Reference in New Issue
Block a user