318 lines
9.4 KiB
Go
318 lines
9.4 KiB
Go
package service
|
||
|
||
import (
|
||
"encoding/base64"
|
||
"fmt"
|
||
"net/http"
|
||
"opencatd-open/internal/model"
|
||
"opencatd-open/pkg/config"
|
||
"opencatd-open/pkg/store"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/go-webauthn/webauthn/protocol"
|
||
"github.com/go-webauthn/webauthn/webauthn"
|
||
"github.com/mileusna/useragent"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
var _ webauthn.User = (*WebAuthnUser)(nil)
|
||
|
||
// WebAuthnUser 实现webauthn.User接口的结构体
|
||
type WebAuthnUser struct {
|
||
User *model.User
|
||
// ID int64
|
||
// Name string
|
||
// DisplayName string
|
||
Credentials []webauthn.Credential
|
||
}
|
||
|
||
// WebAuthnID 返回用户ID
|
||
func (u *WebAuthnUser) WebAuthnID() []byte {
|
||
return []byte(strconv.Itoa(int(u.User.ID)))
|
||
}
|
||
|
||
// WebAuthnName 返回用户名
|
||
func (u *WebAuthnUser) WebAuthnName() string {
|
||
return u.User.Username
|
||
}
|
||
|
||
// WebAuthnDisplayName 返回用户显示名
|
||
func (u *WebAuthnUser) WebAuthnDisplayName() string {
|
||
return u.User.Name
|
||
}
|
||
|
||
// WebAuthnCredentials 返回用户所有凭证
|
||
func (u *WebAuthnUser) WebAuthnCredentials() []webauthn.Credential {
|
||
return u.Credentials
|
||
}
|
||
|
||
func (u *WebAuthnUser) WebAuthnCredentialDescriptors() (descriptors []protocol.CredentialDescriptor) {
|
||
credentials := u.WebAuthnCredentials()
|
||
|
||
descriptors = make([]protocol.CredentialDescriptor, len(credentials))
|
||
|
||
for i, credential := range credentials {
|
||
descriptors[i] = credential.Descriptor()
|
||
}
|
||
|
||
return descriptors
|
||
}
|
||
|
||
// WebAuthnService 提供WebAuthn相关功能
|
||
type WebAuthnService struct {
|
||
cfg *config.Config
|
||
DB *gorm.DB
|
||
WebAuthn *webauthn.WebAuthn
|
||
// Sessions map[string]webauthn.SessionData // 用于存储注册和认证过程中的会话数据
|
||
Sessions *store.WebAuthnSessionStore
|
||
}
|
||
|
||
// NewWebAuthnService 创建新的WebAuthn服务
|
||
func NewWebAuthnService(cfg *config.Config, db *gorm.DB) (*WebAuthnService, error) {
|
||
// 创建WebAuthn配置
|
||
wconfig := &webauthn.Config{
|
||
RPDisplayName: cfg.AppName, // 依赖方(Relying Party)显示名称
|
||
RPID: cfg.RPID, // 依赖方ID(通常为域名)
|
||
RPOrigins: cfg.RPOrigins, // 依赖方源(URL)
|
||
AuthenticatorSelection: protocol.AuthenticatorSelection{
|
||
RequireResidentKey: protocol.ResidentKeyRequired(), // 要求认证器存储用户 ID (resident key)
|
||
ResidentKey: protocol.ResidentKeyRequirementRequired, // 使用 Discoverable 模式
|
||
UserVerification: protocol.VerificationPreferred, // 推荐用户验证
|
||
AuthenticatorAttachment: "", // 允许任何认证器 (平台或跨平台)
|
||
},
|
||
// EncodeUserIDAsString: true, // 将用户ID编码为字符串
|
||
}
|
||
|
||
wa, err := webauthn.New(wconfig)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &WebAuthnService{
|
||
cfg: cfg,
|
||
DB: db,
|
||
WebAuthn: wa,
|
||
// Sessions: make(map[string]webauthn.SessionData),
|
||
Sessions: store.NewWebAuthnSessionStore(),
|
||
}, nil
|
||
}
|
||
|
||
// GetUserWithCredentials 获取用户及其凭证
|
||
func (s *WebAuthnService) GetUserWithCredentials(userID int64) (*WebAuthnUser, error) {
|
||
var user model.User
|
||
if err := s.DB.Model(&model.User{}).Preload("Passkeys").First(&user, userID).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 获取用户的所有Passkey
|
||
passkeys := user.Passkeys
|
||
|
||
// 将Passkey转换为webauthn.Credential
|
||
credentials := make([]webauthn.Credential, len(passkeys))
|
||
for i, pk := range passkeys {
|
||
credentialIDBytes, err := base64.StdEncoding.DecodeString(pk.CredentialID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to decode CredentialID: %w", err)
|
||
}
|
||
publicKeyBytes, err := base64.StdEncoding.DecodeString(pk.PublicKey)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to decode PublicKey: %w", err)
|
||
}
|
||
aaguidBytes, err := base64.StdEncoding.DecodeString(pk.AAGUID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to decode AAGUID: %w", err)
|
||
}
|
||
|
||
var transport []protocol.AuthenticatorTransport
|
||
if pk.Transport != "" {
|
||
transport = []protocol.AuthenticatorTransport{protocol.AuthenticatorTransport(pk.Transport)}
|
||
}
|
||
|
||
credentials[i] = webauthn.Credential{
|
||
ID: credentialIDBytes,
|
||
PublicKey: publicKeyBytes,
|
||
AttestationType: pk.AttestationType,
|
||
Transport: transport,
|
||
Flags: webauthn.CredentialFlags{
|
||
UserPresent: true,
|
||
UserVerified: true,
|
||
BackupEligible: pk.BackupEligible,
|
||
BackupState: pk.BackupState,
|
||
},
|
||
Authenticator: webauthn.Authenticator{
|
||
AAGUID: aaguidBytes,
|
||
SignCount: pk.SignCount,
|
||
CloneWarning: false,
|
||
},
|
||
}
|
||
}
|
||
|
||
// 创建WebAuthnUser
|
||
return &WebAuthnUser{
|
||
User: &user,
|
||
Credentials: credentials,
|
||
}, nil
|
||
}
|
||
|
||
// BeginRegistration 开始注册过程
|
||
func (s *WebAuthnService) BeginRegistration(userID int64) (*protocol.CredentialCreation, error) {
|
||
user, err := s.GetUserWithCredentials(userID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 获取注册选项
|
||
options, sessionData, err := s.WebAuthn.BeginRegistration(user)
|
||
// webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired),
|
||
// webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()), // 排除已存在的凭证
|
||
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 保存会话数据
|
||
userid := strconv.Itoa(int(userID))
|
||
s.Sessions.SaveWebauthnSession(userid, sessionData)
|
||
|
||
return options, nil
|
||
}
|
||
|
||
// FinishRegistration 完成注册过程
|
||
func (s *WebAuthnService) FinishRegistration(userID int64, response *http.Request, deviceName string) (*model.Passkey, error) {
|
||
user, err := s.GetUserWithCredentials(userID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
userid := strconv.Itoa(int(userID))
|
||
// 获取并清除会话数据
|
||
sessionData, err := s.Sessions.GetWebauthnSession(userid)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
s.Sessions.DeleteWebauthnSession(userid)
|
||
|
||
// 完成注册
|
||
credential, err := s.WebAuthn.FinishRegistration(user, *sessionData, response)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
ua := useragent.Parse(response.UserAgent())
|
||
|
||
var transport string
|
||
if len(credential.Transport) > 0 {
|
||
transport = string(credential.Transport[0]) // 通常只取第一个传输方式
|
||
}
|
||
// 创建Passkey记录
|
||
passkey := &model.Passkey{
|
||
UserID: userID,
|
||
CredentialID: base64.StdEncoding.EncodeToString(credential.ID),
|
||
PublicKey: base64.StdEncoding.EncodeToString(credential.PublicKey),
|
||
AttestationType: string(credential.AttestationType),
|
||
AAGUID: base64.StdEncoding.EncodeToString(credential.Authenticator.AAGUID),
|
||
SignCount: credential.Authenticator.SignCount,
|
||
Name: deviceName,
|
||
DeviceType: strings.TrimSpace(fmt.Sprintf("%s %s %s %s %s", ua.Device, ua.OS, ua.OSVersionNoFull(), ua.Name, ua.VersionNoFull())),
|
||
LastUsedAt: time.Now().Unix(),
|
||
BackupEligible: credential.Flags.BackupEligible,
|
||
BackupState: credential.Flags.BackupState,
|
||
Transport: transport,
|
||
}
|
||
|
||
// 保存Passkey
|
||
if err := s.DB.Create(passkey).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return passkey, nil
|
||
}
|
||
|
||
// BeginLogin 开始登录过程 (无需用户ID,针对未认证用户)
|
||
func (s *WebAuthnService) BeginLogin() (*protocol.CredentialAssertion, error) {
|
||
// 不指定用户ID,让客户端决定使用哪个凭证
|
||
options, session, err := s.WebAuthn.BeginDiscoverableLogin(
|
||
webauthn.WithUserVerification(protocol.VerificationPreferred), // 推荐用户验证
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
s.Sessions.SaveWebauthnSession(session.Challenge, session)
|
||
|
||
return options, nil
|
||
}
|
||
|
||
// FinishLogin 完成登录过程
|
||
func (s *WebAuthnService) FinishLogin(challenge string, response *http.Request) (*WebAuthnUser, error) {
|
||
// 获取并清除会话数据
|
||
sessionData, err := s.Sessions.GetWebauthnSession(challenge)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
s.Sessions.DeleteWebauthnSession(challenge)
|
||
|
||
// 获取相应的用户
|
||
// var user model.User
|
||
// if err := s.DB.First(&user, passkey.UserID).Error; err != nil {
|
||
// return nil, err
|
||
// }
|
||
|
||
// 创建WebAuthnUser
|
||
// webAuthnUser, err := s.GetUserWithCredentials(user.ID)
|
||
// if err != nil {
|
||
// return nil, err
|
||
// }
|
||
|
||
// 完成登录
|
||
// _, err = s.WebAuthn.FinishLogin(webAuthnUser, sessionData, response)
|
||
// if err != nil {
|
||
// return nil, err
|
||
// }
|
||
var user *WebAuthnUser
|
||
wc, err := s.WebAuthn.FinishDiscoverableLogin(s.GetWebAuthnUser(&user), *sessionData, response)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
// 更新Passkey 这里SignCount应该是由验证器上传,但可能为0,手动+1
|
||
var pk model.Passkey
|
||
if err := s.DB.Model(&model.Passkey{}).Where("credential_id = ?", base64.StdEncoding.EncodeToString(wc.ID)).First(&pk).Error; err == nil {
|
||
if err := s.DB.Model(&model.Passkey{}).Where("id = ?", pk.ID, time.Now().Unix()).Updates(map[string]interface{}{
|
||
"sign_count": pk.SignCount + 1,
|
||
"last_used_at": time.Now().Unix(),
|
||
}).Error; err != nil {
|
||
fmt.Println(err)
|
||
}
|
||
|
||
}
|
||
|
||
return user, nil
|
||
}
|
||
|
||
func (s *WebAuthnService) GetWebAuthnUser(wau **WebAuthnUser) webauthn.DiscoverableUserHandler {
|
||
return func(rawID, userHandle []byte) (webauthn.User, error) {
|
||
userid, err := strconv.ParseInt(string(userHandle), 10, 64)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
*wau, err = s.GetUserWithCredentials(userid)
|
||
return *wau, err
|
||
}
|
||
}
|
||
|
||
// ListPasskeys 列出用户所有Passkey
|
||
func (s *WebAuthnService) ListPasskeys(userID int64) ([]model.Passkey, error) {
|
||
var passkeys []model.Passkey
|
||
if err := s.DB.Where("user_id = ?", userID).Find(&passkeys).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return passkeys, nil
|
||
}
|
||
|
||
// DeletePasskey 删除用户Passkey
|
||
func (s *WebAuthnService) DeletePasskey(userID int64, passkeyID int64) error {
|
||
return s.DB.Where("id = ? AND user_id = ?", passkeyID, userID).Delete(&model.Passkey{}).Error
|
||
}
|