From 7fd82b43f45f54dd846a36e39c2606a2170f0f26 Mon Sep 17 00:00:00 2001 From: Sakurasan <26715255+Sakurasan@users.noreply.github.com> Date: Fri, 13 Sep 2024 21:32:10 +0800 Subject: [PATCH] refact & update new model --- README.md | 29 +- docker-compose.yml | 26 + docker/Dockerfile | 2 +- docker/docker-compose.yml | 20 +- opencat.go | 49 +- pkg/claude/chat.go | 67 ++- pkg/claude/claude.go | 4 +- pkg/error/errdata.go | 11 + pkg/openai/chat.go | 25 +- pkg/openai/whisper.go | 177 +++++++ pkg/team/key.go | 182 +++++++ pkg/team/me.go | 104 ++++ pkg/team/middleware.go | 59 +++ pkg/team/usage.go | 38 ++ pkg/team/user.go | 89 ++++ pkg/tokenizer/tokenizer.go | 9 +- pkg/vertexai/auth.go | 167 +++++++ router/chat.go | 28 +- router/router.go | 948 +++---------------------------------- store/cache.go | 8 + store/keydb.go | 27 ++ 21 files changed, 1094 insertions(+), 975 deletions(-) create mode 100644 docker-compose.yml create mode 100644 pkg/error/errdata.go create mode 100644 pkg/openai/whisper.go create mode 100644 pkg/team/key.go create mode 100644 pkg/team/me.go create mode 100644 pkg/team/middleware.go create mode 100644 pkg/team/usage.go create mode 100644 pkg/team/user.go create mode 100644 pkg/vertexai/auth.go diff --git a/README.md b/README.md index 12334ff..6acde4b 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,7 @@ -# opencatd-open +# ~~opencatd-open~~ [OpenTeam](https://github.com/mirrors2/opencatd-open) + + 本项目即将更名,后续请关注 👉🏻 https://github.com/mirrors2/openteam + GitHub Workflow Status @@ -14,11 +17,11 @@ OpenCat for Team的开源实现 ## Extra Support: -| 任务 | 完成情况 | -| --- | --- | -|[Azure OpenAI](./doc/azure.md) | ✅| -|[Claude](./doc/azure.md) | ✅| -|[Gemini](./doc/gemini.md) | ✅| +| 🎯 | 🚧 |Extra Provider| +| --- | --- | --- | +|[OpenAI](./doc/azure.md) | ✅|Azure, Github Marketplace| +|[Claude](./doc/azure.md) | ✅|VertexAI| +|[Gemini](./doc/gemini.md) | ✅|| | ... | ... | @@ -80,6 +83,18 @@ wget https://github.com/mirrors2/opencatd-open/raw/main/docker/docker-compose.ym pandora for team - [pandora for team](./doc/pandora.md) + +如何自定义HOST地址? (仅OpenAI) + - 需修改环境变量,优先级递增 + - Cloudflare AI Gateway地址 `AIGateWay_Endpoint=https://gateway.ai.cloudflare.com/v1/123456789/xxxx/openai/chat/completions` + - 自定义的endpoint `$CUSTOM_ENDPOINT=true && $OpenAI_Endpoint=https://your.domain/v1/chat/completions` + +设置主页跳转地址? + - 修改环境变量 `CUSTOM_REDIRECT=https://your.domain` + +## 赞助 +[![Buy Me A Coffee](https://img.shields.io/badge/Buy%20Me%20A%20Coffee-FFDD55?style=flat-square&logo=buy-me-a-coffee&logoColor=black)](https://www.buymeacoffee.com/littlecjun) + # License -[GNU General Public License v3.0](License) \ No newline at end of file +[![GitHub License](https://img.shields.io/github/license/mirrors2/opencatd-open.svg?logo=github&style=flat-square)](https://github.com/mirrors2/opencatd-open/blob/main/License) diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..36ade56 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,26 @@ +# Email: admin@example.com +# Password: changeme +version: '3' + +services: + npm: + image: jc21/nginx-proxy-manager + network_mode: host + ports: + - '80:80' + - '81:81' + - '443:443' + volumes: + - $PWD/data:/data + - $PWD/www:/var/www + - $PWD/letsencrypt:/etc/letsencrypt + environment: + - "TZ=Asia/Shanghai" # set timezone, default UTC + - "PUID=1000" # set group id, default 0 (root) + - "PGID=1000" + + # certbot: + # image: certbot/certbot + # volumes: + # - $PWD/data/certbot/conf:/etc/letsencrypt + # - $PWD/data/certbot/www:/var/www/certbot diff --git a/docker/Dockerfile b/docker/Dockerfile index f11aa0d..968a3d4 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,4 +1,4 @@ -FROM node:18.12.1-alpine3.16 AS frontend +FROM node:20-alpine AS frontend WORKDIR /frontend-build COPY ./web/ . RUN npm install && npm run build && rm -rf node_modules diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 154e9fe..29bf135 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -1,10 +1,24 @@ version: '3.7' -services: +services: opencatd: image: mirrors2/opencatd-open - container_name: opencatd-open + container_name: opencatd-open restart: unless-stopped + #network_mode: host ports: - 80:80 volumes: - - /etc/opencatd:/app/db \ No newline at end of file + - $PWD/db:/app/db + logging: + # driver: "json-file" + options: + max-size: 10m + max-file: 3 + # environment: + # Vertex: | + # { + # "type": "service_account", + # "universe_domain": "googleapis.com" + # } + + \ No newline at end of file diff --git a/opencat.go b/opencat.go index 6dee33a..a3ebfe5 100644 --- a/opencat.go +++ b/opencat.go @@ -8,6 +8,7 @@ import ( "io/fs" "log" "net/http" + "opencatd-open/pkg/team" "opencatd-open/router" "opencatd-open/store" "os" @@ -143,41 +144,29 @@ func main() { r := gin.Default() group := r.Group("/1") { - group.Use(router.AuthMiddleware()) + group.Use(team.AuthMiddleware()) // 获取当前用户信息 - group.GET("/me", router.HandleMe) + group.GET("/me", team.HandleMe) + group.GET("/me/usages", team.HandleMeUsage) - group.GET("/me/usages", router.HandleMeUsage) + group.GET("/keys", team.HandleKeys) // 获取所有Key + group.POST("/keys", team.HandleAddKey) // 添加Key + group.DELETE("/keys/:id", team.HandleDelKey) // 删除Key - // 获取所有Key - group.GET("/keys", router.HandleKeys) + group.GET("/users", team.HandleUsers) // 获取所有用户信息 + group.POST("/users", team.HandleAddUser) // 添加用户 + group.DELETE("/users/:id", team.HandleDelUser) // 删除用户 - // 获取所有用户信息 - group.GET("/users", router.HandleUsers) - - group.GET("/usages", router.HandleUsage) - - // 添加Key - group.POST("/keys", router.HandleAddKey) - - // 删除Key - group.DELETE("/keys/:id", router.HandleDelKey) - - // 添加用户 - group.POST("/users", router.HandleAddUser) - - // 删除用户 - group.DELETE("/users/:id", router.HandleDelUser) + group.GET("/usages", team.HandleUsage) // 重置用户Token - group.POST("/users/:id/reset", router.HandleResetUserToken) + group.POST("/users/:id/reset", team.HandleResetUserToken) } - // 初始化用户 - r.POST("/1/users/init", router.Handleinit) + r.POST("/1/users/init", team.Handleinit) - r.Any("/v1/*proxypath", router.HandleProy) + r.Any("/v1/*proxypath", router.HandleProxy) // r.POST("/v1/chat/completions", router.HandleProy) // r.GET("/v1/models", router.HandleProy) @@ -188,7 +177,15 @@ func main() { if err != nil { panic(err) } - r.GET("/", gin.WrapH(http.FileServer(http.FS(idxFS)))) + redirect := os.Getenv("CUSTOM_REDIRECT") + if redirect != "" { + r.GET("/", func(c *gin.Context) { + c.Redirect(http.StatusMovedPermanently, redirect) + }) + + } else { + r.GET("/", gin.WrapH(http.FileServer(http.FS(idxFS)))) + } assetsFS, err := fs.Sub(web, "dist/assets") if err != nil { panic(err) diff --git a/pkg/claude/chat.go b/pkg/claude/chat.go index 8987dc7..39afe48 100644 --- a/pkg/claude/chat.go +++ b/pkg/claude/chat.go @@ -10,8 +10,10 @@ import ( "io" "log" "net/http" + "opencatd-open/pkg/error" "opencatd-open/pkg/openai" "opencatd-open/pkg/tokenizer" + "opencatd-open/pkg/vertexai" "opencatd-open/store" "strings" @@ -27,14 +29,15 @@ func ChatTextCompletions(c *gin.Context, chatReq *openai.ChatCompletionRequest) } type ChatRequest struct { - Model string `json:"model,omitempty"` - Messages any `json:"messages,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Stream bool `json:"stream,omitempty"` - System string `json:"system,omitempty"` - TopK int `json:"top_k,omitempty"` - TopP float64 `json:"top_p,omitempty"` - Temperature float64 `json:"temperature,omitempty"` + Model string `json:"model,omitempty"` + Messages any `json:"messages,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` + System string `json:"system,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + AnthropicVersion string `json:"anthropic_version,omitempty"` } func (c *ChatRequest) ByteJson() []byte { @@ -117,8 +120,12 @@ type ClaudeStreamResponse struct { } func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) { + var ( + req *http.Request + targetURL = ClaudeMessageEndpoint + ) - onekey, err := store.SelectKeyCache("claude") + apiKey, err := store.SelectKeyCache("claude") if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -131,6 +138,10 @@ func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) { // claudReq.Temperature = chatReq.Temperature claudReq.TopP = chatReq.TopP claudReq.MaxTokens = 4096 + if apiKey.ApiType == "vertex" { + claudReq.AnthropicVersion = "vertex-2023-10-16" + claudReq.Model = "" + } var prompt string @@ -181,10 +192,40 @@ func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) { usagelog.PromptCount = tokenizer.NumTokensFromStr(prompt, chatReq.Model) - req, _ := http.NewRequest("POST", MessageEndpoint, bytes.NewReader(claudReq.ByteJson())) - req.Header.Set("x-api-key", onekey.Key) - req.Header.Set("anthropic-version", "2023-06-01") - req.Header.Set("Content-Type", "application/json") + if apiKey.ApiType == "vertex" { + var vertexSecret vertexai.VertexSecretKey + if err := json.Unmarshal([]byte(apiKey.ApiSecret), &vertexSecret); err != nil { + c.JSON(http.StatusInternalServerError, error.ErrorData(err.Error())) + return + } + + vcmodel, ok := vertexai.VertexClaudeModelMap[chatReq.Model] + if !ok { + c.JSON(http.StatusInternalServerError, error.ErrorData("Model not found")) + return + } + + // 获取gcloud token,临时放置在apiKey.Key中 + gcloudToken, err := vertexai.GcloudAuth(vertexSecret.ClientEmail, vertexSecret.PrivateKey) + if err != nil { + c.JSON(http.StatusInternalServerError, error.ErrorData(err.Error())) + return + } + + // 拼接vertex的请求地址 + targetURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", vcmodel.Region, vertexSecret.ProjectID, vcmodel.Region, vcmodel.VertexName) + + req, _ = http.NewRequest("POST", targetURL, bytes.NewReader(claudReq.ByteJson())) + req.Header.Set("Authorization", "Bearer "+gcloudToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Accept-Encoding", "identity") + } else { + req, _ = http.NewRequest("POST", targetURL, bytes.NewReader(claudReq.ByteJson())) + req.Header.Set("x-api-key", apiKey.Key) + req.Header.Set("anthropic-version", "2023-06-01") + req.Header.Set("Content-Type", "application/json") + } client := http.DefaultClient rsp, err := client.Do(req) diff --git a/pkg/claude/claude.go b/pkg/claude/claude.go index 0964f46..3c5f110 100644 --- a/pkg/claude/claude.go +++ b/pkg/claude/claude.go @@ -68,8 +68,8 @@ import ( ) var ( - ClaudeUrl = "https://api.anthropic.com/v1/complete" - MessageEndpoint = "https://api.anthropic.com/v1/messages" + ClaudeUrl = "https://api.anthropic.com/v1/complete" + ClaudeMessageEndpoint = "https://api.anthropic.com/v1/messages" ) type MessageModule struct { diff --git a/pkg/error/errdata.go b/pkg/error/errdata.go new file mode 100644 index 0000000..7a5b146 --- /dev/null +++ b/pkg/error/errdata.go @@ -0,0 +1,11 @@ +package error + +import "github.com/gin-gonic/gin" + +func ErrorData(message string) gin.H { + return gin.H{ + "error": gin.H{ + "message": message, + }, + } +} diff --git a/pkg/openai/chat.go b/pkg/openai/chat.go index b7270b7..f6c9dce 100644 --- a/pkg/openai/chat.go +++ b/pkg/openai/chat.go @@ -17,8 +17,10 @@ import ( ) const ( - AzureApiVersion = "2024-02-01" - OpenAI_Endpoint = "https://api.openai.com/v1/chat/completions" + AzureApiVersion = "2024-02-01" + BaseHost = "api.openai.com" + OpenAI_Endpoint = "https://api.openai.com/v1/chat/completions" + Github_Marketplace = "https://models.inference.ai.azure.com/chat/completions" ) var ( @@ -204,6 +206,10 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) { var req *http.Request switch onekey.ApiType { + case "github": + req, err = http.NewRequest(c.Request.Method, Github_Marketplace, bytes.NewReader(chatReq.ToByteJson())) + req.Header = c.Request.Header + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key)) case "azure": var buildurl string if onekey.EndPoint != "" { @@ -215,15 +221,18 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) { req.Header = c.Request.Header req.Header.Set("api-key", onekey.Key) default: + req, err = http.NewRequest(c.Request.Method, OpenAI_Endpoint, bytes.NewReader(chatReq.ToByteJson())) if onekey.EndPoint != "" { // 优先key的endpoint req, err = http.NewRequest(c.Request.Method, onekey.EndPoint+c.Request.RequestURI, bytes.NewReader(chatReq.ToByteJson())) - } else { - if BaseURL != "" { // 其次BaseURL - req, err = http.NewRequest(c.Request.Method, BaseURL+c.Request.RequestURI, bytes.NewReader(chatReq.ToByteJson())) - } else { // 最后是gateway的endpoint - req, err = http.NewRequest(c.Request.Method, AIGateWay_Endpoint, bytes.NewReader(chatReq.ToByteJson())) - } } + if AIGateWay_Endpoint != "" { // cloudflare gateway的endpoint + req, err = http.NewRequest(c.Request.Method, AIGateWay_Endpoint, bytes.NewReader(chatReq.ToByteJson())) + } + customEndpoint := os.Getenv("CUSTOM_ENDPOINT") // 最后是用户自定义的endpoint CUSTOM_ENDPOINT=true OpenAI_Endpoint + if customEndpoint == "true" && OpenAI_Endpoint != "" { + req, err = http.NewRequest(c.Request.Method, BaseURL, bytes.NewReader(chatReq.ToByteJson())) + } + req.Header = c.Request.Header req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key)) } diff --git a/pkg/openai/whisper.go b/pkg/openai/whisper.go new file mode 100644 index 0000000..ff8fd18 --- /dev/null +++ b/pkg/openai/whisper.go @@ -0,0 +1,177 @@ +package openai + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "log" + "mime/multipart" + "net/http" + "net/http/httputil" + "net/url" + "opencatd-open/pkg/tokenizer" + "opencatd-open/store" + "path/filepath" + "time" + + "github.com/faiface/beep" + "github.com/faiface/beep/mp3" + "github.com/faiface/beep/wav" + "github.com/gin-gonic/gin" + "gopkg.in/vansante/go-ffprobe.v2" +) + +func WhisperProxy(c *gin.Context) { + var chatlog store.Tokens + + byteBody, _ := io.ReadAll(c.Request.Body) + c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody)) + + model, _ := c.GetPostForm("model") + + key, err := store.SelectKeyCache("openai") + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": err.Error(), + }, + }) + return + } + + chatlog.Model = model + + token, _ := c.Get("localuser") + + lu, err := store.GetUserByToken(token.(string)) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": err.Error(), + }, + }) + return + } + chatlog.UserID = int(lu.ID) + + if err := ParseWhisperRequestTokens(c, &chatlog, byteBody); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": err.Error(), + }, + }) + return + } + if key.EndPoint == "" { + key.EndPoint = "https://api.openai.com" + } + targetUrl, _ := url.ParseRequestURI(key.EndPoint + c.Request.URL.String()) + log.Println(targetUrl) + proxy := httputil.NewSingleHostReverseProxy(targetUrl) + proxy.Director = func(req *http.Request) { + req.Host = targetUrl.Host + req.URL.Scheme = targetUrl.Scheme + req.URL.Host = targetUrl.Host + + req.Header.Set("Authorization", "Bearer "+key.Key) + } + + proxy.ModifyResponse = func(resp *http.Response) error { + if resp.StatusCode != http.StatusOK { + return nil + } + chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount + chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) + if err := store.Record(&chatlog); err != nil { + log.Println(err) + } + if err := store.SumDaily(chatlog.UserID); err != nil { + log.Println(err) + } + return nil + } + proxy.ServeHTTP(c.Writer, c.Request) +} + +func probe(fileReader io.Reader) (time.Duration, error) { + ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelFn() + + data, err := ffprobe.ProbeReader(ctx, fileReader) + if err != nil { + return 0, err + } + + duration := data.Format.DurationSeconds + pduration, err := time.ParseDuration(fmt.Sprintf("%fs", duration)) + if err != nil { + return 0, fmt.Errorf("Error parsing duration: %s", err) + } + return pduration, nil +} + +func getAudioDuration(file *multipart.FileHeader) (time.Duration, error) { + var ( + streamer beep.StreamSeekCloser + format beep.Format + err error + ) + + f, err := file.Open() + defer f.Close() + + // Get the file extension to determine the audio file type + fileType := filepath.Ext(file.Filename) + + switch fileType { + case ".mp3": + streamer, format, err = mp3.Decode(f) + case ".wav": + streamer, format, err = wav.Decode(f) + case ".m4a": + duration, err := probe(f) + if err != nil { + return 0, err + } + return duration, nil + default: + return 0, errors.New("unsupported audio file format") + } + + if err != nil { + return 0, err + } + defer streamer.Close() + + // Calculate the audio file's duration. + numSamples := streamer.Len() + sampleRate := format.SampleRate + duration := time.Duration(numSamples) * time.Second / time.Duration(sampleRate) + + return duration, nil +} + +func ParseWhisperRequestTokens(c *gin.Context, usage *store.Tokens, byteBody []byte) error { + file, _ := c.FormFile("file") + model, _ := c.GetPostForm("model") + usage.Model = model + + if file != nil { + duration, err := getAudioDuration(file) + if err != nil { + return fmt.Errorf("Error getting audio duration:%s", err) + } + + if duration > 5*time.Minute { + return fmt.Errorf("Audio duration exceeds 5 minutes") + } + // 计算时长,四舍五入到最接近的秒数 + usage.PromptCount = int(duration.Round(time.Second).Seconds()) + } + + c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody)) + + return nil +} diff --git a/pkg/team/key.go b/pkg/team/key.go new file mode 100644 index 0000000..802f10b --- /dev/null +++ b/pkg/team/key.go @@ -0,0 +1,182 @@ +package team + +import ( + "net/http" + "opencatd-open/pkg/azureopenai" + "opencatd-open/store" + "strings" + + "github.com/Sakurasan/to" + "github.com/gin-gonic/gin" +) + +type Key struct { + ID int `json:"id,omitempty"` + Key string `json:"key,omitempty"` + Name string `json:"name,omitempty"` + ApiType string `json:"api_type,omitempty"` + Endpoint string `json:"endpoint,omitempty"` + UpdatedAt string `json:"updatedAt,omitempty"` + CreatedAt string `json:"createdAt,omitempty"` +} + +func HandleKeys(c *gin.Context) { + keys, err := store.GetAllKeys() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "error": err.Error(), + }) + } + + c.JSON(http.StatusOK, keys) +} + +func HandleAddKey(c *gin.Context) { + var body Key + if err := c.BindJSON(&body); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } + body.Name = strings.ToLower(strings.TrimSpace(body.Name)) + body.Key = strings.TrimSpace(body.Key) + if strings.HasPrefix(body.Name, "azure.") { + keynames := strings.Split(body.Name, ".") + if len(keynames) < 2 { + c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ + "message": "Invalid Key Name", + }}) + return + } + k := &store.Key{ + ApiType: "azure", + Name: body.Name, + Key: body.Key, + ResourceNmae: keynames[1], + EndPoint: body.Endpoint, + } + if err := store.CreateKey(k); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } + } else if strings.HasPrefix(body.Name, "claude.") { + keynames := strings.Split(body.Name, ".") + if len(keynames) < 2 { + c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ + "message": "Invalid Key Name", + }}) + return + } + if body.Endpoint == "" { + body.Endpoint = "https://api.anthropic.com" + } + k := &store.Key{ + // ApiType: "anthropic", + ApiType: "claude", + Name: body.Name, + Key: body.Key, + ResourceNmae: keynames[1], + EndPoint: body.Endpoint, + } + if err := store.CreateKey(k); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } + } else if strings.HasPrefix(body.Name, "google.") { + keynames := strings.Split(body.Name, ".") + if len(keynames) < 2 { + c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ + "message": "Invalid Key Name", + }}) + return + } + + k := &store.Key{ + // ApiType: "anthropic", + ApiType: "google", + Name: body.Name, + Key: body.Key, + ResourceNmae: keynames[1], + EndPoint: body.Endpoint, + } + if err := store.CreateKey(k); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } + } else if strings.HasPrefix(body.Name, "github.") { + keynames := strings.Split(body.Name, ".") + if len(keynames) < 2 { + c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ + "message": "Invalid Key Name", + }}) + return + } + + k := &store.Key{ + ApiType: "github", + Name: body.Name, + Key: body.Key, + ResourceNmae: keynames[1], + EndPoint: body.Endpoint, + } + if err := store.CreateKey(k); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } + } else { + if body.ApiType == "" { + if err := store.AddKey("openai", body.Key, body.Name); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } + } else { + k := &store.Key{ + ApiType: body.ApiType, + Name: body.Name, + Key: body.Key, + ResourceNmae: azureopenai.GetResourceName(body.Endpoint), + EndPoint: body.Endpoint, + } + if err := store.CreateKey(k); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } + } + + } + + k, err := store.GetKeyrByName(body.Name) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } + c.JSON(http.StatusOK, k) +} + +func HandleDelKey(c *gin.Context) { + id := to.Int(c.Param("id")) + if id < 1 { + c.JSON(http.StatusOK, gin.H{"error": "invalid key id"}) + return + } + if err := store.DeleteKey(uint(id)); err != nil { + c.JSON(http.StatusOK, gin.H{"error": "invalid key id"}) + return + } + c.JSON(http.StatusOK, gin.H{"message": "ok"}) +} diff --git a/pkg/team/me.go b/pkg/team/me.go new file mode 100644 index 0000000..146ce7c --- /dev/null +++ b/pkg/team/me.go @@ -0,0 +1,104 @@ +package team + +import ( + "errors" + "net/http" + "opencatd-open/store" + "time" + + "github.com/Sakurasan/to" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "gorm.io/gorm" +) + +func Handleinit(c *gin.Context) { + user, err := store.GetUserByID(1) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + u := store.User{Name: "root", Token: uuid.NewString()} + u.ID = 1 + if err := store.CreateUser(&u); err != nil { + c.JSON(http.StatusForbidden, gin.H{ + "error": err.Error(), + }) + return + } else { + rootToken = u.Token + resJSON := User{ + false, + int(u.ID), + u.UpdatedAt.Format(time.RFC3339), + u.Name, + u.Token, + u.CreatedAt.Format(time.RFC3339), + } + c.JSON(http.StatusOK, resJSON) + return + } + } + c.JSON(http.StatusOK, gin.H{ + "error": err.Error(), + }) + return + } + if user.ID == uint(1) { + c.JSON(http.StatusForbidden, gin.H{ + "error": "super user already exists, use cli to reset password", + }) + } +} + +func HandleMe(c *gin.Context) { + token := c.GetHeader("Authorization") + u, err := store.GetUserByToken(token[7:]) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "error": err.Error(), + }) + } + + resJSON := User{ + false, + int(u.ID), + u.UpdatedAt.Format(time.RFC3339), + u.Name, + u.Token, + u.CreatedAt.Format(time.RFC3339), + } + c.JSON(http.StatusOK, resJSON) +} + +func HandleMeUsage(c *gin.Context) { + token := c.GetHeader("Authorization") + fromStr := c.Query("from") + toStr := c.Query("to") + getMonthStartAndEnd := func() (start, end string) { + loc, _ := time.LoadLocation("Local") + now := time.Now().In(loc) + + year, month, _ := now.Date() + + startOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, loc) + endOfMonth := startOfMonth.AddDate(0, 1, 0) + + start = startOfMonth.Format("2006-01-02") + end = endOfMonth.Format("2006-01-02") + return + } + if fromStr == "" || toStr == "" { + fromStr, toStr = getMonthStartAndEnd() + } + user, err := store.GetUserByToken(token) + if err != nil { + c.AbortWithError(http.StatusForbidden, err) + return + } + usage, err := store.QueryUserUsage(to.String(user.ID), fromStr, toStr) + if err != nil { + c.AbortWithError(http.StatusForbidden, err) + return + } + + c.JSON(200, usage) +} diff --git a/pkg/team/middleware.go b/pkg/team/middleware.go new file mode 100644 index 0000000..ab01e79 --- /dev/null +++ b/pkg/team/middleware.go @@ -0,0 +1,59 @@ +package team + +import ( + "log" + "net/http" + "opencatd-open/store" + "strings" + + "github.com/gin-gonic/gin" +) + +var ( + rootToken string +) + +func AuthMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + if rootToken == "" { + u, err := store.GetUserByID(uint(1)) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + c.Abort() + return + } + rootToken = u.Token + } + token := c.GetHeader("Authorization") + if token == "" || token[:7] != "Bearer " { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + c.Abort() + return + } + if store.IsExistAuthCache(token[7:]) { + if strings.HasPrefix(c.Request.URL.Path, "/1/me") { + c.Next() + return + } + } + if token[7:] != rootToken { + u, err := store.GetUserByID(uint(1)) + if err != nil { + log.Println(err) + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + c.Abort() + return + } + if token[:7] != u.Token { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + c.Abort() + return + } + rootToken = u.Token + store.LoadAuthCache() + } + // 可以在这里对 token 进行验证并检查权限 + + c.Next() + } +} diff --git a/pkg/team/usage.go b/pkg/team/usage.go new file mode 100644 index 0000000..5d9763e --- /dev/null +++ b/pkg/team/usage.go @@ -0,0 +1,38 @@ +package team + +import ( + "net/http" + "opencatd-open/store" + "time" + + "github.com/gin-gonic/gin" +) + +func HandleUsage(c *gin.Context) { + fromStr := c.Query("from") + toStr := c.Query("to") + getMonthStartAndEnd := func() (start, end string) { + loc, _ := time.LoadLocation("Local") + now := time.Now().In(loc) + + year, month, _ := now.Date() + + startOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, loc) + endOfMonth := startOfMonth.AddDate(0, 1, 0) + + start = startOfMonth.Format("2006-01-02") + end = endOfMonth.Format("2006-01-02") + return + } + if fromStr == "" || toStr == "" { + fromStr, toStr = getMonthStartAndEnd() + } + + usage, err := store.QueryUsage(fromStr, toStr) + if err != nil { + c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) + return + } + + c.JSON(200, usage) +} diff --git a/pkg/team/user.go b/pkg/team/user.go new file mode 100644 index 0000000..b848be1 --- /dev/null +++ b/pkg/team/user.go @@ -0,0 +1,89 @@ +package team + +import ( + "net/http" + "opencatd-open/store" + + "github.com/Sakurasan/to" + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +type User struct { + IsDelete bool `json:"IsDelete,omitempty"` + ID int `json:"id,omitempty"` + UpdatedAt string `json:"updatedAt,omitempty"` + Name string `json:"name,omitempty"` + Token string `json:"token,omitempty"` + CreatedAt string `json:"createdAt,omitempty"` +} + +func HandleUsers(c *gin.Context) { + users, err := store.GetAllUsers() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "error": err.Error(), + }) + } + + c.JSON(http.StatusOK, users) +} + +func HandleAddUser(c *gin.Context) { + var body User + if err := c.BindJSON(&body); err != nil { + c.JSON(http.StatusOK, gin.H{"error": err.Error()}) + return + } + if len(body.Name) == 0 { + c.JSON(http.StatusOK, gin.H{"error": "invalid user name"}) + return + } + + if err := store.AddUser(body.Name, uuid.NewString()); err != nil { + c.JSON(http.StatusOK, gin.H{"error": err.Error()}) + return + } + u, err := store.GetUserByName(body.Name) + if err != nil { + c.JSON(http.StatusOK, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, u) +} + +func HandleDelUser(c *gin.Context) { + id := to.Int(c.Param("id")) + if id <= 1 { + c.JSON(http.StatusOK, gin.H{"error": "invalid user id"}) + return + } + if err := store.DeleteUser(uint(id)); err != nil { + c.JSON(http.StatusOK, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "ok"}) +} + +func HandleResetUserToken(c *gin.Context) { + id := to.Int(c.Param("id")) + newtoken := c.Query("token") + if newtoken == "" { + newtoken = uuid.NewString() + } + + if err := store.UpdateUser(uint(id), newtoken); err != nil { + c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) + return + } + u, err := store.GetUserByID(uint(id)) + if err != nil { + c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) + return + } + if u.ID == uint(1) { + rootToken = u.Token + } + c.JSON(http.StatusOK, u) +} diff --git a/pkg/tokenizer/tokenizer.go b/pkg/tokenizer/tokenizer.go index 3f8be85..4176396 100644 --- a/pkg/tokenizer/tokenizer.go +++ b/pkg/tokenizer/tokenizer.go @@ -97,8 +97,14 @@ func Cost(model string, promptCount, completionCount int) float64 { cost = 0.01*float64(prompt/1000) + 0.03*float64(completion/1000) case "gpt-4o", "gpt-4o-2024-05-13": cost = 0.005*float64(prompt/1000) + 0.015*float64(completion/1000) + case "gpt-4o-2024-08-06": + cost = 0.0025*float64(prompt/1000) + 0.010*float64(completion/1000) case "gpt-4o-mini", "gpt-4o-mini-2024-07-18": cost = 0.00015*float64(prompt/1000) + 0.0006*float64(completion/1000) + case "o1-preview", "o1-preview-2024-09-12": + cost = 0.015*float64(prompt/1000) + 0.06*float64(completion/1000) + case "o1-mini", "o1-mini-2024-09-12": + cost = 0.003*float64(prompt/1000) + 0.012*float64(completion/1000) case "whisper-1": // 0.006$/min cost = 0.006 * float64(prompt+completion) / 60 @@ -149,7 +155,8 @@ func Cost(model string, promptCount, completionCount int) float64 { cost = (0.003/1000)*float64(prompt) + (0.015/1000)*float64(completion) case "claude-3-opus-20240229": cost = (0.015/1000)*float64(prompt) + (0.075/1000)*float64(completion) - + case "claude-3-5-sonnet", "claude-3-5-sonnet-20240620": + cost = (0.003/1000)*float64(prompt) + (0.015/1000)*float64(completion) // google // https://ai.google.dev/pricing?hl=zh-cn case "gemini-pro": diff --git a/pkg/vertexai/auth.go b/pkg/vertexai/auth.go new file mode 100644 index 0000000..a7a6967 --- /dev/null +++ b/pkg/vertexai/auth.go @@ -0,0 +1,167 @@ +/* +https://docs.anthropic.com/zh-CN/api/claude-on-vertex-ai + +MODEL_ID=claude-3-5-sonnet@20240620 +REGION=us-east5 +PROJECT_ID=MY_PROJECT_ID + +curl \ +-X POST \ +-H "Authorization: Bearer $(gcloud auth print-access-token)" \ +-H "Content-Type: application/json" \ +https://$LOCATION-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/anthropic/models/${MODEL_ID}:streamRawPredict \ +-d '{ + "anthropic_version": "vertex-2023-10-16", + "messages": [{ + "role": "user", + "content": "介绍一下你自己" + }], + "stream": true, + "max_tokens": 4096 +}' +*/ + +package vertexai + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/golang-jwt/jwt" +) + +// json文件存储在ApiKey.ApiSecret中 +type VertexSecretKey struct { + Type string `json:"type"` + ProjectID string `json:"project_id"` + PrivateKeyID string `json:"private_key_id"` + PrivateKey string `json:"private_key"` + ClientEmail string `json:"client_email"` + ClientID string `json:"client_id"` + AuthURI string `json:"auth_uri"` + TokenURI string `json:"token_uri"` + AuthProviderX509CertURL string `json:"auth_provider_x509_cert_url"` + ClientX509CertURL string `json:"client_x509_cert_url"` + UniverseDomain string `json:"universe_domain"` +} + +type VertexClaudeModel struct { + VertexName string + Region string +} + +var VertexClaudeModelMap = map[string]VertexClaudeModel{ + "claude-3-opus": { + VertexName: "claude-3-opus@20240229", + Region: "us-east5", + }, + "claude-3-sonnet": { + VertexName: "claude-3-sonnet@20240229", + Region: "us-central1", + // Region: "asia-southeast1", + }, + "claude-3-haiku": { + VertexName: "claude-3-haiku@20240307", + Region: "us-central1", + // Region: "europe-west4", + }, + "claude-3-opus-20240229": { + VertexName: "claude-3-opus@20240229", + Region: "us-east5", + }, + "claude-3-sonnet-20240229": { + VertexName: "claude-3-sonnet@20240229", + Region: "us-central1", + // Region: "asia-southeast1", + }, + "claude-3-haiku-20240307": { + VertexName: "claude-3-haiku@20240307", + Region: "us-central1", + // Region: "europe-west4", + }, + "claude-3-5-sonnet": { + VertexName: "claude-3-5-sonnet@20240620", + Region: "us-east5", + // Region: "europe-west1", + }, + "claude-3-5-sonnet-20240620": { + VertexName: "claude-3-5-sonnet@20240620", + Region: "us-east5", + // Region: "europe-west1", + }, +} + +func createSignedJWT(email, privateKeyPEM string) (string, error) { + block, _ := pem.Decode([]byte(privateKeyPEM)) + if block == nil { + return "", fmt.Errorf("failed to parse PEM block containing the private key") + } + + privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return "", err + } + + rsaKey, ok := privateKey.(*rsa.PrivateKey) + if !ok { + return "", fmt.Errorf("not an RSA private key") + } + + now := time.Now() + claims := jwt.MapClaims{ + "iss": email, + "aud": "https://www.googleapis.com/oauth2/v4/token", + "iat": now.Unix(), + "exp": now.Add(10 * time.Minute).Unix(), + "scope": "https://www.googleapis.com/auth/cloud-platform", + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + return token.SignedString(rsaKey) +} + +func exchangeJwtForAccessToken(signedJWT string) (string, error) { + authURL := "https://www.googleapis.com/oauth2/v4/token" + data := url.Values{} + data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer") + data.Set("assertion", signedJWT) + + resp, err := http.PostForm(authURL, data) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", err + } + + accessToken, ok := result["access_token"].(string) + if !ok { + return "", fmt.Errorf("access token not found in response") + } + + return accessToken, nil +} + +// 获取gcloud auth token +func GcloudAuth(ClientEmail, PrivateKey string) (string, error) { + signedJWT, err := createSignedJWT(ClientEmail, PrivateKey) + if err != nil { + return "", err + } + + token, err := exchangeJwtForAccessToken(signedJWT) + if err != nil { + return "", fmt.Errorf("Invalid jwt token: %v\n", err) + } + + return token, nil +} diff --git a/router/chat.go b/router/chat.go index 509f6d9..5292a9f 100644 --- a/router/chat.go +++ b/router/chat.go @@ -18,33 +18,7 @@ func ChatHandler(c *gin.Context) { return } - // if chatreq.Messages[len(chatreq.Messages)-1].Role == "user" { - // result, err := search.BingSearch(search.SearchParams{Query: string(chatreq.Messages[len(chatreq.Messages)-1].Content)}) - // if err == nil { - // var msgs []openai.ChatCompletionMessage - // for i, m := range chatreq.Messages { - // var buf bytes.Buffer - // buf.WriteString("根据我提问的语言回答我,我将提供一些从搜索引擎获取的信息(以websearch:开头)。你自行判断是否使用搜索引擎获取的内容。不要原封不动照抄,根据你自己的知识库提炼信息之后回答我\n\n") - // if m.Role == "system" { - // buf.Write(m.Content) - // msgs = append(msgs, openai.ChatCompletionMessage{Role: m.Role, Content: buf.Bytes()}) - // } else { - // msgs = append(msgs, openai.ChatCompletionMessage{Role: m.Role, Content: buf.Bytes()}) - // } - // if i == len(chatreq.Messages)-1 { - // m.Content = append(m.Content, json.RawMessage("\n\nwebsearch:")...) - // m.Content = append(m.Content, json.RawMessage(result.(string))...) - // msgs = append(msgs, openai.ChatCompletionMessage{Role: m.Role, Content: m.Content}) - // } else { - // msgs = append(msgs, openai.ChatCompletionMessage{Role: m.Role, Content: m.Content}) - // } - - // } - // chatreq.Messages = msgs - // } - // } - - if strings.HasPrefix(chatreq.Model, "gpt") { + if strings.HasPrefix(chatreq.Model, "gpt") || strings.HasPrefix(chatreq.Model, "o1-") { openai.ChatProxy(c, &chatreq) return } diff --git a/router/router.go b/router/router.go index 7e639e4..d1e250b 100644 --- a/router/router.go +++ b/router/router.go @@ -1,108 +1,66 @@ package router import ( - "bufio" - "bytes" - "context" "crypto/tls" - "encoding/json" - "errors" "fmt" - "io" "log" - "mime/multipart" "net" "net/http" "net/http/httputil" - "net/url" - "opencatd-open/pkg/azureopenai" "opencatd-open/pkg/claude" oai "opencatd-open/pkg/openai" - "opencatd-open/pkg/tokenizer" "opencatd-open/store" "os" - "path/filepath" - "strings" "time" - "github.com/Sakurasan/to" - "github.com/duke-git/lancet/v2/cryptor" - "github.com/faiface/beep" - "github.com/faiface/beep/mp3" - "github.com/faiface/beep/wav" "github.com/gin-gonic/gin" - "github.com/google/uuid" - "github.com/sashabaranov/go-openai" - "gopkg.in/vansante/go-ffprobe.v2" - "gorm.io/gorm" ) var ( - rootToken string baseUrl = "https://api.openai.com" GPT3Dot5Turbo = "gpt-3.5-turbo" GPT4 = "gpt-4" - client = getHttpClient() ) -type User struct { - IsDelete bool `json:"IsDelete,omitempty"` - ID int `json:"id,omitempty"` - UpdatedAt string `json:"updatedAt,omitempty"` - Name string `json:"name,omitempty"` - Token string `json:"token,omitempty"` - CreatedAt string `json:"createdAt,omitempty"` -} +// type ChatCompletionMessage struct { +// Role string `json:"role"` +// Content string `json:"content"` +// Name string `json:"name,omitempty"` +// } -type Key struct { - ID int `json:"id,omitempty"` - Key string `json:"key,omitempty"` - Name string `json:"name,omitempty"` - ApiType string `json:"api_type,omitempty"` - Endpoint string `json:"endpoint,omitempty"` - UpdatedAt string `json:"updatedAt,omitempty"` - CreatedAt string `json:"createdAt,omitempty"` -} +// type ChatCompletionRequest struct { +// Model string `json:"model"` +// Messages []ChatCompletionMessage `json:"messages"` +// MaxTokens int `json:"max_tokens,omitempty"` +// Temperature float32 `json:"temperature,omitempty"` +// TopP float32 `json:"top_p,omitempty"` +// N int `json:"n,omitempty"` +// Stream bool `json:"stream,omitempty"` +// Stop []string `json:"stop,omitempty"` +// PresencePenalty float32 `json:"presence_penalty,omitempty"` +// FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` +// LogitBias map[string]int `json:"logit_bias,omitempty"` +// User string `json:"user,omitempty"` +// } -type ChatCompletionMessage struct { - Role string `json:"role"` - Content string `json:"content"` - Name string `json:"name,omitempty"` -} +// type ChatCompletionChoice struct { +// Index int `json:"index"` +// Message ChatCompletionMessage `json:"message"` +// FinishReason string `json:"finish_reason"` +// } -type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []ChatCompletionMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` -} - -type ChatCompletionChoice struct { - Index int `json:"index"` - Message ChatCompletionMessage `json:"message"` - FinishReason string `json:"finish_reason"` -} - -type ChatCompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionChoice `json:"choices"` - Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - } `json:"usage"` -} +// type ChatCompletionResponse struct { +// ID string `json:"id"` +// Object string `json:"object"` +// Created int64 `json:"created"` +// Model string `json:"model"` +// Choices []ChatCompletionChoice `json:"choices"` +// Usage struct { +// PromptTokens int `json:"prompt_tokens"` +// CompletionTokens int `json:"completion_tokens"` +// TotalTokens int `json:"total_tokens"` +// } `json:"usage"` +// } func init() { if openai_endpoint := os.Getenv("openai_endpoint"); openai_endpoint != "" { @@ -111,385 +69,9 @@ func init() { } } -func AuthMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - if rootToken == "" { - u, err := store.GetUserByID(uint(1)) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) - c.Abort() - return - } - rootToken = u.Token - } - token := c.GetHeader("Authorization") - if token == "" || token[:7] != "Bearer " { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) - c.Abort() - return - } - if store.IsExistAuthCache(token[7:]) { - if strings.HasPrefix(c.Request.URL.Path, "/1/me") { - c.Next() - return - } - } - if token[7:] != rootToken { - u, err := store.GetUserByID(uint(1)) - if err != nil { - log.Println(err) - c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) - c.Abort() - return - } - if token[:7] != u.Token { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) - c.Abort() - return - } - rootToken = u.Token - store.LoadAuthCache() - } - // 可以在这里对 token 进行验证并检查权限 - - c.Next() - } -} - -func Handleinit(c *gin.Context) { - user, err := store.GetUserByID(1) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - u := store.User{Name: "root", Token: uuid.NewString()} - u.ID = 1 - if err := store.CreateUser(&u); err != nil { - c.JSON(http.StatusForbidden, gin.H{ - "error": err.Error(), - }) - return - } else { - rootToken = u.Token - resJSON := User{ - false, - int(u.ID), - u.UpdatedAt.Format(time.RFC3339), - u.Name, - u.Token, - u.CreatedAt.Format(time.RFC3339), - } - c.JSON(http.StatusOK, resJSON) - return - } - } - c.JSON(http.StatusOK, gin.H{ - "error": err.Error(), - }) - return - } - if user.ID == uint(1) { - c.JSON(http.StatusForbidden, gin.H{ - "error": "super user already exists, use cli to reset password", - }) - } -} - -func HandleMe(c *gin.Context) { - token := c.GetHeader("Authorization") - u, err := store.GetUserByToken(token[7:]) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "error": err.Error(), - }) - } - - resJSON := User{ - false, - int(u.ID), - u.UpdatedAt.Format(time.RFC3339), - u.Name, - u.Token, - u.CreatedAt.Format(time.RFC3339), - } - c.JSON(http.StatusOK, resJSON) -} - -func HandleMeUsage(c *gin.Context) { - token := c.GetHeader("Authorization") - fromStr := c.Query("from") - toStr := c.Query("to") - getMonthStartAndEnd := func() (start, end string) { - loc, _ := time.LoadLocation("Local") - now := time.Now().In(loc) - - year, month, _ := now.Date() - - startOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, loc) - endOfMonth := startOfMonth.AddDate(0, 1, 0) - - start = startOfMonth.Format("2006-01-02") - end = endOfMonth.Format("2006-01-02") - return - } - if fromStr == "" || toStr == "" { - fromStr, toStr = getMonthStartAndEnd() - } - user, err := store.GetUserByToken(token) - if err != nil { - c.AbortWithError(http.StatusForbidden, err) - return - } - usage, err := store.QueryUserUsage(to.String(user.ID), fromStr, toStr) - if err != nil { - c.AbortWithError(http.StatusForbidden, err) - return - } - - c.JSON(200, usage) -} - -func HandleKeys(c *gin.Context) { - keys, err := store.GetAllKeys() - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "error": err.Error(), - }) - } - - c.JSON(http.StatusOK, keys) -} - -func HandleUsers(c *gin.Context) { - users, err := store.GetAllUsers() - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "error": err.Error(), - }) - } - - c.JSON(http.StatusOK, users) -} - -func HandleAddKey(c *gin.Context) { - var body Key - if err := c.BindJSON(&body); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ - "message": err.Error(), - }}) - return - } - body.Name = strings.ToLower(strings.TrimSpace(body.Name)) - body.Key = strings.TrimSpace(body.Key) - if strings.HasPrefix(body.Name, "azure.") { - keynames := strings.Split(body.Name, ".") - if len(keynames) < 2 { - c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ - "message": "Invalid Key Name", - }}) - return - } - k := &store.Key{ - ApiType: "azure", - Name: body.Name, - Key: body.Key, - ResourceNmae: keynames[1], - EndPoint: body.Endpoint, - } - if err := store.CreateKey(k); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ - "message": err.Error(), - }}) - return - } - } else if strings.HasPrefix(body.Name, "claude.") { - keynames := strings.Split(body.Name, ".") - if len(keynames) < 2 { - c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ - "message": "Invalid Key Name", - }}) - return - } - if body.Endpoint == "" { - body.Endpoint = "https://api.anthropic.com" - } - k := &store.Key{ - // ApiType: "anthropic", - ApiType: "claude", - Name: body.Name, - Key: body.Key, - ResourceNmae: keynames[1], - EndPoint: body.Endpoint, - } - if err := store.CreateKey(k); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ - "message": err.Error(), - }}) - return - } - } else if strings.HasPrefix(body.Name, "google.") { - keynames := strings.Split(body.Name, ".") - if len(keynames) < 2 { - c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ - "message": "Invalid Key Name", - }}) - return - } - - k := &store.Key{ - // ApiType: "anthropic", - ApiType: "google", - Name: body.Name, - Key: body.Key, - ResourceNmae: keynames[1], - EndPoint: body.Endpoint, - } - if err := store.CreateKey(k); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ - "message": err.Error(), - }}) - return - } - } else { - if body.ApiType == "" { - if err := store.AddKey("openai", body.Key, body.Name); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ - "message": err.Error(), - }}) - return - } - } else { - k := &store.Key{ - ApiType: body.ApiType, - Name: body.Name, - Key: body.Key, - ResourceNmae: azureopenai.GetResourceName(body.Endpoint), - EndPoint: body.Endpoint, - } - if err := store.CreateKey(k); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ - "message": err.Error(), - }}) - return - } - } - - } - - k, err := store.GetKeyrByName(body.Name) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ - "message": err.Error(), - }}) - return - } - c.JSON(http.StatusOK, k) -} - -func HandleDelKey(c *gin.Context) { - id := to.Int(c.Param("id")) - if id < 1 { - c.JSON(http.StatusOK, gin.H{"error": "invalid key id"}) - return - } - if err := store.DeleteKey(uint(id)); err != nil { - c.JSON(http.StatusOK, gin.H{"error": "invalid key id"}) - return - } - c.JSON(http.StatusOK, gin.H{"message": "ok"}) -} - -func HandleAddUser(c *gin.Context) { - var body User - if err := c.BindJSON(&body); err != nil { - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) - return - } - if len(body.Name) == 0 { - c.JSON(http.StatusOK, gin.H{"error": "invalid user name"}) - return - } - - if err := store.AddUser(body.Name, uuid.NewString()); err != nil { - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) - return - } - u, err := store.GetUserByName(body.Name) - if err != nil { - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) - return - } - c.JSON(http.StatusOK, u) -} - -func HandleDelUser(c *gin.Context) { - id := to.Int(c.Param("id")) - if id <= 1 { - c.JSON(http.StatusOK, gin.H{"error": "invalid user id"}) - return - } - if err := store.DeleteUser(uint(id)); err != nil { - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{"message": "ok"}) -} - -func HandleResetUserToken(c *gin.Context) { - id := to.Int(c.Param("id")) - newtoken := c.Query("token") - if newtoken == "" { - newtoken = uuid.NewString() - } - - if err := store.UpdateUser(uint(id), newtoken); err != nil { - c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) - return - } - u, err := store.GetUserByID(uint(id)) - if err != nil { - c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) - return - } - if u.ID == uint(1) { - rootToken = u.Token - } - c.JSON(http.StatusOK, u) -} - -func GenerateToken() string { - token := uuid.New() - return token.String() -} - -func getHttpClient() *http.Client { - tr := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } - return &http.Client{Transport: tr} -} - -func HandleProy(c *gin.Context) { +func HandleProxy(c *gin.Context) { var ( - localuser bool - isStream bool - chatreq = openai.ChatCompletionRequest{} - chatres = openai.ChatCompletionResponse{} - chatlog store.Tokens - onekey store.Key - pre_prompt string - req *http.Request - err error - // wg sync.WaitGroup + localuser bool ) auth := c.Request.Header.Get("Authorization") if len(auth) > 7 && auth[:7] == "Bearer " { @@ -497,11 +79,17 @@ func HandleProy(c *gin.Context) { c.Set("localuser", auth[7:]) } if c.Request.URL.Path == "/v1/complete" { - claude.ClaudeProxy(c) - return + if localuser { + claude.ClaudeProxy(c) + return + } else { + HandleReverseProxy(c, "api.anthropic.com") + return + } + } if c.Request.URL.Path == "/v1/audio/transcriptions" { - WhisperProxy(c) + oai.WhisperProxy(c) return } if c.Request.URL.Path == "/v1/audio/speech" { @@ -514,191 +102,30 @@ func HandleProy(c *gin.Context) { return } - if c.Request.URL.Path == "/v1/chat/completions" && localuser { - if store.KeysCache.ItemCount() == 0 { - c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{ - "message": "No Api-Key Available", - }}) - return - } - - ChatHandler(c) - return - - if err := c.BindJSON(&chatreq); err != nil { - c.AbortWithError(http.StatusBadRequest, err) - return - } - - chatlog.Model = chatreq.Model - for _, m := range chatreq.Messages { - pre_prompt += m.Content + "\n" - } - chatlog.PromptHash = cryptor.Md5String(pre_prompt) - chatlog.PromptCount = tokenizer.NumTokensFromMessages(chatreq.Messages, chatreq.Model) - isStream = chatreq.Stream - chatlog.UserID, _ = store.GetUserID(auth[7:]) - - var body bytes.Buffer - json.NewEncoder(&body).Encode(chatreq) - - if strings.HasPrefix(chatreq.Model, "claude-") { - onekey, err = store.SelectKeyCache("claude") - if err != nil { - c.AbortWithError(http.StatusForbidden, err) - } - } else { - onekey = store.FromKeyCacheRandomItemKey() - } - - // 创建 API 请求 - switch onekey.ApiType { - case "claude": - payload, _ := claude.TransReq(&chatreq) - buildurl := "https://api.anthropic.com/v1/complete" - req, err = http.NewRequest("POST", buildurl, payload) - req.Header.Add("accept", "application/json") - req.Header.Add("anthropic-version", "2023-06-01") - req.Header.Add("x-api-key", onekey.Key) - req.Header.Add("content-type", "application/json") - case "azure": - fallthrough - case "azure_openai": - var buildurl string - var apiVersion = "2023-05-15" - if onekey.EndPoint != "" { - buildurl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", onekey.EndPoint, modelmap(chatreq.Model), apiVersion) - } else { - buildurl = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s", onekey.ResourceNmae, modelmap(chatreq.Model), apiVersion) - } - req, err = http.NewRequest(c.Request.Method, buildurl, &body) - req.Header = c.Request.Header - req.Header.Set("api-key", onekey.Key) - case "openai": - fallthrough - default: - if onekey.EndPoint != "" { - req, err = http.NewRequest(c.Request.Method, onekey.EndPoint+c.Request.RequestURI, &body) - } else { - req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body) - } - - req.Header = c.Request.Header - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key)) - } - if err != nil { - log.Println(err) - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) - return - } - - } else { - req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, c.Request.Body) - if err != nil { - log.Println(err) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - req.Header = c.Request.Header - } - - resp, err := client.Do(req) - if err != nil { - log.Println(err) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - defer resp.Body.Close() - - // 复制 API 响应头部 - for name, values := range resp.Header { - for _, value := range values { - c.Writer.Header().Add(name, value) - } - } - head := map[string]string{ - "Cache-Control": "no-store", - "access-control-allow-origin": "*", - "access-control-allow-credentials": "true", - } - for k, v := range head { - if _, ok := resp.Header[k]; !ok { - c.Writer.Header().Set(k, v) - } - } - resp.Header.Del("content-security-policy") - resp.Header.Del("content-security-policy-report-only") - resp.Header.Del("clear-site-data") - - c.Writer.WriteHeader(resp.StatusCode) - writer := bufio.NewWriter(c.Writer) - defer writer.Flush() - - reader := bufio.NewReader(resp.Body) - - if resp.StatusCode == 200 && localuser { - switch onekey.ApiType { - case "claude": - claude.TransRsp(c, isStream, chatlog, reader) - return - case "openai", "azure", "azure_openai": - fallthrough - default: - if isStream { - contentCh := fetchResponseContent(c, reader) - var buffer bytes.Buffer - for content := range contentCh { - buffer.WriteString(content) - } - chatlog.CompletionCount = tokenizer.NumTokensFromStr(buffer.String(), chatreq.Model) - chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount - chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) - if err := store.Record(&chatlog); err != nil { - log.Println(err) - } - if err := store.SumDaily(chatlog.UserID); err != nil { - log.Println(err) - } - return - } - res, err := io.ReadAll(reader) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{ - "message": err.Error(), + if c.Request.URL.Path == "/v1/chat/completions" { + if localuser { + if store.KeysCache.ItemCount() == 0 { + c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{ + "message": "No Api-Key Available", }}) return } - reader = bufio.NewReader(bytes.NewBuffer(res)) - json.NewDecoder(bytes.NewBuffer(res)).Decode(&chatres) - chatlog.PromptCount = chatres.Usage.PromptTokens - chatlog.CompletionCount = chatres.Usage.CompletionTokens - chatlog.TotalTokens = chatres.Usage.TotalTokens - chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) - if err := store.Record(&chatlog); err != nil { - log.Println(err) - } - if err := store.SumDaily(chatlog.UserID); err != nil { - log.Println(err) - } - } - } - // 返回 API 响应主体 - if _, err := io.Copy(writer, reader); err != nil { - log.Println(err) - c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{ - "message": err.Error(), - }}) + ChatHandler(c) + return + } + } else { + HandleReverseProxy(c, "api.openai.com") return } + } -func HandleReverseProxy(c *gin.Context) { +func HandleReverseProxy(c *gin.Context, targetHost string) { proxy := &httputil.ReverseProxy{ Director: func(req *http.Request) { req.URL.Scheme = "https" - req.URL.Host = "api.openai.com" - // req.Header.Set("Authorization", "Bearer YOUR_API_KEY_HERE") + req.URL.Host = targetHost }, Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -715,266 +142,13 @@ func HandleReverseProxy(c *gin.Context) { }, } - var localuser bool - auth := c.Request.Header.Get("Authorization") - if len(auth) > 7 && auth[:7] == "Bearer " { - log.Println(store.IsExistAuthCache(auth[7:])) - localuser = store.IsExistAuthCache(auth[7:]) - } - req, err := http.NewRequest(c.Request.Method, c.Request.URL.Path, c.Request.Body) if err != nil { c.JSON(http.StatusOK, gin.H{"error": err.Error()}) return } req.Header = c.Request.Header - if localuser { - if store.KeysCache.ItemCount() == 0 { - c.JSON(http.StatusOK, gin.H{"error": "No Api-Key Available"}) - return - } - onekey := store.FromKeyCacheRandomItemKey() - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key)) - } proxy.ServeHTTP(c.Writer, req) - -} - -func HandleUsage(c *gin.Context) { - fromStr := c.Query("from") - toStr := c.Query("to") - getMonthStartAndEnd := func() (start, end string) { - loc, _ := time.LoadLocation("Local") - now := time.Now().In(loc) - - year, month, _ := now.Date() - - startOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, loc) - endOfMonth := startOfMonth.AddDate(0, 1, 0) - - start = startOfMonth.Format("2006-01-02") - end = endOfMonth.Format("2006-01-02") - return - } - if fromStr == "" || toStr == "" { - fromStr, toStr = getMonthStartAndEnd() - } - - usage, err := store.QueryUsage(fromStr, toStr) - if err != nil { - c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) - return - } - - c.JSON(200, usage) -} - -func fetchResponseContent(ctx *gin.Context, responseBody *bufio.Reader) <-chan string { - contentCh := make(chan string) - go func() { - defer close(contentCh) - for { - line, err := responseBody.ReadString('\n') - if err == nil { - lines := strings.Split(line, "") - for _, word := range lines { - ctx.Writer.WriteString(word) - ctx.Writer.Flush() - } - if line == "\n" { - continue - } - if strings.HasPrefix(line, "data:") { - line = strings.TrimSpace(strings.TrimPrefix(line, "data:")) - if strings.HasSuffix(line, "[DONE]") { - break - } - line = strings.TrimSpace(line) - } - - dec := json.NewDecoder(strings.NewReader(line)) - var data map[string]interface{} - if err := dec.Decode(&data); err == io.EOF { - log.Println("EOF:", err) - break - } else if err != nil { - fmt.Println("Error decoding response:", err) - return - } - if choices, ok := data["choices"].([]interface{}); ok { - for _, choice := range choices { - choiceMap := choice.(map[string]interface{}) - if content, ok := choiceMap["delta"].(map[string]interface{})["content"]; ok { - contentCh <- content.(string) - } - } - } - } else { - break - } - } - }() - return contentCh -} - -func modelmap(in string) string { - // gpt-3.5-turbo -> gpt-35-turbo - if strings.Contains(in, ".") { - return strings.ReplaceAll(in, ".", "") - } - return in -} - -func WhisperProxy(c *gin.Context) { - var chatlog store.Tokens - - byteBody, _ := io.ReadAll(c.Request.Body) - c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody)) - - model, _ := c.GetPostForm("model") - - key, err := store.SelectKeyCache("openai") - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": err.Error(), - }, - }) - return - } - - chatlog.Model = model - - token, _ := c.Get("localuser") - - lu, err := store.GetUserByToken(token.(string)) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": err.Error(), - }, - }) - return - } - chatlog.UserID = int(lu.ID) - - if err := ParseWhisperRequestTokens(c, &chatlog, byteBody); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "error": gin.H{ - "message": err.Error(), - }, - }) - return - } - if key.EndPoint == "" { - key.EndPoint = "https://api.openai.com" - } - targetUrl, _ := url.ParseRequestURI(key.EndPoint + c.Request.URL.String()) - log.Println(targetUrl) - proxy := httputil.NewSingleHostReverseProxy(targetUrl) - proxy.Director = func(req *http.Request) { - req.Host = targetUrl.Host - req.URL.Scheme = targetUrl.Scheme - req.URL.Host = targetUrl.Host - - req.Header.Set("Authorization", "Bearer "+key.Key) - } - - proxy.ModifyResponse = func(resp *http.Response) error { - if resp.StatusCode != http.StatusOK { - return nil - } - chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount - chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) - if err := store.Record(&chatlog); err != nil { - log.Println(err) - } - if err := store.SumDaily(chatlog.UserID); err != nil { - log.Println(err) - } - return nil - } - proxy.ServeHTTP(c.Writer, c.Request) -} - -func probe(fileReader io.Reader) (time.Duration, error) { - ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) - defer cancelFn() - - data, err := ffprobe.ProbeReader(ctx, fileReader) - if err != nil { - return 0, err - } - - duration := data.Format.DurationSeconds - pduration, err := time.ParseDuration(fmt.Sprintf("%fs", duration)) - if err != nil { - return 0, fmt.Errorf("Error parsing duration: %s", err) - } - return pduration, nil -} - -func getAudioDuration(file *multipart.FileHeader) (time.Duration, error) { - var ( - streamer beep.StreamSeekCloser - format beep.Format - err error - ) - - f, err := file.Open() - defer f.Close() - - // Get the file extension to determine the audio file type - fileType := filepath.Ext(file.Filename) - - switch fileType { - case ".mp3": - streamer, format, err = mp3.Decode(f) - case ".wav": - streamer, format, err = wav.Decode(f) - case ".m4a": - duration, err := probe(f) - if err != nil { - return 0, err - } - return duration, nil - default: - return 0, errors.New("unsupported audio file format") - } - - if err != nil { - return 0, err - } - defer streamer.Close() - - // Calculate the audio file's duration. - numSamples := streamer.Len() - sampleRate := format.SampleRate - duration := time.Duration(numSamples) * time.Second / time.Duration(sampleRate) - - return duration, nil -} - -func ParseWhisperRequestTokens(c *gin.Context, usage *store.Tokens, byteBody []byte) error { - file, _ := c.FormFile("file") - model, _ := c.GetPostForm("model") - usage.Model = model - - if file != nil { - duration, err := getAudioDuration(file) - if err != nil { - return fmt.Errorf("Error getting audio duration:%s", err) - } - - if duration > 5*time.Minute { - return fmt.Errorf("Audio duration exceeds 5 minutes") - } - // 计算时长,四舍五入到最接近的秒数 - usage.PromptCount = int(duration.Round(time.Second).Seconds()) - } - - c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody)) - - return nil + return } diff --git a/store/cache.go b/store/cache.go index 64009f4..2a9462e 100644 --- a/store/cache.go +++ b/store/cache.go @@ -53,6 +53,14 @@ func SelectKeyCache(apitype string) (Key, error) { if item.Object.(Key).ApiType == "azure" { keys = append(keys, item.Object.(Key)) } + if item.Object.(Key).ApiType == "github" { + keys = append(keys, item.Object.(Key)) + } + } + if apitype == "claude" { + if item.Object.(Key).ApiType == "vertex" { + keys = append(keys, item.Object.(Key)) + } } } if len(keys) == 0 { diff --git a/store/keydb.go b/store/keydb.go index c5d106b..c7d55db 100644 --- a/store/keydb.go +++ b/store/keydb.go @@ -2,9 +2,35 @@ package store import ( "encoding/json" + "fmt" + "log" + "opencatd-open/pkg/vertexai" + "os" "time" ) +func init() { + // check vertex + if os.Getenv("Vertex") != "" { + vertex_auth := os.Getenv("Vertex") + var Vertex vertexai.VertexSecretKey + if err := json.Unmarshal([]byte(vertex_auth), &Vertex); err != nil { + log.Fatalln(err) + return + } + key := Key{ + ApiType: "vertex", + Name: Vertex.ProjectID, + Key: vertex_auth, + ApiSecret: vertex_auth, + } + if err := db.FirstOrCreate(&key).Error; err != nil { + log.Println(fmt.Errorf("create vertex key error: %v", err)) + } + } + LoadKeysCache() +} + type Key struct { ID uint `gorm:"primarykey" json:"id,omitempty"` Key string `gorm:"unique;not null" json:"key,omitempty"` @@ -14,6 +40,7 @@ type Key struct { EndPoint string `gorm:"column:endpoint"` ResourceNmae string `gorm:"column:resource_name"` DeploymentName string `gorm:"column:deployment_name"` + ApiSecret string `gorm:"column:api_secret"` CreatedAt time.Time `json:"createdAt,omitempty"` UpdatedAt time.Time `json:"updatedAt,omitempty"` }