reface to openteam

This commit is contained in:
Sakurasan
2025-04-16 18:01:27 +08:00
parent bc223d6530
commit e7ffc9e8b9
92 changed files with 5345 additions and 1273 deletions

View File

@@ -0,0 +1,60 @@
package controller
import (
"fmt"
"net/http"
"opencatd-open/internal/dto"
"opencatd-open/llm"
"opencatd-open/llm/claude/v2"
"opencatd-open/llm/google/v2"
"opencatd-open/llm/openai_compatible"
"github.com/gin-gonic/gin"
)
func (h *Proxy) ChatHandler(c *gin.Context) {
var chatreq llm.ChatRequest
if err := c.ShouldBindJSON(&chatreq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
err := h.SelectApiKey(chatreq.Model)
if err != nil {
dto.WrapErrorAsOpenAI(c, 500, err.Error())
return
}
var llm llm.LLM
switch *h.apikey.ApiType {
case "claude":
llm, err = claude.NewClaude(h.apikey)
case "gemini":
llm, err = google.NewGemini(c, h.apikey)
case "openai", "azure", "github":
fallthrough
default:
llm, err = openai_compatible.NewOpenAICompatible(h.apikey)
if err != nil {
dto.WrapErrorAsOpenAI(c, 500, fmt.Errorf("create llm client error: %w", err).Error())
return
}
}
if !chatreq.Stream {
resp, err := llm.Chat(c, chatreq)
if err != nil {
dto.WrapErrorAsOpenAI(c, 500, err.Error())
}
c.JSON(http.StatusOK, resp)
} else {
datachan, err := llm.StreamChat(c, chatreq)
if err != nil {
dto.WrapErrorAsOpenAI(c, 500, err.Error())
}
for data := range datachan {
c.SSEvent("", data)
}
}
}

View File

@@ -0,0 +1,345 @@
package controller
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math/rand"
"net/http"
"net/url"
"opencatd-open/internal/dao"
"opencatd-open/internal/model"
"opencatd-open/pkg/config"
"os"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/lib/pq"
"github.com/tidwall/gjson"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type Proxy struct {
ctx context.Context
cfg *config.Config
db *gorm.DB
wg *sync.WaitGroup
usageChan chan *model.Usage // 用于异步处理的channel
apikey *model.ApiKey
httpClient *http.Client
userDAO *dao.UserDAO
apiKeyDao *dao.ApiKeyDAO
tokenDAO *dao.TokenDAO
usageDAO *dao.UsageDAO
dailyUsageDAO *dao.DailyUsageDAO
}
func NewProxy(ctx context.Context, cfg *config.Config, db *gorm.DB, wg *sync.WaitGroup, userDAO *dao.UserDAO, apiKeyDAO *dao.ApiKeyDAO, tokenDAO *dao.TokenDAO, usageDAO *dao.UsageDAO, dailyUsageDAO *dao.DailyUsageDAO) *Proxy {
client := http.DefaultClient
if os.Getenv("LOCAL_PROXY") != "" {
proxyUrl, err := url.Parse(os.Getenv("LOCAL_PROXY"))
if err == nil {
tr := &http.Transport{
Proxy: http.ProxyURL(proxyUrl),
}
client.Transport = tr
}
}
np := &Proxy{
ctx: ctx,
cfg: cfg,
db: db,
wg: wg,
httpClient: client,
usageChan: make(chan *model.Usage, cfg.UsageChanSize),
userDAO: userDAO,
apiKeyDao: apiKeyDAO,
tokenDAO: tokenDAO,
usageDAO: usageDAO,
dailyUsageDAO: dailyUsageDAO,
}
go np.ProcessUsage()
go np.ScheduleTask()
return np
}
func (p *Proxy) HandleProxy(c *gin.Context) {
if c.Request.URL.Path == "/v1/chat/completions" {
p.ChatHandler(c)
return
}
}
func (p *Proxy) SendUsage(usage *model.Usage) {
select {
case p.usageChan <- usage:
default:
log.Println("usage channel is full, skip processing")
bj, _ := json.Marshal(usage)
log.Println(string(bj))
//TODO: send to a queue
}
}
func (p *Proxy) ProcessUsage() {
for i := 0; i < p.cfg.UsageWorker; i++ {
p.wg.Add(1)
go func(i int) {
defer p.wg.Done()
for {
select {
case usage, ok := <-p.usageChan:
if !ok {
// channel 关闭,退出程序
return
}
err := p.Do(usage)
if err != nil {
log.Printf("process usage error: %v\n", err)
}
case <-p.ctx.Done():
// close(s.usageChan)
// for usage := range s.usageChan {
// if err := s.Do(usage); err != nil {
// fmt.Printf("[close event]process usage error: %v\n", err)
// }
// }
for {
select {
case usage, ok := <-p.usageChan:
if !ok {
return
}
if err := p.Do(usage); err != nil {
fmt.Printf("[close event]process usage error: %v\n", err)
}
default:
fmt.Printf("usageChan is empty,usage worker %d done\n", i)
return
}
}
}
}
}(i)
}
}
func (p *Proxy) Do(usage *model.Usage) error {
err := p.db.Transaction(func(tx *gorm.DB) error {
// 1. 记录使用记录
if err := tx.WithContext(p.ctx).Create(usage).Error; err != nil {
return fmt.Errorf("create usage error: %w", err)
}
// 2. 更新每日统计upsert 操作)
dailyUsage := model.DailyUsage{
UserID: usage.UserID,
TokenID: usage.TokenID,
Capability: usage.Capability,
Date: time.Date(usage.Date.Year(), usage.Date.Month(), usage.Date.Day(), 0, 0, 0, 0, usage.Date.Location()),
Model: usage.Model,
Stream: usage.Stream,
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
TotalTokens: usage.TotalTokens,
Cost: usage.Cost,
}
// 使用 OnConflict 实现 upsert
if err := tx.WithContext(p.ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "user_id"}, {Name: "token_id"}, {Name: "capability"}, {Name: "date"}}, // 唯一键
DoUpdates: clause.Assignments(map[string]interface{}{
"prompt_tokens": gorm.Expr("prompt_tokens + ?", usage.PromptTokens),
"completion_tokens": gorm.Expr("completion_tokens + ?", usage.CompletionTokens),
"total_tokens": gorm.Expr("total_tokens + ?", usage.TotalTokens),
"cost": gorm.Expr("cost + ?", usage.Cost),
}),
}).Create(&dailyUsage).Error; err != nil {
return fmt.Errorf("upsert daily usage error: %w", err)
}
// 3. 更新用户额度
if err := tx.WithContext(p.ctx).Model(&model.User{}).Where("id = ?", usage.UserID).Updates(map[string]interface{}{
"quota": gorm.Expr("quota - ?", usage.Cost),
"used_quota": gorm.Expr("used_quota + ?", usage.Cost),
}).Error; err != nil {
return fmt.Errorf("update user quota and used_quota error: %w", err)
}
return nil
})
return err
}
func (p *Proxy) SelectApiKey(model string) error {
akpikeys, err := p.apiKeyDao.FindApiKeysBySupportModel(p.db, model)
if err != nil {
return err
}
if len(akpikeys) == 0 {
return errors.New("no available apikey")
} else {
if strings.HasPrefix(model, "gpt") {
keys, err := p.apiKeyDao.FindKeys(map[string]any{"type = ?": "openai"})
if err != nil {
return err
}
akpikeys = append(akpikeys, keys...)
}
if strings.HasPrefix(model, "gemini") {
keys, err := p.apiKeyDao.FindKeys(map[string]any{"type = ?": "gemini"})
if err != nil {
return err
}
akpikeys = append(akpikeys, keys...)
}
if strings.HasPrefix(model, "claude") {
keys, err := p.apiKeyDao.FindKeys(map[string]any{"type = ?": "claude"})
if err != nil {
return err
}
akpikeys = append(akpikeys, keys...)
}
}
if len(akpikeys) == 0 {
return errors.New("no available apikey")
}
if len(akpikeys) == 1 {
p.apikey = &akpikeys[0]
return nil
}
length := len(akpikeys) - 1
p.apikey = &akpikeys[rand.Intn(length)]
return nil
}
func (p *Proxy) updateSupportModel() {
keys, err := p.apiKeyDao.FindKeys(map[string]interface{}{"type in ?": "openai,azure,claude"})
if err != nil {
return
}
for _, key := range keys {
var supportModels []string
if *key.ApiType == "openai" || *key.ApiType == "azure" {
supportModels, err = p.getOpenAISupportModels(key)
}
if *key.ApiType == "claude" {
supportModels, err = p.getClaudeSupportModels(key)
}
if err != nil {
log.Println(err)
continue
}
if len(supportModels) == 0 {
continue
}
if p.cfg.DB_Type == "sqlite" {
bytejson, _ := json.Marshal(supportModels)
if err := p.db.Model(&model.ApiKey{}).Where("id = ?", key.ID).UpdateColumn("support_models", string(bytejson)).Error; err != nil {
log.Println(err)
}
} else if p.cfg.DB_Type == "postgres" {
if err := p.db.Model(&model.ApiKey{}).Where("id = ?", key.ID).UpdateColumn("support_models", pq.StringArray(supportModels)).Error; err != nil {
log.Println(err)
}
}
}
}
func (p *Proxy) ScheduleTask() {
func() {
for {
select {
case <-time.After(time.Duration(p.cfg.TaskTimeInterval) * time.Minute):
p.updateSupportModel()
case <-p.ctx.Done():
fmt.Println("schedule task done")
return
}
}
}()
}
func (p *Proxy) getOpenAISupportModels(apikey model.ApiKey) ([]string, error) {
openaiModelsUrl := "https://api.openai.com/v1/models"
// https://learn.microsoft.com/zh-cn/rest/api/azureopenai/models/list?view=rest-azureopenai-2025-02-01-preview&tabs=HTTP
azureModelsUrl := "/openai/deployments?api-version=2022-12-01"
var supportModels []string
var req *http.Request
if *apikey.ApiType == "azure" {
req, _ = http.NewRequest("GET", *apikey.Endpoint+azureModelsUrl, nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("api-key", *apikey.ApiKey)
} else {
req, _ = http.NewRequest("GET", openaiModelsUrl, nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+*apikey.ApiKey)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
bytesbody, _ := io.ReadAll(resp.Body)
result := gjson.GetBytes(bytesbody, "data.#.id").Array()
for _, v := range result {
model := v.Str
model = strings.Replace(model, "-35-", "-3.5-", -1)
model = strings.Replace(model, "-41-", "-4.1-", -1)
supportModels = append(supportModels, model)
}
}
return supportModels, nil
}
func (p *Proxy) getClaudeSupportModels(apikey model.ApiKey) ([]string, error) {
// https://docs.anthropic.com/en/api/models-list
claudemodelsUrl := "https://api.anthropic.com/v1/models"
var supportModels []string
req, _ := http.NewRequest("GET", claudemodelsUrl, nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-api-key", *apikey.ApiKey)
req.Header.Set("anthropic-version", "2023-06-01")
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
bytesbody, _ := io.ReadAll(resp.Body)
result := gjson.GetBytes(bytesbody, "data.#.id").Array()
for _, v := range result {
supportModels = append(supportModels, v.Str)
}
}
return supportModels, nil
}