Compare commits

...

2 Commits

Author SHA1 Message Date
Sakurasan
73e53c2333 add models task 2025-04-21 01:40:06 +08:00
Sakurasan
470e49b850 support fetch models 2025-04-21 01:30:17 +08:00
3 changed files with 69 additions and 4 deletions

View File

@@ -0,0 +1,58 @@
package controller
import (
"encoding/json"
"fmt"
"net/http"
"opencatd-open/internal/dto"
"github.com/gin-gonic/gin"
)
func (p *Proxy) HandleModels(c *gin.Context) {
models, err := p.getModelCache()
if err != nil {
dto.Fail(c, http.StatusBadGateway, err.Error())
return
}
type _model struct {
ID string `json:"id"`
}
var ms []_model
for _, model := range models {
ms = append(ms, _model{ID: model})
}
dto.Success(c, ms)
}
func (p *Proxy) setModelCache() error {
apikeys, err := p.apiKeyDao.FindKeys(nil)
models := make(map[string]bool)
if err == nil && len(apikeys) > 0 {
for _, k := range apikeys {
if len(k.SupportModelsArray) > 0 {
for _, sm := range k.SupportModelsArray {
models[sm] = true
}
} else {
var sma []string
json.Unmarshal([]byte(*k.SupportModels), &sma) // nolint:errCheck
for _, sm := range sma {
models[sm] = true
}
}
}
} else {
return fmt.Errorf("empty data")
}
var support_models []string
for m, _ := range models {
support_models = append(support_models, m)
}
return p.cache.Set("models", support_models)
}
func (p *Proxy) getModelCache() ([]string, error) {
models, err := p.cache.Get("models")
return models.([]string), err
}

View File

@@ -19,6 +19,7 @@ import (
"sync"
"time"
"github.com/bluele/gcache"
"github.com/gin-gonic/gin"
"github.com/lib/pq"
"github.com/tidwall/gjson"
@@ -34,6 +35,7 @@ type Proxy struct {
usageChan chan *model.Usage // 用于异步处理的channel
apikey *model.ApiKey
httpClient *http.Client
cache gcache.Cache
userDAO *dao.UserDAO
apiKeyDao *dao.ApiKeyDAO
@@ -53,12 +55,14 @@ func NewProxy(ctx context.Context, cfg *config.Config, db *gorm.DB, wg *sync.Wai
client.Transport = tr
}
}
np := &Proxy{
ctx: ctx,
cfg: cfg,
db: db,
wg: wg,
httpClient: client,
cache: gcache.New(1).Build(),
usageChan: make(chan *model.Usage, cfg.UsageChanSize),
userDAO: userDAO,
apiKeyDao: apiKeyDAO,
@@ -69,7 +73,7 @@ func NewProxy(ctx context.Context, cfg *config.Config, db *gorm.DB, wg *sync.Wai
go np.ProcessUsage()
go np.ScheduleTask()
np.setModelCache()
return np
}
@@ -185,7 +189,7 @@ func (p *Proxy) Do(usage *model.Usage) error {
func (p *Proxy) SelectApiKey(model string) error {
akpikeys, err := p.apiKeyDao.FindApiKeysBySupportModel(p.db, model)
fmt.Println(len(akpikeys), err)
if err != nil || len(akpikeys) == 0 {
if strings.HasPrefix(model, "gpt") || strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") || strings.HasPrefix(model, "o4") {
keys, err := p.apiKeyDao.FindKeys(map[string]any{"active = ?": true, "apitype = ?": "openai"})
@@ -271,7 +275,10 @@ func (p *Proxy) ScheduleTask() {
select {
case <-time.After(time.Duration(p.cfg.TaskTimeInterval) * time.Minute):
p.updateSupportModel()
case <-time.After(time.Hour * 12):
if err := p.setModelCache(); err != nil {
fmt.Println("refrash model cache err:", err)
}
case <-p.ctx.Done():
fmt.Println("schedule task done")
return

View File

@@ -124,7 +124,7 @@ func SetRouter(cfg *config.Config, db *gorm.DB, web *embed.FS) {
{
// v1.POST("/v2/*proxypath", router.HandleProxy)
v1.POST("/*proxypath", proxy.HandleProxy)
// v1.GET("/models", dashboard.HandleModels)
v1.GET("/models", proxy.HandleModels)
}
idxFS, err := fs.Sub(web, "dist")