400 lines
11 KiB
Go
400 lines
11 KiB
Go
package controller
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"math/rand"
|
|
"net/http"
|
|
"net/url"
|
|
"opencatd-open/internal/dao"
|
|
"opencatd-open/internal/model"
|
|
"opencatd-open/internal/utils"
|
|
"opencatd-open/llm"
|
|
"opencatd-open/pkg/config"
|
|
"opencatd-open/pkg/tokenizer"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/bluele/gcache"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/lib/pq"
|
|
"github.com/tidwall/gjson"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type Proxy struct {
|
|
ctx context.Context
|
|
cfg *config.Config
|
|
db *gorm.DB
|
|
wg *sync.WaitGroup
|
|
usageChan chan *llm.TokenUsage // 用于异步处理的channel
|
|
apikey *model.ApiKey
|
|
httpClient *http.Client
|
|
cache gcache.Cache
|
|
|
|
userDAO *dao.UserDAO
|
|
apiKeyDao *dao.ApiKeyDAO
|
|
tokenDAO *dao.TokenDAO
|
|
usageDAO *dao.UsageDAO
|
|
dailyUsageDAO *dao.DailyUsageDAO
|
|
}
|
|
|
|
func NewProxy(ctx context.Context, cfg *config.Config, db *gorm.DB, wg *sync.WaitGroup, userDAO *dao.UserDAO, apiKeyDAO *dao.ApiKeyDAO, tokenDAO *dao.TokenDAO, usageDAO *dao.UsageDAO, dailyUsageDAO *dao.DailyUsageDAO) *Proxy {
|
|
client := http.DefaultClient
|
|
if os.Getenv("LOCAL_PROXY") != "" {
|
|
proxyUrl, err := url.Parse(os.Getenv("LOCAL_PROXY"))
|
|
if err == nil {
|
|
tr := &http.Transport{
|
|
Proxy: http.ProxyURL(proxyUrl),
|
|
}
|
|
client.Transport = tr
|
|
}
|
|
}
|
|
|
|
np := &Proxy{
|
|
ctx: ctx,
|
|
cfg: cfg,
|
|
db: db,
|
|
wg: wg,
|
|
httpClient: client,
|
|
cache: gcache.New(1).Build(),
|
|
usageChan: make(chan *llm.TokenUsage, cfg.UsageChanSize),
|
|
userDAO: userDAO,
|
|
apiKeyDao: apiKeyDAO,
|
|
tokenDAO: tokenDAO,
|
|
usageDAO: usageDAO,
|
|
dailyUsageDAO: dailyUsageDAO,
|
|
}
|
|
|
|
go np.ProcessUsage()
|
|
go np.ScheduleTask()
|
|
np.setModelCache()
|
|
return np
|
|
}
|
|
|
|
func (p *Proxy) HandleProxy(c *gin.Context) {
|
|
if c.Request.URL.Path == "/v1/chat/completions" {
|
|
p.ChatHandler(c)
|
|
return
|
|
}
|
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/messages") {
|
|
p.ProxyClaude(c)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) SendUsage(usage *llm.TokenUsage) {
|
|
select {
|
|
case p.usageChan <- usage:
|
|
default:
|
|
log.Println("usage channel is full, skip processing")
|
|
bj, _ := json.Marshal(usage)
|
|
log.Println(string(bj))
|
|
//TODO: send to a queue
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) ProcessUsage() {
|
|
for i := 0; i < p.cfg.UsageWorker; i++ {
|
|
p.wg.Add(1)
|
|
go func(i int) {
|
|
defer p.wg.Done()
|
|
for {
|
|
select {
|
|
case usage, ok := <-p.usageChan:
|
|
if !ok {
|
|
// channel 关闭,退出程序
|
|
return
|
|
}
|
|
err := p.Do(usage)
|
|
if err != nil {
|
|
log.Printf("process usage error: %v\n", err)
|
|
}
|
|
case <-p.ctx.Done():
|
|
// close(s.usageChan)
|
|
// for usage := range s.usageChan {
|
|
// if err := s.Do(usage); err != nil {
|
|
// fmt.Printf("[close event]process usage error: %v\n", err)
|
|
// }
|
|
// }
|
|
for {
|
|
select {
|
|
case usage, ok := <-p.usageChan:
|
|
if !ok {
|
|
return
|
|
}
|
|
if err := p.Do(usage); err != nil {
|
|
fmt.Printf("[close event]process usage error: %v\n", err)
|
|
}
|
|
default:
|
|
fmt.Printf("usageChan is empty,usage worker %d done\n", i)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
}(i)
|
|
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) Do(llmusage *llm.TokenUsage) error {
|
|
err := p.db.Transaction(func(tx *gorm.DB) error {
|
|
now := time.Now()
|
|
today, _ := time.Parse("2006-01-02", now.Format("2006-01-02"))
|
|
|
|
cost := tokenizer.Cost(llmusage.Model, llmusage.PromptTokens, llmusage.CompletionTokens)
|
|
token, err := p.tokenDAO.GetByID(p.ctx, llmusage.TokenID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
usage := &model.Usage{
|
|
UserID: llmusage.User.ID,
|
|
TokenID: llmusage.TokenID,
|
|
Date: now,
|
|
Model: llmusage.Model,
|
|
Stream: llmusage.Stream,
|
|
PromptTokens: llmusage.PromptTokens,
|
|
CompletionTokens: llmusage.CompletionTokens,
|
|
TotalTokens: llmusage.TotalTokens,
|
|
Cost: fmt.Sprintf("%.8f", cost),
|
|
}
|
|
// 1. 记录使用记录
|
|
if err := tx.WithContext(p.ctx).Create(usage).Error; err != nil {
|
|
return fmt.Errorf("create usage error: %w", err)
|
|
}
|
|
|
|
// 2. 更新每日统计
|
|
var dailyUsage model.DailyUsage
|
|
result := tx.WithContext(p.ctx).Where("user_id = ? and date = ?", llmusage.User.ID, today).First(&dailyUsage)
|
|
if result.RowsAffected == 0 {
|
|
dailyUsage.UserID = llmusage.User.ID
|
|
dailyUsage.TokenID = llmusage.TokenID
|
|
dailyUsage.Date = today
|
|
dailyUsage.Model = llmusage.Model
|
|
dailyUsage.Stream = llmusage.Stream
|
|
dailyUsage.PromptTokens = llmusage.PromptTokens
|
|
dailyUsage.CompletionTokens = llmusage.CompletionTokens
|
|
dailyUsage.TotalTokens = llmusage.TotalTokens
|
|
dailyUsage.Cost = fmt.Sprintf("%.8f", cost)
|
|
if err := tx.WithContext(p.ctx).Create(&dailyUsage).Error; err != nil {
|
|
return fmt.Errorf("create daily usage error: %w", err)
|
|
}
|
|
} else {
|
|
if err := tx.WithContext(p.ctx).Model(&model.DailyUsage{}).Where("user_id = ? and date = ?", llmusage.User.ID, today).
|
|
Updates(map[string]interface{}{
|
|
"prompt_tokens": gorm.Expr("prompt_tokens + ?", llmusage.PromptTokens),
|
|
"completion_tokens": gorm.Expr("completion_tokens + ?", llmusage.CompletionTokens),
|
|
"total_tokens": gorm.Expr("total_tokens + ?", llmusage.TotalTokens),
|
|
}).Error; err != nil {
|
|
return fmt.Errorf("update daily usage error: %w", err)
|
|
}
|
|
}
|
|
|
|
// 3. 更新用户额度
|
|
if *llmusage.User.UnlimitedQuota {
|
|
if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", llmusage.User.ID).Updates(map[string]interface{}{
|
|
"used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)),
|
|
}).Error; err != nil {
|
|
return fmt.Errorf("update user quota and used_quota error: %w", err)
|
|
}
|
|
} else {
|
|
if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", llmusage.User.ID).Updates(map[string]interface{}{
|
|
"quota": gorm.Expr("quota - ?", fmt.Sprintf("%.8f", cost)),
|
|
"used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)),
|
|
}).Error; err != nil {
|
|
return fmt.Errorf("update user quota and used_quota error: %w", err)
|
|
}
|
|
}
|
|
|
|
//4 . 更新token额度
|
|
if *token.UnlimitedQuota {
|
|
if err := tx.WithContext(p.ctx).Model(&model.Token{}).Where("id = ?", llmusage.TokenID).Updates(map[string]interface{}{
|
|
"used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)),
|
|
}).Error; err != nil {
|
|
return fmt.Errorf("update token quota and used_quota error: %w", err)
|
|
}
|
|
} else {
|
|
if err := tx.WithContext(p.ctx).Model(&model.Token{}).Where("id = ?", llmusage.TokenID).Updates(map[string]interface{}{
|
|
"quota": gorm.Expr("quota - ?", fmt.Sprintf("%.8f", cost)),
|
|
"used_quota": gorm.Expr("used_quota + ?", fmt.Sprintf("%.8f", cost)),
|
|
}).Error; err != nil {
|
|
return fmt.Errorf("update token quota and used_quota error: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (p *Proxy) SelectApiKey(model string) error {
|
|
akpikeys, err := p.apiKeyDao.FindApiKeysBySupportModel(p.db, model)
|
|
if err != nil || len(akpikeys) == 0 {
|
|
if strings.HasPrefix(model, "gpt") || strings.HasPrefix(model, "o1") || strings.HasPrefix(model, "o3") || strings.HasPrefix(model, "o4") {
|
|
keys, err := p.apiKeyDao.FindKeys(map[string]any{"active = ?": true, "apitype = ?": "openai"})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
akpikeys = append(akpikeys, keys...)
|
|
}
|
|
|
|
if strings.HasPrefix(model, "gemini") {
|
|
keys, err := p.apiKeyDao.FindKeys(map[string]any{"active = ?": true, "apitype = ?": "gemini"})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
akpikeys = append(akpikeys, keys...)
|
|
}
|
|
|
|
if strings.HasPrefix(model, "claude") {
|
|
keys, err := p.apiKeyDao.FindKeys(map[string]any{"active = ?": true, "apitype = ?": "claude"})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
akpikeys = append(akpikeys, keys...)
|
|
}
|
|
}
|
|
if len(akpikeys) == 0 {
|
|
return errors.New("no available apikey")
|
|
}
|
|
|
|
if len(akpikeys) == 1 {
|
|
p.apikey = &akpikeys[0]
|
|
return nil
|
|
}
|
|
length := len(akpikeys) - 1
|
|
|
|
p.apikey = &akpikeys[rand.Intn(length)]
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *Proxy) updateSupportModel() {
|
|
|
|
keys, err := p.apiKeyDao.FindKeys(map[string]interface{}{"apitype in ?": []string{"openai", "azure", "claude"}})
|
|
if err != nil {
|
|
return
|
|
}
|
|
for _, key := range keys {
|
|
var supportModels []string
|
|
if *key.ApiType == "openai" || *key.ApiType == "azure" {
|
|
supportModels, err = p.getOpenAISupportModels(key)
|
|
}
|
|
if *key.ApiType == "claude" {
|
|
supportModels, err = p.getClaudeSupportModels(key)
|
|
}
|
|
|
|
if err != nil {
|
|
log.Println(err)
|
|
continue
|
|
}
|
|
if len(supportModels) == 0 {
|
|
continue
|
|
|
|
}
|
|
if p.cfg.DB_Type == "sqlite" {
|
|
bytejson, _ := json.Marshal(supportModels)
|
|
if err := p.db.Model(&model.ApiKey{}).Where("id = ?", key.ID).UpdateColumn("support_models", string(bytejson)).Error; err != nil {
|
|
log.Println(err)
|
|
}
|
|
} else if p.cfg.DB_Type == "postgres" {
|
|
if err := p.db.Model(&model.ApiKey{}).Where("id = ?", key.ID).UpdateColumn("support_models", pq.StringArray(supportModels)).Error; err != nil {
|
|
log.Println(err)
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
func (p *Proxy) ScheduleTask() {
|
|
|
|
func() {
|
|
for {
|
|
select {
|
|
case <-time.After(time.Duration(p.cfg.TaskTimeInterval) * time.Minute):
|
|
p.updateSupportModel()
|
|
case <-time.After(time.Hour * 12):
|
|
if err := p.setModelCache(); err != nil {
|
|
fmt.Println("refrash model cache err:", err)
|
|
}
|
|
case <-p.ctx.Done():
|
|
fmt.Println("schedule task done")
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (p *Proxy) getOpenAISupportModels(apikey model.ApiKey) ([]string, error) {
|
|
openaiModelsUrl := "https://api.openai.com/v1/models"
|
|
// https://learn.microsoft.com/zh-cn/rest/api/azureopenai/models/list?view=rest-azureopenai-2025-02-01-preview&tabs=HTTP
|
|
azureModelsUrl := "/openai/deployments?api-version=2022-12-01"
|
|
|
|
var supportModels []string
|
|
var req *http.Request
|
|
if *apikey.ApiType == "azure" {
|
|
if strings.HasSuffix(*apikey.Endpoint, "/") {
|
|
apikey.Endpoint = utils.ToPtr(strings.TrimSuffix(*apikey.Endpoint, "/"))
|
|
}
|
|
req, _ = http.NewRequest("GET", *apikey.Endpoint+azureModelsUrl, nil)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("api-key", *apikey.ApiKey)
|
|
} else {
|
|
req, _ = http.NewRequest("GET", openaiModelsUrl, nil)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+*apikey.ApiKey)
|
|
}
|
|
|
|
resp, err := p.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode == http.StatusOK {
|
|
bytesbody, _ := io.ReadAll(resp.Body)
|
|
result := gjson.GetBytes(bytesbody, "data.#.id").Array()
|
|
for _, v := range result {
|
|
model := v.Str
|
|
model = strings.Replace(model, "-35-", "-3.5-", -1)
|
|
model = strings.Replace(model, "-41-", "-4.1-", -1)
|
|
supportModels = append(supportModels, model)
|
|
}
|
|
}
|
|
return supportModels, nil
|
|
}
|
|
|
|
func (p *Proxy) getClaudeSupportModels(apikey model.ApiKey) ([]string, error) {
|
|
// https://docs.anthropic.com/en/api/models-list
|
|
claudemodelsUrl := "https://api.anthropic.com/v1/models"
|
|
var supportModels []string
|
|
|
|
req, _ := http.NewRequest("GET", claudemodelsUrl, nil)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("x-api-key", *apikey.ApiKey)
|
|
req.Header.Set("anthropic-version", "2023-06-01")
|
|
|
|
resp, err := p.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode == http.StatusOK {
|
|
bytesbody, _ := io.ReadAll(resp.Body)
|
|
result := gjson.GetBytes(bytesbody, "data.#.id").Array()
|
|
for _, v := range result {
|
|
supportModels = append(supportModels, v.Str)
|
|
}
|
|
}
|
|
return supportModels, nil
|
|
}
|