From e7ffc9e8b9455b6144322ec329b9967e6dbeb78c Mon Sep 17 00:00:00 2001 From: Sakurasan <26715255+Sakurasan@users.noreply.github.com> Date: Wed, 16 Apr 2025 18:01:27 +0800 Subject: [PATCH] reface to openteam --- .gitignore | 3 +- cmd/openteam/main.go | 120 +++++- {docker => deploy/docker}/Dockerfile | 0 deploy/docker/docker-compose.mariadb.yml | 25 ++ deploy/docker/docker-compose.pg.yml | 27 ++ {docker => deploy/docker}/docker-compose.yml | 0 docker-compose.yml | 26 -- go.mod | 76 ++-- go.sum | 283 +++++++------- internal/auth/auth.go | 77 ++++ {team => internal}/consts/consts.go | 18 +- internal/controller/apikey.go | 152 ++++++++ internal/controller/init.go | 27 ++ internal/controller/proxy/chat_proxy.go | 60 +++ internal/controller/proxy/proxy.go | 345 ++++++++++++++++++ .../controller}/team/middleware.go | 6 +- .../controller}/team/team.go | 194 +++++----- internal/controller/user.go | 218 +++++++++++ internal/controller/user_token.go | 218 +++++++++++ internal/controller/webauth.go | 108 ++++++ {team => internal}/dao/apikey.go | 88 +++-- {team => internal}/dao/token.go | 86 +++-- {team => internal}/dao/usage.go | 28 +- {team => internal}/dao/user.go | 90 ++--- internal/dto/batch.go | 6 + .../err_resp.go => internal/dto/error.go | 12 +- internal/dto/key.go | 107 ++++++ internal/dto/passkey.go | 11 + internal/dto/response.go | 28 ++ {team => internal}/dto/team/team.go | 28 +- internal/dto/user.go | 16 + internal/model/apikey.go | 50 +++ internal/model/passkey.go | 38 ++ internal/model/token.go | 22 ++ {team => internal}/model/usage.go | 0 internal/model/user.go | 53 +++ internal/service/apikey.go | 92 +++++ .../service/team}/apikey.go | 32 +- .../service/team}/token.go | 46 +-- internal/service/team/usage.go | 62 ++++ .../service => internal/service/team}/user.go | 185 ++++------ internal/service/token.go | 251 +++++++++++++ internal/service/usage.go | 22 ++ internal/service/user.go | 320 ++++++++++++++++ internal/service/webauth.go | 304 +++++++++++++++ internal/utils/convert.go | 16 + internal/utils/map_tools.go | 139 +++++++ internal/utils/password.go | 15 + internal/utils/pointer.go | 6 + llm/aws/aws.go | 36 ++ {pkg => llm}/azureopenai/azureopenai.go | 1 - llm/claude/chat.go | 138 +++++++ {pkg => llm}/claude/claude.go | 0 .../chat.go => llm/claude/handle_proxy.go | 105 +----- llm/claude/v2/chat.go | 246 +++++++++++++ {pkg => llm}/google/chat.go | 2 +- llm/google/v2/chat.go | 228 ++++++++++++ llm/llm.go | 20 + llm/openai/chat.go | 178 +++++++++ {pkg => llm}/openai/dall-e.go | 0 .../chat.go => llm/openai/handle_proxy.go | 172 --------- {pkg => llm}/openai/realtime.go | 0 {pkg => llm}/openai/tts.go | 0 {pkg => llm}/openai/whisper.go | 0 llm/openai_compatible/chat.go | 221 +++++++++++ llm/types.go | 41 +++ {pkg => llm}/vertexai/auth.go | 0 makefile | 2 +- middleware/auth.go | 55 +++ middleware/auth_team.go | 103 ++++++ middleware/cors.go | 15 + middleware/ratelimit.go | 53 +++ opencat.go | 6 - pkg/config/config.go | 241 ++++++++++++ pkg/search/bing_test.go | 30 ++ pkg/store/db.go | 29 +- pkg/store/gcache.go | 67 ++++ pkg/team/key.go | 2 +- router/chat.go | 6 +- router/router.go | 37 +- store/cache.go | 2 +- store/db.go | 2 +- store/keydb.go | 4 +- team/dashboard/dashboard.go | 16 - team/dashboard/login.go | 20 - team/key.go | 2 +- team/model/apikey.go | 45 --- team/model/token.go | 20 - team/model/user.go | 49 --- team/service/usage.go | 137 ------- wire/wire.go | 92 ++++- wire/wire_gen.go | 59 ++- 92 files changed, 5345 insertions(+), 1273 deletions(-) rename {docker => deploy/docker}/Dockerfile (100%) create mode 100644 deploy/docker/docker-compose.mariadb.yml create mode 100644 deploy/docker/docker-compose.pg.yml rename {docker => deploy/docker}/docker-compose.yml (100%) delete mode 100644 docker-compose.yml create mode 100644 internal/auth/auth.go rename {team => internal}/consts/consts.go (70%) create mode 100644 internal/controller/apikey.go create mode 100644 internal/controller/init.go create mode 100644 internal/controller/proxy/chat_proxy.go create mode 100644 internal/controller/proxy/proxy.go rename {team/handler => internal/controller}/team/middleware.go (92%) rename {team/handler => internal/controller}/team/team.go (71%) create mode 100644 internal/controller/user.go create mode 100644 internal/controller/user_token.go create mode 100644 internal/controller/webauth.go rename {team => internal}/dao/apikey.go (61%) rename {team => internal}/dao/token.go (57%) rename {team => internal}/dao/usage.go (95%) rename {team => internal}/dao/user.go (56%) create mode 100644 internal/dto/batch.go rename team/dto/openai/err_resp.go => internal/dto/error.go (53%) create mode 100644 internal/dto/key.go create mode 100644 internal/dto/passkey.go create mode 100644 internal/dto/response.go rename {team => internal}/dto/team/team.go (68%) create mode 100644 internal/dto/user.go create mode 100644 internal/model/apikey.go create mode 100644 internal/model/passkey.go create mode 100644 internal/model/token.go rename {team => internal}/model/usage.go (100%) create mode 100644 internal/model/user.go create mode 100644 internal/service/apikey.go rename {team/service => internal/service/team}/apikey.go (80%) rename {team/service => internal/service/team}/token.go (53%) create mode 100644 internal/service/team/usage.go rename {team/service => internal/service/team}/user.go (82%) create mode 100644 internal/service/token.go create mode 100644 internal/service/usage.go create mode 100644 internal/service/user.go create mode 100644 internal/service/webauth.go create mode 100644 internal/utils/convert.go create mode 100644 internal/utils/map_tools.go create mode 100644 internal/utils/password.go create mode 100644 llm/aws/aws.go rename {pkg => llm}/azureopenai/azureopenai.go (99%) create mode 100644 llm/claude/chat.go rename {pkg => llm}/claude/claude.go (100%) rename pkg/claude/chat.go => llm/claude/handle_proxy.go (70%) create mode 100644 llm/claude/v2/chat.go rename {pkg => llm}/google/chat.go (99%) create mode 100644 llm/google/v2/chat.go create mode 100644 llm/llm.go create mode 100644 llm/openai/chat.go rename {pkg => llm}/openai/dall-e.go (100%) rename pkg/openai/chat.go => llm/openai/handle_proxy.go (53%) rename {pkg => llm}/openai/realtime.go (100%) rename {pkg => llm}/openai/tts.go (100%) rename {pkg => llm}/openai/whisper.go (100%) create mode 100644 llm/openai_compatible/chat.go create mode 100644 llm/types.go rename {pkg => llm}/vertexai/auth.go (100%) create mode 100644 middleware/auth.go create mode 100644 middleware/auth_team.go create mode 100644 middleware/cors.go create mode 100644 middleware/ratelimit.go create mode 100644 pkg/config/config.go create mode 100644 pkg/search/bing_test.go create mode 100644 pkg/store/gcache.go delete mode 100644 team/dashboard/dashboard.go delete mode 100644 team/dashboard/login.go delete mode 100644 team/model/apikey.go delete mode 100644 team/model/token.go delete mode 100644 team/model/user.go delete mode 100644 team/service/usage.go diff --git a/.gitignore b/.gitignore index 2312a81..17faf3d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ bin/ test/ +demo/ *.log *.db -demo/ \ No newline at end of file +.env \ No newline at end of file diff --git a/cmd/openteam/main.go b/cmd/openteam/main.go index f11a165..b343de6 100644 --- a/cmd/openteam/main.go +++ b/cmd/openteam/main.go @@ -2,13 +2,16 @@ package main import ( "context" + "fmt" "log" "net/http" + "opencatd-open/middleware" + "opencatd-open/pkg/config" "opencatd-open/pkg/store" - "opencatd-open/team/dashboard" "opencatd-open/wire" "os" "os/signal" + "sync" "syscall" "time" @@ -16,26 +19,47 @@ import ( ) func main() { - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup - _, err := store.InitDB() + cfg, err := config.LoadConfig() if err != nil { panic(err) } - team, err := wire.InitTeamHandler(ctx, store.DB) + db, err := store.InitDB(cfg) + if err != nil { + panic(err) + } + sqlDB, err := db.DB() + if err != nil { + log.Fatalf("Failed to get underlying *sql.DB: %v", err) + } + + team, err := wire.InitTeamHandler(ctx, cfg, db) + if err != nil { + panic(err) + } + + api, err := wire.InitAPIHandler(ctx, cfg, db) + if err != nil { + panic(err) + } + + proxy, err := wire.InitProxyHandler(ctx, cfg, db, &wg) if err != nil { panic(err) } r := gin.Default() + r.Use(middleware.CORS()) teamGroup := r.Group("/1") teamGroup.Use(team.AuthMiddleware()) { teamGroup.POST("/users/init", team.InitAdmin) // 获取当前用户信息 teamGroup.GET("/me", team.Me) - //// team.GET("/me/usages", team.HandleMeUsage) + // team.GET("/me/usages", team.HandleMeUsage) teamGroup.POST("/keys", team.CreateKey) teamGroup.GET("/keys", team.ListKeys) @@ -50,9 +74,59 @@ func main() { teamGroup.GET("/1/usages", team.ListUsages) } - api := r.Group("/api") + public := r.Group("/api/auth") { - api.POST("/login", dashboard.HandleLogin) + public.GET("/passkey/begin", api.PasskeyAuthBegin) + public.POST("/passkey/finish", api.PasskeyAuthFinish) + + public.POST("/register", api.Register) + public.POST("/login", api.Login) + } + + apiGroup := r.Group("/api", middleware.Auth) + { + apiGroup.GET("/profile", api.Profile) + apiGroup.POST("/profile/update", api.UpdateProfile) + apiGroup.POST("/profile/update/password", api.UpdatePassword) + // 绑定PassKey + apiGroup.GET("/profile/passkey", api.PasskeyCreateBegin) + apiGroup.POST("/profile/passkey", api.PasskeyCreateFinish) + apiGroup.GET("/profile/passkeys", api.ListPasskey) + apiGroup.DELETE("/profile/passkeys/:id", api.DeletePasskey) + + userGroup := apiGroup.Group("/users") + { + userGroup.POST("", api.CreateUser) + userGroup.GET("", api.ListUser) + userGroup.GET("/:id", api.GetUser) + userGroup.PUT("/:id", api.EditUser) + userGroup.DELETE("/:id", api.DeleteUser) + userGroup.POST("/batch/:option", api.UserOption) + } + + tokenGroup := apiGroup.Group("/tokens") + tokenGroup.POST("", api.CreateToken) + tokenGroup.GET("", api.ListToken) + // tokenGroup.GET("/:id", api.GetToken) + tokenGroup.POST("/reset/:id", api.ResetToken) + tokenGroup.PUT("/:id", api.UpdateToken) + tokenGroup.DELETE("/:id", api.DeleteToken) + // tokenGroup.POST("/batch/:option", api.TokenOption) + + apiGroup.POST("keys", api.CreateApiKey) + apiGroup.GET("keys", api.ListApiKey) + apiGroup.GET("keys/:id", api.GetApiKey) + apiGroup.PUT("keys/:id", api.UpdateApiKey) + apiGroup.DELETE("keys/:id", api.DeleteApiKey) + apiGroup.POST("keys/batch/:option", api.ApiKeyOption) + } + + v1 := r.Group("/v1") + v1.Use(middleware.AuthLLM(store.DB)) + { + // v1.POST("/v2/*proxypath", router.HandleProxy) + v1.POST("/v1/*proxypath", proxy.HandleProxy) + // v1.GET("/models", dashboard.HandleModels) } srv := &http.Server{ @@ -74,19 +148,31 @@ func main() { // kill -9 is syscall.SIGKILL but can't be catch signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit - log.Println("Shutdown Server ...") + fmt.Println("\nShutdown Server ...") + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() - defer cancel() - if err := srv.Shutdown(ctx); err != nil { - log.Fatal("Server Shutdown:", err) + if err := srv.Shutdown(shutdownCtx); err != nil { + log.Fatalln("Server Shutdown:", err) } - db, _ := store.DB.DB() - db.Close() - // catching ctx.Done(). timeout of 1 seconds. + + cancel() + + sqlDB.Close() + + waitChan := make(chan struct{}) + go func() { + wg.Wait() + close(waitChan) + }() + select { - case <-ctx.Done(): - log.Println("timeout of 5 seconds.") + case <-waitChan: + fmt.Println("All goroutines have finished") + case <-shutdownCtx.Done(): + fmt.Println("⚠️ Shutdown timeout") } - log.Println("Server exiting") + + fmt.Println("Server exited") } diff --git a/docker/Dockerfile b/deploy/docker/Dockerfile similarity index 100% rename from docker/Dockerfile rename to deploy/docker/Dockerfile diff --git a/deploy/docker/docker-compose.mariadb.yml b/deploy/docker/docker-compose.mariadb.yml new file mode 100644 index 0000000..928b1ba --- /dev/null +++ b/deploy/docker/docker-compose.mariadb.yml @@ -0,0 +1,25 @@ +version: '3.9' + +services: + mariadb: + image: mariadb + container_name: mysql + ports: + - "3306:3306" + volumes: + - ${PWD}/mysqldb:/var/lib/mysql + command: + - --character-set-server=utf8mb4 + - --collation-server=utf8mb4_unicode_ci + - --skip-character-set-client-handshake + environment: + MYSQL_ROOT_PASSWORD: openteam + MYSQL_DATABASE: openteam + MYSQL_USER: openteam + MYSQL_PASSWORD: openteam + + # adminer: + # image: adminer + # restart: always + # ports: + # - 8080:8080 diff --git a/deploy/docker/docker-compose.pg.yml b/deploy/docker/docker-compose.pg.yml new file mode 100644 index 0000000..8f23d13 --- /dev/null +++ b/deploy/docker/docker-compose.pg.yml @@ -0,0 +1,27 @@ +# CREATE EXTENSION vector; +# SELECT * FROM pg_extension; +# SELECT * FROM pg_available_extensions; + +version: '3.9' + +services: + pg: + image: pgvector/pgvector:pg17 + # image: paradedb/paradedb + container_name: pg + restart: always + # network_mode: host + ports: + - 5432:5432 + environment: + POSTGRES_DB: openteam + POSTGRES_USER: openteam + POSTGRES_PASSWORD: openteam + volumes: + - $PWD/pgdata:/var/lib/postgresql/data + + # adminer: + # image: adminer + # restart: always + # ports: + # - 8080:8080 diff --git a/docker/docker-compose.yml b/deploy/docker/docker-compose.yml similarity index 100% rename from docker/docker-compose.yml rename to deploy/docker/docker-compose.yml diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index 36ade56..0000000 --- a/docker-compose.yml +++ /dev/null @@ -1,26 +0,0 @@ -# Email: admin@example.com -# Password: changeme -version: '3' - -services: - npm: - image: jc21/nginx-proxy-manager - network_mode: host - ports: - - '80:80' - - '81:81' - - '443:443' - volumes: - - $PWD/data:/data - - $PWD/www:/var/www - - $PWD/letsencrypt:/etc/letsencrypt - environment: - - "TZ=Asia/Shanghai" # set timezone, default UTC - - "PUID=1000" # set group id, default 0 (root) - - "PGID=1000" - - # certbot: - # image: certbot/certbot - # volumes: - # - $PWD/data/certbot/conf:/etc/letsencrypt - # - $PWD/data/certbot/www:/var/www/certbot diff --git a/go.mod b/go.mod index e59a1a9..d733033 100644 --- a/go.mod +++ b/go.mod @@ -5,24 +5,35 @@ go 1.23.2 require ( cloud.google.com/go/vertexai v0.13.1 github.com/Sakurasan/to v0.0.0-20180919163141-e72657dd7c7d + github.com/bluele/gcache v0.0.2 github.com/coder/websocket v1.8.12 github.com/duke-git/lancet/v2 v2.3.3 github.com/faiface/beep v1.1.0 github.com/gin-contrib/cors v1.7.2 github.com/gin-gonic/gin v1.10.0 github.com/glebarez/sqlite v1.11.0 + github.com/go-ozzo/ozzo-validation/v4 v4.3.0 + github.com/go-webauthn/webauthn v0.12.3 github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/golang-jwt/jwt/v5 v5.2.2 github.com/google/generative-ai-go v0.18.0 github.com/google/uuid v1.6.0 + github.com/google/wire v0.6.0 github.com/gorilla/websocket v1.5.3 github.com/joho/godotenv v1.5.1 github.com/lib/pq v1.10.9 + github.com/liushuangls/go-anthropic/v2 v2.15.0 + github.com/mileusna/useragent v1.3.5 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkoukk/tiktoken-go v0.1.7 github.com/sashabaranov/go-openai v1.32.2 github.com/tidwall/gjson v1.18.0 - golang.org/x/sync v0.8.0 - google.golang.org/api v0.201.0 + golang.org/x/crypto v0.37.0 + golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c + golang.org/x/sync v0.13.0 + golang.org/x/time v0.10.0 + google.golang.org/api v0.224.0 + google.golang.org/genai v1.0.0 gopkg.in/vansante/go-ffprobe.v2 v2.2.0 gorm.io/driver/mysql v1.5.7 gorm.io/driver/postgres v1.5.11 @@ -30,14 +41,14 @@ require ( ) require ( - cloud.google.com/go v0.116.0 // indirect + cloud.google.com/go v0.120.0 // indirect cloud.google.com/go/ai v0.8.2 // indirect - cloud.google.com/go/aiplatform v1.68.0 // indirect - cloud.google.com/go/auth v0.9.8 // indirect - cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect - cloud.google.com/go/compute/metadata v0.5.2 // indirect - cloud.google.com/go/iam v1.2.1 // indirect - cloud.google.com/go/longrunning v0.6.1 // indirect + cloud.google.com/go/aiplatform v1.74.0 // indirect + cloud.google.com/go/auth v0.15.0 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.7 // indirect + cloud.google.com/go/compute/metadata v0.6.0 // indirect + cloud.google.com/go/iam v1.4.0 // indirect + cloud.google.com/go/longrunning v0.6.4 // indirect filippo.io/edwards25519 v1.1.0 // indirect github.com/bytedance/sonic v1.12.3 // indirect github.com/bytedance/sonic/loader v0.2.1 // indirect @@ -46,6 +57,7 @@ require ( github.com/dlclark/regexp2 v1.11.4 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/fxamacker/cbor/v2 v2.8.0 // indirect github.com/gabriel-vasile/mimetype v1.4.6 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/glebarez/go-sqlite v1.22.0 // indirect @@ -55,12 +67,14 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.22.1 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/go-webauthn/x v0.1.20 // indirect github.com/goccy/go-json v0.10.3 // indirect - github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect + github.com/google/go-cmp v0.7.0 // indirect + github.com/google/go-tpm v0.9.3 // indirect github.com/google/pprof v0.0.0-20240827171923-fa2c70bbbfe5 // indirect - github.com/google/s2a-go v0.1.8 // indirect - github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect - github.com/googleapis/gax-go/v2 v2.13.0 // indirect + github.com/google/s2a-go v0.1.9 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect + github.com/googleapis/gax-go/v2 v2.14.1 // indirect github.com/hajimehoshi/go-mp3 v0.3.4 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect @@ -70,9 +84,9 @@ require ( github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect - github.com/kr/text v0.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect @@ -83,25 +97,23 @@ require ( github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect - go.opencensus.io v0.24.0 // indirect - go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.56.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect - go.opentelemetry.io/otel v1.31.0 // indirect - go.opentelemetry.io/otel/metric v1.31.0 // indirect - go.opentelemetry.io/otel/trace v1.31.0 // indirect + github.com/x448/float16 v0.8.4 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect + go.opentelemetry.io/otel v1.35.0 // indirect + go.opentelemetry.io/otel/metric v1.35.0 // indirect + go.opentelemetry.io/otel/trace v1.35.0 // indirect golang.org/x/arch v0.11.0 // indirect - golang.org/x/crypto v0.28.0 // indirect - golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c // indirect - golang.org/x/net v0.30.0 // indirect - golang.org/x/oauth2 v0.23.0 // indirect - golang.org/x/sys v0.26.0 // indirect - golang.org/x/text v0.19.0 // indirect - golang.org/x/time v0.7.0 // indirect - google.golang.org/genproto v0.0.0-20241007155032-5fefd90f89a9 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20241015192408-796eee8c2d53 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53 // indirect - google.golang.org/grpc v1.67.1 // indirect - google.golang.org/protobuf v1.35.1 // indirect + golang.org/x/net v0.39.0 // indirect + golang.org/x/oauth2 v0.28.0 // indirect + golang.org/x/sys v0.32.0 // indirect + golang.org/x/text v0.24.0 // indirect + google.golang.org/genproto v0.0.0-20250303144028-a0af3efb3deb // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20250303144028-a0af3efb3deb // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a // indirect + google.golang.org/grpc v1.71.1 // indirect + google.golang.org/protobuf v1.36.6 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect modernc.org/libc v1.61.0 // indirect modernc.org/mathutil v1.6.0 // indirect diff --git a/go.sum b/go.sum index f1bb6ab..b1985f0 100644 --- a/go.sum +++ b/go.sum @@ -1,43 +1,41 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= -cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= +cloud.google.com/go v0.120.0 h1:wc6bgG9DHyKqF5/vQvX1CiZrtHnxJjBlKUyF9nP6meA= +cloud.google.com/go v0.120.0/go.mod h1:/beW32s8/pGRuj4IILWQNd4uuebeT4dkOhKmkfit64Q= cloud.google.com/go/ai v0.8.2 h1:LEaQwqBv+k2ybrcdTtCTc9OPZXoEdcQaGrfvDYS6Bnk= cloud.google.com/go/ai v0.8.2/go.mod h1:Wb3EUUGWwB6yHBaUf/+oxUq/6XbCaU1yh0GrwUS8lr4= -cloud.google.com/go/aiplatform v1.68.0 h1:EPPqgHDJpBZKRvv+OsB3cr0jYz3EL2pZ+802rBPcG8U= -cloud.google.com/go/aiplatform v1.68.0/go.mod h1:105MFA3svHjC3Oazl7yjXAmIR89LKhRAeNdnDKJczME= -cloud.google.com/go/auth v0.9.8 h1:+CSJ0Gw9iVeSENVCKJoLHhdUykDgXSc4Qn+gu2BRtR8= -cloud.google.com/go/auth v0.9.8/go.mod h1:xxA5AqpDrvS+Gkmo9RqrGGRh6WSNKKOXhY3zNOr38tI= -cloud.google.com/go/auth/oauth2adapt v0.2.4 h1:0GWE/FUsXhf6C+jAkWgYm7X9tK8cuEIfy19DBn6B6bY= -cloud.google.com/go/auth/oauth2adapt v0.2.4/go.mod h1:jC/jOpwFP6JBxhB3P5Rr0a9HLMC/Pe3eaL4NmdvqPtc= -cloud.google.com/go/compute/metadata v0.5.2 h1:UxK4uu/Tn+I3p2dYWTfiX4wva7aYlKixAHn3fyqngqo= -cloud.google.com/go/compute/metadata v0.5.2/go.mod h1:C66sj2AluDcIqakBq/M8lw8/ybHgOZqin2obFxa/E5k= -cloud.google.com/go/iam v1.2.1 h1:QFct02HRb7H12J/3utj0qf5tobFh9V4vR6h9eX5EBRU= -cloud.google.com/go/iam v1.2.1/go.mod h1:3VUIJDPpwT6p/amXRC5GY8fCCh70lxPygguVtI0Z4/g= -cloud.google.com/go/longrunning v0.6.1 h1:lOLTFxYpr8hcRtcwWir5ITh1PAKUD/sG2lKrTSYjyMc= -cloud.google.com/go/longrunning v0.6.1/go.mod h1:nHISoOZpBcmlwbJmiVk5oDRz0qG/ZxPynEGs1iZ79s0= +cloud.google.com/go/aiplatform v1.74.0 h1:rE2P5H7FOAFISAZilmdkapbk4CVgwfVs6FDWlhGfuy0= +cloud.google.com/go/aiplatform v1.74.0/go.mod h1:hVEw30CetNut5FrblYd1AJUWRVSIjoyIvp0EVUh51HA= +cloud.google.com/go/auth v0.15.0 h1:Ly0u4aA5vG/fsSsxu98qCQBemXtAtJf+95z9HK+cxps= +cloud.google.com/go/auth v0.15.0/go.mod h1:WJDGqZ1o9E9wKIL+IwStfyn/+s59zl4Bi+1KQNVXLZ8= +cloud.google.com/go/auth/oauth2adapt v0.2.7 h1:/Lc7xODdqcEw8IrZ9SvwnlLX6j9FHQM74z6cBk9Rw6M= +cloud.google.com/go/auth/oauth2adapt v0.2.7/go.mod h1:NTbTTzfvPl1Y3V1nPpOgl2w6d/FjO7NNUQaWSox6ZMc= +cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= +cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= +cloud.google.com/go/iam v1.4.0 h1:ZNfy/TYfn2uh/ukvhp783WhnbVluqf/tzOaqVUPlIPA= +cloud.google.com/go/iam v1.4.0/go.mod h1:gMBgqPaERlriaOV0CUl//XUzDhSfXevn4OEUbg6VRs4= +cloud.google.com/go/longrunning v0.6.4 h1:3tyw9rO3E2XVXzSApn1gyEEnH2K9SynNQjMlBi3uHLg= +cloud.google.com/go/longrunning v0.6.4/go.mod h1:ttZpLCe6e7EXvn9OxpBRx7kZEB0efv8yBO6YnVMfhJs= cloud.google.com/go/vertexai v0.13.1 h1:E6I+eA6vNQxz7/rb0wdILdKg4hFmMNWZLp+dSy9DnEo= cloud.google.com/go/vertexai v0.13.1/go.mod h1:25DzKFzP9JByYxcNjJefu/px2dRjcRpCDSdULYL2avI= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/Sakurasan/to v0.0.0-20180919163141-e72657dd7c7d h1:3v1QFdgk450QH+7C+lw1k+olbjK4fKGsrEfnEG/HLkY= github.com/Sakurasan/to v0.0.0-20180919163141-e72657dd7c7d/go.mod h1:2sp0vsMyh5sqmKl5N+ps/cSspqLkoXUlesSzsufIGRU= +github.com/asaskevich/govalidator v0.0.0-20200108200545-475eaeb16496 h1:zV3ejI06GQ59hwDQAvmK1qxOQGB3WuVTRoY0okPTAv0= +github.com/asaskevich/govalidator v0.0.0-20200108200545-475eaeb16496/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg= +github.com/bluele/gcache v0.0.2 h1:WcbfdXICg7G/DGBh1PFfcirkWOQV+v077yF1pSy3DGw= +github.com/bluele/gcache v0.0.2/go.mod h1:m15KV+ECjptwSPxKhOhQoAFQVtUFjTVkc3H8o0t/fp0= github.com/bytedance/sonic v1.12.3 h1:W2MGa7RCU1QTeYRTPE3+88mVC0yXmsRQRChiyVocVjU= github.com/bytedance/sonic v1.12.3/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/bytedance/sonic/loader v0.2.1 h1:1GgorWTqf12TA8mma4DDSbaQigE2wOgQo7iCjjJv3+E= github.com/bytedance/sonic/loader v0.2.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= -github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/d4l3k/messagediff v1.2.2-0.20190829033028-7e0a312ae40b/go.mod h1:Oozbb1TVXFac9FtSIxHBMnBCq2qeH/2KkEQxENCrlLo= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -48,14 +46,12 @@ github.com/duke-git/lancet/v2 v2.3.3 h1:OhqzNzkbJBS9ZlWLo/C7g+WSAOAAyNj7p9CAiEHu github.com/duke-git/lancet/v2 v2.3.3/go.mod h1:zGa2R4xswg6EG9I6WnyubDbFO/+A/RROxIbXcwryTsc= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/faiface/beep v1.1.0 h1:A2gWP6xf5Rh7RG/p9/VAW2jRSDEGQm5sbOb38sf5d4c= github.com/faiface/beep v1.1.0/go.mod h1:6I8p6kK2q4opL/eWb+kAkk38ehnTunWeToJB+s51sT4= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/fxamacker/cbor/v2 v2.8.0 h1:fFtUGXUzXPHTIUdne5+zzMPTfffl3RD5qYnkY40vtxU= +github.com/fxamacker/cbor/v2 v2.8.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= github.com/gabriel-vasile/mimetype v1.4.6 h1:3+PzJTKLkvgjeTbts6msPJt4DixhT4YtFNf1gtGe3zc= github.com/gabriel-vasile/mimetype v1.4.6/go.mod h1:JX1qVKqZd40hUPpAfiNTe0Sne7hdfKSbOqqmkq8GCXc= github.com/gdamore/encoding v1.0.0/go.mod h1:alR0ol34c49FCSBLjhosxzcPHQbf2trDkoo5dl+VrEg= @@ -78,6 +74,8 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-ozzo/ozzo-validation/v4 v4.3.0 h1:byhDUpfEwjsVQb1vBunvIjh2BHQ9ead57VkAEY4V+Es= +github.com/go-ozzo/ozzo-validation/v4 v4.3.0/go.mod h1:2NKgrcHl3z6cJs+3Oo940FPRiTzuqKbvfrL2RxCj6Ew= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -89,48 +87,39 @@ github.com/go-playground/validator/v10 v10.22.1/go.mod h1:dbuPbCMFw/DrkbEynArYaC github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/go-webauthn/webauthn v0.12.3 h1:hHQl1xkUuabUU9uS+ISNCMLs9z50p9mDUZI/FmkayNE= +github.com/go-webauthn/webauthn v0.12.3/go.mod h1:4JRe8Z3W7HIw8NGEWn2fnUwecoDzkkeach/NnvhkqGY= +github.com/go-webauthn/x v0.1.20 h1:brEBDqfiPtNNCdS/peu8gARtq8fIPsHz0VzpPjGvgiw= +github.com/go-webauthn/x v0.1.20/go.mod h1:n/gAc8ssZJGATM0qThE+W+vfgXiMedsWi3wf/C4lld0= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= -github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/generative-ai-go v0.18.0 h1:6ybg9vOCLcI/UpBBYXOTVgvKmcUKFRNj+2Cj3GnebSo= github.com/google/generative-ai-go v0.18.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/go-tpm v0.9.3 h1:+yx0/anQuGzi+ssRqeD6WpXjW2L/V0dItUayO0i9sRc= +github.com/google/go-tpm v0.9.3/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20240827171923-fa2c70bbbfe5 h1:5iH8iuqE5apketRbSFBy+X1V0o+l+8NF1avt4HWl7cA= github.com/google/pprof v0.0.0-20240827171923-fa2c70bbbfe5/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= -github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM= -github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA= -github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw= -github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA= -github.com/googleapis/gax-go/v2 v2.13.0 h1:yitjD5f7jQHhyDsnhKEBU52NdvvdSeGzlAnDPT0hH1s= -github.com/googleapis/gax-go/v2 v2.13.0/go.mod h1:Z/fvTZXF8/uw7Xu5GuslPw+bplx6SS338j1Is2S+B7A= +github.com/google/wire v0.6.0 h1:HBkoIh4BdSxoyo9PveV8giw7ZsaBOvzWKfcg/6MrVwI= +github.com/google/wire v0.6.0/go.mod h1:F4QhpQ9EDIdJ1Mbop/NZBRB+5yrR6qg3BnctaoUk6NA= +github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU9uHLo7OnF5tL52HFAgMmyrf4= +github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= +github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q= +github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hajimehoshi/go-mp3 v0.3.0/go.mod h1:qMJj/CSDxx6CGHiZeCgbiq2DSUkbK0UbtXShQcnfyMM= @@ -163,20 +152,26 @@ github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa02 github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= -github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/liushuangls/go-anthropic/v2 v2.15.0 h1:zpplg7BRV/9FlMmeMPI0eDwhViB0l9SkNrF8ErYlRoQ= +github.com/liushuangls/go-anthropic/v2 v2.15.0/go.mod h1:kq2yW3JVy1/rph8u5KzX7F3q95CEpCT2RXp/2nfCmb4= github.com/lucasb-eyer/go-colorful v1.0.2/go.mod h1:0MS4r+7BZKSJ5mw4/S5MPN+qHFF1fYclkSPilDOKW0s= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mewkiz/flac v1.0.7/go.mod h1:yU74UH277dBUpqxPouHSQIar3G1X/QIclVbFahSd1pU= github.com/mewkiz/pkg v0.0.0-20190919212034-518ade7978e2/go.mod h1:3E2FUC/qYUfM8+r9zAwpeHJzqRVVMIYnpzD/clwWxyA= +github.com/mileusna/useragent v1.3.5 h1:SJM5NzBmh/hO+4LGeATKpaEX9+b4vcGg2qXGLiNGDws= +github.com/mileusna/useragent v1.3.5/go.mod h1:3d8TOmwL/5I8pJjyVDteHtgDGcefrFUX4ccGOMKNYYc= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -195,23 +190,23 @@ github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQ github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= -github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/sashabaranov/go-openai v1.32.2 h1:8z9PfYaLPbRzmJIYpwcWu6z3XU8F+RwVMF1QRSeSF2M= github.com/sashabaranov/go-openai v1.32.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= @@ -222,114 +217,126 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= -go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= -go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.56.0 h1:yMkBS9yViCc7U7yeLzJPM2XizlfdVvBRSmsQDWu6qc0= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.56.0/go.mod h1:n8MR6/liuGB5EmTETUBeU5ZgqMOlqKRxUaqPQBOANZ8= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 h1:UP6IpuHFkUgOQL9FFQFrZ+5LiwhhYRbi7VZSIx6Nj5s= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0/go.mod h1:qxuZLtbq5QDtdeSHsS7bcf6EH6uO6jUAgk764zd3rhM= -go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY= -go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE= -go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozRPcF2fE= -go.opentelemetry.io/otel/metric v1.31.0/go.mod h1:C3dEloVbLuYoX41KpmAhOqNriGbA+qqH6PQ5E5mUfnY= -go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys= -go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 h1:rgMkmiGfix9vFJDcDi1PK8WEQP4FLQwLDfhp5ZLpFeE= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0/go.mod h1:ijPqXp5P6IRRByFVVg9DY8P5HkxkHE5ARIa+86aXPf4= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 h1:sbiXRNDSWJOTobXh5HyQKjq6wUC5tNybqjIqDpAY4CU= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0/go.mod h1:69uWxva0WgAA/4bu2Yy70SLDBwZXuQ6PbBpbsa5iZrQ= +go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= +go.opentelemetry.io/otel v1.35.0/go.mod h1:UEqy8Zp11hpkUrL73gSlELM0DupHoiq72dR+Zqel/+Y= +go.opentelemetry.io/otel/metric v1.35.0 h1:0znxYu2SNyuMSQT4Y9WDWej0VpcsxkuklLa4/siN90M= +go.opentelemetry.io/otel/metric v1.35.0/go.mod h1:nKVFgxBZ2fReX6IlyW28MgZojkoAkJGaE8CpgeAU3oE= +go.opentelemetry.io/otel/sdk v1.35.0 h1:iPctf8iprVySXSKJffSS79eOjl9pvxV9ZqOWT0QejKY= +go.opentelemetry.io/otel/sdk v1.35.0/go.mod h1:+ga1bZliga3DxJ3CQGg3updiaAJoNECOgJREo9KHGQg= +go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5JpUCaEqEI9o= +go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w= +go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs= +go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= golang.org/x/arch v0.11.0 h1:KXV8WWKCXm6tRpLirl2szsO5j/oOODwZf4hATmGVNs4= golang.org/x/arch v0.11.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= -golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY= golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8= golang.org/x/image v0.0.0-20190220214146-31aff87c08e9/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mobile v0.0.0-20190415191353-3e0bab5405d6/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= -golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= -golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc= +golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190429190828-d89cdac9e872/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190626150813-e07cf5db2756/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220712014510-0a85c31ab51e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= +golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= -golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= -golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= -golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4= +golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= +golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ= golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/api v0.201.0 h1:+7AD9JNM3tREtawRMu8sOjSbb8VYcYXJG/2eEOmfDu0= -google.golang.org/api v0.201.0/go.mod h1:HVY0FCHVs89xIW9fzf/pBvOEm+OolHa86G/txFezyq4= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20241007155032-5fefd90f89a9 h1:nFS3IivktIU5Mk6KQa+v6RKkHUpdQpphqGNLxqNnbEk= -google.golang.org/genproto v0.0.0-20241007155032-5fefd90f89a9/go.mod h1:tEzYTYZxbmVNOu0OAFH9HzdJtLn6h4Aj89zzlBCdHms= -google.golang.org/genproto/googleapis/api v0.0.0-20241015192408-796eee8c2d53 h1:fVoAXEKA4+yufmbdVYv+SE73+cPZbbbe8paLsHfkK+U= -google.golang.org/genproto/googleapis/api v0.0.0-20241015192408-796eee8c2d53/go.mod h1:riSXTwQ4+nqmPGtobMFyW5FqVAmIs0St6VPp4Ug7CE4= -google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53 h1:X58yt85/IXCx0Y3ZwN6sEIKZzQtDEYaBWrDvErdXrRE= -google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E= -google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= -google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.224.0 h1:Ir4UPtDsNiwIOHdExr3fAj4xZ42QjK7uQte3lORLJwU= +google.golang.org/api v0.224.0/go.mod h1:3V39my2xAGkodXy0vEqcEtkqgw2GtrFL5WuBZlCTCOQ= +google.golang.org/genai v1.0.0 h1:9IIZimT9bJm0wiF55VAoGCL8MfOAZcwqRRlxZZ/KSoc= +google.golang.org/genai v1.0.0/go.mod h1:TyfOKRz/QyCaj6f/ZDt505x+YreXnY40l2I6k8TvgqY= +google.golang.org/genproto v0.0.0-20250303144028-a0af3efb3deb h1:ITgPrl429bc6+2ZraNSzMDk3I95nmQln2fuPstKwFDE= +google.golang.org/genproto v0.0.0-20250303144028-a0af3efb3deb/go.mod h1:sAo5UzpjUwgFBCzupwhcLcxHVDK7vG5IqI30YnwX2eE= +google.golang.org/genproto/googleapis/api v0.0.0-20250303144028-a0af3efb3deb h1:p31xT4yrYrSM/G4Sn2+TNUkVhFCbG9y8itM2S6Th950= +google.golang.org/genproto/googleapis/api v0.0.0-20250303144028-a0af3efb3deb/go.mod h1:jbe3Bkdp+Dh2IrslsFCklNhweNTBgSYanP1UXhJDhKg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a h1:GIqLhp/cYUkuGuiT+vJk8vhOP86L4+SP5j8yXgeVpvI= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250409194420-de1ac958c67a/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= +google.golang.org/grpc v1.71.1 h1:ffsFWr7ygTUscGPI0KKK6TLrGz0476KUvvsbqWK0rPI= +google.golang.org/grpc v1.71.1/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/vansante/go-ffprobe.v2 v2.2.0 h1:iuOqTsbfYuqIz4tAU9NWh22CmBGxlGHdgj4iqP+NUmY= gopkg.in/vansante/go-ffprobe.v2 v2.2.0/go.mod h1:qF0AlAjk7Nqzqf3y333Ly+KxN3cKF2JqA3JT5ZheUGE= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= @@ -340,8 +347,6 @@ gorm.io/driver/postgres v1.5.11/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSk gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= modernc.org/cc/v4 v4.21.4 h1:3Be/Rdo1fpr8GrQ7IVw9OHtplU4gWbb+wNgeoBMmGLQ= modernc.org/cc/v4 v4.21.4/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ= modernc.org/ccgo/v4 v4.21.0 h1:kKPI3dF7RIag8YcToh5ZwDcVMIv6VGa0ED5cvh0LMW4= diff --git a/internal/auth/auth.go b/internal/auth/auth.go new file mode 100644 index 0000000..1011c1e --- /dev/null +++ b/internal/auth/auth.go @@ -0,0 +1,77 @@ +package auth + +import ( + "errors" + "opencatd-open/internal/model" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +type Claims struct { + UserID int64 `json:"user_id"` + Name string `json:"name"` + Type string `json:"type"` + jwt.RegisteredClaims +} + +type TokenPair struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` +} + +func GenerateTokenPair(user *model.User, secret string, accessExpire, refreshExpire time.Duration) (*TokenPair, error) { + // Generate access token + accessToken, err := generateToken(user, "access", secret, accessExpire) + if err != nil { + return nil, err + } + + // Generate refresh token + refreshToken, err := generateToken(user, "refresh", secret, refreshExpire) + if err != nil { + return nil, err + } + + return &TokenPair{ + AccessToken: accessToken, + RefreshToken: refreshToken, + }, nil +} + +func generateToken(user *model.User, tokenType, secret string, expire time.Duration) (string, error) { + now := time.Now() + + claims := Claims{ + UserID: user.ID, + Name: user.Username, + Type: tokenType, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(expire)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(secret)) +} + +func ValidateToken(tokenString, secret string) (*Claims, error) { + token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, errors.New("unexpected signing method") + } + return []byte(secret), nil + }) + + if err != nil { + return nil, err + } + + if claims, ok := token.Claims.(*Claims); ok && token.Valid { + return claims, nil + } + + return nil, jwt.ErrInvalidKey +} diff --git a/team/consts/consts.go b/internal/consts/consts.go similarity index 70% rename from team/consts/consts.go rename to internal/consts/consts.go index 27a0671..45655e8 100644 --- a/team/consts/consts.go +++ b/internal/consts/consts.go @@ -2,12 +2,16 @@ package consts import "gorm.io/gorm" +const SecretKey = "openteam" + +const Day = 24 * 60 * 60 // day := 86400 + type UserRole int const ( RoleUser UserRole = iota * 10 RoleAdmin - RoleSuperAdmin + RoleRoot ) const ( @@ -35,10 +39,10 @@ func OpenOrClose(status bool) int { return StatusDisabled } -type DBType int +// type DBType int -const ( - DBTypeMySQL DBType = iota - DBTypePostgreSQL - DBTypeSQLite -) +// const ( +// DBTypeMySQL DBType = iota +// DBTypePostgreSQL +// DBTypeSQLite +// ) diff --git a/internal/controller/apikey.go b/internal/controller/apikey.go new file mode 100644 index 0000000..1157abb --- /dev/null +++ b/internal/controller/apikey.go @@ -0,0 +1,152 @@ +package controller + +import ( + "net/http" + "opencatd-open/internal/consts" + "opencatd-open/internal/dto" + "opencatd-open/internal/model" + "strconv" + "strings" + + "github.com/duke-git/lancet/v2/slice" + "github.com/gin-gonic/gin" +) + +func (a Api) CreateApiKey(c *gin.Context) { + role := c.MustGet("user_role").(*consts.UserRole) + if *role < consts.RoleAdmin { + dto.Fail(c, 403, "Permission denied") + return + } + req := new(model.ApiKey) + err := c.ShouldBind(&req) + if err != nil { + dto.Fail(c, 400, err.Error()) + } + + err = a.keyService.CreateApiKey(c, req) + if err != nil { + dto.Fail(c, 400, err.Error()) + } else { + dto.Success(c, nil) + } + +} + +func (a Api) GetApiKey(c *gin.Context) { + role := c.MustGet("user_role").(*consts.UserRole) + if *role < consts.RoleAdmin { + dto.Fail(c, 403, "Permission denied") + return + } + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } + key, err := a.keyService.GetApiKey(c, id) + if err != nil { + dto.Fail(c, 400, err.Error()) + } else { + dto.Success(c, key) + } +} + +func (a Api) ListApiKey(c *gin.Context) { + role := c.MustGet("user_role").(*consts.UserRole) + if *role < consts.RoleAdmin { + dto.Fail(c, 403, "Permission denied") + return + } + limit, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20")) + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + offset := (page - 1) * limit + active := c.QueryArray("active[]") + if !slice.ContainSubSlice([]string{"true", "false"}, active) { + dto.Fail(c, http.StatusBadRequest, "active must be true or false") + return + } + + keys, total, err := a.keyService.ListApiKey(c, limit, offset, active) + if err != nil { + dto.Fail(c, 500, err.Error()) + } else { + dto.Success(c, gin.H{ + "total": total, + "keys": keys, + }) + } +} + +func (a Api) DeleteApiKey(c *gin.Context) { + role := c.MustGet("user_role").(*consts.UserRole) + if *role < consts.RoleAdmin { + dto.Fail(c, 403, "Permission denied") + return + } + var batchid dto.BatchIDRequest + err := c.ShouldBind(&batchid) + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } + + err = a.keyService.DeleteApiKey(c, batchid.IDs) + if err != nil { + dto.Fail(c, 500, err.Error()) + } else { + dto.Success(c, nil) + } +} + +func (a Api) UpdateApiKey(c *gin.Context) { + role := c.MustGet("user_role").(*consts.UserRole) + if *role < consts.RoleAdmin { + dto.Fail(c, 403, "Permission denied") + return + } + var req model.ApiKey + err := c.ShouldBind(&req) + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } + + err = a.keyService.UpdateApiKey(c, &req) + if err != nil { + dto.Fail(c, 500, err.Error()) + } else { + dto.Success(c, nil) + } +} + +func (a Api) ApiKeyOption(c *gin.Context) { + role := c.MustGet("user_role").(*consts.UserRole) + if *role < consts.RoleAdmin { + dto.Fail(c, 403, "Permission denied") + return + } + option := strings.ToLower(c.Param("option")) + var batchid dto.BatchIDRequest + err := c.ShouldBind(&batchid) + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } + switch option { + case "enable": + err = a.keyService.EnableApiKey(c, batchid.IDs) + case "disable": + err = a.keyService.DisableApiKey(c, batchid.IDs) + case "delete": + err = a.keyService.DeleteApiKey(c, batchid.IDs) + default: + dto.Fail(c, 400, "invalid option, only support enable, disable, delete") + return + } + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } + dto.Success(c, nil) +} diff --git a/internal/controller/init.go b/internal/controller/init.go new file mode 100644 index 0000000..37ac0cd --- /dev/null +++ b/internal/controller/init.go @@ -0,0 +1,27 @@ +package controller + +import ( + "opencatd-open/internal/service" + + "gorm.io/gorm" +) + +type Api struct { + db *gorm.DB + userService *service.UserServiceImpl + tokenService *service.TokenServiceImpl + keyService *service.ApiKeyServiceImpl + webAuthService *service.WebAuthnService + usageService *service.UsageService +} + +func NewApi(db *gorm.DB, userService *service.UserServiceImpl, tokenService *service.TokenServiceImpl, keyService *service.ApiKeyServiceImpl, webAuthService *service.WebAuthnService, usageService *service.UsageService) *Api { + return &Api{ + db: db, + userService: userService, + tokenService: tokenService, + keyService: keyService, + webAuthService: webAuthService, + usageService: usageService, + } +} diff --git a/internal/controller/proxy/chat_proxy.go b/internal/controller/proxy/chat_proxy.go new file mode 100644 index 0000000..98d0c65 --- /dev/null +++ b/internal/controller/proxy/chat_proxy.go @@ -0,0 +1,60 @@ +package controller + +import ( + "fmt" + "net/http" + "opencatd-open/internal/dto" + "opencatd-open/llm" + "opencatd-open/llm/claude/v2" + "opencatd-open/llm/google/v2" + "opencatd-open/llm/openai_compatible" + + "github.com/gin-gonic/gin" +) + +func (h *Proxy) ChatHandler(c *gin.Context) { + var chatreq llm.ChatRequest + if err := c.ShouldBindJSON(&chatreq); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + err := h.SelectApiKey(chatreq.Model) + if err != nil { + dto.WrapErrorAsOpenAI(c, 500, err.Error()) + return + } + + var llm llm.LLM + switch *h.apikey.ApiType { + case "claude": + llm, err = claude.NewClaude(h.apikey) + case "gemini": + llm, err = google.NewGemini(c, h.apikey) + case "openai", "azure", "github": + fallthrough + default: + llm, err = openai_compatible.NewOpenAICompatible(h.apikey) + if err != nil { + dto.WrapErrorAsOpenAI(c, 500, fmt.Errorf("create llm client error: %w", err).Error()) + return + } + } + + if !chatreq.Stream { + resp, err := llm.Chat(c, chatreq) + if err != nil { + dto.WrapErrorAsOpenAI(c, 500, err.Error()) + } + c.JSON(http.StatusOK, resp) + + } else { + datachan, err := llm.StreamChat(c, chatreq) + if err != nil { + dto.WrapErrorAsOpenAI(c, 500, err.Error()) + } + for data := range datachan { + c.SSEvent("", data) + } + } +} diff --git a/internal/controller/proxy/proxy.go b/internal/controller/proxy/proxy.go new file mode 100644 index 0000000..3941f36 --- /dev/null +++ b/internal/controller/proxy/proxy.go @@ -0,0 +1,345 @@ +package controller + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "math/rand" + "net/http" + "net/url" + "opencatd-open/internal/dao" + "opencatd-open/internal/model" + "opencatd-open/pkg/config" + "os" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/lib/pq" + "github.com/tidwall/gjson" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +type Proxy struct { + ctx context.Context + cfg *config.Config + db *gorm.DB + wg *sync.WaitGroup + usageChan chan *model.Usage // 用于异步处理的channel + apikey *model.ApiKey + httpClient *http.Client + + userDAO *dao.UserDAO + apiKeyDao *dao.ApiKeyDAO + tokenDAO *dao.TokenDAO + usageDAO *dao.UsageDAO + dailyUsageDAO *dao.DailyUsageDAO +} + +func NewProxy(ctx context.Context, cfg *config.Config, db *gorm.DB, wg *sync.WaitGroup, userDAO *dao.UserDAO, apiKeyDAO *dao.ApiKeyDAO, tokenDAO *dao.TokenDAO, usageDAO *dao.UsageDAO, dailyUsageDAO *dao.DailyUsageDAO) *Proxy { + client := http.DefaultClient + if os.Getenv("LOCAL_PROXY") != "" { + proxyUrl, err := url.Parse(os.Getenv("LOCAL_PROXY")) + if err == nil { + tr := &http.Transport{ + Proxy: http.ProxyURL(proxyUrl), + } + client.Transport = tr + } + } + np := &Proxy{ + ctx: ctx, + cfg: cfg, + db: db, + wg: wg, + httpClient: client, + usageChan: make(chan *model.Usage, cfg.UsageChanSize), + userDAO: userDAO, + apiKeyDao: apiKeyDAO, + tokenDAO: tokenDAO, + usageDAO: usageDAO, + dailyUsageDAO: dailyUsageDAO, + } + + go np.ProcessUsage() + go np.ScheduleTask() + + return np +} + +func (p *Proxy) HandleProxy(c *gin.Context) { + if c.Request.URL.Path == "/v1/chat/completions" { + p.ChatHandler(c) + return + } +} + +func (p *Proxy) SendUsage(usage *model.Usage) { + select { + case p.usageChan <- usage: + default: + log.Println("usage channel is full, skip processing") + bj, _ := json.Marshal(usage) + log.Println(string(bj)) + //TODO: send to a queue + } +} + +func (p *Proxy) ProcessUsage() { + for i := 0; i < p.cfg.UsageWorker; i++ { + p.wg.Add(1) + go func(i int) { + defer p.wg.Done() + for { + select { + case usage, ok := <-p.usageChan: + if !ok { + // channel 关闭,退出程序 + return + } + err := p.Do(usage) + if err != nil { + log.Printf("process usage error: %v\n", err) + } + case <-p.ctx.Done(): + // close(s.usageChan) + // for usage := range s.usageChan { + // if err := s.Do(usage); err != nil { + // fmt.Printf("[close event]process usage error: %v\n", err) + // } + // } + for { + select { + case usage, ok := <-p.usageChan: + if !ok { + return + } + if err := p.Do(usage); err != nil { + fmt.Printf("[close event]process usage error: %v\n", err) + } + default: + fmt.Printf("usageChan is empty,usage worker %d done\n", i) + return + } + } + } + } + + }(i) + + } +} + +func (p *Proxy) Do(usage *model.Usage) error { + err := p.db.Transaction(func(tx *gorm.DB) error { + // 1. 记录使用记录 + if err := tx.WithContext(p.ctx).Create(usage).Error; err != nil { + return fmt.Errorf("create usage error: %w", err) + } + + // 2. 更新每日统计(upsert 操作) + dailyUsage := model.DailyUsage{ + UserID: usage.UserID, + TokenID: usage.TokenID, + Capability: usage.Capability, + Date: time.Date(usage.Date.Year(), usage.Date.Month(), usage.Date.Day(), 0, 0, 0, 0, usage.Date.Location()), + Model: usage.Model, + Stream: usage.Stream, + PromptTokens: usage.PromptTokens, + CompletionTokens: usage.CompletionTokens, + TotalTokens: usage.TotalTokens, + Cost: usage.Cost, + } + + // 使用 OnConflict 实现 upsert + if err := tx.WithContext(p.ctx).Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "user_id"}, {Name: "token_id"}, {Name: "capability"}, {Name: "date"}}, // 唯一键 + DoUpdates: clause.Assignments(map[string]interface{}{ + "prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens), + "completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens), + "total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens), + "cost": gorm.Expr("cost + ?", usage.Cost), + }), + }).Create(&dailyUsage).Error; err != nil { + return fmt.Errorf("upsert daily usage error: %w", err) + } + + // 3. 更新用户额度 + if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", usage.UserID).Updates(map[string]interface{}{ + "quota": gorm.Expr("quota - ?", usage.Cost), + "used_quota": gorm.Expr("used_quota + ?", usage.Cost), + }).Error; err != nil { + return fmt.Errorf("update user quota and used_quota error: %w", err) + } + + return nil + }) + return err +} + +func (p *Proxy) SelectApiKey(model string) error { + akpikeys, err := p.apiKeyDao.FindApiKeysBySupportModel(p.db, model) + if err != nil { + return err + } + if len(akpikeys) == 0 { + return errors.New("no available apikey") + } else { + if strings.HasPrefix(model, "gpt") { + keys, err := p.apiKeyDao.FindKeys(map[string]any{"type = ?": "openai"}) + if err != nil { + return err + } + akpikeys = append(akpikeys, keys...) + } + + if strings.HasPrefix(model, "gemini") { + keys, err := p.apiKeyDao.FindKeys(map[string]any{"type = ?": "gemini"}) + if err != nil { + return err + } + akpikeys = append(akpikeys, keys...) + } + + if strings.HasPrefix(model, "claude") { + keys, err := p.apiKeyDao.FindKeys(map[string]any{"type = ?": "claude"}) + if err != nil { + return err + } + akpikeys = append(akpikeys, keys...) + } + } + if len(akpikeys) == 0 { + return errors.New("no available apikey") + + } + + if len(akpikeys) == 1 { + p.apikey = &akpikeys[0] + return nil + } + length := len(akpikeys) - 1 + + p.apikey = &akpikeys[rand.Intn(length)] + + return nil +} + +func (p *Proxy) updateSupportModel() { + + keys, err := p.apiKeyDao.FindKeys(map[string]interface{}{"type in ?": "openai,azure,claude"}) + if err != nil { + return + } + for _, key := range keys { + var supportModels []string + if *key.ApiType == "openai" || *key.ApiType == "azure" { + supportModels, err = p.getOpenAISupportModels(key) + } + if *key.ApiType == "claude" { + supportModels, err = p.getClaudeSupportModels(key) + } + + if err != nil { + log.Println(err) + continue + } + if len(supportModels) == 0 { + continue + + } + if p.cfg.DB_Type == "sqlite" { + bytejson, _ := json.Marshal(supportModels) + if err := p.db.Model(&model.ApiKey{}).Where("id = ?", key.ID).UpdateColumn("support_models", string(bytejson)).Error; err != nil { + log.Println(err) + } + } else if p.cfg.DB_Type == "postgres" { + if err := p.db.Model(&model.ApiKey{}).Where("id = ?", key.ID).UpdateColumn("support_models", pq.StringArray(supportModels)).Error; err != nil { + log.Println(err) + } + } + + } + +} + +func (p *Proxy) ScheduleTask() { + + func() { + for { + select { + case <-time.After(time.Duration(p.cfg.TaskTimeInterval) * time.Minute): + p.updateSupportModel() + + case <-p.ctx.Done(): + fmt.Println("schedule task done") + return + } + } + }() +} + +func (p *Proxy) getOpenAISupportModels(apikey model.ApiKey) ([]string, error) { + openaiModelsUrl := "https://api.openai.com/v1/models" + // https://learn.microsoft.com/zh-cn/rest/api/azureopenai/models/list?view=rest-azureopenai-2025-02-01-preview&tabs=HTTP + azureModelsUrl := "/openai/deployments?api-version=2022-12-01" + + var supportModels []string + var req *http.Request + if *apikey.ApiType == "azure" { + req, _ = http.NewRequest("GET", *apikey.Endpoint+azureModelsUrl, nil) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("api-key", *apikey.ApiKey) + } else { + req, _ = http.NewRequest("GET", openaiModelsUrl, nil) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+*apikey.ApiKey) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusOK { + bytesbody, _ := io.ReadAll(resp.Body) + result := gjson.GetBytes(bytesbody, "data.#.id").Array() + for _, v := range result { + model := v.Str + model = strings.Replace(model, "-35-", "-3.5-", -1) + model = strings.Replace(model, "-41-", "-4.1-", -1) + supportModels = append(supportModels, model) + } + } + return supportModels, nil +} + +func (p *Proxy) getClaudeSupportModels(apikey model.ApiKey) ([]string, error) { + // https://docs.anthropic.com/en/api/models-list + claudemodelsUrl := "https://api.anthropic.com/v1/models" + var supportModels []string + + req, _ := http.NewRequest("GET", claudemodelsUrl, nil) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-api-key", *apikey.ApiKey) + req.Header.Set("anthropic-version", "2023-06-01") + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusOK { + bytesbody, _ := io.ReadAll(resp.Body) + result := gjson.GetBytes(bytesbody, "data.#.id").Array() + for _, v := range result { + supportModels = append(supportModels, v.Str) + } + } + return supportModels, nil +} diff --git a/team/handler/team/middleware.go b/internal/controller/team/middleware.go similarity index 92% rename from team/handler/team/middleware.go rename to internal/controller/team/middleware.go index f6588c9..6f40097 100644 --- a/team/handler/team/middleware.go +++ b/internal/controller/team/middleware.go @@ -1,15 +1,15 @@ -package handler +package controller import ( "fmt" "net/http" - "opencatd-open/team/consts" + "opencatd-open/internal/consts" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" ) -func (h *TeamHandler) AuthMiddleware() gin.HandlerFunc { +func (h *Team) AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { if c.Request.URL.Path == "/1/users/init" { c.Next() diff --git a/team/handler/team/team.go b/internal/controller/team/team.go similarity index 71% rename from team/handler/team/team.go rename to internal/controller/team/team.go index c87aab1..b699bbf 100644 --- a/team/handler/team/team.go +++ b/internal/controller/team/team.go @@ -1,24 +1,26 @@ -package handler +package controller import ( "errors" "net/http" + "slices" "strconv" "strings" "time" + "opencatd-open/internal/consts" + dto "opencatd-open/internal/dto/team" + "opencatd-open/internal/model" + service "opencatd-open/internal/service/team" "opencatd-open/internal/utils" - "opencatd-open/team/consts" - dto "opencatd-open/team/dto/team" - "opencatd-open/team/model" - "opencatd-open/team/service" + "github.com/duke-git/lancet/v2/slice" "github.com/gin-gonic/gin" "github.com/google/uuid" "gorm.io/gorm" ) -type TeamHandler struct { +type Team struct { db *gorm.DB userService service.UserService tokenService service.TokenService @@ -26,8 +28,8 @@ type TeamHandler struct { usageService service.UsageService } -func NewTeamHandler(userService service.UserService, tokenService service.TokenService, keyService service.ApiKeyService, usageService service.UsageService) *TeamHandler { - return &TeamHandler{ +func NewTeam(userService service.UserService, tokenService service.TokenService, keyService service.ApiKeyService, usageService service.UsageService) *Team { + return &Team{ userService: userService, tokenService: tokenService, keyService: keyService, @@ -36,7 +38,7 @@ func NewTeamHandler(userService service.UserService, tokenService service.TokenS } // initadmin -func (h *TeamHandler) InitAdmin(c *gin.Context) { +func (h *Team) InitAdmin(c *gin.Context) { admin, err := h.userService.GetUser(c, 1) if err != nil { @@ -45,12 +47,12 @@ func (h *TeamHandler) InitAdmin(c *gin.Context) { Name: "root", Username: "root", Password: "openteam", - Role: int(consts.RoleSuperAdmin), + Role: utils.ToPtr(consts.RoleRoot), Tokens: []model.Token{ { Name: "default", - Key: "team-" + strings.ReplaceAll(uuid.New().String(), "-", ""), - UnlimitedQuota: true, + Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""), + UnlimitedQuota: utils.ToPtr(true), }, }, } @@ -59,8 +61,8 @@ func (h *TeamHandler) InitAdmin(c *gin.Context) { return } var result = dto.UserInfo{ - ID: int(user.ID), - Name: user.Name, + ID: user.ID, + Name: user.Username, Token: user.Tokens[0].Key, Status: utils.ToPtr(user.Status == consts.StatusEnabled), } @@ -80,18 +82,7 @@ func (h *TeamHandler) InitAdmin(c *gin.Context) { } } -func (h *TeamHandler) Me(c *gin.Context) { - // token := c.GetHeader("Authorization") - // token = strings.TrimPrefix(token, "Bearer ") - // userToken, err := h.tokenService.GetTokenByKey(token) - // if err != nil { - // c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) - // return - // } - // if userToken.ID != 1 { - // c.JSON(http.StatusForbidden, gin.H{"error": "only first user token can access"}) - // return - // } +func (h *Team) Me(c *gin.Context) { token, exists := c.Get("token") if !exists { c.JSON(http.StatusNotFound, gin.H{"error": "token not found"}) @@ -100,7 +91,7 @@ func (h *TeamHandler) Me(c *gin.Context) { userToken := token.(*model.Token) c.JSON(http.StatusOK, dto.UserInfo{ - ID: int(userToken.UserID), + ID: userToken.UserID, Name: userToken.User.Name, Token: userToken.Key, Status: utils.ToPtr(userToken.User.Status == consts.StatusEnabled), @@ -109,7 +100,7 @@ func (h *TeamHandler) Me(c *gin.Context) { } // CreateUser 创建用户 -func (h *TeamHandler) CreateUser(c *gin.Context) { +func (h *Team) CreateUser(c *gin.Context) { var userReq dto.UserInfo if err := c.ShouldBindJSON(&userReq); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid input"}) @@ -122,10 +113,14 @@ func (h *TeamHandler) CreateUser(c *gin.Context) { return } userToken := token.(*model.Token) - if userToken.User.Role < int(consts.RoleAdmin) { + if *userToken.User.Role < consts.RoleAdmin { // 普通用户只能创建自己的token create := &model.Token{ Name: userReq.Name, - Key: "team-" + strings.ReplaceAll(uuid.New().String(), "-", ""), + Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""), + } + if userReq.Token != "" { + _key := strings.ReplaceAll(userReq.Token, "-", "") + create.Key = "sk-team-" + strings.ReplaceAll(_key, " ", "") } if err := h.tokenService.Create(c.Request.Context(), create); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -136,11 +131,11 @@ func (h *TeamHandler) CreateUser(c *gin.Context) { user := &model.User{ Name: userReq.Name, Username: userReq.Name, - Role: int(consts.RoleUser), + Role: utils.ToPtr(consts.RoleUser), Tokens: []model.Token{ { Name: "default", - Key: "team-" + strings.ReplaceAll(uuid.New().String(), "-", ""), + Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""), }, }, } @@ -156,7 +151,7 @@ func (h *TeamHandler) CreateUser(c *gin.Context) { } // GetUser 获取用户信息 -func (h *TeamHandler) GetUser(c *gin.Context) { +func (h *Team) GetUser(c *gin.Context) { idStr := c.Param("id") id, err := strconv.ParseInt(idStr, 10, 64) if err != nil { @@ -174,7 +169,7 @@ func (h *TeamHandler) GetUser(c *gin.Context) { } // UpdateUser 更新用户信息 -func (h *TeamHandler) UpdateUser(c *gin.Context) { +func (h *Team) UpdateUser(c *gin.Context) { var user model.User if err := c.ShouldBindJSON(&user); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid input"}) @@ -197,7 +192,7 @@ func (h *TeamHandler) UpdateUser(c *gin.Context) { } // DeleteUser 删除用户 -func (h *TeamHandler) DeleteUser(c *gin.Context) { +func (h *Team) DeleteUser(c *gin.Context) { idStr := c.Param("id") id, err := strconv.ParseInt(idStr, 10, 64) if err != nil { @@ -212,14 +207,14 @@ func (h *TeamHandler) DeleteUser(c *gin.Context) { } userToken := token.(*model.Token) - if userToken.User.Role < int(consts.RoleAdmin) { // 用户只能删除自己的token - err := h.tokenService.Delete(c.Request.Context(), int(id)) + if *userToken.User.Role < consts.RoleAdmin { // 用户只能删除自己的token + err := h.tokenService.Delete(c.Request.Context(), id) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } } else { - if err := h.userService.DeleteUser(c.Request.Context(), id, userToken.UserID); err != nil { + if err := h.userService.DeleteUser(c, id, userToken.UserID); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -228,7 +223,7 @@ func (h *TeamHandler) DeleteUser(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "ok"}) } -func (h *TeamHandler) ListUsages(c *gin.Context) { +func (h *Team) ListUsages(c *gin.Context) { fromStr := c.Query("from") toStr := c.Query("to") @@ -258,7 +253,7 @@ func (h *TeamHandler) ListUsages(c *gin.Context) { token, _ := c.Get("token") userToken := token.(*model.Token) - if userToken.User.Role < int(consts.RoleAdmin) { + if *userToken.User.Role < consts.RoleAdmin { listUsage, err = h.usageService.ListByDateRange(c.Request.Context(), from, to, map[string]interface{}{"user_id": userToken.UserID}) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -277,29 +272,24 @@ func (h *TeamHandler) ListUsages(c *gin.Context) { } // ListUsers 获取用户列表 -func (h *TeamHandler) ListUsers(c *gin.Context) { - pageStr := c.DefaultQuery("page", "1") - pageSizeStr := c.DefaultQuery("pageSize", "100") +func (h *Team) ListUsers(c *gin.Context) { + limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100")) + offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0")) + active := c.DefaultQuery("active", "") - page, err := strconv.Atoi(pageStr) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid page number"}) + if !slices.Contains([]string{"true", "false", ""}, active) { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid active value"}) return } - pageSize, err := strconv.Atoi(pageSizeStr) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid page size"}) - return - } token, exists := c.Get("token") if !exists { c.JSON(http.StatusNotFound, gin.H{"error": "Unauthorized"}) return } userToken := token.(*model.Token) - if userToken.User.Role < int(consts.RoleAdmin) { - tokens, _, err := h.tokenService.ListsWithFilters(c, 0, 100, map[string]interface{}{"user_id": userToken.UserID}) + if *userToken.User.Role < consts.RoleAdmin { // 用户只能获取自己的token + tokens, _, err := h.tokenService.Lists(c, limit, offset) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -307,7 +297,7 @@ func (h *TeamHandler) ListUsers(c *gin.Context) { var userDTOs []dto.UserInfo for _, token := range tokens { userDTOs = append(userDTOs, dto.UserInfo{ - ID: int(token.User.ID), + ID: token.User.ID, Name: token.User.Name, Token: token.Key, Status: utils.ToPtr(token.User.Status == consts.StatusEnabled), @@ -317,7 +307,7 @@ func (h *TeamHandler) ListUsers(c *gin.Context) { return } - users, _, err := h.userService.ListUsers(c.Request.Context(), page, pageSize) + users, err := h.userService.ListUsers(c, limit, offset, active) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -326,7 +316,7 @@ func (h *TeamHandler) ListUsers(c *gin.Context) { var userDTOs []dto.UserInfo for _, user := range users { useres := dto.UserInfo{ - ID: int(user.ID), + ID: user.ID, Name: user.Name, Status: utils.ToPtr(user.Status == consts.StatusEnabled), @@ -340,9 +330,8 @@ func (h *TeamHandler) ListUsers(c *gin.Context) { c.JSON(http.StatusOK, userDTOs) } -func (h *TeamHandler) ResetUserToken(c *gin.Context) { - idstr := c.Param("id") - id, err := strconv.Atoi(idstr) +func (h *Team) ResetUserToken(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"}) return @@ -359,10 +348,10 @@ func (h *TeamHandler) ResetUserToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - findtoken.Key = "team-" + strings.ReplaceAll(uuid.New().String(), "-", "") + findtoken.Key = "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", "") - if userToken.User.Role < int(consts.RoleAdmin) { // 非管理员只能修改自己的token - if userToken.User.Role <= findtoken.User.Role || userToken.UserID != findtoken.UserID { + if *userToken.User.Role < consts.RoleAdmin { // 非管理员只能修改自己的token + if *userToken.User.Role <= *findtoken.User.Role || userToken.UserID != findtoken.UserID { c.JSON(http.StatusForbidden, gin.H{"error": "forbidden"}) return } @@ -379,34 +368,34 @@ func (h *TeamHandler) ResetUserToken(c *gin.Context) { } c.JSON(http.StatusOK, dto.UserInfo{ - ID: int(findtoken.User.ID), + ID: findtoken.User.ID, Name: findtoken.User.Name, Token: findtoken.Key, }) } -func (h *TeamHandler) CreateKey(c *gin.Context) { +func (h *Team) CreateKey(c *gin.Context) { token, exists := c.Get("token") if !exists { c.JSON(http.StatusNotFound, gin.H{"error": "token not found"}) return } userToken := token.(*model.Token) - if userToken.User.Role < int(consts.RoleAdmin) { + if *userToken.User.Role < consts.RoleAdmin { c.JSON(http.StatusForbidden, gin.H{"error": "forbidden"}) return } - var key dto.KeyInfo + var key dto.ApiKeyInfo if err := c.ShouldBindJSON(&key); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } err := h.keyService.Create(&model.ApiKey{ - Name: key.Name, - ApiType: key.ApiType, - ApiKey: key.Key, - Endpoint: key.Endpoint, + Name: utils.ToPtr(key.Name), + ApiType: utils.ToPtr(key.ApiType), + ApiKey: utils.ToPtr(key.Key), + Endpoint: utils.ToPtr(key.Endpoint), }) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -415,31 +404,40 @@ func (h *TeamHandler) CreateKey(c *gin.Context) { c.JSON(http.StatusOK, key) } -func (h *TeamHandler) ListKeys(c *gin.Context) { - keys, err := h.keyService.List(0, 100, nil) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) +func (h *Team) ListKeys(c *gin.Context) { + limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20")) + offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0")) + active := c.Query("active") + if !slice.Contain([]string{"true", "false", ""}, active) { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid active value"}) + return } - var keysDTO []dto.KeyInfo + keys, err := h.keyService.List(limit, offset, active) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + var keysDTO []dto.ApiKeyInfo for _, key := range keys { - keylength := len(key.ApiKey) / 3 + keylength := len(*key.ApiKey) / 3 if keylength < 1 { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid key length"}) return } - keysDTO = append(keysDTO, dto.KeyInfo{ + keysDTO = append(keysDTO, dto.ApiKeyInfo{ ID: int(key.ID), - Name: key.Name, - ApiType: key.ApiType, - Endpoint: key.Endpoint, - Key: key.ApiKey[:keylength] + "****" + key.ApiKey[len(key.ApiKey)-keylength:], + Name: *key.Name, + ApiType: *key.ApiType, + Endpoint: *key.Endpoint, + Key: *key.ApiKey, }) } c.JSON(http.StatusOK, keysDTO) } -func (h *TeamHandler) UpdateKey(c *gin.Context) { +func (h *Team) UpdateKey(c *gin.Context) { // 1. 获取并验证ID id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { @@ -448,7 +446,7 @@ func (h *TeamHandler) UpdateKey(c *gin.Context) { } // 2. 解析请求体 - var updateKey dto.KeyInfo // 更明确的命名 + var updateKey dto.ApiKeyInfo // 更明确的命名 if err := c.ShouldBindJSON(&updateKey); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return @@ -473,7 +471,7 @@ func (h *TeamHandler) UpdateKey(c *gin.Context) { c.JSON(http.StatusOK, updatedKey) } -func (h *TeamHandler) DeleteKey(c *gin.Context) { +func (h *Team) DeleteKey(c *gin.Context) { // 1. 获取并验证ID id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { @@ -491,7 +489,7 @@ func (h *TeamHandler) DeleteKey(c *gin.Context) { } // ChangePassword 修改密码 -func (h *TeamHandler) ChangePassword(c *gin.Context) { +func (h *Team) ChangePassword(c *gin.Context) { userID := c.GetInt64("userID") // 假设从上下文中获取用户ID var req struct { @@ -513,15 +511,14 @@ func (h *TeamHandler) ChangePassword(c *gin.Context) { } // ResetPassword 重置密码 -func (h *TeamHandler) ResetPassword(c *gin.Context) { - idStr := c.Param("id") - id, err := strconv.ParseInt(idStr, 10, 64) +func (h *Team) ResetPassword(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"}) return } - operatorID := c.GetInt64("userID") // 假设从上下文中获取操作者ID + operatorID := int64(c.GetInt("userID")) // 假设从上下文中获取操作者ID if err := h.userService.ResetPassword(c.Request.Context(), id, operatorID); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -531,16 +528,16 @@ func (h *TeamHandler) ResetPassword(c *gin.Context) { } // EnableUser 启用用户 -func (h *TeamHandler) EnableUser(c *gin.Context) { - idStr := c.Param("id") - id, err := strconv.ParseInt(idStr, 10, 64) +func (h *Team) EnableUser(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"}) return } - operatorID := c.GetInt64("userID") // 假设从上下文中获取操作者ID - if err := h.userService.EnableUser(c.Request.Context(), id, operatorID); err != nil { + operatorID := int64(c.GetInt("userID")) // 假设从上下文中获取操作者ID + + if err := h.userService.BatchEnableUsers(c, []int64{id}, operatorID); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -549,16 +546,15 @@ func (h *TeamHandler) EnableUser(c *gin.Context) { } // DisableUser 禁用用户 -func (h *TeamHandler) DisableUser(c *gin.Context) { - idStr := c.Param("id") - id, err := strconv.ParseInt(idStr, 10, 64) +func (h *Team) DisableUser(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"}) return } - operatorID := c.GetInt64("userID") // 假设从上下文中获取操作者ID - if err := h.userService.DisableUser(c.Request.Context(), id, operatorID); err != nil { + operatorID := int64(c.GetInt("userID")) // 假设从上下文中获取操作者ID + if err := h.userService.BatchDisableUsers(c.Request.Context(), []int64{id}, operatorID); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } diff --git a/internal/controller/user.go b/internal/controller/user.go new file mode 100644 index 0000000..85e9814 --- /dev/null +++ b/internal/controller/user.go @@ -0,0 +1,218 @@ +package controller + +import ( + "fmt" + "net/http" + "opencatd-open/internal/dto" + "opencatd-open/internal/model" + "opencatd-open/internal/utils" + "strconv" + "strings" + + "github.com/duke-git/lancet/v2/slice" + "github.com/gin-gonic/gin" +) + +func (a Api) Register(c *gin.Context) { + req := new(dto.User) + err := c.ShouldBind(&req) + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } + + err = a.userService.Register(c, &model.User{ + Username: req.Username, + Password: req.Password, + }) + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } else { + dto.Success(c, nil) + } + +} + +func (a Api) Login(c *gin.Context) { + req := new(dto.User) + err := c.ShouldBind(&req) + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } + + auth, err := a.userService.Login(c, req) + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } else { + dto.Success(c, auth) + } + +} + +func (a Api) Profile(c *gin.Context) { + user, err := a.userService.Profile(c) + if err != nil { + dto.Fail(c, http.StatusUnauthorized, err.Error()) + return + } else { + dto.Success(c, user) + } +} + +func (a Api) UpdateProfile(c *gin.Context) { + var user = model.User{} + err := c.ShouldBind(&user) + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } + + err = a.userService.Update(c, &model.User{Name: user.Name, Username: user.Username, Email: user.Email}) + if err != nil { + dto.Fail(c, http.StatusInternalServerError, err.Error()) + return + } + dto.Success(c, nil) +} + +func (a Api) UpdatePassword(c *gin.Context) { + var passwd dto.ChangePassword + err := c.ShouldBind(&passwd) + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } + + _user := c.MustGet("user").(*model.User) + if _user.Password == "" { + hashpass, err := utils.HashPassword(passwd.NewPassword) + if err != nil { + dto.Fail(c, http.StatusInternalServerError, err.Error()) + return + } + _user.Password = hashpass + } else { + if !utils.CheckPassword(_user.Password, passwd.Password) { + dto.Fail(c, http.StatusBadRequest, "password not match") + return + } + hashpass, err := utils.HashPassword(passwd.NewPassword) + if err != nil { + dto.Fail(c, http.StatusInternalServerError, err.Error()) + return + } + _user.Password = hashpass + } + err = a.userService.Update(c, _user) + if err != nil { + dto.Fail(c, http.StatusInternalServerError, err.Error()) + return + } + dto.Success(c, nil) +} + +func (a Api) ListUser(c *gin.Context) { + limit, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20")) + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + offset := (page - 1) * limit + active := c.QueryArray("active[]") + if !slice.ContainSubSlice([]string{"true", "false", ""}, active) { + dto.Fail(c, http.StatusBadRequest, "active must be true or false") + return + } + + users, total, err := a.userService.List(c, limit, offset, active) + if err != nil { + dto.Fail(c, http.StatusInternalServerError, err.Error()) + return + } + dto.Success(c, gin.H{ + "users": users, + "total": total, + }) +} + +func (a Api) CreateUser(c *gin.Context) { + var user model.User + err := c.ShouldBind(&user) + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } + fmt.Printf("user:%+v\n", user) + err = a.userService.Create(c, &user) + if err != nil { + dto.Fail(c, http.StatusInternalServerError, err.Error()) + return + } + + dto.Success(c, nil) +} + +func (a Api) GetUser(c *gin.Context) { + id, _ := strconv.ParseInt(c.Param("id"), 10, 64) + user, err := a.userService.GetByID(c, id) + if err != nil { + dto.Fail(c, 500, err.Error()) + return + } + dto.Success(c, user) +} + +func (a Api) EditUser(c *gin.Context) { + id, _ := strconv.ParseInt(c.Param("id"), 10, 64) + var user model.User + err := c.ShouldBind(&user) + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } + + user.ID = int64(id) + err = a.userService.Update(c, &user) + if err != nil { + dto.Fail(c, 500, err.Error()) + return + } + dto.Success(c, nil) +} + +func (a Api) DeleteUser(c *gin.Context) { + id, _ := strconv.ParseInt(c.Param("id"), 10, 64) + err := a.userService.Delete(c, id) + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } + dto.Success(c, nil) +} + +func (a Api) UserOption(c *gin.Context) { + option := strings.ToLower(c.Param("option")) + var batchid dto.BatchIDRequest + err := c.ShouldBind(&batchid) + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } + switch option { + case "enable": + err = a.userService.BatchEnable(c, batchid.IDs) + case "disable": + err = a.userService.BatchDisable(c, batchid.IDs) + case "delete": + err = a.userService.BatchDelete(c, batchid.IDs) + default: + dto.Fail(c, 400, "invalid option, only support enable, disable, delete") + return + } + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } + dto.Success(c, nil) + +} diff --git a/internal/controller/user_token.go b/internal/controller/user_token.go new file mode 100644 index 0000000..f0764d2 --- /dev/null +++ b/internal/controller/user_token.go @@ -0,0 +1,218 @@ +package controller + +import ( + "net/http" + "opencatd-open/internal/dto" + "opencatd-open/internal/model" + "opencatd-open/internal/utils" + "strconv" + "strings" + + "github.com/duke-git/lancet/v2/slice" + "github.com/gin-gonic/gin" +) + +func (a Api) CreateToken(c *gin.Context) { + userid := c.GetInt64("user_id") + user, err := a.userService.GetByID(c, userid) + if err != nil { + dto.Fail(c, http.StatusInternalServerError, err.Error()) + return + } + if len(user.Tokens) >= 20 { + dto.Fail(c, http.StatusForbidden, "user has reached the maximum number of tokens") + return + } + + var token model.Token + err = c.ShouldBindJSON(&token) + if err != nil { + c.JSON(400, gin.H{"error": err.Error()}) + return + } + token.UserID = userid + + err = a.tokenService.CreateToken(c, &token) + if err != nil { + dto.Fail(c, http.StatusInternalServerError, err.Error()) + return + } + dto.Success(c, nil) +} + +func (a Api) ListToken(c *gin.Context) { + limit, _ := strconv.Atoi(c.DefaultQuery("pageSize", "20")) + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + offset := (page - 1) * limit + active := c.QueryArray("active[]") + if !slice.ContainSubSlice([]string{"true", "false"}, active) { + dto.Fail(c, http.StatusBadRequest, "active must be true or false") + } + + tokens, total, err := a.tokenService.ListToken(c, limit, offset, active) + if err != nil { + dto.Fail(c, http.StatusInternalServerError, err.Error()) + return + } + + dto.Success(c, gin.H{ + "total": total, + "tokens": tokens, + }) + +} + +func (a Api) GetToken(c *gin.Context) { + id, _ := strconv.ParseInt(c.Param("id"), 10, 64) + + token, err := a.tokenService.GetToken(c, id) + if err != nil { + dto.Fail(c, http.StatusInternalServerError, err.Error()) + return + } + + dto.Success(c, token) +} + +func (a Api) ResetToken(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + dto.Fail(c, http.StatusBadRequest, err.Error()) + return + } + + token, err := a.tokenService.GetToken(c, id) + if err != nil { + dto.Fail(c, http.StatusInternalServerError, err.Error()) + return + } + if token == nil { + dto.Fail(c, http.StatusNotFound, "token not found") + return + } + token.UsedQuota = utils.ToPtr(int64(0)) + + err = a.tokenService.UpdateToken(c, token) + if err != nil { + dto.Fail(c, http.StatusInternalServerError, err.Error()) + return + } + + dto.Success(c, nil) +} + +func (a Api) UpdateToken(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + dto.Fail(c, http.StatusBadRequest, err.Error()) + return + } + + var token model.Token + err = c.ShouldBindJSON(&token) + if err != nil { + c.JSON(400, gin.H{"error": err.Error()}) + return + } + token.ID = id + if token.UserID == 0 { + dto.Fail(c, http.StatusBadRequest, "user_id is required") + return + } + + var _token *model.Token + + user, err := a.userService.GetByID(c, token.UserID) + if err != nil { + dto.Fail(c, http.StatusInternalServerError, err.Error()) + return + } + if len(user.Tokens) == 0 { + dto.Fail(c, http.StatusForbidden, "user has no tokens") + return + } else { + if findtoken, ok := slice.Find(user.Tokens, + func(idx int, t model.Token) bool { + return t.ID == id + }); ok { + _token = findtoken + _token.User = user + } else { + dto.Fail(c, http.StatusForbidden, "user has no tokens") + return + } + } + // 更新_token信息 + if token.Name != "" { + _token.Name = token.Name + } + if token.Key != "" { + _token.Key = token.Key + } + if token.Active != nil { + _token.Active = token.Active + } + if token.Quota != nil { + _token.Quota = token.Quota + } + if token.UnlimitedQuota != nil { + _token.UnlimitedQuota = token.UnlimitedQuota + } + if token.ExpiredAt != nil { + _token.ExpiredAt = token.ExpiredAt + } + + err = a.tokenService.UpdateToken(c, _token) + if err != nil { + dto.Fail(c, http.StatusInternalServerError, err.Error()) + return + } + + dto.Success(c, nil) +} + +func (a Api) DeleteToken(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + dto.Fail(c, http.StatusBadRequest, err.Error()) + return + } + + err = a.tokenService.DeleteToken(c, id) + if err != nil { + dto.Fail(c, http.StatusInternalServerError, err.Error()) + return + } + + dto.Success(c, nil) +} + +func (a Api) TokenOption(c *gin.Context) { + option := strings.ToLower(c.Param("option")) + var batchid dto.BatchIDRequest + err := c.ShouldBind(&batchid) + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } + if batchid.UserID == nil { + dto.Fail(c, 400, "user_id is required") + return + } + switch option { + case "enable": + err = a.tokenService.EnableTokens(c, *batchid.UserID, batchid.IDs) + case "disable": + err = a.tokenService.DisableTokens(c, *batchid.UserID, batchid.IDs) + case "delete": + err = a.tokenService.DeleteTokens(c, *batchid.UserID, batchid.IDs) + default: + dto.Fail(c, 400, "invalid option, only support enable, disable, delete") + return + } + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } + dto.Success(c, nil) +} diff --git a/internal/controller/webauth.go b/internal/controller/webauth.go new file mode 100644 index 0000000..1fbf96d --- /dev/null +++ b/internal/controller/webauth.go @@ -0,0 +1,108 @@ +package controller + +import ( + "fmt" + "opencatd-open/internal/auth" + "opencatd-open/internal/consts" + "opencatd-open/internal/dto" + "strconv" + "time" + + "github.com/gin-gonic/gin" +) + +func (a *Api) PasskeyCreateBegin(c *gin.Context) { + userid := c.GetInt64("user_id") + cred, err := a.webAuthService.BeginRegistration(userid) + if err != nil { + dto.Fail(c, 500, err.Error()) + return + } + dto.Success(c, cred) +} + +func (a *Api) PasskeyCreateFinish(c *gin.Context) { + userid := c.GetInt64("user_id") + name := c.Query("name") + if name == "" { + name = fmt.Sprintf("User-%d-%d", userid, time.Now().Unix()) + } + // var body protocol.CredentialCreationResponse + // if err := c.ShouldBindJSON(&body); err != nil { + // dto.Fail(c, 400, err.Error()) + // return + // } + + // 获取用户凭证 + cred, err := a.webAuthService.FinishRegistration(userid, c.Request, name) + if err != nil { + dto.Fail(c, 500, err.Error()) + return + } + + dto.Success(c, cred) +} + +func (a *Api) ListPasskey(c *gin.Context) { + passkeys, err := a.webAuthService.ListPasskeys(c.GetInt64("user_id")) + if err != nil { + dto.Fail(c, 500, err.Error()) + return + } + var passkeysDto []dto.Passkey + for _, passkey := range passkeys { + passkeysDto = append(passkeysDto, dto.Passkey{ + ID: passkey.ID, + Name: passkey.Name, + DeviceType: passkey.DeviceType, + SignCount: passkey.SignCount, + LastUsedAt: passkey.LastUsedAt, + CreatedAt: passkey.CreatedAt, + UpdatedAt: passkey.UpdatedAt, + }) + } + + dto.Success(c, passkeysDto) +} + +func (a *Api) DeletePasskey(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + dto.Fail(c, 400, err.Error()) + return + } + if err = a.webAuthService.DeletePasskey(c.GetInt64("user_id"), id); err != nil { + dto.Fail(c, 500, err.Error()) + return + } + dto.Success(c, "删除成功") +} + +// 登陆 +func (a *Api) PasskeyAuthBegin(c *gin.Context) { + + cred, err := a.webAuthService.BeginLogin() + if err != nil { + dto.Fail(c, 500, err.Error()) + return + } + dto.Success(c, cred) +} + +func (a *Api) PasskeyAuthFinish(c *gin.Context) { + challenge := c.Query("challenge") + webAuthUser, err := a.webAuthService.FinishLogin(challenge, c.Request) + if err != nil { + dto.Fail(c, 500, err.Error()) + return + } + at, err := auth.GenerateTokenPair(webAuthUser.User, consts.SecretKey, consts.Day*time.Second, consts.Day*time.Second) + if err != nil { + dto.Fail(c, 500, err.Error()) + return + } + dto.Success(c, dto.Auth{ + Token: at.AccessToken, + ExpiresIn: time.Now().Add(consts.Day * time.Second).Unix(), + }) +} diff --git a/team/dao/apikey.go b/internal/dao/apikey.go similarity index 61% rename from team/dao/apikey.go rename to internal/dao/apikey.go index 88b0ccf..a113326 100644 --- a/team/dao/apikey.go +++ b/internal/dao/apikey.go @@ -2,8 +2,8 @@ package dao import ( "errors" - "opencatd-open/team/model" - "time" + "opencatd-open/internal/model" + "opencatd-open/pkg/config" "gorm.io/gorm" ) @@ -16,11 +16,8 @@ type ApiKeyRepository interface { GetByName(name string) (*model.ApiKey, error) GetByApiKey(apiKeyValue string) (*model.ApiKey, error) Update(apiKey *model.ApiKey) error - Delete(id int64) error - List(offset, limit int, status *int) ([]model.ApiKey, error) - ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.ApiKey, int64, error) - Enable(id int64) error - Disable(id int64) error + List(limit, offset int, status string) ([]*model.ApiKey, error) + ListWithFilters(limit, offset int, filters map[string]interface{}) ([]*model.ApiKey, int64, error) BatchEnable(ids []int64) error BatchDisable(ids []int64) error BatchDelete(ids []int64) error @@ -28,7 +25,8 @@ type ApiKeyRepository interface { } type ApiKeyDAO struct { - db *gorm.DB + cfg *config.Config + db *gorm.DB } func NewApiKeyDAO(db *gorm.DB) *ApiKeyDAO { @@ -73,26 +71,54 @@ func (dao *ApiKeyDAO) GetByApiKey(apiKeyValue string) (*model.ApiKey, error) { return &apiKey, nil } +func (dao *ApiKeyDAO) FindKeys(condition map[string]any) ([]model.ApiKey, error) { + var apiKeys []model.ApiKey + + query := dao.db.Model(&model.ApiKey{}) + for k, v := range condition { + query = query.Where(k, v) + } + err := query.Find(&apiKeys).Error + + return apiKeys, err +} + +func (dao *ApiKeyDAO) FindApiKeysBySupportModel(db *gorm.DB, modelName string) ([]model.ApiKey, error) { + var apiKeys []model.ApiKey + switch dao.cfg.DB_Type { + case "mysql": + return nil, errors.New("not support") + case "postgres": + return nil, errors.New("not support") + } + err := db.Model(&model.ApiKey{}). + Joins("CROSS JOIN JSON_EACH(apikeys.support_models)"). + Where("value = ?", modelName). + Find(&apiKeys).Error + return apiKeys, err +} + // UpdateApiKey 更新ApiKey信息 func (dao *ApiKeyDAO) Update(apiKey *model.ApiKey) error { if apiKey == nil { return errors.New("apiKey is nil") } - apiKey.UpdatedAt = time.Now().Unix() + // return dao.db.Model(&model.ApiKey{}). + // Select("name", "apitype", "apikey", "status", "endpoint", "resource_name", "deployment_name").Updates(apiKey).Error return dao.db.Save(apiKey).Error } // DeleteApiKey 删除ApiKey func (dao *ApiKeyDAO) Delete(id int64) error { - return dao.db.Delete(&model.ApiKey{}, id).Error + return dao.db.Unscoped().Delete(&model.ApiKey{}, id).Error } // ListApiKeys 获取ApiKey列表 -func (dao *ApiKeyDAO) List(offset, limit int, status *int) ([]model.ApiKey, error) { - var apiKeys []model.ApiKey - db := dao.db.Offset(offset).Limit(limit) - if status != nil { - db = db.Where("status = ?", *status) +func (dao *ApiKeyDAO) List(limit, offset int, status string) ([]*model.ApiKey, error) { + var apiKeys []*model.ApiKey + db := dao.db.Limit(limit).Offset(offset) + if status != "" { + db = db.Where("status = ?", status) } err := db.Find(&apiKeys).Error if err != nil { @@ -102,28 +128,20 @@ func (dao *ApiKeyDAO) List(offset, limit int, status *int) ([]model.ApiKey, erro } // ListApiKeysWithFilters 根据条件获取ApiKey列表 -func (dao *ApiKeyDAO) ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.ApiKey, int64, error) { - var apiKeys []model.ApiKey - db := dao.db.Offset(offset).Limit(limit) - for key, value := range filters { - db = db.Where(key+" = ?", value) +func (dao *ApiKeyDAO) ListWithFilters(limit, offset int, filters map[string]interface{}) ([]*model.ApiKey, int64, error) { + var apiKeys []*model.ApiKey + db := dao.db.Limit(limit).Offset(offset) + for k, v := range filters { + db = db.Where(k, v) } - var count int64 - err := db.Find(&apiKeys).Count(&count).Error + err := db.Find(&apiKeys).Error if err != nil { return nil, 0, err } - return apiKeys, count, nil -} + var total int64 + db.Model(&model.ApiKey{}).Count(&total) -// EnableApiKey 启用ApiKey -func (dao *ApiKeyDAO) Enable(id int64) error { - return dao.db.Model(&model.ApiKey{}).Where("id = ?", id).Update("status", 0).Error -} - -// DisableApiKey 禁用ApiKey -func (dao *ApiKeyDAO) Disable(id int64) error { - return dao.db.Model(&model.ApiKey{}).Where("id = ?", id).Update("status", 1).Error + return apiKeys, total, nil } // BatchEnableApiKeys 批量启用ApiKey @@ -131,7 +149,7 @@ func (dao *ApiKeyDAO) BatchEnable(ids []int64) error { if len(ids) == 0 { return errors.New("ids is empty") } - return dao.db.Model(&model.ApiKey{}).Where("id IN ?", ids).Update("status", 0).Error + return dao.db.Model(&model.ApiKey{}).Where("id IN ?", ids).Update("active", true).Error } // BatchDisableApiKeys 批量禁用ApiKey @@ -139,7 +157,7 @@ func (dao *ApiKeyDAO) BatchDisable(ids []int64) error { if len(ids) == 0 { return errors.New("ids is empty") } - return dao.db.Model(&model.ApiKey{}).Where("id IN ?", ids).Update("status", 1).Error + return dao.db.Model(&model.ApiKey{}).Where("id IN ?", ids).Update("active", false).Error } // BatchDeleteApiKey 批量删除ApiKey @@ -147,7 +165,7 @@ func (dao *ApiKeyDAO) BatchDelete(ids []int64) error { if len(ids) == 0 { return errors.New("ids is empty") } - return dao.db.Delete(&model.ApiKey{}, ids).Error + return dao.db.Unscoped().Delete(&model.ApiKey{}, ids).Error } // CountApiKeys 获取ApiKey总数 diff --git a/team/dao/token.go b/internal/dao/token.go similarity index 57% rename from team/dao/token.go rename to internal/dao/token.go index 2f1deb8..f46020d 100644 --- a/team/dao/token.go +++ b/internal/dao/token.go @@ -3,8 +3,8 @@ package dao import ( "context" "errors" - "opencatd-open/team/consts" - "opencatd-open/team/model" + "opencatd-open/internal/consts" + "opencatd-open/internal/model" "time" "gorm.io/gorm" @@ -15,19 +15,19 @@ var _ TokenRepository = (*TokenDAO)(nil) type TokenRepository interface { Create(ctx context.Context, token *model.Token) error - GetByID(ctx context.Context, id int) (*model.Token, error) + GetByID(ctx context.Context, id int64) (*model.Token, error) GetByKey(ctx context.Context, key string) (*model.Token, error) - GetByUserID(ctx context.Context, userID int) (*model.Token, error) + GetByUserID(ctx context.Context, userID int64) (*model.Token, error) Update(ctx context.Context, token *model.Token) error UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error - Delete(ctx context.Context, id int) error - List(ctx context.Context, offset, limit int) ([]model.Token, error) - ListWithFilters(ctx context.Context, offset, limit int, filters map[string]interface{}) ([]model.Token, int64, error) + Delete(ctx context.Context, id int64, condition map[string]interface{}) error + List(ctx context.Context, limit, offset int) ([]*model.Token, error) + ListWithFilters(ctx context.Context, limit, offset int, filters map[string]interface{}) ([]*model.Token, int64, error) Disable(ctx context.Context, id int) error Enable(ctx context.Context, id int) error - BatchDisable(ctx context.Context, ids []int) error - BatchEnable(ctx context.Context, ids []int) error - BatchDelete(ctx context.Context, ids []int) error + BatchDisable(ctx context.Context, ids []int64, filters map[string]interface{}) error + BatchEnable(ctx context.Context, ids []int64, filters map[string]interface{}) error + BatchDelete(ctx context.Context, ids []int64, filters map[string]interface{}) error } type TokenDAO struct { @@ -47,9 +47,9 @@ func (dao *TokenDAO) Create(ctx context.Context, token *model.Token) error { } // 根据 ID 获取 Token -func (dao *TokenDAO) GetByID(ctx context.Context, id int) (*model.Token, error) { +func (dao *TokenDAO) GetByID(ctx context.Context, id int64) (*model.Token, error) { var token model.Token - err := dao.db.WithContext(ctx).First(&token, id).Error + err := dao.db.WithContext(ctx).Preload("User").First(&token, id).Error if err != nil { return nil, err } @@ -68,7 +68,7 @@ func (dao *TokenDAO) GetByKey(ctx context.Context, key string) (*model.Token, er } // 根据 UserID 获取 Token -func (dao *TokenDAO) GetByUserID(ctx context.Context, userID int) (*model.Token, error) { +func (dao *TokenDAO) GetByUserID(ctx context.Context, userID int64) (*model.Token, error) { var token model.Token err := dao.db.WithContext(ctx).Preload("User").Where("user_id = ?", userID).Find(&token).Error if err != nil { @@ -98,14 +98,21 @@ func (dao *TokenDAO) UpdateWithCondition(ctx context.Context, token *model.Token } // DeleteToken 删除 Token -func (dao *TokenDAO) Delete(ctx context.Context, id int) error { - return dao.db.WithContext(ctx).Delete(&model.Token{}, id).Error +func (dao *TokenDAO) Delete(ctx context.Context, id int64, condition map[string]interface{}) error { + if id <= 0 { + return errors.New("id is invalid") + } + query := dao.db.WithContext(ctx).Where("id = ?", id) + for key, value := range condition { + query = query.Where(key, value) + } + return query.Unscoped().Delete(&model.Token{}).Error } // ListTokens 获取 Token 列表 -func (dao *TokenDAO) List(ctx context.Context, offset, limit int) ([]model.Token, error) { - var tokens []model.Token - err := dao.db.WithContext(ctx).Offset(offset).Limit(limit).Find(&tokens).Error +func (dao *TokenDAO) List(ctx context.Context, limit, offset int) ([]*model.Token, error) { + var tokens []*model.Token + err := dao.db.WithContext(ctx).Limit(limit).Offset(offset).Find(&tokens).Error if err != nil { return nil, err } @@ -113,22 +120,22 @@ func (dao *TokenDAO) List(ctx context.Context, offset, limit int) ([]model.Token } // ListTokensWithFilters 获取 Token 列表,支持过滤 -func (dao *TokenDAO) ListWithFilters(ctx context.Context, offset, limit int, filters map[string]interface{}) ([]model.Token, int64, error) { - var tokens []model.Token +func (dao *TokenDAO) ListWithFilters(ctx context.Context, limit, offset int, filters map[string]interface{}) ([]*model.Token, int64, error) { + var tokens []*model.Token var count int64 db := dao.db.WithContext(ctx) - for key, value := range filters { - db = db.Where(key+" = ?", value) + if filters != nil { + for k, v := range filters { + db = db.Where(k, v) + } } - if err := db.Offset(offset).Limit(limit).Find(&tokens).Error; err != nil { + if err := db.Limit(limit).Offset(offset).Find(&tokens).Error; err != nil { return nil, 0, err } - if err := db.Model(&model.Token{}).Count(&count).Error; err != nil { - return nil, 0, err - } + db.Model(&model.Token{}).Count(&count) return tokens, count, nil } @@ -144,18 +151,31 @@ func (dao *TokenDAO) Enable(ctx context.Context, id int) error { } // BatchDisableTokens 批量禁用 Token -func (dao *TokenDAO) BatchDisable(ctx context.Context, ids []int) error { - return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id IN ?", ids).Update("status", false).Error +func (dao *TokenDAO) BatchDisable(ctx context.Context, ids []int64, filters map[string]interface{}) error { + query := dao.db.WithContext(ctx).Model(&model.Token{}).Where("id IN ?", ids) + for key, value := range filters { + query = query.Where(key, value) + } + return query.Update("active", false).Error } // BatchEnableTokens 批量启用 Token -func (dao *TokenDAO) BatchEnable(ctx context.Context, ids []int) error { - return dao.db.WithContext(ctx).Model(&model.Token{}).Where("id IN ?", ids).Update("status", true).Error +func (dao *TokenDAO) BatchEnable(ctx context.Context, ids []int64, filters map[string]interface{}) error { + query := dao.db.WithContext(ctx).Model(&model.Token{}).Where("id IN ?", ids) + for key, value := range filters { + query = query.Where(key, value) + } + return query.Update("active", true).Error } // BatchDeleteTokens 批量删除 Token -func (dao *TokenDAO) BatchDelete(ctx context.Context, ids []int) error { - return dao.db.WithContext(ctx).Where("id IN ?", ids).Delete(&model.Token{}).Error +func (dao *TokenDAO) BatchDelete(ctx context.Context, ids []int64, filters map[string]interface{}) error { + query := dao.db.Unscoped().WithContext(ctx).Where("id IN ?", ids) + for key, value := range filters { + query = query.Where(key, value) + } + return query.Delete(&model.Token{}).Error + // return dao.db.WithContext(ctx).Where("name != 'default' AND id IN ?", ids).Delete(&model.Token{}).Error } // 检查 token 是否有效 @@ -170,7 +190,7 @@ func (dao *TokenDAO) IsValid(ctx context.Context, key string) (bool, error) { } return false, err } - if token.User.Status != consts.StatusEnabled || (token.User.UnlimitedQuota == 1 && token.User.Quota <= 0) { + if token.User.Status != consts.StatusEnabled || (*token.User.UnlimitedQuota && *token.User.Quota <= 0) { return false, nil } diff --git a/team/dao/usage.go b/internal/dao/usage.go similarity index 95% rename from team/dao/usage.go rename to internal/dao/usage.go index 9c5ea14..014c6f0 100644 --- a/team/dao/usage.go +++ b/internal/dao/usage.go @@ -3,10 +3,9 @@ package dao import ( "context" "fmt" - "opencatd-open/pkg/store" - "opencatd-open/team/consts" - dto "opencatd-open/team/dto/team" - "opencatd-open/team/model" + dto "opencatd-open/internal/dto/team" + "opencatd-open/internal/model" + "opencatd-open/pkg/config" "time" "gorm.io/gorm" @@ -58,14 +57,15 @@ type UsageDAO struct { } type DailyUsageDAO struct { - db *gorm.DB + cfg *config.Config + db *gorm.DB } -func NewUsageDAO(db *gorm.DB) *UsageDAO { +func NewUsageDAO(cfg *config.Config, db *gorm.DB) *UsageDAO { return &UsageDAO{db: db} } -func NewDailyUsageDAO(db *gorm.DB) *DailyUsageDAO { +func NewDailyUsageDAO(cfg *config.Config, db *gorm.DB) *DailyUsageDAO { return &DailyUsageDAO{db: db} } @@ -206,8 +206,8 @@ func (d *DailyUsageDAO) UpsertDailyUsage(ctx context.Context, usage *model.Usage db := d.db.WithContext(ctx) - switch store.DBType { - case consts.DBTypeMySQL: + switch d.cfg.DB_Type { + case "mysql": // MySQL: INSERT ... ON DUPLICATE KEY UPDATE return db.Clauses(clause.OnConflict{ Columns: []clause.Column{ @@ -221,7 +221,7 @@ func (d *DailyUsageDAO) UpsertDailyUsage(ctx context.Context, usage *model.Usage DoUpdates: clause.Assignments(updateColumns), }).Create(dailyUsage).Error - case consts.DBTypePostgreSQL: + case "postgres": // PostgreSQL: INSERT ... ON CONFLICT DO UPDATE updateColumns := map[string]interface{}{ "prompt_tokens": gorm.Expr("daily_usages.prompt_tokens + EXCLUDED.prompt_tokens"), @@ -239,8 +239,9 @@ func (d *DailyUsageDAO) UpsertDailyUsage(ctx context.Context, usage *model.Usage }, DoUpdates: clause.Assignments(updateColumns), }).Create(dailyUsage).Error - case consts.DBTypeSQLite: - // SQLite: 需要使用事务来模拟 upsert + case "sqlite": + fallthrough + default: return db.Transaction(func(tx *gorm.DB) error { var existing model.DailyUsage err := tx.Where("user_id = ? AND token_id = ? AND capability = ? AND date = ? AND model = ? AND stream = ?", @@ -261,9 +262,6 @@ func (d *DailyUsageDAO) UpsertDailyUsage(ctx context.Context, usage *model.Usage "total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens), }).Error }) - - default: - return fmt.Errorf("不支持的数据库类型: %s", store.DBType) } } diff --git a/team/dao/user.go b/internal/dao/user.go similarity index 56% rename from team/dao/user.go rename to internal/dao/user.go index 1d00060..ec94ae6 100644 --- a/team/dao/user.go +++ b/internal/dao/user.go @@ -3,8 +3,7 @@ package dao import ( "errors" "fmt" - "opencatd-open/team/consts" - "opencatd-open/team/model" + "opencatd-open/internal/model" "time" "gorm.io/gorm" @@ -20,13 +19,12 @@ type UserRepository interface { GetByUsername(username string) (*model.User, error) Update(user *model.User) error Delete(id int64) error - List(offset, limit int) ([]model.User, error) - ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.User, int64, error) - Enable(id int64) error - Disable(id int64) error - BatchEnable(ids []int64) error - BatchDisable(ids []int64) error - BatchDelete(ids []int64) error + List(limit, offset int, condition map[string]interface{}) ([]model.User, int64, error) + // Enable(id int64) error + // Disable(id int64) error + BatchEnable(ids []int64, condition []string) error + BatchDisable(ids []int64, condition []string) error + BatchDelete(ids []int64, condition []string) error } type UserDAO struct { @@ -42,6 +40,7 @@ func (dao *UserDAO) Create(user *model.User) error { if user == nil { return errors.New("user is nil") } + fmt.Println(*user) return dao.db.Transaction(func(tx *gorm.DB) error { // 创建用户 @@ -57,7 +56,7 @@ func (dao *UserDAO) Create(user *model.User) error { func (dao *UserDAO) GetByID(id int64) (*model.User, error) { var user model.User // err := dao.db.First(&user, id).Error - err := dao.db.Preload("Tokens").First(&user, id).Error + err := dao.db.Preload("Tokens", "user_id = ?", id).First(&user, id).Error if err != nil { return nil, err } @@ -80,84 +79,85 @@ func (dao *UserDAO) Update(user *model.User) error { if user == nil { return errors.New("user is nil") } + user.UpdatedAt = time.Now().Unix() return dao.db.Save(user).Error } // 删除用户 func (dao *UserDAO) Delete(id int64) error { - return dao.db.Delete(&model.User{}, id).Error + return dao.db.Unscoped().Delete(&model.User{}, id).Error // return dao.db.Model(&model.User{}).Where("id = ?", id).Update("status", 2).Error } // 获取用户列表 -func (dao *UserDAO) List(offset, limit int) ([]model.User, error) { - var users []model.User - err := dao.db.Preload("Tokens").Offset(offset).Limit(limit).Find(&users).Error - if err != nil { - return nil, err +func (dao *UserDAO) List(limit, offset int, condition map[string]interface{}) ([]model.User, int64, error) { + if offset < 0 { + offset = 0 } - return users, nil -} - -// 获取用户列表,带过滤条件 -func (dao *UserDAO) ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.User, int64, error) { var users []model.User var total int64 - // 构建查询 - query := dao.db.Model(&model.User{}) + query := dao.db.Preload("Tokens").Model(&model.User{}) - // 添加过滤条件 - for key, value := range filters { - query = query.Where(key+" = ?", value) + for k, v := range condition { + query = query.Where(k, v) } - - // 查询总数 - if err := query.Count(&total).Error; err != nil { - return nil, 0, err - } - - // 分页查询 - err := query.Offset(offset).Limit(limit).Find(&users).Error + err := query.Limit(limit).Offset(offset).Find(&users).Error if err != nil { return nil, 0, err } + query = dao.db.Model(&model.User{}) + for k, v := range condition { + query = query.Where(k, v) + } + query.Count(&total) return users, total, nil } // 启用User -func (dao *UserDAO) Enable(id int64) error { - return dao.db.Model(&model.User{}).Where("id = ?", id).Update("status", consts.StatusEnabled).Error +func (dao *UserDAO) Enable(id uint) error { + return dao.db.Model(&model.User{}).Where("id = ?", id).Update("active", true).Error } // 禁用User -func (dao *UserDAO) Disable(id int64) error { - return dao.db.Model(&model.User{}).Where("id = ?", id).Update("status", consts.StatusDisabled).Error +func (dao *UserDAO) Disable(id uint) error { + return dao.db.Model(&model.User{}).Where("id = ?", id).Update("active", false).Error } // 批量启用User -func (dao *UserDAO) BatchEnable(ids []int64) error { +func (dao *UserDAO) BatchEnable(ids []int64, condition []string) error { if len(ids) == 0 { return errors.New("ids is empty") } - return dao.db.Model(&model.User{}).Where("id IN ?", ids).Update("status", 0).Error + query := dao.db.Model(&model.User{}).Where("id IN ?", ids) + for _, value := range condition { + query = query.Where(value) + } + return query.Update("active", true).Error } // 批量禁用User -func (dao *UserDAO) BatchDisable(ids []int64) error { +func (dao *UserDAO) BatchDisable(ids []int64, condition []string) error { if len(ids) == 0 { return errors.New("ids is empty") } - return dao.db.Model(&model.User{}).Where("id IN ?", ids).Update("status", 1).Error + query := dao.db.Model(&model.User{}).Where("id IN ?", ids) + for _, value := range condition { + query = query.Where(value) + } + return query.Update("active", false).Error } // 批量删除用户 -func (dao *UserDAO) BatchDelete(ids []int64) error { +func (dao *UserDAO) BatchDelete(ids []int64, condition []string) error { if len(ids) == 0 { return errors.New("ids is empty") } - return dao.db.Where("id IN ?", ids).Delete(&model.User{}).Error - // return dao.db.Model(&model.User{}).Where("id IN ?", ids).Update("status", 2).Error + query := dao.db.Unscoped().Where("id IN ?", ids) + for _, value := range condition { + query = query.Where(value) + } + return query.Delete(&model.User{}).Error } diff --git a/internal/dto/batch.go b/internal/dto/batch.go new file mode 100644 index 0000000..4778a56 --- /dev/null +++ b/internal/dto/batch.go @@ -0,0 +1,6 @@ +package dto + +type BatchIDRequest struct { + UserID *int64 `json:"user_id"` + IDs []int64 `json:"ids" binding:"required"` +} diff --git a/team/dto/openai/err_resp.go b/internal/dto/error.go similarity index 53% rename from team/dto/openai/err_resp.go rename to internal/dto/error.go index 3b82c23..214f3c9 100644 --- a/team/dto/openai/err_resp.go +++ b/internal/dto/error.go @@ -1,22 +1,20 @@ package dto import ( - "net/http" - "github.com/gin-gonic/gin" ) type Error struct { + Code int `json:"code,omitempty"` Message string `json:"message,omitempty"` - Code string `json:"code,omitempty"` } -func WarpErrAsOpenAI(c *gin.Context, msg string, code string) { - c.JSON(http.StatusForbidden, gin.H{ +func WrapErrorAsOpenAI(c *gin.Context, code int, msg string) { + c.JSON(code, gin.H{ "error": Error{ - Message: msg, Code: code, + Message: msg, }, }) - return + c.Abort() } diff --git a/internal/dto/key.go b/internal/dto/key.go new file mode 100644 index 0000000..f323e34 --- /dev/null +++ b/internal/dto/key.go @@ -0,0 +1,107 @@ +package dto + +import ( + "errors" + "regexp" + "time" + + validation "github.com/go-ozzo/ozzo-validation/v4" +) + +// TeamKey 结构体定义 +type TeamKey struct { + ID *int64 `json:"id,omitempty"` + UserID *int64 `json:"userID,omitempty"` + Name *string `json:"name,omitempty"` // 必须 + Key *string `json:"key,omitempty"` + Status *int64 `json:"status,omitempty"` // 默认1 允许,0禁止 + Quota *int64 `json:"quota,omitempty"` // UnlimitedQuota不为1 的时候必须 + UnlimitedQuota *bool `json:"unlimitedQuota,omitempty"` // 默认1 不限制,0限制 + UsedQuota *int64 `json:"usedQuota,omitempty"` + CreatedAt *int64 `json:"createdAt,omitempty"` + ExpiredAt *int64 `json:"expiredAt,omitempty"` // 可选 +} + +// DefaultTeamKey 创建一个具有默认值的 TeamKey +func DefaultTeamKey() TeamKey { + status := int64(1) // 默认允许 + unlimitedQuota := true // 默认不限制 + createdAt := time.Now().Unix() + + return TeamKey{ + Status: &status, + UnlimitedQuota: &unlimitedQuota, + CreatedAt: &createdAt, + } +} + +// Validate 验证 TeamKey 结构体 +func (t TeamKey) Validate() error { + // 自定义验证规则 + var quotaRule validation.Rule = validation.Skip + if t.UnlimitedQuota != nil && !*t.UnlimitedQuota { + quotaRule = validation.Required.Error("当 UnlimitedQuota 为 false 时,Quota 是必填项") + } + + // 过期时间校验 + var expiredAtRule validation.Rule = validation.Skip + if t.ExpiredAt != nil { + expiredAtRule = validation.Min(time.Now().Unix()).Error("过期时间不能早于当前时间") + } + + return validation.ValidateStruct(&t, + // ID 通常由系统生成,不需要验证 + + // UserID 可选,但如果提供必须大于 0 + validation.Field(&t.UserID, + validation.When(t.UserID != nil, validation.Min(int64(1)).Error("用户 ID 必须大于 0"))), + + // Name 是必填字段 + validation.Field(&t.Name, + validation.Required.Error("名称不能为空"), + validation.When(t.Name != nil, validation.Length(1, 100).Error("名称长度应在 1-100 之间"))), + + // Key 可选,但如果提供需要符合特定格式 + validation.Field(&t.Key, + validation.When(t.Key != nil, + validation.Length(1, 255).Error("Key 长度应在 1-255 之间")), + validation.Match(regexp.MustCompile(`^[^\s]+$`)).Error("Key 不能包含空格"), + ), + + // Status 只能是 0 或 1 + validation.Field(&t.Status, + validation.When(t.Status != nil, validation.In(int64(0), int64(1)).Error("状态只能是 0(禁止) 或 1(允许)"))), + + // Quota 要求依赖于 UnlimitedQuota + validation.Field(&t.Quota, quotaRule, + validation.When(t.Quota != nil, validation.Min(int64(1)).Error("配额必须大于 0"))), + + // UnlimitedQuota 是否限制配额 + validation.Field(&t.UnlimitedQuota), + + // UsedQuota 系统维护,不需要验证 + validation.Field(&t.UsedQuota, + validation.When(t.UsedQuota != nil, validation.Min(int64(0)).Error("已使用配额不能为负数"))), + + // CreatedAt 系统维护,不需要验证 + validation.Field(&t.CreatedAt), + + // ExpiredAt 可选,但如果提供必须大于当前时间 + validation.Field(&t.ExpiredAt, expiredAtRule), + ) +} + +// ValidateCreate 创建时的特殊验证 +func (t TeamKey) ValidateCreate() error { + // 首先进行基本验证 + if err := t.Validate(); err != nil { + return err + } + + // 创建时的额外验证 + if t.Name == nil { + return errors.New("创建时必须提供名称") + } + + return nil +} diff --git a/internal/dto/passkey.go b/internal/dto/passkey.go new file mode 100644 index 0000000..5dd7aa5 --- /dev/null +++ b/internal/dto/passkey.go @@ -0,0 +1,11 @@ +package dto + +type Passkey struct { + ID int64 `json:"id" gorm:"column:id;primaryKey;autoIncrement"` + Name string `json:"name" gorm:"column:name"` // 凭证名称,用于用户识别不同的设备 + SignCount uint32 `json:"sign_count" gorm:"column:sign_count"` // 签名计数器,用于防止重放攻击 + DeviceType string `json:"device_type" gorm:"column:device_type"` // 设备类型,如"platform"或"cross-platform" + LastUsedAt int64 `json:"last_used_at" gorm:"column:last_used_at"` // 最后使用时间 + CreatedAt int64 `json:"created_at,omitempty" gorm:"autoCreateTime"` + UpdatedAt int64 `json:"updated_at,omitempty" gorm:"autoUpdateTime"` +} diff --git a/internal/dto/response.go b/internal/dto/response.go new file mode 100644 index 0000000..e2e4820 --- /dev/null +++ b/internal/dto/response.go @@ -0,0 +1,28 @@ +package dto + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +type Result struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data any `json:"data,omitempty"` +} + +func Success(ctx *gin.Context, data any) { + ctx.JSON(http.StatusOK, Result{ + Code: 200, + Data: data, + Msg: "success", + }) +} + +func Fail(c *gin.Context, code int, err string) { + c.AbortWithStatusJSON(code, gin.H{ + "code": code, + "error": err, + }) +} diff --git a/team/dto/team/team.go b/internal/dto/team/team.go similarity index 68% rename from team/dto/team/team.go rename to internal/dto/team/team.go index c9994ab..41c80a7 100644 --- a/team/dto/team/team.go +++ b/internal/dto/team/team.go @@ -1,12 +1,12 @@ package dto import ( - "opencatd-open/team/consts" - "opencatd-open/team/model" + "opencatd-open/internal/model" + "opencatd-open/internal/utils" ) type UserInfo struct { - ID int `json:"id"` + ID int64 `json:"id"` Name string `json:"name"` Token string `json:"token"` Status *bool `json:"status,omitempty"` @@ -24,7 +24,7 @@ func (u UserInfo) HasStatusUpdate() bool { return u.Status != nil } -type KeyInfo struct { +type ApiKeyInfo struct { ID int `json:"id,omitempty"` Key string `json:"key,omitempty"` Name string `json:"name,omitempty"` @@ -34,43 +34,43 @@ type KeyInfo struct { } // 添加辅助方法判断字段是否需要更新 -func (k KeyInfo) HasNameUpdate() bool { +func (k ApiKeyInfo) HasNameUpdate() bool { return k.Name != "" } -func (k KeyInfo) HasKeyUpdate() bool { +func (k ApiKeyInfo) HasKeyUpdate() bool { return k.Key != "" } -func (k KeyInfo) HasStatusUpdate() bool { +func (k ApiKeyInfo) HasStatusUpdate() bool { return k.Status != nil } -func (k KeyInfo) HasApiTypeUpdate() bool { +func (k ApiKeyInfo) HasApiTypeUpdate() bool { return k.ApiType != "" } // 辅助函数:统一处理字段更新 -func (update *KeyInfo) UpdateFields(existing *model.ApiKey) *model.ApiKey { +func (update *ApiKeyInfo) UpdateFields(existing *model.ApiKey) *model.ApiKey { result := &model.ApiKey{ ID: existing.ID, Name: existing.Name, // 默认保持原值 ApiType: existing.ApiType, // 默认保持原值 ApiKey: existing.ApiKey, // 默认保持原值 - Status: existing.Status, // 默认保持原值 + Active: existing.Active, // 默认保持原值 } if update.HasNameUpdate() { - result.Name = update.Name + result.Name = utils.ToPtr(update.Name) } if update.HasKeyUpdate() { - result.ApiKey = update.Key + result.ApiKey = utils.ToPtr(update.Key) } if update.HasStatusUpdate() { - result.Status = consts.OpenOrClose(*update.Status) + result.Active = update.Status } if update.HasApiTypeUpdate() { - result.ApiType = update.ApiType + result.ApiType = utils.ToPtr(update.ApiType) } return result diff --git a/internal/dto/user.go b/internal/dto/user.go new file mode 100644 index 0000000..e83605f --- /dev/null +++ b/internal/dto/user.go @@ -0,0 +1,16 @@ +package dto + +type User struct { + Username string `json:"username" binding:"required,min=3,max=32"` + Password string `json:"password" binding:"required,min=4"` +} + +type Auth struct { + Token string `json:"token"` + ExpiresIn int64 `json:"expires_in"` +} + +type ChangePassword struct { + Password string `json:"password" binding:"required,min=4"` + NewPassword string `json:"newpassword" binding:"required,min=4"` +} diff --git a/internal/model/apikey.go b/internal/model/apikey.go new file mode 100644 index 0000000..47260b3 --- /dev/null +++ b/internal/model/apikey.go @@ -0,0 +1,50 @@ +package model + +import "github.com/lib/pq" //pq.StringArray + +type ApiKey_PG struct { + ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id,omitempty"` + Name *string `gorm:"column:name;not null;unique;index:idx_apikey_name" json:"name,omitempty"` + ApiType *string `gorm:"column:apitype;not null;index:idx_apikey_apitype" json:"type,omitempty"` + ApiKey *string `gorm:"column:apikey;not null;index:idx_apikey_apikey" json:"apikey,omitempty"` + Active *bool `gorm:"column:active;default:true" json:"active,omitempty"` + Endpoint *string `gorm:"column:endpoint" json:"endpoint,omitempty"` + ResourceNmae *string `gorm:"column:resource_name" json:"resource_name,omitempty"` + DeploymentName *string `gorm:"column:deployment_name" json:"deployment_name,omitempty"` + ApiSecret *string `gorm:"column:api_secret" json:"api_secret,omitempty"` + ModelPrefix *string `gorm:"column:model_prefix" json:"model_prefix,omitempty"` + ModelAlias *string `gorm:"column:model_alias" json:"model_alias,omitempty"` + Parameters *string `gorm:"column:parameters" json:"parameters,omitempty"` + SupportModelsArray pq.StringArray `gorm:"column:support_models;type:text[]" json:"support_models_array,omitempty"` + SupportModels *string `gorm:"-" json:"support_models,omitempty"` + CreatedAt int64 `gorm:"column:created_at;autoUpdateTime" json:"created_at,omitempty"` + UpdatedAt int64 `gorm:"column:updated_at;autoCreateTime" json:"updated_at,omitempty"` +} + +func (ApiKey_PG) TableName() string { + return "apikeys" +} + +type ApiKey struct { + ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id,omitempty"` + Name *string `gorm:"column:name;not null;unique;index:idx_apikey_name" json:"name,omitempty"` + ApiType *string `gorm:"column:apitype;not null;index:idx_apikey_apitype" json:"type,omitempty"` + ApiKey *string `gorm:"column:apikey;not null;index:idx_apikey_apikey" json:"apikey,omitempty"` + Active *bool `gorm:"column:active;default:true" json:"active,omitempty"` + Endpoint *string `gorm:"column:endpoint" json:"endpoint,omitempty"` + ResourceNmae *string `gorm:"column:resource_name" json:"resource_name,omitempty"` + DeploymentName *string `gorm:"column:deployment_name" json:"deployment_name,omitempty"` + AccessKey *string `gorm:"column:access_key" json:"access_key,omitempty"` + SecretKey *string `gorm:"column:secret_key" json:"secret_key,omitempty"` + ModelPrefix *string `gorm:"column:model_prefix" json:"model_prefix,omitempty"` + ModelAlias *string `gorm:"column:model_alias" json:"model_alias,omitempty"` + Parameters *string `gorm:"column:parameters" json:"parameters,omitempty"` + SupportModels *string `gorm:"column:support_models;type:json" json:"support_models,omitempty"` + SupportModelsArray []string `gorm:"-" json:"support_models_array,omitempty"` + CreatedAt int64 `gorm:"column:created_at;autoUpdateTime" json:"created_at,omitempty"` + UpdatedAt int64 `gorm:"column:updated_at;autoCreateTime" json:"updated_at,omitempty"` +} + +func (ApiKey) TableName() string { + return "apikeys" +} diff --git a/internal/model/passkey.go b/internal/model/passkey.go new file mode 100644 index 0000000..d09cab1 --- /dev/null +++ b/internal/model/passkey.go @@ -0,0 +1,38 @@ +package model + +import ( + "time" +) + +// Passkey 用户凭证密钥模型 +type Passkey struct { + ID int64 `json:"id" gorm:"column:id;primaryKey;autoIncrement"` + UserID int64 `json:"user_id" gorm:"column:user_id;index"` + CredentialID string `json:"credential_id" gorm:"column:credential_id;index"` // 凭证ID,用于识别特定的passkey + PublicKey string `json:"public_key" gorm:"column:public_key"` // 公钥,用于验证签名 + AttestationType string `json:"attestation_type" gorm:"column:attestation_type"` // 证明类型 + AAGUID string `json:"aaguid" gorm:"column:aaguid"` // 认证器标识符 + SignCount uint32 `json:"sign_count" gorm:"column:sign_count"` // 签名计数器,用于防止重放攻击 + Name string `json:"name" gorm:"column:name"` // 凭证名称,用于用户识别不同的设备 + DeviceType string `json:"device_type" gorm:"column:device_type"` // 设备类型 + BackupEligible bool `json:"backup_eligible" gorm:"column:backup_eligible"` // 是否可备份 + BackupState bool `json:"backup_state" gorm:"backup_state"` // 备份状态 + Transport string `json:"transport" gorm:"column:transport"` // 传输方式 (如usb、nfc、ble等) + LastUsedAt int64 `json:"last_used_at" gorm:"column:last_used_at;autoUpdateTime"` // 最后使用时间 + CreatedAt int64 `json:"created_at,omitempty" gorm:"column:created_at;autoCreateTime"` + UpdatedAt int64 `json:"updated_at,omitempty" gorm:"column:updated_at;autoUpdateTime"` + + // 关联用户模型(不存入数据库) + User User `json:"-" gorm:"foreignKey:UserID;references:ID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` +} + +// 创建表结构 +func (Passkey) TableName() string { + return "passkeys" +} + +// UpdateSignCount 更新签名计数器和最后使用时间 +func (p *Passkey) UpdateSignCount(count uint32) { + p.SignCount = count + p.LastUsedAt = time.Now().Unix() +} diff --git a/internal/model/token.go b/internal/model/token.go new file mode 100644 index 0000000..60e9733 --- /dev/null +++ b/internal/model/token.go @@ -0,0 +1,22 @@ +package model + +// 用户的token +type Token struct { + ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id,omitempty"` + UserID int64 `gorm:"column:user_id;not null;index:idx_token_user_id" json:"userid,omitempty"` + Name string `gorm:"column:name;not null;index:idx_token_name" json:"name,omitempty" binding:"required,min=1,max=20"` + Key string `gorm:"column:key;not null;uniqueIndex:idx_token_key;comment:token key" json:"key,omitempty"` + Active *bool `gorm:"column:active;default:true" json:"active,omitempty"` // + Quota *int64 `gorm:"column:quota;type:bigint;default:0" json:"quota,omitempty"` // default 0 + UnlimitedQuota *bool `gorm:"column:unlimited_quota;default:true" json:"unlimited_quota,omitempty"` // set Quota 1 unlimited + UsedQuota *int64 `gorm:"column:used_quota;type:bigint;default:0" json:"used_quota,omitempty"` + ExpiredAt *int64 `gorm:"column:expired_at;type:bigint;default:0" json:"expired_at,omitempty"` + NeverExpired *bool `gorm:"column:never_expires;type:bigint;" json:"never_expires,omitempty"` + CreatedAt int64 `gorm:"column:created_at;type:bigint;autoCreateTime" json:"created_at,omitempty"` + LastUsedAt int64 `gorm:"column:lastused_at;type:bigint;autoUpdateTime" json:"lastused_at,omitempty"` + User *User `gorm:"foreignKey:UserID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE" json:"-"` +} + +func (Token) TableName() string { + return "tokens" +} diff --git a/team/model/usage.go b/internal/model/usage.go similarity index 100% rename from team/model/usage.go rename to internal/model/usage.go diff --git a/internal/model/user.go b/internal/model/user.go new file mode 100644 index 0000000..f084d60 --- /dev/null +++ b/internal/model/user.go @@ -0,0 +1,53 @@ +package model + +import ( + "opencatd-open/internal/consts" + "time" +) + +type User struct { + ID int64 `json:"id" gorm:"column:id;primaryKey;autoIncrement"` + Name string `json:"name" gorm:"column:name;index"` + Username string `json:"username" gorm:"column:username;unique;index"` + Password string `json:"-" gorm:"column:password;"` + NewPassword string `json:"newpassword" gorm:"-"` + Role *consts.UserRole `json:"role" gorm:"column:role;type:int;default:0"` // default user 0-10-20 + Active *bool `json:"active" gorm:"column:active;default:true;"` + Status int `json:"status" gorm:"column:status;type:int;default:1"` // disabled 0, enabled 1, deleted 2 + AvatarURL string `json:"avatar_url" gorm:"column:avatar_url;type:varchar(255)"` + EmailVerified *bool `json:"email_verified" gorm:"column:email_verified;default:false"` + Email string `json:"email" gorm:"column:email;type:varchar(255);index"` + Quota *float32 `json:"quota" gorm:"column:quota;bigint;default:0"` // default unlimited + UsedQuota *float32 `json:"used_quota" gorm:"column:used_quota;bigint;default:0"` // default 0 + UnlimitedQuota *bool `json:"unlimited_quota" gorm:"column:unlimited_quota;default:true;"` // 0 limited , 1 unlimited + Timezone string `json:"timezone" gorm:"column:timezone;type:varchar(50)"` + Language string `json:"language" gorm:"column:language;type:varchar(50)"` + + // 添加一对多关系 + Tokens []Token `json:"tokens" gorm:"foreignKey:UserID;references:ID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` + Passkeys []Passkey `json:"passkeys" gorm:"foreignKey:UserID;references:ID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` + + CreatedAt int64 `json:"created_at,omitempty" gorm:"autoCreateTime"` + UpdatedAt int64 `json:"updated_at,omitempty" gorm:"autoUpdateTime"` +} + +func (User) TableName() string { + return "users" +} + +type Session struct { + ID int64 `json:"id" gorm:"primaryKey;autoIncrement"` + UserID int64 `json:"user_id" gorm:"index:idx_user_id"` + Token string `json:"token" gorm:"type:varchar(64);uniqueIndex"` + DeviceType string `json:"device_type" gorm:"type:varchar(100);default:''"` + DeviceName string `json:"device_name" gorm:"type:varchar(100);default:''"` + LastActiveAt time.Time `json:"last_active_at" gorm:"type:timestamp;default:CURRENT_TIMESTAMP"` + LogoutAt time.Time `json:"logout_at" gorm:"type:timestamp;null"` + + CreatedAt time.Time `json:"created_at" gorm:"type:timestamp;not null;default:CURRENT_TIMESTAMP"` + UpdatedAt time.Time `json:"updated_at" gorm:"type:timestamp;not null;default:CURRENT_TIMESTAMP;update:CURRENT_TIMESTAMP"` +} + +func (Session) TableName() string { + return "sessions" +} diff --git a/internal/service/apikey.go b/internal/service/apikey.go new file mode 100644 index 0000000..bda4ab8 --- /dev/null +++ b/internal/service/apikey.go @@ -0,0 +1,92 @@ +package service + +import ( + "context" + "fmt" + "opencatd-open/internal/dao" + "opencatd-open/internal/model" + "opencatd-open/internal/utils" + + "gorm.io/gorm" +) + +type ApiKeyServiceImpl struct { + db *gorm.DB + apiKeyRepo dao.ApiKeyRepository +} + +func NewApiKeyService(db *gorm.DB, apiKeyDao dao.ApiKeyRepository) *ApiKeyServiceImpl { + return &ApiKeyServiceImpl{db: db, apiKeyRepo: apiKeyDao} +} + +func (s *ApiKeyServiceImpl) CreateApiKey(ctx context.Context, apikey *model.ApiKey) error { + return s.apiKeyRepo.Create(apikey) +} + +func (s *ApiKeyServiceImpl) GetApiKey(ctx context.Context, id int64) (*model.ApiKey, error) { + return s.apiKeyRepo.GetByID(id) +} + +func (s *ApiKeyServiceImpl) ListApiKey(ctx context.Context, limit, offset int, active []string) ([]*model.ApiKey, int64, error) { + var conditions = make(map[string]interface{}) + if len(active) > 0 { + conditions["active IN ?"] = utils.StringToBool(active) + } + return s.apiKeyRepo.ListWithFilters(limit, offset, conditions) +} + +func (s *ApiKeyServiceImpl) UpdateApiKey(ctx context.Context, apikey *model.ApiKey) error { + _key, err := s.apiKeyRepo.GetByID(apikey.ID) + if err != nil { + return fmt.Errorf("get apikey failed: %v", err) + } + if apikey.ApiKey != nil { + _key.ApiKey = apikey.ApiKey + } + if apikey.Active != nil { + _key.Active = apikey.Active + } + if apikey.Endpoint != nil { + _key.Endpoint = apikey.Endpoint + } + if apikey.ResourceNmae != nil { + _key.ResourceNmae = apikey.ResourceNmae + } + if apikey.DeploymentName != nil { + _key.DeploymentName = apikey.DeploymentName + } + if apikey.AccessKey != nil { + _key.AccessKey = apikey.AccessKey + } + if apikey.SecretKey != nil { + _key.SecretKey = apikey.SecretKey + } + if apikey.ModelAlias != nil { + _key.ModelAlias = apikey.ModelAlias + } + if apikey.ModelPrefix != nil { + _key.ModelPrefix = apikey.ModelPrefix + } + if apikey.Parameters != nil { + _key.Parameters = apikey.Parameters + } + if apikey.SupportModels != nil { + _key.SupportModels = apikey.SupportModels + } + if apikey.SupportModelsArray != nil { + _key.SupportModelsArray = apikey.SupportModelsArray + } + + return s.apiKeyRepo.Update(apikey) +} + +func (s *ApiKeyServiceImpl) DeleteApiKey(ctx context.Context, ids []int64) error { + return s.apiKeyRepo.BatchDelete(ids) +} + +func (s *ApiKeyServiceImpl) EnableApiKey(ctx context.Context, ids []int64) error { + return s.apiKeyRepo.BatchEnable(ids) +} +func (s *ApiKeyServiceImpl) DisableApiKey(ctx context.Context, ids []int64) error { + return s.apiKeyRepo.BatchDisable(ids) +} diff --git a/team/service/apikey.go b/internal/service/team/apikey.go similarity index 80% rename from team/service/apikey.go rename to internal/service/team/apikey.go index 543b0f7..f48efc2 100644 --- a/team/service/apikey.go +++ b/internal/service/team/apikey.go @@ -2,8 +2,8 @@ package service import ( "errors" - "opencatd-open/team/dao" - "opencatd-open/team/model" + "opencatd-open/internal/dao" + "opencatd-open/internal/model" "time" "gorm.io/gorm" @@ -18,10 +18,8 @@ type ApiKeyService interface { GetByApiKey(apiKeyValue string) (*model.ApiKey, error) Update(apiKey *model.ApiKey) error Delete(id int64) error - List(offset, limit int, status *int) ([]model.ApiKey, error) - ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.ApiKey, int64, error) - Enable(id int64) error - Disable(id int64) error + List(limit, offset int, status string) ([]*model.ApiKey, error) + ListWithFilters(limit, offset int, filters map[string]interface{}) ([]*model.ApiKey, int64, error) BatchEnable(ids []int64) error BatchDisable(ids []int64) error BatchDelete(ids []int64) error @@ -29,11 +27,11 @@ type ApiKeyService interface { } type ApiKeyServiceImpl struct { - apiKeyRepo dao.ApiKeyRepository db *gorm.DB + apiKeyRepo dao.ApiKeyRepository } -func NewApiKeyService(apiKeyDao dao.ApiKeyRepository, db *gorm.DB) ApiKeyService { +func NewApiKeyService(db *gorm.DB, apiKeyDao dao.ApiKeyRepository) ApiKeyService { return &ApiKeyServiceImpl{apiKeyRepo: apiKeyDao, db: db} } @@ -41,10 +39,10 @@ func (s *ApiKeyServiceImpl) Create(apiKey *model.ApiKey) error { if apiKey == nil { return errors.New("apiKey不能为空") } - if apiKey.Name == "" { + if apiKey.Name == nil { return errors.New("apiKey名称不能为空") } - if apiKey.ApiKey == "" { + if apiKey.ApiKey == nil { return errors.New("apiKey值不能为空") } apiKey.CreatedAt = time.Now().Unix() @@ -88,25 +86,25 @@ func (s *ApiKeyServiceImpl) Delete(id int64) error { if id <= 0 { return errors.New("id 必须大于 0") } - return s.apiKeyRepo.Delete(id) + return s.apiKeyRepo.BatchDelete([]int64{id}) } -func (s *ApiKeyServiceImpl) List(offset, limit int, status *int) ([]model.ApiKey, error) { +func (s *ApiKeyServiceImpl) List(offset, limit int, status string) ([]*model.ApiKey, error) { if offset < 0 { offset = 0 } if limit <= 0 { - limit = 10 // 设置默认值 + limit = 20 // 设置默认值 } return s.apiKeyRepo.List(offset, limit, status) } -func (s *ApiKeyServiceImpl) ListWithFilters(offset, limit int, filters map[string]interface{}) ([]model.ApiKey, int64, error) { +func (s *ApiKeyServiceImpl) ListWithFilters(offset, limit int, filters map[string]interface{}) ([]*model.ApiKey, int64, error) { if offset < 0 { offset = 0 } if limit <= 0 { - limit = 10 // 设置默认值 + limit = 20 // 设置默认值 } return s.apiKeyRepo.ListWithFilters(offset, limit, filters) @@ -116,14 +114,14 @@ func (s *ApiKeyServiceImpl) Enable(id int64) error { if id <= 0 { return errors.New("id 必须大于 0") } - return s.apiKeyRepo.Enable(id) + return s.apiKeyRepo.BatchEnable([]int64{id}) } func (s *ApiKeyServiceImpl) Disable(id int64) error { if id <= 0 { return errors.New("id 必须大于 0") } - return s.apiKeyRepo.Disable(id) + return s.apiKeyRepo.BatchDisable([]int64{id}) } func (s *ApiKeyServiceImpl) BatchEnable(ids []int64) error { diff --git a/team/service/token.go b/internal/service/team/token.go similarity index 53% rename from team/service/token.go rename to internal/service/team/token.go index 8a4491b..2038d1c 100644 --- a/team/service/token.go +++ b/internal/service/team/token.go @@ -2,8 +2,8 @@ package service import ( "context" - "opencatd-open/team/dao" - "opencatd-open/team/model" + "opencatd-open/internal/dao" + "opencatd-open/internal/model" "strings" "github.com/google/uuid" @@ -14,19 +14,15 @@ var _ TokenService = (*TokenServiceImpl)(nil) type TokenService interface { Create(ctx context.Context, token *model.Token) error - GetByID(ctx context.Context, id int) (*model.Token, error) + GetByID(ctx context.Context, id int64) (*model.Token, error) GetByKey(ctx context.Context, key string) (*model.Token, error) - GetByUserID(ctx context.Context, userID int) (*model.Token, error) + GetByUserID(ctx context.Context, userID int64) (*model.Token, error) Update(ctx context.Context, token *model.Token) error UpdateWithCondition(ctx context.Context, token *model.Token, filters map[string]interface{}, updates map[string]interface{}) error - Delete(ctx context.Context, id int) error - Lists(ctx context.Context, offset, limit int) ([]model.Token, error) - ListsWithFilters(ctx context.Context, offset, limit int, filters map[string]interface{}) ([]model.Token, int64, error) + Delete(ctx context.Context, id int64) error + Lists(ctx context.Context, limit, offset int) ([]*model.Token, int64, error) Disable(ctx context.Context, id int) error Enable(ctx context.Context, id int) error - BatchDisable(ctx context.Context, ids []int) error - BatchEnable(ctx context.Context, ids []int) error - BatchDelete(ctx context.Context, ids []int) error } type TokenServiceImpl struct { @@ -39,12 +35,12 @@ func NewTokenService(tokenRepo dao.TokenRepository) TokenService { func (s *TokenServiceImpl) Create(ctx context.Context, token *model.Token) error { if token.Key == "" { - token.Key = "team-" + strings.ReplaceAll(uuid.New().String(), "-", "") + token.Key = "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", "") } return s.tokenRepo.Create(ctx, token) } -func (s *TokenServiceImpl) GetByID(ctx context.Context, id int) (*model.Token, error) { +func (s *TokenServiceImpl) GetByID(ctx context.Context, id int64) (*model.Token, error) { return s.tokenRepo.GetByID(ctx, id) } @@ -52,7 +48,7 @@ func (s *TokenServiceImpl) GetByKey(ctx context.Context, key string) (*model.Tok return s.tokenRepo.GetByKey(ctx, key) } -func (s *TokenServiceImpl) GetByUserID(ctx context.Context, userID int) (*model.Token, error) { +func (s *TokenServiceImpl) GetByUserID(ctx context.Context, userID int64) (*model.Token, error) { return s.tokenRepo.GetByUserID(ctx, userID) } @@ -64,16 +60,12 @@ func (s *TokenServiceImpl) UpdateWithCondition(ctx context.Context, token *model return s.tokenRepo.UpdateWithCondition(ctx, token, filters, updates) } -func (s *TokenServiceImpl) Delete(ctx context.Context, id int) error { - return s.tokenRepo.Delete(ctx, id) +func (s *TokenServiceImpl) Delete(ctx context.Context, id int64) error { + return s.tokenRepo.Delete(ctx, id, nil) } -func (s *TokenServiceImpl) Lists(ctx context.Context, offset, limit int) ([]model.Token, error) { - return s.tokenRepo.List(ctx, offset, limit) -} - -func (s *TokenServiceImpl) ListsWithFilters(ctx context.Context, offset, limit int, filters map[string]interface{}) ([]model.Token, int64, error) { - return s.tokenRepo.ListWithFilters(ctx, offset, limit, filters) +func (s *TokenServiceImpl) Lists(ctx context.Context, limit, offset int) ([]*model.Token, int64, error) { + return s.tokenRepo.ListWithFilters(ctx, limit, offset, nil) } func (s *TokenServiceImpl) Disable(ctx context.Context, id int) error { @@ -83,15 +75,3 @@ func (s *TokenServiceImpl) Disable(ctx context.Context, id int) error { func (s *TokenServiceImpl) Enable(ctx context.Context, id int) error { return s.tokenRepo.Enable(ctx, id) } - -func (s *TokenServiceImpl) BatchDisable(ctx context.Context, ids []int) error { - return s.tokenRepo.BatchDisable(ctx, ids) -} - -func (s *TokenServiceImpl) BatchEnable(ctx context.Context, ids []int) error { - return s.tokenRepo.BatchEnable(ctx, ids) -} - -func (s *TokenServiceImpl) BatchDelete(ctx context.Context, ids []int) error { - return s.tokenRepo.BatchDelete(ctx, ids) -} diff --git a/internal/service/team/usage.go b/internal/service/team/usage.go new file mode 100644 index 0000000..4fe44f7 --- /dev/null +++ b/internal/service/team/usage.go @@ -0,0 +1,62 @@ +package service + +import ( + "context" + "opencatd-open/internal/dao" + dto "opencatd-open/internal/dto/team" + "opencatd-open/internal/model" + "opencatd-open/pkg/config" + "time" + + "gorm.io/gorm" +) + +var _ UsageService = (*usageService)(nil) + +type UsageService interface { + ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error) + ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) + ListByDateRange(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) + + Delete(ctx context.Context, id int64) error +} + +type usageService struct { + ctx context.Context + cfg *config.Config + db *gorm.DB + + usageDAO dao.UsageRepository + dailyUsageDAO dao.DailyUsageRepository +} + +func NewUsageService(ctx context.Context, cfg *config.Config, db *gorm.DB, usageRepo dao.UsageRepository, dailyUsageRepo dao.DailyUsageRepository) UsageService { + srv := &usageService{ + ctx: ctx, + cfg: cfg, + db: db, + + usageDAO: usageRepo, + dailyUsageDAO: dailyUsageRepo, + } + + // 启动异步处理goroutine + + return srv +} + +func (s *usageService) ListByUserID(ctx context.Context, userID int64, limit int, offset int) ([]*model.Usage, error) { + return s.usageDAO.ListByUserID(ctx, userID, limit, offset) +} + +func (s *usageService) ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) { + return s.usageDAO.ListByCapability(ctx, capability, limit, offset) +} + +func (s *usageService) ListByDateRange(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) { + return s.dailyUsageDAO.StatUserUsages(ctx, start, end, filters) +} + +func (s *usageService) Delete(ctx context.Context, id int64) error { + return s.usageDAO.Delete(ctx, id) +} diff --git a/team/service/user.go b/internal/service/team/user.go similarity index 82% rename from team/service/user.go rename to internal/service/team/user.go index 9b809e7..d7a9dc6 100644 --- a/team/service/user.go +++ b/internal/service/team/user.go @@ -3,9 +3,9 @@ package service import ( "context" "errors" - "opencatd-open/team/consts" - "opencatd-open/team/dao" - "opencatd-open/team/model" + "opencatd-open/internal/consts" + "opencatd-open/internal/dao" + "opencatd-open/internal/model" "regexp" "strings" "time" @@ -123,10 +123,9 @@ type UserService interface { GetUserByUsername(ctx context.Context, username string) (*model.User, error) UpdateUser(ctx context.Context, user *model.User, operatorID int64) error DeleteUser(ctx context.Context, id int64, operatorID int64) error - ListUsers(ctx context.Context, page, pageSize int) ([]model.User, int64, error) - ListUsersWithFilters(ctx context.Context, page, pageSize int, filters map[string]interface{}) ([]model.User, int64, error) - EnableUser(ctx context.Context, id int64, operatorID int64) error - DisableUser(ctx context.Context, id int64, operatorID int64) error + ListUsers(ctx context.Context, limit, offset int, active string) ([]model.User, error) + // EnableUser(ctx context.Context, id int64, operatorID int64) error + // DisableUser(ctx context.Context, id int64, operatorID int64) error BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error BatchDisableUsers(ctx context.Context, ids []int64, operatorID int64) error BatchDeleteUsers(ctx context.Context, ids []int64, operatorID int64) error @@ -144,7 +143,7 @@ type userService struct { } // NewUserService 创建 UserService 实例 -func NewUserService(userRepo dao.UserRepository, db *gorm.DB) UserService { +func NewUserService(db *gorm.DB, userRepo dao.UserRepository) UserService { return &userService{ userRepo: userRepo, db: db, @@ -204,7 +203,7 @@ func (s *userService) CheckPermission(ctx context.Context, requiredRole consts.U userToken := ctx.Value("Token").(*model.Token) // 检查用户角色 - if userToken.User.Role < int(requiredRole) { + if *userToken.User.Role < requiredRole { return ErrPermissionDenied } @@ -387,44 +386,22 @@ func (s *userService) ResetPassword(ctx context.Context, userID int64, operatorI } // ListUsers 获取用户列表(增加过滤功能) -func (s *userService) ListUsers(ctx context.Context, page, pageSize int) ([]model.User, int64, error) { - if page < 1 { - page = 1 +func (s *userService) ListUsers(ctx context.Context, limit, offset int, active string) ([]model.User, error) { + if limit < 0 { + limit = 20 } - if pageSize < 1 { - pageSize = 10 + if offset < 0 { + offset = 0 + } + var users []model.User + var err error + if active != "" { + users, _, err = s.userRepo.List(limit, offset, map[string]interface{}{"active in ?": strings.Split(active, ",")}) + } else { + users, _, err = s.userRepo.List(limit, offset, nil) } - offset := (page - 1) * pageSize - - users, err := s.userRepo.List(offset, pageSize) - if err != nil { - return nil, 0, err - } - - var total int64 = 0 - - return users, total, nil -} - -// ListUsers 获取用户列表(增加过滤功能) -func (s *userService) ListUsersWithFilters(ctx context.Context, page, pageSize int, filters map[string]interface{}) ([]model.User, int64, error) { - if page < 1 { - page = 1 - } - if pageSize < 1 { - pageSize = 10 - } - - offset := (page - 1) * pageSize - - // 使用新的 ListWithFilters 方法 - users, total, err := s.userRepo.ListWithFilters(offset, pageSize, filters) - if err != nil { - return nil, 0, err - } - - return users, total, nil + return users, err } // generateRandomPassword 生成随机密码 @@ -485,7 +462,7 @@ func (s *userService) DeleteUser(ctx context.Context, id int64, operatorID int64 } // 检查是否试图删除管理员 - if user.Role == int(consts.RoleAdmin) { + if *user.Role == consts.RoleAdmin { return ErrPermissionDenied } @@ -494,70 +471,70 @@ func (s *userService) DeleteUser(ctx context.Context, id int64, operatorID int64 } // EnableUser 启用用户 -func (s *userService) EnableUser(ctx context.Context, id int64, operatorID int64) error { - // 检查参数 - if id <= 0 { - return ErrInvalidUserInput - } +// func (s *userService) EnableUser(ctx context.Context, id int64, operatorID int64) error { +// // 检查参数 +// if id <= 0 { +// return ErrInvalidUserInput +// } - // 检查操作者权限 - if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil { - return err - } +// // 检查操作者权限 +// if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil { +// return err +// } - return s.withTransaction(ctx, func(tx *gorm.DB) error { - // 检查用户是否存在 - user, err := s.userRepo.GetByID(id) - if err != nil { - return ErrUserNotFound - } +// return s.withTransaction(ctx, func(tx *gorm.DB) error { +// // 检查用户是否存在 +// user, err := s.userRepo.GetByID(id) +// if err != nil { +// return ErrUserNotFound +// } - // 如果用户已经是启用状态,返回成功 - if user.Status == consts.StatusEnabled { - return nil - } +// // 如果用户已经是启用状态,返回成功 +// if user.Status == consts.StatusEnabled { +// return nil +// } - return s.userRepo.Enable(id) - }) -} +// return s.userRepo.Enable(id) +// }) +// } // DisableUser 禁用用户 -func (s *userService) DisableUser(ctx context.Context, id int64, operatorID int64) error { - // 检查参数 - if id <= 0 { - return ErrInvalidUserInput - } +// func (s *userService) DisableUser(ctx context.Context, id int64, operatorID int64) error { +// // 检查参数 +// if id <= 0 { +// return ErrInvalidUserInput +// } - // 检查操作者权限 - if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil { - return err - } +// // 检查操作者权限 +// if err := s.CheckPermission(ctx, consts.RoleAdmin); err != nil { +// return err +// } - // 不允许禁用自己 - if id == operatorID { - return ErrInvalidOperation - } +// // 不允许禁用自己 +// if id == operatorID { +// return ErrInvalidOperation +// } - return s.withTransaction(ctx, func(tx *gorm.DB) error { - // 检查用户是否存在 - user, err := s.userRepo.GetByID(id) - if err != nil { - return ErrUserNotFound - } +// return s.withTransaction(ctx, func(tx *gorm.DB) error { +// // 检查用户是否存在 +// user, err := s.userRepo.GetByID(id) +// if err != nil { +// return ErrUserNotFound +// } - // 检查是否试图禁用超级管理员 - if user.Role == int(consts.RoleAdmin) { - return ErrPermissionDenied - } +// // 检查是否试图禁用超级管理员 +// if user.Role == consts.RoleAdmin { +// return ErrPermissionDenied +// } - // 如果用户已经是禁用状态,返回成功 - if user.Status == consts.StatusDisabled { - return nil - } +// // 如果用户已经是禁用状态,返回成功 +// if user.Status == consts.StatusDisabled { +// return nil +// } - return s.userRepo.Disable(id) - }) -} +// return s.userRepo.Disable(id) +// }) +// } // BatchEnableUsers 批量启用用户 func (s *userService) BatchEnableUsers(ctx context.Context, ids []int64, operatorID int64) error { @@ -579,7 +556,7 @@ func (s *userService) BatchEnableUsers(ctx context.Context, ids []int64, operato if err != nil { return ErrUserNotFound } - if user.Status == consts.StatusEnabled { + if *user.Active == true { enabledUsers = append(enabledUsers, id) } } @@ -598,7 +575,7 @@ func (s *userService) BatchEnableUsers(ctx context.Context, ids []int64, operato } if len(toEnableIds) > 0 { - return s.userRepo.BatchEnable(toEnableIds) + return s.userRepo.BatchEnable(toEnableIds, nil) } return nil }) @@ -630,7 +607,7 @@ func (s *userService) BatchDisableUsers(ctx context.Context, ids []int64, operat return ErrUserNotFound } // 不允许禁用管理员 - if user.Role == int(consts.RoleAdmin) { + if *user.Role == consts.RoleAdmin { return ErrPermissionDenied } if user.Status == consts.StatusDisabled { @@ -652,7 +629,7 @@ func (s *userService) BatchDisableUsers(ctx context.Context, ids []int64, operat } if len(toDisableIds) > 0 { - return s.userRepo.BatchDisable(toDisableIds) + return s.userRepo.BatchDisable(toDisableIds, nil) } return nil }) @@ -682,17 +659,17 @@ func (s *userService) BatchDeleteUsers(ctx context.Context, ids []int64, operato if err != nil { return ErrUserNotFound } - if user.Role == int(consts.RoleAdmin) { + if *user.Role == consts.RoleAdmin { return ErrPermissionDenied } } - return s.userRepo.BatchDelete(ids) + return s.userRepo.BatchDelete(ids, nil) }) } // contains 检查切片中是否包含特定值 -func contains(slice []int64, item int64) bool { +func contains[T comparable](slice []T, item T) bool { for _, s := range slice { if s == item { return true diff --git a/internal/service/token.go b/internal/service/token.go new file mode 100644 index 0000000..2a16a16 --- /dev/null +++ b/internal/service/token.go @@ -0,0 +1,251 @@ +package service + +import ( + "context" + "fmt" + "opencatd-open/internal/consts" + "opencatd-open/internal/dao" + "opencatd-open/internal/model" + "opencatd-open/internal/utils" + "strings" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +// var _ TokenService = (*TokenServiceImpl)(nil) + +// type TokenService interface { +// } + +type TokenServiceImpl struct { + db *gorm.DB + tokenRepo dao.TokenRepository +} + +func NewTokenService(db *gorm.DB, tokenRepo dao.TokenRepository) *TokenServiceImpl { + return &TokenServiceImpl{ + db: db, + tokenRepo: tokenRepo, + } +} + +func (t *TokenServiceImpl) CreateToken(ctx context.Context, token *model.Token) error { + if token.UserID == 0 { + token.UserID = ctx.Value("user_id").(int64) + } + if token.Active == nil { + token.Active = utils.ToPtr(true) + } + if token.UnlimitedQuota == nil { + token.UnlimitedQuota = utils.ToPtr(true) + } + if token.ExpiredAt == nil { + token.ExpiredAt = utils.ToPtr(int64(-1)) + } + + if token.Key == "" { + token.Key = "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", "") + } + if !strings.HasPrefix(token.Key, "sk-team-") { + token.Key = "sk-team-" + strings.ReplaceAll(token.Key, " ", "") + } + return t.tokenRepo.Create(ctx, token) +} + +func (t *TokenServiceImpl) GetToken(ctx context.Context, id int64) (*model.Token, error) { + userid := ctx.Value("user_id").(int64) + tk := &model.Token{} + return tk, t.db.Model(&model.Token{}).Where("user_id = ?", userid).Where("id = ?", id).First(tk).Error +} + +func (t *TokenServiceImpl) ListToken(ctx context.Context, limit, offset int, active []string) ([]*model.Token, int64, error) { + userid := ctx.Value("user_id").(int64) + condition := make(map[string]interface{}) + condition["user_id = ?"] = userid + if len(active) > 0 { + condition["active IN ?"] = utils.StringToBool(active) + return t.tokenRepo.ListWithFilters(ctx, limit, offset, condition) + } + return t.tokenRepo.ListWithFilters(ctx, limit, offset, condition) +} + +func (t *TokenServiceImpl) UpdateToken(ctx context.Context, token *model.Token) error { + userid := ctx.Value("user_id").(int64) // 操作者 + userRoleValue := ctx.Value("user_role") + if userRoleValue == nil { + return fmt.Errorf("user role not found in context") + } + + role, ok := userRoleValue.(*consts.UserRole) // 操作角色 + if !ok { + return fmt.Errorf("user role in context is not an integer") + } + + switch { + case *role < consts.RoleAdmin: + if userid != token.UserID { + return fmt.Errorf("Permission denied") + } + case *role == consts.RoleAdmin: + if *role <= *token.User.Role { + return fmt.Errorf("Permission denied") + } + } + + return t.db.Model(&model.Token{}).Where("id = ?", token.ID).Updates(token).Error +} + +func (t *TokenServiceImpl) ResetToken(ctx context.Context, id int64) error { + userid := ctx.Value("user_id").(int64) // 操作者 + userRoleValue := ctx.Value("user_role") + if userRoleValue == nil { + return fmt.Errorf("user role not found in context") + } + + role, ok := userRoleValue.(*consts.UserRole) // 操作角色 + if !ok { + return fmt.Errorf("user role in context is not an integer") + } + switch { + case *role < consts.RoleAdmin: + if userid != id { + return fmt.Errorf("Permission denied") + } + case *role == consts.RoleAdmin: + var user = &model.User{} + if err := t.db.Model(&model.User{}).Where("id = ?", id).First(user).Error; err != nil { + return fmt.Errorf("User not found") + } + if *role <= *user.Role { + return fmt.Errorf("Permission denied") + } + } + + token := "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", "") + return t.db.Model(&model.Token{}).Where("user_id = ?", userid).Where("id = ?", id).Update("token", token).Error +} +func (t *TokenServiceImpl) DeleteToken(ctx context.Context, id int64) error { + token, err := t.tokenRepo.GetByID(ctx, id) + if err != nil { + return fmt.Errorf("Token not found") + } + if token.User == nil { + return fmt.Errorf("Token user not found") + } + + role := ctx.Value("user_role").(*consts.UserRole) // 操作角色 + userid := ctx.Value("user_id").(int64) // 操作者 + + switch { + case *role < consts.RoleAdmin: + if userid != token.UserID { + return fmt.Errorf("Permission denied") + } + case *role == consts.RoleAdmin: + if *role <= *token.User.Role { + return fmt.Errorf("Permission denied") + } + } + + return t.db.Model(&model.Token{}).Where("id = ?", id).Delete(&model.Token{}).Error +} + +func (t *TokenServiceImpl) DeleteTokens(ctx context.Context, userid int64, ids []int64) error { + operator_id := ctx.Value("user_id").(int64) + + roleValue := ctx.Value("user_role") + if roleValue == nil { + return fmt.Errorf("user role not found in context") + } + operator_role, ok := roleValue.(*consts.UserRole) // 操作角色 + if !ok { + return fmt.Errorf("user role in context is not an integer") + } + + switch { + case *operator_role < consts.RoleAdmin: + if operator_id != userid { + return fmt.Errorf("Permission denied") + } + return t.tokenRepo.BatchDelete(ctx, ids, map[string]interface{}{"name != ?": "default", "user_id = ?": userid}) + case *operator_role == consts.RoleAdmin: + var user = &model.User{} + if err := t.db.Model(&model.User{}).Where("id = ?", userid).First(user).Error; err != nil { + return fmt.Errorf("User not found") + } + if *operator_role <= *user.Role { + return fmt.Errorf("Permission denied") + } + return t.tokenRepo.BatchDelete(ctx, ids, map[string]interface{}{"name != ?": "default", "user_id = ?": userid}) + default: + return t.tokenRepo.BatchDelete(ctx, ids, map[string]interface{}{"name != ?": "default"}) + } + +} + +func (t *TokenServiceImpl) EnableTokens(ctx context.Context, userid int64, ids []int64) error { + operator_id := ctx.Value("user_id").(int64) + + roleValue := ctx.Value("user_role") + if roleValue == nil { + return fmt.Errorf("user role not found in context") + } + operator_role, ok := roleValue.(*consts.UserRole) // 操作角色 + if !ok { + return fmt.Errorf("user role in context is not an integer") + } + + switch { + case *operator_role < consts.RoleAdmin: + if operator_id != userid { + return fmt.Errorf("Permission denied") + } + return t.tokenRepo.BatchEnable(ctx, ids, map[string]interface{}{"user_id = ?": userid}) + case *operator_role == consts.RoleAdmin: + var user = &model.User{} + if err := t.db.Model(&model.User{}).Where("id = ?", userid).First(user).Error; err != nil { + return fmt.Errorf("User not found") + } + if *operator_role <= *user.Role { + return fmt.Errorf("Permission denied") + } + return t.tokenRepo.BatchEnable(ctx, ids, map[string]interface{}{"user_id = ?": userid}) + default: + return t.tokenRepo.BatchEnable(ctx, ids, nil) + } + +} + +func (t *TokenServiceImpl) DisableTokens(ctx context.Context, userid int64, ids []int64) error { + operator_id := ctx.Value("user_id").(int64) + + roleValue := ctx.Value("user_role") + if roleValue == nil { + return fmt.Errorf("user role not found in context") + } + operator_role, ok := roleValue.(*consts.UserRole) // 操作角色 + if !ok { + return fmt.Errorf("user role in context is not an integer") + } + + switch { + case *operator_role < consts.RoleAdmin: + if operator_id != userid { + return fmt.Errorf("Permission denied") + } + return t.tokenRepo.BatchDisable(ctx, ids, map[string]interface{}{"user_id =": userid}) + case *operator_role == consts.RoleAdmin: + var user = &model.User{} + if err := t.db.Model(&model.User{}).Where("id = ?", userid).First(user).Error; err != nil { + return fmt.Errorf("User not found") + } + if *operator_role <= *user.Role { + return fmt.Errorf("Permission denied") + } + return t.tokenRepo.BatchDisable(ctx, ids, map[string]interface{}{"user_id =": userid}) + default: + return t.tokenRepo.BatchDisable(ctx, ids, nil) + } + +} diff --git a/internal/service/usage.go b/internal/service/usage.go new file mode 100644 index 0000000..e6f6cb8 --- /dev/null +++ b/internal/service/usage.go @@ -0,0 +1,22 @@ +package service + +import ( + "context" + "opencatd-open/pkg/config" + + "gorm.io/gorm" +) + +type UsageService struct { + Ctx context.Context + Cfg *config.Config + DB *gorm.DB +} + +func NewUsageService(ctx context.Context, cfg *config.Config, db *gorm.DB) *UsageService { + return &UsageService{ + Ctx: ctx, + Cfg: cfg, + DB: db, + } +} diff --git a/internal/service/user.go b/internal/service/user.go new file mode 100644 index 0000000..751f57d --- /dev/null +++ b/internal/service/user.go @@ -0,0 +1,320 @@ +package service + +import ( + "context" + "fmt" + "opencatd-open/internal/auth" + "opencatd-open/internal/consts" + "opencatd-open/internal/dao" + "opencatd-open/internal/dto" + "opencatd-open/internal/model" + "opencatd-open/internal/utils" + "strings" + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +type UserServiceImpl struct { + db *gorm.DB + userRepo dao.UserRepository +} + +func NewUserService(db *gorm.DB, userRepo dao.UserRepository) *UserServiceImpl { + return &UserServiceImpl{ + db: db, + userRepo: userRepo, + } +} + +func (s *UserServiceImpl) Register(ctx context.Context, req *model.User) error { + var _user model.User + var count int64 + err := s.db.Model(&model.User{}).Count(&count).Error + if err != nil { + return fmt.Errorf("username or email already exists") + } + if count == 0 { + _user.Name = "root" + _user.Role = utils.ToPtr(consts.RoleRoot) + _user.Active = utils.ToPtr(true) + _user.UnlimitedQuota = utils.ToPtr(true) + } + _user.Password, err = utils.HashPassword(req.Password) + if err != nil { + return err + } + _user.Username = req.Username + _user.Email = req.Email + _user.Tokens = []model.Token{ + { + Name: "default", + Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""), + }, + } + + return s.userRepo.Create(&_user) +} + +func (s *UserServiceImpl) Login(ctx context.Context, req *dto.User) (*dto.Auth, error) { + var _user model.User + if err := s.db.Model(&model.User{}).Where("username = ?", req.Username).First(&_user).Error; err != nil { + if err := s.db.Model(&model.User{}).Where("email = ?", req.Username).First(&_user).Error; err != nil { + return nil, err + } + } + if utils.CheckPassword(_user.Password, req.Password) { + day := 86400 + at, err := auth.GenerateTokenPair(&_user, consts.SecretKey, time.Duration(day)*time.Second, time.Duration(day*7)*time.Second) + if err != nil { + return nil, err + } + return &dto.Auth{ + Token: at.AccessToken, + ExpiresIn: time.Now().Add(time.Duration(day) * time.Second).Unix(), + }, nil + } + return nil, fmt.Errorf("密码错误") +} + +func (s *UserServiceImpl) Profile(ctx context.Context) (*model.User, error) { + id := ctx.Value("user_id").(int64) + return s.userRepo.GetByID(id) +} + +func (s *UserServiceImpl) List(ctx context.Context, limit, offset int, active []string) ([]model.User, int64, error) { + userRoleValue := ctx.Value("user_role") + if userRoleValue == nil { + return nil, 0, fmt.Errorf("user role not found in context") + } + + role, ok := userRoleValue.(*consts.UserRole) + if !ok { + return nil, 0, fmt.Errorf("user role in context is not an integer") + } + + if *role < consts.RoleAdmin { + return nil, 0, fmt.Errorf("Unauthorized") + } else if *role < consts.RoleRoot { // 管理员只能查看普通用户 + var condition = map[string]interface{}{"role = ?": consts.RoleUser} + if len(active) > 0 { + boolCondition := utils.StringToBool(active) + condition["active IN ?"] = boolCondition + } + return s.userRepo.List(limit, offset, condition) + } else { + var condition = make(map[string]interface{}) + if len(active) > 0 { + boolCondition := utils.StringToBool(active) + condition["active IN ?"] = boolCondition + } + return s.userRepo.List(limit, offset, condition) + } +} + +func (s *UserServiceImpl) Create(ctx context.Context, req *model.User) error { + userRoleValue := ctx.Value("user_role") + if userRoleValue == nil { + return fmt.Errorf("user role not found in context") + } + + role, ok := userRoleValue.(*consts.UserRole) + if !ok { + return fmt.Errorf("user role in context is not an integer") + } + var _user model.User + + if *role < consts.RoleAdmin { + return fmt.Errorf("Forbidden") + } else if *role < consts.RoleRoot { + _user.Role = utils.ToPtr(consts.RoleRoot) + } else { + _user.Role = req.Role + } + _user.Username = req.Username + _user.Name = req.Name + _user.Email = req.Email + _user.Active = req.Active + _user.Quota = req.Quota + _user.UnlimitedQuota = req.UnlimitedQuota + _user.Language = req.Language + if hashpass, err := utils.HashPassword(req.Password); err != nil { + return err + } else { + _user.Password = hashpass + } + _user.Tokens = []model.Token{ + { + Name: "default", + Key: "sk-team-" + strings.ReplaceAll(uuid.New().String(), "-", ""), + }, + } + + return s.userRepo.Create(&_user) +} +func (s *UserServiceImpl) GetByID(ctx context.Context, id int64) (*model.User, error) { + return s.userRepo.GetByID(id) +} + +func (s *UserServiceImpl) Update(ctx context.Context, user *model.User) error { + _user := ctx.Value("user").(*model.User) // 被更新的用户 + if _user == nil { + return fmt.Errorf("user not found in context") + } + userid := ctx.Value("user_id").(int64) // 操作者 + userRoleValue := ctx.Value("user_role") + if userRoleValue == nil { + return fmt.Errorf("user role not found in context") + } + + role, ok := userRoleValue.(*consts.UserRole) // 操作者角色 + if !ok { + return fmt.Errorf("user role in context is not an integer") + } + switch { + case *role < consts.RoleAdmin: + if user.ID != userid { + return fmt.Errorf("Permission denied") + } + case *role == consts.RoleAdmin: + if *user.Role > *role { // 更新的用户角色不能高于操作者角色 + return fmt.Errorf("Permission denied") + } + if *_user.Role >= *role { // 管理员之间不能被修改 + return fmt.Errorf("Permission denied") + } + case *role > consts.RoleAdmin: // 根不能被修改 + if user.ID == userid { + user.Role = role // root不能修改自己的角色 + } else { + if user.Role != nil && user.Role == utils.ToPtr(consts.RoleRoot) { + return fmt.Errorf("Root user Only one can exist") + } + } + } + + if user.Name != "" { + _user.Name = user.Name + } + if user.Username != "" { + _user.Username = user.Username + } + if user.Email != "" { + _user.Email = user.Email + _user.EmailVerified = utils.ToPtr(false) + } + if user.Active != nil { + _user.Active = user.Active + } + if user.Role != nil { + _user.Role = user.Role + } + if user.Active != nil { + _user.Active = user.Active + } + if user.Quota != nil { + _user.Quota = user.Quota + } + if user.UsedQuota != nil { + _user.UsedQuota = user.UsedQuota + } + if user.UnlimitedQuota != nil { + _user.UnlimitedQuota = user.UnlimitedQuota + } + if user.Timezone != "" { + _user.Timezone = user.Timezone + } + if user.Language != "" { + _user.Language = user.Language + } + return s.userRepo.Update(_user) +} + +func (s *UserServiceImpl) Delete(ctx context.Context, id int64) error { + _user, err := s.userRepo.GetByID(id) // 被更新的用户 + if err != nil { + return err + } + userid := ctx.Value("user_id").(int64) + userRoleValue := ctx.Value("user_role") + if userRoleValue == nil { + return fmt.Errorf("user role not found in context") + } + role, ok := userRoleValue.(*consts.UserRole) // 操作者 + if !ok { + return fmt.Errorf("user role in context is not an integer") + } + + switch { + case *role < consts.RoleAdmin: + if _user.ID != userid { + return fmt.Errorf("Permission denied") + } + case *role == consts.RoleAdmin: + if *_user.Role >= *role { // 管理员之间不能被修改 + return fmt.Errorf("Permission denied") + } + case *_user.Role == consts.RoleRoot: // 根不能被修改 + return fmt.Errorf("Root user can not be modified") + } + + return s.userRepo.Delete(id) +} + +func (s *UserServiceImpl) BatchDelete(ctx context.Context, ids []int64) error { + userRoleValue := ctx.Value("user_role") + if userRoleValue == nil { + return fmt.Errorf("user role not found in context") + } + role, ok := userRoleValue.(*consts.UserRole) + if !ok { + return fmt.Errorf("user role in context is not an integer") + } + + switch { + case *role < consts.RoleAdmin: + return fmt.Errorf("Unauthorized") + case *role == consts.RoleAdmin: + return s.userRepo.BatchDelete(ids, []string{fmt.Sprintf("role < %d", role)}) + } + return s.userRepo.BatchDelete(ids, []string{fmt.Sprintf("role < %d", consts.RoleRoot)}) +} + +func (s *UserServiceImpl) BatchEnable(ctx context.Context, ids []int64) error { + userRoleValue := ctx.Value("user_role") + if userRoleValue == nil { + return fmt.Errorf("user role not found in context") + } + role, ok := userRoleValue.(*consts.UserRole) + if !ok { + return fmt.Errorf("user role in context is not an integer") + } + + switch { + case *role < consts.RoleAdmin: + return fmt.Errorf("Unauthorized") + case *role == consts.RoleAdmin: + return s.userRepo.BatchEnable(ids, []string{fmt.Sprintf("role < %d", role)}) + } + return s.userRepo.BatchEnable(ids, nil) +} + +func (s *UserServiceImpl) BatchDisable(ctx context.Context, ids []int64) error { + userRoleValue := ctx.Value("user_role") + if userRoleValue == nil { + return fmt.Errorf("user role not found in context") + } + role, ok := userRoleValue.(*consts.UserRole) + if !ok { + return fmt.Errorf("user role in context is not an integer") + } + + switch { + case *role < consts.RoleAdmin: + return fmt.Errorf("Unauthorized") + case *role == consts.RoleAdmin: + return s.userRepo.BatchDisable(ids, []string{fmt.Sprintf("role < %d", role)}) + } + return s.userRepo.BatchDisable(ids, nil) +} diff --git a/internal/service/webauth.go b/internal/service/webauth.go new file mode 100644 index 0000000..37c3691 --- /dev/null +++ b/internal/service/webauth.go @@ -0,0 +1,304 @@ +package service + +import ( + "encoding/base64" + "fmt" + "net/http" + "opencatd-open/internal/model" + "opencatd-open/pkg/config" + "opencatd-open/pkg/store" + "strconv" + "strings" + "time" + + "github.com/go-webauthn/webauthn/protocol" + "github.com/go-webauthn/webauthn/webauthn" + "github.com/mileusna/useragent" + "gorm.io/gorm" +) + +var _ webauthn.User = (*WebAuthnUser)(nil) + +// WebAuthnUser 实现webauthn.User接口的结构体 +type WebAuthnUser struct { + User *model.User + // ID int64 + // Name string + // DisplayName string + Credentials []webauthn.Credential +} + +// WebAuthnID 返回用户ID +func (u *WebAuthnUser) WebAuthnID() []byte { + return []byte(strconv.Itoa(int(u.User.ID))) +} + +// WebAuthnName 返回用户名 +func (u *WebAuthnUser) WebAuthnName() string { + return u.User.Username +} + +// WebAuthnDisplayName 返回用户显示名 +func (u *WebAuthnUser) WebAuthnDisplayName() string { + return u.User.Name +} + +// WebAuthnCredentials 返回用户所有凭证 +func (u *WebAuthnUser) WebAuthnCredentials() []webauthn.Credential { + return u.Credentials +} + +func (u *WebAuthnUser) WebAuthnCredentialDescriptors() (descriptors []protocol.CredentialDescriptor) { + credentials := u.WebAuthnCredentials() + + descriptors = make([]protocol.CredentialDescriptor, len(credentials)) + + for i, credential := range credentials { + descriptors[i] = credential.Descriptor() + } + + return descriptors +} + +// WebAuthnService 提供WebAuthn相关功能 +type WebAuthnService struct { + DB *gorm.DB + WebAuthn *webauthn.WebAuthn + // Sessions map[string]webauthn.SessionData // 用于存储注册和认证过程中的会话数据 + Sessions *store.WebAuthnSessionStore +} + +// NewWebAuthnService 创建新的WebAuthn服务 +func NewWebAuthnService(db *gorm.DB, cfg *config.Config) (*WebAuthnService, error) { + // 创建WebAuthn配置 + wconfig := &webauthn.Config{ + RPDisplayName: config.Cfg.AppName, // 依赖方(Relying Party)显示名称 + RPID: config.Cfg.Domain, // 依赖方ID(通常为域名) + RPOrigins: []string{config.Cfg.AppURL}, // 依赖方源(URL) + AuthenticatorSelection: protocol.AuthenticatorSelection{ + RequireResidentKey: protocol.ResidentKeyRequired(), // 要求认证器存储用户 ID (resident key) + ResidentKey: protocol.ResidentKeyRequirementRequired, // 使用 Discoverable 模式 + UserVerification: protocol.VerificationPreferred, // 推荐用户验证 + AuthenticatorAttachment: "", // 允许任何认证器 (平台或跨平台) + }, + // EncodeUserIDAsString: true, // 将用户ID编码为字符串 + } + + wa, err := webauthn.New(wconfig) + if err != nil { + return nil, err + } + + return &WebAuthnService{ + DB: db, + WebAuthn: wa, + // Sessions: make(map[string]webauthn.SessionData), + Sessions: store.NewWebAuthnSessionStore(), + }, nil +} + +// GetUserWithCredentials 获取用户及其凭证 +func (s *WebAuthnService) GetUserWithCredentials(userID int64) (*WebAuthnUser, error) { + var user model.User + if err := s.DB.Model(&model.User{}).Preload("Passkeys").First(&user, userID).Error; err != nil { + return nil, err + } + + // 获取用户的所有Passkey + passkeys := user.Passkeys + + // 将Passkey转换为webauthn.Credential + credentials := make([]webauthn.Credential, len(passkeys)) + for i, pk := range passkeys { + credentialIDBytes, err := base64.StdEncoding.DecodeString(pk.CredentialID) + if err != nil { + return nil, fmt.Errorf("failed to decode CredentialID: %w", err) + } + publicKeyBytes, err := base64.StdEncoding.DecodeString(pk.PublicKey) + if err != nil { + return nil, fmt.Errorf("failed to decode PublicKey: %w", err) + } + aaguidBytes, err := base64.StdEncoding.DecodeString(pk.AAGUID) + if err != nil { + return nil, fmt.Errorf("failed to decode AAGUID: %w", err) + } + + var transport []protocol.AuthenticatorTransport + if pk.Transport != "" { + transport = []protocol.AuthenticatorTransport{protocol.AuthenticatorTransport(pk.Transport)} + } + + credentials[i] = webauthn.Credential{ + ID: credentialIDBytes, + PublicKey: publicKeyBytes, + AttestationType: pk.AttestationType, + Transport: transport, + Flags: webauthn.CredentialFlags{ + UserPresent: true, + UserVerified: true, + BackupEligible: pk.BackupEligible, + BackupState: pk.BackupState, + }, + Authenticator: webauthn.Authenticator{ + AAGUID: aaguidBytes, + SignCount: pk.SignCount, + CloneWarning: false, + }, + } + } + + // 创建WebAuthnUser + return &WebAuthnUser{ + User: &user, + Credentials: credentials, + }, nil +} + +// BeginRegistration 开始注册过程 +func (s *WebAuthnService) BeginRegistration(userID int64) (*protocol.CredentialCreation, error) { + user, err := s.GetUserWithCredentials(userID) + if err != nil { + return nil, err + } + + // 获取注册选项 + options, sessionData, err := s.WebAuthn.BeginRegistration(user) + // webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), + // webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()), // 排除已存在的凭证 + + if err != nil { + return nil, err + } + + // 保存会话数据 + userid := strconv.Itoa(int(userID)) + s.Sessions.SaveWebauthnSession(userid, sessionData) + + return options, nil +} + +// FinishRegistration 完成注册过程 +func (s *WebAuthnService) FinishRegistration(userID int64, response *http.Request, deviceName string) (*model.Passkey, error) { + user, err := s.GetUserWithCredentials(userID) + if err != nil { + return nil, err + } + + userid := strconv.Itoa(int(userID)) + // 获取并清除会话数据 + sessionData, err := s.Sessions.GetWebauthnSession(userid) + if err != nil { + return nil, err + } + s.Sessions.DeleteWebauthnSession(userid) + + // 完成注册 + credential, err := s.WebAuthn.FinishRegistration(user, *sessionData, response) + if err != nil { + return nil, err + } + + ua := useragent.Parse(response.UserAgent()) + + var transport string + if len(credential.Transport) > 0 { + transport = string(credential.Transport[0]) // 通常只取第一个传输方式 + } + // 创建Passkey记录 + passkey := &model.Passkey{ + UserID: userID, + CredentialID: base64.StdEncoding.EncodeToString(credential.ID), + PublicKey: base64.StdEncoding.EncodeToString(credential.PublicKey), + AttestationType: string(credential.AttestationType), + AAGUID: base64.StdEncoding.EncodeToString(credential.Authenticator.AAGUID), + SignCount: credential.Authenticator.SignCount, + Name: deviceName, + DeviceType: strings.TrimSpace(fmt.Sprintf("%s %s %s %s %s", ua.Device, ua.OS, ua.OSVersionNoFull(), ua.Name, ua.VersionNoFull())), + LastUsedAt: time.Now().Unix(), + BackupEligible: credential.Flags.BackupEligible, + BackupState: credential.Flags.BackupState, + Transport: transport, + } + + // 保存Passkey + if err := s.DB.Create(passkey).Error; err != nil { + return nil, err + } + + return passkey, nil +} + +// BeginLogin 开始登录过程 (无需用户ID,针对未认证用户) +func (s *WebAuthnService) BeginLogin() (*protocol.CredentialAssertion, error) { + // 不指定用户ID,让客户端决定使用哪个凭证 + options, session, err := s.WebAuthn.BeginDiscoverableLogin( + webauthn.WithUserVerification(protocol.VerificationPreferred), // 推荐用户验证 + ) + if err != nil { + return nil, err + } + + s.Sessions.SaveWebauthnSession(session.Challenge, session) + + return options, nil +} + +// FinishLogin 完成登录过程 +func (s *WebAuthnService) FinishLogin(challenge string, response *http.Request) (*WebAuthnUser, error) { + // 获取并清除会话数据 + sessionData, err := s.Sessions.GetWebauthnSession(challenge) + if err != nil { + return nil, err + } + s.Sessions.DeleteWebauthnSession(challenge) + + // 获取相应的用户 + // var user model.User + // if err := s.DB.First(&user, passkey.UserID).Error; err != nil { + // return nil, err + // } + + // 创建WebAuthnUser + // webAuthnUser, err := s.GetUserWithCredentials(user.ID) + // if err != nil { + // return nil, err + // } + + // 完成登录 + // _, err = s.WebAuthn.FinishLogin(webAuthnUser, sessionData, response) + // if err != nil { + // return nil, err + // } + var user *WebAuthnUser + _, err = s.WebAuthn.FinishDiscoverableLogin(s.GetWebAuthnUser(&user), *sessionData, response) + if err != nil { + return nil, err + } + // 更新Passkey的LastUsedAt + return user, nil +} + +func (s *WebAuthnService) GetWebAuthnUser(wau **WebAuthnUser) webauthn.DiscoverableUserHandler { + return func(rawID, userHandle []byte) (webauthn.User, error) { + userid, err := strconv.ParseInt(string(userHandle), 10, 64) + if err != nil { + return nil, err + } + *wau, err = s.GetUserWithCredentials(userid) + return *wau, err + } +} + +// ListPasskeys 列出用户所有Passkey +func (s *WebAuthnService) ListPasskeys(userID int64) ([]model.Passkey, error) { + var passkeys []model.Passkey + if err := s.DB.Where("user_id = ?", userID).Find(&passkeys).Error; err != nil { + return nil, err + } + return passkeys, nil +} + +// DeletePasskey 删除用户Passkey +func (s *WebAuthnService) DeletePasskey(userID int64, passkeyID int64) error { + return s.DB.Where("id = ? AND user_id = ?", passkeyID, userID).Delete(&model.Passkey{}).Error +} diff --git a/internal/utils/convert.go b/internal/utils/convert.go new file mode 100644 index 0000000..73a9500 --- /dev/null +++ b/internal/utils/convert.go @@ -0,0 +1,16 @@ +package utils + +import "strings" + +func StringToBool(strSlice []string) []bool { + boolSlice := make([]bool, len(strSlice)) + for i, str := range strSlice { + str = strings.ToLower(str) + if str == "true" { + boolSlice[i] = true + } else if str == "false" { + boolSlice[i] = false + } + } + return boolSlice +} diff --git a/internal/utils/map_tools.go b/internal/utils/map_tools.go new file mode 100644 index 0000000..fa59263 --- /dev/null +++ b/internal/utils/map_tools.go @@ -0,0 +1,139 @@ +package utils + +import ( + "fmt" + "reflect" + "strings" +) + +func MergeJSONObjects(dst, src map[string]interface{}) map[string]interface{} { + + result := make(map[string]interface{}) + for k, v := range dst { + result[k] = v + } + + for key, value2 := range src { + value1, exists := result[key] + + if exists { + map1Val, map1IsMap := value1.(map[string]interface{}) + map2Val, map2IsMap := value2.(map[string]interface{}) + + if map1IsMap && map2IsMap { + result[key] = MergeJSONObjects(map1Val, map2Val) + } else { + // 覆盖第一个map中的值 + result[key] = value2 + } + } else { + // 添加新的键值对 + result[key] = value2 + } + } + + return result +} + +func StructToMap(in interface{}) (map[string]interface{}, error) { + out := make(map[string]interface{}) + + v := reflect.ValueOf(in) + // If it's a pointer, dereference it + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + // Check if it's a struct + if v.Kind() != reflect.Struct { + return nil, fmt.Errorf("StructToMap only accepts structs or pointers to structs; got %T", v.Interface()) + } + + t := v.Type() // Get the type of the struct + for i := 0; i < v.NumField(); i++ { + // Get the field Value and Type + fieldV := v.Field(i) + fieldT := t.Field(i) + + // Skip unexported fields + if !fieldT.IsExported() { + continue + } + + // --- Handle JSON Tag --- + tag := fieldT.Tag.Get("json") + key := fieldT.Name // Default key is the field name + omitempty := false + + if tag != "" { + parts := strings.Split(tag, ",") + tagName := parts[0] + + if tagName == "-" { + // Skip fields tagged with "-" + continue + } + if tagName != "" { + key = tagName // Use tag name as key + } + + // Check for omitempty option + for _, part := range parts[1:] { + if part == "omitempty" { + omitempty = true + break + } + } + } + + // --- Handle omitempty --- + val := fieldV.Interface() + if omitempty && fieldV.IsZero() { + continue // Skip zero-value fields if omitempty is set + } + + // --- Handle Nested Structs/Pointers to Structs (Recursion) --- + // Check for pointer first + if fieldV.Kind() == reflect.Ptr { + // If pointer is nil and omitempty is set, it was already skipped + // If pointer is nil and omitempty is not set, add nil to map + if fieldV.IsNil() { + // Only add nil if omitempty is not set (already handled above) + if !omitempty { + out[key] = nil + } + continue // Move to next field + } + // If it points to a struct, dereference and recurse + if fieldV.Elem().Kind() == reflect.Struct { + nestedMap, err := StructToMap(fieldV.Interface()) // Pass the pointer + if err != nil { + // Decide how to handle nested errors, e.g., log or return + fmt.Printf("Warning: could not convert nested struct pointer %s: %v\n", fieldT.Name, err) + out[key] = val // Store original value on error? Or skip? + } else { + out[key] = nestedMap + } + continue // Move to next field after handling pointer + } + // If pointer to non-struct, just get the interface value (handled below) + val = fieldV.Interface() // Use the actual pointer value + + } else if fieldV.Kind() == reflect.Struct { + // If it's a struct (not a pointer), recurse + nestedMap, err := StructToMap(fieldV.Interface()) // Pass the struct value + if err != nil { + fmt.Printf("Warning: could not convert nested struct %s: %v\n", fieldT.Name, err) + out[key] = val // Store original value on error? Or skip? + } else { + out[key] = nestedMap + } + continue // Move to next field after handling struct + } + + // Assign the value (primitive, slice, map, non-struct pointer, etc.) + out[key] = val + } + + return out, nil +} diff --git a/internal/utils/password.go b/internal/utils/password.go new file mode 100644 index 0000000..8e90241 --- /dev/null +++ b/internal/utils/password.go @@ -0,0 +1,15 @@ +package utils + +import ( + "golang.org/x/crypto/bcrypt" +) + +func HashPassword(password string) (string, error) { + bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + return string(bytes), err +} + +func CheckPassword(hash, password string) bool { + err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) + return err == nil +} diff --git a/internal/utils/pointer.go b/internal/utils/pointer.go index 7847c92..7ac4b6a 100644 --- a/internal/utils/pointer.go +++ b/internal/utils/pointer.go @@ -3,3 +3,9 @@ package utils func ToPtr[T any](v T) *T { return &v } + +func UpdatePtrField[T any](target *T, value *T) { + if value != nil { + *target = *value + } +} diff --git a/llm/aws/aws.go b/llm/aws/aws.go new file mode 100644 index 0000000..d5c7288 --- /dev/null +++ b/llm/aws/aws.go @@ -0,0 +1,36 @@ +// /* +// # AWS +// https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-service.html +// https://aws.amazon.com/cn/bedrock/pricing/ +// Anthropic models Price for 1000 input tokens Price for 1000 output tokens +// Claude Instant $0.00163 $0.00551 + +// Claude $0.01102 $0.03268 + +// https://docs.aws.amazon.com/bedrock/latest/userguide/endpointsTable.html +// 地区名称 地区 端点 协议 +// 美国东部(弗吉尼亚北部) 美国东部1 bedrock-runtime.us-east-1.amazonaws.com HTTPS +// bedrock-runtime-fips.us-east-1.amazonaws.com HTTPS +// 美国西部(俄勒冈州) 美国西2号 bedrock-runtime.us-west-2.amazonaws.com HTTPS +// bedrock-runtime-fips.us-west-2.amazonaws.com HTTPS +// 亚太地区(新加坡) ap-东南-1 bedrock-runtime.ap-southeast-1.amazonaws.com HTTPS +// */ +// // + +package aws + +// import ( +// "context" +// "log" + +// "github.com/aws/aws-sdk-go-v2/config" +// ) + +// // ... + +// func CallClaude() { +// cfg, err := config.LoadDefaultConfig(context.TODO()) +// if err != nil { +// log.Fatalf("failed to load configuration, %v", err) +// } +// } diff --git a/pkg/azureopenai/azureopenai.go b/llm/azureopenai/azureopenai.go similarity index 99% rename from pkg/azureopenai/azureopenai.go rename to llm/azureopenai/azureopenai.go index 61dbc5d..f900a3d 100644 --- a/pkg/azureopenai/azureopenai.go +++ b/llm/azureopenai/azureopenai.go @@ -67,7 +67,6 @@ func Models(endpoint, apikey string) (*ModelsList, error) { return nil, err } return &modelsl, nil - } func RemoveTrailingSlash(s string) string { diff --git a/llm/claude/chat.go b/llm/claude/chat.go new file mode 100644 index 0000000..f7d2ecf --- /dev/null +++ b/llm/claude/chat.go @@ -0,0 +1,138 @@ +// https://docs.anthropic.com/claude/reference/messages_post + +package claude + +import ( + "context" + "encoding/json" + "opencatd-open/internal/model" + "opencatd-open/llm" + "opencatd-open/llm/openai" + + "github.com/gin-gonic/gin" +) + +func ChatProxy(c *gin.Context, chatReq *openai.ChatCompletionRequest) { + ChatMessages(c, chatReq) +} + +func ChatTextCompletions(c *gin.Context, chatReq *openai.ChatCompletionRequest) { + +} + +type ChatRequest struct { + Model string `json:"model,omitempty"` + Messages any `json:"messages,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` + System string `json:"system,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + AnthropicVersion string `json:"anthropic_version,omitempty"` +} + +func (c *ChatRequest) ByteJson() []byte { + bytejson, _ := json.Marshal(c) + return bytejson +} + +type ChatMessage struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` +} + +type VisionMessages struct { + Role string `json:"role,omitempty"` + Content []VisionContent `json:"content,omitempty"` +} + +type VisionContent struct { + Type string `json:"type,omitempty"` + Source *VisionSource `json:"source,omitempty"` + Text string `json:"text,omitempty"` +} + +type VisionSource struct { + Type string `json:"type,omitempty"` + MediaType string `json:"media_type,omitempty"` + Data string `json:"data,omitempty"` +} + +type ChatResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Model string `json:"model"` + StopSequence any `json:"stop_sequence"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + StopReason string `json:"stop_reason"` +} + +type ClaudeStreamResponse struct { + Type string `json:"type"` + Index int `json:"index"` + ContentBlock struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content_block"` + Delta struct { + Type string `json:"type"` + Text string `json:"text"` + StopReason string `json:"stop_reason"` + StopSequence any `json:"stop_sequence"` + } `json:"delta"` + Message struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []any `json:"content"` + Model string `json:"model"` + StopReason string `json:"stop_reason"` + StopSequence any `json:"stop_sequence"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` + } `json:"message"` + Error struct { + Type string `json:"type"` + Message string `json:"message"` + } `json:"error"` + Usage struct { + OutputTokens int `json:"output_tokens"` + } `json:"usage"` +} + +type Claude struct { + Ctx context.Context + + ApiKey *model.ApiKey + tokenUsage *llm.TokenUsage + + Done chan struct{} +} + +func NewClaude(ctx context.Context, apiKey *model.ApiKey) (*Claude, error) { + return &Claude{ + Ctx: context.Background(), + ApiKey: apiKey, + tokenUsage: &llm.TokenUsage{}, + Done: make(chan struct{}), + }, nil +} + +func (c *Claude) Chat(ctx context.Context, chatReq llm.ChatRequest) (*llm.ChatResponse, error) { + return nil, nil +} + +func (g *Claude) StreamChat(ctx context.Context, chatReq llm.ChatRequest) (chan *llm.StreamChatResponse, error) { + return nil, nil +} diff --git a/pkg/claude/claude.go b/llm/claude/claude.go similarity index 100% rename from pkg/claude/claude.go rename to llm/claude/claude.go diff --git a/pkg/claude/chat.go b/llm/claude/handle_proxy.go similarity index 70% rename from pkg/claude/chat.go rename to llm/claude/handle_proxy.go index fdfda85..43136b8 100644 --- a/pkg/claude/chat.go +++ b/llm/claude/handle_proxy.go @@ -1,5 +1,3 @@ -// https://docs.anthropic.com/claude/reference/messages_post - package claude import ( @@ -10,115 +8,16 @@ import ( "io" "log" "net/http" + "opencatd-open/llm/openai" + "opencatd-open/llm/vertexai" "opencatd-open/pkg/error" - "opencatd-open/pkg/openai" "opencatd-open/pkg/tokenizer" - "opencatd-open/pkg/vertexai" "opencatd-open/store" "strings" "github.com/gin-gonic/gin" ) -func ChatProxy(c *gin.Context, chatReq *openai.ChatCompletionRequest) { - ChatMessages(c, chatReq) -} - -func ChatTextCompletions(c *gin.Context, chatReq *openai.ChatCompletionRequest) { - -} - -type ChatRequest struct { - Model string `json:"model,omitempty"` - Messages any `json:"messages,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Stream bool `json:"stream,omitempty"` - System string `json:"system,omitempty"` - TopK int `json:"top_k,omitempty"` - TopP float64 `json:"top_p,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - AnthropicVersion string `json:"anthropic_version,omitempty"` -} - -func (c *ChatRequest) ByteJson() []byte { - bytejson, _ := json.Marshal(c) - return bytejson -} - -type ChatMessage struct { - Role string `json:"role,omitempty"` - Content string `json:"content,omitempty"` -} - -type VisionMessages struct { - Role string `json:"role,omitempty"` - Content []VisionContent `json:"content,omitempty"` -} - -type VisionContent struct { - Type string `json:"type,omitempty"` - Source *VisionSource `json:"source,omitempty"` - Text string `json:"text,omitempty"` -} - -type VisionSource struct { - Type string `json:"type,omitempty"` - MediaType string `json:"media_type,omitempty"` - Data string `json:"data,omitempty"` -} - -type ChatResponse struct { - ID string `json:"id"` - Type string `json:"type"` - Role string `json:"role"` - Model string `json:"model"` - StopSequence any `json:"stop_sequence"` - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - } `json:"usage"` - Content []struct { - Type string `json:"type"` - Text string `json:"text"` - } `json:"content"` - StopReason string `json:"stop_reason"` -} - -type ClaudeStreamResponse struct { - Type string `json:"type"` - Index int `json:"index"` - ContentBlock struct { - Type string `json:"type"` - Text string `json:"text"` - } `json:"content_block"` - Delta struct { - Type string `json:"type"` - Text string `json:"text"` - StopReason string `json:"stop_reason"` - StopSequence any `json:"stop_sequence"` - } `json:"delta"` - Message struct { - ID string `json:"id"` - Type string `json:"type"` - Role string `json:"role"` - Content []any `json:"content"` - Model string `json:"model"` - StopReason string `json:"stop_reason"` - StopSequence any `json:"stop_sequence"` - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - } `json:"usage"` - } `json:"message"` - Error struct { - Type string `json:"type"` - Message string `json:"message"` - } `json:"error"` - Usage struct { - OutputTokens int `json:"output_tokens"` - } `json:"usage"` -} - func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) { var ( req *http.Request diff --git a/llm/claude/v2/chat.go b/llm/claude/v2/chat.go new file mode 100644 index 0000000..9252895 --- /dev/null +++ b/llm/claude/v2/chat.go @@ -0,0 +1,246 @@ +package claude + +import ( + "context" + "encoding/base64" + "net/http" + "net/url" + "opencatd-open/internal/model" + "opencatd-open/llm" + "os" + "strings" + + "github.com/liushuangls/go-anthropic/v2" + "github.com/sashabaranov/go-openai" +) + +type Claude struct { + Ctx context.Context + ApiKey *model.ApiKey + tokenUsage *llm.TokenUsage + Done chan struct{} + Client *anthropic.Client +} + +func NewClaude(apiKey *model.ApiKey) (*Claude, error) { + opts := []anthropic.ClientOption{} + if os.Getenv("LOCAL_PROXY") != "" { + proxyUrl, err := url.Parse(os.Getenv("LOCAL_PROXY")) + if err == nil { + client := http.DefaultClient + client.Transport = &http.Transport{Proxy: http.ProxyURL(proxyUrl)} + opts = append(opts, anthropic.WithHTTPClient(client)) + } + } + return &Claude{ + Ctx: context.Background(), + ApiKey: apiKey, + tokenUsage: &llm.TokenUsage{}, + Done: make(chan struct{}), + Client: anthropic.NewClient(*apiKey.ApiKey, opts...), + }, nil +} + +func (c *Claude) Chat(ctx context.Context, chatReq llm.ChatRequest) (*llm.ChatResponse, error) { + var messages []anthropic.Message + + if len(chatReq.Messages) > 0 { + for _, msg := range chatReq.Messages { + var role anthropic.ChatRole + if msg.Role != "assistant" { + role = anthropic.RoleUser + } else { + role = anthropic.RoleAssistant + } + + var content []anthropic.MessageContent + if len(msg.MultiContent) > 0 { + for _, mc := range msg.MultiContent { + if mc.Type == "text" { + content = append(content, anthropic.MessageContent{Type: anthropic.MessagesContentTypeText, Text: &mc.Text}) + } + if mc.Type == "image_url" { + if strings.HasPrefix(mc.ImageURL.URL, "http") { + continue + } + if strings.HasPrefix(mc.ImageURL.URL, "data:image") { + var mediaType string + if strings.HasPrefix(mc.ImageURL.URL, "data:image/jpeg") { + mediaType = "image/jpeg" + } + if strings.HasPrefix(mc.ImageURL.URL, "data:image/png") { + mediaType = "image/png" + } + imageString := strings.Split(mc.ImageURL.URL, ",")[1] + imageBytes, _ := base64.StdEncoding.DecodeString(imageString) + + content = append(content, anthropic.MessageContent{Type: "image", Source: &anthropic.MessageContentSource{Type: "base64", MediaType: mediaType, Data: imageBytes}}) + } + + } + messages = append(messages, anthropic.Message{Role: role, Content: content}) + } + } else { + if len(msg.Content) > 0 { + content = append(content, anthropic.MessageContent{Type: "text", Text: &msg.Content}) + } + } + messages = append(messages, anthropic.Message{Role: role, Content: content}) + } + } + + var maxTokens int + if chatReq.MaxTokens > 0 { + maxTokens = chatReq.MaxTokens + } else { + if strings.Contains(chatReq.Model, "sonnet") || strings.Contains(chatReq.Model, "haiku") { + maxTokens = 8192 + } else { + maxTokens = 4096 + } + } + + resp, err := c.Client.CreateMessages(ctx, anthropic.MessagesRequest{ + Model: anthropic.Model(chatReq.Model), + Messages: messages, + MaxTokens: maxTokens, + Stream: false, + }) + if err != nil { + return nil, err + } + + c.tokenUsage.PromptTokens += resp.Usage.InputTokens + c.tokenUsage.CompletionTokens += resp.Usage.OutputTokens + c.tokenUsage.TotalTokens += resp.Usage.InputTokens + resp.Usage.OutputTokens + + return &llm.ChatResponse{ + Model: string(resp.Model), + Choices: []openai.ChatCompletionChoice{ + { + FinishReason: openai.FinishReason(resp.StopReason), + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: *resp.Content[0].Text, + }, + }, + }, + }, nil +} + +func (c *Claude) StreamChat(ctx context.Context, chatReq llm.ChatRequest) (chan *llm.StreamChatResponse, error) { + var messages []anthropic.Message + + if len(chatReq.Messages) > 0 { + for _, msg := range chatReq.Messages { + var role anthropic.ChatRole + if msg.Role != "assistant" { + role = anthropic.RoleUser + } else { + role = anthropic.RoleAssistant + } + + var content []anthropic.MessageContent + if len(msg.MultiContent) > 0 { + for _, mc := range msg.MultiContent { + if mc.Type == "text" { + content = append(content, anthropic.MessageContent{Type: anthropic.MessagesContentTypeText, Text: &mc.Text}) + } + if mc.Type == "image_url" { + if strings.HasPrefix(mc.ImageURL.URL, "http") { + continue + } + if strings.HasPrefix(mc.ImageURL.URL, "data:image") { + var mediaType string + if strings.HasPrefix(mc.ImageURL.URL, "data:image/jpeg") { + mediaType = "image/jpeg" + } + if strings.HasPrefix(mc.ImageURL.URL, "data:image/png") { + mediaType = "image/png" + } + imageString := strings.Split(mc.ImageURL.URL, ",")[1] + imageBytes, _ := base64.StdEncoding.DecodeString(imageString) + + content = append(content, anthropic.MessageContent{Type: "image", Source: &anthropic.MessageContentSource{Type: "base64", MediaType: mediaType, Data: imageBytes}}) + } + + } + messages = append(messages, anthropic.Message{Role: role, Content: content}) + } + } else { + if len(msg.Content) > 0 { + content = append(content, anthropic.MessageContent{Type: "text", Text: &msg.Content}) + } + } + messages = append(messages, anthropic.Message{Role: role, Content: content}) + } + } + + var maxTokens int + if chatReq.MaxTokens > 0 { + maxTokens = chatReq.MaxTokens + } else { + if strings.Contains(chatReq.Model, "sonnet") || strings.Contains(chatReq.Model, "haiku") { + maxTokens = 8192 + } else { + maxTokens = 4096 + } + } + + datachan := make(chan *llm.StreamChatResponse) + // var resp anthropic.MessagesResponse + var err error + go func() { + defer close(datachan) + _, err = c.Client.CreateMessagesStream(ctx, anthropic.MessagesStreamRequest{ + MessagesRequest: anthropic.MessagesRequest{ + Model: anthropic.Model(chatReq.Model), + Messages: messages, + MaxTokens: maxTokens, + }, + OnContentBlockDelta: func(data anthropic.MessagesEventContentBlockDeltaData) { + datachan <- &llm.StreamChatResponse{ + Model: chatReq.Model, + Choices: []openai.ChatCompletionStreamChoice{ + { + Delta: openai.ChatCompletionStreamChoiceDelta{Content: *data.Delta.Text}, + }, + }, + } + }, + OnMessageStart: func(memss anthropic.MessagesEventMessageStartData) { + c.tokenUsage.PromptTokens += memss.Message.Usage.InputTokens + c.tokenUsage.CompletionTokens += memss.Message.Usage.OutputTokens + c.tokenUsage.TotalTokens += memss.Message.Usage.InputTokens + memss.Message.Usage.OutputTokens + }, + OnMessageDelta: func(memdd anthropic.MessagesEventMessageDeltaData) { + c.tokenUsage.PromptTokens += memdd.Usage.InputTokens + c.tokenUsage.CompletionTokens += memdd.Usage.OutputTokens + c.tokenUsage.TotalTokens += memdd.Usage.InputTokens + memdd.Usage.OutputTokens + + datachan <- &llm.StreamChatResponse{ + Model: chatReq.Model, + Choices: []openai.ChatCompletionStreamChoice{ + {FinishReason: openai.FinishReason(memdd.Delta.StopReason)}, + }, + } + }, + }) + + select { + case <-ctx.Done(): + return + default: + } + }() + if err != nil { + return nil, err + } + + return datachan, err +} + +func (c *Claude) GetTokenUsage() *llm.TokenUsage { + return c.tokenUsage + +} diff --git a/pkg/google/chat.go b/llm/google/chat.go similarity index 99% rename from pkg/google/chat.go rename to llm/google/chat.go index d0e78d7..fd6a184 100644 --- a/pkg/google/chat.go +++ b/llm/google/chat.go @@ -11,7 +11,7 @@ import ( "io" "log" "net/http" - "opencatd-open/pkg/openai" + "opencatd-open/llm/openai" "opencatd-open/pkg/tokenizer" "opencatd-open/store" "strings" diff --git a/llm/google/v2/chat.go b/llm/google/v2/chat.go new file mode 100644 index 0000000..444b729 --- /dev/null +++ b/llm/google/v2/chat.go @@ -0,0 +1,228 @@ +// https://github.com/google-gemini/api-examples/ + +package google + +import ( + "context" + "encoding/base64" + "fmt" + "net/http" + "net/url" + "opencatd-open/internal/model" + "opencatd-open/llm" + "os" + "strings" + + "github.com/sashabaranov/go-openai" + + "google.golang.org/genai" +) + +type Gemini struct { + Ctx context.Context + Client *genai.Client + ApiKey *model.ApiKey + tokenUsage *llm.TokenUsage + + Done chan struct{} +} + +func NewGemini(ctx context.Context, apiKey *model.ApiKey) (*Gemini, error) { + hc := http.DefaultClient + if os.Getenv("LOCAL_PROXY") != "" { + proxyUrl, err := url.Parse(os.Getenv("LOCAL_PROXY")) + if err == nil { + hc = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}} + } + } + client, err := genai.NewClient(ctx, &genai.ClientConfig{ + APIKey: *apiKey.ApiKey, + Backend: genai.BackendGeminiAPI, + HTTPClient: hc, + }) + if err != nil { + return nil, err + } + + return &Gemini{ + Ctx: context.Background(), + Client: client, + ApiKey: apiKey, + tokenUsage: &llm.TokenUsage{}, + Done: make(chan struct{}), + }, nil + +} + +func (g *Gemini) Chat(ctx context.Context, chatReq llm.ChatRequest) (*llm.ChatResponse, error) { + var content []*genai.Content + if len(chatReq.Messages) > 0 { + for _, msg := range chatReq.Messages { + var role genai.Role + if msg.Role == "user" || msg.Role == "system" { + role = genai.RoleUser + } else { + role = genai.RoleModel + } + + if len(msg.MultiContent) > 0 { + for _, c := range msg.MultiContent { + var parts []*genai.Part + + if c.Type == "text" { + parts = append(parts, genai.NewPartFromText(c.Text)) + } + if c.Type == "image_url" { + if strings.HasPrefix(c.ImageURL.URL, "http") { + continue + } + if strings.HasPrefix(c.ImageURL.URL, "data:image") { + var mediaType string + if strings.HasPrefix(c.ImageURL.URL, "data:image/jpeg") { + mediaType = "image/jpeg" + } + if strings.HasPrefix(c.ImageURL.URL, "data:image/png") { + mediaType = "image/png" + } + imageString := strings.Split(c.ImageURL.URL, ",")[1] + imageBytes, _ := base64.StdEncoding.DecodeString(imageString) + + parts = append(parts, genai.NewPartFromBytes(imageBytes, mediaType)) + } + + } + content = append(content, genai.NewContentFromParts(parts, role)) + } + } else { + content = append(content, genai.NewContentFromText(msg.Content, role)) + } + + } + } + + tools := []*genai.Tool{{GoogleSearch: &genai.GoogleSearch{}}} + response, err := g.Client.Models.GenerateContent(g.Ctx, + chatReq.Model, + content, + &genai.GenerateContentConfig{Tools: tools}) + if err != nil { + return nil, err + } + + if response.UsageMetadata != nil { + g.tokenUsage.PromptTokens += int(response.UsageMetadata.PromptTokenCount) + g.tokenUsage.CompletionTokens += int(response.UsageMetadata.CandidatesTokenCount) + g.tokenUsage.ToolsTokens += int(response.UsageMetadata.ToolUsePromptTokenCount) + g.tokenUsage.TotalTokens += int(response.UsageMetadata.TotalTokenCount) + } + + // var text string + // if response.Candidates != nil && response.Candidates[0].Content != nil { + // for _, part := range response.Candidates[0].Content.Parts { + // text += part.Text + // } + // } + + return &llm.ChatResponse{ + Model: response.ModelVersion, + Choices: []openai.ChatCompletionChoice{ + { + Message: openai.ChatCompletionMessage{Content: response.Text(), Role: "assistant"}, + FinishReason: openai.FinishReason(response.Candidates[0].FinishReason), + }, + }, + Usage: openai.Usage{PromptTokens: g.tokenUsage.PromptTokens + g.tokenUsage.ToolsTokens, CompletionTokens: g.tokenUsage.CompletionTokens, TotalTokens: g.tokenUsage.TotalTokens}, + }, nil + +} + +func (g *Gemini) StreamChat(ctx context.Context, chatReq llm.ChatRequest) (chan *llm.StreamChatResponse, error) { + var contents []*genai.Content + if len(chatReq.Messages) > 0 { + for _, msg := range chatReq.Messages { + var role genai.Role + if msg.Role == "user" { + role = genai.RoleUser + } else { + role = genai.RoleModel + } + + if len(msg.MultiContent) > 0 { + for _, c := range msg.MultiContent { + var parts []*genai.Part + + if c.Type == "text" { + parts = append(parts, genai.NewPartFromText(c.Text)) + } + if c.Type == "image_url" { + if strings.HasPrefix(c.ImageURL.URL, "http") { + continue + } + if strings.HasPrefix(c.ImageURL.URL, "data:image") { + var mediaType string + if strings.HasPrefix(c.ImageURL.URL, "data:image/jpeg") { + mediaType = "image/jpeg" + } + if strings.HasPrefix(c.ImageURL.URL, "data:image/png") { + mediaType = "image/png" + } + imageString := strings.Split(c.ImageURL.URL, ",")[1] + imageBytes, _ := base64.StdEncoding.DecodeString(imageString) + + parts = append(parts, genai.NewPartFromBytes(imageBytes, mediaType)) + } + + } + contents = append(contents, genai.NewContentFromParts(parts, role)) + } + } else { + contents = append(contents, genai.NewContentFromText(msg.Content, role)) + } + + } + } + + datachan := make(chan *llm.StreamChatResponse) + var generr error + + tools := []*genai.Tool{{GoogleSearch: &genai.GoogleSearch{}}} + + go func() { + defer close(datachan) + for result, err := range g.Client.Models.GenerateContentStream(g.Ctx, chatReq.Model, contents, &genai.GenerateContentConfig{Tools: tools}) { + if err != nil { + fmt.Println(err) + generr = err + return + } + if result.UsageMetadata != nil { + g.tokenUsage.PromptTokens += int(result.UsageMetadata.PromptTokenCount) + g.tokenUsage.CompletionTokens += int(result.UsageMetadata.CandidatesTokenCount) + g.tokenUsage.ToolsTokens += int(result.UsageMetadata.ToolUsePromptTokenCount) + g.tokenUsage.TotalTokens += int(result.UsageMetadata.TotalTokenCount) + } + + datachan <- &llm.StreamChatResponse{ + Model: result.ModelVersion, + Choices: []openai.ChatCompletionStreamChoice{ + { + Delta: openai.ChatCompletionStreamChoiceDelta{ + Role: "assistant", + // Content: result.Candidates[0].Content.Parts[0].Text, + Content: result.Text(), + }, + FinishReason: openai.FinishReason(result.Candidates[0].FinishReason), + }, + }, + Usage: &openai.Usage{PromptTokens: g.tokenUsage.PromptTokens + g.tokenUsage.ToolsTokens, CompletionTokens: g.tokenUsage.CompletionTokens, TotalTokens: g.tokenUsage.TotalTokens}, + } + + } + }() + + return datachan, generr +} + +func (g *Gemini) GetTokenUsage() *llm.TokenUsage { + return g.tokenUsage +} diff --git a/llm/llm.go b/llm/llm.go new file mode 100644 index 0000000..0bf889e --- /dev/null +++ b/llm/llm.go @@ -0,0 +1,20 @@ +package llm + +import ( + "context" + "opencatd-open/internal/model" +) + +type LLM interface { + Chat(ctx context.Context, req ChatRequest) (*ChatResponse, error) + StreamChat(ctx context.Context, req ChatRequest) (chan *StreamChatResponse, error) + GetTokenUsage() *TokenUsage +} + +type llm struct { + ApiKey *model.ApiKey + Usage *model.Usage + tools any // TODO + Messages []any // TODO + llm LLM +} diff --git a/llm/openai/chat.go b/llm/openai/chat.go new file mode 100644 index 0000000..0ecc4af --- /dev/null +++ b/llm/openai/chat.go @@ -0,0 +1,178 @@ +package openai + +import ( + "encoding/json" + "os" + "strings" +) + +const ( + // https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation#latest-preview-api-releases + AzureApiVersion = "2024-10-21" + BaseHost = "api.openai.com" + OpenAI_Endpoint = "https://api.openai.com/v1/chat/completions" + Github_Marketplace = "https://models.inference.ai.azure.com/chat/completions" +) + +var ( + Custom_Endpoint string + AIGateWay_Endpoint string // "https://gateway.ai.cloudflare.com/v1/431ba10f11200d544922fbca177aaa7f/openai/openai/chat/completions" +) + +func init() { + if os.Getenv("OpenAI_Endpoint") != "" { + Custom_Endpoint = os.Getenv("OpenAI_Endpoint") + } + if os.Getenv("AIGateWay_Endpoint") != "" { + AIGateWay_Endpoint = os.Getenv("AIGateWay_Endpoint") + } +} + +// Vision Content +type VisionContent struct { + Type string `json:"type,omitempty"` + Text string `json:"text,omitempty"` + ImageURL *VisionImageURL `json:"image_url,omitempty"` +} +type VisionImageURL struct { + URL string `json:"url,omitempty"` + Detail string `json:"detail,omitempty"` +} + +type ChatCompletionMessage struct { + Role string `json:"role"` + Content any `json:"content"` + Name string `json:"name,omitempty"` + // MultiContent []VisionContent +} + +type FunctionDefinition struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters any `json:"parameters"` +} + +type Tool struct { + Type string `json:"type"` + Function *FunctionDefinition `json:"function,omitempty"` +} + +type StreamOption struct { + IncludeUsage bool `json:"include_usage,omitempty"` +} + +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + User string `json:"user,omitempty"` + // Functions []FunctionDefinition `json:"functions,omitempty"` + // FunctionCall any `json:"function_call,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` + // ToolChoice any `json:"tool_choice,omitempty"` + StreamOptions *StreamOption `json:"stream_options,omitempty"` +} + +func (c ChatCompletionRequest) ToByteJson() []byte { + bytejson, _ := json.Marshal(c) + return bytejson +} + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` +} + +type ChatCompletionResponse struct { + ID string `json:"id,omitempty"` + Object string `json:"object,omitempty"` + Created int `json:"created,omitempty"` + Model string `json:"model,omitempty"` + Choices []struct { + Index int `json:"index,omitempty"` + Message struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + } `json:"message,omitempty"` + Logprobs string `json:"logprobs,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + } `json:"choices,omitempty"` + Usage struct { + PromptTokens int `json:"prompt_tokens,omitempty"` + CompletionTokens int `json:"completion_tokens,omitempty"` + TotalTokens int `json:"total_tokens,omitempty"` + PromptTokensDetails struct { + CachedTokens int `json:"cached_tokens,omitempty"` + AudioTokens int `json:"audio_tokens,omitempty"` + } `json:"prompt_tokens_details,omitempty"` + CompletionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens,omitempty"` + AudioTokens int `json:"audio_tokens,omitempty"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"` + } `json:"completion_tokens_details,omitempty"` + } `json:"usage,omitempty"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` +} + +type Choice struct { + Index int `json:"index"` + Delta struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls"` + } `json:"delta"` + FinishReason string `json:"finish_reason"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + +type ChatCompletionStreamResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + Choices []Choice `json:"choices"` +} + +func (c *ChatCompletionStreamResponse) ByteJson() []byte { + bytejson, _ := json.Marshal(c) + return bytejson +} + +func modelmap(in string) string { + // gpt-3.5-turbo -> gpt-35-turbo + if strings.Contains(in, ".") { + return strings.ReplaceAll(in, ".", "") + } + return in +} + +type ErrResponse struct { + Error struct { + Message string `json:"message"` + Code string `json:"code"` + } `json:"error"` +} + +func (e *ErrResponse) ByteJson() []byte { + bytejson, _ := json.Marshal(e) + return bytejson +} diff --git a/pkg/openai/dall-e.go b/llm/openai/dall-e.go similarity index 100% rename from pkg/openai/dall-e.go rename to llm/openai/dall-e.go diff --git a/pkg/openai/chat.go b/llm/openai/handle_proxy.go similarity index 53% rename from pkg/openai/chat.go rename to llm/openai/handle_proxy.go index 69a30fa..1e752fe 100644 --- a/pkg/openai/chat.go +++ b/llm/openai/handle_proxy.go @@ -10,163 +10,11 @@ import ( "net/http" "opencatd-open/pkg/tokenizer" "opencatd-open/store" - "os" "strings" "github.com/gin-gonic/gin" ) -const ( - // https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation#latest-preview-api-releases - AzureApiVersion = "2024-06-01" - BaseHost = "api.openai.com" - OpenAI_Endpoint = "https://api.openai.com/v1/chat/completions" - Github_Marketplace = "https://models.inference.ai.azure.com/chat/completions" -) - -var ( - Custom_Endpoint string - AIGateWay_Endpoint string // "https://gateway.ai.cloudflare.com/v1/431ba10f11200d544922fbca177aaa7f/openai/openai/chat/completions" -) - -func init() { - if os.Getenv("OpenAI_Endpoint") != "" { - Custom_Endpoint = os.Getenv("OpenAI_Endpoint") - } - if os.Getenv("AIGateWay_Endpoint") != "" { - AIGateWay_Endpoint = os.Getenv("AIGateWay_Endpoint") - } -} - -// Vision Content -type VisionContent struct { - Type string `json:"type,omitempty"` - Text string `json:"text,omitempty"` - ImageURL *VisionImageURL `json:"image_url,omitempty"` -} -type VisionImageURL struct { - URL string `json:"url,omitempty"` - Detail string `json:"detail,omitempty"` -} - -type ChatCompletionMessage struct { - Role string `json:"role"` - Content any `json:"content"` - Name string `json:"name,omitempty"` - // MultiContent []VisionContent -} - -type FunctionDefinition struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - Parameters any `json:"parameters"` -} - -type Tool struct { - Type string `json:"type"` - Function *FunctionDefinition `json:"function,omitempty"` -} - -type StreamOption struct { - IncludeUsage bool `json:"include_usage,omitempty"` -} - -type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []ChatCompletionMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` - // Functions []FunctionDefinition `json:"functions,omitempty"` - // FunctionCall any `json:"function_call,omitempty"` - Tools []Tool `json:"tools,omitempty"` - ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` - // ToolChoice any `json:"tool_choice,omitempty"` - StreamOptions *StreamOption `json:"stream_options,omitempty"` -} - -func (c ChatCompletionRequest) ToByteJson() []byte { - bytejson, _ := json.Marshal(c) - return bytejson -} - -type ToolCall struct { - ID string `json:"id"` - Type string `json:"type"` - Function struct { - Name string `json:"name"` - Arguments string `json:"arguments"` - } `json:"function"` -} - -type ChatCompletionResponse struct { - ID string `json:"id,omitempty"` - Object string `json:"object,omitempty"` - Created int `json:"created,omitempty"` - Model string `json:"model,omitempty"` - Choices []struct { - Index int `json:"index,omitempty"` - Message struct { - Role string `json:"role,omitempty"` - Content string `json:"content,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - } `json:"message,omitempty"` - Logprobs string `json:"logprobs,omitempty"` - FinishReason string `json:"finish_reason,omitempty"` - } `json:"choices,omitempty"` - Usage struct { - PromptTokens int `json:"prompt_tokens,omitempty"` - CompletionTokens int `json:"completion_tokens,omitempty"` - TotalTokens int `json:"total_tokens,omitempty"` - PromptTokensDetails struct { - CachedTokens int `json:"cached_tokens,omitempty"` - AudioTokens int `json:"audio_tokens,omitempty"` - } `json:"prompt_tokens_details,omitempty"` - CompletionTokensDetails struct { - ReasoningTokens int `json:"reasoning_tokens,omitempty"` - AudioTokens int `json:"audio_tokens,omitempty"` - AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"` - RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"` - } `json:"completion_tokens_details,omitempty"` - } `json:"usage,omitempty"` - SystemFingerprint string `json:"system_fingerprint,omitempty"` -} - -type Choice struct { - Index int `json:"index"` - Delta struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls"` - } `json:"delta"` - FinishReason string `json:"finish_reason"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` -} - -type ChatCompletionStreamResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - Model string `json:"model"` - Choices []Choice `json:"choices"` -} - -func (c *ChatCompletionStreamResponse) ByteJson() []byte { - bytejson, _ := json.Marshal(c) - return bytejson -} - func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) { usagelog := store.Tokens{Model: chatReq.Model} @@ -365,23 +213,3 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) { log.Println(err) } } - -func modelmap(in string) string { - // gpt-3.5-turbo -> gpt-35-turbo - if strings.Contains(in, ".") { - return strings.ReplaceAll(in, ".", "") - } - return in -} - -type ErrResponse struct { - Error struct { - Message string `json:"message"` - Code string `json:"code"` - } `json:"error"` -} - -func (e *ErrResponse) ByteJson() []byte { - bytejson, _ := json.Marshal(e) - return bytejson -} diff --git a/pkg/openai/realtime.go b/llm/openai/realtime.go similarity index 100% rename from pkg/openai/realtime.go rename to llm/openai/realtime.go diff --git a/pkg/openai/tts.go b/llm/openai/tts.go similarity index 100% rename from pkg/openai/tts.go rename to llm/openai/tts.go diff --git a/pkg/openai/whisper.go b/llm/openai/whisper.go similarity index 100% rename from pkg/openai/whisper.go rename to llm/openai/whisper.go diff --git a/llm/openai_compatible/chat.go b/llm/openai_compatible/chat.go new file mode 100644 index 0000000..d0d0367 --- /dev/null +++ b/llm/openai_compatible/chat.go @@ -0,0 +1,221 @@ +package openai_compatible + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "opencatd-open/internal/model" + "opencatd-open/internal/utils" + "opencatd-open/llm" + "os" + "strings" +) + +// https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation#latest-preview-api-releases +const AzureApiVersion = "2024-10-21" +const defaultOpenAICompatibleEndpoint = "https://api.openai.com/v1/chat/completions" +const Github_Marketplace = "https://models.inference.ai.azure.com/chat/completions" + +type OpenAICompatible struct { + Client *http.Client + ApiKey *model.ApiKey + tokenUsage *llm.TokenUsage + Params map[string]interface{} + Done chan struct{} +} + +func NewOpenAICompatible(apikey *model.ApiKey) (*OpenAICompatible, error) { + hc := http.DefaultClient + if os.Getenv("LOCAL_PROXY") != "" { + proxyUrl, err := url.Parse(os.Getenv("LOCAL_PROXY")) + if err == nil { + tr := http.Transport{ + Proxy: http.ProxyURL(proxyUrl), + } + hc.Transport = &tr + } + } + + oc := OpenAICompatible{ + ApiKey: apikey, + Client: hc, + tokenUsage: &llm.TokenUsage{}, + Done: make(chan struct{}), + } + + if apikey.Parameters != nil { + var params map[string]interface{} + err := json.Unmarshal([]byte(*apikey.Parameters), ¶ms) + if err != nil { + return nil, err + } + oc.Params = params + } + + return &oc, nil +} + +func (o *OpenAICompatible) Chat(ctx context.Context, chatReq llm.ChatRequest) (*llm.ChatResponse, error) { + chatReq.Stream = false + dst, err := utils.StructToMap(chatReq) + if err != nil { + return nil, err + } + if len(o.Params) > 0 { + dst = utils.MergeJSONObjects(dst, o.Params) + } + + var reqBody bytes.Buffer + if err := json.NewEncoder(&reqBody).Encode(dst); err != nil { + return nil, err + } + + var req *http.Request + switch *o.ApiKey.ApiType { + case "azure": + formatModel := func(in string) string { + if strings.Contains(in, ".") { + return strings.ReplaceAll(in, ".", "") + } + return in + } + var buildurl string + if *o.ApiKey.Endpoint != "" { + buildurl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.Endpoint, formatModel(chatReq.Model), AzureApiVersion) + } else { + buildurl = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.ResourceNmae, formatModel(chatReq.Model), AzureApiVersion) + } + req, _ = http.NewRequest(http.MethodPost, buildurl, &reqBody) + req.Header.Set("api-key", *o.ApiKey.ApiKey) + case "github": + req, _ = http.NewRequest(http.MethodPost, Github_Marketplace, &reqBody) + default: + if o.ApiKey.Endpoint == nil || *o.ApiKey.Endpoint == "" { + req, _ = http.NewRequest(http.MethodPost, defaultOpenAICompatibleEndpoint, &reqBody) + } else { + req, _ = http.NewRequest(http.MethodPost, *o.ApiKey.Endpoint, &reqBody) + } + } + req.Header.Set("Authorization", "Bearer "+*o.ApiKey.ApiKey) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept-Encoding", "identity") + + resp, err := o.Client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var chatResp llm.ChatResponse + if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { + return nil, err + } + + o.tokenUsage.PromptTokens = chatResp.Usage.PromptTokens + o.tokenUsage.CompletionTokens = chatResp.Usage.CompletionTokens + o.tokenUsage.TotalTokens = chatResp.Usage.TotalTokens + return &chatResp, nil +} + +func (o *OpenAICompatible) StreamChat(ctx context.Context, chatReq llm.ChatRequest) (chan *llm.StreamChatResponse, error) { + chatReq.Stream = true + dst, err := utils.StructToMap(chatReq) + if err != nil { + return nil, err + } + if len(o.Params) > 0 { + dst = utils.MergeJSONObjects(dst, o.Params) + } + + var reqBody bytes.Buffer + if err := json.NewEncoder(&reqBody).Encode(dst); err != nil { + return nil, err + } + + var req *http.Request + switch *o.ApiKey.ApiType { + case "azure": + formatModel := func(in string) string { + if strings.Contains(in, ".") { + return strings.ReplaceAll(in, ".", "") + } + return in + } + var buildurl string + if *o.ApiKey.Endpoint != "" { + buildurl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.Endpoint, formatModel(chatReq.Model), AzureApiVersion) + } else { + buildurl = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.ResourceNmae, formatModel(chatReq.Model), AzureApiVersion) + } + req, _ = http.NewRequest(http.MethodPost, buildurl, &reqBody) + req.Header.Set("api-key", *o.ApiKey.ApiKey) + case "github": + req, _ = http.NewRequest(http.MethodPost, Github_Marketplace, &reqBody) + default: + if o.ApiKey.Endpoint == nil || *o.ApiKey.Endpoint == "" { + req, _ = http.NewRequest(http.MethodPost, defaultOpenAICompatibleEndpoint, &reqBody) + } else { + req, _ = http.NewRequest(http.MethodPost, *o.ApiKey.Endpoint, &reqBody) + } + } + req.Header.Set("Authorization", "Bearer "+*o.ApiKey.ApiKey) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept-Encoding", "identity") + + resp, err := o.Client.Do(req) + if err != nil { + return nil, err + } + + output := make(chan *llm.StreamChatResponse) + + b := new(bytes.Buffer) + teeReader := io.TeeReader(resp.Body, b) + // 流式响应 + scanner := bufio.NewScanner(teeReader) + + go func() { + defer resp.Body.Close() + defer close(output) + + for scanner.Scan() { + line := scanner.Bytes() + var streamResp llm.StreamChatResponse + if len(line) > 0 { + // fmt.Println(string(line)) + if bytes.HasPrefix(line, []byte("data: ")) { + if bytes.HasPrefix(line, []byte("data: [DONE]")) { + break + } + line = bytes.Replace(line, []byte("data: "), []byte(""), -1) + line = bytes.TrimSpace(line) + if err := json.Unmarshal(line, &streamResp); err != nil { + continue + } + if streamResp.Usage != nil { + o.tokenUsage.PromptTokens += streamResp.Usage.PromptTokens + o.tokenUsage.CompletionTokens += streamResp.Usage.CompletionTokens + o.tokenUsage.TotalTokens += streamResp.Usage.TotalTokens + } + output <- &streamResp + } + } + + // select { + // case <-ctx.Done(): + // return + // case output <- &streamResp: + // } + } + }() + return output, nil +} + +func (o *OpenAICompatible) GetTokenUsage() *llm.TokenUsage { + return o.tokenUsage +} diff --git a/llm/types.go b/llm/types.go new file mode 100644 index 0000000..5c6d99f --- /dev/null +++ b/llm/types.go @@ -0,0 +1,41 @@ +package llm + +import ( + "fmt" + + "github.com/sashabaranov/go-openai" +) + +type ChatRequest openai.ChatCompletionRequest + +type ChatResponse openai.ChatCompletionResponse + +type StreamChatResponse openai.ChatCompletionStreamResponse + +type ChatMessage openai.ChatCompletionMessage + +type TokenUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + ToolsTokens int `json:"total_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type ErrorResponse struct { + Err struct { + Message string `json:"message,omitempty"` + Type string `json:"type,omitempty"` + Param string `json:"param,omitempty"` + Code string `json:"code,omitempty"` + } `json:"error,omitempty"` + HTTPStatusCode int `json:"-"` + HTTPStatus string `json:"-"` +} + +func (e ErrorResponse) Error() string { + if e.HTTPStatusCode > 0 { + return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Err.Message) + } + + return e.Err.Message +} diff --git a/pkg/vertexai/auth.go b/llm/vertexai/auth.go similarity index 100% rename from pkg/vertexai/auth.go rename to llm/vertexai/auth.go diff --git a/makefile b/makefile index c14a6d7..cc0e579 100644 --- a/makefile +++ b/makefile @@ -25,7 +25,7 @@ web: build: # mkdir -p bin/ && go build -ldflags $(LDFlags) -o ./bin/ ./... rm -rf bin - mkdir -p bin/ && go build -ldflags "-s -w" -o ./bin/opencatd . + mkdir -p bin/ && go build -ldflags "-s -w" -o ./bin/opencatd opencat.go upx -9 bin/opencatd .PHONY:docker diff --git a/middleware/auth.go b/middleware/auth.go new file mode 100644 index 0000000..23a4ed2 --- /dev/null +++ b/middleware/auth.go @@ -0,0 +1,55 @@ +package middleware + +import ( + "fmt" + "net/http" + "opencatd-open/internal/auth" + "opencatd-open/internal/consts" + "opencatd-open/internal/dto" + "opencatd-open/internal/model" + "opencatd-open/pkg/store" + + "github.com/gin-gonic/gin" +) + +func Auth(c *gin.Context) { + authToken := c.GetHeader("Authorization") + if authToken == "" { + dto.Fail(c, http.StatusUnauthorized, "未提供认证信息") + return + } + authToken = authToken[7:] + claim, err := auth.ValidateToken(authToken, consts.SecretKey) + if err != nil { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "code": http.StatusUnauthorized, + "error": "无效的认证信息", + }) + return + } + var user model.User + if err := store.GetDB().Model(&model.User{ID: int64(claim.UserID)}).First(&user).Error; err != nil { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "code": http.StatusUnauthorized, + "error": "无效的认证信息", + }) + return + } + c.Set("user", &user) + c.Set("user_id", claim.UserID) + c.Set("user_role", user.Role) + c.Next() +} +func CheckRole(role consts.UserRole) func(c *gin.Context) { + fmt.Println("CheckRoleMiddleware") + return func(c *gin.Context) { + userRole := c.GetInt("user_role") // 操作者 + fmt.Println("userRole", userRole) + // if userRole < int(role) { + // dto.Fail(c, http.StatusForbidden, "permission denied") + // return + // } + + c.Next() + } +} diff --git a/middleware/auth_team.go b/middleware/auth_team.go new file mode 100644 index 0000000..a6b92ba --- /dev/null +++ b/middleware/auth_team.go @@ -0,0 +1,103 @@ +package middleware + +import ( + "net/http" + "opencatd-open/internal/dto" + "opencatd-open/internal/model" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +func AuthTeam(db *gorm.DB) gin.HandlerFunc { + return func(c *gin.Context) { + + auth_token := c.GetHeader("Authorization") + if len(auth_token) < 7 || auth_token[:7] != "Bearer " { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + c.Abort() + return + } + auth_token = auth_token[7:] + token := model.Token{} + if err := db.Preload("Users").First(&token, "token = ?", auth_token).Error; err != nil { + dto.WrapErrorAsOpenAI(c, http.StatusUnauthorized, "invalid_api_key") + c.Abort() + return + } + + if token.User == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + c.Abort() + return + } + + if !*token.User.Active || !*token.Active { + dto.WrapErrorAsOpenAI(c, http.StatusForbidden, "User or API key is not active") + c.Abort() + return + } + + if token.Name != "default" { + dto.WrapErrorAsOpenAI(c, http.StatusForbidden, "Only default api key accessible") + c.Abort() + return + } + + c.Set("user", token.User) + c.Set("authed", true) + // 可以在这里对 token 进行验证并检查权限 + c.Next() + + } +} + +func AuthLLM(db *gorm.DB) gin.HandlerFunc { + return func(c *gin.Context) { + + auth_token := c.GetHeader("Authorization") + if len(auth_token) < 7 || auth_token[:7] != "Bearer " { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + c.Abort() + return + } + auth_token = auth_token[7:] + token := model.Token{} + if err := db.Preload("Users").First(&token, "token = ?", auth_token).Error; err != nil { + dto.WrapErrorAsOpenAI(c, http.StatusUnauthorized, "invalid_api_key") + c.Abort() + return + } + + if token.User == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + c.Abort() + return + } + + if !*token.User.Active || !*token.Active { + dto.WrapErrorAsOpenAI(c, http.StatusForbidden, "User or API key is not active") + c.Abort() + return + } + + if !*token.User.UnlimitedQuota && *token.User.Quota <= 0 { + dto.WrapErrorAsOpenAI(c, http.StatusForbidden, "quota_exceeded") + c.Abort() + return + } + + if !*token.UnlimitedQuota && *token.Quota <= 0 { + dto.WrapErrorAsOpenAI(c, http.StatusForbidden, "quota_exceeded") + c.Abort() + return + } + + c.Set("user", token.User) + c.Set("authed", true) + // 可以在这里对 token 进行验证并检查权限 + + c.Next() + + } +} diff --git a/middleware/cors.go b/middleware/cors.go new file mode 100644 index 0000000..d2a109a --- /dev/null +++ b/middleware/cors.go @@ -0,0 +1,15 @@ +package middleware + +import ( + "github.com/gin-contrib/cors" + "github.com/gin-gonic/gin" +) + +func CORS() gin.HandlerFunc { + config := cors.DefaultConfig() + config.AllowAllOrigins = true + config.AllowCredentials = true + config.AllowMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"} + config.AllowHeaders = []string{"*"} + return cors.New(config) +} diff --git a/middleware/ratelimit.go b/middleware/ratelimit.go new file mode 100644 index 0000000..c35b47a --- /dev/null +++ b/middleware/ratelimit.go @@ -0,0 +1,53 @@ +package middleware + +import ( + "net/http" + "sync" + + "github.com/gin-gonic/gin" + "golang.org/x/time/rate" +) + +type IPRateLimiter struct { + ips map[string]*rate.Limiter + mu *sync.RWMutex + r rate.Limit + b int +} + +func NewIPRateLimiter(r rate.Limit, b int) *IPRateLimiter { + return &IPRateLimiter{ + ips: make(map[string]*rate.Limiter), + mu: &sync.RWMutex{}, + r: r, + b: b, + } +} + +func (i *IPRateLimiter) GetLimiter(ip string) *rate.Limiter { + i.mu.Lock() + defer i.mu.Unlock() + + limiter, exists := i.ips[ip] + if !exists { + limiter = rate.NewLimiter(i.r, i.b) + i.ips[ip] = limiter + } + + return limiter +} + +func RateLimit(limiter *IPRateLimiter) gin.HandlerFunc { + return func(c *gin.Context) { + ip := c.ClientIP() + if !limiter.GetLimiter(ip).Allow() { + c.JSON(http.StatusTooManyRequests, gin.H{ + "code": 429, + "message": "too many requests", + }) + c.Abort() + return + } + c.Next() + } +} diff --git a/opencat.go b/opencat.go index aa49857..33b105b 100644 --- a/opencat.go +++ b/opencat.go @@ -11,7 +11,6 @@ import ( "opencatd-open/router" "opencatd-open/store" "opencatd-open/team" - "opencatd-open/team/dashboard" "os" "github.com/duke-git/lancet/v2/fileutil" @@ -170,11 +169,6 @@ func main() { r.Any("/v1/*proxypath", router.HandleProxy) - api := r.Group("/api") - { - api.POST("/login", dashboard.HandleLogin) - } - // r.POST("/v1/chat/completions", router.HandleProy) // r.GET("/v1/models", router.HandleProy) // r.GET("/v1/dashboard/billing/subscription", router.HandleProy) diff --git a/pkg/config/config.go b/pkg/config/config.go new file mode 100644 index 0000000..e99294a --- /dev/null +++ b/pkg/config/config.go @@ -0,0 +1,241 @@ +package config + +import ( + "fmt" + "os" + "strconv" + "time" + + _ "github.com/joho/godotenv/autoload" +) + +var Cfg *Config + +// Config 结构体存储应用配置 +type Config struct { + // 服务器配置 + ServerPort int + ServerHost string + ReadTimeout time.Duration + WriteTimeout time.Duration + + // PassKey配置 + AppName string + Domain string + AppURL string + WebAuthnTimeout time.Duration + ChallengeExpiration time.Duration + + // 数据库配置 + DB_Type string + DSN string + DBMaxOpenConns int + DBMaxIdleConns int + // DBHost string + // DBPort int + // DBUser string + // DBPassword string + // DBName string + + // 缓存配置 + RedisHost string + RedisPort int + RedisPassword string + RedisDB int + + // 日志配置 + LogLevel string + LogPath string + + // 其他应用特定配置 + AllowRegister bool + UnlimitedQuota bool + DefaultActive bool + + UsageWorker int + UsageChanSize int + + TaskTimeInterval int +} + +func init() { + // 加载配置 + cfg, err := LoadConfig() + if err != nil { + panic(fmt.Sprintf("加载配置失败: %v", err)) + } + Cfg = cfg +} + +// LoadConfig 从环境变量加载配置 +func LoadConfig() (*Config, error) { + cfg := &Config{ + AppName: "OpenTeam", + Domain: "localhost", + AppURL: "https://localhost:5173", + // 默认值设置 + ServerPort: 8080, + ServerHost: "0.0.0.0", + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + + LogLevel: "info", + LogPath: "./logs/", + + DB_Type: "sqlite", + DSN: "", + DBMaxOpenConns: 10, + DBMaxIdleConns: 5, + + RedisDB: 0, + + // 系统设置 + AllowRegister: false, + UnlimitedQuota: false, + DefaultActive: false, + + UsageWorker: 1, + UsageChanSize: 1000, + TaskTimeInterval: 60, + } + + // PassKey配置 + if appName := os.Getenv("APP_NAME"); appName != "" { + cfg.AppName = appName + } + if domain := os.Getenv("DOMAIN"); domain != "" { + cfg.Domain = domain + } + if appURL := os.Getenv("APP_URL"); appURL != "" { + cfg.AppURL = appURL + } + + // 服务器配置 + if port := os.Getenv("SERVER_PORT"); port != "" { + if p, err := strconv.Atoi(port); err == nil { + cfg.ServerPort = p + } else { + return nil, fmt.Errorf("无效的SERVER_PORT: %s", port) + } + } + + if host := os.Getenv("SERVER_HOST"); host != "" { + cfg.ServerHost = host + } + + if timeout := os.Getenv("READ_TIMEOUT"); timeout != "" { + if t, err := strconv.Atoi(timeout); err == nil { + cfg.ReadTimeout = time.Duration(t) * time.Second + } else { + return nil, fmt.Errorf("无效的READ_TIMEOUT: %s", timeout) + } + } + + if timeout := os.Getenv("WRITE_TIMEOUT"); timeout != "" { + if t, err := strconv.Atoi(timeout); err == nil { + cfg.WriteTimeout = time.Duration(t) * time.Second + } else { + return nil, fmt.Errorf("无效的WRITE_TIMEOUT: %s", timeout) + } + } + + // 数据库配置 + if dbType := os.Getenv("DB_TYPE"); dbType != "" { + cfg.DB_Type = dbType + } else { + cfg.DB_Type = "sqlite" + } + + if dsn := os.Getenv("DB_DSN"); dsn != "" { + cfg.DSN = dsn + } + + if conns := os.Getenv("DB_MAX_OPEN_CONNS"); conns != "" { + if c, err := strconv.Atoi(conns); err == nil { + cfg.DBMaxOpenConns = c + } else { + return nil, fmt.Errorf("无效的DB_MAX_OPEN_CONNS: %s", conns) + } + } + + if conns := os.Getenv("DB_MAX_IDLE_CONNS"); conns != "" { + if c, err := strconv.Atoi(conns); err == nil { + cfg.DBMaxIdleConns = c + } else { + return nil, fmt.Errorf("无效的DB_MAX_IDLE_CONNS: %s", conns) + } + } + + // Redis配置 + if host := os.Getenv("REDIS_HOST"); host != "" { + cfg.RedisHost = host + } + + if port := os.Getenv("REDIS_PORT"); port != "" { + if p, err := strconv.Atoi(port); err == nil { + cfg.RedisPort = p + } else { + return nil, fmt.Errorf("无效的REDIS_PORT: %s", port) + } + } + + if password := os.Getenv("REDIS_PASSWORD"); password != "" { + cfg.RedisPassword = password + } + + if db := os.Getenv("REDIS_DB"); db != "" { + if d, err := strconv.Atoi(db); err == nil { + cfg.RedisDB = d + } else { + return nil, fmt.Errorf("无效的REDIS_DB: %s", db) + } + } + + // 日志配置 + if level := os.Getenv("LOG_LEVEL"); level != "" { + cfg.LogLevel = level + } + + if path := os.Getenv("LOG_PATH"); path != "" { + cfg.LogPath = path + } + + // 功能标志 + if allowRegister := os.Getenv("ALLOW_REGISTER"); allowRegister != "" { + if b, err := strconv.ParseBool(allowRegister); err == nil { + cfg.AllowRegister = b + } + } + + if unlimitedQuota := os.Getenv("UNLIMITED_QUOTA"); unlimitedQuota != "" { + if b, err := strconv.ParseBool(unlimitedQuota); err == nil { + cfg.UnlimitedQuota = b + } + } + + if defaultActive := os.Getenv("DEFAULT_ACTIVE"); defaultActive != "" { + if b, err := strconv.ParseBool(defaultActive); err == nil { + cfg.DefaultActive = b + } + } + + if worker := os.Getenv("USAGE_WORKER"); worker != "" { + if w, err := strconv.Atoi(worker); err == nil { + cfg.UsageWorker = w + } + } + + if size := os.Getenv("USAGE_CHAN_SIZE"); size != "" { + if s, err := strconv.Atoi(size); err == nil { + cfg.UsageChanSize = s + } + } + + if interval := os.Getenv("TASK_TIME_INTERVAL"); interval != "" { + if i, err := strconv.Atoi(interval); err == nil { + cfg.TaskTimeInterval = i + } + } + + return cfg, nil +} diff --git a/pkg/search/bing_test.go b/pkg/search/bing_test.go new file mode 100644 index 0000000..421fb33 --- /dev/null +++ b/pkg/search/bing_test.go @@ -0,0 +1,30 @@ +/* +文档 https://www.microsoft.com/en-us/bing/apis/bing-web-search-api +价格 https://www.microsoft.com/en-us/bing/apis/pricing + +curl -H "Ocp-Apim-Subscription-Key: " https://api.bing.microsoft.com/v7.0/search?q=今天上海天气怎么样 +curl -H "Ocp-Apim-Subscription-Key: 6fc7c97ebed54f75a5e383ee2272c917" https://api.bing.microsoft.com/v7.0/search?q=今天上海天气怎么样 +*/ + +package search + +import ( + "testing" +) + +func TestBingSearch(t *testing.T) { + var searchParams = SearchParams{ + Query: "上海明天天气怎么样", + Num: 3, + } + + t.Run("BingSearch", func(t *testing.T) { + got, err := BingSearch(searchParams) + if err != nil { + t.Errorf("BingSearch() error = %v", err) + return + } + t.Log(got) + }) + +} diff --git a/pkg/store/db.go b/pkg/store/db.go index dcd337e..72d9056 100644 --- a/pkg/store/db.go +++ b/pkg/store/db.go @@ -3,37 +3,36 @@ package store import ( "fmt" "log" - "opencatd-open/team/consts" - "opencatd-open/team/model" - "os" + "opencatd-open/internal/model" + "opencatd-open/pkg/config" "strings" // "gocloud.dev/mysql" // "gocloud.dev/postgres" "github.com/glebarez/sqlite" - "github.com/google/wire" "gorm.io/driver/mysql" "gorm.io/driver/postgres" // "gorm.io/driver/sqlite" + "gorm.io/gorm" ) var DB *gorm.DB -var DBType consts.DBType +// var DBType consts.DBType var IsPostgres bool -var DBSet = wire.NewSet( - InitDB, -) +func GetDB() *gorm.DB { + return DB +} // InitDB 初始化数据库连接 -func InitDB() (*gorm.DB, error) { +func InitDB(cfg *config.Config) (*gorm.DB, error) { var db *gorm.DB var err error // 从环境变量获取DSN - dsn := os.Getenv("DSN") + dsn := cfg.DSN if dsn == "" { log.Println("No DSN provided, using SQLite as default") @@ -43,10 +42,12 @@ func InitDB() (*gorm.DB, error) { // 解析DSN来确定数据库类型 if strings.HasPrefix(dsn, "postgres://") { IsPostgres = true - DBType = consts.DBTypePostgreSQL + + cfg.DB_Type = "postgres" db, err = initPostgres(dsn) } else if strings.HasPrefix(dsn, "mysql://") { - DBType = consts.DBTypeMySQL + + cfg.DB_Type = "mysql" db, err = initMySQL(dsn) } else { if dsn != "" { @@ -60,12 +61,12 @@ func InitDB() (*gorm.DB, error) { DB = db if IsPostgres { - err = db.AutoMigrate(&model.User{}, &model.Token{}, &model.ApiKey_PG{}, &model.Usage{}, &model.DailyUsage{}) + err = db.AutoMigrate(&model.User{}, &model.Token{}, &model.ApiKey_PG{}, &model.Usage{}, &model.DailyUsage{}, &model.Passkey{}) if err != nil { return nil, err } } else { - err = db.AutoMigrate(&model.User{}, &model.Token{}, &model.ApiKey{}, &model.Usage{}, &model.DailyUsage{}) + err = db.AutoMigrate(&model.User{}, &model.Token{}, &model.ApiKey{}, &model.Usage{}, &model.DailyUsage{}, &model.Passkey{}) if err != nil { return nil, err } diff --git a/pkg/store/gcache.go b/pkg/store/gcache.go new file mode 100644 index 0000000..8cecced --- /dev/null +++ b/pkg/store/gcache.go @@ -0,0 +1,67 @@ +package store + +import ( + "errors" + "log" + "time" + + "github.com/bluele/gcache" + "github.com/go-webauthn/webauthn/webauthn" + "github.com/google/uuid" +) + +// WebAuthnSessionStore 使用 gcache 存储 WebAuthn 会话数据 +type WebAuthnSessionStore struct { + cache gcache.Cache +} + +// NewWebAuthnSessionStore 创建一个新的会话存储实例 +func NewWebAuthnSessionStore() *WebAuthnSessionStore { + // 创建一个 LRU 缓存,最多存储 10000 个会话,每个会话有效期 5 分钟 + gc := gcache.New(10000). + LRU(). + Expiration(5 * time.Minute). + Build() + return &WebAuthnSessionStore{cache: gc} +} + +// GenerateSessionID 生成唯一的会话ID +func GenerateSessionID() string { + return uuid.NewString() +} + +// SaveWebauthnSession 保存 WebAuthn 会话数据 +func (s *WebAuthnSessionStore) SaveWebauthnSession(sessionID string, data *webauthn.SessionData) error { + return s.cache.Set(sessionID, data) +} + +// GetWebauthnSession 获取 WebAuthn 会话数据 +func (s *WebAuthnSessionStore) GetWebauthnSession(sessionID string) (*webauthn.SessionData, error) { + val, err := s.cache.Get(sessionID) + if err != nil { + if errors.Is(err, gcache.KeyNotFoundError) { + return nil, errors.New("会话未找到或已过期") + } + return nil, err // 其他 gcache 错误 + } + + sessionData, ok := val.(*webauthn.SessionData) + if !ok { + // 如果类型断言失败,说明缓存中存储了错误类型的数据 + log.Printf("警告:会话存储中发现非预期的类型,Key: %s", sessionID) + // 尝试删除无效数据 + _ = s.cache.Remove(sessionID) + return nil, errors.New("无效的会话数据类型") + } + return sessionData, nil +} + +// DeleteWebauthnSession 删除 WebAuthn 会话数据 +func (s *WebAuthnSessionStore) DeleteWebauthnSession(sessionID string) { + s.cache.Remove(sessionID) +} + +func (s *WebAuthnSessionStore) GetALL() map[any]any { + + return s.cache.GetALL(false) +} diff --git a/pkg/team/key.go b/pkg/team/key.go index 802f10b..57ba808 100644 --- a/pkg/team/key.go +++ b/pkg/team/key.go @@ -2,7 +2,7 @@ package team import ( "net/http" - "opencatd-open/pkg/azureopenai" + "opencatd-open/llm/azureopenai" "opencatd-open/store" "strings" diff --git a/router/chat.go b/router/chat.go index 0637192..fae2772 100644 --- a/router/chat.go +++ b/router/chat.go @@ -4,9 +4,9 @@ import ( "net/http" "strings" - "opencatd-open/pkg/claude" - "opencatd-open/pkg/google" - "opencatd-open/pkg/openai" + "opencatd-open/llm/claude" + "opencatd-open/llm/google" + "opencatd-open/llm/openai" "github.com/gin-gonic/gin" ) diff --git a/router/router.go b/router/router.go index d8c0d62..400535e 100644 --- a/router/router.go +++ b/router/router.go @@ -5,32 +5,17 @@ import ( "net" "net/http" "net/http/httputil" - "opencatd-open/pkg/claude" - oai "opencatd-open/pkg/openai" - "opencatd-open/store" + "opencatd-open/llm/claude" + oai "opencatd-open/llm/openai" "time" "github.com/gin-gonic/gin" ) func HandleProxy(c *gin.Context) { - var ( - localuser bool - ) - auth := c.Request.Header.Get("Authorization") - if len(auth) > 7 && auth[:7] == "Bearer " { - localuser = store.IsExistAuthCache(auth[7:]) - c.Set("localuser", auth[7:]) - } if c.Request.URL.Path == "/v1/complete" { - if localuser { - claude.ClaudeProxy(c) - return - } else { - HandleReverseProxy(c, "api.anthropic.com") - return - } - + claude.ClaudeProxy(c) + return } if c.Request.URL.Path == "/v1/audio/transcriptions" { oai.WhisperProxy(c) @@ -52,19 +37,7 @@ func HandleProxy(c *gin.Context) { } if c.Request.URL.Path == "/v1/chat/completions" { - if localuser { - if store.KeysCache.ItemCount() == 0 { - c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{ - "message": "No Api-Key Available", - }}) - return - } - - ChatHandler(c) - return - } - } else { - HandleReverseProxy(c, "api.openai.com") + ChatHandler(c) return } diff --git a/store/cache.go b/store/cache.go index 5f8376b..70feb20 100644 --- a/store/cache.go +++ b/store/cache.go @@ -17,7 +17,7 @@ var ( AuthCache *cache.Cache ) -func init() { +func InitCache() { KeysCache = cache.New(cache.NoExpiration, cache.NoExpiration) AuthCache = cache.New(cache.NoExpiration, cache.NoExpiration) } diff --git a/store/db.go b/store/db.go index 65027bb..b5c2d6e 100644 --- a/store/db.go +++ b/store/db.go @@ -13,7 +13,7 @@ var db *gorm.DB var usage *gorm.DB -func init() { +func InitDB() { if _, err := os.Stat("db"); os.IsNotExist(err) { errDir := os.MkdirAll("db", 0755) if errDir != nil { diff --git a/store/keydb.go b/store/keydb.go index e5b0f81..5c83c1d 100644 --- a/store/keydb.go +++ b/store/keydb.go @@ -4,12 +4,12 @@ import ( "encoding/json" "fmt" "log" - "opencatd-open/pkg/vertexai" + "opencatd-open/llm/vertexai" "os" "time" ) -func init() { +func InitKey() { // check vertex if os.Getenv("Vertex") != "" { vertex_auth := os.Getenv("Vertex") diff --git a/team/dashboard/dashboard.go b/team/dashboard/dashboard.go deleted file mode 100644 index 82b769a..0000000 --- a/team/dashboard/dashboard.go +++ /dev/null @@ -1,16 +0,0 @@ -package dashboard - -import "github.com/gin-gonic/gin" - -func HandleTeam(c *gin.Context) { - c.JSON(200, gin.H{ - "code": 200, - "data": gin.H{ - "team": gin.H{ - "total_users": 10, - "total_keys": 20, - "total_projects": 30, - }, - }, - }) -} diff --git a/team/dashboard/login.go b/team/dashboard/login.go deleted file mode 100644 index 83a9c5b..0000000 --- a/team/dashboard/login.go +++ /dev/null @@ -1,20 +0,0 @@ -package dashboard - -import ( - "fmt" - - "github.com/gin-gonic/gin" -) - -func HandleLogin(c *gin.Context) { - var user map[string]string - c.ShouldBind(&user) - fmt.Sprintf("%v", user) - c.JSON(200, gin.H{ - "code": 200, - "msg": "success", - "data": gin.H{ - "token": "token", - }, - }) -} diff --git a/team/key.go b/team/key.go index 802f10b..57ba808 100644 --- a/team/key.go +++ b/team/key.go @@ -2,7 +2,7 @@ package team import ( "net/http" - "opencatd-open/pkg/azureopenai" + "opencatd-open/llm/azureopenai" "opencatd-open/store" "strings" diff --git a/team/model/apikey.go b/team/model/apikey.go deleted file mode 100644 index 40e84bd..0000000 --- a/team/model/apikey.go +++ /dev/null @@ -1,45 +0,0 @@ -package model - -import "github.com/lib/pq" //pq.StringArray - -type ApiKey_PG struct { - ID int64 `gorm:"column:id;primaryKey;autoIncrement"` - Name string `gorm:"column:name;not null;unique;index:idx_apikey_name"` - ApiType string `gorm:"column:apitype;not null;unique;index:idx_apikey_apitype"` - ApiKey string `gorm:"column:apikey;not null;unique;uniqueIndex:idx_apikey"` - Status int `gorm:"type:int;default:1"` // enabled 1, disabled 0 - Endpoint string `gorm:"column:endpoint;comment:接入点"` - ResourceNmae string `gorm:"column:resource_name;comment:azure资源名称"` - DeploymentName string `gorm:"column:deployment_name;comment:azure部署名称"` - ApiSecret string `gorm:"column:api_secret"` - ModelPrefix string `gorm:"column:model_prefix;comment:模型前缀"` - ModelAlias string `gorm:"column:model_alias;comment:模型别名"` - SupportModels pq.StringArray `gorm:"column:support_models;type:text[]"` - CreatedAt int64 `gorm:"column:created_at;autoUpdateTime" json:"created_at,omitempty"` - UpdatedAt int64 `gorm:"column:updated_at;autoCreateTime" json:"updated_at,omitempty"` -} - -func (ApiKey_PG) TableName() string { - return "apikeys" -} - -type ApiKey struct { - ID int64 `gorm:"column:id;primaryKey;autoIncrement"` - Name string `gorm:"column:name;not null;unique;index:idx_apikey_name"` - ApiType string `gorm:"column:apitype;not null;unique;index:idx_apikey_apitype"` - ApiKey string `gorm:"column:apikey;not null;unique;index:idx_apikey_apikey"` - Status int `gorm:"type:int;default:1"` // enabled 1, disabled 0 - Endpoint string `gorm:"column:endpoint"` - ResourceNmae string `gorm:"column:resource_name"` - DeploymentName string `gorm:"column:deployment_name"` - ApiSecret string `gorm:"column:api_secret"` - ModelPrefix string `gorm:"column:model_prefix"` - ModelAlias string `gorm:"column:model_alias"` - SupportModels []string `gorm:"column:support_models;type:json"` - CreatedAt int64 `gorm:"column:created_at;autoUpdateTime" json:"created_at,omitempty"` - UpdatedAt int64 `gorm:"column:updated_at;autoCreateTime" json:"updated_at,omitempty"` -} - -func (ApiKey) TableName() string { - return "apikeys" -} diff --git a/team/model/token.go b/team/model/token.go deleted file mode 100644 index 6a021ff..0000000 --- a/team/model/token.go +++ /dev/null @@ -1,20 +0,0 @@ -package model - -// 用户的token -type Token struct { - ID int64 `gorm:"column:id;primaryKey;autoIncrement"` - UserID int64 `gorm:"column:user_id;not null;index:idx_token_user_id"` - Name string `gorm:"column:name;index:idx_token_name"` - Key string `gorm:"column:key;not null;uniqueIndex:idx_token_key;comment:token key"` - Status int64 `gorm:"column:status;default:1;check:status IN (0,1)"` // enabled 1, disabled 0 - Quota int64 `gorm:"column:quota;type:bigint;default:0"` // default 0 - UnlimitedQuota bool `gorm:"column:unlimited_quota;default:true"` // set Quota 1 unlimited - UsedQuota int64 `gorm:"column:used_quota;type:bigint;default:0"` - CreatedAt int64 `gorm:"column:created_at;type:bigint;autoCreateTime"` - ExpiredAt int64 `gorm:"column:expired_at;type:bigint;default:-1"` // -1 means never expired - User User `gorm:"foreignKey:UserID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE" json:"user"` -} - -func (Token) TableName() string { - return "tokens" -} diff --git a/team/model/user.go b/team/model/user.go deleted file mode 100644 index 3bf44d6..0000000 --- a/team/model/user.go +++ /dev/null @@ -1,49 +0,0 @@ -package model - -import ( - "time" -) - -type User struct { - ID int64 `json:"id" gorm:"column:id;primaryKey;autoIncrement"` - Name string `json:"name" gorm:"column:name;not null;unique;index"` - Username string `json:"username" gorm:"column:username;unique;index"` - Password string `json:"password" gorm:"column:password;"` - Role int `json:"role" gorm:"column:role;type:int;default:0"` // default user 0-10-20 - Status int `json:"status" gorm:"column:status;type:int;default:1"` // disabled 0, enabled 1, deleted 2 - Nickname string `json:"nickname" gorm:"column:nickname;type:varchar(50)"` - AvatarURL string `json:"avatar_url" gorm:"column:avatar_url;type:varchar(255)"` - Email string `json:"email" gorm:"column:email;type:varchar(255);index"` - Quota int64 `json:"quota" gorm:"column:quota;bigint;default:0"` // default unlimited - UnlimitedQuota int `json:"unlimited_quota" gorm:"column:unlimited_quota;default:1;check:(unlimited_quota IN (0,1))"` // 0 limited , 1 unlimited - Timezone string `json:"timezone" gorm:"column:timezone;type:varchar(50)"` - Language string `json:"language" gorm:"column:language;type:varchar(50)"` - - // 添加一对多关系 - // Token string `json:"-" gorm:"column:token;type:varchar(64);unique;index"` - Tokens []Token `json:"-" gorm:"foreignKey:UserID;references:ID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` - - CreatedAt int64 `json:"created_at,omitempty" gorm:"autoCreateTime"` - UpdatedAt int64 `json:"updated_at,omitempty" gorm:"autoUpdateTime"` -} - -func (User) TableName() string { - return "users" -} - -type Session struct { - ID int64 `json:"id" gorm:"primaryKey;autoIncrement"` - UserID int64 `json:"user_id" gorm:"index:idx_user_id"` - Token string `json:"token" gorm:"type:varchar(64);uniqueIndex"` - DeviceType string `json:"device_type" gorm:"type:varchar(100);default:''"` - DeviceName string `json:"device_name" gorm:"type:varchar(100);default:''"` - LastActiveAt time.Time `json:"last_active_at" gorm:"type:timestamp;default:CURRENT_TIMESTAMP"` - LogoutAt time.Time `json:"logout_at" gorm:"type:timestamp;null"` - - CreatedAt time.Time `json:"created_at" gorm:"type:timestamp;not null;default:CURRENT_TIMESTAMP"` - UpdatedAt time.Time `json:"updated_at" gorm:"type:timestamp;not null;default:CURRENT_TIMESTAMP;update:CURRENT_TIMESTAMP"` -} - -func (Session) TableName() string { - return "sessions" -} diff --git a/team/service/usage.go b/team/service/usage.go deleted file mode 100644 index 5e131f6..0000000 --- a/team/service/usage.go +++ /dev/null @@ -1,137 +0,0 @@ -package service - -import ( - "context" - "fmt" - "opencatd-open/team/dao" - dto "opencatd-open/team/dto/team" - "opencatd-open/team/model" - "time" - - "log" - - "gorm.io/gorm" - "gorm.io/gorm/clause" -) - -var _ UsageService = (*usageService)(nil) - -type UsageService interface { - // AsyncProcessUsage 异步处理使用记录 - AsyncProcessUsage(usage *model.Usage) - - ListByUserID(ctx context.Context, userID int64, limit, offset int) ([]*model.Usage, error) - ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) - ListByDateRange(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) - - Delete(ctx context.Context, id int64) error -} - -type usageService struct { - db *gorm.DB - usageDAO dao.UsageRepository - dailyUsageDAO dao.DailyUsageRepository - usageChan chan *model.Usage // 用于异步处理的channel - ctx context.Context -} - -func NewUsageService(ctx context.Context, db *gorm.DB, usageRepo dao.UsageRepository, dailyUsageRepo dao.DailyUsageRepository) UsageService { - srv := &usageService{ - db: db, - usageDAO: usageRepo, - dailyUsageDAO: dailyUsageRepo, - usageChan: make(chan *model.Usage, 1000), // 设置合适的缓冲区大小 - ctx: ctx, - } - - // 启动异步处理goroutine - go srv.processUsageWorker() - - return srv -} - -func (s *usageService) AsyncProcessUsage(usage *model.Usage) { - select { - case s.usageChan <- usage: - // 成功发送到channel - default: - // channel已满,记录错误日志 - log.Println("usage channel is full, skip processing") - } -} - -func (s *usageService) processUsageWorker() { - for { - select { - case usage := <-s.usageChan: - err := s.processUsage(usage) - if err != nil { - log.Println("processUsage error:", err) - } - case <-s.ctx.Done(): - log.Println("processUsageWorker is exiting") - return - } - } -} - -// processUsageWorker 异步处理worker -func (s *usageService) processUsage(usage *model.Usage) error { - err := s.db.Transaction(func(tx *gorm.DB) error { - // 1. 记录使用记录 - if err := tx.WithContext(s.ctx).Create(usage).Error; err != nil { - return fmt.Errorf("create usage error: %w", err) - } - - // 2. 更新每日统计(upsert 操作) - dailyUsage := model.DailyUsage{ - UserID: usage.UserID, - TokenID: usage.TokenID, - Capability: usage.Capability, - Date: time.Date(usage.Date.Year(), usage.Date.Month(), usage.Date.Day(), 0, 0, 0, 0, usage.Date.Location()), - Model: usage.Model, - Stream: usage.Stream, - PromptTokens: usage.PromptTokens, - CompletionTokens: usage.CompletionTokens, - TotalTokens: usage.TotalTokens, - Cost: usage.Cost, - } - - // 使用 OnConflict 实现 upsert - if err := tx.WithContext(s.ctx).Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "user_id"}, {Name: "token_id"}, {Name: "capability"}, {Name: "date"}}, // 唯一键 - DoUpdates: clause.Assignments(map[string]interface{}{ - "prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens), - "completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens), - "total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens), - "cost": gorm.Expr("cost + ?", usage.Cost), - }), - }).Create(&dailyUsage).Error; err != nil { - return fmt.Errorf("upsert daily usage error: %w", err) - } - - // 3. 更新用户额度 - if err := tx.WithContext(s.ctx).Model(&model.User{}).Where("id = ?", usage.UserID).Update("quota", gorm.Expr("quota - ?", usage.Cost)).Error; err != nil { - return fmt.Errorf("update user quota error: %w", err) - } - - return nil - }) - return err -} - -func (s *usageService) ListByUserID(ctx context.Context, userID int64, limit int, offset int) ([]*model.Usage, error) { - return s.usageDAO.ListByUserID(ctx, userID, limit, offset) -} - -func (s *usageService) ListByCapability(ctx context.Context, capability string, limit, offset int) ([]*model.Usage, error) { - return s.usageDAO.ListByCapability(ctx, capability, limit, offset) -} - -func (s *usageService) ListByDateRange(ctx context.Context, start, end time.Time, filters map[string]interface{}) ([]*dto.UsageInfo, error) { - return s.dailyUsageDAO.StatUserUsages(ctx, start, end, filters) -} - -func (s *usageService) Delete(ctx context.Context, id int64) error { - return s.usageDAO.Delete(ctx, id) -} diff --git a/wire/wire.go b/wire/wire.go index c2f9256..3149b3c 100644 --- a/wire/wire.go +++ b/wire/wire.go @@ -5,9 +5,15 @@ package wire import ( "context" - "opencatd-open/team/dao" - handler "opencatd-open/team/handler/team" - "opencatd-open/team/service" + "opencatd-open/internal/controller" + proxy "opencatd-open/internal/controller/proxy" + team_controller "opencatd-open/internal/controller/team" + "opencatd-open/internal/dao" + "opencatd-open/pkg/config" + "sync" + + service "opencatd-open/internal/service" + teamService "opencatd-open/internal/service/team" "github.com/google/wire" "gorm.io/gorm" @@ -17,19 +23,19 @@ import ( var userSet = wire.NewSet( dao.NewUserDAO, wire.Bind(new(dao.UserRepository), new(*dao.UserDAO)), - service.NewUserService, + teamService.NewUserService, ) var keySet = wire.NewSet( dao.NewApiKeyDAO, wire.Bind(new(dao.ApiKeyRepository), new(*dao.ApiKeyDAO)), - service.NewApiKeyService, + teamService.NewApiKeyService, ) var tokenSet = wire.NewSet( dao.NewTokenDAO, wire.Bind(new(dao.TokenRepository), new(*dao.TokenDAO)), - service.NewTokenService, + teamService.NewTokenService, ) var usageSet = wire.NewSet( @@ -37,11 +43,79 @@ var usageSet = wire.NewSet( wire.Bind(new(dao.UsageRepository), new(*dao.UsageDAO)), dao.NewDailyUsageDAO, wire.Bind(new(dao.DailyUsageRepository), new(*dao.DailyUsageDAO)), - service.NewUsageService, + teamService.NewUsageService, ) // 初始化 TeamHandler -func InitTeamHandler(ctx context.Context, db *gorm.DB) (*handler.TeamHandler, error) { - wire.Build(userSet, keySet, tokenSet, usageSet, handler.NewTeamHandler) +func InitTeamHandler(ctx context.Context, cfg *config.Config, db *gorm.DB) (*team_controller.Team, error) { + wire.Build(userSet, keySet, tokenSet, usageSet, team_controller.NewTeam) + return nil, nil +} + +// var userApi = wire.NewSet( +// dao.NewUserDAO, +// wire.Bind(new(dao.UserRepository), new(*dao.UserDAO)), +// service.NewUserService, +// ) + +// var keyApi = wire.NewSet( +// dao.NewApiKeyDAO, +// wire.Bind(new(dao.ApiKeyRepository), new(*dao.ApiKeyDAO)), +// service.NewApiKeyService, +// ) + +// var tokenApi = wire.NewSet( +// dao.NewTokenDAO, +// wire.Bind(new(dao.TokenRepository), new(*dao.TokenDAO)), +// service.NewTokenService, +// ) + +// func InitAPIHandler(ctx context.Context, db *gorm.DB) (*controller.Api, error) { +// wire.Build(userApi, keyApi, tokenApi, controller.NewApi) +// return nil, nil +// } + +var repositorySet = wire.NewSet( + dao.NewUserDAO, + wire.Bind(new(dao.UserRepository), new(*dao.UserDAO)), + dao.NewTokenDAO, + wire.Bind(new(dao.TokenRepository), new(*dao.TokenDAO)), + dao.NewApiKeyDAO, + wire.Bind(new(dao.ApiKeyRepository), new(*dao.ApiKeyDAO)), + + dao.NewUsageDAO, + wire.Bind(new(dao.UsageRepository), new(*dao.UsageDAO)), + dao.NewDailyUsageDAO, + wire.Bind(new(dao.DailyUsageRepository), new(*dao.DailyUsageDAO)), +) + +var serviceSet = wire.NewSet( + service.NewUserService, + service.NewTokenService, + service.NewApiKeyService, + service.NewWebAuthnService, + + service.NewUsageService, +) + +func InitAPIHandler(ctx context.Context, cfg *config.Config, db *gorm.DB) (*controller.Api, error) { + wire.Build( + repositorySet, + serviceSet, + controller.NewApi, + ) + return nil, nil +} + +var proxySet = wire.NewSet( + dao.NewUserDAO, + dao.NewApiKeyDAO, + dao.NewTokenDAO, + dao.NewUsageDAO, + dao.NewDailyUsageDAO, +) + +func InitProxyHandler(ctx context.Context, cfg *config.Config, db *gorm.DB, wg *sync.WaitGroup) (*proxy.Proxy, error) { + wire.Build(proxySet, proxy.NewProxy) return nil, nil } diff --git a/wire/wire_gen.go b/wire/wire_gen.go index 93b6b5e..95a0e14 100644 --- a/wire/wire_gen.go +++ b/wire/wire_gen.go @@ -9,26 +9,57 @@ import ( "context" "github.com/google/wire" "gorm.io/gorm" - "opencatd-open/team/dao" - "opencatd-open/team/handler/team" - "opencatd-open/team/service" + controller2 "opencatd-open/internal/controller" + controller3 "opencatd-open/internal/controller/proxy" + "opencatd-open/internal/controller/team" + "opencatd-open/internal/dao" + service2 "opencatd-open/internal/service" + "opencatd-open/internal/service/team" + "opencatd-open/pkg/config" + "sync" ) // Injectors from wire.go: // 初始化 TeamHandler -func InitTeamHandler(ctx context.Context, db *gorm.DB) (*handler.TeamHandler, error) { +func InitTeamHandler(ctx context.Context, cfg *config.Config, db *gorm.DB) (*controller.Team, error) { userDAO := dao.NewUserDAO(db) - userService := service.NewUserService(userDAO, db) + userService := service.NewUserService(db, userDAO) tokenDAO := dao.NewTokenDAO(db) tokenService := service.NewTokenService(tokenDAO) apiKeyDAO := dao.NewApiKeyDAO(db) - apiKeyService := service.NewApiKeyService(apiKeyDAO, db) - usageDAO := dao.NewUsageDAO(db) - dailyUsageDAO := dao.NewDailyUsageDAO(db) - usageService := service.NewUsageService(ctx, db, usageDAO, dailyUsageDAO) - teamHandler := handler.NewTeamHandler(userService, tokenService, apiKeyService, usageService) - return teamHandler, nil + apiKeyService := service.NewApiKeyService(db, apiKeyDAO) + usageDAO := dao.NewUsageDAO(cfg, db) + dailyUsageDAO := dao.NewDailyUsageDAO(cfg, db) + usageService := service.NewUsageService(ctx, cfg, db, usageDAO, dailyUsageDAO) + team := controller.NewTeam(userService, tokenService, apiKeyService, usageService) + return team, nil +} + +func InitAPIHandler(ctx context.Context, cfg *config.Config, db *gorm.DB) (*controller2.Api, error) { + userDAO := dao.NewUserDAO(db) + userServiceImpl := service2.NewUserService(db, userDAO) + tokenDAO := dao.NewTokenDAO(db) + tokenServiceImpl := service2.NewTokenService(db, tokenDAO) + apiKeyDAO := dao.NewApiKeyDAO(db) + apiKeyServiceImpl := service2.NewApiKeyService(db, apiKeyDAO) + webAuthnService, err := service2.NewWebAuthnService(db, cfg) + if err != nil { + return nil, err + } + usageService := service2.NewUsageService(ctx, cfg, db) + api := controller2.NewApi(db, userServiceImpl, tokenServiceImpl, apiKeyServiceImpl, webAuthnService, usageService) + return api, nil +} + +func InitProxyHandler(ctx context.Context, cfg *config.Config, db *gorm.DB, wg *sync.WaitGroup) (*controller3.Proxy, error) { + userDAO := dao.NewUserDAO(db) + apiKeyDAO := dao.NewApiKeyDAO(db) + tokenDAO := dao.NewTokenDAO(db) + usageDAO := dao.NewUsageDAO(cfg, db) + dailyUsageDAO := dao.NewDailyUsageDAO(cfg, db) + proxy := controller3.NewProxy(ctx, cfg, db, wg, userDAO, apiKeyDAO, tokenDAO, usageDAO, dailyUsageDAO) + return proxy, nil } // wire.go: @@ -41,3 +72,9 @@ var keySet = wire.NewSet(dao.NewApiKeyDAO, wire.Bind(new(dao.ApiKeyRepository), var tokenSet = wire.NewSet(dao.NewTokenDAO, wire.Bind(new(dao.TokenRepository), new(*dao.TokenDAO)), service.NewTokenService) var usageSet = wire.NewSet(dao.NewUsageDAO, wire.Bind(new(dao.UsageRepository), new(*dao.UsageDAO)), dao.NewDailyUsageDAO, wire.Bind(new(dao.DailyUsageRepository), new(*dao.DailyUsageDAO)), service.NewUsageService) + +var repositorySet = wire.NewSet(dao.NewUserDAO, wire.Bind(new(dao.UserRepository), new(*dao.UserDAO)), dao.NewTokenDAO, wire.Bind(new(dao.TokenRepository), new(*dao.TokenDAO)), dao.NewApiKeyDAO, wire.Bind(new(dao.ApiKeyRepository), new(*dao.ApiKeyDAO)), dao.NewUsageDAO, wire.Bind(new(dao.UsageRepository), new(*dao.UsageDAO)), dao.NewDailyUsageDAO, wire.Bind(new(dao.DailyUsageRepository), new(*dao.DailyUsageDAO))) + +var serviceSet = wire.NewSet(service2.NewUserService, service2.NewTokenService, service2.NewApiKeyService, service2.NewWebAuthnService, service2.NewUsageService) + +var proxySet = wire.NewSet(dao.NewUserDAO, dao.NewApiKeyDAO, dao.NewTokenDAO, dao.NewUsageDAO, dao.NewDailyUsageDAO)