From 7fd82b43f45f54dd846a36e39c2606a2170f0f26 Mon Sep 17 00:00:00 2001
From: Sakurasan <26715255+Sakurasan@users.noreply.github.com>
Date: Fri, 13 Sep 2024 21:32:10 +0800
Subject: [PATCH] refact & update new model
---
README.md | 29 +-
docker-compose.yml | 26 +
docker/Dockerfile | 2 +-
docker/docker-compose.yml | 20 +-
opencat.go | 49 +-
pkg/claude/chat.go | 67 ++-
pkg/claude/claude.go | 4 +-
pkg/error/errdata.go | 11 +
pkg/openai/chat.go | 25 +-
pkg/openai/whisper.go | 177 +++++++
pkg/team/key.go | 182 +++++++
pkg/team/me.go | 104 ++++
pkg/team/middleware.go | 59 +++
pkg/team/usage.go | 38 ++
pkg/team/user.go | 89 ++++
pkg/tokenizer/tokenizer.go | 9 +-
pkg/vertexai/auth.go | 167 +++++++
router/chat.go | 28 +-
router/router.go | 948 +++----------------------------------
store/cache.go | 8 +
store/keydb.go | 27 ++
21 files changed, 1094 insertions(+), 975 deletions(-)
create mode 100644 docker-compose.yml
create mode 100644 pkg/error/errdata.go
create mode 100644 pkg/openai/whisper.go
create mode 100644 pkg/team/key.go
create mode 100644 pkg/team/me.go
create mode 100644 pkg/team/middleware.go
create mode 100644 pkg/team/usage.go
create mode 100644 pkg/team/user.go
create mode 100644 pkg/vertexai/auth.go
diff --git a/README.md b/README.md
index 12334ff..6acde4b 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,7 @@
-# opencatd-open
+# ~~opencatd-open~~ [OpenTeam](https://github.com/mirrors2/opencatd-open)
+
+ 本项目即将更名,后续请关注 👉🏻 https://github.com/mirrors2/openteam
+
@@ -14,11 +17,11 @@ OpenCat for Team的开源实现
## Extra Support:
-| 任务 | 完成情况 |
-| --- | --- |
-|[Azure OpenAI](./doc/azure.md) | ✅|
-|[Claude](./doc/azure.md) | ✅|
-|[Gemini](./doc/gemini.md) | ✅|
+| 🎯 | 🚧 |Extra Provider|
+| --- | --- | --- |
+|[OpenAI](./doc/azure.md) | ✅|Azure, Github Marketplace|
+|[Claude](./doc/azure.md) | ✅|VertexAI|
+|[Gemini](./doc/gemini.md) | ✅||
| ... | ... |
@@ -80,6 +83,18 @@ wget https://github.com/mirrors2/opencatd-open/raw/main/docker/docker-compose.ym
pandora for team
- [pandora for team](./doc/pandora.md)
+
+如何自定义HOST地址? (仅OpenAI)
+ - 需修改环境变量,优先级递增
+ - Cloudflare AI Gateway地址 `AIGateWay_Endpoint=https://gateway.ai.cloudflare.com/v1/123456789/xxxx/openai/chat/completions`
+ - 自定义的endpoint `$CUSTOM_ENDPOINT=true && $OpenAI_Endpoint=https://your.domain/v1/chat/completions`
+
+设置主页跳转地址?
+ - 修改环境变量 `CUSTOM_REDIRECT=https://your.domain`
+
+## 赞助
+[](https://www.buymeacoffee.com/littlecjun)
+
# License
-[GNU General Public License v3.0](License)
\ No newline at end of file
+[](https://github.com/mirrors2/opencatd-open/blob/main/License)
diff --git a/docker-compose.yml b/docker-compose.yml
new file mode 100644
index 0000000..36ade56
--- /dev/null
+++ b/docker-compose.yml
@@ -0,0 +1,26 @@
+# Email: admin@example.com
+# Password: changeme
+version: '3'
+
+services:
+ npm:
+ image: jc21/nginx-proxy-manager
+ network_mode: host
+ ports:
+ - '80:80'
+ - '81:81'
+ - '443:443'
+ volumes:
+ - $PWD/data:/data
+ - $PWD/www:/var/www
+ - $PWD/letsencrypt:/etc/letsencrypt
+ environment:
+ - "TZ=Asia/Shanghai" # set timezone, default UTC
+ - "PUID=1000" # set group id, default 0 (root)
+ - "PGID=1000"
+
+ # certbot:
+ # image: certbot/certbot
+ # volumes:
+ # - $PWD/data/certbot/conf:/etc/letsencrypt
+ # - $PWD/data/certbot/www:/var/www/certbot
diff --git a/docker/Dockerfile b/docker/Dockerfile
index f11aa0d..968a3d4 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -1,4 +1,4 @@
-FROM node:18.12.1-alpine3.16 AS frontend
+FROM node:20-alpine AS frontend
WORKDIR /frontend-build
COPY ./web/ .
RUN npm install && npm run build && rm -rf node_modules
diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml
index 154e9fe..29bf135 100644
--- a/docker/docker-compose.yml
+++ b/docker/docker-compose.yml
@@ -1,10 +1,24 @@
version: '3.7'
-services:
+services:
opencatd:
image: mirrors2/opencatd-open
- container_name: opencatd-open
+ container_name: opencatd-open
restart: unless-stopped
+ #network_mode: host
ports:
- 80:80
volumes:
- - /etc/opencatd:/app/db
\ No newline at end of file
+ - $PWD/db:/app/db
+ logging:
+ # driver: "json-file"
+ options:
+ max-size: 10m
+ max-file: 3
+ # environment:
+ # Vertex: |
+ # {
+ # "type": "service_account",
+ # "universe_domain": "googleapis.com"
+ # }
+
+
\ No newline at end of file
diff --git a/opencat.go b/opencat.go
index 6dee33a..a3ebfe5 100644
--- a/opencat.go
+++ b/opencat.go
@@ -8,6 +8,7 @@ import (
"io/fs"
"log"
"net/http"
+ "opencatd-open/pkg/team"
"opencatd-open/router"
"opencatd-open/store"
"os"
@@ -143,41 +144,29 @@ func main() {
r := gin.Default()
group := r.Group("/1")
{
- group.Use(router.AuthMiddleware())
+ group.Use(team.AuthMiddleware())
// 获取当前用户信息
- group.GET("/me", router.HandleMe)
+ group.GET("/me", team.HandleMe)
+ group.GET("/me/usages", team.HandleMeUsage)
- group.GET("/me/usages", router.HandleMeUsage)
+ group.GET("/keys", team.HandleKeys) // 获取所有Key
+ group.POST("/keys", team.HandleAddKey) // 添加Key
+ group.DELETE("/keys/:id", team.HandleDelKey) // 删除Key
- // 获取所有Key
- group.GET("/keys", router.HandleKeys)
+ group.GET("/users", team.HandleUsers) // 获取所有用户信息
+ group.POST("/users", team.HandleAddUser) // 添加用户
+ group.DELETE("/users/:id", team.HandleDelUser) // 删除用户
- // 获取所有用户信息
- group.GET("/users", router.HandleUsers)
-
- group.GET("/usages", router.HandleUsage)
-
- // 添加Key
- group.POST("/keys", router.HandleAddKey)
-
- // 删除Key
- group.DELETE("/keys/:id", router.HandleDelKey)
-
- // 添加用户
- group.POST("/users", router.HandleAddUser)
-
- // 删除用户
- group.DELETE("/users/:id", router.HandleDelUser)
+ group.GET("/usages", team.HandleUsage)
// 重置用户Token
- group.POST("/users/:id/reset", router.HandleResetUserToken)
+ group.POST("/users/:id/reset", team.HandleResetUserToken)
}
-
// 初始化用户
- r.POST("/1/users/init", router.Handleinit)
+ r.POST("/1/users/init", team.Handleinit)
- r.Any("/v1/*proxypath", router.HandleProy)
+ r.Any("/v1/*proxypath", router.HandleProxy)
// r.POST("/v1/chat/completions", router.HandleProy)
// r.GET("/v1/models", router.HandleProy)
@@ -188,7 +177,15 @@ func main() {
if err != nil {
panic(err)
}
- r.GET("/", gin.WrapH(http.FileServer(http.FS(idxFS))))
+ redirect := os.Getenv("CUSTOM_REDIRECT")
+ if redirect != "" {
+ r.GET("/", func(c *gin.Context) {
+ c.Redirect(http.StatusMovedPermanently, redirect)
+ })
+
+ } else {
+ r.GET("/", gin.WrapH(http.FileServer(http.FS(idxFS))))
+ }
assetsFS, err := fs.Sub(web, "dist/assets")
if err != nil {
panic(err)
diff --git a/pkg/claude/chat.go b/pkg/claude/chat.go
index 8987dc7..39afe48 100644
--- a/pkg/claude/chat.go
+++ b/pkg/claude/chat.go
@@ -10,8 +10,10 @@ import (
"io"
"log"
"net/http"
+ "opencatd-open/pkg/error"
"opencatd-open/pkg/openai"
"opencatd-open/pkg/tokenizer"
+ "opencatd-open/pkg/vertexai"
"opencatd-open/store"
"strings"
@@ -27,14 +29,15 @@ func ChatTextCompletions(c *gin.Context, chatReq *openai.ChatCompletionRequest)
}
type ChatRequest struct {
- Model string `json:"model,omitempty"`
- Messages any `json:"messages,omitempty"`
- MaxTokens int `json:"max_tokens,omitempty"`
- Stream bool `json:"stream,omitempty"`
- System string `json:"system,omitempty"`
- TopK int `json:"top_k,omitempty"`
- TopP float64 `json:"top_p,omitempty"`
- Temperature float64 `json:"temperature,omitempty"`
+ Model string `json:"model,omitempty"`
+ Messages any `json:"messages,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ System string `json:"system,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ AnthropicVersion string `json:"anthropic_version,omitempty"`
}
func (c *ChatRequest) ByteJson() []byte {
@@ -117,8 +120,12 @@ type ClaudeStreamResponse struct {
}
func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
+ var (
+ req *http.Request
+ targetURL = ClaudeMessageEndpoint
+ )
- onekey, err := store.SelectKeyCache("claude")
+ apiKey, err := store.SelectKeyCache("claude")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -131,6 +138,10 @@ func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
// claudReq.Temperature = chatReq.Temperature
claudReq.TopP = chatReq.TopP
claudReq.MaxTokens = 4096
+ if apiKey.ApiType == "vertex" {
+ claudReq.AnthropicVersion = "vertex-2023-10-16"
+ claudReq.Model = ""
+ }
var prompt string
@@ -181,10 +192,40 @@ func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
usagelog.PromptCount = tokenizer.NumTokensFromStr(prompt, chatReq.Model)
- req, _ := http.NewRequest("POST", MessageEndpoint, bytes.NewReader(claudReq.ByteJson()))
- req.Header.Set("x-api-key", onekey.Key)
- req.Header.Set("anthropic-version", "2023-06-01")
- req.Header.Set("Content-Type", "application/json")
+ if apiKey.ApiType == "vertex" {
+ var vertexSecret vertexai.VertexSecretKey
+ if err := json.Unmarshal([]byte(apiKey.ApiSecret), &vertexSecret); err != nil {
+ c.JSON(http.StatusInternalServerError, error.ErrorData(err.Error()))
+ return
+ }
+
+ vcmodel, ok := vertexai.VertexClaudeModelMap[chatReq.Model]
+ if !ok {
+ c.JSON(http.StatusInternalServerError, error.ErrorData("Model not found"))
+ return
+ }
+
+ // 获取gcloud token,临时放置在apiKey.Key中
+ gcloudToken, err := vertexai.GcloudAuth(vertexSecret.ClientEmail, vertexSecret.PrivateKey)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, error.ErrorData(err.Error()))
+ return
+ }
+
+ // 拼接vertex的请求地址
+ targetURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", vcmodel.Region, vertexSecret.ProjectID, vcmodel.Region, vcmodel.VertexName)
+
+ req, _ = http.NewRequest("POST", targetURL, bytes.NewReader(claudReq.ByteJson()))
+ req.Header.Set("Authorization", "Bearer "+gcloudToken)
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "text/event-stream")
+ req.Header.Set("Accept-Encoding", "identity")
+ } else {
+ req, _ = http.NewRequest("POST", targetURL, bytes.NewReader(claudReq.ByteJson()))
+ req.Header.Set("x-api-key", apiKey.Key)
+ req.Header.Set("anthropic-version", "2023-06-01")
+ req.Header.Set("Content-Type", "application/json")
+ }
client := http.DefaultClient
rsp, err := client.Do(req)
diff --git a/pkg/claude/claude.go b/pkg/claude/claude.go
index 0964f46..3c5f110 100644
--- a/pkg/claude/claude.go
+++ b/pkg/claude/claude.go
@@ -68,8 +68,8 @@ import (
)
var (
- ClaudeUrl = "https://api.anthropic.com/v1/complete"
- MessageEndpoint = "https://api.anthropic.com/v1/messages"
+ ClaudeUrl = "https://api.anthropic.com/v1/complete"
+ ClaudeMessageEndpoint = "https://api.anthropic.com/v1/messages"
)
type MessageModule struct {
diff --git a/pkg/error/errdata.go b/pkg/error/errdata.go
new file mode 100644
index 0000000..7a5b146
--- /dev/null
+++ b/pkg/error/errdata.go
@@ -0,0 +1,11 @@
+package error
+
+import "github.com/gin-gonic/gin"
+
+func ErrorData(message string) gin.H {
+ return gin.H{
+ "error": gin.H{
+ "message": message,
+ },
+ }
+}
diff --git a/pkg/openai/chat.go b/pkg/openai/chat.go
index b7270b7..f6c9dce 100644
--- a/pkg/openai/chat.go
+++ b/pkg/openai/chat.go
@@ -17,8 +17,10 @@ import (
)
const (
- AzureApiVersion = "2024-02-01"
- OpenAI_Endpoint = "https://api.openai.com/v1/chat/completions"
+ AzureApiVersion = "2024-02-01"
+ BaseHost = "api.openai.com"
+ OpenAI_Endpoint = "https://api.openai.com/v1/chat/completions"
+ Github_Marketplace = "https://models.inference.ai.azure.com/chat/completions"
)
var (
@@ -204,6 +206,10 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) {
var req *http.Request
switch onekey.ApiType {
+ case "github":
+ req, err = http.NewRequest(c.Request.Method, Github_Marketplace, bytes.NewReader(chatReq.ToByteJson()))
+ req.Header = c.Request.Header
+ req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key))
case "azure":
var buildurl string
if onekey.EndPoint != "" {
@@ -215,15 +221,18 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) {
req.Header = c.Request.Header
req.Header.Set("api-key", onekey.Key)
default:
+ req, err = http.NewRequest(c.Request.Method, OpenAI_Endpoint, bytes.NewReader(chatReq.ToByteJson()))
if onekey.EndPoint != "" { // 优先key的endpoint
req, err = http.NewRequest(c.Request.Method, onekey.EndPoint+c.Request.RequestURI, bytes.NewReader(chatReq.ToByteJson()))
- } else {
- if BaseURL != "" { // 其次BaseURL
- req, err = http.NewRequest(c.Request.Method, BaseURL+c.Request.RequestURI, bytes.NewReader(chatReq.ToByteJson()))
- } else { // 最后是gateway的endpoint
- req, err = http.NewRequest(c.Request.Method, AIGateWay_Endpoint, bytes.NewReader(chatReq.ToByteJson()))
- }
}
+ if AIGateWay_Endpoint != "" { // cloudflare gateway的endpoint
+ req, err = http.NewRequest(c.Request.Method, AIGateWay_Endpoint, bytes.NewReader(chatReq.ToByteJson()))
+ }
+ customEndpoint := os.Getenv("CUSTOM_ENDPOINT") // 最后是用户自定义的endpoint CUSTOM_ENDPOINT=true OpenAI_Endpoint
+ if customEndpoint == "true" && OpenAI_Endpoint != "" {
+ req, err = http.NewRequest(c.Request.Method, BaseURL, bytes.NewReader(chatReq.ToByteJson()))
+ }
+
req.Header = c.Request.Header
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key))
}
diff --git a/pkg/openai/whisper.go b/pkg/openai/whisper.go
new file mode 100644
index 0000000..ff8fd18
--- /dev/null
+++ b/pkg/openai/whisper.go
@@ -0,0 +1,177 @@
+package openai
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+ "mime/multipart"
+ "net/http"
+ "net/http/httputil"
+ "net/url"
+ "opencatd-open/pkg/tokenizer"
+ "opencatd-open/store"
+ "path/filepath"
+ "time"
+
+ "github.com/faiface/beep"
+ "github.com/faiface/beep/mp3"
+ "github.com/faiface/beep/wav"
+ "github.com/gin-gonic/gin"
+ "gopkg.in/vansante/go-ffprobe.v2"
+)
+
+func WhisperProxy(c *gin.Context) {
+ var chatlog store.Tokens
+
+ byteBody, _ := io.ReadAll(c.Request.Body)
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody))
+
+ model, _ := c.GetPostForm("model")
+
+ key, err := store.SelectKeyCache("openai")
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{
+ "error": gin.H{
+ "message": err.Error(),
+ },
+ })
+ return
+ }
+
+ chatlog.Model = model
+
+ token, _ := c.Get("localuser")
+
+ lu, err := store.GetUserByToken(token.(string))
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{
+ "error": gin.H{
+ "message": err.Error(),
+ },
+ })
+ return
+ }
+ chatlog.UserID = int(lu.ID)
+
+ if err := ParseWhisperRequestTokens(c, &chatlog, byteBody); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{
+ "error": gin.H{
+ "message": err.Error(),
+ },
+ })
+ return
+ }
+ if key.EndPoint == "" {
+ key.EndPoint = "https://api.openai.com"
+ }
+ targetUrl, _ := url.ParseRequestURI(key.EndPoint + c.Request.URL.String())
+ log.Println(targetUrl)
+ proxy := httputil.NewSingleHostReverseProxy(targetUrl)
+ proxy.Director = func(req *http.Request) {
+ req.Host = targetUrl.Host
+ req.URL.Scheme = targetUrl.Scheme
+ req.URL.Host = targetUrl.Host
+
+ req.Header.Set("Authorization", "Bearer "+key.Key)
+ }
+
+ proxy.ModifyResponse = func(resp *http.Response) error {
+ if resp.StatusCode != http.StatusOK {
+ return nil
+ }
+ chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
+ chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
+ if err := store.Record(&chatlog); err != nil {
+ log.Println(err)
+ }
+ if err := store.SumDaily(chatlog.UserID); err != nil {
+ log.Println(err)
+ }
+ return nil
+ }
+ proxy.ServeHTTP(c.Writer, c.Request)
+}
+
+func probe(fileReader io.Reader) (time.Duration, error) {
+ ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancelFn()
+
+ data, err := ffprobe.ProbeReader(ctx, fileReader)
+ if err != nil {
+ return 0, err
+ }
+
+ duration := data.Format.DurationSeconds
+ pduration, err := time.ParseDuration(fmt.Sprintf("%fs", duration))
+ if err != nil {
+ return 0, fmt.Errorf("Error parsing duration: %s", err)
+ }
+ return pduration, nil
+}
+
+func getAudioDuration(file *multipart.FileHeader) (time.Duration, error) {
+ var (
+ streamer beep.StreamSeekCloser
+ format beep.Format
+ err error
+ )
+
+ f, err := file.Open()
+ defer f.Close()
+
+ // Get the file extension to determine the audio file type
+ fileType := filepath.Ext(file.Filename)
+
+ switch fileType {
+ case ".mp3":
+ streamer, format, err = mp3.Decode(f)
+ case ".wav":
+ streamer, format, err = wav.Decode(f)
+ case ".m4a":
+ duration, err := probe(f)
+ if err != nil {
+ return 0, err
+ }
+ return duration, nil
+ default:
+ return 0, errors.New("unsupported audio file format")
+ }
+
+ if err != nil {
+ return 0, err
+ }
+ defer streamer.Close()
+
+ // Calculate the audio file's duration.
+ numSamples := streamer.Len()
+ sampleRate := format.SampleRate
+ duration := time.Duration(numSamples) * time.Second / time.Duration(sampleRate)
+
+ return duration, nil
+}
+
+func ParseWhisperRequestTokens(c *gin.Context, usage *store.Tokens, byteBody []byte) error {
+ file, _ := c.FormFile("file")
+ model, _ := c.GetPostForm("model")
+ usage.Model = model
+
+ if file != nil {
+ duration, err := getAudioDuration(file)
+ if err != nil {
+ return fmt.Errorf("Error getting audio duration:%s", err)
+ }
+
+ if duration > 5*time.Minute {
+ return fmt.Errorf("Audio duration exceeds 5 minutes")
+ }
+ // 计算时长,四舍五入到最接近的秒数
+ usage.PromptCount = int(duration.Round(time.Second).Seconds())
+ }
+
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody))
+
+ return nil
+}
diff --git a/pkg/team/key.go b/pkg/team/key.go
new file mode 100644
index 0000000..802f10b
--- /dev/null
+++ b/pkg/team/key.go
@@ -0,0 +1,182 @@
+package team
+
+import (
+ "net/http"
+ "opencatd-open/pkg/azureopenai"
+ "opencatd-open/store"
+ "strings"
+
+ "github.com/Sakurasan/to"
+ "github.com/gin-gonic/gin"
+)
+
+type Key struct {
+ ID int `json:"id,omitempty"`
+ Key string `json:"key,omitempty"`
+ Name string `json:"name,omitempty"`
+ ApiType string `json:"api_type,omitempty"`
+ Endpoint string `json:"endpoint,omitempty"`
+ UpdatedAt string `json:"updatedAt,omitempty"`
+ CreatedAt string `json:"createdAt,omitempty"`
+}
+
+func HandleKeys(c *gin.Context) {
+ keys, err := store.GetAllKeys()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "error": err.Error(),
+ })
+ }
+
+ c.JSON(http.StatusOK, keys)
+}
+
+func HandleAddKey(c *gin.Context) {
+ var body Key
+ if err := c.BindJSON(&body); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
+ "message": err.Error(),
+ }})
+ return
+ }
+ body.Name = strings.ToLower(strings.TrimSpace(body.Name))
+ body.Key = strings.TrimSpace(body.Key)
+ if strings.HasPrefix(body.Name, "azure.") {
+ keynames := strings.Split(body.Name, ".")
+ if len(keynames) < 2 {
+ c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
+ "message": "Invalid Key Name",
+ }})
+ return
+ }
+ k := &store.Key{
+ ApiType: "azure",
+ Name: body.Name,
+ Key: body.Key,
+ ResourceNmae: keynames[1],
+ EndPoint: body.Endpoint,
+ }
+ if err := store.CreateKey(k); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
+ "message": err.Error(),
+ }})
+ return
+ }
+ } else if strings.HasPrefix(body.Name, "claude.") {
+ keynames := strings.Split(body.Name, ".")
+ if len(keynames) < 2 {
+ c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
+ "message": "Invalid Key Name",
+ }})
+ return
+ }
+ if body.Endpoint == "" {
+ body.Endpoint = "https://api.anthropic.com"
+ }
+ k := &store.Key{
+ // ApiType: "anthropic",
+ ApiType: "claude",
+ Name: body.Name,
+ Key: body.Key,
+ ResourceNmae: keynames[1],
+ EndPoint: body.Endpoint,
+ }
+ if err := store.CreateKey(k); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
+ "message": err.Error(),
+ }})
+ return
+ }
+ } else if strings.HasPrefix(body.Name, "google.") {
+ keynames := strings.Split(body.Name, ".")
+ if len(keynames) < 2 {
+ c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
+ "message": "Invalid Key Name",
+ }})
+ return
+ }
+
+ k := &store.Key{
+ // ApiType: "anthropic",
+ ApiType: "google",
+ Name: body.Name,
+ Key: body.Key,
+ ResourceNmae: keynames[1],
+ EndPoint: body.Endpoint,
+ }
+ if err := store.CreateKey(k); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
+ "message": err.Error(),
+ }})
+ return
+ }
+ } else if strings.HasPrefix(body.Name, "github.") {
+ keynames := strings.Split(body.Name, ".")
+ if len(keynames) < 2 {
+ c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
+ "message": "Invalid Key Name",
+ }})
+ return
+ }
+
+ k := &store.Key{
+ ApiType: "github",
+ Name: body.Name,
+ Key: body.Key,
+ ResourceNmae: keynames[1],
+ EndPoint: body.Endpoint,
+ }
+ if err := store.CreateKey(k); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
+ "message": err.Error(),
+ }})
+ return
+ }
+ } else {
+ if body.ApiType == "" {
+ if err := store.AddKey("openai", body.Key, body.Name); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
+ "message": err.Error(),
+ }})
+ return
+ }
+ } else {
+ k := &store.Key{
+ ApiType: body.ApiType,
+ Name: body.Name,
+ Key: body.Key,
+ ResourceNmae: azureopenai.GetResourceName(body.Endpoint),
+ EndPoint: body.Endpoint,
+ }
+ if err := store.CreateKey(k); err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
+ "message": err.Error(),
+ }})
+ return
+ }
+ }
+
+ }
+
+ k, err := store.GetKeyrByName(body.Name)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
+ "message": err.Error(),
+ }})
+ return
+ }
+ c.JSON(http.StatusOK, k)
+}
+
+func HandleDelKey(c *gin.Context) {
+ id := to.Int(c.Param("id"))
+ if id < 1 {
+ c.JSON(http.StatusOK, gin.H{"error": "invalid key id"})
+ return
+ }
+ if err := store.DeleteKey(uint(id)); err != nil {
+ c.JSON(http.StatusOK, gin.H{"error": "invalid key id"})
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{"message": "ok"})
+}
diff --git a/pkg/team/me.go b/pkg/team/me.go
new file mode 100644
index 0000000..146ce7c
--- /dev/null
+++ b/pkg/team/me.go
@@ -0,0 +1,104 @@
+package team
+
+import (
+ "errors"
+ "net/http"
+ "opencatd-open/store"
+ "time"
+
+ "github.com/Sakurasan/to"
+ "github.com/gin-gonic/gin"
+ "github.com/google/uuid"
+ "gorm.io/gorm"
+)
+
+func Handleinit(c *gin.Context) {
+ user, err := store.GetUserByID(1)
+ if err != nil {
+ if errors.Is(err, gorm.ErrRecordNotFound) {
+ u := store.User{Name: "root", Token: uuid.NewString()}
+ u.ID = 1
+ if err := store.CreateUser(&u); err != nil {
+ c.JSON(http.StatusForbidden, gin.H{
+ "error": err.Error(),
+ })
+ return
+ } else {
+ rootToken = u.Token
+ resJSON := User{
+ false,
+ int(u.ID),
+ u.UpdatedAt.Format(time.RFC3339),
+ u.Name,
+ u.Token,
+ u.CreatedAt.Format(time.RFC3339),
+ }
+ c.JSON(http.StatusOK, resJSON)
+ return
+ }
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "error": err.Error(),
+ })
+ return
+ }
+ if user.ID == uint(1) {
+ c.JSON(http.StatusForbidden, gin.H{
+ "error": "super user already exists, use cli to reset password",
+ })
+ }
+}
+
+func HandleMe(c *gin.Context) {
+ token := c.GetHeader("Authorization")
+ u, err := store.GetUserByToken(token[7:])
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "error": err.Error(),
+ })
+ }
+
+ resJSON := User{
+ false,
+ int(u.ID),
+ u.UpdatedAt.Format(time.RFC3339),
+ u.Name,
+ u.Token,
+ u.CreatedAt.Format(time.RFC3339),
+ }
+ c.JSON(http.StatusOK, resJSON)
+}
+
+func HandleMeUsage(c *gin.Context) {
+ token := c.GetHeader("Authorization")
+ fromStr := c.Query("from")
+ toStr := c.Query("to")
+ getMonthStartAndEnd := func() (start, end string) {
+ loc, _ := time.LoadLocation("Local")
+ now := time.Now().In(loc)
+
+ year, month, _ := now.Date()
+
+ startOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, loc)
+ endOfMonth := startOfMonth.AddDate(0, 1, 0)
+
+ start = startOfMonth.Format("2006-01-02")
+ end = endOfMonth.Format("2006-01-02")
+ return
+ }
+ if fromStr == "" || toStr == "" {
+ fromStr, toStr = getMonthStartAndEnd()
+ }
+ user, err := store.GetUserByToken(token)
+ if err != nil {
+ c.AbortWithError(http.StatusForbidden, err)
+ return
+ }
+ usage, err := store.QueryUserUsage(to.String(user.ID), fromStr, toStr)
+ if err != nil {
+ c.AbortWithError(http.StatusForbidden, err)
+ return
+ }
+
+ c.JSON(200, usage)
+}
diff --git a/pkg/team/middleware.go b/pkg/team/middleware.go
new file mode 100644
index 0000000..ab01e79
--- /dev/null
+++ b/pkg/team/middleware.go
@@ -0,0 +1,59 @@
+package team
+
+import (
+ "log"
+ "net/http"
+ "opencatd-open/store"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+var (
+ rootToken string
+)
+
+func AuthMiddleware() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ if rootToken == "" {
+ u, err := store.GetUserByID(uint(1))
+ if err != nil {
+ c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
+ c.Abort()
+ return
+ }
+ rootToken = u.Token
+ }
+ token := c.GetHeader("Authorization")
+ if token == "" || token[:7] != "Bearer " {
+ c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
+ c.Abort()
+ return
+ }
+ if store.IsExistAuthCache(token[7:]) {
+ if strings.HasPrefix(c.Request.URL.Path, "/1/me") {
+ c.Next()
+ return
+ }
+ }
+ if token[7:] != rootToken {
+ u, err := store.GetUserByID(uint(1))
+ if err != nil {
+ log.Println(err)
+ c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
+ c.Abort()
+ return
+ }
+ if token[:7] != u.Token {
+ c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
+ c.Abort()
+ return
+ }
+ rootToken = u.Token
+ store.LoadAuthCache()
+ }
+ // 可以在这里对 token 进行验证并检查权限
+
+ c.Next()
+ }
+}
diff --git a/pkg/team/usage.go b/pkg/team/usage.go
new file mode 100644
index 0000000..5d9763e
--- /dev/null
+++ b/pkg/team/usage.go
@@ -0,0 +1,38 @@
+package team
+
+import (
+ "net/http"
+ "opencatd-open/store"
+ "time"
+
+ "github.com/gin-gonic/gin"
+)
+
+func HandleUsage(c *gin.Context) {
+ fromStr := c.Query("from")
+ toStr := c.Query("to")
+ getMonthStartAndEnd := func() (start, end string) {
+ loc, _ := time.LoadLocation("Local")
+ now := time.Now().In(loc)
+
+ year, month, _ := now.Date()
+
+ startOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, loc)
+ endOfMonth := startOfMonth.AddDate(0, 1, 0)
+
+ start = startOfMonth.Format("2006-01-02")
+ end = endOfMonth.Format("2006-01-02")
+ return
+ }
+ if fromStr == "" || toStr == "" {
+ fromStr, toStr = getMonthStartAndEnd()
+ }
+
+ usage, err := store.QueryUsage(fromStr, toStr)
+ if err != nil {
+ c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
+ return
+ }
+
+ c.JSON(200, usage)
+}
diff --git a/pkg/team/user.go b/pkg/team/user.go
new file mode 100644
index 0000000..b848be1
--- /dev/null
+++ b/pkg/team/user.go
@@ -0,0 +1,89 @@
+package team
+
+import (
+ "net/http"
+ "opencatd-open/store"
+
+ "github.com/Sakurasan/to"
+ "github.com/gin-gonic/gin"
+ "github.com/google/uuid"
+)
+
+type User struct {
+ IsDelete bool `json:"IsDelete,omitempty"`
+ ID int `json:"id,omitempty"`
+ UpdatedAt string `json:"updatedAt,omitempty"`
+ Name string `json:"name,omitempty"`
+ Token string `json:"token,omitempty"`
+ CreatedAt string `json:"createdAt,omitempty"`
+}
+
+func HandleUsers(c *gin.Context) {
+ users, err := store.GetAllUsers()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "error": err.Error(),
+ })
+ }
+
+ c.JSON(http.StatusOK, users)
+}
+
+func HandleAddUser(c *gin.Context) {
+ var body User
+ if err := c.BindJSON(&body); err != nil {
+ c.JSON(http.StatusOK, gin.H{"error": err.Error()})
+ return
+ }
+ if len(body.Name) == 0 {
+ c.JSON(http.StatusOK, gin.H{"error": "invalid user name"})
+ return
+ }
+
+ if err := store.AddUser(body.Name, uuid.NewString()); err != nil {
+ c.JSON(http.StatusOK, gin.H{"error": err.Error()})
+ return
+ }
+ u, err := store.GetUserByName(body.Name)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{"error": err.Error()})
+ return
+ }
+ c.JSON(http.StatusOK, u)
+}
+
+func HandleDelUser(c *gin.Context) {
+ id := to.Int(c.Param("id"))
+ if id <= 1 {
+ c.JSON(http.StatusOK, gin.H{"error": "invalid user id"})
+ return
+ }
+ if err := store.DeleteUser(uint(id)); err != nil {
+ c.JSON(http.StatusOK, gin.H{"error": err.Error()})
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{"message": "ok"})
+}
+
+func HandleResetUserToken(c *gin.Context) {
+ id := to.Int(c.Param("id"))
+ newtoken := c.Query("token")
+ if newtoken == "" {
+ newtoken = uuid.NewString()
+ }
+
+ if err := store.UpdateUser(uint(id), newtoken); err != nil {
+ c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
+ return
+ }
+ u, err := store.GetUserByID(uint(id))
+ if err != nil {
+ c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
+ return
+ }
+ if u.ID == uint(1) {
+ rootToken = u.Token
+ }
+ c.JSON(http.StatusOK, u)
+}
diff --git a/pkg/tokenizer/tokenizer.go b/pkg/tokenizer/tokenizer.go
index 3f8be85..4176396 100644
--- a/pkg/tokenizer/tokenizer.go
+++ b/pkg/tokenizer/tokenizer.go
@@ -97,8 +97,14 @@ func Cost(model string, promptCount, completionCount int) float64 {
cost = 0.01*float64(prompt/1000) + 0.03*float64(completion/1000)
case "gpt-4o", "gpt-4o-2024-05-13":
cost = 0.005*float64(prompt/1000) + 0.015*float64(completion/1000)
+ case "gpt-4o-2024-08-06":
+ cost = 0.0025*float64(prompt/1000) + 0.010*float64(completion/1000)
case "gpt-4o-mini", "gpt-4o-mini-2024-07-18":
cost = 0.00015*float64(prompt/1000) + 0.0006*float64(completion/1000)
+ case "o1-preview", "o1-preview-2024-09-12":
+ cost = 0.015*float64(prompt/1000) + 0.06*float64(completion/1000)
+ case "o1-mini", "o1-mini-2024-09-12":
+ cost = 0.003*float64(prompt/1000) + 0.012*float64(completion/1000)
case "whisper-1":
// 0.006$/min
cost = 0.006 * float64(prompt+completion) / 60
@@ -149,7 +155,8 @@ func Cost(model string, promptCount, completionCount int) float64 {
cost = (0.003/1000)*float64(prompt) + (0.015/1000)*float64(completion)
case "claude-3-opus-20240229":
cost = (0.015/1000)*float64(prompt) + (0.075/1000)*float64(completion)
-
+ case "claude-3-5-sonnet", "claude-3-5-sonnet-20240620":
+ cost = (0.003/1000)*float64(prompt) + (0.015/1000)*float64(completion)
// google
// https://ai.google.dev/pricing?hl=zh-cn
case "gemini-pro":
diff --git a/pkg/vertexai/auth.go b/pkg/vertexai/auth.go
new file mode 100644
index 0000000..a7a6967
--- /dev/null
+++ b/pkg/vertexai/auth.go
@@ -0,0 +1,167 @@
+/*
+https://docs.anthropic.com/zh-CN/api/claude-on-vertex-ai
+
+MODEL_ID=claude-3-5-sonnet@20240620
+REGION=us-east5
+PROJECT_ID=MY_PROJECT_ID
+
+curl \
+-X POST \
+-H "Authorization: Bearer $(gcloud auth print-access-token)" \
+-H "Content-Type: application/json" \
+https://$LOCATION-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/anthropic/models/${MODEL_ID}:streamRawPredict \
+-d '{
+ "anthropic_version": "vertex-2023-10-16",
+ "messages": [{
+ "role": "user",
+ "content": "介绍一下你自己"
+ }],
+ "stream": true,
+ "max_tokens": 4096
+}'
+*/
+
+package vertexai
+
+import (
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/json"
+ "encoding/pem"
+ "fmt"
+ "net/http"
+ "net/url"
+ "time"
+
+ "github.com/golang-jwt/jwt"
+)
+
+// json文件存储在ApiKey.ApiSecret中
+type VertexSecretKey struct {
+ Type string `json:"type"`
+ ProjectID string `json:"project_id"`
+ PrivateKeyID string `json:"private_key_id"`
+ PrivateKey string `json:"private_key"`
+ ClientEmail string `json:"client_email"`
+ ClientID string `json:"client_id"`
+ AuthURI string `json:"auth_uri"`
+ TokenURI string `json:"token_uri"`
+ AuthProviderX509CertURL string `json:"auth_provider_x509_cert_url"`
+ ClientX509CertURL string `json:"client_x509_cert_url"`
+ UniverseDomain string `json:"universe_domain"`
+}
+
+type VertexClaudeModel struct {
+ VertexName string
+ Region string
+}
+
+var VertexClaudeModelMap = map[string]VertexClaudeModel{
+ "claude-3-opus": {
+ VertexName: "claude-3-opus@20240229",
+ Region: "us-east5",
+ },
+ "claude-3-sonnet": {
+ VertexName: "claude-3-sonnet@20240229",
+ Region: "us-central1",
+ // Region: "asia-southeast1",
+ },
+ "claude-3-haiku": {
+ VertexName: "claude-3-haiku@20240307",
+ Region: "us-central1",
+ // Region: "europe-west4",
+ },
+ "claude-3-opus-20240229": {
+ VertexName: "claude-3-opus@20240229",
+ Region: "us-east5",
+ },
+ "claude-3-sonnet-20240229": {
+ VertexName: "claude-3-sonnet@20240229",
+ Region: "us-central1",
+ // Region: "asia-southeast1",
+ },
+ "claude-3-haiku-20240307": {
+ VertexName: "claude-3-haiku@20240307",
+ Region: "us-central1",
+ // Region: "europe-west4",
+ },
+ "claude-3-5-sonnet": {
+ VertexName: "claude-3-5-sonnet@20240620",
+ Region: "us-east5",
+ // Region: "europe-west1",
+ },
+ "claude-3-5-sonnet-20240620": {
+ VertexName: "claude-3-5-sonnet@20240620",
+ Region: "us-east5",
+ // Region: "europe-west1",
+ },
+}
+
+func createSignedJWT(email, privateKeyPEM string) (string, error) {
+ block, _ := pem.Decode([]byte(privateKeyPEM))
+ if block == nil {
+ return "", fmt.Errorf("failed to parse PEM block containing the private key")
+ }
+
+ privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
+ if err != nil {
+ return "", err
+ }
+
+ rsaKey, ok := privateKey.(*rsa.PrivateKey)
+ if !ok {
+ return "", fmt.Errorf("not an RSA private key")
+ }
+
+ now := time.Now()
+ claims := jwt.MapClaims{
+ "iss": email,
+ "aud": "https://www.googleapis.com/oauth2/v4/token",
+ "iat": now.Unix(),
+ "exp": now.Add(10 * time.Minute).Unix(),
+ "scope": "https://www.googleapis.com/auth/cloud-platform",
+ }
+
+ token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
+ return token.SignedString(rsaKey)
+}
+
+func exchangeJwtForAccessToken(signedJWT string) (string, error) {
+ authURL := "https://www.googleapis.com/oauth2/v4/token"
+ data := url.Values{}
+ data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
+ data.Set("assertion", signedJWT)
+
+ resp, err := http.PostForm(authURL, data)
+ if err != nil {
+ return "", err
+ }
+ defer resp.Body.Close()
+
+ var result map[string]interface{}
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+ return "", err
+ }
+
+ accessToken, ok := result["access_token"].(string)
+ if !ok {
+ return "", fmt.Errorf("access token not found in response")
+ }
+
+ return accessToken, nil
+}
+
+// 获取gcloud auth token
+func GcloudAuth(ClientEmail, PrivateKey string) (string, error) {
+ signedJWT, err := createSignedJWT(ClientEmail, PrivateKey)
+ if err != nil {
+ return "", err
+ }
+
+ token, err := exchangeJwtForAccessToken(signedJWT)
+ if err != nil {
+ return "", fmt.Errorf("Invalid jwt token: %v\n", err)
+ }
+
+ return token, nil
+}
diff --git a/router/chat.go b/router/chat.go
index 509f6d9..5292a9f 100644
--- a/router/chat.go
+++ b/router/chat.go
@@ -18,33 +18,7 @@ func ChatHandler(c *gin.Context) {
return
}
- // if chatreq.Messages[len(chatreq.Messages)-1].Role == "user" {
- // result, err := search.BingSearch(search.SearchParams{Query: string(chatreq.Messages[len(chatreq.Messages)-1].Content)})
- // if err == nil {
- // var msgs []openai.ChatCompletionMessage
- // for i, m := range chatreq.Messages {
- // var buf bytes.Buffer
- // buf.WriteString("根据我提问的语言回答我,我将提供一些从搜索引擎获取的信息(以websearch:开头)。你自行判断是否使用搜索引擎获取的内容。不要原封不动照抄,根据你自己的知识库提炼信息之后回答我\n\n")
- // if m.Role == "system" {
- // buf.Write(m.Content)
- // msgs = append(msgs, openai.ChatCompletionMessage{Role: m.Role, Content: buf.Bytes()})
- // } else {
- // msgs = append(msgs, openai.ChatCompletionMessage{Role: m.Role, Content: buf.Bytes()})
- // }
- // if i == len(chatreq.Messages)-1 {
- // m.Content = append(m.Content, json.RawMessage("\n\nwebsearch:")...)
- // m.Content = append(m.Content, json.RawMessage(result.(string))...)
- // msgs = append(msgs, openai.ChatCompletionMessage{Role: m.Role, Content: m.Content})
- // } else {
- // msgs = append(msgs, openai.ChatCompletionMessage{Role: m.Role, Content: m.Content})
- // }
-
- // }
- // chatreq.Messages = msgs
- // }
- // }
-
- if strings.HasPrefix(chatreq.Model, "gpt") {
+ if strings.HasPrefix(chatreq.Model, "gpt") || strings.HasPrefix(chatreq.Model, "o1-") {
openai.ChatProxy(c, &chatreq)
return
}
diff --git a/router/router.go b/router/router.go
index 7e639e4..d1e250b 100644
--- a/router/router.go
+++ b/router/router.go
@@ -1,108 +1,66 @@
package router
import (
- "bufio"
- "bytes"
- "context"
"crypto/tls"
- "encoding/json"
- "errors"
"fmt"
- "io"
"log"
- "mime/multipart"
"net"
"net/http"
"net/http/httputil"
- "net/url"
- "opencatd-open/pkg/azureopenai"
"opencatd-open/pkg/claude"
oai "opencatd-open/pkg/openai"
- "opencatd-open/pkg/tokenizer"
"opencatd-open/store"
"os"
- "path/filepath"
- "strings"
"time"
- "github.com/Sakurasan/to"
- "github.com/duke-git/lancet/v2/cryptor"
- "github.com/faiface/beep"
- "github.com/faiface/beep/mp3"
- "github.com/faiface/beep/wav"
"github.com/gin-gonic/gin"
- "github.com/google/uuid"
- "github.com/sashabaranov/go-openai"
- "gopkg.in/vansante/go-ffprobe.v2"
- "gorm.io/gorm"
)
var (
- rootToken string
baseUrl = "https://api.openai.com"
GPT3Dot5Turbo = "gpt-3.5-turbo"
GPT4 = "gpt-4"
- client = getHttpClient()
)
-type User struct {
- IsDelete bool `json:"IsDelete,omitempty"`
- ID int `json:"id,omitempty"`
- UpdatedAt string `json:"updatedAt,omitempty"`
- Name string `json:"name,omitempty"`
- Token string `json:"token,omitempty"`
- CreatedAt string `json:"createdAt,omitempty"`
-}
+// type ChatCompletionMessage struct {
+// Role string `json:"role"`
+// Content string `json:"content"`
+// Name string `json:"name,omitempty"`
+// }
-type Key struct {
- ID int `json:"id,omitempty"`
- Key string `json:"key,omitempty"`
- Name string `json:"name,omitempty"`
- ApiType string `json:"api_type,omitempty"`
- Endpoint string `json:"endpoint,omitempty"`
- UpdatedAt string `json:"updatedAt,omitempty"`
- CreatedAt string `json:"createdAt,omitempty"`
-}
+// type ChatCompletionRequest struct {
+// Model string `json:"model"`
+// Messages []ChatCompletionMessage `json:"messages"`
+// MaxTokens int `json:"max_tokens,omitempty"`
+// Temperature float32 `json:"temperature,omitempty"`
+// TopP float32 `json:"top_p,omitempty"`
+// N int `json:"n,omitempty"`
+// Stream bool `json:"stream,omitempty"`
+// Stop []string `json:"stop,omitempty"`
+// PresencePenalty float32 `json:"presence_penalty,omitempty"`
+// FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
+// LogitBias map[string]int `json:"logit_bias,omitempty"`
+// User string `json:"user,omitempty"`
+// }
-type ChatCompletionMessage struct {
- Role string `json:"role"`
- Content string `json:"content"`
- Name string `json:"name,omitempty"`
-}
+// type ChatCompletionChoice struct {
+// Index int `json:"index"`
+// Message ChatCompletionMessage `json:"message"`
+// FinishReason string `json:"finish_reason"`
+// }
-type ChatCompletionRequest struct {
- Model string `json:"model"`
- Messages []ChatCompletionMessage `json:"messages"`
- MaxTokens int `json:"max_tokens,omitempty"`
- Temperature float32 `json:"temperature,omitempty"`
- TopP float32 `json:"top_p,omitempty"`
- N int `json:"n,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Stop []string `json:"stop,omitempty"`
- PresencePenalty float32 `json:"presence_penalty,omitempty"`
- FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
- LogitBias map[string]int `json:"logit_bias,omitempty"`
- User string `json:"user,omitempty"`
-}
-
-type ChatCompletionChoice struct {
- Index int `json:"index"`
- Message ChatCompletionMessage `json:"message"`
- FinishReason string `json:"finish_reason"`
-}
-
-type ChatCompletionResponse struct {
- ID string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Model string `json:"model"`
- Choices []ChatCompletionChoice `json:"choices"`
- Usage struct {
- PromptTokens int `json:"prompt_tokens"`
- CompletionTokens int `json:"completion_tokens"`
- TotalTokens int `json:"total_tokens"`
- } `json:"usage"`
-}
+// type ChatCompletionResponse struct {
+// ID string `json:"id"`
+// Object string `json:"object"`
+// Created int64 `json:"created"`
+// Model string `json:"model"`
+// Choices []ChatCompletionChoice `json:"choices"`
+// Usage struct {
+// PromptTokens int `json:"prompt_tokens"`
+// CompletionTokens int `json:"completion_tokens"`
+// TotalTokens int `json:"total_tokens"`
+// } `json:"usage"`
+// }
func init() {
if openai_endpoint := os.Getenv("openai_endpoint"); openai_endpoint != "" {
@@ -111,385 +69,9 @@ func init() {
}
}
-func AuthMiddleware() gin.HandlerFunc {
- return func(c *gin.Context) {
- if rootToken == "" {
- u, err := store.GetUserByID(uint(1))
- if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
- c.Abort()
- return
- }
- rootToken = u.Token
- }
- token := c.GetHeader("Authorization")
- if token == "" || token[:7] != "Bearer " {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
- c.Abort()
- return
- }
- if store.IsExistAuthCache(token[7:]) {
- if strings.HasPrefix(c.Request.URL.Path, "/1/me") {
- c.Next()
- return
- }
- }
- if token[7:] != rootToken {
- u, err := store.GetUserByID(uint(1))
- if err != nil {
- log.Println(err)
- c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
- c.Abort()
- return
- }
- if token[:7] != u.Token {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
- c.Abort()
- return
- }
- rootToken = u.Token
- store.LoadAuthCache()
- }
- // 可以在这里对 token 进行验证并检查权限
-
- c.Next()
- }
-}
-
-func Handleinit(c *gin.Context) {
- user, err := store.GetUserByID(1)
- if err != nil {
- if errors.Is(err, gorm.ErrRecordNotFound) {
- u := store.User{Name: "root", Token: uuid.NewString()}
- u.ID = 1
- if err := store.CreateUser(&u); err != nil {
- c.JSON(http.StatusForbidden, gin.H{
- "error": err.Error(),
- })
- return
- } else {
- rootToken = u.Token
- resJSON := User{
- false,
- int(u.ID),
- u.UpdatedAt.Format(time.RFC3339),
- u.Name,
- u.Token,
- u.CreatedAt.Format(time.RFC3339),
- }
- c.JSON(http.StatusOK, resJSON)
- return
- }
- }
- c.JSON(http.StatusOK, gin.H{
- "error": err.Error(),
- })
- return
- }
- if user.ID == uint(1) {
- c.JSON(http.StatusForbidden, gin.H{
- "error": "super user already exists, use cli to reset password",
- })
- }
-}
-
-func HandleMe(c *gin.Context) {
- token := c.GetHeader("Authorization")
- u, err := store.GetUserByToken(token[7:])
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "error": err.Error(),
- })
- }
-
- resJSON := User{
- false,
- int(u.ID),
- u.UpdatedAt.Format(time.RFC3339),
- u.Name,
- u.Token,
- u.CreatedAt.Format(time.RFC3339),
- }
- c.JSON(http.StatusOK, resJSON)
-}
-
-func HandleMeUsage(c *gin.Context) {
- token := c.GetHeader("Authorization")
- fromStr := c.Query("from")
- toStr := c.Query("to")
- getMonthStartAndEnd := func() (start, end string) {
- loc, _ := time.LoadLocation("Local")
- now := time.Now().In(loc)
-
- year, month, _ := now.Date()
-
- startOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, loc)
- endOfMonth := startOfMonth.AddDate(0, 1, 0)
-
- start = startOfMonth.Format("2006-01-02")
- end = endOfMonth.Format("2006-01-02")
- return
- }
- if fromStr == "" || toStr == "" {
- fromStr, toStr = getMonthStartAndEnd()
- }
- user, err := store.GetUserByToken(token)
- if err != nil {
- c.AbortWithError(http.StatusForbidden, err)
- return
- }
- usage, err := store.QueryUserUsage(to.String(user.ID), fromStr, toStr)
- if err != nil {
- c.AbortWithError(http.StatusForbidden, err)
- return
- }
-
- c.JSON(200, usage)
-}
-
-func HandleKeys(c *gin.Context) {
- keys, err := store.GetAllKeys()
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "error": err.Error(),
- })
- }
-
- c.JSON(http.StatusOK, keys)
-}
-
-func HandleUsers(c *gin.Context) {
- users, err := store.GetAllUsers()
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "error": err.Error(),
- })
- }
-
- c.JSON(http.StatusOK, users)
-}
-
-func HandleAddKey(c *gin.Context) {
- var body Key
- if err := c.BindJSON(&body); err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
- "message": err.Error(),
- }})
- return
- }
- body.Name = strings.ToLower(strings.TrimSpace(body.Name))
- body.Key = strings.TrimSpace(body.Key)
- if strings.HasPrefix(body.Name, "azure.") {
- keynames := strings.Split(body.Name, ".")
- if len(keynames) < 2 {
- c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
- "message": "Invalid Key Name",
- }})
- return
- }
- k := &store.Key{
- ApiType: "azure",
- Name: body.Name,
- Key: body.Key,
- ResourceNmae: keynames[1],
- EndPoint: body.Endpoint,
- }
- if err := store.CreateKey(k); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
- "message": err.Error(),
- }})
- return
- }
- } else if strings.HasPrefix(body.Name, "claude.") {
- keynames := strings.Split(body.Name, ".")
- if len(keynames) < 2 {
- c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
- "message": "Invalid Key Name",
- }})
- return
- }
- if body.Endpoint == "" {
- body.Endpoint = "https://api.anthropic.com"
- }
- k := &store.Key{
- // ApiType: "anthropic",
- ApiType: "claude",
- Name: body.Name,
- Key: body.Key,
- ResourceNmae: keynames[1],
- EndPoint: body.Endpoint,
- }
- if err := store.CreateKey(k); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
- "message": err.Error(),
- }})
- return
- }
- } else if strings.HasPrefix(body.Name, "google.") {
- keynames := strings.Split(body.Name, ".")
- if len(keynames) < 2 {
- c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
- "message": "Invalid Key Name",
- }})
- return
- }
-
- k := &store.Key{
- // ApiType: "anthropic",
- ApiType: "google",
- Name: body.Name,
- Key: body.Key,
- ResourceNmae: keynames[1],
- EndPoint: body.Endpoint,
- }
- if err := store.CreateKey(k); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
- "message": err.Error(),
- }})
- return
- }
- } else {
- if body.ApiType == "" {
- if err := store.AddKey("openai", body.Key, body.Name); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
- "message": err.Error(),
- }})
- return
- }
- } else {
- k := &store.Key{
- ApiType: body.ApiType,
- Name: body.Name,
- Key: body.Key,
- ResourceNmae: azureopenai.GetResourceName(body.Endpoint),
- EndPoint: body.Endpoint,
- }
- if err := store.CreateKey(k); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
- "message": err.Error(),
- }})
- return
- }
- }
-
- }
-
- k, err := store.GetKeyrByName(body.Name)
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
- "message": err.Error(),
- }})
- return
- }
- c.JSON(http.StatusOK, k)
-}
-
-func HandleDelKey(c *gin.Context) {
- id := to.Int(c.Param("id"))
- if id < 1 {
- c.JSON(http.StatusOK, gin.H{"error": "invalid key id"})
- return
- }
- if err := store.DeleteKey(uint(id)); err != nil {
- c.JSON(http.StatusOK, gin.H{"error": "invalid key id"})
- return
- }
- c.JSON(http.StatusOK, gin.H{"message": "ok"})
-}
-
-func HandleAddUser(c *gin.Context) {
- var body User
- if err := c.BindJSON(&body); err != nil {
- c.JSON(http.StatusOK, gin.H{"error": err.Error()})
- return
- }
- if len(body.Name) == 0 {
- c.JSON(http.StatusOK, gin.H{"error": "invalid user name"})
- return
- }
-
- if err := store.AddUser(body.Name, uuid.NewString()); err != nil {
- c.JSON(http.StatusOK, gin.H{"error": err.Error()})
- return
- }
- u, err := store.GetUserByName(body.Name)
- if err != nil {
- c.JSON(http.StatusOK, gin.H{"error": err.Error()})
- return
- }
- c.JSON(http.StatusOK, u)
-}
-
-func HandleDelUser(c *gin.Context) {
- id := to.Int(c.Param("id"))
- if id <= 1 {
- c.JSON(http.StatusOK, gin.H{"error": "invalid user id"})
- return
- }
- if err := store.DeleteUser(uint(id)); err != nil {
- c.JSON(http.StatusOK, gin.H{"error": err.Error()})
- return
- }
-
- c.JSON(http.StatusOK, gin.H{"message": "ok"})
-}
-
-func HandleResetUserToken(c *gin.Context) {
- id := to.Int(c.Param("id"))
- newtoken := c.Query("token")
- if newtoken == "" {
- newtoken = uuid.NewString()
- }
-
- if err := store.UpdateUser(uint(id), newtoken); err != nil {
- c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
- return
- }
- u, err := store.GetUserByID(uint(id))
- if err != nil {
- c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
- return
- }
- if u.ID == uint(1) {
- rootToken = u.Token
- }
- c.JSON(http.StatusOK, u)
-}
-
-func GenerateToken() string {
- token := uuid.New()
- return token.String()
-}
-
-func getHttpClient() *http.Client {
- tr := &http.Transport{
- Proxy: http.ProxyFromEnvironment,
- DialContext: (&net.Dialer{
- Timeout: 30 * time.Second,
- KeepAlive: 30 * time.Second,
- }).DialContext,
- ForceAttemptHTTP2: true,
- MaxIdleConns: 100,
- IdleConnTimeout: 90 * time.Second,
- TLSHandshakeTimeout: 10 * time.Second,
- ExpectContinueTimeout: 1 * time.Second,
- TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
- }
- return &http.Client{Transport: tr}
-}
-
-func HandleProy(c *gin.Context) {
+func HandleProxy(c *gin.Context) {
var (
- localuser bool
- isStream bool
- chatreq = openai.ChatCompletionRequest{}
- chatres = openai.ChatCompletionResponse{}
- chatlog store.Tokens
- onekey store.Key
- pre_prompt string
- req *http.Request
- err error
- // wg sync.WaitGroup
+ localuser bool
)
auth := c.Request.Header.Get("Authorization")
if len(auth) > 7 && auth[:7] == "Bearer " {
@@ -497,11 +79,17 @@ func HandleProy(c *gin.Context) {
c.Set("localuser", auth[7:])
}
if c.Request.URL.Path == "/v1/complete" {
- claude.ClaudeProxy(c)
- return
+ if localuser {
+ claude.ClaudeProxy(c)
+ return
+ } else {
+ HandleReverseProxy(c, "api.anthropic.com")
+ return
+ }
+
}
if c.Request.URL.Path == "/v1/audio/transcriptions" {
- WhisperProxy(c)
+ oai.WhisperProxy(c)
return
}
if c.Request.URL.Path == "/v1/audio/speech" {
@@ -514,191 +102,30 @@ func HandleProy(c *gin.Context) {
return
}
- if c.Request.URL.Path == "/v1/chat/completions" && localuser {
- if store.KeysCache.ItemCount() == 0 {
- c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{
- "message": "No Api-Key Available",
- }})
- return
- }
-
- ChatHandler(c)
- return
-
- if err := c.BindJSON(&chatreq); err != nil {
- c.AbortWithError(http.StatusBadRequest, err)
- return
- }
-
- chatlog.Model = chatreq.Model
- for _, m := range chatreq.Messages {
- pre_prompt += m.Content + "\n"
- }
- chatlog.PromptHash = cryptor.Md5String(pre_prompt)
- chatlog.PromptCount = tokenizer.NumTokensFromMessages(chatreq.Messages, chatreq.Model)
- isStream = chatreq.Stream
- chatlog.UserID, _ = store.GetUserID(auth[7:])
-
- var body bytes.Buffer
- json.NewEncoder(&body).Encode(chatreq)
-
- if strings.HasPrefix(chatreq.Model, "claude-") {
- onekey, err = store.SelectKeyCache("claude")
- if err != nil {
- c.AbortWithError(http.StatusForbidden, err)
- }
- } else {
- onekey = store.FromKeyCacheRandomItemKey()
- }
-
- // 创建 API 请求
- switch onekey.ApiType {
- case "claude":
- payload, _ := claude.TransReq(&chatreq)
- buildurl := "https://api.anthropic.com/v1/complete"
- req, err = http.NewRequest("POST", buildurl, payload)
- req.Header.Add("accept", "application/json")
- req.Header.Add("anthropic-version", "2023-06-01")
- req.Header.Add("x-api-key", onekey.Key)
- req.Header.Add("content-type", "application/json")
- case "azure":
- fallthrough
- case "azure_openai":
- var buildurl string
- var apiVersion = "2023-05-15"
- if onekey.EndPoint != "" {
- buildurl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", onekey.EndPoint, modelmap(chatreq.Model), apiVersion)
- } else {
- buildurl = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s", onekey.ResourceNmae, modelmap(chatreq.Model), apiVersion)
- }
- req, err = http.NewRequest(c.Request.Method, buildurl, &body)
- req.Header = c.Request.Header
- req.Header.Set("api-key", onekey.Key)
- case "openai":
- fallthrough
- default:
- if onekey.EndPoint != "" {
- req, err = http.NewRequest(c.Request.Method, onekey.EndPoint+c.Request.RequestURI, &body)
- } else {
- req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body)
- }
-
- req.Header = c.Request.Header
- req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key))
- }
- if err != nil {
- log.Println(err)
- c.JSON(http.StatusOK, gin.H{"error": err.Error()})
- return
- }
-
- } else {
- req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, c.Request.Body)
- if err != nil {
- log.Println(err)
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
- return
- }
- req.Header = c.Request.Header
- }
-
- resp, err := client.Do(req)
- if err != nil {
- log.Println(err)
- c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
- return
- }
- defer resp.Body.Close()
-
- // 复制 API 响应头部
- for name, values := range resp.Header {
- for _, value := range values {
- c.Writer.Header().Add(name, value)
- }
- }
- head := map[string]string{
- "Cache-Control": "no-store",
- "access-control-allow-origin": "*",
- "access-control-allow-credentials": "true",
- }
- for k, v := range head {
- if _, ok := resp.Header[k]; !ok {
- c.Writer.Header().Set(k, v)
- }
- }
- resp.Header.Del("content-security-policy")
- resp.Header.Del("content-security-policy-report-only")
- resp.Header.Del("clear-site-data")
-
- c.Writer.WriteHeader(resp.StatusCode)
- writer := bufio.NewWriter(c.Writer)
- defer writer.Flush()
-
- reader := bufio.NewReader(resp.Body)
-
- if resp.StatusCode == 200 && localuser {
- switch onekey.ApiType {
- case "claude":
- claude.TransRsp(c, isStream, chatlog, reader)
- return
- case "openai", "azure", "azure_openai":
- fallthrough
- default:
- if isStream {
- contentCh := fetchResponseContent(c, reader)
- var buffer bytes.Buffer
- for content := range contentCh {
- buffer.WriteString(content)
- }
- chatlog.CompletionCount = tokenizer.NumTokensFromStr(buffer.String(), chatreq.Model)
- chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
- chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
- if err := store.Record(&chatlog); err != nil {
- log.Println(err)
- }
- if err := store.SumDaily(chatlog.UserID); err != nil {
- log.Println(err)
- }
- return
- }
- res, err := io.ReadAll(reader)
- if err != nil {
- c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{
- "message": err.Error(),
+ if c.Request.URL.Path == "/v1/chat/completions" {
+ if localuser {
+ if store.KeysCache.ItemCount() == 0 {
+ c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{
+ "message": "No Api-Key Available",
}})
return
}
- reader = bufio.NewReader(bytes.NewBuffer(res))
- json.NewDecoder(bytes.NewBuffer(res)).Decode(&chatres)
- chatlog.PromptCount = chatres.Usage.PromptTokens
- chatlog.CompletionCount = chatres.Usage.CompletionTokens
- chatlog.TotalTokens = chatres.Usage.TotalTokens
- chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
- if err := store.Record(&chatlog); err != nil {
- log.Println(err)
- }
- if err := store.SumDaily(chatlog.UserID); err != nil {
- log.Println(err)
- }
- }
- }
- // 返回 API 响应主体
- if _, err := io.Copy(writer, reader); err != nil {
- log.Println(err)
- c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{
- "message": err.Error(),
- }})
+ ChatHandler(c)
+ return
+ }
+ } else {
+ HandleReverseProxy(c, "api.openai.com")
return
}
+
}
-func HandleReverseProxy(c *gin.Context) {
+func HandleReverseProxy(c *gin.Context, targetHost string) {
proxy := &httputil.ReverseProxy{
Director: func(req *http.Request) {
req.URL.Scheme = "https"
- req.URL.Host = "api.openai.com"
- // req.Header.Set("Authorization", "Bearer YOUR_API_KEY_HERE")
+ req.URL.Host = targetHost
},
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
@@ -715,266 +142,13 @@ func HandleReverseProxy(c *gin.Context) {
},
}
- var localuser bool
- auth := c.Request.Header.Get("Authorization")
- if len(auth) > 7 && auth[:7] == "Bearer " {
- log.Println(store.IsExistAuthCache(auth[7:]))
- localuser = store.IsExistAuthCache(auth[7:])
- }
-
req, err := http.NewRequest(c.Request.Method, c.Request.URL.Path, c.Request.Body)
if err != nil {
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
return
}
req.Header = c.Request.Header
- if localuser {
- if store.KeysCache.ItemCount() == 0 {
- c.JSON(http.StatusOK, gin.H{"error": "No Api-Key Available"})
- return
- }
- onekey := store.FromKeyCacheRandomItemKey()
- req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key))
- }
proxy.ServeHTTP(c.Writer, req)
-
-}
-
-func HandleUsage(c *gin.Context) {
- fromStr := c.Query("from")
- toStr := c.Query("to")
- getMonthStartAndEnd := func() (start, end string) {
- loc, _ := time.LoadLocation("Local")
- now := time.Now().In(loc)
-
- year, month, _ := now.Date()
-
- startOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, loc)
- endOfMonth := startOfMonth.AddDate(0, 1, 0)
-
- start = startOfMonth.Format("2006-01-02")
- end = endOfMonth.Format("2006-01-02")
- return
- }
- if fromStr == "" || toStr == "" {
- fromStr, toStr = getMonthStartAndEnd()
- }
-
- usage, err := store.QueryUsage(fromStr, toStr)
- if err != nil {
- c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
- return
- }
-
- c.JSON(200, usage)
-}
-
-func fetchResponseContent(ctx *gin.Context, responseBody *bufio.Reader) <-chan string {
- contentCh := make(chan string)
- go func() {
- defer close(contentCh)
- for {
- line, err := responseBody.ReadString('\n')
- if err == nil {
- lines := strings.Split(line, "")
- for _, word := range lines {
- ctx.Writer.WriteString(word)
- ctx.Writer.Flush()
- }
- if line == "\n" {
- continue
- }
- if strings.HasPrefix(line, "data:") {
- line = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
- if strings.HasSuffix(line, "[DONE]") {
- break
- }
- line = strings.TrimSpace(line)
- }
-
- dec := json.NewDecoder(strings.NewReader(line))
- var data map[string]interface{}
- if err := dec.Decode(&data); err == io.EOF {
- log.Println("EOF:", err)
- break
- } else if err != nil {
- fmt.Println("Error decoding response:", err)
- return
- }
- if choices, ok := data["choices"].([]interface{}); ok {
- for _, choice := range choices {
- choiceMap := choice.(map[string]interface{})
- if content, ok := choiceMap["delta"].(map[string]interface{})["content"]; ok {
- contentCh <- content.(string)
- }
- }
- }
- } else {
- break
- }
- }
- }()
- return contentCh
-}
-
-func modelmap(in string) string {
- // gpt-3.5-turbo -> gpt-35-turbo
- if strings.Contains(in, ".") {
- return strings.ReplaceAll(in, ".", "")
- }
- return in
-}
-
-func WhisperProxy(c *gin.Context) {
- var chatlog store.Tokens
-
- byteBody, _ := io.ReadAll(c.Request.Body)
- c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody))
-
- model, _ := c.GetPostForm("model")
-
- key, err := store.SelectKeyCache("openai")
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{
- "error": gin.H{
- "message": err.Error(),
- },
- })
- return
- }
-
- chatlog.Model = model
-
- token, _ := c.Get("localuser")
-
- lu, err := store.GetUserByToken(token.(string))
- if err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{
- "error": gin.H{
- "message": err.Error(),
- },
- })
- return
- }
- chatlog.UserID = int(lu.ID)
-
- if err := ParseWhisperRequestTokens(c, &chatlog, byteBody); err != nil {
- c.JSON(http.StatusInternalServerError, gin.H{
- "error": gin.H{
- "message": err.Error(),
- },
- })
- return
- }
- if key.EndPoint == "" {
- key.EndPoint = "https://api.openai.com"
- }
- targetUrl, _ := url.ParseRequestURI(key.EndPoint + c.Request.URL.String())
- log.Println(targetUrl)
- proxy := httputil.NewSingleHostReverseProxy(targetUrl)
- proxy.Director = func(req *http.Request) {
- req.Host = targetUrl.Host
- req.URL.Scheme = targetUrl.Scheme
- req.URL.Host = targetUrl.Host
-
- req.Header.Set("Authorization", "Bearer "+key.Key)
- }
-
- proxy.ModifyResponse = func(resp *http.Response) error {
- if resp.StatusCode != http.StatusOK {
- return nil
- }
- chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
- chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
- if err := store.Record(&chatlog); err != nil {
- log.Println(err)
- }
- if err := store.SumDaily(chatlog.UserID); err != nil {
- log.Println(err)
- }
- return nil
- }
- proxy.ServeHTTP(c.Writer, c.Request)
-}
-
-func probe(fileReader io.Reader) (time.Duration, error) {
- ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancelFn()
-
- data, err := ffprobe.ProbeReader(ctx, fileReader)
- if err != nil {
- return 0, err
- }
-
- duration := data.Format.DurationSeconds
- pduration, err := time.ParseDuration(fmt.Sprintf("%fs", duration))
- if err != nil {
- return 0, fmt.Errorf("Error parsing duration: %s", err)
- }
- return pduration, nil
-}
-
-func getAudioDuration(file *multipart.FileHeader) (time.Duration, error) {
- var (
- streamer beep.StreamSeekCloser
- format beep.Format
- err error
- )
-
- f, err := file.Open()
- defer f.Close()
-
- // Get the file extension to determine the audio file type
- fileType := filepath.Ext(file.Filename)
-
- switch fileType {
- case ".mp3":
- streamer, format, err = mp3.Decode(f)
- case ".wav":
- streamer, format, err = wav.Decode(f)
- case ".m4a":
- duration, err := probe(f)
- if err != nil {
- return 0, err
- }
- return duration, nil
- default:
- return 0, errors.New("unsupported audio file format")
- }
-
- if err != nil {
- return 0, err
- }
- defer streamer.Close()
-
- // Calculate the audio file's duration.
- numSamples := streamer.Len()
- sampleRate := format.SampleRate
- duration := time.Duration(numSamples) * time.Second / time.Duration(sampleRate)
-
- return duration, nil
-}
-
-func ParseWhisperRequestTokens(c *gin.Context, usage *store.Tokens, byteBody []byte) error {
- file, _ := c.FormFile("file")
- model, _ := c.GetPostForm("model")
- usage.Model = model
-
- if file != nil {
- duration, err := getAudioDuration(file)
- if err != nil {
- return fmt.Errorf("Error getting audio duration:%s", err)
- }
-
- if duration > 5*time.Minute {
- return fmt.Errorf("Audio duration exceeds 5 minutes")
- }
- // 计算时长,四舍五入到最接近的秒数
- usage.PromptCount = int(duration.Round(time.Second).Seconds())
- }
-
- c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody))
-
- return nil
+ return
}
diff --git a/store/cache.go b/store/cache.go
index 64009f4..2a9462e 100644
--- a/store/cache.go
+++ b/store/cache.go
@@ -53,6 +53,14 @@ func SelectKeyCache(apitype string) (Key, error) {
if item.Object.(Key).ApiType == "azure" {
keys = append(keys, item.Object.(Key))
}
+ if item.Object.(Key).ApiType == "github" {
+ keys = append(keys, item.Object.(Key))
+ }
+ }
+ if apitype == "claude" {
+ if item.Object.(Key).ApiType == "vertex" {
+ keys = append(keys, item.Object.(Key))
+ }
}
}
if len(keys) == 0 {
diff --git a/store/keydb.go b/store/keydb.go
index c5d106b..c7d55db 100644
--- a/store/keydb.go
+++ b/store/keydb.go
@@ -2,9 +2,35 @@ package store
import (
"encoding/json"
+ "fmt"
+ "log"
+ "opencatd-open/pkg/vertexai"
+ "os"
"time"
)
+func init() {
+ // check vertex
+ if os.Getenv("Vertex") != "" {
+ vertex_auth := os.Getenv("Vertex")
+ var Vertex vertexai.VertexSecretKey
+ if err := json.Unmarshal([]byte(vertex_auth), &Vertex); err != nil {
+ log.Fatalln(err)
+ return
+ }
+ key := Key{
+ ApiType: "vertex",
+ Name: Vertex.ProjectID,
+ Key: vertex_auth,
+ ApiSecret: vertex_auth,
+ }
+ if err := db.FirstOrCreate(&key).Error; err != nil {
+ log.Println(fmt.Errorf("create vertex key error: %v", err))
+ }
+ }
+ LoadKeysCache()
+}
+
type Key struct {
ID uint `gorm:"primarykey" json:"id,omitempty"`
Key string `gorm:"unique;not null" json:"key,omitempty"`
@@ -14,6 +40,7 @@ type Key struct {
EndPoint string `gorm:"column:endpoint"`
ResourceNmae string `gorm:"column:resource_name"`
DeploymentName string `gorm:"column:deployment_name"`
+ ApiSecret string `gorm:"column:api_secret"`
CreatedAt time.Time `json:"createdAt,omitempty"`
UpdatedAt time.Time `json:"updatedAt,omitempty"`
}