team api
This commit is contained in:
92
cmd/openteam/main.go
Normal file
92
cmd/openteam/main.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"opencatd-open/pkg/store"
|
||||
"opencatd-open/team/dashboard"
|
||||
"opencatd-open/wire"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
|
||||
_, err := store.InitDB()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
team, err := wire.InitTeamHandler(ctx, store.DB)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r := gin.Default()
|
||||
teamGroup := r.Group("/1")
|
||||
teamGroup.Use(team.AuthMiddleware())
|
||||
{
|
||||
teamGroup.POST("/users/init", team.InitAdmin)
|
||||
// 获取当前用户信息
|
||||
teamGroup.GET("/me", team.Me)
|
||||
//// team.GET("/me/usages", team.HandleMeUsage)
|
||||
|
||||
teamGroup.POST("/keys", team.CreateKey)
|
||||
teamGroup.GET("/keys", team.ListKeys)
|
||||
teamGroup.POST("/keys/:id", team.UpdateKey)
|
||||
teamGroup.DELETE("/keys/:id", team.DeleteKey)
|
||||
|
||||
teamGroup.POST("/users", team.CreateUser)
|
||||
teamGroup.GET("/users", team.ListUsers)
|
||||
teamGroup.POST("/users/:id/reset", team.ResetUserToken)
|
||||
teamGroup.DELETE("/users/:id", team.DeleteUser)
|
||||
|
||||
teamGroup.GET("/1/usages", team.ListUsages)
|
||||
}
|
||||
|
||||
api := r.Group("/api")
|
||||
{
|
||||
api.POST("/login", dashboard.HandleLogin)
|
||||
}
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: ":8080",
|
||||
Handler: r,
|
||||
}
|
||||
|
||||
go func() {
|
||||
// 服务启动
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("listen: %s\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待中断信号来优雅地关闭服务器
|
||||
quit := make(chan os.Signal, 1)
|
||||
// kill (no param) default send syscall.SIGTERM
|
||||
// kill -2 is syscall.SIGINT
|
||||
// kill -9 is syscall.SIGKILL but can't be catch
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
log.Println("Shutdown Server ...")
|
||||
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
log.Fatal("Server Shutdown:", err)
|
||||
}
|
||||
db, _ := store.DB.DB()
|
||||
db.Close()
|
||||
// catching ctx.Done(). timeout of 1 seconds.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Println("timeout of 5 seconds.")
|
||||
}
|
||||
log.Println("Server exiting")
|
||||
|
||||
}
|
||||
5
internal/utils/pointer.go
Normal file
5
internal/utils/pointer.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package utils
|
||||
|
||||
func ToPtr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
22
opencat.go
22
opencat.go
@@ -8,9 +8,10 @@ import (
|
||||
"io/fs"
|
||||
"log"
|
||||
"net/http"
|
||||
"opencatd-open/pkg/team"
|
||||
"opencatd-open/router"
|
||||
"opencatd-open/store"
|
||||
"opencatd-open/team"
|
||||
"opencatd-open/team/dashboard"
|
||||
"os"
|
||||
|
||||
"github.com/duke-git/lancet/v2/fileutil"
|
||||
@@ -35,7 +36,7 @@ func main() {
|
||||
args := os.Args[1:]
|
||||
if len(args) > 0 {
|
||||
type user struct {
|
||||
ID uint
|
||||
ID int64
|
||||
Name string
|
||||
Token string
|
||||
}
|
||||
@@ -155,20 +156,25 @@ func main() {
|
||||
group.POST("/keys", team.HandleAddKey) // 添加Key
|
||||
group.DELETE("/keys/:id", team.HandleDelKey) // 删除Key
|
||||
|
||||
group.GET("/users", team.HandleUsers) // 获取所有用户信息
|
||||
group.POST("/users", team.HandleAddUser) // 添加用户
|
||||
group.DELETE("/users/:id", team.HandleDelUser) // 删除用户
|
||||
// group.GET("/users", team.HandleUsers) // 获取所有用户信息
|
||||
// group.POST("/users", team.HandleAddUser) // 添加用户
|
||||
// group.DELETE("/users/:id", team.HandleDelUser) // 删除用户
|
||||
|
||||
group.GET("/usages", team.HandleUsage)
|
||||
// group.GET("/usages", team.HandleUsage)
|
||||
|
||||
// 重置用户Token
|
||||
group.POST("/users/:id/reset", team.HandleResetUserToken)
|
||||
// // 重置用户Token
|
||||
// group.POST("/users/:id/reset", team.HandleResetUserToken)
|
||||
}
|
||||
// 初始化用户
|
||||
r.POST("/1/users/init", team.Handleinit)
|
||||
|
||||
r.Any("/v1/*proxypath", router.HandleProxy)
|
||||
|
||||
api := r.Group("/api")
|
||||
{
|
||||
api.POST("/login", dashboard.HandleLogin)
|
||||
}
|
||||
|
||||
// r.POST("/v1/chat/completions", router.HandleProy)
|
||||
// r.GET("/v1/models", router.HandleProy)
|
||||
// r.GET("/v1/dashboard/billing/subscription", router.HandleProy)
|
||||
|
||||
@@ -3,6 +3,7 @@ package store
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"opencatd-open/team/consts"
|
||||
"opencatd-open/team/model"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -10,13 +11,23 @@ import (
|
||||
// "gocloud.dev/mysql"
|
||||
// "gocloud.dev/postgres"
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/google/wire"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
|
||||
// "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
|
||||
var DBType consts.DBType
|
||||
var IsPostgres bool
|
||||
|
||||
var DBSet = wire.NewSet(
|
||||
InitDB,
|
||||
)
|
||||
|
||||
// InitDB 初始化数据库连接
|
||||
func InitDB() (*gorm.DB, error) {
|
||||
var db *gorm.DB
|
||||
@@ -32,21 +43,35 @@ func InitDB() (*gorm.DB, error) {
|
||||
// 解析DSN来确定数据库类型
|
||||
if strings.HasPrefix(dsn, "postgres://") {
|
||||
IsPostgres = true
|
||||
DBType = consts.DBTypePostgreSQL
|
||||
db, err = initPostgres(dsn)
|
||||
} else if strings.HasPrefix(dsn, "mysql://") {
|
||||
DBType = consts.DBTypeMySQL
|
||||
db, err = initMySQL(dsn)
|
||||
} else {
|
||||
if dsn != "" {
|
||||
return nil, fmt.Errorf("unsupported database type in DSN: %s", dsn)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
DB = db
|
||||
|
||||
if IsPostgres {
|
||||
err = db.AutoMigrate(&model.User{}, &model.ApiKey_PG{}, &model.Token{}, &model.Session{}, &model.Usage{}, &model.DailyUsage{})
|
||||
err = db.AutoMigrate(&model.User{}, &model.Token{}, &model.ApiKey_PG{}, &model.Usage{}, &model.DailyUsage{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
err = db.AutoMigrate(&model.User{}, &model.Token{}, &model.ApiKey{}, &model.Usage{}, &model.DailyUsage{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported database type in DSN: %s", dsn)
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// initSQLite 初始化 SQLite 数据库
|
||||
@@ -55,6 +80,7 @@ func initSQLite() (*gorm.DB, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to SQLite: %v", err)
|
||||
}
|
||||
// db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
||||
return db, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ func Handleinit(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.ID == uint(1) {
|
||||
if user.ID == 1 {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "super user already exists, use cli to reset password",
|
||||
})
|
||||
|
||||
@@ -82,7 +82,7 @@ func HandleResetUserToken(c *gin.Context) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if u.ID == uint(1) {
|
||||
if u.ID == 1 {
|
||||
rootToken = u.Token
|
||||
}
|
||||
c.JSON(http.StatusOK, u)
|
||||
|
||||
@@ -183,7 +183,7 @@ func Cost(model string, promptCount, completionCount int) float64 {
|
||||
cost = (0.00035/1000)*float64(prompt) + (0.00053/1000)*float64(completion)
|
||||
case "gemini-2.0-flash-exp":
|
||||
cost = (0.00035/1000)*float64(prompt) + (0.00053/1000)*float64(completion)
|
||||
case "gemini-2.0-flash-thinking-exp-1219":
|
||||
case "gemini-2.0-flash-thinking-exp-1219", "gemini-2.0-flash-thinking-exp-01-21":
|
||||
cost = (0.00035/1000)*float64(prompt) + (0.00053/1000)*float64(completion)
|
||||
case "learnlm-1.5-pro-experimental", " gemini-exp-1114", "gemini-exp-1121", "gemini-exp-1206":
|
||||
cost = (0.00035/1000)*float64(prompt) + (0.00053/1000)*float64(completion)
|
||||
|
||||
@@ -82,7 +82,7 @@ func SelectKeyCacheByModel(model string) (Key, error) {
|
||||
}
|
||||
items := KeysCache.Items()
|
||||
for _, item := range items {
|
||||
if strings.Contains(model, "realtime") {
|
||||
if strings.Contains(model, "realtime") || strings.HasPrefix(model, "o1-") {
|
||||
if item.Object.(Key).ApiType == "openai" {
|
||||
keys = append(keys, item.Object.(Key))
|
||||
}
|
||||
@@ -101,7 +101,7 @@ func SelectKeyCacheByModel(model string) (Key, error) {
|
||||
keys = append(keys, item.Object.(Key))
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(model, "o1-") || strings.HasPrefix(model, "chatgpt-") {
|
||||
if strings.HasPrefix(model, "chatgpt-") {
|
||||
if item.Object.(Key).ApiType == "openai" {
|
||||
keys = append(keys, item.Object.(Key))
|
||||
}
|
||||
|
||||
@@ -1,17 +1,44 @@
|
||||
package consts
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
type UserRole int
|
||||
|
||||
const (
|
||||
RoleGuest = iota * 10
|
||||
RoleUser
|
||||
RoleUser UserRole = iota * 10
|
||||
RoleAdmin
|
||||
RoleSuperAdmin
|
||||
)
|
||||
|
||||
const (
|
||||
StatusEnabled = iota
|
||||
StatusDisabled
|
||||
StatusDisabled = iota
|
||||
StatusEnabled
|
||||
StatusExpired // 过期
|
||||
StatusExhausted // 耗尽
|
||||
StatusDeleted
|
||||
|
||||
StatusDeleted = -1
|
||||
)
|
||||
const (
|
||||
Limited = iota
|
||||
Unlimited
|
||||
UnlimitedQuota = 999999
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserNotFound = gorm.ErrRecordNotFound
|
||||
)
|
||||
|
||||
func OpenOrClose(status bool) int {
|
||||
if status {
|
||||
return StatusEnabled
|
||||
}
|
||||
return StatusDisabled
|
||||
}
|
||||
|
||||
type DBType int
|
||||
|
||||
const (
|
||||
DBTypeMySQL DBType = iota
|
||||
DBTypePostgreSQL
|
||||
DBTypeSQLite
|
||||
)
|
||||
const UnlimitedQuota = -1
|
||||
|
||||
@@ -18,6 +18,7 @@ type ApiKeyRepository interface {
|
||||
Update(apiKey *model.ApiKey) error
|
||||
Delete(id int64) error
|
||||
List(offset, limit int, status *int) ([]model.ApiKey, error)
|
||||
ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.ApiKey, int64, error)
|
||||
Enable(id int64) error
|
||||
Disable(id int64) error
|
||||
BatchEnable(ids []int64) error
|
||||
@@ -77,7 +78,7 @@ func (dao *ApiKeyDAO) Update(apiKey *model.ApiKey) error {
|
||||
if apiKey == nil {
|
||||
return errors.New("apiKey is nil")
|
||||
}
|
||||
apiKey.UpdatedAt = time.Now()
|
||||
apiKey.UpdatedAt = time.Now().Unix()
|
||||
return dao.db.Save(apiKey).Error
|
||||
}
|
||||
|
||||
@@ -100,6 +101,21 @@ func (dao *ApiKeyDAO) List(offset, limit int, status *int) ([]model.ApiKey, erro
|
||||
return apiKeys, nil
|
||||
}
|
||||
|
||||
// ListApiKeysWithFilters 根据条件获取ApiKey列表
|
||||
func (dao *ApiKeyDAO) ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.ApiKey, int64, error) {
|
||||
var apiKeys []model.ApiKey
|
||||
db := dao.db.Offset(offset).Limit(limit)
|
||||
for key, value := range filters {
|
||||
db = db.Where(key+" = ?", value)
|
||||
}
|
||||
var count int64
|
||||
err := db.Find(&apiKeys).Count(&count).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return apiKeys, count, nil
|
||||
}
|
||||
|
||||
// EnableApiKey 启用ApiKey
|
||||
func (dao *ApiKeyDAO) Enable(id int64) error {
|
||||
return dao.db.Model(&model.ApiKey{}).Where("id = ?", id).Update("status", 0).Error
|
||||
|
||||
@@ -1,28 +1,33 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"opencatd-open/team/consts"
|
||||
"opencatd-open/team/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 确保 TokenDAO 实现了 TokenDAOInterface 接口
|
||||
var _ TokenDAOInterface = (*TokenDAO)(nil)
|
||||
// 确保 TokenDAO 实现了 TokenRepository 接口
|
||||
var _ TokenRepository = (*TokenDAO)(nil)
|
||||
|
||||
type TokenDAOInterface interface {
|
||||
Create(token *model.Token) error
|
||||
GetByID(id int) (*model.Token, error)
|
||||
GetByKey(key string) (*model.Token, error)
|
||||
GetByUserID(userID int) ([]model.Token, error)
|
||||
Update(token *model.Token) error
|
||||
Delete(id int) error
|
||||
List(offset, limit int) ([]model.Token, error)
|
||||
Disable(id int) error
|
||||
Enable(id int) error
|
||||
BatchDisable(ids []int) error
|
||||
BatchEnable(ids []int) error
|
||||
BatchDelete(ids []int) error
|
||||
type TokenRepository interface {
|
||||
Create(ctx context.Context, token *model.Token) error
|
||||
GetByID(ctx context.Context, id int) (*model.Token, error)
|
||||
GetByKey(ctx context.Context, key string) (*model.Token, error)
|
||||
GetByUserID(ctx context.Context, userID int) (*model.Token, error)
|
||||
Update(ctx context.Context, token *model.Token) error
|
||||
UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error
|
||||
Delete(ctx context.Context, id int) error
|
||||
List(ctx context.Context, offset, limit int) ([]model.Token, error)
|
||||
ListWithFilters(ctx context.Context, offset, limit int, filters map[string]interface{}) ([]model.Token, int64, error)
|
||||
Disable(ctx context.Context, id int) error
|
||||
Enable(ctx context.Context, id int) error
|
||||
BatchDisable(ctx context.Context, ids []int) error
|
||||
BatchEnable(ctx context.Context, ids []int) error
|
||||
BatchDelete(ctx context.Context, ids []int) error
|
||||
}
|
||||
|
||||
type TokenDAO struct {
|
||||
@@ -34,87 +39,140 @@ func NewTokenDAO(db *gorm.DB) *TokenDAO {
|
||||
}
|
||||
|
||||
// CreateToken 创建 Token
|
||||
func (dao *TokenDAO) Create(token *model.Token) error {
|
||||
func (dao *TokenDAO) Create(ctx context.Context, token *model.Token) error {
|
||||
if token == nil {
|
||||
return errors.New("token is nil")
|
||||
}
|
||||
return dao.db.Create(token).Error
|
||||
return dao.db.WithContext(ctx).Create(token).Error
|
||||
}
|
||||
|
||||
// GetTokenByID 根据 ID 获取 Token
|
||||
func (dao *TokenDAO) GetByID(id int) (*model.Token, error) {
|
||||
// 根据 ID 获取 Token
|
||||
func (dao *TokenDAO) GetByID(ctx context.Context, id int) (*model.Token, error) {
|
||||
var token model.Token
|
||||
err := dao.db.First(&token, id).Error
|
||||
err := dao.db.WithContext(ctx).First(&token, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// GetTokenByKey 根据 Key 获取 Token
|
||||
func (dao *TokenDAO) GetByKey(key string) (*model.Token, error) {
|
||||
// 根据 Key 获取 Token
|
||||
func (dao *TokenDAO) GetByKey(ctx context.Context, key string) (*model.Token, error) {
|
||||
var token model.Token
|
||||
err := dao.db.Where("key = ?", key).First(&token).Error
|
||||
// err := dao.db.Where("key = ?", key).First(&token).Error
|
||||
err := dao.db.WithContext(ctx).Preload("User").Where("key = ?", key).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// GetTokensByUserID 根据 UserID 获取 Token 列表
|
||||
func (dao *TokenDAO) GetByUserID(userID int) ([]model.Token, error) {
|
||||
var tokens []model.Token
|
||||
err := dao.db.Where("user_id = ?", userID).Find(&tokens).Error
|
||||
// 根据 UserID 获取 Token
|
||||
func (dao *TokenDAO) GetByUserID(ctx context.Context, userID int) (*model.Token, error) {
|
||||
var token model.Token
|
||||
err := dao.db.WithContext(ctx).Preload("User").Where("user_id = ?", userID).Find(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tokens, nil
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// UpdateToken 更新 Token 信息
|
||||
func (dao *TokenDAO) Update(token *model.Token) error {
|
||||
func (dao *TokenDAO) Update(ctx context.Context, token *model.Token) error {
|
||||
if token == nil {
|
||||
return errors.New("token is nil")
|
||||
}
|
||||
return dao.db.Save(token).Error
|
||||
return dao.db.WithContext(ctx).Save(token).Error
|
||||
}
|
||||
|
||||
// UpdateTokenWithFilters 更新 Token 信息,支持过滤
|
||||
func (dao *TokenDAO) UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error {
|
||||
if token == nil {
|
||||
return errors.New("token is nil")
|
||||
}
|
||||
db := dao.db.WithContext(ctx)
|
||||
for key, value := range filters {
|
||||
db = db.Where(key+" = ?", value)
|
||||
}
|
||||
return db.Model(&model.Token{}).Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteToken 删除 Token
|
||||
func (dao *TokenDAO) Delete(id int) error {
|
||||
return dao.db.Delete(&model.Token{}, id).Error
|
||||
func (dao *TokenDAO) Delete(ctx context.Context, id int) error {
|
||||
return dao.db.WithContext(ctx).Delete(&model.Token{}, id).Error
|
||||
}
|
||||
|
||||
// ListTokens 获取 Token 列表
|
||||
func (dao *TokenDAO) List(offset, limit int) ([]model.Token, error) {
|
||||
func (dao *TokenDAO) List(ctx context.Context, offset, limit int) ([]model.Token, error) {
|
||||
var tokens []model.Token
|
||||
err := dao.db.Offset(offset).Limit(limit).Find(&tokens).Error
|
||||
err := dao.db.WithContext(ctx).Offset(offset).Limit(limit).Find(&tokens).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// ListTokensWithFilters 获取 Token 列表,支持过滤
|
||||
func (dao *TokenDAO) ListWithFilters(ctx context.Context, offset, limit int, filters map[string]interface{}) ([]model.Token, int64, error) {
|
||||
var tokens []model.Token
|
||||
var count int64
|
||||
|
||||
db := dao.db.WithContext(ctx)
|
||||
for key, value := range filters {
|
||||
db = db.Where(key+" = ?", value)
|
||||
}
|
||||
|
||||
if err := db.Offset(offset).Limit(limit).Find(&tokens).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if err := db.Model(&model.Token{}).Count(&count).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return tokens, count, nil
|
||||
}
|
||||
|
||||
// DisableToken 禁用 Token
|
||||
func (dao *TokenDAO) Disable(id int) error {
|
||||
return dao.db.Model(&model.Token{}).Where("id = ?", id).Update("status", false).Error
|
||||
func (dao *TokenDAO) Disable(ctx context.Context, id int) error {
|
||||
return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id = ?", id).Update("status", false).Error
|
||||
}
|
||||
|
||||
// EnableToken 启用 Token
|
||||
func (dao *TokenDAO) Enable(id int) error {
|
||||
return dao.db.Model(&model.Token{}).Where("id = ?", id).Update("status", true).Error
|
||||
func (dao *TokenDAO) Enable(ctx context.Context, id int) error {
|
||||
return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id = ?", id).Update("status", true).Error
|
||||
}
|
||||
|
||||
// BatchDisableTokens 批量禁用 Token
|
||||
func (dao *TokenDAO) BatchDisable(ids []int) error {
|
||||
return dao.db.Model(&model.Token{}).Where("id IN ?", ids).Update("status", false).Error
|
||||
func (dao *TokenDAO) BatchDisable(ctx context.Context, ids []int) error {
|
||||
return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id IN ?", ids).Update("status", false).Error
|
||||
}
|
||||
|
||||
// BatchEnableTokens 批量启用 Token
|
||||
func (dao *TokenDAO) BatchEnable(ids []int) error {
|
||||
return dao.db.Model(&model.Token{}).Where("id IN ?", ids).Update("status", true).Error
|
||||
func (dao *TokenDAO) BatchEnable(ctx context.Context, ids []int) error {
|
||||
return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id IN ?", ids).Update("status", true).Error
|
||||
}
|
||||
|
||||
// BatchDeleteTokens 批量删除 Token
|
||||
func (dao *TokenDAO) BatchDelete(ids []int) error {
|
||||
return dao.db.Where("id IN ?", ids).Delete(&model.Token{}).Error
|
||||
func (dao *TokenDAO) BatchDelete(ctx context.Context, ids []int) error {
|
||||
return dao.db.WithContext(ctx).Where("id IN ?", ids).Delete(&model.Token{}).Error
|
||||
}
|
||||
|
||||
// 检查 token 是否有效
|
||||
func (dao *TokenDAO) IsValid(ctx context.Context, key string) (bool, error) {
|
||||
var token model.Token
|
||||
err := dao.db.WithContext(ctx).Where("key = ? AND status = ? AND (expired_time = -1 OR expired_time > ?)",
|
||||
key, consts.StatusEnabled, time.Now().Unix()).First(&token).Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
if token.User.Status != consts.StatusEnabled || (token.User.UnlimitedQuota == 1 && token.User.Quota <= 0) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
297
team/dao/usage.go
Normal file
297
team/dao/usage.go
Normal file
@@ -0,0 +1,297 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"opencatd-open/pkg/store"
|
||||
"opencatd-open/team/consts"
|
||||
dto "opencatd-open/team/dto/team"
|
||||
"opencatd-open/team/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
var _ UsageRepository = (*UsageDAO)(nil)
|
||||
var _ DailyUsageRepository = (*DailyUsageDAO)(nil)
|
||||
|
||||
type UsageRepository interface {
|
||||
// Create
|
||||
Create(ctx context.Context, usage *model.Usage) error
|
||||
BatchCreate(ctx context.Context, usages []*model.Usage) error
|
||||
|
||||
// Read
|
||||
ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error)
|
||||
ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.Usage, error)
|
||||
ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.Usage, error)
|
||||
ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error)
|
||||
|
||||
// Delete
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
// Statistics
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
}
|
||||
|
||||
type DailyUsageRepository interface {
|
||||
// Create
|
||||
Create(ctx context.Context, usage *model.DailyUsage) error
|
||||
BatchCreate(ctx context.Context, usages []*model.DailyUsage) error
|
||||
|
||||
// Read
|
||||
ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.DailyUsage, error)
|
||||
ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.DailyUsage, error)
|
||||
ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.DailyUsage, error)
|
||||
GetByDate(ctx context.Context, userID int64, date time.Time) (*model.DailyUsage, error)
|
||||
|
||||
// Delete
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
// Statistics
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
StatUserUsages(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error)
|
||||
}
|
||||
|
||||
type UsageDAO struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
type DailyUsageDAO struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUsageDAO(db *gorm.DB) *UsageDAO {
|
||||
return &UsageDAO{db: db}
|
||||
}
|
||||
|
||||
func NewDailyUsageDAO(db *gorm.DB) *DailyUsageDAO {
|
||||
return &DailyUsageDAO{db: db}
|
||||
}
|
||||
|
||||
// Usage DAO implementations
|
||||
func (d *UsageDAO) Create(ctx context.Context, usage *model.Usage) error {
|
||||
return d.db.WithContext(ctx).Create(usage).Error
|
||||
}
|
||||
|
||||
func (d *UsageDAO) BatchCreate(ctx context.Context, usages []*model.Usage) error {
|
||||
return d.db.WithContext(ctx).Create(usages).Error
|
||||
}
|
||||
|
||||
func (d *UsageDAO) GetByID(ctx context.Context, id int64) (*model.Usage, error) {
|
||||
var usage model.Usage
|
||||
err := d.db.WithContext(ctx).First(&usage, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func (d *UsageDAO) ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error) {
|
||||
var usages []*model.Usage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("user_id = ?", userID).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *UsageDAO) ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.Usage, error) {
|
||||
var usages []*model.Usage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("token_id = ?", tokenID).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *UsageDAO) ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.Usage, error) {
|
||||
var usages []*model.Usage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("date BETWEEN ? AND ?", start, end).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *UsageDAO) ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) {
|
||||
var usages []*model.Usage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("capability = ?", capability).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *UsageDAO) Delete(ctx context.Context, id int64) error {
|
||||
return d.db.WithContext(ctx).Delete(&model.Usage{}, id).Error
|
||||
}
|
||||
|
||||
func (d *UsageDAO) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := d.db.WithContext(ctx).Model(&model.Usage{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// DailyUsage DAO implementations
|
||||
func (d *DailyUsageDAO) Create(ctx context.Context, usage *model.DailyUsage) error {
|
||||
return d.db.WithContext(ctx).Create(usage).Error
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) BatchCreate(ctx context.Context, usages []*model.DailyUsage) error {
|
||||
return d.db.WithContext(ctx).Create(usages).Error
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.DailyUsage, error) {
|
||||
var usages []*model.DailyUsage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("user_id = ?", userID).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) ListByTokenID(ctx context.Context, tokenID int64, limit, offset int) ([]*model.DailyUsage, error) {
|
||||
var usages []*model.DailyUsage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("token_id = ?", tokenID).
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) ListByDateRange(ctx context.Context, start, end time.Time) ([]*model.DailyUsage, error) {
|
||||
var usages []*model.DailyUsage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("date BETWEEN ? AND ?", start, end).
|
||||
Find(&usages).Error
|
||||
return usages, err
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) GetByDate(ctx context.Context, userID int64, date time.Time) (*model.DailyUsage, error) {
|
||||
var usage model.DailyUsage
|
||||
err := d.db.WithContext(ctx).
|
||||
Where("user_id = ? AND date = ?", userID, date).
|
||||
First(&usage).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
// UpsertDailyUsage 根据不同数据库类型执行 Upsert
|
||||
func (d *DailyUsageDAO) UpsertDailyUsage(ctx context.Context, usage *model.Usage) error {
|
||||
date := usage.Date.Truncate(24 * time.Hour)
|
||||
dailyUsage := &model.DailyUsage{
|
||||
UserID: usage.UserID,
|
||||
TokenID: usage.TokenID,
|
||||
Capability: usage.Capability,
|
||||
Model: usage.Model,
|
||||
Stream: usage.Stream,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
TotalTokens: usage.TotalTokens,
|
||||
Cost: usage.Cost,
|
||||
}
|
||||
|
||||
updateColumns := map[string]interface{}{
|
||||
"prompt_tokens": gorm.Expr("prompt_tokens + VALUES(prompt_tokens)"),
|
||||
"completion_tokens": gorm.Expr("completion_tokens + VALUES(completion_tokens)"),
|
||||
"total_tokens": gorm.Expr("total_tokens + VALUES(total_tokens)"),
|
||||
}
|
||||
|
||||
db := d.db.WithContext(ctx)
|
||||
|
||||
switch store.DBType {
|
||||
case consts.DBTypeMySQL:
|
||||
// MySQL: INSERT ... ON DUPLICATE KEY UPDATE
|
||||
return db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "user_id"},
|
||||
{Name: "token_id"},
|
||||
{Name: "capability"},
|
||||
{Name: "date"},
|
||||
{Name: "model"},
|
||||
{Name: "stream"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(updateColumns),
|
||||
}).Create(dailyUsage).Error
|
||||
|
||||
case consts.DBTypePostgreSQL:
|
||||
// PostgreSQL: INSERT ... ON CONFLICT DO UPDATE
|
||||
updateColumns := map[string]interface{}{
|
||||
"prompt_tokens": gorm.Expr("daily_usages.prompt_tokens + EXCLUDED.prompt_tokens"),
|
||||
"completion_tokens": gorm.Expr("daily_usages.completion_tokens + EXCLUDED.completion_tokens"),
|
||||
"total_tokens": gorm.Expr("daily_usages.total_tokens + EXCLUDED.total_tokens"),
|
||||
}
|
||||
return db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{
|
||||
{Name: "user_id"},
|
||||
{Name: "token_id"},
|
||||
{Name: "capability"},
|
||||
{Name: "date"},
|
||||
{Name: "model"},
|
||||
{Name: "stream"},
|
||||
},
|
||||
DoUpdates: clause.Assignments(updateColumns),
|
||||
}).Create(dailyUsage).Error
|
||||
case consts.DBTypeSQLite:
|
||||
// SQLite: 需要使用事务来模拟 upsert
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
var existing model.DailyUsage
|
||||
err := tx.Where("user_id = ? AND token_id = ? AND capability = ? AND date = ? AND model = ? AND stream = ?",
|
||||
usage.UserID, usage.TokenID, usage.Capability, date, usage.Model, usage.Stream).
|
||||
First(&existing).Error
|
||||
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
// 记录不存在,创建新记录
|
||||
return tx.Create(dailyUsage).Error
|
||||
} else if err != nil {
|
||||
return err // 返回其他错误
|
||||
}
|
||||
|
||||
// 记录存在,更新
|
||||
return tx.Model(&existing).Updates(map[string]interface{}{
|
||||
"prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens),
|
||||
"completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens),
|
||||
"total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens),
|
||||
}).Error
|
||||
})
|
||||
|
||||
default:
|
||||
return fmt.Errorf("不支持的数据库类型: %s", store.DBType)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) Delete(ctx context.Context, id int64) error {
|
||||
return d.db.WithContext(ctx).Delete(&model.DailyUsage{}, id).Error
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := d.db.WithContext(ctx).Model(&model.DailyUsage{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (d *DailyUsageDAO) StatUserUsages(ctx context.Context, from, to time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) {
|
||||
var usages []*dto.UsageInfo
|
||||
|
||||
query := d.db.WithContext(ctx).
|
||||
Model(&model.DailyUsage{}).
|
||||
Select("user_id as userId, sum(total_tokens) as totalUnit, sum(cast(cost as decimal(20,6))) as cost")
|
||||
for key, value := range filters {
|
||||
query = query.Where(fmt.Sprintf("%s = ?", key), value)
|
||||
}
|
||||
query = query.Group("user_id").Where("date >= ? AND date <= ?", from, to)
|
||||
|
||||
err := query.Group("user_id").Find(&usages).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list usages: %w", err)
|
||||
}
|
||||
|
||||
return usages, nil
|
||||
}
|
||||
@@ -2,6 +2,8 @@ package dao
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"opencatd-open/team/consts"
|
||||
"opencatd-open/team/model"
|
||||
"time"
|
||||
|
||||
@@ -19,6 +21,7 @@ type UserRepository interface {
|
||||
Update(user *model.User) error
|
||||
Delete(id int64) error
|
||||
List(offset, limit int) ([]model.User, error)
|
||||
ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.User, int64, error)
|
||||
Enable(id int64) error
|
||||
Disable(id int64) error
|
||||
BatchEnable(ids []int64) error
|
||||
@@ -34,70 +37,107 @@ func NewUserDAO(db *gorm.DB) *UserDAO {
|
||||
return &UserDAO{db: db}
|
||||
}
|
||||
|
||||
// CreateUser 创建用户
|
||||
// 创建用户
|
||||
func (dao *UserDAO) Create(user *model.User) error {
|
||||
if user == nil {
|
||||
return errors.New("user is nil")
|
||||
}
|
||||
return dao.db.Create(user).Error
|
||||
|
||||
return dao.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 创建用户
|
||||
if err := tx.Create(user).Error; err != nil {
|
||||
return fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserByID 根据ID获取用户
|
||||
// 根据ID获取用户
|
||||
func (dao *UserDAO) GetByID(id int64) (*model.User, error) {
|
||||
var user model.User
|
||||
err := dao.db.First(&user, id).Error
|
||||
// err := dao.db.First(&user, id).Error
|
||||
err := dao.db.Preload("Tokens").First(&user, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserByUsername 根据用户名获取用户
|
||||
// 根据用户名获取用户
|
||||
func (dao *UserDAO) GetByUsername(username string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := dao.db.Where("user_name = ?", username).First(&user).Error
|
||||
// err := dao.db.Where("user_name = ?", username).First(&user).Error
|
||||
err := dao.db.Preload("Tokens").Where("user_name = ?", username).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// UpdateUser 更新用户信息
|
||||
// 更新用户信息
|
||||
func (dao *UserDAO) Update(user *model.User) error {
|
||||
if user == nil {
|
||||
return errors.New("user is nil")
|
||||
}
|
||||
user.UpdatedAt = time.Now()
|
||||
user.UpdatedAt = time.Now().Unix()
|
||||
return dao.db.Save(user).Error
|
||||
}
|
||||
|
||||
// DeleteUser 删除用户
|
||||
// 删除用户
|
||||
func (dao *UserDAO) Delete(id int64) error {
|
||||
return dao.db.Delete(&model.User{}, id).Error
|
||||
// return dao.db.Model(&model.User{}).Where("id = ?", id).Update("status", 2).Error
|
||||
}
|
||||
|
||||
// ListUsers 获取用户列表
|
||||
// 获取用户列表
|
||||
func (dao *UserDAO) List(offset, limit int) ([]model.User, error) {
|
||||
var users []model.User
|
||||
err := dao.db.Offset(offset).Limit(limit).Find(&users).Error
|
||||
err := dao.db.Preload("Tokens").Offset(offset).Limit(limit).Find(&users).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// EnableUser 启用User
|
||||
// 获取用户列表,带过滤条件
|
||||
func (dao *UserDAO) ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.User, int64, error) {
|
||||
var users []model.User
|
||||
var total int64
|
||||
|
||||
// 构建查询
|
||||
query := dao.db.Model(&model.User{})
|
||||
|
||||
// 添加过滤条件
|
||||
for key, value := range filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
|
||||
// 查询总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
err := query.Offset(offset).Limit(limit).Find(&users).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return users, total, nil
|
||||
}
|
||||
|
||||
// 启用User
|
||||
func (dao *UserDAO) Enable(id int64) error {
|
||||
return dao.db.Model(&model.User{}).Where("id = ?", id).Update("status", 0).Error
|
||||
return dao.db.Model(&model.User{}).Where("id = ?", id).Update("status", consts.StatusEnabled).Error
|
||||
}
|
||||
|
||||
// DisableUser 禁用User
|
||||
// 禁用User
|
||||
func (dao *UserDAO) Disable(id int64) error {
|
||||
return dao.db.Model(&model.User{}).Where("id = ?", id).Update("status", 1).Error
|
||||
return dao.db.Model(&model.User{}).Where("id = ?", id).Update("status", consts.StatusDisabled).Error
|
||||
}
|
||||
|
||||
// BatchEnableUsers 批量启用User
|
||||
// 批量启用User
|
||||
func (dao *UserDAO) BatchEnable(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids is empty")
|
||||
@@ -105,7 +145,7 @@ func (dao *UserDAO) BatchEnable(ids []int64) error {
|
||||
return dao.db.Model(&model.User{}).Where("id IN ?", ids).Update("status", 0).Error
|
||||
}
|
||||
|
||||
// BatchDisableUsers 批量禁用User
|
||||
// 批量禁用User
|
||||
func (dao *UserDAO) BatchDisable(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids is empty")
|
||||
@@ -113,7 +153,7 @@ func (dao *UserDAO) BatchDisable(ids []int64) error {
|
||||
return dao.db.Model(&model.User{}).Where("id IN ?", ids).Update("status", 1).Error
|
||||
}
|
||||
|
||||
// BatchDeleteUser 批量删除用户
|
||||
// 批量删除用户
|
||||
func (dao *UserDAO) BatchDelete(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids is empty")
|
||||
|
||||
16
team/dashboard/dashboard.go
Normal file
16
team/dashboard/dashboard.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package dashboard
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func HandleTeam(c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"code": 200,
|
||||
"data": gin.H{
|
||||
"team": gin.H{
|
||||
"total_users": 10,
|
||||
"total_keys": 20,
|
||||
"total_projects": 30,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
20
team/dashboard/login.go
Normal file
20
team/dashboard/login.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package dashboard
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func HandleLogin(c *gin.Context) {
|
||||
var user map[string]string
|
||||
c.ShouldBind(&user)
|
||||
fmt.Sprintf("%v", user)
|
||||
c.JSON(200, gin.H{
|
||||
"code": 200,
|
||||
"msg": "success",
|
||||
"data": gin.H{
|
||||
"token": "token",
|
||||
},
|
||||
})
|
||||
}
|
||||
22
team/dto/openai/err_resp.go
Normal file
22
team/dto/openai/err_resp.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
Message string `json:"message,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
func WarpErrAsOpenAI(c *gin.Context, msg string, code string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": Error{
|
||||
Message: msg,
|
||||
Code: code,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
83
team/dto/team/team.go
Normal file
83
team/dto/team/team.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"opencatd-open/team/consts"
|
||||
"opencatd-open/team/model"
|
||||
)
|
||||
|
||||
type UserInfo struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Token string `json:"token"`
|
||||
Status *bool `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
func (u UserInfo) HasNameUpdate() bool {
|
||||
return u.Name != ""
|
||||
}
|
||||
|
||||
func (u UserInfo) HasTokenUpdate() bool {
|
||||
return u.Token != ""
|
||||
}
|
||||
|
||||
func (u UserInfo) HasStatusUpdate() bool {
|
||||
return u.Status != nil
|
||||
}
|
||||
|
||||
type KeyInfo struct {
|
||||
ID int `json:"id,omitempty"`
|
||||
Key string `json:"key,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ApiType string `json:"api_type,omitempty"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
Status *bool `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
// 添加辅助方法判断字段是否需要更新
|
||||
func (k KeyInfo) HasNameUpdate() bool {
|
||||
return k.Name != ""
|
||||
}
|
||||
|
||||
func (k KeyInfo) HasKeyUpdate() bool {
|
||||
return k.Key != ""
|
||||
}
|
||||
|
||||
func (k KeyInfo) HasStatusUpdate() bool {
|
||||
return k.Status != nil
|
||||
}
|
||||
|
||||
func (k KeyInfo) HasApiTypeUpdate() bool {
|
||||
return k.ApiType != ""
|
||||
}
|
||||
|
||||
// 辅助函数:统一处理字段更新
|
||||
func (update *KeyInfo) UpdateFields(existing *model.ApiKey) *model.ApiKey {
|
||||
result := &model.ApiKey{
|
||||
ID: existing.ID,
|
||||
Name: existing.Name, // 默认保持原值
|
||||
ApiType: existing.ApiType, // 默认保持原值
|
||||
ApiKey: existing.ApiKey, // 默认保持原值
|
||||
Status: existing.Status, // 默认保持原值
|
||||
}
|
||||
|
||||
if update.HasNameUpdate() {
|
||||
result.Name = update.Name
|
||||
}
|
||||
if update.HasKeyUpdate() {
|
||||
result.ApiKey = update.Key
|
||||
}
|
||||
if update.HasStatusUpdate() {
|
||||
result.Status = consts.OpenOrClose(*update.Status)
|
||||
}
|
||||
if update.HasApiTypeUpdate() {
|
||||
result.ApiType = update.ApiType
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type UsageInfo struct {
|
||||
UserId int `json:"userId"`
|
||||
TotalUnit int `json:"totalUnit"`
|
||||
Cost string `json:"cost"`
|
||||
}
|
||||
53
team/handler/team/middleware.go
Normal file
53
team/handler/team/middleware.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"opencatd-open/team/consts"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (h *TeamHandler) AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.URL.Path == "/1/users/init" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
authtoken := c.GetHeader("Authorization")
|
||||
if authtoken == "" || len(authtoken) <= 7 || authtoken[:7] != "Bearer " {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
authtoken = authtoken[7:]
|
||||
token, err := h.tokenService.GetByKey(c, authtoken)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
if token.Name != "default" {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "only default token can access"})
|
||||
c.Abort()
|
||||
}
|
||||
if token.User.Status != consts.StatusEnabled {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "user is disabled"})
|
||||
c.Abort()
|
||||
}
|
||||
c.Set("local_user", true)
|
||||
c.Set("token", token)
|
||||
|
||||
// 可以在这里对 token 进行验证并检查权限
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func CORS() gin.HandlerFunc {
|
||||
config := cors.DefaultConfig()
|
||||
config.AllowAllOrigins = true
|
||||
config.AllowCredentials = true
|
||||
config.AllowMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
|
||||
config.AllowHeaders = []string{"*"}
|
||||
return cors.New(config)
|
||||
}
|
||||
567
team/handler/team/team.go
Normal file
567
team/handler/team/team.go
Normal file
@@ -0,0 +1,567 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"opencatd-open/internal/utils"
|
||||
"opencatd-open/team/consts"
|
||||
dto "opencatd-open/team/dto/team"
|
||||
"opencatd-open/team/model"
|
||||
"opencatd-open/team/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type TeamHandler struct {
|
||||
db *gorm.DB
|
||||
userService service.UserService
|
||||
tokenService service.TokenService
|
||||
keyService service.ApiKeyService
|
||||
usageService service.UsageService
|
||||
}
|
||||
|
||||
func NewTeamHandler(userService service.UserService, tokenService service.TokenService, keyService service.ApiKeyService, usageService service.UsageService) *TeamHandler {
|
||||
return &TeamHandler{
|
||||
userService: userService,
|
||||
tokenService: tokenService,
|
||||
keyService: keyService,
|
||||
usageService: usageService,
|
||||
}
|
||||
}
|
||||
|
||||
// initadmin
|
||||
func (h *TeamHandler) InitAdmin(c *gin.Context) {
|
||||
admin, err := h.userService.GetUser(c, 1)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
user := &model.User{
|
||||
Name: "root",
|
||||
Username: "root",
|
||||
Password: "openteam",
|
||||
Role: int(consts.RoleSuperAdmin),
|
||||
Tokens: []model.Token{
|
||||
{
|
||||
Name: "default",
|
||||
Key: "team-" + strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
UnlimitedQuota: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := h.userService.CreateUser(c, user); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
var result = dto.UserInfo{
|
||||
ID: int(user.ID),
|
||||
Name: user.Name,
|
||||
Token: user.Tokens[0].Key,
|
||||
Status: utils.ToPtr(user.Status == consts.StatusEnabled),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, result)
|
||||
return
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
if admin != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "super user already exists, use cli to reset password",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (h *TeamHandler) Me(c *gin.Context) {
|
||||
// token := c.GetHeader("Authorization")
|
||||
// token = strings.TrimPrefix(token, "Bearer ")
|
||||
// userToken, err := h.tokenService.GetTokenByKey(token)
|
||||
// if err != nil {
|
||||
// c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
// return
|
||||
// }
|
||||
// if userToken.ID != 1 {
|
||||
// c.JSON(http.StatusForbidden, gin.H{"error": "only first user token can access"})
|
||||
// return
|
||||
// }
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "token not found"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
|
||||
c.JSON(http.StatusOK, dto.UserInfo{
|
||||
ID: int(userToken.UserID),
|
||||
Name: userToken.User.Name,
|
||||
Token: userToken.Key,
|
||||
Status: utils.ToPtr(userToken.User.Status == consts.StatusEnabled),
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// CreateUser 创建用户
|
||||
func (h *TeamHandler) CreateUser(c *gin.Context) {
|
||||
var userReq dto.UserInfo
|
||||
if err := c.ShouldBindJSON(&userReq); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid input"})
|
||||
return
|
||||
}
|
||||
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
if userToken.User.Role < int(consts.RoleAdmin) {
|
||||
create := &model.Token{
|
||||
Name: userReq.Name,
|
||||
Key: "team-" + strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
}
|
||||
if err := h.tokenService.Create(c.Request.Context(), create); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
user := &model.User{
|
||||
Name: userReq.Name,
|
||||
Username: userReq.Name,
|
||||
Role: int(consts.RoleUser),
|
||||
Tokens: []model.Token{
|
||||
{
|
||||
Name: "default",
|
||||
Key: "team-" + strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 默认角色为普通用户
|
||||
if err := h.userService.CreateUser(c.Request.Context(), user); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
// GetUser 获取用户信息
|
||||
func (h *TeamHandler) GetUser(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.userService.GetUser(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, user)
|
||||
}
|
||||
|
||||
// UpdateUser 更新用户信息
|
||||
func (h *TeamHandler) UpdateUser(c *gin.Context) {
|
||||
var user model.User
|
||||
if err := c.ShouldBindJSON(&user); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid input"})
|
||||
return
|
||||
}
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
|
||||
operatorID := userToken.UserID // 假设从上下文中获取操作者ID
|
||||
if err := h.userService.UpdateUser(c.Request.Context(), &user, operatorID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
// DeleteUser 删除用户
|
||||
func (h *TeamHandler) DeleteUser(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
|
||||
if userToken.User.Role < int(consts.RoleAdmin) { // 用户只能删除自己的token
|
||||
err := h.tokenService.Delete(c.Request.Context(), int(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := h.userService.DeleteUser(c.Request.Context(), id, userToken.UserID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
func (h *TeamHandler) ListUsages(c *gin.Context) {
|
||||
fromStr := c.Query("from")
|
||||
toStr := c.Query("to")
|
||||
|
||||
var from, to time.Time
|
||||
loc, _ := time.LoadLocation("Local")
|
||||
|
||||
var listUsage []*dto.UsageInfo
|
||||
var err error
|
||||
|
||||
if fromStr != "" && toStr != "" {
|
||||
|
||||
from, err = time.Parse("2006-01-02", fromStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid from date"})
|
||||
return
|
||||
}
|
||||
to, err = time.Parse("2006-01-02", toStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid to date"})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
year, month, _ := time.Now().In(loc).Date()
|
||||
from = time.Date(year, month, 1, 0, 0, 0, 0, loc)
|
||||
to = from.AddDate(0, 1, 0)
|
||||
}
|
||||
|
||||
token, _ := c.Get("token")
|
||||
userToken := token.(*model.Token)
|
||||
if userToken.User.Role < int(consts.RoleAdmin) {
|
||||
listUsage, err = h.usageService.ListByDateRange(c.Request.Context(), from, to, map[string]interface{}{"user_id": userToken.UserID})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
listUsage, err = h.usageService.ListByDateRange(c.Request.Context(), from, to, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, listUsage)
|
||||
|
||||
}
|
||||
|
||||
// ListUsers 获取用户列表
|
||||
func (h *TeamHandler) ListUsers(c *gin.Context) {
|
||||
pageStr := c.DefaultQuery("page", "1")
|
||||
pageSizeStr := c.DefaultQuery("pageSize", "100")
|
||||
|
||||
page, err := strconv.Atoi(pageStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid page number"})
|
||||
return
|
||||
}
|
||||
|
||||
pageSize, err := strconv.Atoi(pageSizeStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid page size"})
|
||||
return
|
||||
}
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
if userToken.User.Role < int(consts.RoleAdmin) {
|
||||
tokens, _, err := h.tokenService.ListsWithFilters(c, 0, 100, map[string]interface{}{"user_id": userToken.UserID})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
var userDTOs []dto.UserInfo
|
||||
for _, token := range tokens {
|
||||
userDTOs = append(userDTOs, dto.UserInfo{
|
||||
ID: int(token.User.ID),
|
||||
Name: token.User.Name,
|
||||
Token: token.Key,
|
||||
Status: utils.ToPtr(token.User.Status == consts.StatusEnabled),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, userDTOs)
|
||||
return
|
||||
}
|
||||
|
||||
users, _, err := h.userService.ListUsers(c.Request.Context(), page, pageSize)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var userDTOs []dto.UserInfo
|
||||
for _, user := range users {
|
||||
useres := dto.UserInfo{
|
||||
ID: int(user.ID),
|
||||
Name: user.Name,
|
||||
|
||||
Status: utils.ToPtr(user.Status == consts.StatusEnabled),
|
||||
}
|
||||
if len(user.Tokens) > 0 {
|
||||
useres.Token = user.Tokens[0].Key
|
||||
}
|
||||
userDTOs = append(userDTOs, useres)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, userDTOs)
|
||||
}
|
||||
|
||||
func (h *TeamHandler) ResetUserToken(c *gin.Context) {
|
||||
idstr := c.Param("id")
|
||||
id, err := strconv.Atoi(idstr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
|
||||
findtoken, err := h.tokenService.GetByUserID(c, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
findtoken.Key = "team-" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
if userToken.User.Role < int(consts.RoleAdmin) { // 非管理员只能修改自己的token
|
||||
if userToken.User.Role <= findtoken.User.Role || userToken.UserID != findtoken.UserID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "forbidden"})
|
||||
return
|
||||
}
|
||||
err := h.tokenService.UpdateWithCondition(c, findtoken, map[string]interface{}{"user_id": userToken.UserID}, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := h.tokenService.Update(c, findtoken); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, dto.UserInfo{
|
||||
ID: int(findtoken.User.ID),
|
||||
Name: findtoken.User.Name,
|
||||
Token: findtoken.Key,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *TeamHandler) CreateKey(c *gin.Context) {
|
||||
token, exists := c.Get("token")
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "token not found"})
|
||||
return
|
||||
}
|
||||
userToken := token.(*model.Token)
|
||||
if userToken.User.Role < int(consts.RoleAdmin) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "forbidden"})
|
||||
return
|
||||
}
|
||||
|
||||
var key dto.KeyInfo
|
||||
if err := c.ShouldBindJSON(&key); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
err := h.keyService.Create(&model.ApiKey{
|
||||
Name: key.Name,
|
||||
ApiType: key.ApiType,
|
||||
ApiKey: key.Key,
|
||||
Endpoint: key.Endpoint,
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, key)
|
||||
}
|
||||
|
||||
func (h *TeamHandler) ListKeys(c *gin.Context) {
|
||||
keys, err := h.keyService.List(0, 100, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
var keysDTO []dto.KeyInfo
|
||||
for _, key := range keys {
|
||||
keylength := len(key.ApiKey) / 3
|
||||
if keylength < 1 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid key length"})
|
||||
return
|
||||
}
|
||||
keysDTO = append(keysDTO, dto.KeyInfo{
|
||||
ID: int(key.ID),
|
||||
Name: key.Name,
|
||||
ApiType: key.ApiType,
|
||||
Endpoint: key.Endpoint,
|
||||
Key: key.ApiKey[:keylength] + "****" + key.ApiKey[len(key.ApiKey)-keylength:],
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, keysDTO)
|
||||
}
|
||||
|
||||
func (h *TeamHandler) UpdateKey(c *gin.Context) {
|
||||
// 1. 获取并验证ID
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid key id"})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 解析请求体
|
||||
var updateKey dto.KeyInfo // 更明确的命名
|
||||
if err := c.ShouldBindJSON(&updateKey); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取现有记录
|
||||
existingKey, err := h.keyService.GetByID(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 使用 UpdateFields 方法统一处理字段更新
|
||||
updatedKey := updateKey.UpdateFields(existingKey)
|
||||
|
||||
// 5. 保存更新
|
||||
if err := h.keyService.Update(updatedKey); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, updatedKey)
|
||||
}
|
||||
|
||||
func (h *TeamHandler) DeleteKey(c *gin.Context) {
|
||||
// 1. 获取并验证ID
|
||||
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid key id"})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 删除记录
|
||||
if err := h.keyService.Delete(id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
func (h *TeamHandler) ChangePassword(c *gin.Context) {
|
||||
userID := c.GetInt64("userID") // 假设从上下文中获取用户ID
|
||||
|
||||
var req struct {
|
||||
OldPassword string `json:"oldPassword"`
|
||||
NewPassword string `json:"newPassword"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid input"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.userService.ChangePassword(c.Request.Context(), userID, req.OldPassword, req.NewPassword); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
func (h *TeamHandler) ResetPassword(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
|
||||
operatorID := c.GetInt64("userID") // 假设从上下文中获取操作者ID
|
||||
if err := h.userService.ResetPassword(c.Request.Context(), id, operatorID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "password reset successfully"})
|
||||
}
|
||||
|
||||
// EnableUser 启用用户
|
||||
func (h *TeamHandler) EnableUser(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
|
||||
operatorID := c.GetInt64("userID") // 假设从上下文中获取操作者ID
|
||||
if err := h.userService.EnableUser(c.Request.Context(), id, operatorID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "user enabled successfully"})
|
||||
}
|
||||
|
||||
// DisableUser 禁用用户
|
||||
func (h *TeamHandler) DisableUser(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
|
||||
operatorID := c.GetInt64("userID") // 假设从上下文中获取操作者ID
|
||||
if err := h.userService.DisableUser(c.Request.Context(), id, operatorID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "user disabled successfully"})
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package model
|
||||
|
||||
// 用户的token
|
||||
type Token struct {
|
||||
Id int64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
UserId int64 `gorm:"column:user_id;not null;index:idx_token_user_id"`
|
||||
Name string `gorm:"column:name;index:idx_token_name"`
|
||||
Key string `gorm:"column:key;type:char(48);uniqueIndex:idx_token_key"`
|
||||
Status bool `gorm:"column:status;default:true"` // enabled 0, disabled 1
|
||||
Quota int64 `gorm:"column:quota;type:bigint;default:0"` // -1 means unlimited
|
||||
UnlimitedQuota bool `gorm:"column:unlimited_quota;default:true"`
|
||||
UsedQuota int64 `gorm:"column:used_quota;type:bigint;default:0"`
|
||||
CreatedTime int64 `gorm:"column:created_time;type:bigint"`
|
||||
ExpiredTime int64 `gorm:"column:expired_time;type:bigint;default:-1"` // -1 means never expired
|
||||
}
|
||||
|
||||
func (Token) TableName() string {
|
||||
return "token"
|
||||
}
|
||||
@@ -1,26 +1,22 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
import "github.com/lib/pq" //pq.StringArray
|
||||
|
||||
type ApiKey_PG struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
Name string `gorm:"column:name;not null;unique;index:idx_apikey_name"`
|
||||
ApiType string `gorm:"column:apitype;not null;unique;index:idx_apikey_apitype"`
|
||||
ApiKey string `gorm:"column:apikey;not null;unique;index:idx_apikey_apikey"`
|
||||
Status int `gorm:"type:int;default:0"` // enabled 0, disabled 1
|
||||
Endpoint string `gorm:"column:endpoint"`
|
||||
ResourceNmae string `gorm:"column:resource_name"`
|
||||
DeploymentName string `gorm:"column:deployment_name"`
|
||||
ApiKey string `gorm:"column:apikey;not null;unique;uniqueIndex:idx_apikey"`
|
||||
Status int `gorm:"type:int;default:1"` // enabled 1, disabled 0
|
||||
Endpoint string `gorm:"column:endpoint;comment:接入点"`
|
||||
ResourceNmae string `gorm:"column:resource_name;comment:azure资源名称"`
|
||||
DeploymentName string `gorm:"column:deployment_name;comment:azure部署名称"`
|
||||
ApiSecret string `gorm:"column:api_secret"`
|
||||
ModelPrefix string `gorm:"column:model_prefix"`
|
||||
ModelAlias string `gorm:"column:model_alias"`
|
||||
SupportModels pq.StringArray `gorm:"type:text[]"`
|
||||
UpdatedAt time.Time `json:"updatedAt,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt,omitempty"`
|
||||
ModelPrefix string `gorm:"column:model_prefix;comment:模型前缀"`
|
||||
ModelAlias string `gorm:"column:model_alias;comment:模型别名"`
|
||||
SupportModels pq.StringArray `gorm:"column:support_models;type:text[]"`
|
||||
CreatedAt int64 `gorm:"column:created_at;autoUpdateTime" json:"created_at,omitempty"`
|
||||
UpdatedAt int64 `gorm:"column:updated_at;autoCreateTime" json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
func (ApiKey_PG) TableName() string {
|
||||
@@ -28,20 +24,20 @@ func (ApiKey_PG) TableName() string {
|
||||
}
|
||||
|
||||
type ApiKey struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
Name string `gorm:"column:name;not null;unique;index:idx_apikey_name"`
|
||||
ApiType string `gorm:"column:apitype;not null;unique;index:idx_apikey_apitype"`
|
||||
ApiKey string `gorm:"column:apikey;not null;unique;index:idx_apikey_apikey"`
|
||||
Status int `json:"status" gorm:"type:int;default:0"` // enabled 0, disabled 1
|
||||
Endpoint string `gorm:"column:endpoint"`
|
||||
ResourceNmae string `gorm:"column:resource_name"`
|
||||
DeploymentName string `gorm:"column:deployment_name"`
|
||||
ApiSecret string `gorm:"column:api_secret"`
|
||||
ModelPrefix string `gorm:"column:model_prefix"`
|
||||
ModelAlias string `gorm:"column:model_alias"`
|
||||
SupportModels []string `gorm:"type:json"`
|
||||
CreatedAt time.Time `json:"created_at,omitempty" gorm:"autoUpdateTime"`
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty" gorm:"autoCreateTime"`
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
Name string `gorm:"column:name;not null;unique;index:idx_apikey_name"`
|
||||
ApiType string `gorm:"column:apitype;not null;unique;index:idx_apikey_apitype"`
|
||||
ApiKey string `gorm:"column:apikey;not null;unique;index:idx_apikey_apikey"`
|
||||
Status int `gorm:"type:int;default:1"` // enabled 1, disabled 0
|
||||
Endpoint string `gorm:"column:endpoint"`
|
||||
ResourceNmae string `gorm:"column:resource_name"`
|
||||
DeploymentName string `gorm:"column:deployment_name"`
|
||||
ApiSecret string `gorm:"column:api_secret"`
|
||||
ModelPrefix string `gorm:"column:model_prefix"`
|
||||
ModelAlias string `gorm:"column:model_alias"`
|
||||
SupportModels []string `gorm:"column:support_models;type:json"`
|
||||
CreatedAt int64 `gorm:"column:created_at;autoUpdateTime" json:"created_at,omitempty"`
|
||||
UpdatedAt int64 `gorm:"column:updated_at;autoCreateTime" json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
func (ApiKey) TableName() string {
|
||||
|
||||
20
team/model/token.go
Normal file
20
team/model/token.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package model
|
||||
|
||||
// 用户的token
|
||||
type Token struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
UserID int64 `gorm:"column:user_id;not null;index:idx_token_user_id"`
|
||||
Name string `gorm:"column:name;index:idx_token_name"`
|
||||
Key string `gorm:"column:key;not null;uniqueIndex:idx_token_key;comment:token key"`
|
||||
Status int64 `gorm:"column:status;default:1;check:status IN (0,1)"` // enabled 1, disabled 0
|
||||
Quota int64 `gorm:"column:quota;type:bigint;default:0"` // default 0
|
||||
UnlimitedQuota bool `gorm:"column:unlimited_quota;default:true"` // set Quota 1 unlimited
|
||||
UsedQuota int64 `gorm:"column:used_quota;type:bigint;default:0"`
|
||||
CreatedAt int64 `gorm:"column:created_at;type:bigint;autoCreateTime"`
|
||||
ExpiredAt int64 `gorm:"column:expired_at;type:bigint;default:-1"` // -1 means never expired
|
||||
User User `gorm:"foreignKey:UserID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE" json:"user"`
|
||||
}
|
||||
|
||||
func (Token) TableName() string {
|
||||
return "tokens"
|
||||
}
|
||||
@@ -8,40 +8,42 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type DailyUsage struct {
|
||||
type Usage struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
UserID int64 `gorm:"column:user_id;index:idx_user_id"`
|
||||
TokenId int64 `gorm:"column:token_id;index:idx_token_id"`
|
||||
Capability string `gorm:"column:capability;index:idx_capability;comment:模型能力"`
|
||||
TokenID int64 `gorm:"column:token_id;index:idx_token_id"`
|
||||
Capability string `gorm:"column:capability;index:idx_usage_capability;comment:模型能力"`
|
||||
Date time.Time `gorm:"column:date;autoCreateTime;index:idx_date"`
|
||||
Model string `gorm:"column:model"`
|
||||
Stream bool `gorm:"column:stream"`
|
||||
PromptTokens int `gorm:"column:prompt_tokens"`
|
||||
CompletionTokens int `gorm:"column:completion_tokens"`
|
||||
TotalTokens int `gorm:"column:total_tokens"`
|
||||
PromptTokens float64 `gorm:"column:prompt_tokens"`
|
||||
CompletionTokens float64 `gorm:"column:completion_tokens"`
|
||||
TotalTokens float64 `gorm:"column:total_tokens"`
|
||||
Cost string `gorm:"column:cost"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"`
|
||||
}
|
||||
|
||||
func (DailyUsage) TableName() string {
|
||||
return "daily_usages"
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
ID int `gorm:"column:id"`
|
||||
UserID int `gorm:"column:user_id"`
|
||||
SKU string `gorm:"column:sku"`
|
||||
PromptUnits int `gorm:"column:prompt_units"`
|
||||
CompletionUnits int `gorm:"column:completion_units"`
|
||||
TotalUnit int `gorm:"column:total_unit"`
|
||||
Cost string `gorm:"column:cost"`
|
||||
Date time.Time `gorm:"column:date"`
|
||||
}
|
||||
|
||||
func (Usage) TableName() string {
|
||||
return "usages"
|
||||
}
|
||||
|
||||
type DailyUsage struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
UserID int64 `gorm:"column:user_id;uniqueIndex:idx_daily_unique,priority:1"`
|
||||
TokenID int64 `gorm:"column:token_id;index:idx_daily_token_id"`
|
||||
Capability string `gorm:"column:capability;uniqueIndex:idx_daily_unique,priority:2;comment:模型能力"`
|
||||
Date time.Time `gorm:"column:date;autoCreateTime;uniqueIndex:idx_daily_unique,priority:3"`
|
||||
Model string `gorm:"column:model"`
|
||||
Stream bool `gorm:"column:stream"`
|
||||
PromptTokens float64 `gorm:"column:prompt_tokens"`
|
||||
CompletionTokens float64 `gorm:"column:completion_tokens"`
|
||||
TotalTokens float64 `gorm:"column:total_tokens"`
|
||||
Cost string `gorm:"column:cost"`
|
||||
}
|
||||
|
||||
func (DailyUsage) TableName() string {
|
||||
return "daily_usages"
|
||||
}
|
||||
|
||||
func HandleUsage(c *gin.Context) {
|
||||
fromStr := c.Query("from")
|
||||
toStr := c.Query("to")
|
||||
|
||||
@@ -1,99 +1,34 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"opencatd-open/store"
|
||||
"time"
|
||||
|
||||
"github.com/Sakurasan/to"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID int64 `json:"id" gorm:"primaryKey;autoIncrement"`
|
||||
Name string `json:"name" gorm:"unique;index" validate:"max=12"`
|
||||
UserName string `json:"username" gorm:"unique;index" validate:"max=12"`
|
||||
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
|
||||
Role int `json:"role" gorm:"type:int;default:1"` // default user
|
||||
Status int `json:"status" gorm:"type:int;default:0"` // enabled 0, disabled 1 deleted 2
|
||||
Nickname string `json:"nickname" gorm:"type:varchar(50)"`
|
||||
AvatarURL string `json:"avatar_url" gorm:"type:varchar(255)"`
|
||||
Email string `json:"email" gorm:"type:varchar(255);unique;index"`
|
||||
Quota int64 `json:"quota" gorm:"bigint;default:-1"` // default unlimited
|
||||
Token string `json:"token,omitempty"`
|
||||
Timezone string `json:"timezone" gorm:"type:varchar(50)"`
|
||||
Language string `json:"language" gorm:"type:varchar(50)"`
|
||||
ID int64 `json:"id" gorm:"column:id;primaryKey;autoIncrement"`
|
||||
Name string `json:"name" gorm:"column:name;not null;unique;index"`
|
||||
Username string `json:"username" gorm:"column:username;unique;index"`
|
||||
Password string `json:"password" gorm:"column:password;"`
|
||||
Role int `json:"role" gorm:"column:role;type:int;default:0"` // default user 0-10-20
|
||||
Status int `json:"status" gorm:"column:status;type:int;default:1"` // disabled 0, enabled 1, deleted 2
|
||||
Nickname string `json:"nickname" gorm:"column:nickname;type:varchar(50)"`
|
||||
AvatarURL string `json:"avatar_url" gorm:"column:avatar_url;type:varchar(255)"`
|
||||
Email string `json:"email" gorm:"column:email;type:varchar(255);index"`
|
||||
Quota int64 `json:"quota" gorm:"column:quota;bigint;default:0"` // default unlimited
|
||||
UnlimitedQuota int `json:"unlimited_quota" gorm:"column:unlimited_quota;default:1;check:(unlimited_quota IN (0,1))"` // 0 limited , 1 unlimited
|
||||
Timezone string `json:"timezone" gorm:"column:timezone;type:varchar(50)"`
|
||||
Language string `json:"language" gorm:"column:language;type:varchar(50)"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at,omitempty" gorm:"autoCreateTime"`
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty" gorm:"autoUpdateTime"`
|
||||
// 添加一对多关系
|
||||
// Token string `json:"-" gorm:"column:token;type:varchar(64);unique;index"`
|
||||
Tokens []Token `json:"-" gorm:"foreignKey:UserID;references:ID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
|
||||
|
||||
CreatedAt int64 `json:"created_at,omitempty" gorm:"autoCreateTime"`
|
||||
UpdatedAt int64 `json:"updated_at,omitempty" gorm:"autoUpdateTime"`
|
||||
}
|
||||
|
||||
func HandleUsers(c *gin.Context) {
|
||||
users, err := store.GetAllUsers()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, users)
|
||||
}
|
||||
|
||||
func HandleAddUser(c *gin.Context) {
|
||||
var body User
|
||||
if err := c.BindJSON(&body); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if len(body.Name) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"error": "invalid user name"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := store.AddUser(body.Name, uuid.NewString()); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
u, err := store.GetUserByName(body.Name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, u)
|
||||
}
|
||||
|
||||
func HandleDelUser(c *gin.Context) {
|
||||
id := to.Int(c.Param("id"))
|
||||
if id <= 1 {
|
||||
c.JSON(http.StatusOK, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
if err := store.DeleteUser(uint(id)); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
func HandleResetUserToken(c *gin.Context) {
|
||||
id := to.Int(c.Param("id"))
|
||||
newtoken := c.Query("token")
|
||||
if newtoken == "" {
|
||||
newtoken = uuid.NewString()
|
||||
}
|
||||
|
||||
if err := store.UpdateUser(uint(id), newtoken); err != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
u, err := store.GetUserByID(uint(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, u)
|
||||
func (User) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
@@ -104,8 +39,9 @@ type Session struct {
|
||||
DeviceName string `json:"device_name" gorm:"type:varchar(100);default:''"`
|
||||
LastActiveAt time.Time `json:"last_active_at" gorm:"type:timestamp;default:CURRENT_TIMESTAMP"`
|
||||
LogoutAt time.Time `json:"logout_at" gorm:"type:timestamp;null"`
|
||||
CreatedAt time.Time `json:"created_at" gorm:"type:timestamp;not null;default:CURRENT_TIMESTAMP"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"type:timestamp;not null;default:CURRENT_TIMESTAMP;update:CURRENT_TIMESTAMP"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at" gorm:"type:timestamp;not null;default:CURRENT_TIMESTAMP"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"type:timestamp;not null;default:CURRENT_TIMESTAMP;update:CURRENT_TIMESTAMP"`
|
||||
}
|
||||
|
||||
func (Session) TableName() string {
|
||||
|
||||
152
team/service/apikey.go
Normal file
152
team/service/apikey.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"opencatd-open/team/dao"
|
||||
"opencatd-open/team/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ ApiKeyService = (*ApiKeyServiceImpl)(nil)
|
||||
|
||||
type ApiKeyService interface {
|
||||
Create(apiKey *model.ApiKey) error
|
||||
GetByID(id int64) (*model.ApiKey, error)
|
||||
GetByName(name string) (*model.ApiKey, error)
|
||||
GetByApiKey(apiKeyValue string) (*model.ApiKey, error)
|
||||
Update(apiKey *model.ApiKey) error
|
||||
Delete(id int64) error
|
||||
List(offset, limit int, status *int) ([]model.ApiKey, error)
|
||||
ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.ApiKey, int64, error)
|
||||
Enable(id int64) error
|
||||
Disable(id int64) error
|
||||
BatchEnable(ids []int64) error
|
||||
BatchDisable(ids []int64) error
|
||||
BatchDelete(ids []int64) error
|
||||
Count() (int64, error)
|
||||
}
|
||||
|
||||
type ApiKeyServiceImpl struct {
|
||||
apiKeyRepo dao.ApiKeyRepository
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewApiKeyService(apiKeyDao dao.ApiKeyRepository, db *gorm.DB) ApiKeyService {
|
||||
return &ApiKeyServiceImpl{apiKeyRepo: apiKeyDao, db: db}
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Create(apiKey *model.ApiKey) error {
|
||||
if apiKey == nil {
|
||||
return errors.New("apiKey不能为空")
|
||||
}
|
||||
if apiKey.Name == "" {
|
||||
return errors.New("apiKey名称不能为空")
|
||||
}
|
||||
if apiKey.ApiKey == "" {
|
||||
return errors.New("apiKey值不能为空")
|
||||
}
|
||||
apiKey.CreatedAt = time.Now().Unix()
|
||||
apiKey.UpdatedAt = time.Now().Unix()
|
||||
|
||||
return s.apiKeyRepo.Create(apiKey)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) GetByID(id int64) (*model.ApiKey, error) {
|
||||
if id <= 0 {
|
||||
return nil, errors.New("id 必须大于 0")
|
||||
}
|
||||
return s.apiKeyRepo.GetByID(id)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) GetByName(name string) (*model.ApiKey, error) {
|
||||
if name == "" {
|
||||
return nil, errors.New("name 不能为空")
|
||||
}
|
||||
return s.apiKeyRepo.GetByName(name)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) GetByApiKey(apiKeyValue string) (*model.ApiKey, error) {
|
||||
if apiKeyValue == "" {
|
||||
return nil, errors.New("apiKeyValue 不能为空")
|
||||
}
|
||||
return s.apiKeyRepo.GetByApiKey(apiKeyValue)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Update(apiKey *model.ApiKey) error {
|
||||
if apiKey == nil {
|
||||
return errors.New("apiKey不能为空")
|
||||
}
|
||||
if apiKey.ID <= 0 {
|
||||
return errors.New("apiKey ID 必须大于 0")
|
||||
}
|
||||
return s.apiKeyRepo.Update(apiKey)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Delete(id int64) error {
|
||||
if id <= 0 {
|
||||
return errors.New("id 必须大于 0")
|
||||
}
|
||||
return s.apiKeyRepo.Delete(id)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) List(offset, limit int, status *int) ([]model.ApiKey, error) {
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = 10 // 设置默认值
|
||||
}
|
||||
return s.apiKeyRepo.List(offset, limit, status)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.ApiKey, int64, error) {
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = 10 // 设置默认值
|
||||
}
|
||||
|
||||
return s.apiKeyRepo.ListWithFilters(offset, limit, filters)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Enable(id int64) error {
|
||||
if id <= 0 {
|
||||
return errors.New("id 必须大于 0")
|
||||
}
|
||||
return s.apiKeyRepo.Enable(id)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Disable(id int64) error {
|
||||
if id <= 0 {
|
||||
return errors.New("id 必须大于 0")
|
||||
}
|
||||
return s.apiKeyRepo.Disable(id)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) BatchEnable(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids 不能为空")
|
||||
}
|
||||
return s.apiKeyRepo.BatchEnable(ids)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) BatchDisable(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids 不能为空")
|
||||
}
|
||||
return s.apiKeyRepo.BatchDisable(ids)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) BatchDelete(ids []int64) error {
|
||||
if len(ids) == 0 {
|
||||
return errors.New("ids 不能为空")
|
||||
}
|
||||
return s.apiKeyRepo.BatchDelete(ids)
|
||||
}
|
||||
|
||||
func (s *ApiKeyServiceImpl) Count() (int64, error) {
|
||||
return s.apiKeyRepo.Count()
|
||||
}
|
||||
97
team/service/token.go
Normal file
97
team/service/token.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"opencatd-open/team/dao"
|
||||
"opencatd-open/team/model"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// 确保 TokenService 实现了 TokenServiceInterface 接口
|
||||
var _ TokenService = (*TokenServiceImpl)(nil)
|
||||
|
||||
type TokenService interface {
|
||||
Create(ctx context.Context, token *model.Token) error
|
||||
GetByID(ctx context.Context, id int) (*model.Token, error)
|
||||
GetByKey(ctx context.Context, key string) (*model.Token, error)
|
||||
GetByUserID(ctx context.Context, userID int) (*model.Token, error)
|
||||
Update(ctx context.Context, token *model.Token) error
|
||||
UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error
|
||||
Delete(ctx context.Context, id int) error
|
||||
Lists(ctx context.Context, offset, limit int) ([]model.Token, error)
|
||||
ListsWithFilters(ctx context.Context, offset, limit int, filters map[string]interface{}) ([]model.Token, int64, error)
|
||||
Disable(ctx context.Context, id int) error
|
||||
Enable(ctx context.Context, id int) error
|
||||
BatchDisable(ctx context.Context, ids []int) error
|
||||
BatchEnable(ctx context.Context, ids []int) error
|
||||
BatchDelete(ctx context.Context, ids []int) error
|
||||
}
|
||||
|
||||
type TokenServiceImpl struct {
|
||||
tokenRepo dao.TokenRepository
|
||||
}
|
||||
|
||||
func NewTokenService(tokenRepo dao.TokenRepository) TokenService {
|
||||
return &TokenServiceImpl{tokenRepo: tokenRepo}
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Create(ctx context.Context, token *model.Token) error {
|
||||
if token.Key == "" {
|
||||
token.Key = "team-" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
}
|
||||
return s.tokenRepo.Create(ctx, token)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) GetByID(ctx context.Context, id int) (*model.Token, error) {
|
||||
return s.tokenRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) GetByKey(ctx context.Context, key string) (*model.Token, error) {
|
||||
return s.tokenRepo.GetByKey(ctx, key)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) GetByUserID(ctx context.Context, userID int) (*model.Token, error) {
|
||||
return s.tokenRepo.GetByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Update(ctx context.Context, token *model.Token) error {
|
||||
return s.tokenRepo.Update(ctx, token)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error {
|
||||
return s.tokenRepo.UpdateWithCondition(ctx, token, filters, updates)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Delete(ctx context.Context, id int) error {
|
||||
return s.tokenRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Lists(ctx context.Context, offset, limit int) ([]model.Token, error) {
|
||||
return s.tokenRepo.List(ctx, offset, limit)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) ListsWithFilters(ctx context.Context, offset, limit int, filters map[string]interface{}) ([]model.Token, int64, error) {
|
||||
return s.tokenRepo.ListWithFilters(ctx, offset, limit, filters)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Disable(ctx context.Context, id int) error {
|
||||
return s.tokenRepo.Disable(ctx, id)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) Enable(ctx context.Context, id int) error {
|
||||
return s.tokenRepo.Enable(ctx, id)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) BatchDisable(ctx context.Context, ids []int) error {
|
||||
return s.tokenRepo.BatchDisable(ctx, ids)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) BatchEnable(ctx context.Context, ids []int) error {
|
||||
return s.tokenRepo.BatchEnable(ctx, ids)
|
||||
}
|
||||
|
||||
func (s *TokenServiceImpl) BatchDelete(ctx context.Context, ids []int) error {
|
||||
return s.tokenRepo.BatchDelete(ctx, ids)
|
||||
}
|
||||
137
team/service/usage.go
Normal file
137
team/service/usage.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"opencatd-open/team/dao"
|
||||
dto "opencatd-open/team/dto/team"
|
||||
"opencatd-open/team/model"
|
||||
"time"
|
||||
|
||||
"log"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
var _ UsageService = (*usageService)(nil)
|
||||
|
||||
type UsageService interface {
|
||||
// AsyncProcessUsage 异步处理使用记录
|
||||
AsyncProcessUsage(usage *model.Usage)
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error)
|
||||
ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error)
|
||||
ListByDateRange(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error)
|
||||
|
||||
Delete(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
type usageService struct {
|
||||
db *gorm.DB
|
||||
usageDAO dao.UsageRepository
|
||||
dailyUsageDAO dao.DailyUsageRepository
|
||||
usageChan chan *model.Usage // 用于异步处理的channel
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewUsageService(ctx context.Context, db *gorm.DB, usageRepo dao.UsageRepository, dailyUsageRepo dao.DailyUsageRepository) UsageService {
|
||||
srv := &usageService{
|
||||
db: db,
|
||||
usageDAO: usageRepo,
|
||||
dailyUsageDAO: dailyUsageRepo,
|
||||
usageChan: make(chan *model.Usage, 1000), // 设置合适的缓冲区大小
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
// 启动异步处理goroutine
|
||||
go srv.processUsageWorker()
|
||||
|
||||
return srv
|
||||
}
|
||||
|
||||
func (s *usageService) AsyncProcessUsage(usage *model.Usage) {
|
||||
select {
|
||||
case s.usageChan <- usage:
|
||||
// 成功发送到channel
|
||||
default:
|
||||
// channel已满,记录错误日志
|
||||
log.Println("usage channel is full, skip processing")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *usageService) processUsageWorker() {
|
||||
for {
|
||||
select {
|
||||
case usage := <-s.usageChan:
|
||||
err := s.processUsage(usage)
|
||||
if err != nil {
|
||||
log.Println("processUsage error:", err)
|
||||
}
|
||||
case <-s.ctx.Done():
|
||||
log.Println("processUsageWorker is exiting")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processUsageWorker 异步处理worker
|
||||
func (s *usageService) processUsage(usage *model.Usage) error {
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 1. 记录使用记录
|
||||
if err := tx.WithContext(s.ctx).Create(usage).Error; err != nil {
|
||||
return fmt.Errorf("create usage error: %w", err)
|
||||
}
|
||||
|
||||
// 2. 更新每日统计(upsert 操作)
|
||||
dailyUsage := model.DailyUsage{
|
||||
UserID: usage.UserID,
|
||||
TokenID: usage.TokenID,
|
||||
Capability: usage.Capability,
|
||||
Date: time.Date(usage.Date.Year(), usage.Date.Month(), usage.Date.Day(), 0, 0, 0, 0, usage.Date.Location()),
|
||||
Model: usage.Model,
|
||||
Stream: usage.Stream,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
TotalTokens: usage.TotalTokens,
|
||||
Cost: usage.Cost,
|
||||
}
|
||||
|
||||
// 使用 OnConflict 实现 upsert
|
||||
if err := tx.WithContext(s.ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "user_id"}, {Name: "token_id"}, {Name: "capability"}, {Name: "date"}}, // 唯一键
|
||||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||||
"prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens),
|
||||
"completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens),
|
||||
"total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens),
|
||||
"cost": gorm.Expr("cost + ?", usage.Cost),
|
||||
}),
|
||||
}).Create(&dailyUsage).Error; err != nil {
|
||||
return fmt.Errorf("upsert daily usage error: %w", err)
|
||||
}
|
||||
|
||||
// 3. 更新用户额度
|
||||
if err := tx.WithContext(s.ctx).Model(&model.User{}).Where("id = ?", usage.UserID).Update("quota", gorm.Expr("quota - ?", usage.Cost)).Error; err != nil {
|
||||
return fmt.Errorf("update user quota error: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *usageService) ListByUserID(ctx context.Context, userID int64, limit int, offset int) ([]*model.Usage, error) {
|
||||
return s.usageDAO.ListByUserID(ctx, userID, limit, offset)
|
||||
}
|
||||
|
||||
func (s *usageService) ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) {
|
||||
return s.usageDAO.ListByCapability(ctx, capability, limit, offset)
|
||||
}
|
||||
|
||||
func (s *usageService) ListByDateRange(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) {
|
||||
return s.dailyUsageDAO.StatUserUsages(ctx, start, end, filters)
|
||||
}
|
||||
|
||||
func (s *usageService) Delete(ctx context.Context, id int64) error {
|
||||
return s.usageDAO.Delete(ctx, id)
|
||||
}
|
||||
702
team/service/user.go
Normal file
702
team/service/user.go
Normal file
@@ -0,0 +1,702 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"opencatd-open/team/consts"
|
||||
"opencatd-open/team/dao"
|
||||
"opencatd-open/team/model"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/exp/rand"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrInvalidUserInput = errors.New("invalid user input")
|
||||
ErrUserExists = errors.New("user already exists")
|
||||
ErrInvalidPassword = errors.New("invalid password format")
|
||||
ErrPermissionDenied = errors.New("permission denied")
|
||||
ErrInvalidOperation = errors.New("invalid operation")
|
||||
ErrTransactionFailed = errors.New("transaction failed")
|
||||
)
|
||||
|
||||
// PasswordPolicy 定义密码策略
|
||||
type PasswordPolicy struct {
|
||||
MinLength int
|
||||
MaxLength int
|
||||
NeedNumber bool
|
||||
NeedUpper bool
|
||||
NeedLower bool
|
||||
NeedSymbol bool
|
||||
}
|
||||
|
||||
// 生成随机数字
|
||||
func generateNumber() string {
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
return string('0' + rand.Intn(10))
|
||||
}
|
||||
|
||||
// 生成随机大写字母
|
||||
func generateUpper() string {
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
return string('A' + rand.Intn(26))
|
||||
}
|
||||
|
||||
// 生成随机小写字母
|
||||
func generateLower() string {
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
return string('a' + rand.Intn(26))
|
||||
}
|
||||
|
||||
// 生成随机特殊符号
|
||||
func generateSymbol() string {
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
symbols := "!@#$%^&*"
|
||||
return string(symbols[rand.Intn(len(symbols))])
|
||||
}
|
||||
|
||||
// GeneratePassword 根据密码策略生成密码
|
||||
func GeneratePassword(policy PasswordPolicy) string {
|
||||
rand.Seed(uint64(time.Now().UnixNano()))
|
||||
|
||||
// 确保满足所有必须的字符类型
|
||||
var password string
|
||||
if policy.NeedNumber {
|
||||
password += generateNumber()
|
||||
}
|
||||
if policy.NeedUpper {
|
||||
password += generateUpper()
|
||||
}
|
||||
if policy.NeedLower {
|
||||
password += generateLower()
|
||||
}
|
||||
if policy.NeedSymbol {
|
||||
password += generateSymbol()
|
||||
}
|
||||
|
||||
// 计算还需要多少个字符
|
||||
remainingLength := policy.MinLength - len(password)
|
||||
if remainingLength < 0 {
|
||||
remainingLength = 0
|
||||
}
|
||||
// 剩余长度随机生成密码字符
|
||||
for i := 0; i < remainingLength; i++ {
|
||||
randType := rand.Intn(4) // 0:数字, 1:大写, 2:小写, 3:符号
|
||||
switch randType {
|
||||
case 0:
|
||||
password += generateNumber()
|
||||
case 1:
|
||||
password += generateUpper()
|
||||
case 2:
|
||||
password += generateLower()
|
||||
case 3:
|
||||
password += generateSymbol()
|
||||
}
|
||||
}
|
||||
|
||||
// 如果密码长度超过最大值,则截断
|
||||
if len(password) > policy.MaxLength {
|
||||
password = password[:policy.MaxLength]
|
||||
}
|
||||
|
||||
// 将密码打乱
|
||||
passwordRune := []rune(password)
|
||||
rand.Shuffle(len(passwordRune), func(i, j int) {
|
||||
passwordRune[i], passwordRune[j] = passwordRune[j], passwordRune[i]
|
||||
})
|
||||
|
||||
return string(passwordRune)
|
||||
}
|
||||
|
||||
var _ UserService = (*userService)(nil)
|
||||
|
||||
// UserService 定义用户服务的接口
|
||||
type UserService interface {
|
||||
CreateUser(ctx context.Context, user *model.User) error
|
||||
GetUser(ctx context.Context, id int64) (*model.User, error)
|
||||
GetUserByUsername(ctx context.Context, username string) (*model.User, error)
|
||||
UpdateUser(ctx context.Context, user *model.User, operatorID int64) error
|
||||
DeleteUser(ctx context.Context, id int64, operatorID int64) error
|
||||
ListUsers(ctx context.Context, page, pageSize int) ([]model.User, int64, error)
|
||||
ListUsersWithFilters(ctx context.Context, page, pageSize int, filters map[string]interface{}) ([]model.User, int64, error)
|
||||
EnableUser(ctx context.Context, id int64, operatorID int64) error
|
||||
DisableUser(ctx context.Context, id int64, operatorID int64) error
|
||||
BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error
|
||||
BatchDisableUsers(ctx context.Context, ids []int64, operatorID int64) error
|
||||
BatchDeleteUsers(ctx context.Context, ids []int64, operatorID int64) error
|
||||
ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error
|
||||
ResetPassword(ctx context.Context, userID int64, operatorID int64) error
|
||||
ValidatePassword(password string) error
|
||||
CheckPermission(ctx context.Context, requiredRole consts.UserRole) error
|
||||
}
|
||||
|
||||
// userService 实现 UserService 接口
|
||||
type userService struct {
|
||||
userRepo dao.UserRepository
|
||||
db *gorm.DB
|
||||
pwdPolicy PasswordPolicy
|
||||
}
|
||||
|
||||
// NewUserService 创建 UserService 实例
|
||||
func NewUserService(userRepo dao.UserRepository, db *gorm.DB) UserService {
|
||||
return &userService{
|
||||
userRepo: userRepo,
|
||||
db: db,
|
||||
pwdPolicy: PasswordPolicy{
|
||||
MinLength: 8,
|
||||
MaxLength: 32,
|
||||
NeedNumber: true,
|
||||
NeedUpper: true,
|
||||
NeedLower: true,
|
||||
NeedSymbol: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// hashPassword 使用 bcrypt 加密密码
|
||||
func (s *userService) hashPassword(password string) (string, error) {
|
||||
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(hashedBytes), nil
|
||||
}
|
||||
|
||||
// comparePasswords 比较密码
|
||||
func (s *userService) comparePasswords(hashedPassword, password string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ValidatePassword 验证密码是否符合策略
|
||||
func (s *userService) ValidatePassword(password string) error {
|
||||
if len(password) < s.pwdPolicy.MinLength || len(password) > s.pwdPolicy.MaxLength {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
if s.pwdPolicy.NeedNumber && !regexp.MustCompile(`[0-9]`).MatchString(password) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
if s.pwdPolicy.NeedUpper && !regexp.MustCompile(`[A-Z]`).MatchString(password) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
if s.pwdPolicy.NeedLower && !regexp.MustCompile(`[a-z]`).MatchString(password) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
if s.pwdPolicy.NeedSymbol && !regexp.MustCompile(`[!@#$%^&*]`).MatchString(password) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPermission 检查用户权限
|
||||
func (s *userService) CheckPermission(ctx context.Context, requiredRole consts.UserRole) error {
|
||||
userToken := ctx.Value("Token").(*model.Token)
|
||||
|
||||
// 检查用户角色
|
||||
if userToken.User.Role < int(requiredRole) {
|
||||
return ErrPermissionDenied
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// withTransaction 事务处理封装
|
||||
func (s *userService) withTransaction(ctx context.Context, fn func(tx *gorm.DB) error) error {
|
||||
tx := s.db.WithContext(ctx).Begin()
|
||||
if tx.Error != nil {
|
||||
return ErrTransactionFailed
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return ErrTransactionFailed
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateUser 创建用户
|
||||
func (s *userService) CreateUser(ctx context.Context, user *model.User) error {
|
||||
if user == nil {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
if user.Password == "" {
|
||||
user.Password = GeneratePassword(s.pwdPolicy)
|
||||
}
|
||||
|
||||
// 使用事务处理
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查用户名是否已存在
|
||||
// _, err := s.userRepo.GetByID(user.ID)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// 加密密码
|
||||
hashedPassword, err := s.hashPassword(user.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user.Password = hashedPassword
|
||||
|
||||
return s.userRepo.Create(user)
|
||||
})
|
||||
}
|
||||
|
||||
// GetUser 根据 ID 获取用户
|
||||
func (s *userService) GetUser(ctx context.Context, id int64) (*model.User, error) {
|
||||
if id <= 0 {
|
||||
return nil, ErrInvalidUserInput
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return nil, err // 返回其他数据库错误
|
||||
}
|
||||
// 处理返回结果,清除敏感信息
|
||||
user.Password = "" // 清除密码信息
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetUserByUsername 根据用户名获取用户
|
||||
func (s *userService) GetUserByUsername(ctx context.Context, username string) (*model.User, error) {
|
||||
if username == "" {
|
||||
return nil, ErrInvalidUserInput
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByUsername(username)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return nil, err // 返回其他数据库错误
|
||||
}
|
||||
|
||||
// 处理返回结果,清除敏感信息
|
||||
user.Password = "" // 清除密码信息
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// UpdateUser 更新用户信息
|
||||
func (s *userService) UpdateUser(ctx context.Context, user *model.User, operatorID int64) error {
|
||||
if user == nil || user.ID <= 0 {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查用户是否存在
|
||||
existingUser, err := s.userRepo.GetByID(user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果修改了用户名,检查新用户名是否已存在
|
||||
if user.Username != existingUser.Username {
|
||||
tmpUser, err := s.userRepo.GetByUsername(user.Username)
|
||||
if err == nil && tmpUser != nil && tmpUser.ID != user.ID {
|
||||
return ErrUserExists
|
||||
}
|
||||
}
|
||||
|
||||
// 保持原有密码
|
||||
user.Password = existingUser.Password
|
||||
user.UpdatedAt = time.Now().Unix()
|
||||
|
||||
return s.userRepo.Update(user)
|
||||
})
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error {
|
||||
// 验证新密码
|
||||
if err := s.ValidatePassword(newPassword); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
// 验证旧密码
|
||||
if !s.comparePasswords(user.Password, oldPassword) {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
// 加密新密码
|
||||
hashedPassword, err := s.hashPassword(newPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.Password = hashedPassword
|
||||
user.UpdatedAt = time.Now().Unix()
|
||||
|
||||
return s.userRepo.Update(user)
|
||||
})
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
func (s *userService) ResetPassword(ctx context.Context, userID int64, operatorID int64) error {
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
user, err := s.userRepo.GetByID(userID)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
// 生成随机密码
|
||||
newPassword := generateRandomPassword()
|
||||
hashedPassword, err := s.hashPassword(newPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.Password = hashedPassword
|
||||
user.UpdatedAt = time.Now().Unix()
|
||||
|
||||
// TODO: 发送新密码给用户邮箱
|
||||
|
||||
return s.userRepo.Update(user)
|
||||
})
|
||||
}
|
||||
|
||||
// ListUsers 获取用户列表(增加过滤功能)
|
||||
func (s *userService) ListUsers(ctx context.Context, page, pageSize int) ([]model.User, int64, error) {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = 10
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
users, err := s.userRepo.List(offset, pageSize)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var total int64 = 0
|
||||
|
||||
return users, total, nil
|
||||
}
|
||||
|
||||
// ListUsers 获取用户列表(增加过滤功能)
|
||||
func (s *userService) ListUsersWithFilters(ctx context.Context, page, pageSize int, filters map[string]interface{}) ([]model.User, int64, error) {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = 10
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
// 使用新的 ListWithFilters 方法
|
||||
users, total, err := s.userRepo.ListWithFilters(offset, pageSize, filters)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return users, total, nil
|
||||
}
|
||||
|
||||
// generateRandomPassword 生成随机密码
|
||||
func generateRandomPassword() string {
|
||||
const (
|
||||
lowerChars = "abcdefghijklmnopqrstuvwxyz"
|
||||
upperChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
numberChars = "0123456789"
|
||||
specialChars = "!@#$%^&*"
|
||||
)
|
||||
// rand.NewSource(uint64(time.Now().UnixNano()))
|
||||
|
||||
// 确保每种字符都至少出现一次
|
||||
password := []string{
|
||||
string(lowerChars[rand.Intn(len(lowerChars))]),
|
||||
string(upperChars[rand.Intn(len(upperChars))]),
|
||||
string(numberChars[rand.Intn(len(numberChars))]),
|
||||
string(specialChars[rand.Intn(len(specialChars))]),
|
||||
}
|
||||
|
||||
// 所有可用字符
|
||||
allChars := lowerChars + upperChars + numberChars + specialChars
|
||||
|
||||
// 生成剩余的12个字符
|
||||
for i := 0; i < 12; i++ {
|
||||
password = append(password, string(allChars[rand.Intn(len(allChars))]))
|
||||
}
|
||||
|
||||
// 打乱密码字符顺序
|
||||
rand.Shuffle(len(password), func(i, j int) {
|
||||
password[i], password[j] = password[j], password[i]
|
||||
})
|
||||
|
||||
return strings.Join(password, "")
|
||||
}
|
||||
|
||||
// DeleteUser 删除用户
|
||||
func (s *userService) DeleteUser(ctx context.Context, id int64, operatorID int64) error {
|
||||
// 检查参数
|
||||
if id <= 0 {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 不允许删除自己
|
||||
if id == operatorID {
|
||||
return ErrInvalidOperation
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查用户是否存在
|
||||
user, err := s.userRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查是否试图删除管理员
|
||||
if user.Role == int(consts.RoleAdmin) {
|
||||
return ErrPermissionDenied
|
||||
}
|
||||
|
||||
return s.userRepo.Delete(id)
|
||||
})
|
||||
}
|
||||
|
||||
// EnableUser 启用用户
|
||||
func (s *userService) EnableUser(ctx context.Context, id int64, operatorID int64) error {
|
||||
// 检查参数
|
||||
if id <= 0 {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
// 检查操作者权限
|
||||
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查用户是否存在
|
||||
user, err := s.userRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
// 如果用户已经是启用状态,返回成功
|
||||
if user.Status == consts.StatusEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.userRepo.Enable(id)
|
||||
})
|
||||
}
|
||||
|
||||
// DisableUser 禁用用户
|
||||
func (s *userService) DisableUser(ctx context.Context, id int64, operatorID int64) error {
|
||||
// 检查参数
|
||||
if id <= 0 {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
// 检查操作者权限
|
||||
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 不允许禁用自己
|
||||
if id == operatorID {
|
||||
return ErrInvalidOperation
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查用户是否存在
|
||||
user, err := s.userRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
// 检查是否试图禁用超级管理员
|
||||
if user.Role == int(consts.RoleAdmin) {
|
||||
return ErrPermissionDenied
|
||||
}
|
||||
|
||||
// 如果用户已经是禁用状态,返回成功
|
||||
if user.Status == consts.StatusDisabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.userRepo.Disable(id)
|
||||
})
|
||||
}
|
||||
|
||||
// BatchEnableUsers 批量启用用户
|
||||
func (s *userService) BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error {
|
||||
// 检查参数
|
||||
if len(ids) == 0 {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
// 检查操作者权限
|
||||
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查所有用户是否存在,并收集当前状态
|
||||
enabledUsers := make([]int64, 0)
|
||||
for _, id := range ids {
|
||||
user, err := s.userRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
if user.Status == consts.StatusEnabled {
|
||||
enabledUsers = append(enabledUsers, id)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果所有用户都已经是启用状态,返回成功
|
||||
if len(enabledUsers) == len(ids) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 过滤掉已经启用的用户,只处理需要启用的用户
|
||||
toEnableIds := make([]int64, 0)
|
||||
for _, id := range ids {
|
||||
if !contains(enabledUsers, id) {
|
||||
toEnableIds = append(toEnableIds, id)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toEnableIds) > 0 {
|
||||
return s.userRepo.BatchEnable(toEnableIds)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// BatchDisableUsers 批量禁用用户
|
||||
func (s *userService) BatchDisableUsers(ctx context.Context, ids []int64, operatorID int64) error {
|
||||
// 检查参数
|
||||
if len(ids) == 0 {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
// 检查操作者权限
|
||||
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 不允许包含自己
|
||||
if contains(ids, operatorID) {
|
||||
return ErrInvalidOperation
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查所有用户是否存在
|
||||
disabledUsers := make([]int64, 0)
|
||||
for _, id := range ids {
|
||||
user, err := s.userRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
// 不允许禁用管理员
|
||||
if user.Role == int(consts.RoleAdmin) {
|
||||
return ErrPermissionDenied
|
||||
}
|
||||
if user.Status == consts.StatusDisabled {
|
||||
disabledUsers = append(disabledUsers, id)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果所有用户都已经是禁用状态,返回成功
|
||||
if len(disabledUsers) == len(ids) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 过滤掉已经禁用的用户,只处理需要禁用的用户
|
||||
toDisableIds := make([]int64, 0)
|
||||
for _, id := range ids {
|
||||
if !contains(disabledUsers, id) {
|
||||
toDisableIds = append(toDisableIds, id)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toDisableIds) > 0 {
|
||||
return s.userRepo.BatchDisable(toDisableIds)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// BatchDeleteUsers 批量删除用户
|
||||
func (s *userService) BatchDeleteUsers(ctx context.Context, ids []int64, operatorID int64) error {
|
||||
// 检查参数
|
||||
if len(ids) == 0 {
|
||||
return ErrInvalidUserInput
|
||||
}
|
||||
|
||||
// 检查操作者权限
|
||||
if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 不允许包含自己
|
||||
if contains(ids, operatorID) {
|
||||
return ErrInvalidOperation
|
||||
}
|
||||
|
||||
return s.withTransaction(ctx, func(tx *gorm.DB) error {
|
||||
// 检查所有用户是否存在,并确保不会删除管理员
|
||||
for _, id := range ids {
|
||||
user, err := s.userRepo.GetByID(id)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
if user.Role == int(consts.RoleAdmin) {
|
||||
return ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
|
||||
return s.userRepo.BatchDelete(ids)
|
||||
})
|
||||
}
|
||||
|
||||
// contains 检查切片中是否包含特定值
|
||||
func contains(slice []int64, item int64) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
47
wire/wire.go
Normal file
47
wire/wire.go
Normal file
@@ -0,0 +1,47 @@
|
||||
//go:build wireinject
|
||||
// +build wireinject
|
||||
|
||||
package wire
|
||||
|
||||
import (
|
||||
"context"
|
||||
"opencatd-open/team/dao"
|
||||
handler "opencatd-open/team/handler/team"
|
||||
"opencatd-open/team/service"
|
||||
|
||||
"github.com/google/wire"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 定义 provider set
|
||||
var userSet = wire.NewSet(
|
||||
dao.NewUserDAO,
|
||||
wire.Bind(new(dao.UserRepository), new(*dao.UserDAO)),
|
||||
service.NewUserService,
|
||||
)
|
||||
|
||||
var keySet = wire.NewSet(
|
||||
dao.NewApiKeyDAO,
|
||||
wire.Bind(new(dao.ApiKeyRepository), new(*dao.ApiKeyDAO)),
|
||||
service.NewApiKeyService,
|
||||
)
|
||||
|
||||
var tokenSet = wire.NewSet(
|
||||
dao.NewTokenDAO,
|
||||
wire.Bind(new(dao.TokenRepository), new(*dao.TokenDAO)),
|
||||
service.NewTokenService,
|
||||
)
|
||||
|
||||
var usageSet = wire.NewSet(
|
||||
dao.NewUsageDAO,
|
||||
wire.Bind(new(dao.UsageRepository), new(*dao.UsageDAO)),
|
||||
dao.NewDailyUsageDAO,
|
||||
wire.Bind(new(dao.DailyUsageRepository), new(*dao.DailyUsageDAO)),
|
||||
service.NewUsageService,
|
||||
)
|
||||
|
||||
// 初始化 TeamHandler
|
||||
func InitTeamHandler(ctx context.Context, db *gorm.DB) (*handler.TeamHandler, error) {
|
||||
wire.Build(userSet, keySet, tokenSet, usageSet, handler.NewTeamHandler)
|
||||
return nil, nil
|
||||
}
|
||||
43
wire/wire_gen.go
Normal file
43
wire/wire_gen.go
Normal file
@@ -0,0 +1,43 @@
|
||||
// Code generated by Wire. DO NOT EDIT.
|
||||
|
||||
//go:generate go run github.com/google/wire/cmd/wire
|
||||
//+build !wireinject
|
||||
|
||||
package wire
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/google/wire"
|
||||
"gorm.io/gorm"
|
||||
"opencatd-open/team/dao"
|
||||
"opencatd-open/team/handler/team"
|
||||
"opencatd-open/team/service"
|
||||
)
|
||||
|
||||
// Injectors from wire.go:
|
||||
|
||||
// 初始化 TeamHandler
|
||||
func InitTeamHandler(ctx context.Context, db *gorm.DB) (*handler.TeamHandler, error) {
|
||||
userDAO := dao.NewUserDAO(db)
|
||||
userService := service.NewUserService(userDAO, db)
|
||||
tokenDAO := dao.NewTokenDAO(db)
|
||||
tokenService := service.NewTokenService(tokenDAO)
|
||||
apiKeyDAO := dao.NewApiKeyDAO(db)
|
||||
apiKeyService := service.NewApiKeyService(apiKeyDAO, db)
|
||||
usageDAO := dao.NewUsageDAO(db)
|
||||
dailyUsageDAO := dao.NewDailyUsageDAO(db)
|
||||
usageService := service.NewUsageService(ctx, db, usageDAO, dailyUsageDAO)
|
||||
teamHandler := handler.NewTeamHandler(userService, tokenService, apiKeyService, usageService)
|
||||
return teamHandler, nil
|
||||
}
|
||||
|
||||
// wire.go:
|
||||
|
||||
// 定义 provider set
|
||||
var userSet = wire.NewSet(dao.NewUserDAO, wire.Bind(new(dao.UserRepository), new(*dao.UserDAO)), service.NewUserService)
|
||||
|
||||
var keySet = wire.NewSet(dao.NewApiKeyDAO, wire.Bind(new(dao.ApiKeyRepository), new(*dao.ApiKeyDAO)), service.NewApiKeyService)
|
||||
|
||||
var tokenSet = wire.NewSet(dao.NewTokenDAO, wire.Bind(new(dao.TokenRepository), new(*dao.TokenDAO)), service.NewTokenService)
|
||||
|
||||
var usageSet = wire.NewSet(dao.NewUsageDAO, wire.Bind(new(dao.UsageRepository), new(*dao.UsageDAO)), dao.NewDailyUsageDAO, wire.Bind(new(dao.DailyUsageRepository), new(*dao.DailyUsageDAO)), service.NewUsageService)
|
||||
Reference in New Issue
Block a user