package controller import ( "context" "encoding/json" "errors" "fmt" "io" "log" "math/rand" "net/http" "net/url" "opencatd-open/internal/dao" "opencatd-open/internal/model" "opencatd-open/internal/utils" "opencatd-open/pkg/config" "os" "strings" "sync" "time" "github.com/bluele/gcache" "github.com/gin-gonic/gin" "github.com/lib/pq" "github.com/tidwall/gjson" "gorm.io/gorm" "gorm.io/gorm/clause" ) type Proxy struct { ctx context.Context cfg *config.Config db *gorm.DB wg *sync.WaitGroup usageChan chan *model.Usage // 用于异步处理的channel apikey *model.ApiKey httpClient *http.Client cache gcache.Cache userDAO *dao.UserDAO apiKeyDao *dao.ApiKeyDAO tokenDAO *dao.TokenDAO usageDAO *dao.UsageDAO dailyUsageDAO *dao.DailyUsageDAO } func NewProxy(ctx context.Context, cfg *config.Config, db *gorm.DB, wg *sync.WaitGroup, userDAO *dao.UserDAO, apiKeyDAO *dao.ApiKeyDAO, tokenDAO *dao.TokenDAO, usageDAO *dao.UsageDAO, dailyUsageDAO *dao.DailyUsageDAO) *Proxy { client := http.DefaultClient if os.Getenv("LOCAL_PROXY") != "" { proxyUrl, err := url.Parse(os.Getenv("LOCAL_PROXY")) if err == nil { tr := &http.Transport{ Proxy: http.ProxyURL(proxyUrl), } client.Transport = tr } } np := &Proxy{ ctx: ctx, cfg: cfg, db: db, wg: wg, httpClient: client, cache: gcache.New(1).Build(), usageChan: make(chan *model.Usage, cfg.UsageChanSize), userDAO: userDAO, apiKeyDao: apiKeyDAO, tokenDAO: tokenDAO, usageDAO: usageDAO, dailyUsageDAO: dailyUsageDAO, } go np.ProcessUsage() go np.ScheduleTask() np.setModelCache() return np } func (p *Proxy) HandleProxy(c *gin.Context) { if c.Request.URL.Path == "/v1/chat/completions" { p.ChatHandler(c) return } } func (p *Proxy) SendUsage(usage *model.Usage) { select { case p.usageChan <- usage: default: log.Println("usage channel is full, skip processing") bj, _ := json.Marshal(usage) log.Println(string(bj)) //TODO: send to a queue } } func (p *Proxy) ProcessUsage() { for i := 0; i < p.cfg.UsageWorker; i++ { p.wg.Add(1) go func(i int) { defer p.wg.Done() for { select { case usage, ok := <-p.usageChan: if !ok { // channel 关闭,退出程序 return } err := p.Do(usage) if err != nil { log.Printf("process usage error: %v\n", err) } case <-p.ctx.Done(): // close(s.usageChan) // for usage := range s.usageChan { // if err := s.Do(usage); err != nil { // fmt.Printf("[close event]process usage error: %v\n", err) // } // } for { select { case usage, ok := <-p.usageChan: if !ok { return } if err := p.Do(usage); err != nil { fmt.Printf("[close event]process usage error: %v\n", err) } default: fmt.Printf("usageChan is empty,usage worker %d done\n", i) return } } } } }(i) } } func (p *Proxy) Do(usage *model.Usage) error { err := p.db.Transaction(func(tx *gorm.DB) error { // 1. 记录使用记录 if err := tx.WithContext(p.ctx).Create(usage).Error; err != nil { return fmt.Errorf("create usage error: %w", err) } // 2. 更新每日统计(upsert 操作) dailyUsage := model.DailyUsage{ UserID: usage.UserID, TokenID: usage.TokenID, Capability: usage.Capability, Date: time.Date(usage.Date.Year(), usage.Date.Month(), usage.Date.Day(), 0, 0, 0, 0, usage.Date.Location()), Model: usage.Model, Stream: usage.Stream, PromptTokens: usage.PromptTokens, CompletionTokens: usage.CompletionTokens, TotalTokens: usage.TotalTokens, Cost: usage.Cost, } // 使用 OnConflict 实现 upsert if err := tx.WithContext(p.ctx).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "user_id"}, {Name: "token_id"}, {Name: "capability"}, {Name: "date"}}, // 唯一键 DoUpdates: clause.Assignments(map[string]interface{}{ "prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens), "completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens), "total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens), "cost": gorm.Expr("cost + ?", usage.Cost), }), }).Create(&dailyUsage).Error; err != nil { return fmt.Errorf("upsert daily usage error: %w", err) } // 3. 更新用户额度 if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", usage.UserID).Updates(map[string]interface{}{ "quota": gorm.Expr("quota - ?", usage.Cost), "used_quota": gorm.Expr("used_quota + ?", usage.Cost), }).Error; err != nil { return fmt.Errorf("update user quota and used_quota error: %w", err) } return nil }) return err } func (p *Proxy) SelectApiKey(model string) error { akpikeys, err := p.apiKeyDao.FindApiKeysBySupportModel(p.db, model) fmt.Println(len(akpikeys), err) if err != nil || len(akpikeys) == 0 { if strings.HasPrefix(model, "gpt") || strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") || strings.HasPrefix(model, "o4") { keys, err := p.apiKeyDao.FindKeys(map[string]any{"active = ?": true, "apitype = ?": "openai"}) if err != nil { return err } akpikeys = append(akpikeys, keys...) } if strings.HasPrefix(model, "gemini") { keys, err := p.apiKeyDao.FindKeys(map[string]any{"active = ?": true, "apitype = ?": "gemini"}) if err != nil { return err } akpikeys = append(akpikeys, keys...) } if strings.HasPrefix(model, "claude") { keys, err := p.apiKeyDao.FindKeys(map[string]any{"active = ?": true, "apitype = ?": "claude"}) if err != nil { return err } akpikeys = append(akpikeys, keys...) } } if len(akpikeys) == 0 { return errors.New("no available apikey") } if len(akpikeys) == 1 { p.apikey = &akpikeys[0] return nil } length := len(akpikeys) - 1 p.apikey = &akpikeys[rand.Intn(length)] return nil } func (p *Proxy) updateSupportModel() { keys, err := p.apiKeyDao.FindKeys(map[string]interface{}{"apitype in ?": []string{"openai", "azure", "claude"}}) if err != nil { return } for _, key := range keys { var supportModels []string if *key.ApiType == "openai" || *key.ApiType == "azure" { supportModels, err = p.getOpenAISupportModels(key) } if *key.ApiType == "claude" { supportModels, err = p.getClaudeSupportModels(key) } if err != nil { log.Println(err) continue } if len(supportModels) == 0 { continue } if p.cfg.DB_Type == "sqlite" { bytejson, _ := json.Marshal(supportModels) if err := p.db.Model(&model.ApiKey{}).Where("id = ?", key.ID).UpdateColumn("support_models", string(bytejson)).Error; err != nil { log.Println(err) } } else if p.cfg.DB_Type == "postgres" { if err := p.db.Model(&model.ApiKey{}).Where("id = ?", key.ID).UpdateColumn("support_models", pq.StringArray(supportModels)).Error; err != nil { log.Println(err) } } } } func (p *Proxy) ScheduleTask() { func() { for { select { case <-time.After(time.Duration(p.cfg.TaskTimeInterval) * time.Minute): p.updateSupportModel() case <-time.After(time.Hour * 12): if err := p.setModelCache(); err != nil { fmt.Println("refrash model cache err:", err) } case <-p.ctx.Done(): fmt.Println("schedule task done") return } } }() } func (p *Proxy) getOpenAISupportModels(apikey model.ApiKey) ([]string, error) { openaiModelsUrl := "https://api.openai.com/v1/models" // https://learn.microsoft.com/zh-cn/rest/api/azureopenai/models/list?view=rest-azureopenai-2025-02-01-preview&tabs=HTTP azureModelsUrl := "/openai/deployments?api-version=2022-12-01" var supportModels []string var req *http.Request if *apikey.ApiType == "azure" { if strings.HasSuffix(*apikey.Endpoint, "/") { apikey.Endpoint = utils.ToPtr(strings.TrimSuffix(*apikey.Endpoint, "/")) } req, _ = http.NewRequest("GET", *apikey.Endpoint+azureModelsUrl, nil) req.Header.Set("Content-Type", "application/json") req.Header.Set("api-key", *apikey.ApiKey) } else { req, _ = http.NewRequest("GET", openaiModelsUrl, nil) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+*apikey.ApiKey) } resp, err := p.httpClient.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode == http.StatusOK { bytesbody, _ := io.ReadAll(resp.Body) result := gjson.GetBytes(bytesbody, "data.#.id").Array() for _, v := range result { model := v.Str model = strings.Replace(model, "-35-", "-3.5-", -1) model = strings.Replace(model, "-41-", "-4.1-", -1) supportModels = append(supportModels, model) } } return supportModels, nil } func (p *Proxy) getClaudeSupportModels(apikey model.ApiKey) ([]string, error) { // https://docs.anthropic.com/en/api/models-list claudemodelsUrl := "https://api.anthropic.com/v1/models" var supportModels []string req, _ := http.NewRequest("GET", claudemodelsUrl, nil) req.Header.Set("Content-Type", "application/json") req.Header.Set("x-api-key", *apikey.ApiKey) req.Header.Set("anthropic-version", "2023-06-01") resp, err := p.httpClient.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode == http.StatusOK { bytesbody, _ := io.ReadAll(resp.Body) result := gjson.GetBytes(bytesbody, "data.#.id").Array() for _, v := range result { supportModels = append(supportModels, v.Str) } } return supportModels, nil }