package service import ( "encoding/base64" "fmt" "net/http" "opencatd-open/internal/model" "opencatd-open/pkg/config" "opencatd-open/pkg/store" "strconv" "strings" "time" "github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/webauthn" "github.com/mileusna/useragent" "gorm.io/gorm" ) var _ webauthn.User = (*WebAuthnUser)(nil) // WebAuthnUser 实现webauthn.User接口的结构体 type WebAuthnUser struct { User *model.User // ID int64 // Name string // DisplayName string Credentials []webauthn.Credential } // WebAuthnID 返回用户ID func (u *WebAuthnUser) WebAuthnID() []byte { return []byte(strconv.Itoa(int(u.User.ID))) } // WebAuthnName 返回用户名 func (u *WebAuthnUser) WebAuthnName() string { return u.User.Username } // WebAuthnDisplayName 返回用户显示名 func (u *WebAuthnUser) WebAuthnDisplayName() string { return u.User.Name } // WebAuthnCredentials 返回用户所有凭证 func (u *WebAuthnUser) WebAuthnCredentials() []webauthn.Credential { return u.Credentials } func (u *WebAuthnUser) WebAuthnCredentialDescriptors() (descriptors []protocol.CredentialDescriptor) { credentials := u.WebAuthnCredentials() descriptors = make([]protocol.CredentialDescriptor, len(credentials)) for i, credential := range credentials { descriptors[i] = credential.Descriptor() } return descriptors } // WebAuthnService 提供WebAuthn相关功能 type WebAuthnService struct { cfg *config.Config DB *gorm.DB WebAuthn *webauthn.WebAuthn // Sessions map[string]webauthn.SessionData // 用于存储注册和认证过程中的会话数据 Sessions *store.WebAuthnSessionStore } // NewWebAuthnService 创建新的WebAuthn服务 func NewWebAuthnService(cfg *config.Config, db *gorm.DB) (*WebAuthnService, error) { // 创建WebAuthn配置 wconfig := &webauthn.Config{ RPDisplayName: cfg.AppName, // 依赖方(Relying Party)显示名称 RPID: cfg.RPID, // 依赖方ID(通常为域名) RPOrigins: cfg.RPOrigins, // 依赖方源(URL) AuthenticatorSelection: protocol.AuthenticatorSelection{ RequireResidentKey: protocol.ResidentKeyRequired(), // 要求认证器存储用户 ID (resident key) ResidentKey: protocol.ResidentKeyRequirementRequired, // 使用 Discoverable 模式 UserVerification: protocol.VerificationPreferred, // 推荐用户验证 AuthenticatorAttachment: "", // 允许任何认证器 (平台或跨平台) }, // EncodeUserIDAsString: true, // 将用户ID编码为字符串 } wa, err := webauthn.New(wconfig) if err != nil { return nil, err } return &WebAuthnService{ cfg: cfg, DB: db, WebAuthn: wa, // Sessions: make(map[string]webauthn.SessionData), Sessions: store.NewWebAuthnSessionStore(), }, nil } // GetUserWithCredentials 获取用户及其凭证 func (s *WebAuthnService) GetUserWithCredentials(userID int64) (*WebAuthnUser, error) { var user model.User if err := s.DB.Model(&model.User{}).Preload("Passkeys").First(&user, userID).Error; err != nil { return nil, err } // 获取用户的所有Passkey passkeys := user.Passkeys // 将Passkey转换为webauthn.Credential credentials := make([]webauthn.Credential, len(passkeys)) for i, pk := range passkeys { credentialIDBytes, err := base64.StdEncoding.DecodeString(pk.CredentialID) if err != nil { return nil, fmt.Errorf("failed to decode CredentialID: %w", err) } publicKeyBytes, err := base64.StdEncoding.DecodeString(pk.PublicKey) if err != nil { return nil, fmt.Errorf("failed to decode PublicKey: %w", err) } aaguidBytes, err := base64.StdEncoding.DecodeString(pk.AAGUID) if err != nil { return nil, fmt.Errorf("failed to decode AAGUID: %w", err) } var transport []protocol.AuthenticatorTransport if pk.Transport != "" { transport = []protocol.AuthenticatorTransport{protocol.AuthenticatorTransport(pk.Transport)} } credentials[i] = webauthn.Credential{ ID: credentialIDBytes, PublicKey: publicKeyBytes, AttestationType: pk.AttestationType, Transport: transport, Flags: webauthn.CredentialFlags{ UserPresent: true, UserVerified: true, BackupEligible: pk.BackupEligible, BackupState: pk.BackupState, }, Authenticator: webauthn.Authenticator{ AAGUID: aaguidBytes, SignCount: pk.SignCount, CloneWarning: false, }, } } // 创建WebAuthnUser return &WebAuthnUser{ User: &user, Credentials: credentials, }, nil } // BeginRegistration 开始注册过程 func (s *WebAuthnService) BeginRegistration(userID int64) (*protocol.CredentialCreation, error) { user, err := s.GetUserWithCredentials(userID) if err != nil { return nil, err } // 获取注册选项 options, sessionData, err := s.WebAuthn.BeginRegistration(user) // webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), // webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()), // 排除已存在的凭证 if err != nil { return nil, err } // 保存会话数据 userid := strconv.Itoa(int(userID)) s.Sessions.SaveWebauthnSession(userid, sessionData) return options, nil } // FinishRegistration 完成注册过程 func (s *WebAuthnService) FinishRegistration(userID int64, response *http.Request, deviceName string) (*model.Passkey, error) { user, err := s.GetUserWithCredentials(userID) if err != nil { return nil, err } userid := strconv.Itoa(int(userID)) // 获取并清除会话数据 sessionData, err := s.Sessions.GetWebauthnSession(userid) if err != nil { return nil, err } s.Sessions.DeleteWebauthnSession(userid) // 完成注册 credential, err := s.WebAuthn.FinishRegistration(user, *sessionData, response) if err != nil { return nil, err } ua := useragent.Parse(response.UserAgent()) var transport string if len(credential.Transport) > 0 { transport = string(credential.Transport[0]) // 通常只取第一个传输方式 } // 创建Passkey记录 passkey := &model.Passkey{ UserID: userID, CredentialID: base64.StdEncoding.EncodeToString(credential.ID), PublicKey: base64.StdEncoding.EncodeToString(credential.PublicKey), AttestationType: string(credential.AttestationType), AAGUID: base64.StdEncoding.EncodeToString(credential.Authenticator.AAGUID), SignCount: credential.Authenticator.SignCount, Name: deviceName, DeviceType: strings.TrimSpace(fmt.Sprintf("%s %s %s %s %s", ua.Device, ua.OS, ua.OSVersionNoFull(), ua.Name, ua.VersionNoFull())), LastUsedAt: time.Now().Unix(), BackupEligible: credential.Flags.BackupEligible, BackupState: credential.Flags.BackupState, Transport: transport, } // 保存Passkey if err := s.DB.Create(passkey).Error; err != nil { return nil, err } return passkey, nil } // BeginLogin 开始登录过程 (无需用户ID,针对未认证用户) func (s *WebAuthnService) BeginLogin() (*protocol.CredentialAssertion, error) { // 不指定用户ID,让客户端决定使用哪个凭证 options, session, err := s.WebAuthn.BeginDiscoverableLogin( webauthn.WithUserVerification(protocol.VerificationPreferred), // 推荐用户验证 ) if err != nil { return nil, err } s.Sessions.SaveWebauthnSession(session.Challenge, session) return options, nil } // FinishLogin 完成登录过程 func (s *WebAuthnService) FinishLogin(challenge string, response *http.Request) (*WebAuthnUser, error) { // 获取并清除会话数据 sessionData, err := s.Sessions.GetWebauthnSession(challenge) if err != nil { return nil, err } s.Sessions.DeleteWebauthnSession(challenge) // 获取相应的用户 // var user model.User // if err := s.DB.First(&user, passkey.UserID).Error; err != nil { // return nil, err // } // 创建WebAuthnUser // webAuthnUser, err := s.GetUserWithCredentials(user.ID) // if err != nil { // return nil, err // } // 完成登录 // _, err = s.WebAuthn.FinishLogin(webAuthnUser, sessionData, response) // if err != nil { // return nil, err // } var user *WebAuthnUser _, err = s.WebAuthn.FinishDiscoverableLogin(s.GetWebAuthnUser(&user), *sessionData, response) if err != nil { return nil, err } // 更新Passkey的LastUsedAt return user, nil } func (s *WebAuthnService) GetWebAuthnUser(wau **WebAuthnUser) webauthn.DiscoverableUserHandler { return func(rawID, userHandle []byte) (webauthn.User, error) { userid, err := strconv.ParseInt(string(userHandle), 10, 64) if err != nil { return nil, err } *wau, err = s.GetUserWithCredentials(userid) return *wau, err } } // ListPasskeys 列出用户所有Passkey func (s *WebAuthnService) ListPasskeys(userID int64) ([]model.Passkey, error) { var passkeys []model.Passkey if err := s.DB.Where("user_id = ?", userID).Find(&passkeys).Error; err != nil { return nil, err } return passkeys, nil } // DeletePasskey 删除用户Passkey func (s *WebAuthnService) DeletePasskey(userID int64, passkeyID int64) error { return s.DB.Where("id = ? AND user_id = ?", passkeyID, userID).Delete(&model.Passkey{}).Error }