reface to openteam
This commit is contained in:
77
internal/auth/auth.go
Normal file
77
internal/auth/auth.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"opencatd-open/internal/model"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type Claims struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
type TokenPair struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
func GenerateTokenPair(user *model.User, secret string, accessExpire, refreshExpire time.Duration) (*TokenPair, error) {
|
||||
// Generate access token
|
||||
accessToken, err := generateToken(user, "access", secret, accessExpire)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate refresh token
|
||||
refreshToken, err := generateToken(user, "refresh", secret, refreshExpire)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TokenPair{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func generateToken(user *model.User, tokenType, secret string, expire time.Duration) (string, error) {
|
||||
now := time.Now()
|
||||
|
||||
claims := Claims{
|
||||
UserID: user.ID,
|
||||
Name: user.Username,
|
||||
Type: tokenType,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(expire)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(secret))
|
||||
}
|
||||
|
||||
func ValidateToken(tokenString, secret string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, errors.New("unexpected signing method")
|
||||
}
|
||||
return []byte(secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, jwt.ErrInvalidKey
|
||||
}
|
||||
48
internal/consts/consts.go
Normal file
48
internal/consts/consts.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package consts
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
const SecretKey = "openteam"
|
||||
|
||||
const Day = 24 * 60 * 60 // day := 86400
|
||||
|
||||
type UserRole int
|
||||
|
||||
const (
|
||||
RoleUser UserRole = iota * 10
|
||||
RoleAdmin
|
||||
RoleRoot
|
||||
)
|
||||
|
||||
const (
|
||||
StatusDisabled = iota
|
||||
StatusEnabled
|
||||
StatusExpired // 过期
|
||||
StatusExhausted // 耗尽
|
||||
|
||||
StatusDeleted = -1
|
||||
)
|
||||
const (
|
||||
Limited = iota
|
||||
Unlimited
|
||||
UnlimitedQuota = 999999
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserNotFound = gorm.ErrRecordNotFound
|
||||
)
|
||||
|
||||
func OpenOrClose(status bool) int {
|
||||
if status {
|
||||
return StatusEnabled
|
||||
}
|
||||
return StatusDisabled
|
||||
}
|
||||
|
||||
// type DBType int
|
||||
|
||||
// const (
|
||||
// DBTypeMySQL DBType = iota
|
||||
// DBTypePostgreSQL
|
||||
// DBTypeSQLite
|
||||
// )
|
||||
152
internal/controller/apikey.go
Normal file
152
internal/controller/apikey.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"opencatd-open/internal/consts"
|
||||
"opencatd-open/internal/dto"
|
||||
"opencatd-open/internal/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/duke-git/lancet/v2/slice"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (a Api) CreateApiKey(c *gin.Context) {
|
||||
role := c.MustGet("user_role").(*consts.UserRole)
|
||||
if *role < consts.RoleAdmin {
|
||||
dto.Fail(c, 403, "Permission denied")
|
||||
return
|
||||
}
|
||||
req := new(model.ApiKey)
|
||||
err := c.ShouldBind(&req)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
}
|
||||
|
||||
err = a.keyService.CreateApiKey(c, req)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
} else {
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (a Api) GetApiKey(c *gin.Context) {
|
||||
role := c.MustGet("user_role").(*consts.UserRole)
|
||||
if *role < consts.RoleAdmin {
|
||||
dto.Fail(c, 403, "Permission denied")
|
||||
return
|
||||
}
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
key, err := a.keyService.GetApiKey(c, id)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
} else {
|
||||
dto.Success(c, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (a Api) ListApiKey(c *gin.Context) {
|
||||
role := c.MustGet("user_role").(*consts.UserRole)
|
||||
if *role < consts.RoleAdmin {
|
||||
dto.Fail(c, 403, "Permission denied")
|
||||
return
|
||||
}
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
offset := (page - 1) * limit
|
||||
active := c.QueryArray("active[]")
|
||||
if !slice.ContainSubSlice([]string{"true", "false"}, active) {
|
||||
dto.Fail(c, http.StatusBadRequest, "active must be true or false")
|
||||
return
|
||||
}
|
||||
|
||||
keys, total, err := a.keyService.ListApiKey(c, limit, offset, active)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
} else {
|
||||
dto.Success(c, gin.H{
|
||||
"total": total,
|
||||
"keys": keys,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (a Api) DeleteApiKey(c *gin.Context) {
|
||||
role := c.MustGet("user_role").(*consts.UserRole)
|
||||
if *role < consts.RoleAdmin {
|
||||
dto.Fail(c, 403, "Permission denied")
|
||||
return
|
||||
}
|
||||
var batchid dto.BatchIDRequest
|
||||
err := c.ShouldBind(&batchid)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = a.keyService.DeleteApiKey(c, batchid.IDs)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
} else {
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (a Api) UpdateApiKey(c *gin.Context) {
|
||||
role := c.MustGet("user_role").(*consts.UserRole)
|
||||
if *role < consts.RoleAdmin {
|
||||
dto.Fail(c, 403, "Permission denied")
|
||||
return
|
||||
}
|
||||
var req model.ApiKey
|
||||
err := c.ShouldBind(&req)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = a.keyService.UpdateApiKey(c, &req)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
} else {
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (a Api) ApiKeyOption(c *gin.Context) {
|
||||
role := c.MustGet("user_role").(*consts.UserRole)
|
||||
if *role < consts.RoleAdmin {
|
||||
dto.Fail(c, 403, "Permission denied")
|
||||
return
|
||||
}
|
||||
option := strings.ToLower(c.Param("option"))
|
||||
var batchid dto.BatchIDRequest
|
||||
err := c.ShouldBind(&batchid)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
switch option {
|
||||
case "enable":
|
||||
err = a.keyService.EnableApiKey(c, batchid.IDs)
|
||||
case "disable":
|
||||
err = a.keyService.DisableApiKey(c, batchid.IDs)
|
||||
case "delete":
|
||||
err = a.keyService.DeleteApiKey(c, batchid.IDs)
|
||||
default:
|
||||
dto.Fail(c, 400, "invalid option, only support enable, disable, delete")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
27
internal/controller/init.go
Normal file
27
internal/controller/init.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"opencatd-open/internal/service"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Api struct {
|
||||
db *gorm.DB
|
||||
userService *service.UserServiceImpl
|
||||
tokenService *service.TokenServiceImpl
|
||||
keyService *service.ApiKeyServiceImpl
|
||||
webAuthService *service.WebAuthnService
|
||||
usageService *service.UsageService
|
||||
}
|
||||
|
||||
func NewApi(db *gorm.DB, userService *service.UserServiceImpl, tokenService *service.TokenServiceImpl, keyService *service.ApiKeyServiceImpl, webAuthService *service.WebAuthnService, usageService *service.UsageService) *Api {
|
||||
return &Api{
|
||||
db: db,
|
||||
userService: userService,
|
||||
tokenService: tokenService,
|
||||
keyService: keyService,
|
||||
webAuthService: webAuthService,
|
||||
usageService: usageService,
|
||||
}
|
||||
}
|
||||
60
internal/controller/proxy/chat_proxy.go
Normal file
60
internal/controller/proxy/chat_proxy.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"opencatd-open/internal/dto"
|
||||
"opencatd-open/llm"
|
||||
"opencatd-open/llm/claude/v2"
|
||||
"opencatd-open/llm/google/v2"
|
||||
"opencatd-open/llm/openai_compatible"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (h *Proxy) ChatHandler(c *gin.Context) {
|
||||
var chatreq llm.ChatRequest
|
||||
if err := c.ShouldBindJSON(&chatreq); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
err := h.SelectApiKey(chatreq.Model)
|
||||
if err != nil {
|
||||
dto.WrapErrorAsOpenAI(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var llm llm.LLM
|
||||
switch *h.apikey.ApiType {
|
||||
case "claude":
|
||||
llm, err = claude.NewClaude(h.apikey)
|
||||
case "gemini":
|
||||
llm, err = google.NewGemini(c, h.apikey)
|
||||
case "openai", "azure", "github":
|
||||
fallthrough
|
||||
default:
|
||||
llm, err = openai_compatible.NewOpenAICompatible(h.apikey)
|
||||
if err != nil {
|
||||
dto.WrapErrorAsOpenAI(c, 500, fmt.Errorf("create llm client error: %w", err).Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !chatreq.Stream {
|
||||
resp, err := llm.Chat(c, chatreq)
|
||||
if err != nil {
|
||||
dto.WrapErrorAsOpenAI(c, 500, err.Error())
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
|
||||
} else {
|
||||
datachan, err := llm.StreamChat(c, chatreq)
|
||||
if err != nil {
|
||||
dto.WrapErrorAsOpenAI(c, 500, err.Error())
|
||||
}
|
||||
for data := range datachan {
|
||||
c.SSEvent("", data)
|
||||
}
|
||||
}
|
||||
}
|
||||
345
internal/controller/proxy/proxy.go
Normal file
345
internal/controller/proxy/proxy.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"opencatd-open/internal/dao"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/pkg/config"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/lib/pq"
|
||||
"github.com/tidwall/gjson"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type Proxy struct {
|
||||
ctx context.Context
|
||||
cfg *config.Config
|
||||
db *gorm.DB
|
||||
wg *sync.WaitGroup
|
||||
usageChan chan *model.Usage // 用于异步处理的channel
|
||||
apikey *model.ApiKey
|
||||
httpClient *http.Client
|
||||
|
||||
userDAO *dao.UserDAO
|
||||
apiKeyDao *dao.ApiKeyDAO
|
||||
tokenDAO *dao.TokenDAO
|
||||
usageDAO *dao.UsageDAO
|
||||
dailyUsageDAO *dao.DailyUsageDAO
|
||||
}
|
||||
|
||||
func NewProxy(ctx context.Context, cfg *config.Config, db *gorm.DB, wg *sync.WaitGroup, userDAO *dao.UserDAO, apiKeyDAO *dao.ApiKeyDAO, tokenDAO *dao.TokenDAO, usageDAO *dao.UsageDAO, dailyUsageDAO *dao.DailyUsageDAO) *Proxy {
|
||||
client := http.DefaultClient
|
||||
if os.Getenv("LOCAL_PROXY") != "" {
|
||||
proxyUrl, err := url.Parse(os.Getenv("LOCAL_PROXY"))
|
||||
if err == nil {
|
||||
tr := &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyUrl),
|
||||
}
|
||||
client.Transport = tr
|
||||
}
|
||||
}
|
||||
np := &Proxy{
|
||||
ctx: ctx,
|
||||
cfg: cfg,
|
||||
db: db,
|
||||
wg: wg,
|
||||
httpClient: client,
|
||||
usageChan: make(chan *model.Usage, cfg.UsageChanSize),
|
||||
userDAO: userDAO,
|
||||
apiKeyDao: apiKeyDAO,
|
||||
tokenDAO: tokenDAO,
|
||||
usageDAO: usageDAO,
|
||||
dailyUsageDAO: dailyUsageDAO,
|
||||
}
|
||||
|
||||
go np.ProcessUsage()
|
||||
go np.ScheduleTask()
|
||||
|
||||
return np
|
||||
}
|
||||
|
||||
func (p *Proxy) HandleProxy(c *gin.Context) {
|
||||
if c.Request.URL.Path == "/v1/chat/completions" {
|
||||
p.ChatHandler(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) SendUsage(usage *model.Usage) {
|
||||
select {
|
||||
case p.usageChan <- usage:
|
||||
default:
|
||||
log.Println("usage channel is full, skip processing")
|
||||
bj, _ := json.Marshal(usage)
|
||||
log.Println(string(bj))
|
||||
//TODO: send to a queue
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) ProcessUsage() {
|
||||
for i := 0; i < p.cfg.UsageWorker; i++ {
|
||||
p.wg.Add(1)
|
||||
go func(i int) {
|
||||
defer p.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case usage, ok := <-p.usageChan:
|
||||
if !ok {
|
||||
// channel 关闭,退出程序
|
||||
return
|
||||
}
|
||||
err := p.Do(usage)
|
||||
if err != nil {
|
||||
log.Printf("process usage error: %v\n", err)
|
||||
}
|
||||
case <-p.ctx.Done():
|
||||
// close(s.usageChan)
|
||||
// for usage := range s.usageChan {
|
||||
// if err := s.Do(usage); err != nil {
|
||||
// fmt.Printf("[close event]process usage error: %v\n", err)
|
||||
// }
|
||||
// }
|
||||
for {
|
||||
select {
|
||||
case usage, ok := <-p.usageChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := p.Do(usage); err != nil {
|
||||
fmt.Printf("[close event]process usage error: %v\n", err)
|
||||
}
|
||||
default:
|
||||
fmt.Printf("usageChan is empty,usage worker %d done\n", i)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}(i)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) Do(usage *model.Usage) error {
|
||||
err := p.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 1. 记录使用记录
|
||||
if err := tx.WithContext(p.ctx).Create(usage).Error; err != nil {
|
||||
return fmt.Errorf("create usage error: %w", err)
|
||||
}
|
||||
|
||||
// 2. 更新每日统计(upsert 操作)
|
||||
dailyUsage := model.DailyUsage{
|
||||
UserID: usage.UserID,
|
||||
TokenID: usage.TokenID,
|
||||
Capability: usage.Capability,
|
||||
Date: time.Date(usage.Date.Year(), usage.Date.Month(), usage.Date.Day(), 0, 0, 0, 0, usage.Date.Location()),
|
||||
Model: usage.Model,
|
||||
Stream: usage.Stream,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
TotalTokens: usage.TotalTokens,
|
||||
Cost: usage.Cost,
|
||||
}
|
||||
|
||||
// 使用 OnConflict 实现 upsert
|
||||
if err := tx.WithContext(p.ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "user_id"}, {Name: "token_id"}, {Name: "capability"}, {Name: "date"}}, // 唯一键
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||
"prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens),
|
||||
"completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens),
|
||||
"total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens),
|
||||
"cost": gorm.Expr("cost + ?", usage.Cost),
|
||||
}),
|
||||
}).Create(&dailyUsage).Error; err != nil {
|
||||
return fmt.Errorf("upsert daily usage error: %w", err)
|
||||
}
|
||||
|
||||
// 3. 更新用户额度
|
||||
if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", usage.UserID).Updates(map[string]interface{}{
|
||||
"quota": gorm.Expr("quota - ?", usage.Cost),
|
||||
"used_quota": gorm.Expr("used_quota + ?", usage.Cost),
|
||||
}).Error; err != nil {
|
||||
return fmt.Errorf("update user quota and used_quota error: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *Proxy) SelectApiKey(model string) error {
|
||||
akpikeys, err := p.apiKeyDao.FindApiKeysBySupportModel(p.db, model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(akpikeys) == 0 {
|
||||
return errors.New("no available apikey")
|
||||
} else {
|
||||
if strings.HasPrefix(model, "gpt") {
|
||||
keys, err := p.apiKeyDao.FindKeys(map[string]any{"type = ?": "openai"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
akpikeys = append(akpikeys, keys...)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(model, "gemini") {
|
||||
keys, err := p.apiKeyDao.FindKeys(map[string]any{"type = ?": "gemini"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
akpikeys = append(akpikeys, keys...)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(model, "claude") {
|
||||
keys, err := p.apiKeyDao.FindKeys(map[string]any{"type = ?": "claude"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
akpikeys = append(akpikeys, keys...)
|
||||
}
|
||||
}
|
||||
if len(akpikeys) == 0 {
|
||||
return errors.New("no available apikey")
|
||||
|
||||
}
|
||||
|
||||
if len(akpikeys) == 1 {
|
||||
p.apikey = &akpikeys[0]
|
||||
return nil
|
||||
}
|
||||
length := len(akpikeys) - 1
|
||||
|
||||
p.apikey = &akpikeys[rand.Intn(length)]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Proxy) updateSupportModel() {
|
||||
|
||||
keys, err := p.apiKeyDao.FindKeys(map[string]interface{}{"type in ?": "openai,azure,claude"})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, key := range keys {
|
||||
var supportModels []string
|
||||
if *key.ApiType == "openai" || *key.ApiType == "azure" {
|
||||
supportModels, err = p.getOpenAISupportModels(key)
|
||||
}
|
||||
if *key.ApiType == "claude" {
|
||||
supportModels, err = p.getClaudeSupportModels(key)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
continue
|
||||
}
|
||||
if len(supportModels) == 0 {
|
||||
continue
|
||||
|
||||
}
|
||||
if p.cfg.DB_Type == "sqlite" {
|
||||
bytejson, _ := json.Marshal(supportModels)
|
||||
if err := p.db.Model(&model.ApiKey{}).Where("id = ?", key.ID).UpdateColumn("support_models", string(bytejson)).Error; err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
} else if p.cfg.DB_Type == "postgres" {
|
||||
if err := p.db.Model(&model.ApiKey{}).Where("id = ?", key.ID).UpdateColumn("support_models", pq.StringArray(supportModels)).Error; err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (p *Proxy) ScheduleTask() {
|
||||
|
||||
func() {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(time.Duration(p.cfg.TaskTimeInterval) * time.Minute):
|
||||
p.updateSupportModel()
|
||||
|
||||
case <-p.ctx.Done():
|
||||
fmt.Println("schedule task done")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *Proxy) getOpenAISupportModels(apikey model.ApiKey) ([]string, error) {
|
||||
openaiModelsUrl := "https://api.openai.com/v1/models"
|
||||
// https://learn.microsoft.com/zh-cn/rest/api/azureopenai/models/list?view=rest-azureopenai-2025-02-01-preview&tabs=HTTP
|
||||
azureModelsUrl := "/openai/deployments?api-version=2022-12-01"
|
||||
|
||||
var supportModels []string
|
||||
var req *http.Request
|
||||
if *apikey.ApiType == "azure" {
|
||||
req, _ = http.NewRequest("GET", *apikey.Endpoint+azureModelsUrl, nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("api-key", *apikey.ApiKey)
|
||||
} else {
|
||||
req, _ = http.NewRequest("GET", openaiModelsUrl, nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+*apikey.ApiKey)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
bytesbody, _ := io.ReadAll(resp.Body)
|
||||
result := gjson.GetBytes(bytesbody, "data.#.id").Array()
|
||||
for _, v := range result {
|
||||
model := v.Str
|
||||
model = strings.Replace(model, "-35-", "-3.5-", -1)
|
||||
model = strings.Replace(model, "-41-", "-4.1-", -1)
|
||||
supportModels = append(supportModels, model)
|
||||
}
|
||||
}
|
||||
return supportModels, nil
|
||||
}
|
||||
|
||||
func (p *Proxy) getClaudeSupportModels(apikey model.ApiKey) ([]string, error) {
|
||||
// https://docs.anthropic.com/en/api/models-list
|
||||
claudemodelsUrl := "https://api.anthropic.com/v1/models"
|
||||
var supportModels []string
|
||||
|
||||
req, _ := http.NewRequest("GET", claudemodelsUrl, nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-api-key", *apikey.ApiKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
bytesbody, _ := io.ReadAll(resp.Body)
|
||||
result := gjson.GetBytes(bytesbody, "data.#.id").Array()
|
||||
for _, v := range result {
|
||||
supportModels = append(supportModels, v.Str)
|
||||
}
|
||||
}
|
||||
return supportModels, nil
|
||||
}
|
||||
53
internal/controller/team/middleware.go
Normal file
53
internal/controller/team/middleware.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"opencatd-open/internal/consts"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (h *Team) AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.URL.Path == "/1/users/init" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
authtoken := c.GetHeader("Authorization")
|
||||
if authtoken == "" || len(authtoken) <= 7 || authtoken[:7] != "Bearer " {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
authtoken = authtoken[7:]
|
||||
token, err := h.tokenService.GetByKey(c, authtoken)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
if token.Name != "default" {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "only default token can access"})
|
||||
c.Abort()
|
||||
}
|
||||
if token.User.Status != consts.StatusEnabled {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "user is disabled"})
|
||||
c.Abort()
|
||||
}
|
||||
c.Set("local_user", true)
|
||||
c.Set("token", token)
|
||||
|
||||
// 可以在这里对 token 进行验证并检查权限
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func CORS() gin.HandlerFunc {
|
||||
config := cors.DefaultConfig()
|
||||
config.AllowAllOrigins = true
|
||||
config.AllowCredentials = true
|
||||
config.AllowMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
|
||||
config.AllowHeaders = []string{"*"}
|
||||
return cors.New(config)
|
||||
}
|
||||
563
internal/controller/team/team.go
Normal file
563
internal/controller/team/team.go
Normal file
@@ -0,0 +1,563 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"opencatd-open/internal/consts"
|
||||
dto "opencatd-open/internal/dto/team"
|
||||
"opencatd-open/internal/model"
|
||||
service "opencatd-open/internal/service/team"
|
||||
"opencatd-open/internal/utils"
|
||||
|
||||
"github.com/duke-git/lancet/v2/slice"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Team struct {
|
||||
db *gorm.DB
|
||||
userService service.UserService
|
||||
tokenService service.TokenService
|
||||
keyService service.ApiKeyService
|
||||
usageService service.UsageService
|
||||
}
|
||||
|
||||
func NewTeam(userService service.UserService, tokenService service.TokenService, keyService service.ApiKeyService, usageService service.UsageService) *Team {
|
||||
return &Team{
|
||||
userService: userService,
|
||||
tokenService: tokenService,
|
||||
keyService: keyService,
|
||||
usageService: usageService,
|
||||
}
|
||||
}
|
||||
|
||||
// initadmin
|
||||
func (h *Team) InitAdmin(c *gin.Context) {
|
||||
admin, err := h.userService.GetUser(c, 1)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
user := &model.User{
|
||||
Name: "root",
|
||||
Username: "root",
|
||||
Password: "openteam",
|
||||
Role: utils.ToPtr(consts.RoleRoot),
|
||||
Tokens: []model.Token{
|
||||
{
|
||||
Name: "default",
|
||||
Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
UnlimitedQuota: utils.ToPtr(true),
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := h.userService.CreateUser(c, user); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
var result = dto.UserInfo{
|
||||
ID: user.ID,
|
||||
Name: user.Username,
|
||||
Token: user.Tokens[0].Key,
|
||||
Status: utils.ToPtr(user.Status == consts.StatusEnabled),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
if admin != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "super user already exists, use cli to reset password",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Team) Me(c *gin.Context) {
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "token not found"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
|
||||
c.JSON(http.StatusOK, dto.UserInfo{
|
||||
ID: userToken.UserID,
|
||||
Name: userToken.User.Name,
|
||||
Token: userToken.Key,
|
||||
Status: utils.ToPtr(userToken.User.Status == consts.StatusEnabled),
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// CreateUser 创建用户
|
||||
func (h *Team) CreateUser(c *gin.Context) {
|
||||
var userReq dto.UserInfo
|
||||
if err := c.ShouldBindJSON(&userReq); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid input"})
|
||||
return
|
||||
}
|
||||
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
if *userToken.User.Role < consts.RoleAdmin { // 普通用户只能创建自己的token
|
||||
create := &model.Token{
|
||||
Name: userReq.Name,
|
||||
Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
}
|
||||
if userReq.Token != "" {
|
||||
_key := strings.ReplaceAll(userReq.Token, "-", "")
|
||||
create.Key = "sk-team-" + strings.ReplaceAll(_key, " ", "")
|
||||
}
|
||||
if err := h.tokenService.Create(c.Request.Context(), create); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
user := &model.User{
|
||||
Name: userReq.Name,
|
||||
Username: userReq.Name,
|
||||
Role: utils.ToPtr(consts.RoleUser),
|
||||
Tokens: []model.Token{
|
||||
{
|
||||
Name: "default",
|
||||
Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 默认角色为普通用户
|
||||
if err := h.userService.CreateUser(c.Request.Context(), user); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
// GetUser 获取用户信息
|
||||
func (h *Team) GetUser(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.GetUser(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, user)
|
||||
}
|
||||
|
||||
// UpdateUser 更新用户信息
|
||||
func (h *Team) UpdateUser(c *gin.Context) {
|
||||
var user model.User
|
||||
if err := c.ShouldBindJSON(&user); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid input"})
|
||||
return
|
||||
}
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
|
||||
operatorID := userToken.UserID // 假设从上下文中获取操作者ID
|
||||
if err := h.userService.UpdateUser(c.Request.Context(), &user, operatorID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
// DeleteUser 删除用户
|
||||
func (h *Team) DeleteUser(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
|
||||
if *userToken.User.Role < consts.RoleAdmin { // 用户只能删除自己的token
|
||||
err := h.tokenService.Delete(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := h.userService.DeleteUser(c, id, userToken.UserID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
func (h *Team) ListUsages(c *gin.Context) {
|
||||
fromStr := c.Query("from")
|
||||
toStr := c.Query("to")
|
||||
|
||||
var from, to time.Time
|
||||
loc, _ := time.LoadLocation("Local")
|
||||
|
||||
var listUsage []*dto.UsageInfo
|
||||
var err error
|
||||
|
||||
if fromStr != "" && toStr != "" {
|
||||
|
||||
from, err = time.Parse("2006-01-02", fromStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid from date"})
|
||||
return
|
||||
}
|
||||
to, err = time.Parse("2006-01-02", toStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid to date"})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
year, month, _ := time.Now().In(loc).Date()
|
||||
from = time.Date(year, month, 1, 0, 0, 0, 0, loc)
|
||||
to = from.AddDate(0, 1, 0)
|
||||
}
|
||||
|
||||
token, _ := c.Get("token")
|
||||
userToken := token.(*model.Token)
|
||||
if *userToken.User.Role < consts.RoleAdmin {
|
||||
listUsage, err = h.usageService.ListByDateRange(c.Request.Context(), from, to, map[string]interface{}{"user_id": userToken.UserID})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
listUsage, err = h.usageService.ListByDateRange(c.Request.Context(), from, to, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, listUsage)
|
||||
|
||||
}
|
||||
|
||||
// ListUsers 获取用户列表
|
||||
func (h *Team) ListUsers(c *gin.Context) {
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||
offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0"))
|
||||
active := c.DefaultQuery("active", "")
|
||||
|
||||
if !slices.Contains([]string{"true", "false", ""}, active) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid active value"})
|
||||
return
|
||||
}
|
||||
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
if *userToken.User.Role < consts.RoleAdmin { // 用户只能获取自己的token
|
||||
tokens, _, err := h.tokenService.Lists(c, limit, offset)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
var userDTOs []dto.UserInfo
|
||||
for _, token := range tokens {
|
||||
userDTOs = append(userDTOs, dto.UserInfo{
|
||||
ID: token.User.ID,
|
||||
Name: token.User.Name,
|
||||
Token: token.Key,
|
||||
Status: utils.ToPtr(token.User.Status == consts.StatusEnabled),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, userDTOs)
|
||||
return
|
||||
}
|
||||
|
||||
users, err := h.userService.ListUsers(c, limit, offset, active)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var userDTOs []dto.UserInfo
|
||||
for _, user := range users {
|
||||
useres := dto.UserInfo{
|
||||
ID: user.ID,
|
||||
Name: user.Name,
|
||||
|
||||
Status: utils.ToPtr(user.Status == consts.StatusEnabled),
|
||||
}
|
||||
if len(user.Tokens) > 0 {
|
||||
useres.Token = user.Tokens[0].Key
|
||||
}
|
||||
userDTOs = append(userDTOs, useres)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, userDTOs)
|
||||
}
|
||||
|
||||
func (h *Team) ResetUserToken(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
|
||||
findtoken, err := h.tokenService.GetByUserID(c, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
findtoken.Key = "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
if *userToken.User.Role < consts.RoleAdmin { // 非管理员只能修改自己的token
|
||||
if *userToken.User.Role <= *findtoken.User.Role || userToken.UserID != findtoken.UserID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "forbidden"})
|
||||
return
|
||||
}
|
||||
err := h.tokenService.UpdateWithCondition(c, findtoken, map[string]interface{}{"user_id": userToken.UserID}, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := h.tokenService.Update(c, findtoken); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, dto.UserInfo{
|
||||
ID: findtoken.User.ID,
|
||||
Name: findtoken.User.Name,
|
||||
Token: findtoken.Key,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Team) CreateKey(c *gin.Context) {
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "token not found"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
if *userToken.User.Role < consts.RoleAdmin {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "forbidden"})
|
||||
return
|
||||
}
|
||||
|
||||
var key dto.ApiKeyInfo
|
||||
if err := c.ShouldBindJSON(&key); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
err := h.keyService.Create(&model.ApiKey{
|
||||
Name: utils.ToPtr(key.Name),
|
||||
ApiType: utils.ToPtr(key.ApiType),
|
||||
ApiKey: utils.ToPtr(key.Key),
|
||||
Endpoint: utils.ToPtr(key.Endpoint),
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, key)
|
||||
}
|
||||
|
||||
func (h *Team) ListKeys(c *gin.Context) {
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||
offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0"))
|
||||
active := c.Query("active")
|
||||
if !slice.Contain([]string{"true", "false", ""}, active) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid active value"})
|
||||
return
|
||||
}
|
||||
|
||||
keys, err := h.keyService.List(limit, offset, active)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var keysDTO []dto.ApiKeyInfo
|
||||
for _, key := range keys {
|
||||
keylength := len(*key.ApiKey) / 3
|
||||
if keylength < 1 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid key length"})
|
||||
return
|
||||
}
|
||||
keysDTO = append(keysDTO, dto.ApiKeyInfo{
|
||||
ID: int(key.ID),
|
||||
Name: *key.Name,
|
||||
ApiType: *key.ApiType,
|
||||
Endpoint: *key.Endpoint,
|
||||
Key: *key.ApiKey,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, keysDTO)
|
||||
}
|
||||
|
||||
func (h *Team) UpdateKey(c *gin.Context) {
|
||||
// 1. 获取并验证ID
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid key id"})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 解析请求体
|
||||
var updateKey dto.ApiKeyInfo // 更明确的命名
|
||||
if err := c.ShouldBindJSON(&updateKey); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取现有记录
|
||||
existingKey, err := h.keyService.GetByID(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 使用 UpdateFields 方法统一处理字段更新
|
||||
updatedKey := updateKey.UpdateFields(existingKey)
|
||||
|
||||
// 5. 保存更新
|
||||
if err := h.keyService.Update(updatedKey); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, updatedKey)
|
||||
}
|
||||
|
||||
func (h *Team) DeleteKey(c *gin.Context) {
|
||||
// 1. 获取并验证ID
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid key id"})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 删除记录
|
||||
if err := h.keyService.Delete(id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
func (h *Team) ChangePassword(c *gin.Context) {
|
||||
userID := c.GetInt64("userID") // 假设从上下文中获取用户ID
|
||||
|
||||
var req struct {
|
||||
OldPassword string `json:"oldPassword"`
|
||||
NewPassword string `json:"newPassword"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid input"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.userService.ChangePassword(c.Request.Context(), userID, req.OldPassword, req.NewPassword); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
func (h *Team) ResetPassword(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
|
||||
operatorID := int64(c.GetInt("userID")) // 假设从上下文中获取操作者ID
|
||||
if err := h.userService.ResetPassword(c.Request.Context(), id, operatorID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "password reset successfully"})
|
||||
}
|
||||
|
||||
// EnableUser 启用用户
|
||||
func (h *Team) EnableUser(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
|
||||
operatorID := int64(c.GetInt("userID")) // 假设从上下文中获取操作者ID
|
||||
|
||||
if err := h.userService.BatchEnableUsers(c, []int64{id}, operatorID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "user enabled successfully"})
|
||||
}
|
||||
|
||||
// DisableUser 禁用用户
|
||||
func (h *Team) DisableUser(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
|
||||
operatorID := int64(c.GetInt("userID")) // 假设从上下文中获取操作者ID
|
||||
if err := h.userService.BatchDisableUsers(c.Request.Context(), []int64{id}, operatorID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "user disabled successfully"})
|
||||
}
|
||||
218
internal/controller/user.go
Normal file
218
internal/controller/user.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"opencatd-open/internal/dto"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/internal/utils"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/duke-git/lancet/v2/slice"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (a Api) Register(c *gin.Context) {
|
||||
req := new(dto.User)
|
||||
err := c.ShouldBind(&req)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = a.userService.Register(c, &model.User{
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
})
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
} else {
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (a Api) Login(c *gin.Context) {
|
||||
req := new(dto.User)
|
||||
err := c.ShouldBind(&req)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
auth, err := a.userService.Login(c, req)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
} else {
|
||||
dto.Success(c, auth)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (a Api) Profile(c *gin.Context) {
|
||||
user, err := a.userService.Profile(c)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
} else {
|
||||
dto.Success(c, user)
|
||||
}
|
||||
}
|
||||
|
||||
func (a Api) UpdateProfile(c *gin.Context) {
|
||||
var user = model.User{}
|
||||
err := c.ShouldBind(&user)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = a.userService.Update(c, &model.User{Name: user.Name, Username: user.Username, Email: user.Email})
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) UpdatePassword(c *gin.Context) {
|
||||
var passwd dto.ChangePassword
|
||||
err := c.ShouldBind(&passwd)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
_user := c.MustGet("user").(*model.User)
|
||||
if _user.Password == "" {
|
||||
hashpass, err := utils.HashPassword(passwd.NewPassword)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
_user.Password = hashpass
|
||||
} else {
|
||||
if !utils.CheckPassword(_user.Password, passwd.Password) {
|
||||
dto.Fail(c, http.StatusBadRequest, "password not match")
|
||||
return
|
||||
}
|
||||
hashpass, err := utils.HashPassword(passwd.NewPassword)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
_user.Password = hashpass
|
||||
}
|
||||
err = a.userService.Update(c, _user)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) ListUser(c *gin.Context) {
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
offset := (page - 1) * limit
|
||||
active := c.QueryArray("active[]")
|
||||
if !slice.ContainSubSlice([]string{"true", "false", ""}, active) {
|
||||
dto.Fail(c, http.StatusBadRequest, "active must be true or false")
|
||||
return
|
||||
}
|
||||
|
||||
users, total, err := a.userService.List(c, limit, offset, active)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, gin.H{
|
||||
"users": users,
|
||||
"total": total,
|
||||
})
|
||||
}
|
||||
|
||||
func (a Api) CreateUser(c *gin.Context) {
|
||||
var user model.User
|
||||
err := c.ShouldBind(&user)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Printf("user:%+v\n", user)
|
||||
err = a.userService.Create(c, &user)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) GetUser(c *gin.Context) {
|
||||
id, _ := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
user, err := a.userService.GetByID(c, id)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, user)
|
||||
}
|
||||
|
||||
func (a Api) EditUser(c *gin.Context) {
|
||||
id, _ := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
var user model.User
|
||||
err := c.ShouldBind(&user)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
user.ID = int64(id)
|
||||
err = a.userService.Update(c, &user)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) DeleteUser(c *gin.Context) {
|
||||
id, _ := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
err := a.userService.Delete(c, id)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) UserOption(c *gin.Context) {
|
||||
option := strings.ToLower(c.Param("option"))
|
||||
var batchid dto.BatchIDRequest
|
||||
err := c.ShouldBind(&batchid)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
switch option {
|
||||
case "enable":
|
||||
err = a.userService.BatchEnable(c, batchid.IDs)
|
||||
case "disable":
|
||||
err = a.userService.BatchDisable(c, batchid.IDs)
|
||||
case "delete":
|
||||
err = a.userService.BatchDelete(c, batchid.IDs)
|
||||
default:
|
||||
dto.Fail(c, 400, "invalid option, only support enable, disable, delete")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, nil)
|
||||
|
||||
}
|
||||
218
internal/controller/user_token.go
Normal file
218
internal/controller/user_token.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"opencatd-open/internal/dto"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/internal/utils"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/duke-git/lancet/v2/slice"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (a Api) CreateToken(c *gin.Context) {
|
||||
userid := c.GetInt64("user_id")
|
||||
user, err := a.userService.GetByID(c, userid)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if len(user.Tokens) >= 20 {
|
||||
dto.Fail(c, http.StatusForbidden, "user has reached the maximum number of tokens")
|
||||
return
|
||||
}
|
||||
|
||||
var token model.Token
|
||||
err = c.ShouldBindJSON(&token)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
token.UserID = userid
|
||||
|
||||
err = a.tokenService.CreateToken(c, &token)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) ListToken(c *gin.Context) {
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
offset := (page - 1) * limit
|
||||
active := c.QueryArray("active[]")
|
||||
if !slice.ContainSubSlice([]string{"true", "false"}, active) {
|
||||
dto.Fail(c, http.StatusBadRequest, "active must be true or false")
|
||||
}
|
||||
|
||||
tokens, total, err := a.tokenService.ListToken(c, limit, offset, active)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
dto.Success(c, gin.H{
|
||||
"total": total,
|
||||
"tokens": tokens,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (a Api) GetToken(c *gin.Context) {
|
||||
id, _ := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
|
||||
token, err := a.tokenService.GetToken(c, id)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
dto.Success(c, token)
|
||||
}
|
||||
|
||||
func (a Api) ResetToken(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
token, err := a.tokenService.GetToken(c, id)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if token == nil {
|
||||
dto.Fail(c, http.StatusNotFound, "token not found")
|
||||
return
|
||||
}
|
||||
token.UsedQuota = utils.ToPtr(int64(0))
|
||||
|
||||
err = a.tokenService.UpdateToken(c, token)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) UpdateToken(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var token model.Token
|
||||
err = c.ShouldBindJSON(&token)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
token.ID = id
|
||||
if token.UserID == 0 {
|
||||
dto.Fail(c, http.StatusBadRequest, "user_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
var _token *model.Token
|
||||
|
||||
user, err := a.userService.GetByID(c, token.UserID)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if len(user.Tokens) == 0 {
|
||||
dto.Fail(c, http.StatusForbidden, "user has no tokens")
|
||||
return
|
||||
} else {
|
||||
if findtoken, ok := slice.Find(user.Tokens,
|
||||
func(idx int, t model.Token) bool {
|
||||
return t.ID == id
|
||||
}); ok {
|
||||
_token = findtoken
|
||||
_token.User = user
|
||||
} else {
|
||||
dto.Fail(c, http.StatusForbidden, "user has no tokens")
|
||||
return
|
||||
}
|
||||
}
|
||||
// 更新_token信息
|
||||
if token.Name != "" {
|
||||
_token.Name = token.Name
|
||||
}
|
||||
if token.Key != "" {
|
||||
_token.Key = token.Key
|
||||
}
|
||||
if token.Active != nil {
|
||||
_token.Active = token.Active
|
||||
}
|
||||
if token.Quota != nil {
|
||||
_token.Quota = token.Quota
|
||||
}
|
||||
if token.UnlimitedQuota != nil {
|
||||
_token.UnlimitedQuota = token.UnlimitedQuota
|
||||
}
|
||||
if token.ExpiredAt != nil {
|
||||
_token.ExpiredAt = token.ExpiredAt
|
||||
}
|
||||
|
||||
err = a.tokenService.UpdateToken(c, _token)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) DeleteToken(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = a.tokenService.DeleteToken(c, id)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) TokenOption(c *gin.Context) {
|
||||
option := strings.ToLower(c.Param("option"))
|
||||
var batchid dto.BatchIDRequest
|
||||
err := c.ShouldBind(&batchid)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
if batchid.UserID == nil {
|
||||
dto.Fail(c, 400, "user_id is required")
|
||||
return
|
||||
}
|
||||
switch option {
|
||||
case "enable":
|
||||
err = a.tokenService.EnableTokens(c, *batchid.UserID, batchid.IDs)
|
||||
case "disable":
|
||||
err = a.tokenService.DisableTokens(c, *batchid.UserID, batchid.IDs)
|
||||
case "delete":
|
||||
err = a.tokenService.DeleteTokens(c, *batchid.UserID, batchid.IDs)
|
||||
default:
|
||||
dto.Fail(c, 400, "invalid option, only support enable, disable, delete")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
108
internal/controller/webauth.go
Normal file
108
internal/controller/webauth.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"opencatd-open/internal/auth"
|
||||
"opencatd-open/internal/consts"
|
||||
"opencatd-open/internal/dto"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (a *Api) PasskeyCreateBegin(c *gin.Context) {
|
||||
userid := c.GetInt64("user_id")
|
||||
cred, err := a.webAuthService.BeginRegistration(userid)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, cred)
|
||||
}
|
||||
|
||||
func (a *Api) PasskeyCreateFinish(c *gin.Context) {
|
||||
userid := c.GetInt64("user_id")
|
||||
name := c.Query("name")
|
||||
if name == "" {
|
||||
name = fmt.Sprintf("User-%d-%d", userid, time.Now().Unix())
|
||||
}
|
||||
// var body protocol.CredentialCreationResponse
|
||||
// if err := c.ShouldBindJSON(&body); err != nil {
|
||||
// dto.Fail(c, 400, err.Error())
|
||||
// return
|
||||
// }
|
||||
|
||||
// 获取用户凭证
|
||||
cred, err := a.webAuthService.FinishRegistration(userid, c.Request, name)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
dto.Success(c, cred)
|
||||
}
|
||||
|
||||
func (a *Api) ListPasskey(c *gin.Context) {
|
||||
passkeys, err := a.webAuthService.ListPasskeys(c.GetInt64("user_id"))
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
var passkeysDto []dto.Passkey
|
||||
for _, passkey := range passkeys {
|
||||
passkeysDto = append(passkeysDto, dto.Passkey{
|
||||
ID: passkey.ID,
|
||||
Name: passkey.Name,
|
||||
DeviceType: passkey.DeviceType,
|
||||
SignCount: passkey.SignCount,
|
||||
LastUsedAt: passkey.LastUsedAt,
|
||||
CreatedAt: passkey.CreatedAt,
|
||||
UpdatedAt: passkey.UpdatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
dto.Success(c, passkeysDto)
|
||||
}
|
||||
|
||||
func (a *Api) DeletePasskey(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
if err = a.webAuthService.DeletePasskey(c.GetInt64("user_id"), id); err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, "删除成功")
|
||||
}
|
||||
|
||||
// 登陆
|
||||
func (a *Api) PasskeyAuthBegin(c *gin.Context) {
|
||||
|
||||
cred, err := a.webAuthService.BeginLogin()
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, cred)
|
||||
}
|
||||
|
||||
func (a *Api) PasskeyAuthFinish(c *gin.Context) {
|
||||
challenge := c.Query("challenge")
|
||||
webAuthUser, err := a.webAuthService.FinishLogin(challenge, c.Request)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
at, err := auth.GenerateTokenPair(webAuthUser.User, consts.SecretKey, consts.Day*time.Second, consts.Day*time.Second)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, dto.Auth{
|
||||
Token: at.AccessToken,
|
||||
ExpiresIn: time.Now().Add(consts.Day * time.Second).Unix(),
|
||||
})
|
||||
}
|
||||
179
internal/dao/apikey.go
Normal file
179
internal/dao/apikey.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/pkg/config"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ ApiKeyRepository = (*ApiKeyDAO)(nil)
|
||||
|
||||
type ApiKeyRepository interface {
|
||||
Create(apiKey *model.ApiKey) error
|
||||
GetByID(id int64) (*model.ApiKey, error)
|
||||
GetByName(name string) (*model.ApiKey, error)
|
||||
GetByApiKey(apiKeyValue string) (*model.ApiKey, error)
|
||||
Update(apiKey *model.ApiKey) error
|
||||
List(limit, offset int, status string) ([]*model.ApiKey, error)
|
||||
ListWithFilters(limit, offset int, filters map[string]interface{}) ([]*model.ApiKey, int64, error)
|
||||
BatchEnable(ids []int64) error
|
||||
BatchDisable(ids []int64) error
|
||||
BatchDelete(ids []int64) error
|
||||
Count() (int64, error)
|
||||
}
|
||||
|
||||
type ApiKeyDAO struct {
|
||||
cfg *config.Config
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewApiKeyDAO(db *gorm.DB) *ApiKeyDAO {
|
||||
return &ApiKeyDAO{db: db}
|
||||
}
|
||||
|
||||
// CreateApiKey 创建ApiKey
|
||||
func (dao *ApiKeyDAO) Create(apiKey *model.ApiKey) error {
|
||||
if apiKey == nil {
|
||||
return errors.New("apiKey is nil")
|
||||
}
|
||||
return dao.db.Create(apiKey).Error
|
||||
}
|
||||
|
||||
// GetApiKeyByID 根据ID获取ApiKey
|
||||
func (dao *ApiKeyDAO) GetByID(id int64) (*model.ApiKey, error) {
|
||||
var apiKey model.ApiKey
|
||||
err := dao.db.First(&apiKey, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &apiKey, nil
|
||||
}
|
||||
|
||||
// GetApiKeyByName 根据名称获取ApiKey
|
||||
func (dao *ApiKeyDAO) GetByName(name string) (*model.ApiKey, error) {
|
||||
var apiKey model.ApiKey
|
||||
err := dao.db.Where("name = ?", name).First(&apiKey).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &apiKey, nil
|
||||
}
|
||||
|
||||
// GetApiKeyByApiKey 根据ApiKey值获取ApiKey
|
||||
func (dao *ApiKeyDAO) GetByApiKey(apiKeyValue string) (*model.ApiKey, error) {
|
||||
var apiKey model.ApiKey
|
||||
err := dao.db.Where("api_key = ?", apiKeyValue).First(&apiKey).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &apiKey, nil
|
||||
}
|
||||
|
||||
func (dao *ApiKeyDAO) FindKeys(condition map[string]any) ([]model.ApiKey, error) {
|
||||
var apiKeys []model.ApiKey
|
||||
|
||||
query := dao.db.Model(&model.ApiKey{})
|
||||
for k, v := range condition {
|
||||
query = query.Where(k, v)
|
||||
}
|
||||
err := query.Find(&apiKeys).Error
|
||||
|
||||
return apiKeys, err
|
||||
}
|
||||
|
||||
func (dao *ApiKeyDAO) FindApiKeysBySupportModel(db *gorm.DB, modelName string) ([]model.ApiKey, error) {
|
||||
var apiKeys []model.ApiKey
|
||||
switch dao.cfg.DB_Type {
|
||||
case "mysql":
|
||||
return nil, errors.New("not support")
|
||||
case "postgres":
|
||||
return nil, errors.New("not support")
|
||||
}
|
||||
err := db.Model(&model.ApiKey{}).
|
||||
Joins("CROSS JOIN JSON_EACH(apikeys.support_models)").
|
||||
Where("value = ?", modelName).
|
||||
Find(&apiKeys).Error
|
||||
return apiKeys, err
|
||||
}
|
||||
|
||||
// UpdateApiKey 更新ApiKey信息
|
||||
func (dao *ApiKeyDAO) Update(apiKey *model.ApiKey) error {
|
||||
if apiKey == nil {
|
||||
return errors.New("apiKey is nil")
|
||||
}
|
||||
// return dao.db.Model(&model.ApiKey{}).
|
||||
// Select("name", "apitype", "apikey", "status", "endpoint", "resource_name", "deployment_name").Updates(apiKey).Error
|
||||
return dao.db.Save(apiKey).Error
|
||||
}
|
||||
|
||||
// DeleteApiKey 删除ApiKey
|
||||
func (dao *ApiKeyDAO) Delete(id int64) error {
|
||||
return dao.db.Unscoped().Delete(&model.ApiKey{}, id).Error
|
||||
}
|
||||
|
||||
// ListApiKeys 获取ApiKey列表
|
||||
func (dao *ApiKeyDAO) List(limit, offset int, status string) ([]*model.ApiKey, error) {
|
||||
var apiKeys []*model.ApiKey
|
||||
db := dao.db.Limit(limit).Offset(offset)
|
||||
if status != "" {
|
||||
db = db.Where("status = ?", status)
|
||||
}
|
||||
err := db.Find(&apiKeys).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return apiKeys, nil
|
||||
}
|
||||
|
||||
// ListApiKeysWithFilters 根据条件获取ApiKey列表
|
||||
func (dao *ApiKeyDAO) ListWithFilters(limit, offset int, filters map[string]interface{}) ([]*model.ApiKey, int64, error) {
|
||||
var apiKeys []*model.ApiKey
|
||||
db := dao.db.Limit(limit).Offset(offset)
|
||||
for k, v := range filters {
|
||||
db = db.Where(k, v)
|
||||
}
|
||||
err := db.Find(&apiKeys).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
var total int64
|
||||
db.Model(&model.ApiKey{}).Count(&total)
|
||||
|
||||
return apiKeys, total, nil
|
||||
}
|
||||
|
||||
// BatchEnableApiKeys 批量启用ApiKey
|
||||
func (dao *ApiKeyDAO) BatchEnable(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids is empty")
|
||||
}
|
||||
return dao.db.Model(&model.ApiKey{}).Where("id IN ?", ids).Update("active", true).Error
|
||||
}
|
||||
|
||||
// BatchDisableApiKeys 批量禁用ApiKey
|
||||
func (dao *ApiKeyDAO) BatchDisable(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids is empty")
|
||||
}
|
||||
return dao.db.Model(&model.ApiKey{}).Where("id IN ?", ids).Update("active", false).Error
|
||||
}
|
||||
|
||||
// BatchDeleteApiKey 批量删除ApiKey
|
||||
func (dao *ApiKeyDAO) BatchDelete(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids is empty")
|
||||
}
|
||||
return dao.db.Unscoped().Delete(&model.ApiKey{}, ids).Error
|
||||
}
|
||||
|
||||
// CountApiKeys 获取ApiKey总数
|
||||
func (dao *ApiKeyDAO) Count() (int64, error) {
|
||||
var count int64
|
||||
err := dao.db.Model(&model.ApiKey{}).Count(&count).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
198
internal/dao/token.go
Normal file
198
internal/dao/token.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"opencatd-open/internal/consts"
|
||||
"opencatd-open/internal/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 确保 TokenDAO 实现了 TokenRepository 接口
|
||||
var _ TokenRepository = (*TokenDAO)(nil)
|
||||
|
||||
type TokenRepository interface {
|
||||
Create(ctx context.Context, token *model.Token) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Token, error)
|
||||
GetByKey(ctx context.Context, key string) (*model.Token, error)
|
||||
GetByUserID(ctx context.Context, userID int64) (*model.Token, error)
|
||||
Update(ctx context.Context, token *model.Token) error
|
||||
UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error
|
||||
Delete(ctx context.Context, id int64, condition map[string]interface{}) error
|
||||
List(ctx context.Context, limit, offset int) ([]*model.Token, error)
|
||||
ListWithFilters(ctx context.Context, limit, offset int, filters map[string]interface{}) ([]*model.Token, int64, error)
|
||||
Disable(ctx context.Context, id int) error
|
||||
Enable(ctx context.Context, id int) error
|
||||
BatchDisable(ctx context.Context, ids []int64, filters map[string]interface{}) error
|
||||
BatchEnable(ctx context.Context, ids []int64, filters map[string]interface{}) error
|
||||
BatchDelete(ctx context.Context, ids []int64, filters map[string]interface{}) error
|
||||
}
|
||||
|
||||
type TokenDAO struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewTokenDAO(db *gorm.DB) *TokenDAO {
|
||||
return &TokenDAO{db: db}
|
||||
}
|
||||
|
||||
// CreateToken 创建 Token
|
||||
func (dao *TokenDAO) Create(ctx context.Context, token *model.Token) error {
|
||||
if token == nil {
|
||||
return errors.New("token is nil")
|
||||
}
|
||||
return dao.db.WithContext(ctx).Create(token).Error
|
||||
}
|
||||
|
||||
// 根据 ID 获取 Token
|
||||
func (dao *TokenDAO) GetByID(ctx context.Context, id int64) (*model.Token, error) {
|
||||
var token model.Token
|
||||
err := dao.db.WithContext(ctx).Preload("User").First(&token, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// 根据 Key 获取 Token
|
||||
func (dao *TokenDAO) GetByKey(ctx context.Context, key string) (*model.Token, error) {
|
||||
var token model.Token
|
||||
// err := dao.db.Where("key = ?", key).First(&token).Error
|
||||
err := dao.db.WithContext(ctx).Preload("User").Where("key = ?", key).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// 根据 UserID 获取 Token
|
||||
func (dao *TokenDAO) GetByUserID(ctx context.Context, userID int64) (*model.Token, error) {
|
||||
var token model.Token
|
||||
err := dao.db.WithContext(ctx).Preload("User").Where("user_id = ?", userID).Find(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// UpdateToken 更新 Token 信息
|
||||
func (dao *TokenDAO) Update(ctx context.Context, token *model.Token) error {
|
||||
if token == nil {
|
||||
return errors.New("token is nil")
|
||||
}
|
||||
return dao.db.WithContext(ctx).Save(token).Error
|
||||
}
|
||||
|
||||
// UpdateTokenWithFilters 更新 Token 信息,支持过滤
|
||||
func (dao *TokenDAO) UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error {
|
||||
if token == nil {
|
||||
return errors.New("token is nil")
|
||||
}
|
||||
db := dao.db.WithContext(ctx)
|
||||
for key, value := range filters {
|
||||
db = db.Where(key+" = ?", value)
|
||||
}
|
||||
return db.Model(&model.Token{}).Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteToken 删除 Token
|
||||
func (dao *TokenDAO) Delete(ctx context.Context, id int64, condition map[string]interface{}) error {
|
||||
if id <= 0 {
|
||||
return errors.New("id is invalid")
|
||||
}
|
||||
query := dao.db.WithContext(ctx).Where("id = ?", id)
|
||||
for key, value := range condition {
|
||||
query = query.Where(key, value)
|
||||
}
|
||||
return query.Unscoped().Delete(&model.Token{}).Error
|
||||
}
|
||||
|
||||
// ListTokens 获取 Token 列表
|
||||
func (dao *TokenDAO) List(ctx context.Context, limit, offset int) ([]*model.Token, error) {
|
||||
var tokens []*model.Token
|
||||
err := dao.db.WithContext(ctx).Limit(limit).Offset(offset).Find(&tokens).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// ListTokensWithFilters 获取 Token 列表,支持过滤
|
||||
func (dao *TokenDAO) ListWithFilters(ctx context.Context, limit, offset int, filters map[string]interface{}) ([]*model.Token, int64, error) {
|
||||
var tokens []*model.Token
|
||||
var count int64
|
||||
|
||||
db := dao.db.WithContext(ctx)
|
||||
if filters != nil {
|
||||
for k, v := range filters {
|
||||
db = db.Where(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
if err := db.Limit(limit).Offset(offset).Find(&tokens).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
db.Model(&model.Token{}).Count(&count)
|
||||
|
||||
return tokens, count, nil
|
||||
}
|
||||
|
||||
// DisableToken 禁用 Token
|
||||
func (dao *TokenDAO) Disable(ctx context.Context, id int) error {
|
||||
return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id = ?", id).Update("status", false).Error
|
||||
}
|
||||
|
||||
// EnableToken 启用 Token
|
||||
func (dao *TokenDAO) Enable(ctx context.Context, id int) error {
|
||||
return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id = ?", id).Update("status", true).Error
|
||||
}
|
||||
|
||||
// BatchDisableTokens 批量禁用 Token
|
||||
func (dao *TokenDAO) BatchDisable(ctx context.Context, ids []int64, filters map[string]interface{}) error {
|
||||
query := dao.db.WithContext(ctx).Model(&model.Token{}).Where("id IN ?", ids)
|
||||
for key, value := range filters {
|
||||
query = query.Where(key, value)
|
||||
}
|
||||
return query.Update("active", false).Error
|
||||
}
|
||||
|
||||
// BatchEnableTokens 批量启用 Token
|
||||
func (dao *TokenDAO) BatchEnable(ctx context.Context, ids []int64, filters map[string]interface{}) error {
|
||||
query := dao.db.WithContext(ctx).Model(&model.Token{}).Where("id IN ?", ids)
|
||||
for key, value := range filters {
|
||||
query = query.Where(key, value)
|
||||
}
|
||||
return query.Update("active", true).Error
|
||||
}
|
||||
|
||||
// BatchDeleteTokens 批量删除 Token
|
||||
func (dao *TokenDAO) BatchDelete(ctx context.Context, ids []int64, filters map[string]interface{}) error {
|
||||
query := dao.db.Unscoped().WithContext(ctx).Where("id IN ?", ids)
|
||||
for key, value := range filters {
|
||||
query = query.Where(key, value)
|
||||
}
|
||||
return query.Delete(&model.Token{}).Error
|
||||
// return dao.db.WithContext(ctx).Where("name != 'default' AND id IN ?", ids).Delete(&model.Token{}).Error
|
||||
}
|
||||
|
||||
// 检查 token 是否有效
|
||||
func (dao *TokenDAO) IsValid(ctx context.Context, key string) (bool, error) {
|
||||
var token model.Token
|
||||
err := dao.db.WithContext(ctx).Where("key = ? AND status = ? AND (expired_time = -1 OR expired_time > ?)",
|
||||
key, consts.StatusEnabled, time.Now().Unix()).First(&token).Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
if token.User.Status != consts.StatusEnabled || (*token.User.UnlimitedQuota && *token.User.Quota <= 0) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
295
internal/dao/usage.go
Normal file
295
internal/dao/usage.go
Normal file
@@ -0,0 +1,295 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
dto "opencatd-open/internal/dto/team"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/pkg/config"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
var _ UsageRepository = (*UsageDAO)(nil)
|
||||
var _ DailyUsageRepository = (*DailyUsageDAO)(nil)
|
||||
|
||||
type UsageRepository interface {
|
||||
// Create
|
||||
Create(ctx context.Context, usage *model.Usage) error
|
||||
BatchCreate(ctx context.Context, usages []*model.Usage) error
|
||||
|
||||
// Read
|
||||
ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error)
|
||||
ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.Usage, error)
|
||||
ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.Usage, error)
|
||||
ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error)
|
||||
|
||||
// Delete
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
// Statistics
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
}
|
||||
|
||||
type DailyUsageRepository interface {
|
||||
// Create
|
||||
Create(ctx context.Context, usage *model.DailyUsage) error
|
||||
BatchCreate(ctx context.Context, usages []*model.DailyUsage) error
|
||||
|
||||
// Read
|
||||
ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.DailyUsage, error)
|
||||
ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.DailyUsage, error)
|
||||
ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.DailyUsage, error)
|
||||
GetByDate(ctx context.Context, userID int64, date time.Time) (*model.DailyUsage, error)
|
||||
|
||||
// Delete
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
// Statistics
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
StatUserUsages(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error)
|
||||
}
|
||||
|
||||
type UsageDAO struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
type DailyUsageDAO struct {
|
||||
cfg *config.Config
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUsageDAO(cfg *config.Config, db *gorm.DB) *UsageDAO {
|
||||
return &UsageDAO{db: db}
|
||||
}
|
||||
|
||||
func NewDailyUsageDAO(cfg *config.Config, db *gorm.DB) *DailyUsageDAO {
|
||||
return &DailyUsageDAO{db: db}
|
||||
}
|
||||
|
||||
// Usage DAO implementations
|
||||
func (d *UsageDAO) Create(ctx context.Context, usage *model.Usage) error {
|
||||
return d.db.WithContext(ctx).Create(usage).Error
|
||||
}
|
||||
|
||||
func (d *UsageDAO) BatchCreate(ctx context.Context, usages []*model.Usage) error {
|
||||
return d.db.WithContext(ctx).Create(usages).Error
|
||||
}
|
||||
|
||||
func (d *UsageDAO) GetByID(ctx context.Context, id int64) (*model.Usage, error) {
|
||||
var usage model.Usage
|
||||
err := d.db.WithContext(ctx).First(&usage, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func (d *UsageDAO) ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error) {
|
||||
var usages []*model.Usage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("user_id = ?", userID).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *UsageDAO) ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.Usage, error) {
|
||||
var usages []*model.Usage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("token_id = ?", tokenID).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *UsageDAO) ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.Usage, error) {
|
||||
var usages []*model.Usage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("date BETWEEN ? AND ?", start, end).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *UsageDAO) ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) {
|
||||
var usages []*model.Usage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("capability = ?", capability).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *UsageDAO) Delete(ctx context.Context, id int64) error {
|
||||
return d.db.WithContext(ctx).Delete(&model.Usage{}, id).Error
|
||||
}
|
||||
|
||||
func (d *UsageDAO) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := d.db.WithContext(ctx).Model(&model.Usage{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// DailyUsage DAO implementations
|
||||
func (d *DailyUsageDAO) Create(ctx context.Context, usage *model.DailyUsage) error {
|
||||
return d.db.WithContext(ctx).Create(usage).Error
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) BatchCreate(ctx context.Context, usages []*model.DailyUsage) error {
|
||||
return d.db.WithContext(ctx).Create(usages).Error
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.DailyUsage, error) {
|
||||
var usages []*model.DailyUsage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("user_id = ?", userID).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.DailyUsage, error) {
|
||||
var usages []*model.DailyUsage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("token_id = ?", tokenID).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.DailyUsage, error) {
|
||||
var usages []*model.DailyUsage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("date BETWEEN ? AND ?", start, end).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) GetByDate(ctx context.Context, userID int64, date time.Time) (*model.DailyUsage, error) {
|
||||
var usage model.DailyUsage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("user_id = ? AND date = ?", userID, date).
|
||||
First(&usage).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
// UpsertDailyUsage 根据不同数据库类型执行 Upsert
|
||||
func (d *DailyUsageDAO) UpsertDailyUsage(ctx context.Context, usage *model.Usage) error {
|
||||
date := usage.Date.Truncate(24 * time.Hour)
|
||||
dailyUsage := &model.DailyUsage{
|
||||
UserID: usage.UserID,
|
||||
TokenID: usage.TokenID,
|
||||
Capability: usage.Capability,
|
||||
Model: usage.Model,
|
||||
Stream: usage.Stream,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
TotalTokens: usage.TotalTokens,
|
||||
Cost: usage.Cost,
|
||||
}
|
||||
|
||||
updateColumns := map[string]interface{}{
|
||||
"prompt_tokens": gorm.Expr("prompt_tokens + VALUES(prompt_tokens)"),
|
||||
"completion_tokens": gorm.Expr("completion_tokens + VALUES(completion_tokens)"),
|
||||
"total_tokens": gorm.Expr("total_tokens + VALUES(total_tokens)"),
|
||||
}
|
||||
|
||||
db := d.db.WithContext(ctx)
|
||||
|
||||
switch d.cfg.DB_Type {
|
||||
case "mysql":
|
||||
// MySQL: INSERT ... ON DUPLICATE KEY UPDATE
|
||||
return db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "user_id"},
|
||||
{Name: "token_id"},
|
||||
{Name: "capability"},
|
||||
{Name: "date"},
|
||||
{Name: "model"},
|
||||
{Name: "stream"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(updateColumns),
|
||||
}).Create(dailyUsage).Error
|
||||
|
||||
case "postgres":
|
||||
// PostgreSQL: INSERT ... ON CONFLICT DO UPDATE
|
||||
updateColumns := map[string]interface{}{
|
||||
"prompt_tokens": gorm.Expr("daily_usages.prompt_tokens + EXCLUDED.prompt_tokens"),
|
||||
"completion_tokens": gorm.Expr("daily_usages.completion_tokens + EXCLUDED.completion_tokens"),
|
||||
"total_tokens": gorm.Expr("daily_usages.total_tokens + EXCLUDED.total_tokens"),
|
||||
}
|
||||
return db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "user_id"},
|
||||
{Name: "token_id"},
|
||||
{Name: "capability"},
|
||||
{Name: "date"},
|
||||
{Name: "model"},
|
||||
{Name: "stream"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(updateColumns),
|
||||
}).Create(dailyUsage).Error
|
||||
case "sqlite":
|
||||
fallthrough
|
||||
default:
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
var existing model.DailyUsage
|
||||
err := tx.Where("user_id = ? AND token_id = ? AND capability = ? AND date = ? AND model = ? AND stream = ?",
|
||||
usage.UserID, usage.TokenID, usage.Capability, date, usage.Model, usage.Stream).
|
||||
First(&existing).Error
|
||||
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
// 记录不存在,创建新记录
|
||||
return tx.Create(dailyUsage).Error
|
||||
} else if err != nil {
|
||||
return err // 返回其他错误
|
||||
}
|
||||
|
||||
// 记录存在,更新
|
||||
return tx.Model(&existing).Updates(map[string]interface{}{
|
||||
"prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens),
|
||||
"completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens),
|
||||
"total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens),
|
||||
}).Error
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) Delete(ctx context.Context, id int64) error {
|
||||
return d.db.WithContext(ctx).Delete(&model.DailyUsage{}, id).Error
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := d.db.WithContext(ctx).Model(&model.DailyUsage{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) StatUserUsages(ctx context.Context, from, to time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) {
|
||||
var usages []*dto.UsageInfo
|
||||
|
||||
query := d.db.WithContext(ctx).
|
||||
Model(&model.DailyUsage{}).
|
||||
Select("user_id as userId, sum(total_tokens) as totalUnit, sum(cast(cost as decimal(20,6))) as cost")
|
||||
for key, value := range filters {
|
||||
query = query.Where(fmt.Sprintf("%s = ?", key), value)
|
||||
}
|
||||
query = query.Group("user_id").Where("date >= ? AND date <= ?", from, to)
|
||||
|
||||
err := query.Group("user_id").Find(&usages).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list usages: %w", err)
|
||||
}
|
||||
|
||||
return usages, nil
|
||||
}
|
||||
163
internal/dao/user.go
Normal file
163
internal/dao/user.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"opencatd-open/internal/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 确保 UserDAO 实现了 UserRepository 接口
|
||||
var _ UserRepository = (*UserDAO)(nil)
|
||||
|
||||
// UserRepository 定义用户数据访问操作的接口
|
||||
type UserRepository interface {
|
||||
Create(user *model.User) error
|
||||
GetByID(id int64) (*model.User, error)
|
||||
GetByUsername(username string) (*model.User, error)
|
||||
Update(user *model.User) error
|
||||
Delete(id int64) error
|
||||
List(limit, offset int, condition map[string]interface{}) ([]model.User, int64, error)
|
||||
// Enable(id int64) error
|
||||
// Disable(id int64) error
|
||||
BatchEnable(ids []int64, condition []string) error
|
||||
BatchDisable(ids []int64, condition []string) error
|
||||
BatchDelete(ids []int64, condition []string) error
|
||||
}
|
||||
|
||||
type UserDAO struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUserDAO(db *gorm.DB) *UserDAO {
|
||||
return &UserDAO{db: db}
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
func (dao *UserDAO) Create(user *model.User) error {
|
||||
if user == nil {
|
||||
return errors.New("user is nil")
|
||||
}
|
||||
fmt.Println(*user)
|
||||
|
||||
return dao.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 创建用户
|
||||
if err := tx.Create(user).Error; err != nil {
|
||||
return fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// 根据ID获取用户
|
||||
func (dao *UserDAO) GetByID(id int64) (*model.User, error) {
|
||||
var user model.User
|
||||
// err := dao.db.First(&user, id).Error
|
||||
err := dao.db.Preload("Tokens", "user_id = ?", id).First(&user, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// 根据用户名获取用户
|
||||
func (dao *UserDAO) GetByUsername(username string) (*model.User, error) {
|
||||
var user model.User
|
||||
// err := dao.db.Where("user_name = ?", username).First(&user).Error
|
||||
err := dao.db.Preload("Tokens").Where("user_name = ?", username).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// 更新用户信息
|
||||
func (dao *UserDAO) Update(user *model.User) error {
|
||||
if user == nil {
|
||||
return errors.New("user is nil")
|
||||
}
|
||||
|
||||
user.UpdatedAt = time.Now().Unix()
|
||||
return dao.db.Save(user).Error
|
||||
}
|
||||
|
||||
// 删除用户
|
||||
func (dao *UserDAO) Delete(id int64) error {
|
||||
return dao.db.Unscoped().Delete(&model.User{}, id).Error
|
||||
// return dao.db.Model(&model.User{}).Where("id = ?", id).Update("status", 2).Error
|
||||
}
|
||||
|
||||
// 获取用户列表
|
||||
func (dao *UserDAO) List(limit, offset int, condition map[string]interface{}) ([]model.User, int64, error) {
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
var users []model.User
|
||||
var total int64
|
||||
|
||||
query := dao.db.Preload("Tokens").Model(&model.User{})
|
||||
|
||||
for k, v := range condition {
|
||||
query = query.Where(k, v)
|
||||
}
|
||||
err := query.Limit(limit).Offset(offset).Find(&users).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
query = dao.db.Model(&model.User{})
|
||||
for k, v := range condition {
|
||||
query = query.Where(k, v)
|
||||
}
|
||||
query.Count(&total)
|
||||
|
||||
return users, total, nil
|
||||
}
|
||||
|
||||
// 启用User
|
||||
func (dao *UserDAO) Enable(id uint) error {
|
||||
return dao.db.Model(&model.User{}).Where("id = ?", id).Update("active", true).Error
|
||||
}
|
||||
|
||||
// 禁用User
|
||||
func (dao *UserDAO) Disable(id uint) error {
|
||||
return dao.db.Model(&model.User{}).Where("id = ?", id).Update("active", false).Error
|
||||
}
|
||||
|
||||
// 批量启用User
|
||||
func (dao *UserDAO) BatchEnable(ids []int64, condition []string) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids is empty")
|
||||
}
|
||||
query := dao.db.Model(&model.User{}).Where("id IN ?", ids)
|
||||
for _, value := range condition {
|
||||
query = query.Where(value)
|
||||
}
|
||||
return query.Update("active", true).Error
|
||||
}
|
||||
|
||||
// 批量禁用User
|
||||
func (dao *UserDAO) BatchDisable(ids []int64, condition []string) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids is empty")
|
||||
}
|
||||
query := dao.db.Model(&model.User{}).Where("id IN ?", ids)
|
||||
for _, value := range condition {
|
||||
query = query.Where(value)
|
||||
}
|
||||
return query.Update("active", false).Error
|
||||
}
|
||||
|
||||
// 批量删除用户
|
||||
func (dao *UserDAO) BatchDelete(ids []int64, condition []string) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids is empty")
|
||||
}
|
||||
query := dao.db.Unscoped().Where("id IN ?", ids)
|
||||
for _, value := range condition {
|
||||
query = query.Where(value)
|
||||
}
|
||||
return query.Delete(&model.User{}).Error
|
||||
}
|
||||
6
internal/dto/batch.go
Normal file
6
internal/dto/batch.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package dto
|
||||
|
||||
type BatchIDRequest struct {
|
||||
UserID *int64 `json:"user_id"`
|
||||
IDs []int64 `json:"ids" binding:"required"`
|
||||
}
|
||||
20
internal/dto/error.go
Normal file
20
internal/dto/error.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
Code int `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
func WrapErrorAsOpenAI(c *gin.Context, code int, msg string) {
|
||||
c.JSON(code, gin.H{
|
||||
"error": Error{
|
||||
Code: code,
|
||||
Message: msg,
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
107
internal/dto/key.go
Normal file
107
internal/dto/key.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
validation "github.com/go-ozzo/ozzo-validation/v4"
|
||||
)
|
||||
|
||||
// TeamKey 结构体定义
|
||||
type TeamKey struct {
|
||||
ID *int64 `json:"id,omitempty"`
|
||||
UserID *int64 `json:"userID,omitempty"`
|
||||
Name *string `json:"name,omitempty"` // 必须
|
||||
Key *string `json:"key,omitempty"`
|
||||
Status *int64 `json:"status,omitempty"` // 默认1 允许,0禁止
|
||||
Quota *int64 `json:"quota,omitempty"` // UnlimitedQuota不为1 的时候必须
|
||||
UnlimitedQuota *bool `json:"unlimitedQuota,omitempty"` // 默认1 不限制,0限制
|
||||
UsedQuota *int64 `json:"usedQuota,omitempty"`
|
||||
CreatedAt *int64 `json:"createdAt,omitempty"`
|
||||
ExpiredAt *int64 `json:"expiredAt,omitempty"` // 可选
|
||||
}
|
||||
|
||||
// DefaultTeamKey 创建一个具有默认值的 TeamKey
|
||||
func DefaultTeamKey() TeamKey {
|
||||
status := int64(1) // 默认允许
|
||||
unlimitedQuota := true // 默认不限制
|
||||
createdAt := time.Now().Unix()
|
||||
|
||||
return TeamKey{
|
||||
Status: &status,
|
||||
UnlimitedQuota: &unlimitedQuota,
|
||||
CreatedAt: &createdAt,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate 验证 TeamKey 结构体
|
||||
func (t TeamKey) Validate() error {
|
||||
// 自定义验证规则
|
||||
var quotaRule validation.Rule = validation.Skip
|
||||
if t.UnlimitedQuota != nil && !*t.UnlimitedQuota {
|
||||
quotaRule = validation.Required.Error("当 UnlimitedQuota 为 false 时,Quota 是必填项")
|
||||
}
|
||||
|
||||
// 过期时间校验
|
||||
var expiredAtRule validation.Rule = validation.Skip
|
||||
if t.ExpiredAt != nil {
|
||||
expiredAtRule = validation.Min(time.Now().Unix()).Error("过期时间不能早于当前时间")
|
||||
}
|
||||
|
||||
return validation.ValidateStruct(&t,
|
||||
// ID 通常由系统生成,不需要验证
|
||||
|
||||
// UserID 可选,但如果提供必须大于 0
|
||||
validation.Field(&t.UserID,
|
||||
validation.When(t.UserID != nil, validation.Min(int64(1)).Error("用户 ID 必须大于 0"))),
|
||||
|
||||
// Name 是必填字段
|
||||
validation.Field(&t.Name,
|
||||
validation.Required.Error("名称不能为空"),
|
||||
validation.When(t.Name != nil, validation.Length(1, 100).Error("名称长度应在 1-100 之间"))),
|
||||
|
||||
// Key 可选,但如果提供需要符合特定格式
|
||||
validation.Field(&t.Key,
|
||||
validation.When(t.Key != nil,
|
||||
validation.Length(1, 255).Error("Key 长度应在 1-255 之间")),
|
||||
validation.Match(regexp.MustCompile(`^[^\s]+$`)).Error("Key 不能包含空格"),
|
||||
),
|
||||
|
||||
// Status 只能是 0 或 1
|
||||
validation.Field(&t.Status,
|
||||
validation.When(t.Status != nil, validation.In(int64(0), int64(1)).Error("状态只能是 0(禁止) 或 1(允许)"))),
|
||||
|
||||
// Quota 要求依赖于 UnlimitedQuota
|
||||
validation.Field(&t.Quota, quotaRule,
|
||||
validation.When(t.Quota != nil, validation.Min(int64(1)).Error("配额必须大于 0"))),
|
||||
|
||||
// UnlimitedQuota 是否限制配额
|
||||
validation.Field(&t.UnlimitedQuota),
|
||||
|
||||
// UsedQuota 系统维护,不需要验证
|
||||
validation.Field(&t.UsedQuota,
|
||||
validation.When(t.UsedQuota != nil, validation.Min(int64(0)).Error("已使用配额不能为负数"))),
|
||||
|
||||
// CreatedAt 系统维护,不需要验证
|
||||
validation.Field(&t.CreatedAt),
|
||||
|
||||
// ExpiredAt 可选,但如果提供必须大于当前时间
|
||||
validation.Field(&t.ExpiredAt, expiredAtRule),
|
||||
)
|
||||
}
|
||||
|
||||
// ValidateCreate 创建时的特殊验证
|
||||
func (t TeamKey) ValidateCreate() error {
|
||||
// 首先进行基本验证
|
||||
if err := t.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建时的额外验证
|
||||
if t.Name == nil {
|
||||
return errors.New("创建时必须提供名称")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
11
internal/dto/passkey.go
Normal file
11
internal/dto/passkey.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package dto
|
||||
|
||||
type Passkey struct {
|
||||
ID int64 `json:"id" gorm:"column:id;primaryKey;autoIncrement"`
|
||||
Name string `json:"name" gorm:"column:name"` // 凭证名称,用于用户识别不同的设备
|
||||
SignCount uint32 `json:"sign_count" gorm:"column:sign_count"` // 签名计数器,用于防止重放攻击
|
||||
DeviceType string `json:"device_type" gorm:"column:device_type"` // 设备类型,如"platform"或"cross-platform"
|
||||
LastUsedAt int64 `json:"last_used_at" gorm:"column:last_used_at"` // 最后使用时间
|
||||
CreatedAt int64 `json:"created_at,omitempty" gorm:"autoCreateTime"`
|
||||
UpdatedAt int64 `json:"updated_at,omitempty" gorm:"autoUpdateTime"`
|
||||
}
|
||||
28
internal/dto/response.go
Normal file
28
internal/dto/response.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Result struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Data any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
func Success(ctx *gin.Context, data any) {
|
||||
ctx.JSON(http.StatusOK, Result{
|
||||
Code: 200,
|
||||
Data: data,
|
||||
Msg: "success",
|
||||
})
|
||||
}
|
||||
|
||||
func Fail(c *gin.Context, code int, err string) {
|
||||
c.AbortWithStatusJSON(code, gin.H{
|
||||
"code": code,
|
||||
"error": err,
|
||||
})
|
||||
}
|
||||
83
internal/dto/team/team.go
Normal file
83
internal/dto/team/team.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/internal/utils"
|
||||
)
|
||||
|
||||
type UserInfo struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Token string `json:"token"`
|
||||
Status *bool `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
func (u UserInfo) HasNameUpdate() bool {
|
||||
return u.Name != ""
|
||||
}
|
||||
|
||||
func (u UserInfo) HasTokenUpdate() bool {
|
||||
return u.Token != ""
|
||||
}
|
||||
|
||||
func (u UserInfo) HasStatusUpdate() bool {
|
||||
return u.Status != nil
|
||||
}
|
||||
|
||||
type ApiKeyInfo struct {
|
||||
ID int `json:"id,omitempty"`
|
||||
Key string `json:"key,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ApiType string `json:"api_type,omitempty"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
Status *bool `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
// 添加辅助方法判断字段是否需要更新
|
||||
func (k ApiKeyInfo) HasNameUpdate() bool {
|
||||
return k.Name != ""
|
||||
}
|
||||
|
||||
func (k ApiKeyInfo) HasKeyUpdate() bool {
|
||||
return k.Key != ""
|
||||
}
|
||||
|
||||
func (k ApiKeyInfo) HasStatusUpdate() bool {
|
||||
return k.Status != nil
|
||||
}
|
||||
|
||||
func (k ApiKeyInfo) HasApiTypeUpdate() bool {
|
||||
return k.ApiType != ""
|
||||
}
|
||||
|
||||
// 辅助函数:统一处理字段更新
|
||||
func (update *ApiKeyInfo) UpdateFields(existing *model.ApiKey) *model.ApiKey {
|
||||
result := &model.ApiKey{
|
||||
ID: existing.ID,
|
||||
Name: existing.Name, // 默认保持原值
|
||||
ApiType: existing.ApiType, // 默认保持原值
|
||||
ApiKey: existing.ApiKey, // 默认保持原值
|
||||
Active: existing.Active, // 默认保持原值
|
||||
}
|
||||
|
||||
if update.HasNameUpdate() {
|
||||
result.Name = utils.ToPtr(update.Name)
|
||||
}
|
||||
if update.HasKeyUpdate() {
|
||||
result.ApiKey = utils.ToPtr(update.Key)
|
||||
}
|
||||
if update.HasStatusUpdate() {
|
||||
result.Active = update.Status
|
||||
}
|
||||
if update.HasApiTypeUpdate() {
|
||||
result.ApiType = utils.ToPtr(update.ApiType)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type UsageInfo struct {
|
||||
UserId int `json:"userId"`
|
||||
TotalUnit int `json:"totalUnit"`
|
||||
Cost string `json:"cost"`
|
||||
}
|
||||
16
internal/dto/user.go
Normal file
16
internal/dto/user.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package dto
|
||||
|
||||
type User struct {
|
||||
Username string `json:"username" binding:"required,min=3,max=32"`
|
||||
Password string `json:"password" binding:"required,min=4"`
|
||||
}
|
||||
|
||||
type Auth struct {
|
||||
Token string `json:"token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
}
|
||||
|
||||
type ChangePassword struct {
|
||||
Password string `json:"password" binding:"required,min=4"`
|
||||
NewPassword string `json:"newpassword" binding:"required,min=4"`
|
||||
}
|
||||
50
internal/model/apikey.go
Normal file
50
internal/model/apikey.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package model
|
||||
|
||||
import "github.com/lib/pq" //pq.StringArray
|
||||
|
||||
type ApiKey_PG struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id,omitempty"`
|
||||
Name *string `gorm:"column:name;not null;unique;index:idx_apikey_name" json:"name,omitempty"`
|
||||
ApiType *string `gorm:"column:apitype;not null;index:idx_apikey_apitype" json:"type,omitempty"`
|
||||
ApiKey *string `gorm:"column:apikey;not null;index:idx_apikey_apikey" json:"apikey,omitempty"`
|
||||
Active *bool `gorm:"column:active;default:true" json:"active,omitempty"`
|
||||
Endpoint *string `gorm:"column:endpoint" json:"endpoint,omitempty"`
|
||||
ResourceNmae *string `gorm:"column:resource_name" json:"resource_name,omitempty"`
|
||||
DeploymentName *string `gorm:"column:deployment_name" json:"deployment_name,omitempty"`
|
||||
ApiSecret *string `gorm:"column:api_secret" json:"api_secret,omitempty"`
|
||||
ModelPrefix *string `gorm:"column:model_prefix" json:"model_prefix,omitempty"`
|
||||
ModelAlias *string `gorm:"column:model_alias" json:"model_alias,omitempty"`
|
||||
Parameters *string `gorm:"column:parameters" json:"parameters,omitempty"`
|
||||
SupportModelsArray pq.StringArray `gorm:"column:support_models;type:text[]" json:"support_models_array,omitempty"`
|
||||
SupportModels *string `gorm:"-" json:"support_models,omitempty"`
|
||||
CreatedAt int64 `gorm:"column:created_at;autoUpdateTime" json:"created_at,omitempty"`
|
||||
UpdatedAt int64 `gorm:"column:updated_at;autoCreateTime" json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
func (ApiKey_PG) TableName() string {
|
||||
return "apikeys"
|
||||
}
|
||||
|
||||
type ApiKey struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id,omitempty"`
|
||||
Name *string `gorm:"column:name;not null;unique;index:idx_apikey_name" json:"name,omitempty"`
|
||||
ApiType *string `gorm:"column:apitype;not null;index:idx_apikey_apitype" json:"type,omitempty"`
|
||||
ApiKey *string `gorm:"column:apikey;not null;index:idx_apikey_apikey" json:"apikey,omitempty"`
|
||||
Active *bool `gorm:"column:active;default:true" json:"active,omitempty"`
|
||||
Endpoint *string `gorm:"column:endpoint" json:"endpoint,omitempty"`
|
||||
ResourceNmae *string `gorm:"column:resource_name" json:"resource_name,omitempty"`
|
||||
DeploymentName *string `gorm:"column:deployment_name" json:"deployment_name,omitempty"`
|
||||
AccessKey *string `gorm:"column:access_key" json:"access_key,omitempty"`
|
||||
SecretKey *string `gorm:"column:secret_key" json:"secret_key,omitempty"`
|
||||
ModelPrefix *string `gorm:"column:model_prefix" json:"model_prefix,omitempty"`
|
||||
ModelAlias *string `gorm:"column:model_alias" json:"model_alias,omitempty"`
|
||||
Parameters *string `gorm:"column:parameters" json:"parameters,omitempty"`
|
||||
SupportModels *string `gorm:"column:support_models;type:json" json:"support_models,omitempty"`
|
||||
SupportModelsArray []string `gorm:"-" json:"support_models_array,omitempty"`
|
||||
CreatedAt int64 `gorm:"column:created_at;autoUpdateTime" json:"created_at,omitempty"`
|
||||
UpdatedAt int64 `gorm:"column:updated_at;autoCreateTime" json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
func (ApiKey) TableName() string {
|
||||
return "apikeys"
|
||||
}
|
||||
38
internal/model/passkey.go
Normal file
38
internal/model/passkey.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Passkey 用户凭证密钥模型
|
||||
type Passkey struct {
|
||||
ID int64 `json:"id" gorm:"column:id;primaryKey;autoIncrement"`
|
||||
UserID int64 `json:"user_id" gorm:"column:user_id;index"`
|
||||
CredentialID string `json:"credential_id" gorm:"column:credential_id;index"` // 凭证ID,用于识别特定的passkey
|
||||
PublicKey string `json:"public_key" gorm:"column:public_key"` // 公钥,用于验证签名
|
||||
AttestationType string `json:"attestation_type" gorm:"column:attestation_type"` // 证明类型
|
||||
AAGUID string `json:"aaguid" gorm:"column:aaguid"` // 认证器标识符
|
||||
SignCount uint32 `json:"sign_count" gorm:"column:sign_count"` // 签名计数器,用于防止重放攻击
|
||||
Name string `json:"name" gorm:"column:name"` // 凭证名称,用于用户识别不同的设备
|
||||
DeviceType string `json:"device_type" gorm:"column:device_type"` // 设备类型
|
||||
BackupEligible bool `json:"backup_eligible" gorm:"column:backup_eligible"` // 是否可备份
|
||||
BackupState bool `json:"backup_state" gorm:"backup_state"` // 备份状态
|
||||
Transport string `json:"transport" gorm:"column:transport"` // 传输方式 (如usb、nfc、ble等)
|
||||
LastUsedAt int64 `json:"last_used_at" gorm:"column:last_used_at;autoUpdateTime"` // 最后使用时间
|
||||
CreatedAt int64 `json:"created_at,omitempty" gorm:"column:created_at;autoCreateTime"`
|
||||
UpdatedAt int64 `json:"updated_at,omitempty" gorm:"column:updated_at;autoUpdateTime"`
|
||||
|
||||
// 关联用户模型(不存入数据库)
|
||||
User User `json:"-" gorm:"foreignKey:UserID;references:ID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
|
||||
}
|
||||
|
||||
// 创建表结构
|
||||
func (Passkey) TableName() string {
|
||||
return "passkeys"
|
||||
}
|
||||
|
||||
// UpdateSignCount 更新签名计数器和最后使用时间
|
||||
func (p *Passkey) UpdateSignCount(count uint32) {
|
||||
p.SignCount = count
|
||||
p.LastUsedAt = time.Now().Unix()
|
||||
}
|
||||
22
internal/model/token.go
Normal file
22
internal/model/token.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package model
|
||||
|
||||
// 用户的token
|
||||
type Token struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id,omitempty"`
|
||||
UserID int64 `gorm:"column:user_id;not null;index:idx_token_user_id" json:"userid,omitempty"`
|
||||
Name string `gorm:"column:name;not null;index:idx_token_name" json:"name,omitempty" binding:"required,min=1,max=20"`
|
||||
Key string `gorm:"column:key;not null;uniqueIndex:idx_token_key;comment:token key" json:"key,omitempty"`
|
||||
Active *bool `gorm:"column:active;default:true" json:"active,omitempty"` //
|
||||
Quota *int64 `gorm:"column:quota;type:bigint;default:0" json:"quota,omitempty"` // default 0
|
||||
UnlimitedQuota *bool `gorm:"column:unlimited_quota;default:true" json:"unlimited_quota,omitempty"` // set Quota 1 unlimited
|
||||
UsedQuota *int64 `gorm:"column:used_quota;type:bigint;default:0" json:"used_quota,omitempty"`
|
||||
ExpiredAt *int64 `gorm:"column:expired_at;type:bigint;default:0" json:"expired_at,omitempty"`
|
||||
NeverExpired *bool `gorm:"column:never_expires;type:bigint;" json:"never_expires,omitempty"`
|
||||
CreatedAt int64 `gorm:"column:created_at;type:bigint;autoCreateTime" json:"created_at,omitempty"`
|
||||
LastUsedAt int64 `gorm:"column:lastused_at;type:bigint;autoUpdateTime" json:"lastused_at,omitempty"`
|
||||
User *User `gorm:"foreignKey:UserID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE" json:"-"`
|
||||
}
|
||||
|
||||
func (Token) TableName() string {
|
||||
return "tokens"
|
||||
}
|
||||
74
internal/model/usage.go
Normal file
74
internal/model/usage.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"opencatd-open/store"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Usage struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
UserID int64 `gorm:"column:user_id;index:idx_user_id"`
|
||||
TokenID int64 `gorm:"column:token_id;index:idx_token_id"`
|
||||
Capability string `gorm:"column:capability;index:idx_usage_capability;comment:模型能力"`
|
||||
Date time.Time `gorm:"column:date;autoCreateTime;index:idx_date"`
|
||||
Model string `gorm:"column:model"`
|
||||
Stream bool `gorm:"column:stream"`
|
||||
PromptTokens float64 `gorm:"column:prompt_tokens"`
|
||||
CompletionTokens float64 `gorm:"column:completion_tokens"`
|
||||
TotalTokens float64 `gorm:"column:total_tokens"`
|
||||
Cost string `gorm:"column:cost"`
|
||||
}
|
||||
|
||||
func (Usage) TableName() string {
|
||||
return "usages"
|
||||
}
|
||||
|
||||
type DailyUsage struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
UserID int64 `gorm:"column:user_id;uniqueIndex:idx_daily_unique,priority:1"`
|
||||
TokenID int64 `gorm:"column:token_id;index:idx_daily_token_id"`
|
||||
Capability string `gorm:"column:capability;uniqueIndex:idx_daily_unique,priority:2;comment:模型能力"`
|
||||
Date time.Time `gorm:"column:date;autoCreateTime;uniqueIndex:idx_daily_unique,priority:3"`
|
||||
Model string `gorm:"column:model"`
|
||||
Stream bool `gorm:"column:stream"`
|
||||
PromptTokens float64 `gorm:"column:prompt_tokens"`
|
||||
CompletionTokens float64 `gorm:"column:completion_tokens"`
|
||||
TotalTokens float64 `gorm:"column:total_tokens"`
|
||||
Cost string `gorm:"column:cost"`
|
||||
}
|
||||
|
||||
func (DailyUsage) TableName() string {
|
||||
return "daily_usages"
|
||||
}
|
||||
|
||||
func HandleUsage(c *gin.Context) {
|
||||
fromStr := c.Query("from")
|
||||
toStr := c.Query("to")
|
||||
getMonthStartAndEnd := func() (start, end string) {
|
||||
loc, _ := time.LoadLocation("Local")
|
||||
now := time.Now().In(loc)
|
||||
|
||||
year, month, _ := now.Date()
|
||||
|
||||
startOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, loc)
|
||||
endOfMonth := startOfMonth.AddDate(0, 1, 0)
|
||||
|
||||
start = startOfMonth.Format("2006-01-02")
|
||||
end = endOfMonth.Format("2006-01-02")
|
||||
return
|
||||
}
|
||||
if fromStr == "" || toStr == "" {
|
||||
fromStr, toStr = getMonthStartAndEnd()
|
||||
}
|
||||
|
||||
usage, err := store.QueryUsage(fromStr, toStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, usage)
|
||||
}
|
||||
53
internal/model/user.go
Normal file
53
internal/model/user.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"opencatd-open/internal/consts"
|
||||
"time"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID int64 `json:"id" gorm:"column:id;primaryKey;autoIncrement"`
|
||||
Name string `json:"name" gorm:"column:name;index"`
|
||||
Username string `json:"username" gorm:"column:username;unique;index"`
|
||||
Password string `json:"-" gorm:"column:password;"`
|
||||
NewPassword string `json:"newpassword" gorm:"-"`
|
||||
Role *consts.UserRole `json:"role" gorm:"column:role;type:int;default:0"` // default user 0-10-20
|
||||
Active *bool `json:"active" gorm:"column:active;default:true;"`
|
||||
Status int `json:"status" gorm:"column:status;type:int;default:1"` // disabled 0, enabled 1, deleted 2
|
||||
AvatarURL string `json:"avatar_url" gorm:"column:avatar_url;type:varchar(255)"`
|
||||
EmailVerified *bool `json:"email_verified" gorm:"column:email_verified;default:false"`
|
||||
Email string `json:"email" gorm:"column:email;type:varchar(255);index"`
|
||||
Quota *float32 `json:"quota" gorm:"column:quota;bigint;default:0"` // default unlimited
|
||||
UsedQuota *float32 `json:"used_quota" gorm:"column:used_quota;bigint;default:0"` // default 0
|
||||
UnlimitedQuota *bool `json:"unlimited_quota" gorm:"column:unlimited_quota;default:true;"` // 0 limited , 1 unlimited
|
||||
Timezone string `json:"timezone" gorm:"column:timezone;type:varchar(50)"`
|
||||
Language string `json:"language" gorm:"column:language;type:varchar(50)"`
|
||||
|
||||
// 添加一对多关系
|
||||
Tokens []Token `json:"tokens" gorm:"foreignKey:UserID;references:ID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
|
||||
Passkeys []Passkey `json:"passkeys" gorm:"foreignKey:UserID;references:ID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
|
||||
|
||||
CreatedAt int64 `json:"created_at,omitempty" gorm:"autoCreateTime"`
|
||||
UpdatedAt int64 `json:"updated_at,omitempty" gorm:"autoUpdateTime"`
|
||||
}
|
||||
|
||||
func (User) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
ID int64 `json:"id" gorm:"primaryKey;autoIncrement"`
|
||||
UserID int64 `json:"user_id" gorm:"index:idx_user_id"`
|
||||
Token string `json:"token" gorm:"type:varchar(64);uniqueIndex"`
|
||||
DeviceType string `json:"device_type" gorm:"type:varchar(100);default:''"`
|
||||
DeviceName string `json:"device_name" gorm:"type:varchar(100);default:''"`
|
||||
LastActiveAt time.Time `json:"last_active_at" gorm:"type:timestamp;default:CURRENT_TIMESTAMP"`
|
||||
LogoutAt time.Time `json:"logout_at" gorm:"type:timestamp;null"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at" gorm:"type:timestamp;not null;default:CURRENT_TIMESTAMP"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"type:timestamp;not null;default:CURRENT_TIMESTAMP;update:CURRENT_TIMESTAMP"`
|
||||
}
|
||||
|
||||
func (Session) TableName() string {
|
||||
return "sessions"
|
||||
}
|
||||
92
internal/service/apikey.go
Normal file
92
internal/service/apikey.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"opencatd-open/internal/dao"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/internal/utils"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ApiKeyServiceImpl struct {
|
||||
db *gorm.DB
|
||||
apiKeyRepo dao.ApiKeyRepository
|
||||
}
|
||||
|
||||
func NewApiKeyService(db *gorm.DB, apiKeyDao dao.ApiKeyRepository) *ApiKeyServiceImpl {
|
||||
return &ApiKeyServiceImpl{db: db, apiKeyRepo: apiKeyDao}
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) CreateApiKey(ctx context.Context, apikey *model.ApiKey) error {
|
||||
return s.apiKeyRepo.Create(apikey)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) GetApiKey(ctx context.Context, id int64) (*model.ApiKey, error) {
|
||||
return s.apiKeyRepo.GetByID(id)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) ListApiKey(ctx context.Context, limit, offset int, active []string) ([]*model.ApiKey, int64, error) {
|
||||
var conditions = make(map[string]interface{})
|
||||
if len(active) > 0 {
|
||||
conditions["active IN ?"] = utils.StringToBool(active)
|
||||
}
|
||||
return s.apiKeyRepo.ListWithFilters(limit, offset, conditions)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) UpdateApiKey(ctx context.Context, apikey *model.ApiKey) error {
|
||||
_key, err := s.apiKeyRepo.GetByID(apikey.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get apikey failed: %v", err)
|
||||
}
|
||||
if apikey.ApiKey != nil {
|
||||
_key.ApiKey = apikey.ApiKey
|
||||
}
|
||||
if apikey.Active != nil {
|
||||
_key.Active = apikey.Active
|
||||
}
|
||||
if apikey.Endpoint != nil {
|
||||
_key.Endpoint = apikey.Endpoint
|
||||
}
|
||||
if apikey.ResourceNmae != nil {
|
||||
_key.ResourceNmae = apikey.ResourceNmae
|
||||
}
|
||||
if apikey.DeploymentName != nil {
|
||||
_key.DeploymentName = apikey.DeploymentName
|
||||
}
|
||||
if apikey.AccessKey != nil {
|
||||
_key.AccessKey = apikey.AccessKey
|
||||
}
|
||||
if apikey.SecretKey != nil {
|
||||
_key.SecretKey = apikey.SecretKey
|
||||
}
|
||||
if apikey.ModelAlias != nil {
|
||||
_key.ModelAlias = apikey.ModelAlias
|
||||
}
|
||||
if apikey.ModelPrefix != nil {
|
||||
_key.ModelPrefix = apikey.ModelPrefix
|
||||
}
|
||||
if apikey.Parameters != nil {
|
||||
_key.Parameters = apikey.Parameters
|
||||
}
|
||||
if apikey.SupportModels != nil {
|
||||
_key.SupportModels = apikey.SupportModels
|
||||
}
|
||||
if apikey.SupportModelsArray != nil {
|
||||
_key.SupportModelsArray = apikey.SupportModelsArray
|
||||
}
|
||||
|
||||
return s.apiKeyRepo.Update(apikey)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) DeleteApiKey(ctx context.Context, ids []int64) error {
|
||||
return s.apiKeyRepo.BatchDelete(ids)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) EnableApiKey(ctx context.Context, ids []int64) error {
|
||||
return s.apiKeyRepo.BatchEnable(ids)
|
||||
}
|
||||
func (s *ApiKeyServiceImpl) DisableApiKey(ctx context.Context, ids []int64) error {
|
||||
return s.apiKeyRepo.BatchDisable(ids)
|
||||
}
|
||||
150
internal/service/team/apikey.go
Normal file
150
internal/service/team/apikey.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"opencatd-open/internal/dao"
|
||||
"opencatd-open/internal/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ ApiKeyService = (*ApiKeyServiceImpl)(nil)
|
||||
|
||||
type ApiKeyService interface {
|
||||
Create(apiKey *model.ApiKey) error
|
||||
GetByID(id int64) (*model.ApiKey, error)
|
||||
GetByName(name string) (*model.ApiKey, error)
|
||||
GetByApiKey(apiKeyValue string) (*model.ApiKey, error)
|
||||
Update(apiKey *model.ApiKey) error
|
||||
Delete(id int64) error
|
||||
List(limit, offset int, status string) ([]*model.ApiKey, error)
|
||||
ListWithFilters(limit, offset int, filters map[string]interface{}) ([]*model.ApiKey, int64, error)
|
||||
BatchEnable(ids []int64) error
|
||||
BatchDisable(ids []int64) error
|
||||
BatchDelete(ids []int64) error
|
||||
Count() (int64, error)
|
||||
}
|
||||
|
||||
type ApiKeyServiceImpl struct {
|
||||
db *gorm.DB
|
||||
apiKeyRepo dao.ApiKeyRepository
|
||||
}
|
||||
|
||||
func NewApiKeyService(db *gorm.DB, apiKeyDao dao.ApiKeyRepository) ApiKeyService {
|
||||
return &ApiKeyServiceImpl{apiKeyRepo: apiKeyDao, db: db}
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Create(apiKey *model.ApiKey) error {
|
||||
if apiKey == nil {
|
||||
return errors.New("apiKey不能为空")
|
||||
}
|
||||
if apiKey.Name == nil {
|
||||
return errors.New("apiKey名称不能为空")
|
||||
}
|
||||
if apiKey.ApiKey == nil {
|
||||
return errors.New("apiKey值不能为空")
|
||||
}
|
||||
apiKey.CreatedAt = time.Now().Unix()
|
||||
apiKey.UpdatedAt = time.Now().Unix()
|
||||
|
||||
return s.apiKeyRepo.Create(apiKey)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) GetByID(id int64) (*model.ApiKey, error) {
|
||||
if id <= 0 {
|
||||
return nil, errors.New("id 必须大于 0")
|
||||
}
|
||||
return s.apiKeyRepo.GetByID(id)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) GetByName(name string) (*model.ApiKey, error) {
|
||||
if name == "" {
|
||||
return nil, errors.New("name 不能为空")
|
||||
}
|
||||
return s.apiKeyRepo.GetByName(name)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) GetByApiKey(apiKeyValue string) (*model.ApiKey, error) {
|
||||
if apiKeyValue == "" {
|
||||
return nil, errors.New("apiKeyValue 不能为空")
|
||||
}
|
||||
return s.apiKeyRepo.GetByApiKey(apiKeyValue)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Update(apiKey *model.ApiKey) error {
|
||||
if apiKey == nil {
|
||||
return errors.New("apiKey不能为空")
|
||||
}
|
||||
if apiKey.ID <= 0 {
|
||||
return errors.New("apiKey ID 必须大于 0")
|
||||
}
|
||||
return s.apiKeyRepo.Update(apiKey)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Delete(id int64) error {
|
||||
if id <= 0 {
|
||||
return errors.New("id 必须大于 0")
|
||||
}
|
||||
return s.apiKeyRepo.BatchDelete([]int64{id})
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) List(offset, limit int, status string) ([]*model.ApiKey, error) {
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = 20 // 设置默认值
|
||||
}
|
||||
return s.apiKeyRepo.List(offset, limit, status)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) ListWithFilters(offset, limit int, filters map[string]interface{}) ([]*model.ApiKey, int64, error) {
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = 20 // 设置默认值
|
||||
}
|
||||
|
||||
return s.apiKeyRepo.ListWithFilters(offset, limit, filters)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Enable(id int64) error {
|
||||
if id <= 0 {
|
||||
return errors.New("id 必须大于 0")
|
||||
}
|
||||
return s.apiKeyRepo.BatchEnable([]int64{id})
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Disable(id int64) error {
|
||||
if id <= 0 {
|
||||
return errors.New("id 必须大于 0")
|
||||
}
|
||||
return s.apiKeyRepo.BatchDisable([]int64{id})
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) BatchEnable(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids 不能为空")
|
||||
}
|
||||
return s.apiKeyRepo.BatchEnable(ids)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) BatchDisable(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids 不能为空")
|
||||
}
|
||||
return s.apiKeyRepo.BatchDisable(ids)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) BatchDelete(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids 不能为空")
|
||||
}
|
||||
return s.apiKeyRepo.BatchDelete(ids)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Count() (int64, error) {
|
||||
return s.apiKeyRepo.Count()
|
||||
}
|
||||
77
internal/service/team/token.go
Normal file
77
internal/service/team/token.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"opencatd-open/internal/dao"
|
||||
"opencatd-open/internal/model"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// 确保 TokenService 实现了 TokenServiceInterface 接口
|
||||
var _ TokenService = (*TokenServiceImpl)(nil)
|
||||
|
||||
type TokenService interface {
|
||||
Create(ctx context.Context, token *model.Token) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Token, error)
|
||||
GetByKey(ctx context.Context, key string) (*model.Token, error)
|
||||
GetByUserID(ctx context.Context, userID int64) (*model.Token, error)
|
||||
Update(ctx context.Context, token *model.Token) error
|
||||
UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
Lists(ctx context.Context, limit, offset int) ([]*model.Token, int64, error)
|
||||
Disable(ctx context.Context, id int) error
|
||||
Enable(ctx context.Context, id int) error
|
||||
}
|
||||
|
||||
type TokenServiceImpl struct {
|
||||
tokenRepo dao.TokenRepository
|
||||
}
|
||||
|
||||
func NewTokenService(tokenRepo dao.TokenRepository) TokenService {
|
||||
return &TokenServiceImpl{tokenRepo: tokenRepo}
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Create(ctx context.Context, token *model.Token) error {
|
||||
if token.Key == "" {
|
||||
token.Key = "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
}
|
||||
return s.tokenRepo.Create(ctx, token)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) GetByID(ctx context.Context, id int64) (*model.Token, error) {
|
||||
return s.tokenRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) GetByKey(ctx context.Context, key string) (*model.Token, error) {
|
||||
return s.tokenRepo.GetByKey(ctx, key)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) GetByUserID(ctx context.Context, userID int64) (*model.Token, error) {
|
||||
return s.tokenRepo.GetByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Update(ctx context.Context, token *model.Token) error {
|
||||
return s.tokenRepo.Update(ctx, token)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error {
|
||||
return s.tokenRepo.UpdateWithCondition(ctx, token, filters, updates)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Delete(ctx context.Context, id int64) error {
|
||||
return s.tokenRepo.Delete(ctx, id, nil)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Lists(ctx context.Context, limit, offset int) ([]*model.Token, int64, error) {
|
||||
return s.tokenRepo.ListWithFilters(ctx, limit, offset, nil)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Disable(ctx context.Context, id int) error {
|
||||
return s.tokenRepo.Disable(ctx, id)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Enable(ctx context.Context, id int) error {
|
||||
return s.tokenRepo.Enable(ctx, id)
|
||||
}
|
||||
62
internal/service/team/usage.go
Normal file
62
internal/service/team/usage.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"opencatd-open/internal/dao"
|
||||
dto "opencatd-open/internal/dto/team"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/pkg/config"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ UsageService = (*usageService)(nil)
|
||||
|
||||
type UsageService interface {
|
||||
ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error)
|
||||
ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error)
|
||||
ListByDateRange(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error)
|
||||
|
||||
Delete(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
type usageService struct {
|
||||
ctx context.Context
|
||||
cfg *config.Config
|
||||
db *gorm.DB
|
||||
|
||||
usageDAO dao.UsageRepository
|
||||
dailyUsageDAO dao.DailyUsageRepository
|
||||
}
|
||||
|
||||
func NewUsageService(ctx context.Context, cfg *config.Config, db *gorm.DB, usageRepo dao.UsageRepository, dailyUsageRepo dao.DailyUsageRepository) UsageService {
|
||||
srv := &usageService{
|
||||
ctx: ctx,
|
||||
cfg: cfg,
|
||||
db: db,
|
||||
|
||||
usageDAO: usageRepo,
|
||||
dailyUsageDAO: dailyUsageRepo,
|
||||
}
|
||||
|
||||
// 启动异步处理goroutine
|
||||
|
||||
return srv
|
||||
}
|
||||
|
||||
func (s *usageService) ListByUserID(ctx context.Context, userID int64, limit int, offset int) ([]*model.Usage, error) {
|
||||
return s.usageDAO.ListByUserID(ctx, userID, limit, offset)
|
||||
}
|
||||
|
||||
func (s *usageService) ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) {
|
||||
return s.usageDAO.ListByCapability(ctx, capability, limit, offset)
|
||||
}
|
||||
|
||||
func (s *usageService) ListByDateRange(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) {
|
||||
return s.dailyUsageDAO.StatUserUsages(ctx, start, end, filters)
|
||||
}
|
||||
|
||||
func (s *usageService) Delete(ctx context.Context, id int64) error {
|
||||
return s.usageDAO.Delete(ctx, id)
|
||||
}
|
||||
679
internal/service/team/user.go
Normal file
679
internal/service/team/user.go
Normal file
@@ -0,0 +1,679 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"opencatd-open/internal/consts"
|
||||
"opencatd-open/internal/dao"
|
||||
"opencatd-open/internal/model"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/exp/rand"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrInvalidUserInput = errors.New("invalid user input")
|
||||
ErrUserExists = errors.New("user already exists")
|
||||
ErrInvalidPassword = errors.New("invalid password format")
|
||||
ErrPermissionDenied = errors.New("permission denied")
|
||||
ErrInvalidOperation = errors.New("invalid operation")
|
||||
ErrTransactionFailed = errors.New("transaction failed")
|
||||
)
|
||||
|
||||
// PasswordPolicy 定义密码策略
|
||||
type PasswordPolicy struct {
|
||||
MinLength int
|
||||
MaxLength int
|
||||
NeedNumber bool
|
||||
NeedUpper bool
|
||||
NeedLower bool
|
||||
NeedSymbol bool
|
||||
}
|
||||
|
||||
// 生成随机数字
|
||||
func generateNumber() string {
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
return string('0' + rand.Intn(10))
|
||||
}
|
||||
|
||||
// 生成随机大写字母
|
||||
func generateUpper() string {
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
return string('A' + rand.Intn(26))
|
||||
}
|
||||
|
||||
// 生成随机小写字母
|
||||
func generateLower() string {
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
return string('a' + rand.Intn(26))
|
||||
}
|
||||
|
||||
// 生成随机特殊符号
|
||||
func generateSymbol() string {
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
symbols := "!@#$%^&*"
|
||||
return string(symbols[rand.Intn(len(symbols))])
|
||||
}
|
||||
|
||||
// GeneratePassword 根据密码策略生成密码
|
||||
func GeneratePassword(policy PasswordPolicy) string {
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
|
||||
// 确保满足所有必须的字符类型
|
||||
var password string
|
||||
if policy.NeedNumber {
|
||||
password += generateNumber()
|
||||
}
|
||||
if policy.NeedUpper {
|
||||
password += generateUpper()
|
||||
}
|
||||
if policy.NeedLower {
|
||||
password += generateLower()
|
||||
}
|
||||
if policy.NeedSymbol {
|
||||
password += generateSymbol()
|
||||
}
|
||||
|
||||
// 计算还需要多少个字符
|
||||
remainingLength := policy.MinLength - len(password)
|
||||
if remainingLength < 0 {
|
||||
remainingLength = 0
|
||||
}
|
||||
// 剩余长度随机生成密码字符
|
||||
for i := 0; i < remainingLength; i++ {
|
||||
randType := rand.Intn(4) // 0:数字, 1:大写, 2:小写, 3:符号
|
||||
switch randType {
|
||||
case 0:
|
||||
password += generateNumber()
|
||||
case 1:
|
||||
password += generateUpper()
|
||||
case 2:
|
||||
password += generateLower()
|
||||
case 3:
|
||||
password += generateSymbol()
|
||||
}
|
||||
}
|
||||
|
||||
// 如果密码长度超过最大值,则截断
|
||||
if len(password) > policy.MaxLength {
|
||||
password = password[:policy.MaxLength]
|
||||
}
|
||||
|
||||
// 将密码打乱
|
||||
passwordRune := []rune(password)
|
||||
rand.Shuffle(len(passwordRune), func(i, j int) {
|
||||
passwordRune[i], passwordRune[j] = passwordRune[j], passwordRune[i]
|
||||
})
|
||||
|
||||
return string(passwordRune)
|
||||
}
|
||||
|
||||
var _ UserService = (*userService)(nil)
|
||||
|
||||
// UserService 定义用户服务的接口
|
||||
type UserService interface {
|
||||
CreateUser(ctx context.Context, user *model.User) error
|
||||
GetUser(ctx context.Context, id int64) (*model.User, error)
|
||||
GetUserByUsername(ctx context.Context, username string) (*model.User, error)
|
||||
UpdateUser(ctx context.Context, user *model.User, operatorID int64) error
|
||||
DeleteUser(ctx context.Context, id int64, operatorID int64) error
|
||||
ListUsers(ctx context.Context, limit, offset int, active string) ([]model.User, error)
|
||||
// EnableUser(ctx context.Context, id int64, operatorID int64) error
|
||||
// DisableUser(ctx context.Context, id int64, operatorID int64) error
|
||||
BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error
|
||||
BatchDisableUsers(ctx context.Context, ids []int64, operatorID int64) error
|
||||
BatchDeleteUsers(ctx context.Context, ids []int64, operatorID int64) error
|
||||
ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error
|
||||
ResetPassword(ctx context.Context, userID int64, operatorID int64) error
|
||||
ValidatePassword(password string) error
|
||||
CheckPermission(ctx context.Context, requiredRole consts.UserRole) error
|
||||
}
|
||||
|
||||
// userService 实现 UserService 接口
|
||||
type userService struct {
|
||||
userRepo dao.UserRepository
|
||||
db *gorm.DB
|
||||
pwdPolicy PasswordPolicy
|
||||
}
|
||||
|
||||
// NewUserService 创建 UserService 实例
|
||||
func NewUserService(db *gorm.DB, userRepo dao.UserRepository) UserService {
|
||||
return &userService{
|
||||
userRepo: userRepo,
|
||||
db: db,
|
||||
pwdPolicy: PasswordPolicy{
|
||||
MinLength: 8,
|
||||
MaxLength: 32,
|
||||
NeedNumber: true,
|
||||
NeedUpper: true,
|
||||
NeedLower: true,
|
||||
NeedSymbol: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// hashPassword 使用 bcrypt 加密密码
|
||||
func (s *userService) hashPassword(password string) (string, error) {
|
||||
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(hashedBytes), nil
|
||||
}
|
||||
|
||||
// comparePasswords 比较密码
|
||||
func (s *userService) comparePasswords(hashedPassword, password string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ValidatePassword 验证密码是否符合策略
|
||||
func (s *userService) ValidatePassword(password string) error {
|
||||
if len(password) < s.pwdPolicy.MinLength || len(password) > s.pwdPolicy.MaxLength {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
if s.pwdPolicy.NeedNumber && !regexp.MustCompile(`[0-9]`).MatchString(password) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
if s.pwdPolicy.NeedUpper && !regexp.MustCompile(`[A-Z]`).MatchString(password) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
if s.pwdPolicy.NeedLower && !regexp.MustCompile(`[a-z]`).MatchString(password) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
if s.pwdPolicy.NeedSymbol && !regexp.MustCompile(`[!@#$%^&*]`).MatchString(password) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPermission 检查用户权限
|
||||
func (s *userService) CheckPermission(ctx context.Context, requiredRole consts.UserRole) error {
|
||||
userToken := ctx.Value("Token").(*model.Token)
|
||||
|
||||
// 检查用户角色
|
||||
if *userToken.User.Role < requiredRole {
|
||||
return ErrPermissionDenied
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// withTransaction 事务处理封装
|
||||
func (s *userService) withTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error {
|
||||
tx := s.db.WithContext(ctx).Begin()
|
||||
if tx.Error != nil {
|
||||
return ErrTransactionFailed
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return ErrTransactionFailed
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateUser 创建用户
|
||||
func (s *userService) CreateUser(ctx context.Context, user *model.User) error {
|
||||
if user == nil {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
if user.Password == "" {
|
||||
user.Password = GeneratePassword(s.pwdPolicy)
|
||||
}
|
||||
|
||||
// 使用事务处理
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查用户名是否已存在
|
||||
// _, err := s.userRepo.GetByID(user.ID)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// 加密密码
|
||||
hashedPassword, err := s.hashPassword(user.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user.Password = hashedPassword
|
||||
|
||||
return s.userRepo.Create(user)
|
||||
})
|
||||
}
|
||||
|
||||
// GetUser 根据 ID 获取用户
|
||||
func (s *userService) GetUser(ctx context.Context, id int64) (*model.User, error) {
|
||||
if id <= 0 {
|
||||
return nil, ErrInvalidUserInput
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return nil, err // 返回其他数据库错误
|
||||
}
|
||||
// 处理返回结果,清除敏感信息
|
||||
user.Password = "" // 清除密码信息
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetUserByUsername 根据用户名获取用户
|
||||
func (s *userService) GetUserByUsername(ctx context.Context, username string) (*model.User, error) {
|
||||
if username == "" {
|
||||
return nil, ErrInvalidUserInput
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByUsername(username)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, err // 返回其他数据库错误
|
||||
}
|
||||
|
||||
// 处理返回结果,清除敏感信息
|
||||
user.Password = "" // 清除密码信息
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// UpdateUser 更新用户信息
|
||||
func (s *userService) UpdateUser(ctx context.Context, user *model.User, operatorID int64) error {
|
||||
if user == nil || user.ID <= 0 {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查用户是否存在
|
||||
existingUser, err := s.userRepo.GetByID(user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果修改了用户名,检查新用户名是否已存在
|
||||
if user.Username != existingUser.Username {
|
||||
tmpUser, err := s.userRepo.GetByUsername(user.Username)
|
||||
if err == nil && tmpUser != nil && tmpUser.ID != user.ID {
|
||||
return ErrUserExists
|
||||
}
|
||||
}
|
||||
|
||||
// 保持原有密码
|
||||
user.Password = existingUser.Password
|
||||
user.UpdatedAt = time.Now().Unix()
|
||||
|
||||
return s.userRepo.Update(user)
|
||||
})
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error {
|
||||
// 验证新密码
|
||||
if err := s.ValidatePassword(newPassword); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
// 验证旧密码
|
||||
if !s.comparePasswords(user.Password, oldPassword) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
// 加密新密码
|
||||
hashedPassword, err := s.hashPassword(newPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.Password = hashedPassword
|
||||
user.UpdatedAt = time.Now().Unix()
|
||||
|
||||
return s.userRepo.Update(user)
|
||||
})
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
func (s *userService) ResetPassword(ctx context.Context, userID int64, operatorID int64) error {
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
// 生成随机密码
|
||||
newPassword := generateRandomPassword()
|
||||
hashedPassword, err := s.hashPassword(newPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.Password = hashedPassword
|
||||
user.UpdatedAt = time.Now().Unix()
|
||||
|
||||
// TODO: 发送新密码给用户邮箱
|
||||
|
||||
return s.userRepo.Update(user)
|
||||
})
|
||||
}
|
||||
|
||||
// ListUsers 获取用户列表(增加过滤功能)
|
||||
func (s *userService) ListUsers(ctx context.Context, limit, offset int, active string) ([]model.User, error) {
|
||||
if limit < 0 {
|
||||
limit = 20
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
var users []model.User
|
||||
var err error
|
||||
if active != "" {
|
||||
users, _, err = s.userRepo.List(limit, offset, map[string]interface{}{"active in ?": strings.Split(active, ",")})
|
||||
} else {
|
||||
users, _, err = s.userRepo.List(limit, offset, nil)
|
||||
}
|
||||
|
||||
return users, err
|
||||
}
|
||||
|
||||
// generateRandomPassword 生成随机密码
|
||||
func generateRandomPassword() string {
|
||||
const (
|
||||
lowerChars = "abcdefghijklmnopqrstuvwxyz"
|
||||
upperChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
numberChars = "0123456789"
|
||||
specialChars = "!@#$%^&*"
|
||||
)
|
||||
// rand.NewSource(uint64(time.Now().UnixNano()))
|
||||
|
||||
// 确保每种字符都至少出现一次
|
||||
password := []string{
|
||||
string(lowerChars[rand.Intn(len(lowerChars))]),
|
||||
string(upperChars[rand.Intn(len(upperChars))]),
|
||||
string(numberChars[rand.Intn(len(numberChars))]),
|
||||
string(specialChars[rand.Intn(len(specialChars))]),
|
||||
}
|
||||
|
||||
// 所有可用字符
|
||||
allChars := lowerChars + upperChars + numberChars + specialChars
|
||||
|
||||
// 生成剩余的12个字符
|
||||
for i := 0; i < 12; i++ {
|
||||
password = append(password, string(allChars[rand.Intn(len(allChars))]))
|
||||
}
|
||||
|
||||
// 打乱密码字符顺序
|
||||
rand.Shuffle(len(password), func(i, j int) {
|
||||
password[i], password[j] = password[j], password[i]
|
||||
})
|
||||
|
||||
return strings.Join(password, "")
|
||||
}
|
||||
|
||||
// DeleteUser 删除用户
|
||||
func (s *userService) DeleteUser(ctx context.Context, id int64, operatorID int64) error {
|
||||
// 检查参数
|
||||
if id <= 0 {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 不允许删除自己
|
||||
if id == operatorID {
|
||||
return ErrInvalidOperation
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查用户是否存在
|
||||
user, err := s.userRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查是否试图删除管理员
|
||||
if *user.Role == consts.RoleAdmin {
|
||||
return ErrPermissionDenied
|
||||
}
|
||||
|
||||
return s.userRepo.Delete(id)
|
||||
})
|
||||
}
|
||||
|
||||
// EnableUser 启用用户
|
||||
// func (s *userService) EnableUser(ctx context.Context, id int64, operatorID int64) error {
|
||||
// // 检查参数
|
||||
// if id <= 0 {
|
||||
// return ErrInvalidUserInput
|
||||
// }
|
||||
|
||||
// // 检查操作者权限
|
||||
// if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// // 检查用户是否存在
|
||||
// user, err := s.userRepo.GetByID(id)
|
||||
// if err != nil {
|
||||
// return ErrUserNotFound
|
||||
// }
|
||||
|
||||
// // 如果用户已经是启用状态,返回成功
|
||||
// if user.Status == consts.StatusEnabled {
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// return s.userRepo.Enable(id)
|
||||
// })
|
||||
// }
|
||||
|
||||
// DisableUser 禁用用户
|
||||
// func (s *userService) DisableUser(ctx context.Context, id int64, operatorID int64) error {
|
||||
// // 检查参数
|
||||
// if id <= 0 {
|
||||
// return ErrInvalidUserInput
|
||||
// }
|
||||
|
||||
// // 检查操作者权限
|
||||
// if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// // 不允许禁用自己
|
||||
// if id == operatorID {
|
||||
// return ErrInvalidOperation
|
||||
// }
|
||||
|
||||
// return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// // 检查用户是否存在
|
||||
// user, err := s.userRepo.GetByID(id)
|
||||
// if err != nil {
|
||||
// return ErrUserNotFound
|
||||
// }
|
||||
|
||||
// // 检查是否试图禁用超级管理员
|
||||
// if user.Role == consts.RoleAdmin {
|
||||
// return ErrPermissionDenied
|
||||
// }
|
||||
|
||||
// // 如果用户已经是禁用状态,返回成功
|
||||
// if user.Status == consts.StatusDisabled {
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// return s.userRepo.Disable(id)
|
||||
// })
|
||||
// }
|
||||
|
||||
// BatchEnableUsers 批量启用用户
|
||||
func (s *userService) BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error {
|
||||
// 检查参数
|
||||
if len(ids) == 0 {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
// 检查操作者权限
|
||||
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查所有用户是否存在,并收集当前状态
|
||||
enabledUsers := make([]int64, 0)
|
||||
for _, id := range ids {
|
||||
user, err := s.userRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
if *user.Active == true {
|
||||
enabledUsers = append(enabledUsers, id)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果所有用户都已经是启用状态,返回成功
|
||||
if len(enabledUsers) == len(ids) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 过滤掉已经启用的用户,只处理需要启用的用户
|
||||
toEnableIds := make([]int64, 0)
|
||||
for _, id := range ids {
|
||||
if !contains(enabledUsers, id) {
|
||||
toEnableIds = append(toEnableIds, id)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toEnableIds) > 0 {
|
||||
return s.userRepo.BatchEnable(toEnableIds, nil)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// BatchDisableUsers 批量禁用用户
|
||||
func (s *userService) BatchDisableUsers(ctx context.Context, ids []int64, operatorID int64) error {
|
||||
// 检查参数
|
||||
if len(ids) == 0 {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
// 检查操作者权限
|
||||
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 不允许包含自己
|
||||
if contains(ids, operatorID) {
|
||||
return ErrInvalidOperation
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查所有用户是否存在
|
||||
disabledUsers := make([]int64, 0)
|
||||
for _, id := range ids {
|
||||
user, err := s.userRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
// 不允许禁用管理员
|
||||
if *user.Role == consts.RoleAdmin {
|
||||
return ErrPermissionDenied
|
||||
}
|
||||
if user.Status == consts.StatusDisabled {
|
||||
disabledUsers = append(disabledUsers, id)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果所有用户都已经是禁用状态,返回成功
|
||||
if len(disabledUsers) == len(ids) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 过滤掉已经禁用的用户,只处理需要禁用的用户
|
||||
toDisableIds := make([]int64, 0)
|
||||
for _, id := range ids {
|
||||
if !contains(disabledUsers, id) {
|
||||
toDisableIds = append(toDisableIds, id)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toDisableIds) > 0 {
|
||||
return s.userRepo.BatchDisable(toDisableIds, nil)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// BatchDeleteUsers 批量删除用户
|
||||
func (s *userService) BatchDeleteUsers(ctx context.Context, ids []int64, operatorID int64) error {
|
||||
// 检查参数
|
||||
if len(ids) == 0 {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
// 检查操作者权限
|
||||
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 不允许包含自己
|
||||
if contains(ids, operatorID) {
|
||||
return ErrInvalidOperation
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查所有用户是否存在,并确保不会删除管理员
|
||||
for _, id := range ids {
|
||||
user, err := s.userRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
if *user.Role == consts.RoleAdmin {
|
||||
return ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
|
||||
return s.userRepo.BatchDelete(ids, nil)
|
||||
})
|
||||
}
|
||||
|
||||
// contains 检查切片中是否包含特定值
|
||||
func contains[T comparable](slice []T, item T) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
251
internal/service/token.go
Normal file
251
internal/service/token.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"opencatd-open/internal/consts"
|
||||
"opencatd-open/internal/dao"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/internal/utils"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// var _ TokenService = (*TokenServiceImpl)(nil)
|
||||
|
||||
// type TokenService interface {
|
||||
// }
|
||||
|
||||
type TokenServiceImpl struct {
|
||||
db *gorm.DB
|
||||
tokenRepo dao.TokenRepository
|
||||
}
|
||||
|
||||
func NewTokenService(db *gorm.DB, tokenRepo dao.TokenRepository) *TokenServiceImpl {
|
||||
return &TokenServiceImpl{
|
||||
db: db,
|
||||
tokenRepo: tokenRepo,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TokenServiceImpl) CreateToken(ctx context.Context, token *model.Token) error {
|
||||
if token.UserID == 0 {
|
||||
token.UserID = ctx.Value("user_id").(int64)
|
||||
}
|
||||
if token.Active == nil {
|
||||
token.Active = utils.ToPtr(true)
|
||||
}
|
||||
if token.UnlimitedQuota == nil {
|
||||
token.UnlimitedQuota = utils.ToPtr(true)
|
||||
}
|
||||
if token.ExpiredAt == nil {
|
||||
token.ExpiredAt = utils.ToPtr(int64(-1))
|
||||
}
|
||||
|
||||
if token.Key == "" {
|
||||
token.Key = "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
}
|
||||
if !strings.HasPrefix(token.Key, "sk-team-") {
|
||||
token.Key = "sk-team-" + strings.ReplaceAll(token.Key, " ", "")
|
||||
}
|
||||
return t.tokenRepo.Create(ctx, token)
|
||||
}
|
||||
|
||||
func (t *TokenServiceImpl) GetToken(ctx context.Context, id int64) (*model.Token, error) {
|
||||
userid := ctx.Value("user_id").(int64)
|
||||
tk := &model.Token{}
|
||||
return tk, t.db.Model(&model.Token{}).Where("user_id = ?", userid).Where("id = ?", id).First(tk).Error
|
||||
}
|
||||
|
||||
func (t *TokenServiceImpl) ListToken(ctx context.Context, limit, offset int, active []string) ([]*model.Token, int64, error) {
|
||||
userid := ctx.Value("user_id").(int64)
|
||||
condition := make(map[string]interface{})
|
||||
condition["user_id = ?"] = userid
|
||||
if len(active) > 0 {
|
||||
condition["active IN ?"] = utils.StringToBool(active)
|
||||
return t.tokenRepo.ListWithFilters(ctx, limit, offset, condition)
|
||||
}
|
||||
return t.tokenRepo.ListWithFilters(ctx, limit, offset, condition)
|
||||
}
|
||||
|
||||
func (t *TokenServiceImpl) UpdateToken(ctx context.Context, token *model.Token) error {
|
||||
userid := ctx.Value("user_id").(int64) // 操作者
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
|
||||
role, ok := userRoleValue.(*consts.UserRole) // 操作角色
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
switch {
|
||||
case *role < consts.RoleAdmin:
|
||||
if userid != token.UserID {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
case *role == consts.RoleAdmin:
|
||||
if *role <= *token.User.Role {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
}
|
||||
|
||||
return t.db.Model(&model.Token{}).Where("id = ?", token.ID).Updates(token).Error
|
||||
}
|
||||
|
||||
func (t *TokenServiceImpl) ResetToken(ctx context.Context, id int64) error {
|
||||
userid := ctx.Value("user_id").(int64) // 操作者
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
|
||||
role, ok := userRoleValue.(*consts.UserRole) // 操作角色
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
switch {
|
||||
case *role < consts.RoleAdmin:
|
||||
if userid != id {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
case *role == consts.RoleAdmin:
|
||||
var user = &model.User{}
|
||||
if err := t.db.Model(&model.User{}).Where("id = ?", id).First(user).Error; err != nil {
|
||||
return fmt.Errorf("User not found")
|
||||
}
|
||||
if *role <= *user.Role {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
}
|
||||
|
||||
token := "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
return t.db.Model(&model.Token{}).Where("user_id = ?", userid).Where("id = ?", id).Update("token", token).Error
|
||||
}
|
||||
func (t *TokenServiceImpl) DeleteToken(ctx context.Context, id int64) error {
|
||||
token, err := t.tokenRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Token not found")
|
||||
}
|
||||
if token.User == nil {
|
||||
return fmt.Errorf("Token user not found")
|
||||
}
|
||||
|
||||
role := ctx.Value("user_role").(*consts.UserRole) // 操作角色
|
||||
userid := ctx.Value("user_id").(int64) // 操作者
|
||||
|
||||
switch {
|
||||
case *role < consts.RoleAdmin:
|
||||
if userid != token.UserID {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
case *role == consts.RoleAdmin:
|
||||
if *role <= *token.User.Role {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
}
|
||||
|
||||
return t.db.Model(&model.Token{}).Where("id = ?", id).Delete(&model.Token{}).Error
|
||||
}
|
||||
|
||||
func (t *TokenServiceImpl) DeleteTokens(ctx context.Context, userid int64, ids []int64) error {
|
||||
operator_id := ctx.Value("user_id").(int64)
|
||||
|
||||
roleValue := ctx.Value("user_role")
|
||||
if roleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
operator_role, ok := roleValue.(*consts.UserRole) // 操作角色
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
switch {
|
||||
case *operator_role < consts.RoleAdmin:
|
||||
if operator_id != userid {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
return t.tokenRepo.BatchDelete(ctx, ids, map[string]interface{}{"name != ?": "default", "user_id = ?": userid})
|
||||
case *operator_role == consts.RoleAdmin:
|
||||
var user = &model.User{}
|
||||
if err := t.db.Model(&model.User{}).Where("id = ?", userid).First(user).Error; err != nil {
|
||||
return fmt.Errorf("User not found")
|
||||
}
|
||||
if *operator_role <= *user.Role {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
return t.tokenRepo.BatchDelete(ctx, ids, map[string]interface{}{"name != ?": "default", "user_id = ?": userid})
|
||||
default:
|
||||
return t.tokenRepo.BatchDelete(ctx, ids, map[string]interface{}{"name != ?": "default"})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (t *TokenServiceImpl) EnableTokens(ctx context.Context, userid int64, ids []int64) error {
|
||||
operator_id := ctx.Value("user_id").(int64)
|
||||
|
||||
roleValue := ctx.Value("user_role")
|
||||
if roleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
operator_role, ok := roleValue.(*consts.UserRole) // 操作角色
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
switch {
|
||||
case *operator_role < consts.RoleAdmin:
|
||||
if operator_id != userid {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
return t.tokenRepo.BatchEnable(ctx, ids, map[string]interface{}{"user_id = ?": userid})
|
||||
case *operator_role == consts.RoleAdmin:
|
||||
var user = &model.User{}
|
||||
if err := t.db.Model(&model.User{}).Where("id = ?", userid).First(user).Error; err != nil {
|
||||
return fmt.Errorf("User not found")
|
||||
}
|
||||
if *operator_role <= *user.Role {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
return t.tokenRepo.BatchEnable(ctx, ids, map[string]interface{}{"user_id = ?": userid})
|
||||
default:
|
||||
return t.tokenRepo.BatchEnable(ctx, ids, nil)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (t *TokenServiceImpl) DisableTokens(ctx context.Context, userid int64, ids []int64) error {
|
||||
operator_id := ctx.Value("user_id").(int64)
|
||||
|
||||
roleValue := ctx.Value("user_role")
|
||||
if roleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
operator_role, ok := roleValue.(*consts.UserRole) // 操作角色
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
switch {
|
||||
case *operator_role < consts.RoleAdmin:
|
||||
if operator_id != userid {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
return t.tokenRepo.BatchDisable(ctx, ids, map[string]interface{}{"user_id =": userid})
|
||||
case *operator_role == consts.RoleAdmin:
|
||||
var user = &model.User{}
|
||||
if err := t.db.Model(&model.User{}).Where("id = ?", userid).First(user).Error; err != nil {
|
||||
return fmt.Errorf("User not found")
|
||||
}
|
||||
if *operator_role <= *user.Role {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
return t.tokenRepo.BatchDisable(ctx, ids, map[string]interface{}{"user_id =": userid})
|
||||
default:
|
||||
return t.tokenRepo.BatchDisable(ctx, ids, nil)
|
||||
}
|
||||
|
||||
}
|
||||
22
internal/service/usage.go
Normal file
22
internal/service/usage.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"opencatd-open/pkg/config"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UsageService struct {
|
||||
Ctx context.Context
|
||||
Cfg *config.Config
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
func NewUsageService(ctx context.Context, cfg *config.Config, db *gorm.DB) *UsageService {
|
||||
return &UsageService{
|
||||
Ctx: ctx,
|
||||
Cfg: cfg,
|
||||
DB: db,
|
||||
}
|
||||
}
|
||||
320
internal/service/user.go
Normal file
320
internal/service/user.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"opencatd-open/internal/auth"
|
||||
"opencatd-open/internal/consts"
|
||||
"opencatd-open/internal/dao"
|
||||
"opencatd-open/internal/dto"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/internal/utils"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserServiceImpl struct {
|
||||
db *gorm.DB
|
||||
userRepo dao.UserRepository
|
||||
}
|
||||
|
||||
func NewUserService(db *gorm.DB, userRepo dao.UserRepository) *UserServiceImpl {
|
||||
return &UserServiceImpl{
|
||||
db: db,
|
||||
userRepo: userRepo,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) Register(ctx context.Context, req *model.User) error {
|
||||
var _user model.User
|
||||
var count int64
|
||||
err := s.db.Model(&model.User{}).Count(&count).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("username or email already exists")
|
||||
}
|
||||
if count == 0 {
|
||||
_user.Name = "root"
|
||||
_user.Role = utils.ToPtr(consts.RoleRoot)
|
||||
_user.Active = utils.ToPtr(true)
|
||||
_user.UnlimitedQuota = utils.ToPtr(true)
|
||||
}
|
||||
_user.Password, err = utils.HashPassword(req.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_user.Username = req.Username
|
||||
_user.Email = req.Email
|
||||
_user.Tokens = []model.Token{
|
||||
{
|
||||
Name: "default",
|
||||
Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
},
|
||||
}
|
||||
|
||||
return s.userRepo.Create(&_user)
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) Login(ctx context.Context, req *dto.User) (*dto.Auth, error) {
|
||||
var _user model.User
|
||||
if err := s.db.Model(&model.User{}).Where("username = ?", req.Username).First(&_user).Error; err != nil {
|
||||
if err := s.db.Model(&model.User{}).Where("email = ?", req.Username).First(&_user).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if utils.CheckPassword(_user.Password, req.Password) {
|
||||
day := 86400
|
||||
at, err := auth.GenerateTokenPair(&_user, consts.SecretKey, time.Duration(day)*time.Second, time.Duration(day*7)*time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.Auth{
|
||||
Token: at.AccessToken,
|
||||
ExpiresIn: time.Now().Add(time.Duration(day) * time.Second).Unix(),
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("密码错误")
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) Profile(ctx context.Context) (*model.User, error) {
|
||||
id := ctx.Value("user_id").(int64)
|
||||
return s.userRepo.GetByID(id)
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) List(ctx context.Context, limit, offset int, active []string) ([]model.User, int64, error) {
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return nil, 0, fmt.Errorf("user role not found in context")
|
||||
}
|
||||
|
||||
role, ok := userRoleValue.(*consts.UserRole)
|
||||
if !ok {
|
||||
return nil, 0, fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
if *role < consts.RoleAdmin {
|
||||
return nil, 0, fmt.Errorf("Unauthorized")
|
||||
} else if *role < consts.RoleRoot { // 管理员只能查看普通用户
|
||||
var condition = map[string]interface{}{"role = ?": consts.RoleUser}
|
||||
if len(active) > 0 {
|
||||
boolCondition := utils.StringToBool(active)
|
||||
condition["active IN ?"] = boolCondition
|
||||
}
|
||||
return s.userRepo.List(limit, offset, condition)
|
||||
} else {
|
||||
var condition = make(map[string]interface{})
|
||||
if len(active) > 0 {
|
||||
boolCondition := utils.StringToBool(active)
|
||||
condition["active IN ?"] = boolCondition
|
||||
}
|
||||
return s.userRepo.List(limit, offset, condition)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) Create(ctx context.Context, req *model.User) error {
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
|
||||
role, ok := userRoleValue.(*consts.UserRole)
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
var _user model.User
|
||||
|
||||
if *role < consts.RoleAdmin {
|
||||
return fmt.Errorf("Forbidden")
|
||||
} else if *role < consts.RoleRoot {
|
||||
_user.Role = utils.ToPtr(consts.RoleRoot)
|
||||
} else {
|
||||
_user.Role = req.Role
|
||||
}
|
||||
_user.Username = req.Username
|
||||
_user.Name = req.Name
|
||||
_user.Email = req.Email
|
||||
_user.Active = req.Active
|
||||
_user.Quota = req.Quota
|
||||
_user.UnlimitedQuota = req.UnlimitedQuota
|
||||
_user.Language = req.Language
|
||||
if hashpass, err := utils.HashPassword(req.Password); err != nil {
|
||||
return err
|
||||
} else {
|
||||
_user.Password = hashpass
|
||||
}
|
||||
_user.Tokens = []model.Token{
|
||||
{
|
||||
Name: "default",
|
||||
Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
},
|
||||
}
|
||||
|
||||
return s.userRepo.Create(&_user)
|
||||
}
|
||||
func (s *UserServiceImpl) GetByID(ctx context.Context, id int64) (*model.User, error) {
|
||||
return s.userRepo.GetByID(id)
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) Update(ctx context.Context, user *model.User) error {
|
||||
_user := ctx.Value("user").(*model.User) // 被更新的用户
|
||||
if _user == nil {
|
||||
return fmt.Errorf("user not found in context")
|
||||
}
|
||||
userid := ctx.Value("user_id").(int64) // 操作者
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
|
||||
role, ok := userRoleValue.(*consts.UserRole) // 操作者角色
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
switch {
|
||||
case *role < consts.RoleAdmin:
|
||||
if user.ID != userid {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
case *role == consts.RoleAdmin:
|
||||
if *user.Role > *role { // 更新的用户角色不能高于操作者角色
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
if *_user.Role >= *role { // 管理员之间不能被修改
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
case *role > consts.RoleAdmin: // 根不能被修改
|
||||
if user.ID == userid {
|
||||
user.Role = role // root不能修改自己的角色
|
||||
} else {
|
||||
if user.Role != nil && user.Role == utils.ToPtr(consts.RoleRoot) {
|
||||
return fmt.Errorf("Root user Only one can exist")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if user.Name != "" {
|
||||
_user.Name = user.Name
|
||||
}
|
||||
if user.Username != "" {
|
||||
_user.Username = user.Username
|
||||
}
|
||||
if user.Email != "" {
|
||||
_user.Email = user.Email
|
||||
_user.EmailVerified = utils.ToPtr(false)
|
||||
}
|
||||
if user.Active != nil {
|
||||
_user.Active = user.Active
|
||||
}
|
||||
if user.Role != nil {
|
||||
_user.Role = user.Role
|
||||
}
|
||||
if user.Active != nil {
|
||||
_user.Active = user.Active
|
||||
}
|
||||
if user.Quota != nil {
|
||||
_user.Quota = user.Quota
|
||||
}
|
||||
if user.UsedQuota != nil {
|
||||
_user.UsedQuota = user.UsedQuota
|
||||
}
|
||||
if user.UnlimitedQuota != nil {
|
||||
_user.UnlimitedQuota = user.UnlimitedQuota
|
||||
}
|
||||
if user.Timezone != "" {
|
||||
_user.Timezone = user.Timezone
|
||||
}
|
||||
if user.Language != "" {
|
||||
_user.Language = user.Language
|
||||
}
|
||||
return s.userRepo.Update(_user)
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) Delete(ctx context.Context, id int64) error {
|
||||
_user, err := s.userRepo.GetByID(id) // 被更新的用户
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userid := ctx.Value("user_id").(int64)
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
role, ok := userRoleValue.(*consts.UserRole) // 操作者
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
switch {
|
||||
case *role < consts.RoleAdmin:
|
||||
if _user.ID != userid {
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
case *role == consts.RoleAdmin:
|
||||
if *_user.Role >= *role { // 管理员之间不能被修改
|
||||
return fmt.Errorf("Permission denied")
|
||||
}
|
||||
case *_user.Role == consts.RoleRoot: // 根不能被修改
|
||||
return fmt.Errorf("Root user can not be modified")
|
||||
}
|
||||
|
||||
return s.userRepo.Delete(id)
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) BatchDelete(ctx context.Context, ids []int64) error {
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
role, ok := userRoleValue.(*consts.UserRole)
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
switch {
|
||||
case *role < consts.RoleAdmin:
|
||||
return fmt.Errorf("Unauthorized")
|
||||
case *role == consts.RoleAdmin:
|
||||
return s.userRepo.BatchDelete(ids, []string{fmt.Sprintf("role < %d", role)})
|
||||
}
|
||||
return s.userRepo.BatchDelete(ids, []string{fmt.Sprintf("role < %d", consts.RoleRoot)})
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) BatchEnable(ctx context.Context, ids []int64) error {
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
role, ok := userRoleValue.(*consts.UserRole)
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
switch {
|
||||
case *role < consts.RoleAdmin:
|
||||
return fmt.Errorf("Unauthorized")
|
||||
case *role == consts.RoleAdmin:
|
||||
return s.userRepo.BatchEnable(ids, []string{fmt.Sprintf("role < %d", role)})
|
||||
}
|
||||
return s.userRepo.BatchEnable(ids, nil)
|
||||
}
|
||||
|
||||
func (s *UserServiceImpl) BatchDisable(ctx context.Context, ids []int64) error {
|
||||
userRoleValue := ctx.Value("user_role")
|
||||
if userRoleValue == nil {
|
||||
return fmt.Errorf("user role not found in context")
|
||||
}
|
||||
role, ok := userRoleValue.(*consts.UserRole)
|
||||
if !ok {
|
||||
return fmt.Errorf("user role in context is not an integer")
|
||||
}
|
||||
|
||||
switch {
|
||||
case *role < consts.RoleAdmin:
|
||||
return fmt.Errorf("Unauthorized")
|
||||
case *role == consts.RoleAdmin:
|
||||
return s.userRepo.BatchDisable(ids, []string{fmt.Sprintf("role < %d", role)})
|
||||
}
|
||||
return s.userRepo.BatchDisable(ids, nil)
|
||||
}
|
||||
304
internal/service/webauth.go
Normal file
304
internal/service/webauth.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/pkg/config"
|
||||
"opencatd-open/pkg/store"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
"github.com/mileusna/useragent"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ webauthn.User = (*WebAuthnUser)(nil)
|
||||
|
||||
// WebAuthnUser 实现webauthn.User接口的结构体
|
||||
type WebAuthnUser struct {
|
||||
User *model.User
|
||||
// ID int64
|
||||
// Name string
|
||||
// DisplayName string
|
||||
Credentials []webauthn.Credential
|
||||
}
|
||||
|
||||
// WebAuthnID 返回用户ID
|
||||
func (u *WebAuthnUser) WebAuthnID() []byte {
|
||||
return []byte(strconv.Itoa(int(u.User.ID)))
|
||||
}
|
||||
|
||||
// WebAuthnName 返回用户名
|
||||
func (u *WebAuthnUser) WebAuthnName() string {
|
||||
return u.User.Username
|
||||
}
|
||||
|
||||
// WebAuthnDisplayName 返回用户显示名
|
||||
func (u *WebAuthnUser) WebAuthnDisplayName() string {
|
||||
return u.User.Name
|
||||
}
|
||||
|
||||
// WebAuthnCredentials 返回用户所有凭证
|
||||
func (u *WebAuthnUser) WebAuthnCredentials() []webauthn.Credential {
|
||||
return u.Credentials
|
||||
}
|
||||
|
||||
func (u *WebAuthnUser) WebAuthnCredentialDescriptors() (descriptors []protocol.CredentialDescriptor) {
|
||||
credentials := u.WebAuthnCredentials()
|
||||
|
||||
descriptors = make([]protocol.CredentialDescriptor, len(credentials))
|
||||
|
||||
for i, credential := range credentials {
|
||||
descriptors[i] = credential.Descriptor()
|
||||
}
|
||||
|
||||
return descriptors
|
||||
}
|
||||
|
||||
// WebAuthnService 提供WebAuthn相关功能
|
||||
type WebAuthnService struct {
|
||||
DB *gorm.DB
|
||||
WebAuthn *webauthn.WebAuthn
|
||||
// Sessions map[string]webauthn.SessionData // 用于存储注册和认证过程中的会话数据
|
||||
Sessions *store.WebAuthnSessionStore
|
||||
}
|
||||
|
||||
// NewWebAuthnService 创建新的WebAuthn服务
|
||||
func NewWebAuthnService(db *gorm.DB, cfg *config.Config) (*WebAuthnService, error) {
|
||||
// 创建WebAuthn配置
|
||||
wconfig := &webauthn.Config{
|
||||
RPDisplayName: config.Cfg.AppName, // 依赖方(Relying Party)显示名称
|
||||
RPID: config.Cfg.Domain, // 依赖方ID(通常为域名)
|
||||
RPOrigins: []string{config.Cfg.AppURL}, // 依赖方源(URL)
|
||||
AuthenticatorSelection: protocol.AuthenticatorSelection{
|
||||
RequireResidentKey: protocol.ResidentKeyRequired(), // 要求认证器存储用户 ID (resident key)
|
||||
ResidentKey: protocol.ResidentKeyRequirementRequired, // 使用 Discoverable 模式
|
||||
UserVerification: protocol.VerificationPreferred, // 推荐用户验证
|
||||
AuthenticatorAttachment: "", // 允许任何认证器 (平台或跨平台)
|
||||
},
|
||||
// EncodeUserIDAsString: true, // 将用户ID编码为字符串
|
||||
}
|
||||
|
||||
wa, err := webauthn.New(wconfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &WebAuthnService{
|
||||
DB: db,
|
||||
WebAuthn: wa,
|
||||
// Sessions: make(map[string]webauthn.SessionData),
|
||||
Sessions: store.NewWebAuthnSessionStore(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUserWithCredentials 获取用户及其凭证
|
||||
func (s *WebAuthnService) GetUserWithCredentials(userID int64) (*WebAuthnUser, error) {
|
||||
var user model.User
|
||||
if err := s.DB.Model(&model.User{}).Preload("Passkeys").First(&user, userID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取用户的所有Passkey
|
||||
passkeys := user.Passkeys
|
||||
|
||||
// 将Passkey转换为webauthn.Credential
|
||||
credentials := make([]webauthn.Credential, len(passkeys))
|
||||
for i, pk := range passkeys {
|
||||
credentialIDBytes, err := base64.StdEncoding.DecodeString(pk.CredentialID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode CredentialID: %w", err)
|
||||
}
|
||||
publicKeyBytes, err := base64.StdEncoding.DecodeString(pk.PublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode PublicKey: %w", err)
|
||||
}
|
||||
aaguidBytes, err := base64.StdEncoding.DecodeString(pk.AAGUID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode AAGUID: %w", err)
|
||||
}
|
||||
|
||||
var transport []protocol.AuthenticatorTransport
|
||||
if pk.Transport != "" {
|
||||
transport = []protocol.AuthenticatorTransport{protocol.AuthenticatorTransport(pk.Transport)}
|
||||
}
|
||||
|
||||
credentials[i] = webauthn.Credential{
|
||||
ID: credentialIDBytes,
|
||||
PublicKey: publicKeyBytes,
|
||||
AttestationType: pk.AttestationType,
|
||||
Transport: transport,
|
||||
Flags: webauthn.CredentialFlags{
|
||||
UserPresent: true,
|
||||
UserVerified: true,
|
||||
BackupEligible: pk.BackupEligible,
|
||||
BackupState: pk.BackupState,
|
||||
},
|
||||
Authenticator: webauthn.Authenticator{
|
||||
AAGUID: aaguidBytes,
|
||||
SignCount: pk.SignCount,
|
||||
CloneWarning: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 创建WebAuthnUser
|
||||
return &WebAuthnUser{
|
||||
User: &user,
|
||||
Credentials: credentials,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BeginRegistration 开始注册过程
|
||||
func (s *WebAuthnService) BeginRegistration(userID int64) (*protocol.CredentialCreation, error) {
|
||||
user, err := s.GetUserWithCredentials(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取注册选项
|
||||
options, sessionData, err := s.WebAuthn.BeginRegistration(user)
|
||||
// webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired),
|
||||
// webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()), // 排除已存在的凭证
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保存会话数据
|
||||
userid := strconv.Itoa(int(userID))
|
||||
s.Sessions.SaveWebauthnSession(userid, sessionData)
|
||||
|
||||
return options, nil
|
||||
}
|
||||
|
||||
// FinishRegistration 完成注册过程
|
||||
func (s *WebAuthnService) FinishRegistration(userID int64, response *http.Request, deviceName string) (*model.Passkey, error) {
|
||||
user, err := s.GetUserWithCredentials(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userid := strconv.Itoa(int(userID))
|
||||
// 获取并清除会话数据
|
||||
sessionData, err := s.Sessions.GetWebauthnSession(userid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.Sessions.DeleteWebauthnSession(userid)
|
||||
|
||||
// 完成注册
|
||||
credential, err := s.WebAuthn.FinishRegistration(user, *sessionData, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ua := useragent.Parse(response.UserAgent())
|
||||
|
||||
var transport string
|
||||
if len(credential.Transport) > 0 {
|
||||
transport = string(credential.Transport[0]) // 通常只取第一个传输方式
|
||||
}
|
||||
// 创建Passkey记录
|
||||
passkey := &model.Passkey{
|
||||
UserID: userID,
|
||||
CredentialID: base64.StdEncoding.EncodeToString(credential.ID),
|
||||
PublicKey: base64.StdEncoding.EncodeToString(credential.PublicKey),
|
||||
AttestationType: string(credential.AttestationType),
|
||||
AAGUID: base64.StdEncoding.EncodeToString(credential.Authenticator.AAGUID),
|
||||
SignCount: credential.Authenticator.SignCount,
|
||||
Name: deviceName,
|
||||
DeviceType: strings.TrimSpace(fmt.Sprintf("%s %s %s %s %s", ua.Device, ua.OS, ua.OSVersionNoFull(), ua.Name, ua.VersionNoFull())),
|
||||
LastUsedAt: time.Now().Unix(),
|
||||
BackupEligible: credential.Flags.BackupEligible,
|
||||
BackupState: credential.Flags.BackupState,
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
// 保存Passkey
|
||||
if err := s.DB.Create(passkey).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return passkey, nil
|
||||
}
|
||||
|
||||
// BeginLogin 开始登录过程 (无需用户ID,针对未认证用户)
|
||||
func (s *WebAuthnService) BeginLogin() (*protocol.CredentialAssertion, error) {
|
||||
// 不指定用户ID,让客户端决定使用哪个凭证
|
||||
options, session, err := s.WebAuthn.BeginDiscoverableLogin(
|
||||
webauthn.WithUserVerification(protocol.VerificationPreferred), // 推荐用户验证
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.Sessions.SaveWebauthnSession(session.Challenge, session)
|
||||
|
||||
return options, nil
|
||||
}
|
||||
|
||||
// FinishLogin 完成登录过程
|
||||
func (s *WebAuthnService) FinishLogin(challenge string, response *http.Request) (*WebAuthnUser, error) {
|
||||
// 获取并清除会话数据
|
||||
sessionData, err := s.Sessions.GetWebauthnSession(challenge)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.Sessions.DeleteWebauthnSession(challenge)
|
||||
|
||||
// 获取相应的用户
|
||||
// var user model.User
|
||||
// if err := s.DB.First(&user, passkey.UserID).Error; err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// 创建WebAuthnUser
|
||||
// webAuthnUser, err := s.GetUserWithCredentials(user.ID)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// 完成登录
|
||||
// _, err = s.WebAuthn.FinishLogin(webAuthnUser, sessionData, response)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
var user *WebAuthnUser
|
||||
_, err = s.WebAuthn.FinishDiscoverableLogin(s.GetWebAuthnUser(&user), *sessionData, response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 更新Passkey的LastUsedAt
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) GetWebAuthnUser(wau **WebAuthnUser) webauthn.DiscoverableUserHandler {
|
||||
return func(rawID, userHandle []byte) (webauthn.User, error) {
|
||||
userid, err := strconv.ParseInt(string(userHandle), 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
*wau, err = s.GetUserWithCredentials(userid)
|
||||
return *wau, err
|
||||
}
|
||||
}
|
||||
|
||||
// ListPasskeys 列出用户所有Passkey
|
||||
func (s *WebAuthnService) ListPasskeys(userID int64) ([]model.Passkey, error) {
|
||||
var passkeys []model.Passkey
|
||||
if err := s.DB.Where("user_id = ?", userID).Find(&passkeys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return passkeys, nil
|
||||
}
|
||||
|
||||
// DeletePasskey 删除用户Passkey
|
||||
func (s *WebAuthnService) DeletePasskey(userID int64, passkeyID int64) error {
|
||||
return s.DB.Where("id = ? AND user_id = ?", passkeyID, userID).Delete(&model.Passkey{}).Error
|
||||
}
|
||||
16
internal/utils/convert.go
Normal file
16
internal/utils/convert.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package utils
|
||||
|
||||
import "strings"
|
||||
|
||||
func StringToBool(strSlice []string) []bool {
|
||||
boolSlice := make([]bool, len(strSlice))
|
||||
for i, str := range strSlice {
|
||||
str = strings.ToLower(str)
|
||||
if str == "true" {
|
||||
boolSlice[i] = true
|
||||
} else if str == "false" {
|
||||
boolSlice[i] = false
|
||||
}
|
||||
}
|
||||
return boolSlice
|
||||
}
|
||||
139
internal/utils/map_tools.go
Normal file
139
internal/utils/map_tools.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func MergeJSONObjects(dst, src map[string]interface{}) map[string]interface{} {
|
||||
|
||||
result := make(map[string]interface{})
|
||||
for k, v := range dst {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
for key, value2 := range src {
|
||||
value1, exists := result[key]
|
||||
|
||||
if exists {
|
||||
map1Val, map1IsMap := value1.(map[string]interface{})
|
||||
map2Val, map2IsMap := value2.(map[string]interface{})
|
||||
|
||||
if map1IsMap && map2IsMap {
|
||||
result[key] = MergeJSONObjects(map1Val, map2Val)
|
||||
} else {
|
||||
// 覆盖第一个map中的值
|
||||
result[key] = value2
|
||||
}
|
||||
} else {
|
||||
// 添加新的键值对
|
||||
result[key] = value2
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func StructToMap(in interface{}) (map[string]interface{}, error) {
|
||||
out := make(map[string]interface{})
|
||||
|
||||
v := reflect.ValueOf(in)
|
||||
// If it's a pointer, dereference it
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
// Check if it's a struct
|
||||
if v.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("StructToMap only accepts structs or pointers to structs; got %T", v.Interface())
|
||||
}
|
||||
|
||||
t := v.Type() // Get the type of the struct
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
// Get the field Value and Type
|
||||
fieldV := v.Field(i)
|
||||
fieldT := t.Field(i)
|
||||
|
||||
// Skip unexported fields
|
||||
if !fieldT.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// --- Handle JSON Tag ---
|
||||
tag := fieldT.Tag.Get("json")
|
||||
key := fieldT.Name // Default key is the field name
|
||||
omitempty := false
|
||||
|
||||
if tag != "" {
|
||||
parts := strings.Split(tag, ",")
|
||||
tagName := parts[0]
|
||||
|
||||
if tagName == "-" {
|
||||
// Skip fields tagged with "-"
|
||||
continue
|
||||
}
|
||||
if tagName != "" {
|
||||
key = tagName // Use tag name as key
|
||||
}
|
||||
|
||||
// Check for omitempty option
|
||||
for _, part := range parts[1:] {
|
||||
if part == "omitempty" {
|
||||
omitempty = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Handle omitempty ---
|
||||
val := fieldV.Interface()
|
||||
if omitempty && fieldV.IsZero() {
|
||||
continue // Skip zero-value fields if omitempty is set
|
||||
}
|
||||
|
||||
// --- Handle Nested Structs/Pointers to Structs (Recursion) ---
|
||||
// Check for pointer first
|
||||
if fieldV.Kind() == reflect.Ptr {
|
||||
// If pointer is nil and omitempty is set, it was already skipped
|
||||
// If pointer is nil and omitempty is not set, add nil to map
|
||||
if fieldV.IsNil() {
|
||||
// Only add nil if omitempty is not set (already handled above)
|
||||
if !omitempty {
|
||||
out[key] = nil
|
||||
}
|
||||
continue // Move to next field
|
||||
}
|
||||
// If it points to a struct, dereference and recurse
|
||||
if fieldV.Elem().Kind() == reflect.Struct {
|
||||
nestedMap, err := StructToMap(fieldV.Interface()) // Pass the pointer
|
||||
if err != nil {
|
||||
// Decide how to handle nested errors, e.g., log or return
|
||||
fmt.Printf("Warning: could not convert nested struct pointer %s: %v\n", fieldT.Name, err)
|
||||
out[key] = val // Store original value on error? Or skip?
|
||||
} else {
|
||||
out[key] = nestedMap
|
||||
}
|
||||
continue // Move to next field after handling pointer
|
||||
}
|
||||
// If pointer to non-struct, just get the interface value (handled below)
|
||||
val = fieldV.Interface() // Use the actual pointer value
|
||||
|
||||
} else if fieldV.Kind() == reflect.Struct {
|
||||
// If it's a struct (not a pointer), recurse
|
||||
nestedMap, err := StructToMap(fieldV.Interface()) // Pass the struct value
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: could not convert nested struct %s: %v\n", fieldT.Name, err)
|
||||
out[key] = val // Store original value on error? Or skip?
|
||||
} else {
|
||||
out[key] = nestedMap
|
||||
}
|
||||
continue // Move to next field after handling struct
|
||||
}
|
||||
|
||||
// Assign the value (primitive, slice, map, non-struct pointer, etc.)
|
||||
out[key] = val
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
15
internal/utils/password.go
Normal file
15
internal/utils/password.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func HashPassword(password string) (string, error) {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
func CheckPassword(hash, password string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
@@ -3,3 +3,9 @@ package utils
|
||||
func ToPtr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
func UpdatePtrField[T any](target *T, value *T) {
|
||||
if value != nil {
|
||||
*target = *value
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user