From f8e539c9b43a3d9729b2c735620a44f4259425f6 Mon Sep 17 00:00:00 2001 From: Sakurasan <26715255+Sakurasan@users.noreply.github.com> Date: Fri, 18 Apr 2025 19:05:23 +0800 Subject: [PATCH] add cli load --- cmd/openteam/main.go | 188 +++---------------------------------- go.mod | 3 + go.sum | 8 ++ internal/cli/cli.go | 91 ++++++++++++++++++ internal/consts/consts.go | 14 +++ router/setRouter.go | 193 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 324 insertions(+), 173 deletions(-) create mode 100644 internal/cli/cli.go create mode 100644 router/setRouter.go diff --git a/cmd/openteam/main.go b/cmd/openteam/main.go index 2e28247..d38ae69 100644 --- a/cmd/openteam/main.go +++ b/cmd/openteam/main.go @@ -1,32 +1,23 @@ package main import ( - "context" "embed" "fmt" "io/fs" "log" - "net/http" - "opencatd-open/middleware" + "opencatd-open/internal/cli" + "opencatd-open/internal/consts" "opencatd-open/pkg/config" "opencatd-open/pkg/store" - "opencatd-open/wire" - "os" - "os/signal" - "sync" - "syscall" - "time" + "opencatd-open/router" - "github.com/gin-gonic/gin" + "github.com/spf13/cobra" ) //go:embed dist/* var web embed.FS func main() { - ctx, cancel := context.WithCancel(context.Background()) - var wg sync.WaitGroup - cfg, err := config.LoadConfig() if err != nil { panic(err) @@ -36,169 +27,20 @@ func main() { if err != nil { panic(err) } - sqlDB, err := db.DB() - if err != nil { - log.Fatalf("Failed to get underlying *sql.DB: %v", err) + + rootCmd := &cobra.Command{ + Use: "openteam", + Short: "openteam cli", + Long: consts.Logo, + Run: func(cmd *cobra.Command, args []string) { + router.SetRouter(cfg, db, &web) + }, } - team, err := wire.InitTeamHandler(ctx, cfg, db) - if err != nil { - panic(err) + rootCmd.AddCommand(cli.LoadCmd) + if err := rootCmd.Execute(); err != nil { + log.Fatal(err) } - - api, err := wire.InitAPIHandler(ctx, cfg, db) - if err != nil { - panic(err) - } - - proxy, err := wire.InitProxyHandler(ctx, cfg, db, &wg) - if err != nil { - panic(err) - } - - r := gin.Default() - r.Use(middleware.CORS()) - - teamGroup := r.Group("/1") - teamGroup.Use(team.AuthMiddleware()) - { - teamGroup.POST("/users/init", team.InitAdmin) - // 获取当前用户信息 - teamGroup.GET("/me", team.Me) - // team.GET("/me/usages", team.HandleMeUsage) - - teamGroup.POST("/keys", team.CreateKey) - teamGroup.GET("/keys", team.ListKeys) - teamGroup.POST("/keys/:id", team.UpdateKey) - teamGroup.DELETE("/keys/:id", team.DeleteKey) - - teamGroup.POST("/users", team.CreateUser) - teamGroup.GET("/users", team.ListUsers) - teamGroup.POST("/users/:id/reset", team.ResetUserToken) - teamGroup.DELETE("/users/:id", team.DeleteUser) - - teamGroup.GET("/1/usages", team.ListUsages) - } - - public := r.Group("/api/auth") - { - public.GET("/passkey/begin", api.PasskeyAuthBegin) - public.POST("/passkey/finish", api.PasskeyAuthFinish) - - public.POST("/register", api.Register) - public.POST("/login", api.Login) - } - - apiGroup := r.Group("/api", middleware.Auth) - { - apiGroup.GET("/profile", api.Profile) - apiGroup.POST("/profile/update", api.UpdateProfile) - apiGroup.POST("/profile/update/password", api.UpdatePassword) - // 绑定PassKey - apiGroup.GET("/profile/passkey", api.PasskeyCreateBegin) - apiGroup.POST("/profile/passkey", api.PasskeyCreateFinish) - apiGroup.GET("/profile/passkeys", api.ListPasskey) - apiGroup.DELETE("/profile/passkeys/:id", api.DeletePasskey) - - userGroup := apiGroup.Group("/users") - { - userGroup.POST("", api.CreateUser) - userGroup.GET("", api.ListUser) - userGroup.GET("/:id", api.GetUser) - userGroup.PUT("/:id", api.EditUser) - userGroup.DELETE("/:id", api.DeleteUser) - userGroup.POST("/batch/:option", api.UserOption) - } - - tokenGroup := apiGroup.Group("/tokens") - tokenGroup.POST("", api.CreateToken) - tokenGroup.GET("", api.ListToken) - // tokenGroup.GET("/:id", api.GetToken) - tokenGroup.POST("/reset/:id", api.ResetToken) - tokenGroup.PUT("/:id", api.UpdateToken) - tokenGroup.DELETE("/:id", api.DeleteToken) - // tokenGroup.POST("/batch/:option", api.TokenOption) - - apiGroup.POST("keys", api.CreateApiKey) - apiGroup.GET("keys", api.ListApiKey) - apiGroup.GET("keys/:id", api.GetApiKey) - apiGroup.PUT("keys/:id", api.UpdateApiKey) - apiGroup.DELETE("keys/:id", api.DeleteApiKey) - apiGroup.POST("keys/batch/:option", api.ApiKeyOption) - } - - v1 := r.Group("/v1") - v1.Use(middleware.AuthLLM(db)) - { - // v1.POST("/v2/*proxypath", router.HandleProxy) - v1.POST("/*proxypath", proxy.HandleProxy) - // v1.GET("/models", dashboard.HandleModels) - } - - idxFS, err := fs.Sub(web, "dist") - if err != nil { - panic(err) - } - - assetsFS, err := fs.Sub(web, "dist/assets") - if err != nil { - panic(err) - } - r.StaticFS("/assets", http.FS(assetsFS)) - - r.NoRoute(func(c *gin.Context) { - if c.Writer.Status() == http.StatusNotFound { - c.FileFromFS("/", http.FS(idxFS)) - } - }) - - srv := &http.Server{ - Addr: fmt.Sprintf(":%d", cfg.Port), - Handler: r, - } - - go func() { - fmt.Println("Starting server at port:", cfg.Port) - // 服务启动 - if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatalf("listen: %s\n", err) - } - }() - - // 等待中断信号来优雅地关闭服务器 - quit := make(chan os.Signal, 1) - // kill (no param) default send syscall.SIGTERM - // kill -2 is syscall.SIGINT - // kill -9 is syscall.SIGKILL but can't be catch - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - <-quit - fmt.Println("\nShutdown Server ...") - shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer shutdownCancel() - - if err := srv.Shutdown(shutdownCtx); err != nil { - log.Fatalln("Server Shutdown:", err) - } - - cancel() - - sqlDB.Close() - - waitChan := make(chan struct{}) - go func() { - wg.Wait() - close(waitChan) - }() - - select { - case <-waitChan: - fmt.Println("All goroutines have finished") - case <-shutdownCtx.Done(): - fmt.Println("⚠️ Shutdown timeout") - } - - fmt.Println("Server exited") - } func printFilesAndDirs(fsys fs.FS, prefix string) error { diff --git a/go.mod b/go.mod index c8bab2d..8e51db1 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkoukk/tiktoken-go v0.1.7 github.com/sashabaranov/go-openai v1.32.2 + github.com/spf13/cobra v1.9.1 github.com/tidwall/gjson v1.18.0 golang.org/x/crypto v0.37.0 golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c @@ -75,6 +76,7 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect github.com/hajimehoshi/go-mp3 v0.3.4 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgx/v5 v5.5.5 // indirect @@ -92,6 +94,7 @@ require ( github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/spf13/pflag v1.0.6 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect diff --git a/go.sum b/go.sum index c97781b..14cbe9f 100644 --- a/go.sum +++ b/go.sum @@ -35,6 +35,7 @@ github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/d4l3k/messagediff v1.2.2-0.20190829033028-7e0a312ae40b/go.mod h1:Oozbb1TVXFac9FtSIxHBMnBCq2qeH/2KkEQxENCrlLo= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -129,6 +130,8 @@ github.com/hajimehoshi/oto v0.7.1/go.mod h1:wovJ8WWMfFKvP587mhHgot/MBr4DnNy9m6Ee github.com/hajimehoshi/oto/v2 v2.3.1/go.mod h1:seWLbgHH7AyUMYKfKYT9pg7PhUu9/SisyJvNTT+ASQo= github.com/icza/bitio v1.0.0/go.mod h1:0jGnlLAx8MKMr9VGnn/4YrvZiprkvBelsVIbA9Jjr9A= github.com/icza/mighty v0.0.0-20180919140131-cfd07d671de6/go.mod h1:xQig96I1VNBDIWGCdTt54nHt6EeI639SmHycLYL7FkA= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= @@ -193,8 +196,13 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94 github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sashabaranov/go-openai v1.32.2 h1:8z9PfYaLPbRzmJIYpwcWu6z3XU8F+RwVMF1QRSeSF2M= github.com/sashabaranov/go-openai v1.32.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= +github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= +github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= +github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= +github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/internal/cli/cli.go b/internal/cli/cli.go new file mode 100644 index 0000000..9fedc13 --- /dev/null +++ b/internal/cli/cli.go @@ -0,0 +1,91 @@ +package cli + +import ( + "encoding/json" + "fmt" + "log" + "opencatd-open/internal/model" + "opencatd-open/pkg/store" + "os" + "strings" + + "github.com/duke-git/lancet/v2/fileutil" + + "github.com/google/uuid" + "github.com/spf13/cobra" +) + +var LoadCmd = &cobra.Command{ + Use: "load", + Short: "import user.json -> db", + Long: "\nimport user.json -> db", + Run: func(cmd *cobra.Command, args []string) { + db := store.GetDB() + var cont int64 + if err := db.Model(model.User{}).Count(&cont).Error; err != nil { + fmt.Println(err) + return + } + if cont == 0 { + fmt.Println("创建管理员之后再操作") + } + if !fileutil.IsExist("./db/user.json") { + log.Fatalln("404! user.json is not found.") + return + } + file, err := os.Open("./db/user.json") + if err != nil { + fmt.Println("Error opening file:", err) + return + } + defer file.Close() + + var usermap []map[string]string + + if err := json.NewDecoder(file).Decode(&usermap); err != nil { + fmt.Println("解析文件失败:", err) + return + } + for _, um := range usermap { + var name string + if um["username"] != "" { + name = um["name"] + } else if um["name"] == "" { + name = um["username"] + } else { + fmt.Println("获取不到数据") + continue + } + var user = model.User{ + Username: name, + Name: name, + Tokens: []model.Token{ + { + Name: "default", + Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""), + }, + { + Name: name, + Key: um["token"], + }, + }, + } + if err := db.Create(&user).Error; err != nil { + fmt.Printf("\nCreate User %s Error:%s", user.Username, err) + } + } + + }, +} + +var SaveCmd = &cobra.Command{ + Use: "save", + Short: "backup user info -> user.json", + Run: func(cmd *cobra.Command, args []string) { + + }, +} + +func init() { + // SaveCmd.Flags().StringP("user", "u", "", "Save User") +} diff --git a/internal/consts/consts.go b/internal/consts/consts.go index 45655e8..f3c1310 100644 --- a/internal/consts/consts.go +++ b/internal/consts/consts.go @@ -2,6 +2,20 @@ package consts import "gorm.io/gorm" +const Logo = ` + ____ _____ + / __ \ |_ _| +| | | |_ __ ___ _ __ | | ___ __ _ _ __ ___ +| | | | '_ \ / _ \ '_ \ | | / _ \/ _' | '_ ' _ \ +| |__| | |_) | __/ | | | | || __/ (_| | | | | | | + \____/| .__/ \___|_| |_| \_/ \___|\__,_|_| |_| |_| + | | + |_| + +https://github.com/mirrors2/openteam +--------------------------------------------------- +` + const SecretKey = "openteam" const Day = 24 * 60 * 60 // day := 86400 diff --git a/router/setRouter.go b/router/setRouter.go new file mode 100644 index 0000000..301971d --- /dev/null +++ b/router/setRouter.go @@ -0,0 +1,193 @@ +package router + +import ( + "context" + "embed" + "fmt" + "io/fs" + "log" + "net/http" + "opencatd-open/middleware" + "opencatd-open/pkg/config" + "opencatd-open/wire" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +func SetRouter(cfg *config.Config, db *gorm.DB, web *embed.FS) { + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + + if cfg == nil || db == nil { + panic("cfg or db is nil") + } + + sqlDB, err := db.DB() + if err != nil { + log.Fatalf("Failed to get underlying *sql.DB: %v", err) + } + + team, err := wire.InitTeamHandler(ctx, cfg, db) + if err != nil { + panic(err) + } + + api, err := wire.InitAPIHandler(ctx, cfg, db) + if err != nil { + panic(err) + } + + proxy, err := wire.InitProxyHandler(ctx, cfg, db, &wg) + if err != nil { + panic(err) + } + + r := gin.Default() + r.Use(middleware.CORS()) + + teamGroup := r.Group("/1") + teamGroup.Use(team.AuthMiddleware()) + { + teamGroup.POST("/users/init", team.InitAdmin) + // 获取当前用户信息 + teamGroup.GET("/me", team.Me) + // team.GET("/me/usages", team.HandleMeUsage) + + teamGroup.POST("/keys", team.CreateKey) + teamGroup.GET("/keys", team.ListKeys) + teamGroup.POST("/keys/:id", team.UpdateKey) + teamGroup.DELETE("/keys/:id", team.DeleteKey) + + teamGroup.POST("/users", team.CreateUser) + teamGroup.GET("/users", team.ListUsers) + teamGroup.POST("/users/:id/reset", team.ResetUserToken) + teamGroup.DELETE("/users/:id", team.DeleteUser) + + teamGroup.GET("/1/usages", team.ListUsages) + } + + public := r.Group("/api/auth") + { + public.GET("/passkey/begin", api.PasskeyAuthBegin) + public.POST("/passkey/finish", api.PasskeyAuthFinish) + + public.POST("/register", api.Register) + public.POST("/login", api.Login) + } + + apiGroup := r.Group("/api", middleware.Auth) + { + apiGroup.GET("/profile", api.Profile) + apiGroup.POST("/profile/update", api.UpdateProfile) + apiGroup.POST("/profile/update/password", api.UpdatePassword) + // 绑定PassKey + apiGroup.GET("/profile/passkey", api.PasskeyCreateBegin) + apiGroup.POST("/profile/passkey", api.PasskeyCreateFinish) + apiGroup.GET("/profile/passkeys", api.ListPasskey) + apiGroup.DELETE("/profile/passkeys/:id", api.DeletePasskey) + + userGroup := apiGroup.Group("/users") + { + userGroup.POST("", api.CreateUser) + userGroup.GET("", api.ListUser) + userGroup.GET("/:id", api.GetUser) + userGroup.PUT("/:id", api.EditUser) + userGroup.DELETE("/:id", api.DeleteUser) + userGroup.POST("/batch/:option", api.UserOption) + } + + tokenGroup := apiGroup.Group("/tokens") + tokenGroup.POST("", api.CreateToken) + tokenGroup.GET("", api.ListToken) + // tokenGroup.GET("/:id", api.GetToken) + tokenGroup.POST("/reset/:id", api.ResetToken) + tokenGroup.PUT("/:id", api.UpdateToken) + tokenGroup.DELETE("/:id", api.DeleteToken) + // tokenGroup.POST("/batch/:option", api.TokenOption) + + apiGroup.POST("keys", api.CreateApiKey) + apiGroup.GET("keys", api.ListApiKey) + apiGroup.GET("keys/:id", api.GetApiKey) + apiGroup.PUT("keys/:id", api.UpdateApiKey) + apiGroup.DELETE("keys/:id", api.DeleteApiKey) + apiGroup.POST("keys/batch/:option", api.ApiKeyOption) + } + + v1 := r.Group("/v1") + v1.Use(middleware.AuthLLM(db)) + { + // v1.POST("/v2/*proxypath", router.HandleProxy) + v1.POST("/*proxypath", proxy.HandleProxy) + // v1.GET("/models", dashboard.HandleModels) + } + + idxFS, err := fs.Sub(web, "dist") + if err != nil { + panic(err) + } + + assetsFS, err := fs.Sub(web, "dist/assets") + if err != nil { + panic(err) + } + r.StaticFS("/assets", http.FS(assetsFS)) + + r.NoRoute(func(c *gin.Context) { + if c.Writer.Status() == http.StatusNotFound { + c.FileFromFS("/", http.FS(idxFS)) + } + }) + + srv := &http.Server{ + Addr: fmt.Sprintf(":%d", cfg.Port), + Handler: r, + } + + go func() { + fmt.Println("Starting server at port:", cfg.Port) + // 服务启动 + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("listen: %s\n", err) + } + }() + + // 等待中断信号来优雅地关闭服务器 + quit := make(chan os.Signal, 1) + // kill (no param) default send syscall.SIGTERM + // kill -2 is syscall.SIGINT + // kill -9 is syscall.SIGKILL but can't be catch + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + fmt.Println("\nShutdown Server ...") + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + if err := srv.Shutdown(shutdownCtx); err != nil { + log.Fatalln("Server Shutdown:", err) + } + + cancel() + + sqlDB.Close() + + waitChan := make(chan struct{}) + go func() { + wg.Wait() + close(waitChan) + }() + + select { + case <-waitChan: + fmt.Println("All goroutines have finished") + case <-shutdownCtx.Done(): + fmt.Println("⚠️ Shutdown timeout") + } + + fmt.Println("Server exited") +}