add cli load

This commit is contained in:
Sakurasan
2025-04-18 19:05:23 +08:00
parent e0b531c578
commit f8e539c9b4
6 changed files with 324 additions and 173 deletions

View File

@@ -1,32 +1,23 @@
package main
import (
"context"
"embed"
"fmt"
"io/fs"
"log"
"net/http"
"opencatd-open/middleware"
"opencatd-open/internal/cli"
"opencatd-open/internal/consts"
"opencatd-open/pkg/config"
"opencatd-open/pkg/store"
"opencatd-open/wire"
"os"
"os/signal"
"sync"
"syscall"
"time"
"opencatd-open/router"
"github.com/gin-gonic/gin"
"github.com/spf13/cobra"
)
//go:embed dist/*
var web embed.FS
func main() {
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
cfg, err := config.LoadConfig()
if err != nil {
panic(err)
@@ -36,169 +27,20 @@ func main() {
if err != nil {
panic(err)
}
sqlDB, err := db.DB()
if err != nil {
log.Fatalf("Failed to get underlying *sql.DB: %v", err)
rootCmd := &cobra.Command{
Use: "openteam",
Short: "openteam cli",
Long: consts.Logo,
Run: func(cmd *cobra.Command, args []string) {
router.SetRouter(cfg, db, &web)
},
}
team, err := wire.InitTeamHandler(ctx, cfg, db)
if err != nil {
panic(err)
rootCmd.AddCommand(cli.LoadCmd)
if err := rootCmd.Execute(); err != nil {
log.Fatal(err)
}
api, err := wire.InitAPIHandler(ctx, cfg, db)
if err != nil {
panic(err)
}
proxy, err := wire.InitProxyHandler(ctx, cfg, db, &wg)
if err != nil {
panic(err)
}
r := gin.Default()
r.Use(middleware.CORS())
teamGroup := r.Group("/1")
teamGroup.Use(team.AuthMiddleware())
{
teamGroup.POST("/users/init", team.InitAdmin)
// 获取当前用户信息
teamGroup.GET("/me", team.Me)
// team.GET("/me/usages", team.HandleMeUsage)
teamGroup.POST("/keys", team.CreateKey)
teamGroup.GET("/keys", team.ListKeys)
teamGroup.POST("/keys/:id", team.UpdateKey)
teamGroup.DELETE("/keys/:id", team.DeleteKey)
teamGroup.POST("/users", team.CreateUser)
teamGroup.GET("/users", team.ListUsers)
teamGroup.POST("/users/:id/reset", team.ResetUserToken)
teamGroup.DELETE("/users/:id", team.DeleteUser)
teamGroup.GET("/1/usages", team.ListUsages)
}
public := r.Group("/api/auth")
{
public.GET("/passkey/begin", api.PasskeyAuthBegin)
public.POST("/passkey/finish", api.PasskeyAuthFinish)
public.POST("/register", api.Register)
public.POST("/login", api.Login)
}
apiGroup := r.Group("/api", middleware.Auth)
{
apiGroup.GET("/profile", api.Profile)
apiGroup.POST("/profile/update", api.UpdateProfile)
apiGroup.POST("/profile/update/password", api.UpdatePassword)
// 绑定PassKey
apiGroup.GET("/profile/passkey", api.PasskeyCreateBegin)
apiGroup.POST("/profile/passkey", api.PasskeyCreateFinish)
apiGroup.GET("/profile/passkeys", api.ListPasskey)
apiGroup.DELETE("/profile/passkeys/:id", api.DeletePasskey)
userGroup := apiGroup.Group("/users")
{
userGroup.POST("", api.CreateUser)
userGroup.GET("", api.ListUser)
userGroup.GET("/:id", api.GetUser)
userGroup.PUT("/:id", api.EditUser)
userGroup.DELETE("/:id", api.DeleteUser)
userGroup.POST("/batch/:option", api.UserOption)
}
tokenGroup := apiGroup.Group("/tokens")
tokenGroup.POST("", api.CreateToken)
tokenGroup.GET("", api.ListToken)
// tokenGroup.GET("/:id", api.GetToken)
tokenGroup.POST("/reset/:id", api.ResetToken)
tokenGroup.PUT("/:id", api.UpdateToken)
tokenGroup.DELETE("/:id", api.DeleteToken)
// tokenGroup.POST("/batch/:option", api.TokenOption)
apiGroup.POST("keys", api.CreateApiKey)
apiGroup.GET("keys", api.ListApiKey)
apiGroup.GET("keys/:id", api.GetApiKey)
apiGroup.PUT("keys/:id", api.UpdateApiKey)
apiGroup.DELETE("keys/:id", api.DeleteApiKey)
apiGroup.POST("keys/batch/:option", api.ApiKeyOption)
}
v1 := r.Group("/v1")
v1.Use(middleware.AuthLLM(db))
{
// v1.POST("/v2/*proxypath", router.HandleProxy)
v1.POST("/*proxypath", proxy.HandleProxy)
// v1.GET("/models", dashboard.HandleModels)
}
idxFS, err := fs.Sub(web, "dist")
if err != nil {
panic(err)
}
assetsFS, err := fs.Sub(web, "dist/assets")
if err != nil {
panic(err)
}
r.StaticFS("/assets", http.FS(assetsFS))
r.NoRoute(func(c *gin.Context) {
if c.Writer.Status() == http.StatusNotFound {
c.FileFromFS("/", http.FS(idxFS))
}
})
srv := &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Port),
Handler: r,
}
go func() {
fmt.Println("Starting server at port:", cfg.Port)
// 服务启动
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("listen: %s\n", err)
}
}()
// 等待中断信号来优雅地关闭服务器
quit := make(chan os.Signal, 1)
// kill (no param) default send syscall.SIGTERM
// kill -2 is syscall.SIGINT
// kill -9 is syscall.SIGKILL but can't be catch
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
fmt.Println("\nShutdown Server ...")
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Fatalln("Server Shutdown:", err)
}
cancel()
sqlDB.Close()
waitChan := make(chan struct{})
go func() {
wg.Wait()
close(waitChan)
}()
select {
case <-waitChan:
fmt.Println("All goroutines have finished")
case <-shutdownCtx.Done():
fmt.Println("⚠️ Shutdown timeout")
}
fmt.Println("Server exited")
}
func printFilesAndDirs(fsys fs.FS, prefix string) error {

3
go.mod
View File

@@ -27,6 +27,7 @@ require (
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkoukk/tiktoken-go v0.1.7
github.com/sashabaranov/go-openai v1.32.2
github.com/spf13/cobra v1.9.1
github.com/tidwall/gjson v1.18.0
golang.org/x/crypto v0.37.0
golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c
@@ -75,6 +76,7 @@ require (
github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect
github.com/googleapis/gax-go/v2 v2.14.1 // indirect
github.com/hajimehoshi/go-mp3 v0.3.4 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect
@@ -92,6 +94,7 @@ require (
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/spf13/pflag v1.0.6 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect

8
go.sum
View File

@@ -35,6 +35,7 @@ github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJ
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/d4l3k/messagediff v1.2.2-0.20190829033028-7e0a312ae40b/go.mod h1:Oozbb1TVXFac9FtSIxHBMnBCq2qeH/2KkEQxENCrlLo=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@@ -129,6 +130,8 @@ github.com/hajimehoshi/oto v0.7.1/go.mod h1:wovJ8WWMfFKvP587mhHgot/MBr4DnNy9m6Ee
github.com/hajimehoshi/oto/v2 v2.3.1/go.mod h1:seWLbgHH7AyUMYKfKYT9pg7PhUu9/SisyJvNTT+ASQo=
github.com/icza/bitio v1.0.0/go.mod h1:0jGnlLAx8MKMr9VGnn/4YrvZiprkvBelsVIbA9Jjr9A=
github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6/go.mod h1:xQig96I1VNBDIWGCdTt54nHt6EeI639SmHycLYL7FkA=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
@@ -193,8 +196,13 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sashabaranov/go-openai v1.32.2 h1:8z9PfYaLPbRzmJIYpwcWu6z3XU8F+RwVMF1QRSeSF2M=
github.com/sashabaranov/go-openai v1.32.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0=
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=

91
internal/cli/cli.go Normal file
View File

@@ -0,0 +1,91 @@
package cli
import (
"encoding/json"
"fmt"
"log"
"opencatd-open/internal/model"
"opencatd-open/pkg/store"
"os"
"strings"
"github.com/duke-git/lancet/v2/fileutil"
"github.com/google/uuid"
"github.com/spf13/cobra"
)
var LoadCmd = &cobra.Command{
Use: "load",
Short: "import user.json -> db",
Long: "\nimport user.json -> db",
Run: func(cmd *cobra.Command, args []string) {
db := store.GetDB()
var cont int64
if err := db.Model(model.User{}).Count(&cont).Error; err != nil {
fmt.Println(err)
return
}
if cont == 0 {
fmt.Println("创建管理员之后再操作")
}
if !fileutil.IsExist("./db/user.json") {
log.Fatalln("404! user.json is not found.")
return
}
file, err := os.Open("./db/user.json")
if err != nil {
fmt.Println("Error opening file:", err)
return
}
defer file.Close()
var usermap []map[string]string
if err := json.NewDecoder(file).Decode(&usermap); err != nil {
fmt.Println("解析文件失败:", err)
return
}
for _, um := range usermap {
var name string
if um["username"] != "" {
name = um["name"]
} else if um["name"] == "" {
name = um["username"]
} else {
fmt.Println("获取不到数据")
continue
}
var user = model.User{
Username: name,
Name: name,
Tokens: []model.Token{
{
Name: "default",
Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""),
},
{
Name: name,
Key: um["token"],
},
},
}
if err := db.Create(&user).Error; err != nil {
fmt.Printf("\nCreate User %s Error:%s", user.Username, err)
}
}
},
}
var SaveCmd = &cobra.Command{
Use: "save",
Short: "backup user info -> user.json",
Run: func(cmd *cobra.Command, args []string) {
},
}
func init() {
// SaveCmd.Flags().StringP("user", "u", "", "Save User")
}

View File

@@ -2,6 +2,20 @@ package consts
import "gorm.io/gorm"
const Logo = `
____ _____
/ __ \ |_ _|
| | | |_ __ ___ _ __ | | ___ __ _ _ __ ___
| | | | '_ \ / _ \ '_ \ | | / _ \/ _' | '_ ' _ \
| |__| | |_) | __/ | | | | || __/ (_| | | | | | |
\____/| .__/ \___|_| |_| \_/ \___|\__,_|_| |_| |_|
| |
|_|
https://github.com/mirrors2/openteam
---------------------------------------------------
`
const SecretKey = "openteam"
const Day = 24 * 60 * 60 // day := 86400

193
router/setRouter.go Normal file
View File

@@ -0,0 +1,193 @@
package router
import (
"context"
"embed"
"fmt"
"io/fs"
"log"
"net/http"
"opencatd-open/middleware"
"opencatd-open/pkg/config"
"opencatd-open/wire"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
func SetRouter(cfg *config.Config, db *gorm.DB, web *embed.FS) {
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
if cfg == nil || db == nil {
panic("cfg or db is nil")
}
sqlDB, err := db.DB()
if err != nil {
log.Fatalf("Failed to get underlying *sql.DB: %v", err)
}
team, err := wire.InitTeamHandler(ctx, cfg, db)
if err != nil {
panic(err)
}
api, err := wire.InitAPIHandler(ctx, cfg, db)
if err != nil {
panic(err)
}
proxy, err := wire.InitProxyHandler(ctx, cfg, db, &wg)
if err != nil {
panic(err)
}
r := gin.Default()
r.Use(middleware.CORS())
teamGroup := r.Group("/1")
teamGroup.Use(team.AuthMiddleware())
{
teamGroup.POST("/users/init", team.InitAdmin)
// 获取当前用户信息
teamGroup.GET("/me", team.Me)
// team.GET("/me/usages", team.HandleMeUsage)
teamGroup.POST("/keys", team.CreateKey)
teamGroup.GET("/keys", team.ListKeys)
teamGroup.POST("/keys/:id", team.UpdateKey)
teamGroup.DELETE("/keys/:id", team.DeleteKey)
teamGroup.POST("/users", team.CreateUser)
teamGroup.GET("/users", team.ListUsers)
teamGroup.POST("/users/:id/reset", team.ResetUserToken)
teamGroup.DELETE("/users/:id", team.DeleteUser)
teamGroup.GET("/1/usages", team.ListUsages)
}
public := r.Group("/api/auth")
{
public.GET("/passkey/begin", api.PasskeyAuthBegin)
public.POST("/passkey/finish", api.PasskeyAuthFinish)
public.POST("/register", api.Register)
public.POST("/login", api.Login)
}
apiGroup := r.Group("/api", middleware.Auth)
{
apiGroup.GET("/profile", api.Profile)
apiGroup.POST("/profile/update", api.UpdateProfile)
apiGroup.POST("/profile/update/password", api.UpdatePassword)
// 绑定PassKey
apiGroup.GET("/profile/passkey", api.PasskeyCreateBegin)
apiGroup.POST("/profile/passkey", api.PasskeyCreateFinish)
apiGroup.GET("/profile/passkeys", api.ListPasskey)
apiGroup.DELETE("/profile/passkeys/:id", api.DeletePasskey)
userGroup := apiGroup.Group("/users")
{
userGroup.POST("", api.CreateUser)
userGroup.GET("", api.ListUser)
userGroup.GET("/:id", api.GetUser)
userGroup.PUT("/:id", api.EditUser)
userGroup.DELETE("/:id", api.DeleteUser)
userGroup.POST("/batch/:option", api.UserOption)
}
tokenGroup := apiGroup.Group("/tokens")
tokenGroup.POST("", api.CreateToken)
tokenGroup.GET("", api.ListToken)
// tokenGroup.GET("/:id", api.GetToken)
tokenGroup.POST("/reset/:id", api.ResetToken)
tokenGroup.PUT("/:id", api.UpdateToken)
tokenGroup.DELETE("/:id", api.DeleteToken)
// tokenGroup.POST("/batch/:option", api.TokenOption)
apiGroup.POST("keys", api.CreateApiKey)
apiGroup.GET("keys", api.ListApiKey)
apiGroup.GET("keys/:id", api.GetApiKey)
apiGroup.PUT("keys/:id", api.UpdateApiKey)
apiGroup.DELETE("keys/:id", api.DeleteApiKey)
apiGroup.POST("keys/batch/:option", api.ApiKeyOption)
}
v1 := r.Group("/v1")
v1.Use(middleware.AuthLLM(db))
{
// v1.POST("/v2/*proxypath", router.HandleProxy)
v1.POST("/*proxypath", proxy.HandleProxy)
// v1.GET("/models", dashboard.HandleModels)
}
idxFS, err := fs.Sub(web, "dist")
if err != nil {
panic(err)
}
assetsFS, err := fs.Sub(web, "dist/assets")
if err != nil {
panic(err)
}
r.StaticFS("/assets", http.FS(assetsFS))
r.NoRoute(func(c *gin.Context) {
if c.Writer.Status() == http.StatusNotFound {
c.FileFromFS("/", http.FS(idxFS))
}
})
srv := &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Port),
Handler: r,
}
go func() {
fmt.Println("Starting server at port:", cfg.Port)
// 服务启动
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("listen: %s\n", err)
}
}()
// 等待中断信号来优雅地关闭服务器
quit := make(chan os.Signal, 1)
// kill (no param) default send syscall.SIGTERM
// kill -2 is syscall.SIGINT
// kill -9 is syscall.SIGKILL but can't be catch
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
fmt.Println("\nShutdown Server ...")
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Fatalln("Server Shutdown:", err)
}
cancel()
sqlDB.Close()
waitChan := make(chan struct{})
go func() {
wg.Wait()
close(waitChan)
}()
select {
case <-waitChan:
fmt.Println("All goroutines have finished")
case <-shutdownCtx.Done():
fmt.Println("⚠️ Shutdown timeout")
}
fmt.Println("Server exited")
}