reface to openteam

This commit is contained in:
Sakurasan
2025-04-16 18:01:27 +08:00
parent bc223d6530
commit e7ffc9e8b9
92 changed files with 5345 additions and 1273 deletions
+304
View File
@@ -0,0 +1,304 @@
package service
import (
"encoding/base64"
"fmt"
"net/http"
"opencatd-open/internal/model"
"opencatd-open/pkg/config"
"opencatd-open/pkg/store"
"strconv"
"strings"
"time"
"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn"
"github.com/mileusna/useragent"
"gorm.io/gorm"
)
var _ webauthn.User = (*WebAuthnUser)(nil)
// WebAuthnUser 实现webauthn.User接口的结构体
type WebAuthnUser struct {
User *model.User
// ID int64
// Name string
// DisplayName string
Credentials []webauthn.Credential
}
// WebAuthnID 返回用户ID
func (u *WebAuthnUser) WebAuthnID() []byte {
return []byte(strconv.Itoa(int(u.User.ID)))
}
// WebAuthnName 返回用户名
func (u *WebAuthnUser) WebAuthnName() string {
return u.User.Username
}
// WebAuthnDisplayName 返回用户显示名
func (u *WebAuthnUser) WebAuthnDisplayName() string {
return u.User.Name
}
// WebAuthnCredentials 返回用户所有凭证
func (u *WebAuthnUser) WebAuthnCredentials() []webauthn.Credential {
return u.Credentials
}
func (u *WebAuthnUser) WebAuthnCredentialDescriptors() (descriptors []protocol.CredentialDescriptor) {
credentials := u.WebAuthnCredentials()
descriptors = make([]protocol.CredentialDescriptor, len(credentials))
for i, credential := range credentials {
descriptors[i] = credential.Descriptor()
}
return descriptors
}
// WebAuthnService 提供WebAuthn相关功能
type WebAuthnService struct {
DB *gorm.DB
WebAuthn *webauthn.WebAuthn
// Sessions map[string]webauthn.SessionData // 用于存储注册和认证过程中的会话数据
Sessions *store.WebAuthnSessionStore
}
// NewWebAuthnService 创建新的WebAuthn服务
func NewWebAuthnService(db *gorm.DB, cfg *config.Config) (*WebAuthnService, error) {
// 创建WebAuthn配置
wconfig := &webauthn.Config{
RPDisplayName: config.Cfg.AppName, // 依赖方(Relying Party)显示名称
RPID: config.Cfg.Domain, // 依赖方ID(通常为域名)
RPOrigins: []string{config.Cfg.AppURL}, // 依赖方源(URL)
AuthenticatorSelection: protocol.AuthenticatorSelection{
RequireResidentKey: protocol.ResidentKeyRequired(), // 要求认证器存储用户 ID (resident key)
ResidentKey: protocol.ResidentKeyRequirementRequired, // 使用 Discoverable 模式
UserVerification: protocol.VerificationPreferred, // 推荐用户验证
AuthenticatorAttachment: "", // 允许任何认证器 (平台或跨平台)
},
// EncodeUserIDAsString: true, // 将用户ID编码为字符串
}
wa, err := webauthn.New(wconfig)
if err != nil {
return nil, err
}
return &WebAuthnService{
DB: db,
WebAuthn: wa,
// Sessions: make(map[string]webauthn.SessionData),
Sessions: store.NewWebAuthnSessionStore(),
}, nil
}
// GetUserWithCredentials 获取用户及其凭证
func (s *WebAuthnService) GetUserWithCredentials(userID int64) (*WebAuthnUser, error) {
var user model.User
if err := s.DB.Model(&model.User{}).Preload("Passkeys").First(&user, userID).Error; err != nil {
return nil, err
}
// 获取用户的所有Passkey
passkeys := user.Passkeys
// 将Passkey转换为webauthn.Credential
credentials := make([]webauthn.Credential, len(passkeys))
for i, pk := range passkeys {
credentialIDBytes, err := base64.StdEncoding.DecodeString(pk.CredentialID)
if err != nil {
return nil, fmt.Errorf("failed to decode CredentialID: %w", err)
}
publicKeyBytes, err := base64.StdEncoding.DecodeString(pk.PublicKey)
if err != nil {
return nil, fmt.Errorf("failed to decode PublicKey: %w", err)
}
aaguidBytes, err := base64.StdEncoding.DecodeString(pk.AAGUID)
if err != nil {
return nil, fmt.Errorf("failed to decode AAGUID: %w", err)
}
var transport []protocol.AuthenticatorTransport
if pk.Transport != "" {
transport = []protocol.AuthenticatorTransport{protocol.AuthenticatorTransport(pk.Transport)}
}
credentials[i] = webauthn.Credential{
ID: credentialIDBytes,
PublicKey: publicKeyBytes,
AttestationType: pk.AttestationType,
Transport: transport,
Flags: webauthn.CredentialFlags{
UserPresent: true,
UserVerified: true,
BackupEligible: pk.BackupEligible,
BackupState: pk.BackupState,
},
Authenticator: webauthn.Authenticator{
AAGUID: aaguidBytes,
SignCount: pk.SignCount,
CloneWarning: false,
},
}
}
// 创建WebAuthnUser
return &WebAuthnUser{
User: &user,
Credentials: credentials,
}, nil
}
// BeginRegistration 开始注册过程
func (s *WebAuthnService) BeginRegistration(userID int64) (*protocol.CredentialCreation, error) {
user, err := s.GetUserWithCredentials(userID)
if err != nil {
return nil, err
}
// 获取注册选项
options, sessionData, err := s.WebAuthn.BeginRegistration(user)
// webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired),
// webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()), // 排除已存在的凭证
if err != nil {
return nil, err
}
// 保存会话数据
userid := strconv.Itoa(int(userID))
s.Sessions.SaveWebauthnSession(userid, sessionData)
return options, nil
}
// FinishRegistration 完成注册过程
func (s *WebAuthnService) FinishRegistration(userID int64, response *http.Request, deviceName string) (*model.Passkey, error) {
user, err := s.GetUserWithCredentials(userID)
if err != nil {
return nil, err
}
userid := strconv.Itoa(int(userID))
// 获取并清除会话数据
sessionData, err := s.Sessions.GetWebauthnSession(userid)
if err != nil {
return nil, err
}
s.Sessions.DeleteWebauthnSession(userid)
// 完成注册
credential, err := s.WebAuthn.FinishRegistration(user, *sessionData, response)
if err != nil {
return nil, err
}
ua := useragent.Parse(response.UserAgent())
var transport string
if len(credential.Transport) > 0 {
transport = string(credential.Transport[0]) // 通常只取第一个传输方式
}
// 创建Passkey记录
passkey := &model.Passkey{
UserID: userID,
CredentialID: base64.StdEncoding.EncodeToString(credential.ID),
PublicKey: base64.StdEncoding.EncodeToString(credential.PublicKey),
AttestationType: string(credential.AttestationType),
AAGUID: base64.StdEncoding.EncodeToString(credential.Authenticator.AAGUID),
SignCount: credential.Authenticator.SignCount,
Name: deviceName,
DeviceType: strings.TrimSpace(fmt.Sprintf("%s %s %s %s %s", ua.Device, ua.OS, ua.OSVersionNoFull(), ua.Name, ua.VersionNoFull())),
LastUsedAt: time.Now().Unix(),
BackupEligible: credential.Flags.BackupEligible,
BackupState: credential.Flags.BackupState,
Transport: transport,
}
// 保存Passkey
if err := s.DB.Create(passkey).Error; err != nil {
return nil, err
}
return passkey, nil
}
// BeginLogin 开始登录过程 (无需用户ID,针对未认证用户)
func (s *WebAuthnService) BeginLogin() (*protocol.CredentialAssertion, error) {
// 不指定用户ID,让客户端决定使用哪个凭证
options, session, err := s.WebAuthn.BeginDiscoverableLogin(
webauthn.WithUserVerification(protocol.VerificationPreferred), // 推荐用户验证
)
if err != nil {
return nil, err
}
s.Sessions.SaveWebauthnSession(session.Challenge, session)
return options, nil
}
// FinishLogin 完成登录过程
func (s *WebAuthnService) FinishLogin(challenge string, response *http.Request) (*WebAuthnUser, error) {
// 获取并清除会话数据
sessionData, err := s.Sessions.GetWebauthnSession(challenge)
if err != nil {
return nil, err
}
s.Sessions.DeleteWebauthnSession(challenge)
// 获取相应的用户
// var user model.User
// if err := s.DB.First(&user, passkey.UserID).Error; err != nil {
// return nil, err
// }
// 创建WebAuthnUser
// webAuthnUser, err := s.GetUserWithCredentials(user.ID)
// if err != nil {
// return nil, err
// }
// 完成登录
// _, err = s.WebAuthn.FinishLogin(webAuthnUser, sessionData, response)
// if err != nil {
// return nil, err
// }
var user *WebAuthnUser
_, err = s.WebAuthn.FinishDiscoverableLogin(s.GetWebAuthnUser(&user), *sessionData, response)
if err != nil {
return nil, err
}
// 更新Passkey的LastUsedAt
return user, nil
}
func (s *WebAuthnService) GetWebAuthnUser(wau **WebAuthnUser) webauthn.DiscoverableUserHandler {
return func(rawID, userHandle []byte) (webauthn.User, error) {
userid, err := strconv.ParseInt(string(userHandle), 10, 64)
if err != nil {
return nil, err
}
*wau, err = s.GetUserWithCredentials(userid)
return *wau, err
}
}
// ListPasskeys 列出用户所有Passkey
func (s *WebAuthnService) ListPasskeys(userID int64) ([]model.Passkey, error) {
var passkeys []model.Passkey
if err := s.DB.Where("user_id = ?", userID).Find(&passkeys).Error; err != nil {
return nil, err
}
return passkeys, nil
}
// DeletePasskey 删除用户Passkey
func (s *WebAuthnService) DeletePasskey(userID int64, passkeyID int64) error {
return s.DB.Where("id = ? AND user_id = ?", passkeyID, userID).Delete(&model.Passkey{}).Error
}