package controller import ( "context" "encoding/json" "errors" "fmt" "io" "log" "math/rand" "net/http" "net/url" "opencatd-open/internal/dao" "opencatd-open/internal/model" "opencatd-open/internal/utils" "opencatd-open/llm" "opencatd-open/pkg/config" "opencatd-open/pkg/tokenizer" "os" "strings" "sync" "time" "github.com/bluele/gcache" "github.com/gin-gonic/gin" "github.com/lib/pq" "github.com/tidwall/gjson" "gorm.io/gorm" ) type Proxy struct { ctx context.Context cfg *config.Config db *gorm.DB wg *sync.WaitGroup usageChan chan *llm.TokenUsage // 用于异步处理的channel apikey *model.ApiKey httpClient *http.Client cache gcache.Cache userDAO *dao.UserDAO apiKeyDao *dao.ApiKeyDAO tokenDAO *dao.TokenDAO usageDAO *dao.UsageDAO dailyUsageDAO *dao.DailyUsageDAO } func NewProxy(ctx context.Context, cfg *config.Config, db *gorm.DB, wg *sync.WaitGroup, userDAO *dao.UserDAO, apiKeyDAO *dao.ApiKeyDAO, tokenDAO *dao.TokenDAO, usageDAO *dao.UsageDAO, dailyUsageDAO *dao.DailyUsageDAO) *Proxy { client := http.DefaultClient if os.Getenv("LOCAL_PROXY") != "" { proxyUrl, err := url.Parse(os.Getenv("LOCAL_PROXY")) if err == nil { tr := &http.Transport{ Proxy: http.ProxyURL(proxyUrl), } client.Transport = tr } } np := &Proxy{ ctx: ctx, cfg: cfg, db: db, wg: wg, httpClient: client, cache: gcache.New(1).Build(), usageChan: make(chan *llm.TokenUsage, cfg.UsageChanSize), userDAO: userDAO, apiKeyDao: apiKeyDAO, tokenDAO: tokenDAO, usageDAO: usageDAO, dailyUsageDAO: dailyUsageDAO, } go np.ProcessUsage() go np.ScheduleTask() np.setModelCache() return np } func (p *Proxy) HandleProxy(c *gin.Context) { if c.Request.URL.Path == "/v1/chat/completions" { p.ChatHandler(c) return } if strings.HasPrefix(c.Request.URL.Path, "/v1/messages") { p.ProxyClaude(c) return } } func (p *Proxy) SendUsage(usage *llm.TokenUsage) { select { case p.usageChan <- usage: default: log.Println("usage channel is full, skip processing") bj, _ := json.Marshal(usage) log.Println(string(bj)) //TODO: send to a queue } } func (p *Proxy) ProcessUsage() { for i := 0; i < p.cfg.UsageWorker; i++ { p.wg.Add(1) go func(i int) { defer p.wg.Done() for { select { case usage, ok := <-p.usageChan: if !ok { // channel 关闭,退出程序 return } err := p.Do(usage) if err != nil { log.Printf("process usage error: %v\n", err) } case <-p.ctx.Done(): // close(s.usageChan) // for usage := range s.usageChan { // if err := s.Do(usage); err != nil { // fmt.Printf("[close event]process usage error: %v\n", err) // } // } for { select { case usage, ok := <-p.usageChan: if !ok { return } if err := p.Do(usage); err != nil { fmt.Printf("[close event]process usage error: %v\n", err) } default: fmt.Printf("usageChan is empty,usage worker %d done\n", i) return } } } } }(i) } } func (p *Proxy) Do(llmusage *llm.TokenUsage) error { err := p.db.Transaction(func(tx *gorm.DB) error { now := time.Now() today, _ := time.Parse("2006-01-02", now.Format("2006-01-02")) cost := tokenizer.Cost(llmusage.Model, llmusage.PromptTokens, llmusage.CompletionTokens) token, err := p.tokenDAO.GetByID(p.ctx, llmusage.TokenID) if err != nil { return err } usage := &model.Usage{ UserID: llmusage.User.ID, TokenID: llmusage.TokenID, Date: now, Model: llmusage.Model, Stream: llmusage.Stream, PromptTokens: llmusage.PromptTokens, CompletionTokens: llmusage.CompletionTokens, TotalTokens: llmusage.TotalTokens, Cost: fmt.Sprintf("%.8f", cost), } // 1. 记录使用记录 if err := tx.WithContext(p.ctx).Create(usage).Error; err != nil { return fmt.Errorf("create usage error: %w", err) } // 2. 更新每日统计 var dailyUsage model.DailyUsage result := tx.WithContext(p.ctx).Where("user_id = ? and date = ?", llmusage.User.ID, today).First(&dailyUsage) if result.RowsAffected == 0 { dailyUsage.UserID = llmusage.User.ID dailyUsage.TokenID = llmusage.TokenID dailyUsage.Date = today dailyUsage.Model = llmusage.Model dailyUsage.Stream = llmusage.Stream dailyUsage.PromptTokens = llmusage.PromptTokens dailyUsage.CompletionTokens = llmusage.CompletionTokens dailyUsage.TotalTokens = llmusage.TotalTokens dailyUsage.Cost = fmt.Sprintf("%.8f", cost) if err := tx.WithContext(p.ctx).Create(&dailyUsage).Error; err != nil { return fmt.Errorf("create daily usage error: %w", err) } } else { if err := tx.WithContext(p.ctx).Model(&model.DailyUsage{}).Where("user_id = ? and date = ?", llmusage.User.ID, today). Updates(map[string]interface{}{ "prompt_tokens": gorm.Expr("prompt_tokens + ?", llmusage.PromptTokens), "completion_tokens": gorm.Expr("completion_tokens + ?", llmusage.CompletionTokens), "total_tokens": gorm.Expr("total_tokens + ?", llmusage.TotalTokens), }).Error; err != nil { return fmt.Errorf("update daily usage error: %w", err) } } // 3. 更新用户额度 if *llmusage.User.UnlimitedQuota { if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", llmusage.User.ID).Updates(map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)), }).Error; err != nil { return fmt.Errorf("update user quota and used_quota error: %w", err) } } else { if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", llmusage.User.ID).Updates(map[string]interface{}{ "quota": gorm.Expr("quota - ?", fmt.Sprintf("%.8f", cost)), "used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)), }).Error; err != nil { return fmt.Errorf("update user quota and used_quota error: %w", err) } } //4 . 更新token额度 if *token.UnlimitedQuota { if err := tx.WithContext(p.ctx).Model(&model.Token{}).Where("id = ?", llmusage.TokenID).Updates(map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)), }).Error; err != nil { return fmt.Errorf("update token quota and used_quota error: %w", err) } } else { if err := tx.WithContext(p.ctx).Model(&model.Token{}).Where("id = ?", llmusage.TokenID).Updates(map[string]interface{}{ "quota": gorm.Expr("quota - ?", fmt.Sprintf("%.8f", cost)), "used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)), }).Error; err != nil { return fmt.Errorf("update token quota and used_quota error: %w", err) } } return nil }) return err } func (p *Proxy) SelectApiKey(model string) error { akpikeys, err := p.apiKeyDao.FindApiKeysBySupportModel(p.db, model) if err != nil || len(akpikeys) == 0 { if strings.HasPrefix(model, "gpt") || strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") || strings.HasPrefix(model, "o4") { keys, err := p.apiKeyDao.FindKeys(map[string]any{"active = ?": true, "apitype = ?": "openai"}) if err != nil { return err } akpikeys = append(akpikeys, keys...) } if strings.HasPrefix(model, "gemini") { keys, err := p.apiKeyDao.FindKeys(map[string]any{"active = ?": true, "apitype = ?": "gemini"}) if err != nil { return err } akpikeys = append(akpikeys, keys...) } if strings.HasPrefix(model, "claude") { keys, err := p.apiKeyDao.FindKeys(map[string]any{"active = ?": true, "apitype = ?": "claude"}) if err != nil { return err } akpikeys = append(akpikeys, keys...) } } if len(akpikeys) == 0 { return errors.New("no available apikey") } if len(akpikeys) == 1 { p.apikey = &akpikeys[0] return nil } length := len(akpikeys) - 1 p.apikey = &akpikeys[rand.Intn(length)] return nil } func (p *Proxy) updateSupportModel() { keys, err := p.apiKeyDao.FindKeys(map[string]interface{}{"apitype in ?": []string{"openai", "azure", "claude"}}) if err != nil { return } for _, key := range keys { var supportModels []string if *key.ApiType == "openai" || *key.ApiType == "azure" { supportModels, err = p.getOpenAISupportModels(key) } if *key.ApiType == "claude" { supportModels, err = p.getClaudeSupportModels(key) } if err != nil { log.Println(err) continue } if len(supportModels) == 0 { continue } if p.cfg.DB_Type == "sqlite" { bytejson, _ := json.Marshal(supportModels) if err := p.db.Model(&model.ApiKey{}).Where("id = ?", key.ID).UpdateColumn("support_models", string(bytejson)).Error; err != nil { log.Println(err) } } else if p.cfg.DB_Type == "postgres" { if err := p.db.Model(&model.ApiKey{}).Where("id = ?", key.ID).UpdateColumn("support_models", pq.StringArray(supportModels)).Error; err != nil { log.Println(err) } } } } func (p *Proxy) ScheduleTask() { func() { for { select { case <-time.After(time.Duration(p.cfg.TaskTimeInterval) * time.Minute): p.updateSupportModel() case <-time.After(time.Hour * 12): if err := p.setModelCache(); err != nil { fmt.Println("refrash model cache err:", err) } case <-p.ctx.Done(): fmt.Println("schedule task done") return } } }() } func (p *Proxy) getOpenAISupportModels(apikey model.ApiKey) ([]string, error) { openaiModelsUrl := "https://api.openai.com/v1/models" // https://learn.microsoft.com/zh-cn/rest/api/azureopenai/models/list?view=rest-azureopenai-2025-02-01-preview&tabs=HTTP azureModelsUrl := "/openai/deployments?api-version=2022-12-01" var supportModels []string var req *http.Request if *apikey.ApiType == "azure" { if strings.HasSuffix(*apikey.Endpoint, "/") { apikey.Endpoint = utils.ToPtr(strings.TrimSuffix(*apikey.Endpoint, "/")) } req, _ = http.NewRequest("GET", *apikey.Endpoint+azureModelsUrl, nil) req.Header.Set("Content-Type", "application/json") req.Header.Set("api-key", *apikey.ApiKey) } else { req, _ = http.NewRequest("GET", openaiModelsUrl, nil) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+*apikey.ApiKey) } resp, err := p.httpClient.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode == http.StatusOK { bytesbody, _ := io.ReadAll(resp.Body) result := gjson.GetBytes(bytesbody, "data.#.id").Array() for _, v := range result { model := v.Str model = strings.Replace(model, "-35-", "-3.5-", -1) model = strings.Replace(model, "-41-", "-4.1-", -1) supportModels = append(supportModels, model) } } return supportModels, nil } func (p *Proxy) getClaudeSupportModels(apikey model.ApiKey) ([]string, error) { // https://docs.anthropic.com/en/api/models-list claudemodelsUrl := "https://api.anthropic.com/v1/models" var supportModels []string req, _ := http.NewRequest("GET", claudemodelsUrl, nil) req.Header.Set("Content-Type", "application/json") req.Header.Set("x-api-key", *apikey.ApiKey) req.Header.Set("anthropic-version", "2023-06-01") resp, err := p.httpClient.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode == http.StatusOK { bytesbody, _ := io.ReadAll(resp.Body) result := gjson.GetBytes(bytesbody, "data.#.id").Array() for _, v := range result { supportModels = append(supportModels, v.Str) } } return supportModels, nil }