refact & update new model
This commit is contained in:
29
README.md
29
README.md
@@ -1,4 +1,7 @@
|
||||
# opencatd-open
|
||||
# ~~opencatd-open~~ [OpenTeam](https://github.com/mirrors2/opencatd-open)
|
||||
|
||||
本项目即将更名,后续请关注 👉🏻 https://github.com/mirrors2/openteam
|
||||
|
||||
|
||||
<a title="Docker Image CI" target="_blank" href="https://github.com/mirrors2/opencatd-open/actions"><img alt="GitHub Workflow Status" src="https://img.shields.io/github/actions/workflow/status/mirrors2/opencatd-open/ci.yaml?label=Actions&logo=github&style=flat-square"></a>
|
||||
<a title="Docker Pulls" target="_blank" href="https://hub.docker.com/r/mirrors2/opencatd-open"><img src="https://img.shields.io/docker/pulls/mirrors2/opencatd-open.svg?logo=docker&label=docker&style=flat-square"></a>
|
||||
@@ -14,11 +17,11 @@ OpenCat for Team的开源实现
|
||||
|
||||
## Extra Support:
|
||||
|
||||
| 任务 | 完成情况 |
|
||||
| --- | --- |
|
||||
|[Azure OpenAI](./doc/azure.md) | ✅|
|
||||
|[Claude](./doc/azure.md) | ✅|
|
||||
|[Gemini](./doc/gemini.md) | ✅|
|
||||
| 🎯 | 🚧 |Extra Provider|
|
||||
| --- | --- | --- |
|
||||
|[OpenAI](./doc/azure.md) | ✅|Azure, Github Marketplace|
|
||||
|[Claude](./doc/azure.md) | ✅|VertexAI|
|
||||
|[Gemini](./doc/gemini.md) | ✅||
|
||||
| ... | ... |
|
||||
|
||||
|
||||
@@ -80,6 +83,18 @@ wget https://github.com/mirrors2/opencatd-open/raw/main/docker/docker-compose.ym
|
||||
|
||||
pandora for team
|
||||
- [pandora for team](./doc/pandora.md)
|
||||
|
||||
如何自定义HOST地址? (仅OpenAI)
|
||||
- 需修改环境变量,优先级递增
|
||||
- Cloudflare AI Gateway地址 `AIGateWay_Endpoint=https://gateway.ai.cloudflare.com/v1/123456789/xxxx/openai/chat/completions`
|
||||
- 自定义的endpoint `$CUSTOM_ENDPOINT=true && $OpenAI_Endpoint=https://your.domain/v1/chat/completions`
|
||||
|
||||
设置主页跳转地址?
|
||||
- 修改环境变量 `CUSTOM_REDIRECT=https://your.domain`
|
||||
|
||||
## 赞助
|
||||
[](https://www.buymeacoffee.com/littlecjun)
|
||||
|
||||
# License
|
||||
|
||||
[GNU General Public License v3.0](License)
|
||||
[](https://github.com/mirrors2/opencatd-open/blob/main/License)
|
||||
|
||||
26
docker-compose.yml
Normal file
26
docker-compose.yml
Normal file
@@ -0,0 +1,26 @@
|
||||
# Email: admin@example.com
|
||||
# Password: changeme
|
||||
version: '3'
|
||||
|
||||
services:
|
||||
npm:
|
||||
image: jc21/nginx-proxy-manager
|
||||
network_mode: host
|
||||
ports:
|
||||
- '80:80'
|
||||
- '81:81'
|
||||
- '443:443'
|
||||
volumes:
|
||||
- $PWD/data:/data
|
||||
- $PWD/www:/var/www
|
||||
- $PWD/letsencrypt:/etc/letsencrypt
|
||||
environment:
|
||||
- "TZ=Asia/Shanghai" # set timezone, default UTC
|
||||
- "PUID=1000" # set group id, default 0 (root)
|
||||
- "PGID=1000"
|
||||
|
||||
# certbot:
|
||||
# image: certbot/certbot
|
||||
# volumes:
|
||||
# - $PWD/data/certbot/conf:/etc/letsencrypt
|
||||
# - $PWD/data/certbot/www:/var/www/certbot
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM node:18.12.1-alpine3.16 AS frontend
|
||||
FROM node:20-alpine AS frontend
|
||||
WORKDIR /frontend-build
|
||||
COPY ./web/ .
|
||||
RUN npm install && npm run build && rm -rf node_modules
|
||||
|
||||
@@ -1,10 +1,24 @@
|
||||
version: '3.7'
|
||||
services:
|
||||
services:
|
||||
opencatd:
|
||||
image: mirrors2/opencatd-open
|
||||
container_name: opencatd-open
|
||||
container_name: opencatd-open
|
||||
restart: unless-stopped
|
||||
#network_mode: host
|
||||
ports:
|
||||
- 80:80
|
||||
volumes:
|
||||
- /etc/opencatd:/app/db
|
||||
- $PWD/db:/app/db
|
||||
logging:
|
||||
# driver: "json-file"
|
||||
options:
|
||||
max-size: 10m
|
||||
max-file: 3
|
||||
# environment:
|
||||
# Vertex: |
|
||||
# {
|
||||
# "type": "service_account",
|
||||
# "universe_domain": "googleapis.com"
|
||||
# }
|
||||
|
||||
|
||||
49
opencat.go
49
opencat.go
@@ -8,6 +8,7 @@ import (
|
||||
"io/fs"
|
||||
"log"
|
||||
"net/http"
|
||||
"opencatd-open/pkg/team"
|
||||
"opencatd-open/router"
|
||||
"opencatd-open/store"
|
||||
"os"
|
||||
@@ -143,41 +144,29 @@ func main() {
|
||||
r := gin.Default()
|
||||
group := r.Group("/1")
|
||||
{
|
||||
group.Use(router.AuthMiddleware())
|
||||
group.Use(team.AuthMiddleware())
|
||||
|
||||
// 获取当前用户信息
|
||||
group.GET("/me", router.HandleMe)
|
||||
group.GET("/me", team.HandleMe)
|
||||
group.GET("/me/usages", team.HandleMeUsage)
|
||||
|
||||
group.GET("/me/usages", router.HandleMeUsage)
|
||||
group.GET("/keys", team.HandleKeys) // 获取所有Key
|
||||
group.POST("/keys", team.HandleAddKey) // 添加Key
|
||||
group.DELETE("/keys/:id", team.HandleDelKey) // 删除Key
|
||||
|
||||
// 获取所有Key
|
||||
group.GET("/keys", router.HandleKeys)
|
||||
group.GET("/users", team.HandleUsers) // 获取所有用户信息
|
||||
group.POST("/users", team.HandleAddUser) // 添加用户
|
||||
group.DELETE("/users/:id", team.HandleDelUser) // 删除用户
|
||||
|
||||
// 获取所有用户信息
|
||||
group.GET("/users", router.HandleUsers)
|
||||
|
||||
group.GET("/usages", router.HandleUsage)
|
||||
|
||||
// 添加Key
|
||||
group.POST("/keys", router.HandleAddKey)
|
||||
|
||||
// 删除Key
|
||||
group.DELETE("/keys/:id", router.HandleDelKey)
|
||||
|
||||
// 添加用户
|
||||
group.POST("/users", router.HandleAddUser)
|
||||
|
||||
// 删除用户
|
||||
group.DELETE("/users/:id", router.HandleDelUser)
|
||||
group.GET("/usages", team.HandleUsage)
|
||||
|
||||
// 重置用户Token
|
||||
group.POST("/users/:id/reset", router.HandleResetUserToken)
|
||||
group.POST("/users/:id/reset", team.HandleResetUserToken)
|
||||
}
|
||||
|
||||
// 初始化用户
|
||||
r.POST("/1/users/init", router.Handleinit)
|
||||
r.POST("/1/users/init", team.Handleinit)
|
||||
|
||||
r.Any("/v1/*proxypath", router.HandleProy)
|
||||
r.Any("/v1/*proxypath", router.HandleProxy)
|
||||
|
||||
// r.POST("/v1/chat/completions", router.HandleProy)
|
||||
// r.GET("/v1/models", router.HandleProy)
|
||||
@@ -188,7 +177,15 @@ func main() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
r.GET("/", gin.WrapH(http.FileServer(http.FS(idxFS))))
|
||||
redirect := os.Getenv("CUSTOM_REDIRECT")
|
||||
if redirect != "" {
|
||||
r.GET("/", func(c *gin.Context) {
|
||||
c.Redirect(http.StatusMovedPermanently, redirect)
|
||||
})
|
||||
|
||||
} else {
|
||||
r.GET("/", gin.WrapH(http.FileServer(http.FS(idxFS))))
|
||||
}
|
||||
assetsFS, err := fs.Sub(web, "dist/assets")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
||||
@@ -10,8 +10,10 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"opencatd-open/pkg/error"
|
||||
"opencatd-open/pkg/openai"
|
||||
"opencatd-open/pkg/tokenizer"
|
||||
"opencatd-open/pkg/vertexai"
|
||||
"opencatd-open/store"
|
||||
"strings"
|
||||
|
||||
@@ -27,14 +29,15 @@ func ChatTextCompletions(c *gin.Context, chatReq *openai.ChatCompletionRequest)
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages any `json:"messages,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages any `json:"messages,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
AnthropicVersion string `json:"anthropic_version,omitempty"`
|
||||
}
|
||||
|
||||
func (c *ChatRequest) ByteJson() []byte {
|
||||
@@ -117,8 +120,12 @@ type ClaudeStreamResponse struct {
|
||||
}
|
||||
|
||||
func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
|
||||
var (
|
||||
req *http.Request
|
||||
targetURL = ClaudeMessageEndpoint
|
||||
)
|
||||
|
||||
onekey, err := store.SelectKeyCache("claude")
|
||||
apiKey, err := store.SelectKeyCache("claude")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -131,6 +138,10 @@ func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
|
||||
// claudReq.Temperature = chatReq.Temperature
|
||||
claudReq.TopP = chatReq.TopP
|
||||
claudReq.MaxTokens = 4096
|
||||
if apiKey.ApiType == "vertex" {
|
||||
claudReq.AnthropicVersion = "vertex-2023-10-16"
|
||||
claudReq.Model = ""
|
||||
}
|
||||
|
||||
var prompt string
|
||||
|
||||
@@ -181,10 +192,40 @@ func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
|
||||
|
||||
usagelog.PromptCount = tokenizer.NumTokensFromStr(prompt, chatReq.Model)
|
||||
|
||||
req, _ := http.NewRequest("POST", MessageEndpoint, bytes.NewReader(claudReq.ByteJson()))
|
||||
req.Header.Set("x-api-key", onekey.Key)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if apiKey.ApiType == "vertex" {
|
||||
var vertexSecret vertexai.VertexSecretKey
|
||||
if err := json.Unmarshal([]byte(apiKey.ApiSecret), &vertexSecret); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, error.ErrorData(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
vcmodel, ok := vertexai.VertexClaudeModelMap[chatReq.Model]
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, error.ErrorData("Model not found"))
|
||||
return
|
||||
}
|
||||
|
||||
// 获取gcloud token,临时放置在apiKey.Key中
|
||||
gcloudToken, err := vertexai.GcloudAuth(vertexSecret.ClientEmail, vertexSecret.PrivateKey)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, error.ErrorData(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// 拼接vertex的请求地址
|
||||
targetURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", vcmodel.Region, vertexSecret.ProjectID, vcmodel.Region, vcmodel.VertexName)
|
||||
|
||||
req, _ = http.NewRequest("POST", targetURL, bytes.NewReader(claudReq.ByteJson()))
|
||||
req.Header.Set("Authorization", "Bearer "+gcloudToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Accept-Encoding", "identity")
|
||||
} else {
|
||||
req, _ = http.NewRequest("POST", targetURL, bytes.NewReader(claudReq.ByteJson()))
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
client := http.DefaultClient
|
||||
rsp, err := client.Do(req)
|
||||
|
||||
@@ -68,8 +68,8 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ClaudeUrl = "https://api.anthropic.com/v1/complete"
|
||||
MessageEndpoint = "https://api.anthropic.com/v1/messages"
|
||||
ClaudeUrl = "https://api.anthropic.com/v1/complete"
|
||||
ClaudeMessageEndpoint = "https://api.anthropic.com/v1/messages"
|
||||
)
|
||||
|
||||
type MessageModule struct {
|
||||
|
||||
11
pkg/error/errdata.go
Normal file
11
pkg/error/errdata.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package error
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func ErrorData(message string) gin.H {
|
||||
return gin.H{
|
||||
"error": gin.H{
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -17,8 +17,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
AzureApiVersion = "2024-02-01"
|
||||
OpenAI_Endpoint = "https://api.openai.com/v1/chat/completions"
|
||||
AzureApiVersion = "2024-02-01"
|
||||
BaseHost = "api.openai.com"
|
||||
OpenAI_Endpoint = "https://api.openai.com/v1/chat/completions"
|
||||
Github_Marketplace = "https://models.inference.ai.azure.com/chat/completions"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -204,6 +206,10 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) {
|
||||
var req *http.Request
|
||||
|
||||
switch onekey.ApiType {
|
||||
case "github":
|
||||
req, err = http.NewRequest(c.Request.Method, Github_Marketplace, bytes.NewReader(chatReq.ToByteJson()))
|
||||
req.Header = c.Request.Header
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key))
|
||||
case "azure":
|
||||
var buildurl string
|
||||
if onekey.EndPoint != "" {
|
||||
@@ -215,15 +221,18 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) {
|
||||
req.Header = c.Request.Header
|
||||
req.Header.Set("api-key", onekey.Key)
|
||||
default:
|
||||
req, err = http.NewRequest(c.Request.Method, OpenAI_Endpoint, bytes.NewReader(chatReq.ToByteJson()))
|
||||
if onekey.EndPoint != "" { // 优先key的endpoint
|
||||
req, err = http.NewRequest(c.Request.Method, onekey.EndPoint+c.Request.RequestURI, bytes.NewReader(chatReq.ToByteJson()))
|
||||
} else {
|
||||
if BaseURL != "" { // 其次BaseURL
|
||||
req, err = http.NewRequest(c.Request.Method, BaseURL+c.Request.RequestURI, bytes.NewReader(chatReq.ToByteJson()))
|
||||
} else { // 最后是gateway的endpoint
|
||||
req, err = http.NewRequest(c.Request.Method, AIGateWay_Endpoint, bytes.NewReader(chatReq.ToByteJson()))
|
||||
}
|
||||
}
|
||||
if AIGateWay_Endpoint != "" { // cloudflare gateway的endpoint
|
||||
req, err = http.NewRequest(c.Request.Method, AIGateWay_Endpoint, bytes.NewReader(chatReq.ToByteJson()))
|
||||
}
|
||||
customEndpoint := os.Getenv("CUSTOM_ENDPOINT") // 最后是用户自定义的endpoint CUSTOM_ENDPOINT=true OpenAI_Endpoint
|
||||
if customEndpoint == "true" && OpenAI_Endpoint != "" {
|
||||
req, err = http.NewRequest(c.Request.Method, BaseURL, bytes.NewReader(chatReq.ToByteJson()))
|
||||
}
|
||||
|
||||
req.Header = c.Request.Header
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key))
|
||||
}
|
||||
|
||||
177
pkg/openai/whisper.go
Normal file
177
pkg/openai/whisper.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"opencatd-open/pkg/tokenizer"
|
||||
"opencatd-open/store"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/faiface/beep"
|
||||
"github.com/faiface/beep/mp3"
|
||||
"github.com/faiface/beep/wav"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gopkg.in/vansante/go-ffprobe.v2"
|
||||
)
|
||||
|
||||
func WhisperProxy(c *gin.Context) {
|
||||
var chatlog store.Tokens
|
||||
|
||||
byteBody, _ := io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody))
|
||||
|
||||
model, _ := c.GetPostForm("model")
|
||||
|
||||
key, err := store.SelectKeyCache("openai")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chatlog.Model = model
|
||||
|
||||
token, _ := c.Get("localuser")
|
||||
|
||||
lu, err := store.GetUserByToken(token.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
chatlog.UserID = int(lu.ID)
|
||||
|
||||
if err := ParseWhisperRequestTokens(c, &chatlog, byteBody); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
if key.EndPoint == "" {
|
||||
key.EndPoint = "https://api.openai.com"
|
||||
}
|
||||
targetUrl, _ := url.ParseRequestURI(key.EndPoint + c.Request.URL.String())
|
||||
log.Println(targetUrl)
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetUrl)
|
||||
proxy.Director = func(req *http.Request) {
|
||||
req.Host = targetUrl.Host
|
||||
req.URL.Scheme = targetUrl.Scheme
|
||||
req.URL.Host = targetUrl.Host
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+key.Key)
|
||||
}
|
||||
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
|
||||
chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
|
||||
if err := store.Record(&chatlog); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
if err := store.SumDaily(chatlog.UserID); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
proxy.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
|
||||
func probe(fileReader io.Reader) (time.Duration, error) {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancelFn()
|
||||
|
||||
data, err := ffprobe.ProbeReader(ctx, fileReader)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
duration := data.Format.DurationSeconds
|
||||
pduration, err := time.ParseDuration(fmt.Sprintf("%fs", duration))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("Error parsing duration: %s", err)
|
||||
}
|
||||
return pduration, nil
|
||||
}
|
||||
|
||||
func getAudioDuration(file *multipart.FileHeader) (time.Duration, error) {
|
||||
var (
|
||||
streamer beep.StreamSeekCloser
|
||||
format beep.Format
|
||||
err error
|
||||
)
|
||||
|
||||
f, err := file.Open()
|
||||
defer f.Close()
|
||||
|
||||
// Get the file extension to determine the audio file type
|
||||
fileType := filepath.Ext(file.Filename)
|
||||
|
||||
switch fileType {
|
||||
case ".mp3":
|
||||
streamer, format, err = mp3.Decode(f)
|
||||
case ".wav":
|
||||
streamer, format, err = wav.Decode(f)
|
||||
case ".m4a":
|
||||
duration, err := probe(f)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return duration, nil
|
||||
default:
|
||||
return 0, errors.New("unsupported audio file format")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer streamer.Close()
|
||||
|
||||
// Calculate the audio file's duration.
|
||||
numSamples := streamer.Len()
|
||||
sampleRate := format.SampleRate
|
||||
duration := time.Duration(numSamples) * time.Second / time.Duration(sampleRate)
|
||||
|
||||
return duration, nil
|
||||
}
|
||||
|
||||
func ParseWhisperRequestTokens(c *gin.Context, usage *store.Tokens, byteBody []byte) error {
|
||||
file, _ := c.FormFile("file")
|
||||
model, _ := c.GetPostForm("model")
|
||||
usage.Model = model
|
||||
|
||||
if file != nil {
|
||||
duration, err := getAudioDuration(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error getting audio duration:%s", err)
|
||||
}
|
||||
|
||||
if duration > 5*time.Minute {
|
||||
return fmt.Errorf("Audio duration exceeds 5 minutes")
|
||||
}
|
||||
// 计算时长,四舍五入到最接近的秒数
|
||||
usage.PromptCount = int(duration.Round(time.Second).Seconds())
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody))
|
||||
|
||||
return nil
|
||||
}
|
||||
182
pkg/team/key.go
Normal file
182
pkg/team/key.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package team
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"opencatd-open/pkg/azureopenai"
|
||||
"opencatd-open/store"
|
||||
"strings"
|
||||
|
||||
"github.com/Sakurasan/to"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Key struct {
|
||||
ID int `json:"id,omitempty"`
|
||||
Key string `json:"key,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ApiType string `json:"api_type,omitempty"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
UpdatedAt string `json:"updatedAt,omitempty"`
|
||||
CreatedAt string `json:"createdAt,omitempty"`
|
||||
}
|
||||
|
||||
func HandleKeys(c *gin.Context) {
|
||||
keys, err := store.GetAllKeys()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, keys)
|
||||
}
|
||||
|
||||
func HandleAddKey(c *gin.Context) {
|
||||
var body Key
|
||||
if err := c.BindJSON(&body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
|
||||
"message": err.Error(),
|
||||
}})
|
||||
return
|
||||
}
|
||||
body.Name = strings.ToLower(strings.TrimSpace(body.Name))
|
||||
body.Key = strings.TrimSpace(body.Key)
|
||||
if strings.HasPrefix(body.Name, "azure.") {
|
||||
keynames := strings.Split(body.Name, ".")
|
||||
if len(keynames) < 2 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
|
||||
"message": "Invalid Key Name",
|
||||
}})
|
||||
return
|
||||
}
|
||||
k := &store.Key{
|
||||
ApiType: "azure",
|
||||
Name: body.Name,
|
||||
Key: body.Key,
|
||||
ResourceNmae: keynames[1],
|
||||
EndPoint: body.Endpoint,
|
||||
}
|
||||
if err := store.CreateKey(k); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
|
||||
"message": err.Error(),
|
||||
}})
|
||||
return
|
||||
}
|
||||
} else if strings.HasPrefix(body.Name, "claude.") {
|
||||
keynames := strings.Split(body.Name, ".")
|
||||
if len(keynames) < 2 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
|
||||
"message": "Invalid Key Name",
|
||||
}})
|
||||
return
|
||||
}
|
||||
if body.Endpoint == "" {
|
||||
body.Endpoint = "https://api.anthropic.com"
|
||||
}
|
||||
k := &store.Key{
|
||||
// ApiType: "anthropic",
|
||||
ApiType: "claude",
|
||||
Name: body.Name,
|
||||
Key: body.Key,
|
||||
ResourceNmae: keynames[1],
|
||||
EndPoint: body.Endpoint,
|
||||
}
|
||||
if err := store.CreateKey(k); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
|
||||
"message": err.Error(),
|
||||
}})
|
||||
return
|
||||
}
|
||||
} else if strings.HasPrefix(body.Name, "google.") {
|
||||
keynames := strings.Split(body.Name, ".")
|
||||
if len(keynames) < 2 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
|
||||
"message": "Invalid Key Name",
|
||||
}})
|
||||
return
|
||||
}
|
||||
|
||||
k := &store.Key{
|
||||
// ApiType: "anthropic",
|
||||
ApiType: "google",
|
||||
Name: body.Name,
|
||||
Key: body.Key,
|
||||
ResourceNmae: keynames[1],
|
||||
EndPoint: body.Endpoint,
|
||||
}
|
||||
if err := store.CreateKey(k); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
|
||||
"message": err.Error(),
|
||||
}})
|
||||
return
|
||||
}
|
||||
} else if strings.HasPrefix(body.Name, "github.") {
|
||||
keynames := strings.Split(body.Name, ".")
|
||||
if len(keynames) < 2 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
|
||||
"message": "Invalid Key Name",
|
||||
}})
|
||||
return
|
||||
}
|
||||
|
||||
k := &store.Key{
|
||||
ApiType: "github",
|
||||
Name: body.Name,
|
||||
Key: body.Key,
|
||||
ResourceNmae: keynames[1],
|
||||
EndPoint: body.Endpoint,
|
||||
}
|
||||
if err := store.CreateKey(k); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
|
||||
"message": err.Error(),
|
||||
}})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if body.ApiType == "" {
|
||||
if err := store.AddKey("openai", body.Key, body.Name); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
|
||||
"message": err.Error(),
|
||||
}})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
k := &store.Key{
|
||||
ApiType: body.ApiType,
|
||||
Name: body.Name,
|
||||
Key: body.Key,
|
||||
ResourceNmae: azureopenai.GetResourceName(body.Endpoint),
|
||||
EndPoint: body.Endpoint,
|
||||
}
|
||||
if err := store.CreateKey(k); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
|
||||
"message": err.Error(),
|
||||
}})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
k, err := store.GetKeyrByName(body.Name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
|
||||
"message": err.Error(),
|
||||
}})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, k)
|
||||
}
|
||||
|
||||
func HandleDelKey(c *gin.Context) {
|
||||
id := to.Int(c.Param("id"))
|
||||
if id < 1 {
|
||||
c.JSON(http.StatusOK, gin.H{"error": "invalid key id"})
|
||||
return
|
||||
}
|
||||
if err := store.DeleteKey(uint(id)); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"error": "invalid key id"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
104
pkg/team/me.go
Normal file
104
pkg/team/me.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package team
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"opencatd-open/store"
|
||||
"time"
|
||||
|
||||
"github.com/Sakurasan/to"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func Handleinit(c *gin.Context) {
|
||||
user, err := store.GetUserByID(1)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
u := store.User{Name: "root", Token: uuid.NewString()}
|
||||
u.ID = 1
|
||||
if err := store.CreateUser(&u); err != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
} else {
|
||||
rootToken = u.Token
|
||||
resJSON := User{
|
||||
false,
|
||||
int(u.ID),
|
||||
u.UpdatedAt.Format(time.RFC3339),
|
||||
u.Name,
|
||||
u.Token,
|
||||
u.CreatedAt.Format(time.RFC3339),
|
||||
}
|
||||
c.JSON(http.StatusOK, resJSON)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if user.ID == uint(1) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"error": "super user already exists, use cli to reset password",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func HandleMe(c *gin.Context) {
|
||||
token := c.GetHeader("Authorization")
|
||||
u, err := store.GetUserByToken(token[7:])
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
resJSON := User{
|
||||
false,
|
||||
int(u.ID),
|
||||
u.UpdatedAt.Format(time.RFC3339),
|
||||
u.Name,
|
||||
u.Token,
|
||||
u.CreatedAt.Format(time.RFC3339),
|
||||
}
|
||||
c.JSON(http.StatusOK, resJSON)
|
||||
}
|
||||
|
||||
func HandleMeUsage(c *gin.Context) {
|
||||
token := c.GetHeader("Authorization")
|
||||
fromStr := c.Query("from")
|
||||
toStr := c.Query("to")
|
||||
getMonthStartAndEnd := func() (start, end string) {
|
||||
loc, _ := time.LoadLocation("Local")
|
||||
now := time.Now().In(loc)
|
||||
|
||||
year, month, _ := now.Date()
|
||||
|
||||
startOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, loc)
|
||||
endOfMonth := startOfMonth.AddDate(0, 1, 0)
|
||||
|
||||
start = startOfMonth.Format("2006-01-02")
|
||||
end = endOfMonth.Format("2006-01-02")
|
||||
return
|
||||
}
|
||||
if fromStr == "" || toStr == "" {
|
||||
fromStr, toStr = getMonthStartAndEnd()
|
||||
}
|
||||
user, err := store.GetUserByToken(token)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusForbidden, err)
|
||||
return
|
||||
}
|
||||
usage, err := store.QueryUserUsage(to.String(user.ID), fromStr, toStr)
|
||||
if err != nil {
|
||||
c.AbortWithError(http.StatusForbidden, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, usage)
|
||||
}
|
||||
59
pkg/team/middleware.go
Normal file
59
pkg/team/middleware.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package team
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"opencatd-open/store"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var (
|
||||
rootToken string
|
||||
)
|
||||
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if rootToken == "" {
|
||||
u, err := store.GetUserByID(uint(1))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
rootToken = u.Token
|
||||
}
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" || token[:7] != "Bearer " {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if store.IsExistAuthCache(token[7:]) {
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/1/me") {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
if token[7:] != rootToken {
|
||||
u, err := store.GetUserByID(uint(1))
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if token[:7] != u.Token {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
rootToken = u.Token
|
||||
store.LoadAuthCache()
|
||||
}
|
||||
// 可以在这里对 token 进行验证并检查权限
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
38
pkg/team/usage.go
Normal file
38
pkg/team/usage.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package team
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"opencatd-open/store"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func HandleUsage(c *gin.Context) {
|
||||
fromStr := c.Query("from")
|
||||
toStr := c.Query("to")
|
||||
getMonthStartAndEnd := func() (start, end string) {
|
||||
loc, _ := time.LoadLocation("Local")
|
||||
now := time.Now().In(loc)
|
||||
|
||||
year, month, _ := now.Date()
|
||||
|
||||
startOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, loc)
|
||||
endOfMonth := startOfMonth.AddDate(0, 1, 0)
|
||||
|
||||
start = startOfMonth.Format("2006-01-02")
|
||||
end = endOfMonth.Format("2006-01-02")
|
||||
return
|
||||
}
|
||||
if fromStr == "" || toStr == "" {
|
||||
fromStr, toStr = getMonthStartAndEnd()
|
||||
}
|
||||
|
||||
usage, err := store.QueryUsage(fromStr, toStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, usage)
|
||||
}
|
||||
89
pkg/team/user.go
Normal file
89
pkg/team/user.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package team
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"opencatd-open/store"
|
||||
|
||||
"github.com/Sakurasan/to"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
IsDelete bool `json:"IsDelete,omitempty"`
|
||||
ID int `json:"id,omitempty"`
|
||||
UpdatedAt string `json:"updatedAt,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
CreatedAt string `json:"createdAt,omitempty"`
|
||||
}
|
||||
|
||||
func HandleUsers(c *gin.Context) {
|
||||
users, err := store.GetAllUsers()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, users)
|
||||
}
|
||||
|
||||
func HandleAddUser(c *gin.Context) {
|
||||
var body User
|
||||
if err := c.BindJSON(&body); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if len(body.Name) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"error": "invalid user name"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := store.AddUser(body.Name, uuid.NewString()); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
u, err := store.GetUserByName(body.Name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, u)
|
||||
}
|
||||
|
||||
func HandleDelUser(c *gin.Context) {
|
||||
id := to.Int(c.Param("id"))
|
||||
if id <= 1 {
|
||||
c.JSON(http.StatusOK, gin.H{"error": "invalid user id"})
|
||||
return
|
||||
}
|
||||
if err := store.DeleteUser(uint(id)); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||
}
|
||||
|
||||
func HandleResetUserToken(c *gin.Context) {
|
||||
id := to.Int(c.Param("id"))
|
||||
newtoken := c.Query("token")
|
||||
if newtoken == "" {
|
||||
newtoken = uuid.NewString()
|
||||
}
|
||||
|
||||
if err := store.UpdateUser(uint(id), newtoken); err != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
u, err := store.GetUserByID(uint(id))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if u.ID == uint(1) {
|
||||
rootToken = u.Token
|
||||
}
|
||||
c.JSON(http.StatusOK, u)
|
||||
}
|
||||
@@ -97,8 +97,14 @@ func Cost(model string, promptCount, completionCount int) float64 {
|
||||
cost = 0.01*float64(prompt/1000) + 0.03*float64(completion/1000)
|
||||
case "gpt-4o", "gpt-4o-2024-05-13":
|
||||
cost = 0.005*float64(prompt/1000) + 0.015*float64(completion/1000)
|
||||
case "gpt-4o-2024-08-06":
|
||||
cost = 0.0025*float64(prompt/1000) + 0.010*float64(completion/1000)
|
||||
case "gpt-4o-mini", "gpt-4o-mini-2024-07-18":
|
||||
cost = 0.00015*float64(prompt/1000) + 0.0006*float64(completion/1000)
|
||||
case "o1-preview", "o1-preview-2024-09-12":
|
||||
cost = 0.015*float64(prompt/1000) + 0.06*float64(completion/1000)
|
||||
case "o1-mini", "o1-mini-2024-09-12":
|
||||
cost = 0.003*float64(prompt/1000) + 0.012*float64(completion/1000)
|
||||
case "whisper-1":
|
||||
// 0.006$/min
|
||||
cost = 0.006 * float64(prompt+completion) / 60
|
||||
@@ -149,7 +155,8 @@ func Cost(model string, promptCount, completionCount int) float64 {
|
||||
cost = (0.003/1000)*float64(prompt) + (0.015/1000)*float64(completion)
|
||||
case "claude-3-opus-20240229":
|
||||
cost = (0.015/1000)*float64(prompt) + (0.075/1000)*float64(completion)
|
||||
|
||||
case "claude-3-5-sonnet", "claude-3-5-sonnet-20240620":
|
||||
cost = (0.003/1000)*float64(prompt) + (0.015/1000)*float64(completion)
|
||||
// google
|
||||
// https://ai.google.dev/pricing?hl=zh-cn
|
||||
case "gemini-pro":
|
||||
|
||||
167
pkg/vertexai/auth.go
Normal file
167
pkg/vertexai/auth.go
Normal file
@@ -0,0 +1,167 @@
|
||||
/*
|
||||
https://docs.anthropic.com/zh-CN/api/claude-on-vertex-ai
|
||||
|
||||
MODEL_ID=claude-3-5-sonnet@20240620
|
||||
REGION=us-east5
|
||||
PROJECT_ID=MY_PROJECT_ID
|
||||
|
||||
curl \
|
||||
-X POST \
|
||||
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
|
||||
-H "Content-Type: application/json" \
|
||||
https://$LOCATION-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/anthropic/models/${MODEL_ID}:streamRawPredict \
|
||||
-d '{
|
||||
"anthropic_version": "vertex-2023-10-16",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "介绍一下你自己"
|
||||
}],
|
||||
"stream": true,
|
||||
"max_tokens": 4096
|
||||
}'
|
||||
*/
|
||||
|
||||
package vertexai
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
// json文件存储在ApiKey.ApiSecret中
|
||||
type VertexSecretKey struct {
|
||||
Type string `json:"type"`
|
||||
ProjectID string `json:"project_id"`
|
||||
PrivateKeyID string `json:"private_key_id"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
ClientEmail string `json:"client_email"`
|
||||
ClientID string `json:"client_id"`
|
||||
AuthURI string `json:"auth_uri"`
|
||||
TokenURI string `json:"token_uri"`
|
||||
AuthProviderX509CertURL string `json:"auth_provider_x509_cert_url"`
|
||||
ClientX509CertURL string `json:"client_x509_cert_url"`
|
||||
UniverseDomain string `json:"universe_domain"`
|
||||
}
|
||||
|
||||
type VertexClaudeModel struct {
|
||||
VertexName string
|
||||
Region string
|
||||
}
|
||||
|
||||
var VertexClaudeModelMap = map[string]VertexClaudeModel{
|
||||
"claude-3-opus": {
|
||||
VertexName: "claude-3-opus@20240229",
|
||||
Region: "us-east5",
|
||||
},
|
||||
"claude-3-sonnet": {
|
||||
VertexName: "claude-3-sonnet@20240229",
|
||||
Region: "us-central1",
|
||||
// Region: "asia-southeast1",
|
||||
},
|
||||
"claude-3-haiku": {
|
||||
VertexName: "claude-3-haiku@20240307",
|
||||
Region: "us-central1",
|
||||
// Region: "europe-west4",
|
||||
},
|
||||
"claude-3-opus-20240229": {
|
||||
VertexName: "claude-3-opus@20240229",
|
||||
Region: "us-east5",
|
||||
},
|
||||
"claude-3-sonnet-20240229": {
|
||||
VertexName: "claude-3-sonnet@20240229",
|
||||
Region: "us-central1",
|
||||
// Region: "asia-southeast1",
|
||||
},
|
||||
"claude-3-haiku-20240307": {
|
||||
VertexName: "claude-3-haiku@20240307",
|
||||
Region: "us-central1",
|
||||
// Region: "europe-west4",
|
||||
},
|
||||
"claude-3-5-sonnet": {
|
||||
VertexName: "claude-3-5-sonnet@20240620",
|
||||
Region: "us-east5",
|
||||
// Region: "europe-west1",
|
||||
},
|
||||
"claude-3-5-sonnet-20240620": {
|
||||
VertexName: "claude-3-5-sonnet@20240620",
|
||||
Region: "us-east5",
|
||||
// Region: "europe-west1",
|
||||
},
|
||||
}
|
||||
|
||||
func createSignedJWT(email, privateKeyPEM string) (string, error) {
|
||||
block, _ := pem.Decode([]byte(privateKeyPEM))
|
||||
if block == nil {
|
||||
return "", fmt.Errorf("failed to parse PEM block containing the private key")
|
||||
}
|
||||
|
||||
privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
rsaKey, ok := privateKey.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("not an RSA private key")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": email,
|
||||
"aud": "https://www.googleapis.com/oauth2/v4/token",
|
||||
"iat": now.Unix(),
|
||||
"exp": now.Add(10 * time.Minute).Unix(),
|
||||
"scope": "https://www.googleapis.com/auth/cloud-platform",
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
return token.SignedString(rsaKey)
|
||||
}
|
||||
|
||||
func exchangeJwtForAccessToken(signedJWT string) (string, error) {
|
||||
authURL := "https://www.googleapis.com/oauth2/v4/token"
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
|
||||
data.Set("assertion", signedJWT)
|
||||
|
||||
resp, err := http.PostForm(authURL, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
accessToken, ok := result["access_token"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("access token not found in response")
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
// 获取gcloud auth token
|
||||
func GcloudAuth(ClientEmail, PrivateKey string) (string, error) {
|
||||
signedJWT, err := createSignedJWT(ClientEmail, PrivateKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
token, err := exchangeJwtForAccessToken(signedJWT)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Invalid jwt token: %v\n", err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
@@ -18,33 +18,7 @@ func ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// if chatreq.Messages[len(chatreq.Messages)-1].Role == "user" {
|
||||
// result, err := search.BingSearch(search.SearchParams{Query: string(chatreq.Messages[len(chatreq.Messages)-1].Content)})
|
||||
// if err == nil {
|
||||
// var msgs []openai.ChatCompletionMessage
|
||||
// for i, m := range chatreq.Messages {
|
||||
// var buf bytes.Buffer
|
||||
// buf.WriteString("根据我提问的语言回答我,我将提供一些从搜索引擎获取的信息(以websearch:开头)。你自行判断是否使用搜索引擎获取的内容。不要原封不动照抄,根据你自己的知识库提炼信息之后回答我\n\n")
|
||||
// if m.Role == "system" {
|
||||
// buf.Write(m.Content)
|
||||
// msgs = append(msgs, openai.ChatCompletionMessage{Role: m.Role, Content: buf.Bytes()})
|
||||
// } else {
|
||||
// msgs = append(msgs, openai.ChatCompletionMessage{Role: m.Role, Content: buf.Bytes()})
|
||||
// }
|
||||
// if i == len(chatreq.Messages)-1 {
|
||||
// m.Content = append(m.Content, json.RawMessage("\n\nwebsearch:")...)
|
||||
// m.Content = append(m.Content, json.RawMessage(result.(string))...)
|
||||
// msgs = append(msgs, openai.ChatCompletionMessage{Role: m.Role, Content: m.Content})
|
||||
// } else {
|
||||
// msgs = append(msgs, openai.ChatCompletionMessage{Role: m.Role, Content: m.Content})
|
||||
// }
|
||||
|
||||
// }
|
||||
// chatreq.Messages = msgs
|
||||
// }
|
||||
// }
|
||||
|
||||
if strings.HasPrefix(chatreq.Model, "gpt") {
|
||||
if strings.HasPrefix(chatreq.Model, "gpt") || strings.HasPrefix(chatreq.Model, "o1-") {
|
||||
openai.ChatProxy(c, &chatreq)
|
||||
return
|
||||
}
|
||||
|
||||
948
router/router.go
948
router/router.go
File diff suppressed because it is too large
Load Diff
@@ -53,6 +53,14 @@ func SelectKeyCache(apitype string) (Key, error) {
|
||||
if item.Object.(Key).ApiType == "azure" {
|
||||
keys = append(keys, item.Object.(Key))
|
||||
}
|
||||
if item.Object.(Key).ApiType == "github" {
|
||||
keys = append(keys, item.Object.(Key))
|
||||
}
|
||||
}
|
||||
if apitype == "claude" {
|
||||
if item.Object.(Key).ApiType == "vertex" {
|
||||
keys = append(keys, item.Object.(Key))
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
|
||||
@@ -2,9 +2,35 @@ package store
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"opencatd-open/pkg/vertexai"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// check vertex
|
||||
if os.Getenv("Vertex") != "" {
|
||||
vertex_auth := os.Getenv("Vertex")
|
||||
var Vertex vertexai.VertexSecretKey
|
||||
if err := json.Unmarshal([]byte(vertex_auth), &Vertex); err != nil {
|
||||
log.Fatalln(err)
|
||||
return
|
||||
}
|
||||
key := Key{
|
||||
ApiType: "vertex",
|
||||
Name: Vertex.ProjectID,
|
||||
Key: vertex_auth,
|
||||
ApiSecret: vertex_auth,
|
||||
}
|
||||
if err := db.FirstOrCreate(&key).Error; err != nil {
|
||||
log.Println(fmt.Errorf("create vertex key error: %v", err))
|
||||
}
|
||||
}
|
||||
LoadKeysCache()
|
||||
}
|
||||
|
||||
type Key struct {
|
||||
ID uint `gorm:"primarykey" json:"id,omitempty"`
|
||||
Key string `gorm:"unique;not null" json:"key,omitempty"`
|
||||
@@ -14,6 +40,7 @@ type Key struct {
|
||||
EndPoint string `gorm:"column:endpoint"`
|
||||
ResourceNmae string `gorm:"column:resource_name"`
|
||||
DeploymentName string `gorm:"column:deployment_name"`
|
||||
ApiSecret string `gorm:"column:api_secret"`
|
||||
CreatedAt time.Time `json:"createdAt,omitempty"`
|
||||
UpdatedAt time.Time `json:"updatedAt,omitempty"`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user