From 470e49b85057d740da5d8e9a57d0ba3db82d5bc6 Mon Sep 17 00:00:00 2001 From: Sakurasan <26715255+Sakurasan@users.noreply.github.com> Date: Mon, 21 Apr 2025 01:30:17 +0800 Subject: [PATCH] support fetch models --- internal/controller/proxy/models.go | 58 +++++++++++++++++++++++++++++ internal/controller/proxy/proxy.go | 4 ++ router/setRouter.go | 2 +- 3 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 internal/controller/proxy/models.go diff --git a/internal/controller/proxy/models.go b/internal/controller/proxy/models.go new file mode 100644 index 0000000..df5fceb --- /dev/null +++ b/internal/controller/proxy/models.go @@ -0,0 +1,58 @@ +package controller + +import ( + "encoding/json" + "fmt" + "net/http" + "opencatd-open/internal/dto" + + "github.com/gin-gonic/gin" +) + +func (p *Proxy) HandleModels(c *gin.Context) { + models, err := p.getCache() + if err != nil { + dto.Fail(c, http.StatusBadGateway, err.Error()) + return + } + type _model struct { + ID string `json:"id"` + } + var ms []_model + for _, model := range models { + ms = append(ms, _model{ID: model}) + } + dto.Success(c, ms) +} + +func (p *Proxy) setCache() error { + apikeys, err := p.apiKeyDao.FindKeys(nil) + models := make(map[string]bool) + if err == nil && len(apikeys) > 0 { + for _, k := range apikeys { + if len(k.SupportModelsArray) > 0 { + for _, sm := range k.SupportModelsArray { + models[sm] = true + } + } else { + var sma []string + json.Unmarshal([]byte(*k.SupportModels), &sma) // nolint:errCheck + for _, sm := range sma { + models[sm] = true + } + } + } + } else { + return fmt.Errorf("empty data") + } + var support_models []string + for m, _ := range models { + support_models = append(support_models, m) + } + return p.cache.Set("models", support_models) +} + +func (p *Proxy) getCache() ([]string, error) { + models, err := p.cache.Get("models") + return models.([]string), err +} diff --git a/internal/controller/proxy/proxy.go b/internal/controller/proxy/proxy.go index 7891c8f..351ac8c 100644 --- a/internal/controller/proxy/proxy.go +++ b/internal/controller/proxy/proxy.go @@ -19,6 +19,7 @@ import ( "sync" "time" + "github.com/bluele/gcache" "github.com/gin-gonic/gin" "github.com/lib/pq" "github.com/tidwall/gjson" @@ -34,6 +35,7 @@ type Proxy struct { usageChan chan *model.Usage // 用于异步处理的channel apikey *model.ApiKey httpClient *http.Client + cache gcache.Cache userDAO *dao.UserDAO apiKeyDao *dao.ApiKeyDAO @@ -53,12 +55,14 @@ func NewProxy(ctx context.Context, cfg *config.Config, db *gorm.DB, wg *sync.Wai client.Transport = tr } } + np := &Proxy{ ctx: ctx, cfg: cfg, db: db, wg: wg, httpClient: client, + cache: gcache.New(1).Build(), usageChan: make(chan *model.Usage, cfg.UsageChanSize), userDAO: userDAO, apiKeyDao: apiKeyDAO, diff --git a/router/setRouter.go b/router/setRouter.go index 301971d..a45ea34 100644 --- a/router/setRouter.go +++ b/router/setRouter.go @@ -124,7 +124,7 @@ func SetRouter(cfg *config.Config, db *gorm.DB, web *embed.FS) { { // v1.POST("/v2/*proxypath", router.HandleProxy) v1.POST("/*proxypath", proxy.HandleProxy) - // v1.GET("/models", dashboard.HandleModels) + v1.GET("/models", proxy.HandleModels) } idxFS, err := fs.Sub(web, "dist")