reface to openteam
This commit is contained in:
152
internal/controller/apikey.go
Normal file
152
internal/controller/apikey.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"opencatd-open/internal/consts"
|
||||
"opencatd-open/internal/dto"
|
||||
"opencatd-open/internal/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/duke-git/lancet/v2/slice"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (a Api) CreateApiKey(c *gin.Context) {
|
||||
role := c.MustGet("user_role").(*consts.UserRole)
|
||||
if *role < consts.RoleAdmin {
|
||||
dto.Fail(c, 403, "Permission denied")
|
||||
return
|
||||
}
|
||||
req := new(model.ApiKey)
|
||||
err := c.ShouldBind(&req)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
}
|
||||
|
||||
err = a.keyService.CreateApiKey(c, req)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
} else {
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (a Api) GetApiKey(c *gin.Context) {
|
||||
role := c.MustGet("user_role").(*consts.UserRole)
|
||||
if *role < consts.RoleAdmin {
|
||||
dto.Fail(c, 403, "Permission denied")
|
||||
return
|
||||
}
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
key, err := a.keyService.GetApiKey(c, id)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
} else {
|
||||
dto.Success(c, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (a Api) ListApiKey(c *gin.Context) {
|
||||
role := c.MustGet("user_role").(*consts.UserRole)
|
||||
if *role < consts.RoleAdmin {
|
||||
dto.Fail(c, 403, "Permission denied")
|
||||
return
|
||||
}
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
offset := (page - 1) * limit
|
||||
active := c.QueryArray("active[]")
|
||||
if !slice.ContainSubSlice([]string{"true", "false"}, active) {
|
||||
dto.Fail(c, http.StatusBadRequest, "active must be true or false")
|
||||
return
|
||||
}
|
||||
|
||||
keys, total, err := a.keyService.ListApiKey(c, limit, offset, active)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
} else {
|
||||
dto.Success(c, gin.H{
|
||||
"total": total,
|
||||
"keys": keys,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (a Api) DeleteApiKey(c *gin.Context) {
|
||||
role := c.MustGet("user_role").(*consts.UserRole)
|
||||
if *role < consts.RoleAdmin {
|
||||
dto.Fail(c, 403, "Permission denied")
|
||||
return
|
||||
}
|
||||
var batchid dto.BatchIDRequest
|
||||
err := c.ShouldBind(&batchid)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = a.keyService.DeleteApiKey(c, batchid.IDs)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
} else {
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (a Api) UpdateApiKey(c *gin.Context) {
|
||||
role := c.MustGet("user_role").(*consts.UserRole)
|
||||
if *role < consts.RoleAdmin {
|
||||
dto.Fail(c, 403, "Permission denied")
|
||||
return
|
||||
}
|
||||
var req model.ApiKey
|
||||
err := c.ShouldBind(&req)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = a.keyService.UpdateApiKey(c, &req)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
} else {
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (a Api) ApiKeyOption(c *gin.Context) {
|
||||
role := c.MustGet("user_role").(*consts.UserRole)
|
||||
if *role < consts.RoleAdmin {
|
||||
dto.Fail(c, 403, "Permission denied")
|
||||
return
|
||||
}
|
||||
option := strings.ToLower(c.Param("option"))
|
||||
var batchid dto.BatchIDRequest
|
||||
err := c.ShouldBind(&batchid)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
switch option {
|
||||
case "enable":
|
||||
err = a.keyService.EnableApiKey(c, batchid.IDs)
|
||||
case "disable":
|
||||
err = a.keyService.DisableApiKey(c, batchid.IDs)
|
||||
case "delete":
|
||||
err = a.keyService.DeleteApiKey(c, batchid.IDs)
|
||||
default:
|
||||
dto.Fail(c, 400, "invalid option, only support enable, disable, delete")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
27
internal/controller/init.go
Normal file
27
internal/controller/init.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"opencatd-open/internal/service"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Api struct {
|
||||
db *gorm.DB
|
||||
userService *service.UserServiceImpl
|
||||
tokenService *service.TokenServiceImpl
|
||||
keyService *service.ApiKeyServiceImpl
|
||||
webAuthService *service.WebAuthnService
|
||||
usageService *service.UsageService
|
||||
}
|
||||
|
||||
func NewApi(db *gorm.DB, userService *service.UserServiceImpl, tokenService *service.TokenServiceImpl, keyService *service.ApiKeyServiceImpl, webAuthService *service.WebAuthnService, usageService *service.UsageService) *Api {
|
||||
return &Api{
|
||||
db: db,
|
||||
userService: userService,
|
||||
tokenService: tokenService,
|
||||
keyService: keyService,
|
||||
webAuthService: webAuthService,
|
||||
usageService: usageService,
|
||||
}
|
||||
}
|
||||
60
internal/controller/proxy/chat_proxy.go
Normal file
60
internal/controller/proxy/chat_proxy.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"opencatd-open/internal/dto"
|
||||
"opencatd-open/llm"
|
||||
"opencatd-open/llm/claude/v2"
|
||||
"opencatd-open/llm/google/v2"
|
||||
"opencatd-open/llm/openai_compatible"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (h *Proxy) ChatHandler(c *gin.Context) {
|
||||
var chatreq llm.ChatRequest
|
||||
if err := c.ShouldBindJSON(&chatreq); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
err := h.SelectApiKey(chatreq.Model)
|
||||
if err != nil {
|
||||
dto.WrapErrorAsOpenAI(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var llm llm.LLM
|
||||
switch *h.apikey.ApiType {
|
||||
case "claude":
|
||||
llm, err = claude.NewClaude(h.apikey)
|
||||
case "gemini":
|
||||
llm, err = google.NewGemini(c, h.apikey)
|
||||
case "openai", "azure", "github":
|
||||
fallthrough
|
||||
default:
|
||||
llm, err = openai_compatible.NewOpenAICompatible(h.apikey)
|
||||
if err != nil {
|
||||
dto.WrapErrorAsOpenAI(c, 500, fmt.Errorf("create llm client error: %w", err).Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !chatreq.Stream {
|
||||
resp, err := llm.Chat(c, chatreq)
|
||||
if err != nil {
|
||||
dto.WrapErrorAsOpenAI(c, 500, err.Error())
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
|
||||
} else {
|
||||
datachan, err := llm.StreamChat(c, chatreq)
|
||||
if err != nil {
|
||||
dto.WrapErrorAsOpenAI(c, 500, err.Error())
|
||||
}
|
||||
for data := range datachan {
|
||||
c.SSEvent("", data)
|
||||
}
|
||||
}
|
||||
}
|
||||
345
internal/controller/proxy/proxy.go
Normal file
345
internal/controller/proxy/proxy.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"opencatd-open/internal/dao"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/pkg/config"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/lib/pq"
|
||||
"github.com/tidwall/gjson"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type Proxy struct {
|
||||
ctx context.Context
|
||||
cfg *config.Config
|
||||
db *gorm.DB
|
||||
wg *sync.WaitGroup
|
||||
usageChan chan *model.Usage // 用于异步处理的channel
|
||||
apikey *model.ApiKey
|
||||
httpClient *http.Client
|
||||
|
||||
userDAO *dao.UserDAO
|
||||
apiKeyDao *dao.ApiKeyDAO
|
||||
tokenDAO *dao.TokenDAO
|
||||
usageDAO *dao.UsageDAO
|
||||
dailyUsageDAO *dao.DailyUsageDAO
|
||||
}
|
||||
|
||||
func NewProxy(ctx context.Context, cfg *config.Config, db *gorm.DB, wg *sync.WaitGroup, userDAO *dao.UserDAO, apiKeyDAO *dao.ApiKeyDAO, tokenDAO *dao.TokenDAO, usageDAO *dao.UsageDAO, dailyUsageDAO *dao.DailyUsageDAO) *Proxy {
|
||||
client := http.DefaultClient
|
||||
if os.Getenv("LOCAL_PROXY") != "" {
|
||||
proxyUrl, err := url.Parse(os.Getenv("LOCAL_PROXY"))
|
||||
if err == nil {
|
||||
tr := &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyUrl),
|
||||
}
|
||||
client.Transport = tr
|
||||
}
|
||||
}
|
||||
np := &Proxy{
|
||||
ctx: ctx,
|
||||
cfg: cfg,
|
||||
db: db,
|
||||
wg: wg,
|
||||
httpClient: client,
|
||||
usageChan: make(chan *model.Usage, cfg.UsageChanSize),
|
||||
userDAO: userDAO,
|
||||
apiKeyDao: apiKeyDAO,
|
||||
tokenDAO: tokenDAO,
|
||||
usageDAO: usageDAO,
|
||||
dailyUsageDAO: dailyUsageDAO,
|
||||
}
|
||||
|
||||
go np.ProcessUsage()
|
||||
go np.ScheduleTask()
|
||||
|
||||
return np
|
||||
}
|
||||
|
||||
func (p *Proxy) HandleProxy(c *gin.Context) {
|
||||
if c.Request.URL.Path == "/v1/chat/completions" {
|
||||
p.ChatHandler(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) SendUsage(usage *model.Usage) {
|
||||
select {
|
||||
case p.usageChan <- usage:
|
||||
default:
|
||||
log.Println("usage channel is full, skip processing")
|
||||
bj, _ := json.Marshal(usage)
|
||||
log.Println(string(bj))
|
||||
//TODO: send to a queue
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) ProcessUsage() {
|
||||
for i := 0; i < p.cfg.UsageWorker; i++ {
|
||||
p.wg.Add(1)
|
||||
go func(i int) {
|
||||
defer p.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case usage, ok := <-p.usageChan:
|
||||
if !ok {
|
||||
// channel 关闭,退出程序
|
||||
return
|
||||
}
|
||||
err := p.Do(usage)
|
||||
if err != nil {
|
||||
log.Printf("process usage error: %v\n", err)
|
||||
}
|
||||
case <-p.ctx.Done():
|
||||
// close(s.usageChan)
|
||||
// for usage := range s.usageChan {
|
||||
// if err := s.Do(usage); err != nil {
|
||||
// fmt.Printf("[close event]process usage error: %v\n", err)
|
||||
// }
|
||||
// }
|
||||
for {
|
||||
select {
|
||||
case usage, ok := <-p.usageChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := p.Do(usage); err != nil {
|
||||
fmt.Printf("[close event]process usage error: %v\n", err)
|
||||
}
|
||||
default:
|
||||
fmt.Printf("usageChan is empty,usage worker %d done\n", i)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}(i)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) Do(usage *model.Usage) error {
|
||||
err := p.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 1. 记录使用记录
|
||||
if err := tx.WithContext(p.ctx).Create(usage).Error; err != nil {
|
||||
return fmt.Errorf("create usage error: %w", err)
|
||||
}
|
||||
|
||||
// 2. 更新每日统计(upsert 操作)
|
||||
dailyUsage := model.DailyUsage{
|
||||
UserID: usage.UserID,
|
||||
TokenID: usage.TokenID,
|
||||
Capability: usage.Capability,
|
||||
Date: time.Date(usage.Date.Year(), usage.Date.Month(), usage.Date.Day(), 0, 0, 0, 0, usage.Date.Location()),
|
||||
Model: usage.Model,
|
||||
Stream: usage.Stream,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
TotalTokens: usage.TotalTokens,
|
||||
Cost: usage.Cost,
|
||||
}
|
||||
|
||||
// 使用 OnConflict 实现 upsert
|
||||
if err := tx.WithContext(p.ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "user_id"}, {Name: "token_id"}, {Name: "capability"}, {Name: "date"}}, // 唯一键
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||
"prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens),
|
||||
"completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens),
|
||||
"total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens),
|
||||
"cost": gorm.Expr("cost + ?", usage.Cost),
|
||||
}),
|
||||
}).Create(&dailyUsage).Error; err != nil {
|
||||
return fmt.Errorf("upsert daily usage error: %w", err)
|
||||
}
|
||||
|
||||
// 3. 更新用户额度
|
||||
if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", usage.UserID).Updates(map[string]interface{}{
|
||||
"quota": gorm.Expr("quota - ?", usage.Cost),
|
||||
"used_quota": gorm.Expr("used_quota + ?", usage.Cost),
|
||||
}).Error; err != nil {
|
||||
return fmt.Errorf("update user quota and used_quota error: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *Proxy) SelectApiKey(model string) error {
|
||||
akpikeys, err := p.apiKeyDao.FindApiKeysBySupportModel(p.db, model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(akpikeys) == 0 {
|
||||
return errors.New("no available apikey")
|
||||
} else {
|
||||
if strings.HasPrefix(model, "gpt") {
|
||||
keys, err := p.apiKeyDao.FindKeys(map[string]any{"type = ?": "openai"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
akpikeys = append(akpikeys, keys...)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(model, "gemini") {
|
||||
keys, err := p.apiKeyDao.FindKeys(map[string]any{"type = ?": "gemini"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
akpikeys = append(akpikeys, keys...)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(model, "claude") {
|
||||
keys, err := p.apiKeyDao.FindKeys(map[string]any{"type = ?": "claude"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
akpikeys = append(akpikeys, keys...)
|
||||
}
|
||||
}
|
||||
if len(akpikeys) == 0 {
|
||||
return errors.New("no available apikey")
|
||||
|
||||
}
|
||||
|
||||
if len(akpikeys) == 1 {
|
||||
p.apikey = &akpikeys[0]
|
||||
return nil
|
||||
}
|
||||
length := len(akpikeys) - 1
|
||||
|
||||
p.apikey = &akpikeys[rand.Intn(length)]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Proxy) updateSupportModel() {
|
||||
|
||||
keys, err := p.apiKeyDao.FindKeys(map[string]interface{}{"type in ?": "openai,azure,claude"})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, key := range keys {
|
||||
var supportModels []string
|
||||
if *key.ApiType == "openai" || *key.ApiType == "azure" {
|
||||
supportModels, err = p.getOpenAISupportModels(key)
|
||||
}
|
||||
if *key.ApiType == "claude" {
|
||||
supportModels, err = p.getClaudeSupportModels(key)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
continue
|
||||
}
|
||||
if len(supportModels) == 0 {
|
||||
continue
|
||||
|
||||
}
|
||||
if p.cfg.DB_Type == "sqlite" {
|
||||
bytejson, _ := json.Marshal(supportModels)
|
||||
if err := p.db.Model(&model.ApiKey{}).Where("id = ?", key.ID).UpdateColumn("support_models", string(bytejson)).Error; err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
} else if p.cfg.DB_Type == "postgres" {
|
||||
if err := p.db.Model(&model.ApiKey{}).Where("id = ?", key.ID).UpdateColumn("support_models", pq.StringArray(supportModels)).Error; err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (p *Proxy) ScheduleTask() {
|
||||
|
||||
func() {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(time.Duration(p.cfg.TaskTimeInterval) * time.Minute):
|
||||
p.updateSupportModel()
|
||||
|
||||
case <-p.ctx.Done():
|
||||
fmt.Println("schedule task done")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *Proxy) getOpenAISupportModels(apikey model.ApiKey) ([]string, error) {
|
||||
openaiModelsUrl := "https://api.openai.com/v1/models"
|
||||
// https://learn.microsoft.com/zh-cn/rest/api/azureopenai/models/list?view=rest-azureopenai-2025-02-01-preview&tabs=HTTP
|
||||
azureModelsUrl := "/openai/deployments?api-version=2022-12-01"
|
||||
|
||||
var supportModels []string
|
||||
var req *http.Request
|
||||
if *apikey.ApiType == "azure" {
|
||||
req, _ = http.NewRequest("GET", *apikey.Endpoint+azureModelsUrl, nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("api-key", *apikey.ApiKey)
|
||||
} else {
|
||||
req, _ = http.NewRequest("GET", openaiModelsUrl, nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+*apikey.ApiKey)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
bytesbody, _ := io.ReadAll(resp.Body)
|
||||
result := gjson.GetBytes(bytesbody, "data.#.id").Array()
|
||||
for _, v := range result {
|
||||
model := v.Str
|
||||
model = strings.Replace(model, "-35-", "-3.5-", -1)
|
||||
model = strings.Replace(model, "-41-", "-4.1-", -1)
|
||||
supportModels = append(supportModels, model)
|
||||
}
|
||||
}
|
||||
return supportModels, nil
|
||||
}
|
||||
|
||||
func (p *Proxy) getClaudeSupportModels(apikey model.ApiKey) ([]string, error) {
|
||||
// https://docs.anthropic.com/en/api/models-list
|
||||
claudemodelsUrl := "https://api.anthropic.com/v1/models"
|
||||
var supportModels []string
|
||||
|
||||
req, _ := http.NewRequest("GET", claudemodelsUrl, nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-api-key", *apikey.ApiKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
bytesbody, _ := io.ReadAll(resp.Body)
|
||||
result := gjson.GetBytes(bytesbody, "data.#.id").Array()
|
||||
for _, v := range result {
|
||||
supportModels = append(supportModels, v.Str)
|
||||
}
|
||||
}
|
||||
return supportModels, nil
|
||||
}
|
||||
53
internal/controller/team/middleware.go
Normal file
53
internal/controller/team/middleware.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"opencatd-open/internal/consts"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (h *Team) AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.URL.Path == "/1/users/init" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
authtoken := c.GetHeader("Authorization")
|
||||
if authtoken == "" || len(authtoken) <= 7 || authtoken[:7] != "Bearer " {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
authtoken = authtoken[7:]
|
||||
token, err := h.tokenService.GetByKey(c, authtoken)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
if token.Name != "default" {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "only default token can access"})
|
||||
c.Abort()
|
||||
}
|
||||
if token.User.Status != consts.StatusEnabled {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "user is disabled"})
|
||||
c.Abort()
|
||||
}
|
||||
c.Set("local_user", true)
|
||||
c.Set("token", token)
|
||||
|
||||
// 可以在这里对 token 进行验证并检查权限
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func CORS() gin.HandlerFunc {
|
||||
config := cors.DefaultConfig()
|
||||
config.AllowAllOrigins = true
|
||||
config.AllowCredentials = true
|
||||
config.AllowMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
|
||||
config.AllowHeaders = []string{"*"}
|
||||
return cors.New(config)
|
||||
}
|
||||
563
internal/controller/team/team.go
Normal file
563
internal/controller/team/team.go
Normal file
@@ -0,0 +1,563 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"opencatd-open/internal/consts"
|
||||
dto "opencatd-open/internal/dto/team"
|
||||
"opencatd-open/internal/model"
|
||||
service "opencatd-open/internal/service/team"
|
||||
"opencatd-open/internal/utils"
|
||||
|
||||
"github.com/duke-git/lancet/v2/slice"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Team struct {
|
||||
db *gorm.DB
|
||||
userService service.UserService
|
||||
tokenService service.TokenService
|
||||
keyService service.ApiKeyService
|
||||
usageService service.UsageService
|
||||
}
|
||||
|
||||
func NewTeam(userService service.UserService, tokenService service.TokenService, keyService service.ApiKeyService, usageService service.UsageService) *Team {
|
||||
return &Team{
|
||||
userService: userService,
|
||||
tokenService: tokenService,
|
||||
keyService: keyService,
|
||||
usageService: usageService,
|
||||
}
|
||||
}
|
||||
|
||||
// initadmin
|
||||
func (h *Team) InitAdmin(c *gin.Context) {
|
||||
admin, err := h.userService.GetUser(c, 1)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
user := &model.User{
|
||||
Name: "root",
|
||||
Username: "root",
|
||||
Password: "openteam",
|
||||
Role: utils.ToPtr(consts.RoleRoot),
|
||||
Tokens: []model.Token{
|
||||
{
|
||||
Name: "default",
|
||||
Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
UnlimitedQuota: utils.ToPtr(true),
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := h.userService.CreateUser(c, user); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
var result = dto.UserInfo{
|
||||
ID: user.ID,
|
||||
Name: user.Username,
|
||||
Token: user.Tokens[0].Key,
|
||||
Status: utils.ToPtr(user.Status == consts.StatusEnabled),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
if admin != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "super user already exists, use cli to reset password",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Team) Me(c *gin.Context) {
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "token not found"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
|
||||
c.JSON(http.StatusOK, dto.UserInfo{
|
||||
ID: userToken.UserID,
|
||||
Name: userToken.User.Name,
|
||||
Token: userToken.Key,
|
||||
Status: utils.ToPtr(userToken.User.Status == consts.StatusEnabled),
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// CreateUser 创建用户
|
||||
func (h *Team) CreateUser(c *gin.Context) {
|
||||
var userReq dto.UserInfo
|
||||
if err := c.ShouldBindJSON(&userReq); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid input"})
|
||||
return
|
||||
}
|
||||
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
if *userToken.User.Role < consts.RoleAdmin { // 普通用户只能创建自己的token
|
||||
create := &model.Token{
|
||||
Name: userReq.Name,
|
||||
Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
}
|
||||
if userReq.Token != "" {
|
||||
_key := strings.ReplaceAll(userReq.Token, "-", "")
|
||||
create.Key = "sk-team-" + strings.ReplaceAll(_key, " ", "")
|
||||
}
|
||||
if err := h.tokenService.Create(c.Request.Context(), create); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
user := &model.User{
|
||||
Name: userReq.Name,
|
||||
Username: userReq.Name,
|
||||
Role: utils.ToPtr(consts.RoleUser),
|
||||
Tokens: []model.Token{
|
||||
{
|
||||
Name: "default",
|
||||
Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 默认角色为普通用户
|
||||
if err := h.userService.CreateUser(c.Request.Context(), user); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
// GetUser 获取用户信息
|
||||
func (h *Team) GetUser(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.GetUser(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, user)
|
||||
}
|
||||
|
||||
// UpdateUser 更新用户信息
|
||||
func (h *Team) UpdateUser(c *gin.Context) {
|
||||
var user model.User
|
||||
if err := c.ShouldBindJSON(&user); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid input"})
|
||||
return
|
||||
}
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
|
||||
operatorID := userToken.UserID // 假设从上下文中获取操作者ID
|
||||
if err := h.userService.UpdateUser(c.Request.Context(), &user, operatorID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
// DeleteUser 删除用户
|
||||
func (h *Team) DeleteUser(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
|
||||
if *userToken.User.Role < consts.RoleAdmin { // 用户只能删除自己的token
|
||||
err := h.tokenService.Delete(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := h.userService.DeleteUser(c, id, userToken.UserID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
func (h *Team) ListUsages(c *gin.Context) {
|
||||
fromStr := c.Query("from")
|
||||
toStr := c.Query("to")
|
||||
|
||||
var from, to time.Time
|
||||
loc, _ := time.LoadLocation("Local")
|
||||
|
||||
var listUsage []*dto.UsageInfo
|
||||
var err error
|
||||
|
||||
if fromStr != "" && toStr != "" {
|
||||
|
||||
from, err = time.Parse("2006-01-02", fromStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid from date"})
|
||||
return
|
||||
}
|
||||
to, err = time.Parse("2006-01-02", toStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid to date"})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
year, month, _ := time.Now().In(loc).Date()
|
||||
from = time.Date(year, month, 1, 0, 0, 0, 0, loc)
|
||||
to = from.AddDate(0, 1, 0)
|
||||
}
|
||||
|
||||
token, _ := c.Get("token")
|
||||
userToken := token.(*model.Token)
|
||||
if *userToken.User.Role < consts.RoleAdmin {
|
||||
listUsage, err = h.usageService.ListByDateRange(c.Request.Context(), from, to, map[string]interface{}{"user_id": userToken.UserID})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
listUsage, err = h.usageService.ListByDateRange(c.Request.Context(), from, to, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, listUsage)
|
||||
|
||||
}
|
||||
|
||||
// ListUsers 获取用户列表
|
||||
func (h *Team) ListUsers(c *gin.Context) {
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||
offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0"))
|
||||
active := c.DefaultQuery("active", "")
|
||||
|
||||
if !slices.Contains([]string{"true", "false", ""}, active) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid active value"})
|
||||
return
|
||||
}
|
||||
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
if *userToken.User.Role < consts.RoleAdmin { // 用户只能获取自己的token
|
||||
tokens, _, err := h.tokenService.Lists(c, limit, offset)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
var userDTOs []dto.UserInfo
|
||||
for _, token := range tokens {
|
||||
userDTOs = append(userDTOs, dto.UserInfo{
|
||||
ID: token.User.ID,
|
||||
Name: token.User.Name,
|
||||
Token: token.Key,
|
||||
Status: utils.ToPtr(token.User.Status == consts.StatusEnabled),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, userDTOs)
|
||||
return
|
||||
}
|
||||
|
||||
users, err := h.userService.ListUsers(c, limit, offset, active)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var userDTOs []dto.UserInfo
|
||||
for _, user := range users {
|
||||
useres := dto.UserInfo{
|
||||
ID: user.ID,
|
||||
Name: user.Name,
|
||||
|
||||
Status: utils.ToPtr(user.Status == consts.StatusEnabled),
|
||||
}
|
||||
if len(user.Tokens) > 0 {
|
||||
useres.Token = user.Tokens[0].Key
|
||||
}
|
||||
userDTOs = append(userDTOs, useres)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, userDTOs)
|
||||
}
|
||||
|
||||
func (h *Team) ResetUserToken(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
|
||||
findtoken, err := h.tokenService.GetByUserID(c, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
findtoken.Key = "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
if *userToken.User.Role < consts.RoleAdmin { // 非管理员只能修改自己的token
|
||||
if *userToken.User.Role <= *findtoken.User.Role || userToken.UserID != findtoken.UserID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "forbidden"})
|
||||
return
|
||||
}
|
||||
err := h.tokenService.UpdateWithCondition(c, findtoken, map[string]interface{}{"user_id": userToken.UserID}, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := h.tokenService.Update(c, findtoken); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, dto.UserInfo{
|
||||
ID: findtoken.User.ID,
|
||||
Name: findtoken.User.Name,
|
||||
Token: findtoken.Key,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Team) CreateKey(c *gin.Context) {
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "token not found"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
if *userToken.User.Role < consts.RoleAdmin {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "forbidden"})
|
||||
return
|
||||
}
|
||||
|
||||
var key dto.ApiKeyInfo
|
||||
if err := c.ShouldBindJSON(&key); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
err := h.keyService.Create(&model.ApiKey{
|
||||
Name: utils.ToPtr(key.Name),
|
||||
ApiType: utils.ToPtr(key.ApiType),
|
||||
ApiKey: utils.ToPtr(key.Key),
|
||||
Endpoint: utils.ToPtr(key.Endpoint),
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, key)
|
||||
}
|
||||
|
||||
func (h *Team) ListKeys(c *gin.Context) {
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||
offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0"))
|
||||
active := c.Query("active")
|
||||
if !slice.Contain([]string{"true", "false", ""}, active) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid active value"})
|
||||
return
|
||||
}
|
||||
|
||||
keys, err := h.keyService.List(limit, offset, active)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var keysDTO []dto.ApiKeyInfo
|
||||
for _, key := range keys {
|
||||
keylength := len(*key.ApiKey) / 3
|
||||
if keylength < 1 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid key length"})
|
||||
return
|
||||
}
|
||||
keysDTO = append(keysDTO, dto.ApiKeyInfo{
|
||||
ID: int(key.ID),
|
||||
Name: *key.Name,
|
||||
ApiType: *key.ApiType,
|
||||
Endpoint: *key.Endpoint,
|
||||
Key: *key.ApiKey,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, keysDTO)
|
||||
}
|
||||
|
||||
func (h *Team) UpdateKey(c *gin.Context) {
|
||||
// 1. 获取并验证ID
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid key id"})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 解析请求体
|
||||
var updateKey dto.ApiKeyInfo // 更明确的命名
|
||||
if err := c.ShouldBindJSON(&updateKey); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取现有记录
|
||||
existingKey, err := h.keyService.GetByID(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 使用 UpdateFields 方法统一处理字段更新
|
||||
updatedKey := updateKey.UpdateFields(existingKey)
|
||||
|
||||
// 5. 保存更新
|
||||
if err := h.keyService.Update(updatedKey); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, updatedKey)
|
||||
}
|
||||
|
||||
func (h *Team) DeleteKey(c *gin.Context) {
|
||||
// 1. 获取并验证ID
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid key id"})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 删除记录
|
||||
if err := h.keyService.Delete(id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
func (h *Team) ChangePassword(c *gin.Context) {
|
||||
userID := c.GetInt64("userID") // 假设从上下文中获取用户ID
|
||||
|
||||
var req struct {
|
||||
OldPassword string `json:"oldPassword"`
|
||||
NewPassword string `json:"newPassword"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid input"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.userService.ChangePassword(c.Request.Context(), userID, req.OldPassword, req.NewPassword); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
func (h *Team) ResetPassword(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
|
||||
operatorID := int64(c.GetInt("userID")) // 假设从上下文中获取操作者ID
|
||||
if err := h.userService.ResetPassword(c.Request.Context(), id, operatorID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "password reset successfully"})
|
||||
}
|
||||
|
||||
// EnableUser 启用用户
|
||||
func (h *Team) EnableUser(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
|
||||
operatorID := int64(c.GetInt("userID")) // 假设从上下文中获取操作者ID
|
||||
|
||||
if err := h.userService.BatchEnableUsers(c, []int64{id}, operatorID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "user enabled successfully"})
|
||||
}
|
||||
|
||||
// DisableUser 禁用用户
|
||||
func (h *Team) DisableUser(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
|
||||
operatorID := int64(c.GetInt("userID")) // 假设从上下文中获取操作者ID
|
||||
if err := h.userService.BatchDisableUsers(c.Request.Context(), []int64{id}, operatorID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "user disabled successfully"})
|
||||
}
|
||||
218
internal/controller/user.go
Normal file
218
internal/controller/user.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"opencatd-open/internal/dto"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/internal/utils"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/duke-git/lancet/v2/slice"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (a Api) Register(c *gin.Context) {
|
||||
req := new(dto.User)
|
||||
err := c.ShouldBind(&req)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = a.userService.Register(c, &model.User{
|
||||
Username: req.Username,
|
||||
Password: req.Password,
|
||||
})
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
} else {
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (a Api) Login(c *gin.Context) {
|
||||
req := new(dto.User)
|
||||
err := c.ShouldBind(&req)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
auth, err := a.userService.Login(c, req)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
} else {
|
||||
dto.Success(c, auth)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (a Api) Profile(c *gin.Context) {
|
||||
user, err := a.userService.Profile(c)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
} else {
|
||||
dto.Success(c, user)
|
||||
}
|
||||
}
|
||||
|
||||
func (a Api) UpdateProfile(c *gin.Context) {
|
||||
var user = model.User{}
|
||||
err := c.ShouldBind(&user)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = a.userService.Update(c, &model.User{Name: user.Name, Username: user.Username, Email: user.Email})
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) UpdatePassword(c *gin.Context) {
|
||||
var passwd dto.ChangePassword
|
||||
err := c.ShouldBind(&passwd)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
_user := c.MustGet("user").(*model.User)
|
||||
if _user.Password == "" {
|
||||
hashpass, err := utils.HashPassword(passwd.NewPassword)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
_user.Password = hashpass
|
||||
} else {
|
||||
if !utils.CheckPassword(_user.Password, passwd.Password) {
|
||||
dto.Fail(c, http.StatusBadRequest, "password not match")
|
||||
return
|
||||
}
|
||||
hashpass, err := utils.HashPassword(passwd.NewPassword)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
_user.Password = hashpass
|
||||
}
|
||||
err = a.userService.Update(c, _user)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) ListUser(c *gin.Context) {
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
offset := (page - 1) * limit
|
||||
active := c.QueryArray("active[]")
|
||||
if !slice.ContainSubSlice([]string{"true", "false", ""}, active) {
|
||||
dto.Fail(c, http.StatusBadRequest, "active must be true or false")
|
||||
return
|
||||
}
|
||||
|
||||
users, total, err := a.userService.List(c, limit, offset, active)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, gin.H{
|
||||
"users": users,
|
||||
"total": total,
|
||||
})
|
||||
}
|
||||
|
||||
func (a Api) CreateUser(c *gin.Context) {
|
||||
var user model.User
|
||||
err := c.ShouldBind(&user)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Printf("user:%+v\n", user)
|
||||
err = a.userService.Create(c, &user)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) GetUser(c *gin.Context) {
|
||||
id, _ := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
user, err := a.userService.GetByID(c, id)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, user)
|
||||
}
|
||||
|
||||
func (a Api) EditUser(c *gin.Context) {
|
||||
id, _ := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
var user model.User
|
||||
err := c.ShouldBind(&user)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
user.ID = int64(id)
|
||||
err = a.userService.Update(c, &user)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) DeleteUser(c *gin.Context) {
|
||||
id, _ := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
err := a.userService.Delete(c, id)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) UserOption(c *gin.Context) {
|
||||
option := strings.ToLower(c.Param("option"))
|
||||
var batchid dto.BatchIDRequest
|
||||
err := c.ShouldBind(&batchid)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
switch option {
|
||||
case "enable":
|
||||
err = a.userService.BatchEnable(c, batchid.IDs)
|
||||
case "disable":
|
||||
err = a.userService.BatchDisable(c, batchid.IDs)
|
||||
case "delete":
|
||||
err = a.userService.BatchDelete(c, batchid.IDs)
|
||||
default:
|
||||
dto.Fail(c, 400, "invalid option, only support enable, disable, delete")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, nil)
|
||||
|
||||
}
|
||||
218
internal/controller/user_token.go
Normal file
218
internal/controller/user_token.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"opencatd-open/internal/dto"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/internal/utils"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/duke-git/lancet/v2/slice"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (a Api) CreateToken(c *gin.Context) {
|
||||
userid := c.GetInt64("user_id")
|
||||
user, err := a.userService.GetByID(c, userid)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if len(user.Tokens) >= 20 {
|
||||
dto.Fail(c, http.StatusForbidden, "user has reached the maximum number of tokens")
|
||||
return
|
||||
}
|
||||
|
||||
var token model.Token
|
||||
err = c.ShouldBindJSON(&token)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
token.UserID = userid
|
||||
|
||||
err = a.tokenService.CreateToken(c, &token)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) ListToken(c *gin.Context) {
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
offset := (page - 1) * limit
|
||||
active := c.QueryArray("active[]")
|
||||
if !slice.ContainSubSlice([]string{"true", "false"}, active) {
|
||||
dto.Fail(c, http.StatusBadRequest, "active must be true or false")
|
||||
}
|
||||
|
||||
tokens, total, err := a.tokenService.ListToken(c, limit, offset, active)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
dto.Success(c, gin.H{
|
||||
"total": total,
|
||||
"tokens": tokens,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (a Api) GetToken(c *gin.Context) {
|
||||
id, _ := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
|
||||
token, err := a.tokenService.GetToken(c, id)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
dto.Success(c, token)
|
||||
}
|
||||
|
||||
func (a Api) ResetToken(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
token, err := a.tokenService.GetToken(c, id)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if token == nil {
|
||||
dto.Fail(c, http.StatusNotFound, "token not found")
|
||||
return
|
||||
}
|
||||
token.UsedQuota = utils.ToPtr(int64(0))
|
||||
|
||||
err = a.tokenService.UpdateToken(c, token)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) UpdateToken(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var token model.Token
|
||||
err = c.ShouldBindJSON(&token)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
token.ID = id
|
||||
if token.UserID == 0 {
|
||||
dto.Fail(c, http.StatusBadRequest, "user_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
var _token *model.Token
|
||||
|
||||
user, err := a.userService.GetByID(c, token.UserID)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
if len(user.Tokens) == 0 {
|
||||
dto.Fail(c, http.StatusForbidden, "user has no tokens")
|
||||
return
|
||||
} else {
|
||||
if findtoken, ok := slice.Find(user.Tokens,
|
||||
func(idx int, t model.Token) bool {
|
||||
return t.ID == id
|
||||
}); ok {
|
||||
_token = findtoken
|
||||
_token.User = user
|
||||
} else {
|
||||
dto.Fail(c, http.StatusForbidden, "user has no tokens")
|
||||
return
|
||||
}
|
||||
}
|
||||
// 更新_token信息
|
||||
if token.Name != "" {
|
||||
_token.Name = token.Name
|
||||
}
|
||||
if token.Key != "" {
|
||||
_token.Key = token.Key
|
||||
}
|
||||
if token.Active != nil {
|
||||
_token.Active = token.Active
|
||||
}
|
||||
if token.Quota != nil {
|
||||
_token.Quota = token.Quota
|
||||
}
|
||||
if token.UnlimitedQuota != nil {
|
||||
_token.UnlimitedQuota = token.UnlimitedQuota
|
||||
}
|
||||
if token.ExpiredAt != nil {
|
||||
_token.ExpiredAt = token.ExpiredAt
|
||||
}
|
||||
|
||||
err = a.tokenService.UpdateToken(c, _token)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) DeleteToken(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err = a.tokenService.DeleteToken(c, id)
|
||||
if err != nil {
|
||||
dto.Fail(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
|
||||
func (a Api) TokenOption(c *gin.Context) {
|
||||
option := strings.ToLower(c.Param("option"))
|
||||
var batchid dto.BatchIDRequest
|
||||
err := c.ShouldBind(&batchid)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
if batchid.UserID == nil {
|
||||
dto.Fail(c, 400, "user_id is required")
|
||||
return
|
||||
}
|
||||
switch option {
|
||||
case "enable":
|
||||
err = a.tokenService.EnableTokens(c, *batchid.UserID, batchid.IDs)
|
||||
case "disable":
|
||||
err = a.tokenService.DisableTokens(c, *batchid.UserID, batchid.IDs)
|
||||
case "delete":
|
||||
err = a.tokenService.DeleteTokens(c, *batchid.UserID, batchid.IDs)
|
||||
default:
|
||||
dto.Fail(c, 400, "invalid option, only support enable, disable, delete")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, nil)
|
||||
}
|
||||
108
internal/controller/webauth.go
Normal file
108
internal/controller/webauth.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"opencatd-open/internal/auth"
|
||||
"opencatd-open/internal/consts"
|
||||
"opencatd-open/internal/dto"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (a *Api) PasskeyCreateBegin(c *gin.Context) {
|
||||
userid := c.GetInt64("user_id")
|
||||
cred, err := a.webAuthService.BeginRegistration(userid)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, cred)
|
||||
}
|
||||
|
||||
func (a *Api) PasskeyCreateFinish(c *gin.Context) {
|
||||
userid := c.GetInt64("user_id")
|
||||
name := c.Query("name")
|
||||
if name == "" {
|
||||
name = fmt.Sprintf("User-%d-%d", userid, time.Now().Unix())
|
||||
}
|
||||
// var body protocol.CredentialCreationResponse
|
||||
// if err := c.ShouldBindJSON(&body); err != nil {
|
||||
// dto.Fail(c, 400, err.Error())
|
||||
// return
|
||||
// }
|
||||
|
||||
// 获取用户凭证
|
||||
cred, err := a.webAuthService.FinishRegistration(userid, c.Request, name)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
dto.Success(c, cred)
|
||||
}
|
||||
|
||||
func (a *Api) ListPasskey(c *gin.Context) {
|
||||
passkeys, err := a.webAuthService.ListPasskeys(c.GetInt64("user_id"))
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
var passkeysDto []dto.Passkey
|
||||
for _, passkey := range passkeys {
|
||||
passkeysDto = append(passkeysDto, dto.Passkey{
|
||||
ID: passkey.ID,
|
||||
Name: passkey.Name,
|
||||
DeviceType: passkey.DeviceType,
|
||||
SignCount: passkey.SignCount,
|
||||
LastUsedAt: passkey.LastUsedAt,
|
||||
CreatedAt: passkey.CreatedAt,
|
||||
UpdatedAt: passkey.UpdatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
dto.Success(c, passkeysDto)
|
||||
}
|
||||
|
||||
func (a *Api) DeletePasskey(c *gin.Context) {
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
dto.Fail(c, 400, err.Error())
|
||||
return
|
||||
}
|
||||
if err = a.webAuthService.DeletePasskey(c.GetInt64("user_id"), id); err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, "删除成功")
|
||||
}
|
||||
|
||||
// 登陆
|
||||
func (a *Api) PasskeyAuthBegin(c *gin.Context) {
|
||||
|
||||
cred, err := a.webAuthService.BeginLogin()
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, cred)
|
||||
}
|
||||
|
||||
func (a *Api) PasskeyAuthFinish(c *gin.Context) {
|
||||
challenge := c.Query("challenge")
|
||||
webAuthUser, err := a.webAuthService.FinishLogin(challenge, c.Request)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
at, err := auth.GenerateTokenPair(webAuthUser.User, consts.SecretKey, consts.Day*time.Second, consts.Day*time.Second)
|
||||
if err != nil {
|
||||
dto.Fail(c, 500, err.Error())
|
||||
return
|
||||
}
|
||||
dto.Success(c, dto.Auth{
|
||||
Token: at.AccessToken,
|
||||
ExpiresIn: time.Now().Add(consts.Day * time.Second).Unix(),
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user