Files
2025-04-18 02:47:10 +08:00

318 lines
9.4 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"encoding/base64"
"fmt"
"net/http"
"opencatd-open/internal/model"
"opencatd-open/pkg/config"
"opencatd-open/pkg/store"
"strconv"
"strings"
"time"
"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn"
"github.com/mileusna/useragent"
"gorm.io/gorm"
)
var _ webauthn.User = (*WebAuthnUser)(nil)
// WebAuthnUser 实现webauthn.User接口的结构体
type WebAuthnUser struct {
User *model.User
// ID int64
// Name string
// DisplayName string
Credentials []webauthn.Credential
}
// WebAuthnID 返回用户ID
func (u *WebAuthnUser) WebAuthnID() []byte {
return []byte(strconv.Itoa(int(u.User.ID)))
}
// WebAuthnName 返回用户名
func (u *WebAuthnUser) WebAuthnName() string {
return u.User.Username
}
// WebAuthnDisplayName 返回用户显示名
func (u *WebAuthnUser) WebAuthnDisplayName() string {
return u.User.Name
}
// WebAuthnCredentials 返回用户所有凭证
func (u *WebAuthnUser) WebAuthnCredentials() []webauthn.Credential {
return u.Credentials
}
func (u *WebAuthnUser) WebAuthnCredentialDescriptors() (descriptors []protocol.CredentialDescriptor) {
credentials := u.WebAuthnCredentials()
descriptors = make([]protocol.CredentialDescriptor, len(credentials))
for i, credential := range credentials {
descriptors[i] = credential.Descriptor()
}
return descriptors
}
// WebAuthnService 提供WebAuthn相关功能
type WebAuthnService struct {
cfg *config.Config
DB *gorm.DB
WebAuthn *webauthn.WebAuthn
// Sessions map[string]webauthn.SessionData // 用于存储注册和认证过程中的会话数据
Sessions *store.WebAuthnSessionStore
}
// NewWebAuthnService 创建新的WebAuthn服务
func NewWebAuthnService(cfg *config.Config, db *gorm.DB) (*WebAuthnService, error) {
// 创建WebAuthn配置
wconfig := &webauthn.Config{
RPDisplayName: cfg.AppName, // 依赖方(Relying Party)显示名称
RPID: cfg.RPID, // 依赖方ID(通常为域名)
RPOrigins: cfg.RPOrigins, // 依赖方源(URL)
AuthenticatorSelection: protocol.AuthenticatorSelection{
RequireResidentKey: protocol.ResidentKeyRequired(), // 要求认证器存储用户 ID (resident key)
ResidentKey: protocol.ResidentKeyRequirementRequired, // 使用 Discoverable 模式
UserVerification: protocol.VerificationPreferred, // 推荐用户验证
AuthenticatorAttachment: "", // 允许任何认证器 (平台或跨平台)
},
// EncodeUserIDAsString: true, // 将用户ID编码为字符串
}
wa, err := webauthn.New(wconfig)
if err != nil {
return nil, err
}
return &WebAuthnService{
cfg: cfg,
DB: db,
WebAuthn: wa,
// Sessions: make(map[string]webauthn.SessionData),
Sessions: store.NewWebAuthnSessionStore(),
}, nil
}
// GetUserWithCredentials 获取用户及其凭证
func (s *WebAuthnService) GetUserWithCredentials(userID int64) (*WebAuthnUser, error) {
var user model.User
if err := s.DB.Model(&model.User{}).Preload("Passkeys").First(&user, userID).Error; err != nil {
return nil, err
}
// 获取用户的所有Passkey
passkeys := user.Passkeys
// 将Passkey转换为webauthn.Credential
credentials := make([]webauthn.Credential, len(passkeys))
for i, pk := range passkeys {
credentialIDBytes, err := base64.StdEncoding.DecodeString(pk.CredentialID)
if err != nil {
return nil, fmt.Errorf("failed to decode CredentialID: %w", err)
}
publicKeyBytes, err := base64.StdEncoding.DecodeString(pk.PublicKey)
if err != nil {
return nil, fmt.Errorf("failed to decode PublicKey: %w", err)
}
aaguidBytes, err := base64.StdEncoding.DecodeString(pk.AAGUID)
if err != nil {
return nil, fmt.Errorf("failed to decode AAGUID: %w", err)
}
var transport []protocol.AuthenticatorTransport
if pk.Transport != "" {
transport = []protocol.AuthenticatorTransport{protocol.AuthenticatorTransport(pk.Transport)}
}
credentials[i] = webauthn.Credential{
ID: credentialIDBytes,
PublicKey: publicKeyBytes,
AttestationType: pk.AttestationType,
Transport: transport,
Flags: webauthn.CredentialFlags{
UserPresent: true,
UserVerified: true,
BackupEligible: pk.BackupEligible,
BackupState: pk.BackupState,
},
Authenticator: webauthn.Authenticator{
AAGUID: aaguidBytes,
SignCount: pk.SignCount,
CloneWarning: false,
},
}
}
// 创建WebAuthnUser
return &WebAuthnUser{
User: &user,
Credentials: credentials,
}, nil
}
// BeginRegistration 开始注册过程
func (s *WebAuthnService) BeginRegistration(userID int64) (*protocol.CredentialCreation, error) {
user, err := s.GetUserWithCredentials(userID)
if err != nil {
return nil, err
}
// 获取注册选项
options, sessionData, err := s.WebAuthn.BeginRegistration(user)
// webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired),
// webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()), // 排除已存在的凭证
if err != nil {
return nil, err
}
// 保存会话数据
userid := strconv.Itoa(int(userID))
s.Sessions.SaveWebauthnSession(userid, sessionData)
return options, nil
}
// FinishRegistration 完成注册过程
func (s *WebAuthnService) FinishRegistration(userID int64, response *http.Request, deviceName string) (*model.Passkey, error) {
user, err := s.GetUserWithCredentials(userID)
if err != nil {
return nil, err
}
userid := strconv.Itoa(int(userID))
// 获取并清除会话数据
sessionData, err := s.Sessions.GetWebauthnSession(userid)
if err != nil {
return nil, err
}
s.Sessions.DeleteWebauthnSession(userid)
// 完成注册
credential, err := s.WebAuthn.FinishRegistration(user, *sessionData, response)
if err != nil {
return nil, err
}
ua := useragent.Parse(response.UserAgent())
var transport string
if len(credential.Transport) > 0 {
transport = string(credential.Transport[0]) // 通常只取第一个传输方式
}
// 创建Passkey记录
passkey := &model.Passkey{
UserID: userID,
CredentialID: base64.StdEncoding.EncodeToString(credential.ID),
PublicKey: base64.StdEncoding.EncodeToString(credential.PublicKey),
AttestationType: string(credential.AttestationType),
AAGUID: base64.StdEncoding.EncodeToString(credential.Authenticator.AAGUID),
SignCount: credential.Authenticator.SignCount,
Name: deviceName,
DeviceType: strings.TrimSpace(fmt.Sprintf("%s %s %s %s %s", ua.Device, ua.OS, ua.OSVersionNoFull(), ua.Name, ua.VersionNoFull())),
LastUsedAt: time.Now().Unix(),
BackupEligible: credential.Flags.BackupEligible,
BackupState: credential.Flags.BackupState,
Transport: transport,
}
// 保存Passkey
if err := s.DB.Create(passkey).Error; err != nil {
return nil, err
}
return passkey, nil
}
// BeginLogin 开始登录过程 (无需用户ID针对未认证用户)
func (s *WebAuthnService) BeginLogin() (*protocol.CredentialAssertion, error) {
// 不指定用户ID让客户端决定使用哪个凭证
options, session, err := s.WebAuthn.BeginDiscoverableLogin(
webauthn.WithUserVerification(protocol.VerificationPreferred), // 推荐用户验证
)
if err != nil {
return nil, err
}
s.Sessions.SaveWebauthnSession(session.Challenge, session)
return options, nil
}
// FinishLogin 完成登录过程
func (s *WebAuthnService) FinishLogin(challenge string, response *http.Request) (*WebAuthnUser, error) {
// 获取并清除会话数据
sessionData, err := s.Sessions.GetWebauthnSession(challenge)
if err != nil {
return nil, err
}
s.Sessions.DeleteWebauthnSession(challenge)
// 获取相应的用户
// var user model.User
// if err := s.DB.First(&user, passkey.UserID).Error; err != nil {
// return nil, err
// }
// 创建WebAuthnUser
// webAuthnUser, err := s.GetUserWithCredentials(user.ID)
// if err != nil {
// return nil, err
// }
// 完成登录
// _, err = s.WebAuthn.FinishLogin(webAuthnUser, sessionData, response)
// if err != nil {
// return nil, err
// }
var user *WebAuthnUser
wc, err := s.WebAuthn.FinishDiscoverableLogin(s.GetWebAuthnUser(&user), *sessionData, response)
if err != nil {
return nil, err
}
// 更新Passkey 这里SignCount应该是由验证器上传但可能为0手动+1
var pk model.Passkey
if err := s.DB.Model(&model.Passkey{}).Where("credential_id = ?", base64.StdEncoding.EncodeToString(wc.ID)).First(&pk).Error; err == nil {
if err := s.DB.Model(&model.Passkey{}).Where("id = ?", pk.ID, time.Now().Unix()).Updates(map[string]interface{}{
"sign_count": pk.SignCount + 1,
"last_used_at": time.Now().Unix(),
}).Error; err != nil {
fmt.Println(err)
}
}
return user, nil
}
func (s *WebAuthnService) GetWebAuthnUser(wau **WebAuthnUser) webauthn.DiscoverableUserHandler {
return func(rawID, userHandle []byte) (webauthn.User, error) {
userid, err := strconv.ParseInt(string(userHandle), 10, 64)
if err != nil {
return nil, err
}
*wau, err = s.GetUserWithCredentials(userid)
return *wau, err
}
}
// ListPasskeys 列出用户所有Passkey
func (s *WebAuthnService) ListPasskeys(userID int64) ([]model.Passkey, error) {
var passkeys []model.Passkey
if err := s.DB.Where("user_id = ?", userID).Find(&passkeys).Error; err != nil {
return nil, err
}
return passkeys, nil
}
// DeletePasskey 删除用户Passkey
func (s *WebAuthnService) DeletePasskey(userID int64, passkeyID int64) error {
return s.DB.Where("id = ? AND user_id = ?", passkeyID, userID).Delete(&model.Passkey{}).Error
}