refact & update new model

This commit is contained in:
Sakurasan
2024-09-13 21:32:10 +08:00
parent c11824f5aa
commit 7fd82b43f4
21 changed files with 1094 additions and 975 deletions

View File

@@ -1,4 +1,7 @@
# opencatd-open
# ~~opencatd-open~~ [OpenTeam](https://github.com/mirrors2/opencatd-open)
本项目即将更名,后续请关注 👉🏻 https://github.com/mirrors2/openteam
<a title="Docker Image CI" target="_blank" href="https://github.com/mirrors2/opencatd-open/actions"><img alt="GitHub Workflow Status" src="https://img.shields.io/github/actions/workflow/status/mirrors2/opencatd-open/ci.yaml?label=Actions&logo=github&style=flat-square"></a>
<a title="Docker Pulls" target="_blank" href="https://hub.docker.com/r/mirrors2/opencatd-open"><img src="https://img.shields.io/docker/pulls/mirrors2/opencatd-open.svg?logo=docker&label=docker&style=flat-square"></a>
@@ -14,11 +17,11 @@ OpenCat for Team的开源实现
## Extra Support:
| 任务 | 完成情况 |
| --- | --- |
|[Azure OpenAI](./doc/azure.md) | ✅|
|[Claude](./doc/azure.md) | ✅|
|[Gemini](./doc/gemini.md) | ✅|
| 🎯 | 🚧 |Extra Provider|
| --- | --- | --- |
|[OpenAI](./doc/azure.md) | ✅|Azure, Github Marketplace|
|[Claude](./doc/azure.md) | ✅|VertexAI|
|[Gemini](./doc/gemini.md) | ✅||
| ... | ... |
@@ -80,6 +83,18 @@ wget https://github.com/mirrors2/opencatd-open/raw/main/docker/docker-compose.ym
pandora for team
- [pandora for team](./doc/pandora.md)
如何自定义HOST地址? (仅OpenAI)
- 需修改环境变量,优先级递增
- Cloudflare AI Gateway地址 `AIGateWay_Endpoint=https://gateway.ai.cloudflare.com/v1/123456789/xxxx/openai/chat/completions`
- 自定义的endpoint `$CUSTOM_ENDPOINT=true && $OpenAI_Endpoint=https://your.domain/v1/chat/completions`
设置主页跳转地址?
- 修改环境变量 `CUSTOM_REDIRECT=https://your.domain`
## 赞助
[![Buy Me A Coffee](https://img.shields.io/badge/Buy%20Me%20A%20Coffee-FFDD55?style=flat-square&logo=buy-me-a-coffee&logoColor=black)](https://www.buymeacoffee.com/littlecjun)
# License
[GNU General Public License v3.0](License)
[![GitHub License](https://img.shields.io/github/license/mirrors2/opencatd-open.svg?logo=github&style=flat-square)](https://github.com/mirrors2/opencatd-open/blob/main/License)

26
docker-compose.yml Normal file
View File

@@ -0,0 +1,26 @@
# Email: admin@example.com
# Password: changeme
version: '3'
services:
npm:
image: jc21/nginx-proxy-manager
network_mode: host
ports:
- '80:80'
- '81:81'
- '443:443'
volumes:
- $PWD/data:/data
- $PWD/www:/var/www
- $PWD/letsencrypt:/etc/letsencrypt
environment:
- "TZ=Asia/Shanghai" # set timezone, default UTC
- "PUID=1000" # set group id, default 0 (root)
- "PGID=1000"
# certbot:
# image: certbot/certbot
# volumes:
# - $PWD/data/certbot/conf:/etc/letsencrypt
# - $PWD/data/certbot/www:/var/www/certbot

View File

@@ -1,4 +1,4 @@
FROM node:18.12.1-alpine3.16 AS frontend
FROM node:20-alpine AS frontend
WORKDIR /frontend-build
COPY ./web/ .
RUN npm install && npm run build && rm -rf node_modules

View File

@@ -1,10 +1,24 @@
version: '3.7'
services:
services:
opencatd:
image: mirrors2/opencatd-open
container_name: opencatd-open
container_name: opencatd-open
restart: unless-stopped
#network_mode: host
ports:
- 80:80
volumes:
- /etc/opencatd:/app/db
- $PWD/db:/app/db
logging:
# driver: "json-file"
options:
max-size: 10m
max-file: 3
# environment:
# Vertex: |
# {
# "type": "service_account",
# "universe_domain": "googleapis.com"
# }

View File

@@ -8,6 +8,7 @@ import (
"io/fs"
"log"
"net/http"
"opencatd-open/pkg/team"
"opencatd-open/router"
"opencatd-open/store"
"os"
@@ -143,41 +144,29 @@ func main() {
r := gin.Default()
group := r.Group("/1")
{
group.Use(router.AuthMiddleware())
group.Use(team.AuthMiddleware())
// 获取当前用户信息
group.GET("/me", router.HandleMe)
group.GET("/me", team.HandleMe)
group.GET("/me/usages", team.HandleMeUsage)
group.GET("/me/usages", router.HandleMeUsage)
group.GET("/keys", team.HandleKeys) // 获取所有Key
group.POST("/keys", team.HandleAddKey) // 添加Key
group.DELETE("/keys/:id", team.HandleDelKey) // 删除Key
// 获取所有Key
group.GET("/keys", router.HandleKeys)
group.GET("/users", team.HandleUsers) // 获取所有用户信息
group.POST("/users", team.HandleAddUser) // 添加用户
group.DELETE("/users/:id", team.HandleDelUser) // 删除用户
// 获取所有用户信息
group.GET("/users", router.HandleUsers)
group.GET("/usages", router.HandleUsage)
// 添加Key
group.POST("/keys", router.HandleAddKey)
// 删除Key
group.DELETE("/keys/:id", router.HandleDelKey)
// 添加用户
group.POST("/users", router.HandleAddUser)
// 删除用户
group.DELETE("/users/:id", router.HandleDelUser)
group.GET("/usages", team.HandleUsage)
// 重置用户Token
group.POST("/users/:id/reset", router.HandleResetUserToken)
group.POST("/users/:id/reset", team.HandleResetUserToken)
}
// 初始化用户
r.POST("/1/users/init", router.Handleinit)
r.POST("/1/users/init", team.Handleinit)
r.Any("/v1/*proxypath", router.HandleProy)
r.Any("/v1/*proxypath", router.HandleProxy)
// r.POST("/v1/chat/completions", router.HandleProy)
// r.GET("/v1/models", router.HandleProy)
@@ -188,7 +177,15 @@ func main() {
if err != nil {
panic(err)
}
r.GET("/", gin.WrapH(http.FileServer(http.FS(idxFS))))
redirect := os.Getenv("CUSTOM_REDIRECT")
if redirect != "" {
r.GET("/", func(c *gin.Context) {
c.Redirect(http.StatusMovedPermanently, redirect)
})
} else {
r.GET("/", gin.WrapH(http.FileServer(http.FS(idxFS))))
}
assetsFS, err := fs.Sub(web, "dist/assets")
if err != nil {
panic(err)

View File

@@ -10,8 +10,10 @@ import (
"io"
"log"
"net/http"
"opencatd-open/pkg/error"
"opencatd-open/pkg/openai"
"opencatd-open/pkg/tokenizer"
"opencatd-open/pkg/vertexai"
"opencatd-open/store"
"strings"
@@ -27,14 +29,15 @@ func ChatTextCompletions(c *gin.Context, chatReq *openai.ChatCompletionRequest)
}
type ChatRequest struct {
Model string `json:"model,omitempty"`
Messages any `json:"messages,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream,omitempty"`
System string `json:"system,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Model string `json:"model,omitempty"`
Messages any `json:"messages,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream,omitempty"`
System string `json:"system,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
AnthropicVersion string `json:"anthropic_version,omitempty"`
}
func (c *ChatRequest) ByteJson() []byte {
@@ -117,8 +120,12 @@ type ClaudeStreamResponse struct {
}
func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
var (
req *http.Request
targetURL = ClaudeMessageEndpoint
)
onekey, err := store.SelectKeyCache("claude")
apiKey, err := store.SelectKeyCache("claude")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -131,6 +138,10 @@ func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
// claudReq.Temperature = chatReq.Temperature
claudReq.TopP = chatReq.TopP
claudReq.MaxTokens = 4096
if apiKey.ApiType == "vertex" {
claudReq.AnthropicVersion = "vertex-2023-10-16"
claudReq.Model = ""
}
var prompt string
@@ -181,10 +192,40 @@ func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
usagelog.PromptCount = tokenizer.NumTokensFromStr(prompt, chatReq.Model)
req, _ := http.NewRequest("POST", MessageEndpoint, bytes.NewReader(claudReq.ByteJson()))
req.Header.Set("x-api-key", onekey.Key)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("Content-Type", "application/json")
if apiKey.ApiType == "vertex" {
var vertexSecret vertexai.VertexSecretKey
if err := json.Unmarshal([]byte(apiKey.ApiSecret), &vertexSecret); err != nil {
c.JSON(http.StatusInternalServerError, error.ErrorData(err.Error()))
return
}
vcmodel, ok := vertexai.VertexClaudeModelMap[chatReq.Model]
if !ok {
c.JSON(http.StatusInternalServerError, error.ErrorData("Model not found"))
return
}
// 获取gcloud token临时放置在apiKey.Key中
gcloudToken, err := vertexai.GcloudAuth(vertexSecret.ClientEmail, vertexSecret.PrivateKey)
if err != nil {
c.JSON(http.StatusInternalServerError, error.ErrorData(err.Error()))
return
}
// 拼接vertex的请求地址
targetURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", vcmodel.Region, vertexSecret.ProjectID, vcmodel.Region, vcmodel.VertexName)
req, _ = http.NewRequest("POST", targetURL, bytes.NewReader(claudReq.ByteJson()))
req.Header.Set("Authorization", "Bearer "+gcloudToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Accept-Encoding", "identity")
} else {
req, _ = http.NewRequest("POST", targetURL, bytes.NewReader(claudReq.ByteJson()))
req.Header.Set("x-api-key", apiKey.Key)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("Content-Type", "application/json")
}
client := http.DefaultClient
rsp, err := client.Do(req)

View File

@@ -68,8 +68,8 @@ import (
)
var (
ClaudeUrl = "https://api.anthropic.com/v1/complete"
MessageEndpoint = "https://api.anthropic.com/v1/messages"
ClaudeUrl = "https://api.anthropic.com/v1/complete"
ClaudeMessageEndpoint = "https://api.anthropic.com/v1/messages"
)
type MessageModule struct {

11
pkg/error/errdata.go Normal file
View File

@@ -0,0 +1,11 @@
package error
import "github.com/gin-gonic/gin"
func ErrorData(message string) gin.H {
return gin.H{
"error": gin.H{
"message": message,
},
}
}

View File

@@ -17,8 +17,10 @@ import (
)
const (
AzureApiVersion = "2024-02-01"
OpenAI_Endpoint = "https://api.openai.com/v1/chat/completions"
AzureApiVersion = "2024-02-01"
BaseHost = "api.openai.com"
OpenAI_Endpoint = "https://api.openai.com/v1/chat/completions"
Github_Marketplace = "https://models.inference.ai.azure.com/chat/completions"
)
var (
@@ -204,6 +206,10 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) {
var req *http.Request
switch onekey.ApiType {
case "github":
req, err = http.NewRequest(c.Request.Method, Github_Marketplace, bytes.NewReader(chatReq.ToByteJson()))
req.Header = c.Request.Header
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key))
case "azure":
var buildurl string
if onekey.EndPoint != "" {
@@ -215,15 +221,18 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) {
req.Header = c.Request.Header
req.Header.Set("api-key", onekey.Key)
default:
req, err = http.NewRequest(c.Request.Method, OpenAI_Endpoint, bytes.NewReader(chatReq.ToByteJson()))
if onekey.EndPoint != "" { // 优先key的endpoint
req, err = http.NewRequest(c.Request.Method, onekey.EndPoint+c.Request.RequestURI, bytes.NewReader(chatReq.ToByteJson()))
} else {
if BaseURL != "" { // 其次BaseURL
req, err = http.NewRequest(c.Request.Method, BaseURL+c.Request.RequestURI, bytes.NewReader(chatReq.ToByteJson()))
} else { // 最后是gateway的endpoint
req, err = http.NewRequest(c.Request.Method, AIGateWay_Endpoint, bytes.NewReader(chatReq.ToByteJson()))
}
}
if AIGateWay_Endpoint != "" { // cloudflare gateway的endpoint
req, err = http.NewRequest(c.Request.Method, AIGateWay_Endpoint, bytes.NewReader(chatReq.ToByteJson()))
}
customEndpoint := os.Getenv("CUSTOM_ENDPOINT") // 最后是用户自定义的endpoint CUSTOM_ENDPOINT=true OpenAI_Endpoint
if customEndpoint == "true" && OpenAI_Endpoint != "" {
req, err = http.NewRequest(c.Request.Method, BaseURL, bytes.NewReader(chatReq.ToByteJson()))
}
req.Header = c.Request.Header
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key))
}

177
pkg/openai/whisper.go Normal file
View File

@@ -0,0 +1,177 @@
package openai
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"log"
"mime/multipart"
"net/http"
"net/http/httputil"
"net/url"
"opencatd-open/pkg/tokenizer"
"opencatd-open/store"
"path/filepath"
"time"
"github.com/faiface/beep"
"github.com/faiface/beep/mp3"
"github.com/faiface/beep/wav"
"github.com/gin-gonic/gin"
"gopkg.in/vansante/go-ffprobe.v2"
)
func WhisperProxy(c *gin.Context) {
var chatlog store.Tokens
byteBody, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody))
model, _ := c.GetPostForm("model")
key, err := store.SelectKeyCache("openai")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
},
})
return
}
chatlog.Model = model
token, _ := c.Get("localuser")
lu, err := store.GetUserByToken(token.(string))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
},
})
return
}
chatlog.UserID = int(lu.ID)
if err := ParseWhisperRequestTokens(c, &chatlog, byteBody); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
},
})
return
}
if key.EndPoint == "" {
key.EndPoint = "https://api.openai.com"
}
targetUrl, _ := url.ParseRequestURI(key.EndPoint + c.Request.URL.String())
log.Println(targetUrl)
proxy := httputil.NewSingleHostReverseProxy(targetUrl)
proxy.Director = func(req *http.Request) {
req.Host = targetUrl.Host
req.URL.Scheme = targetUrl.Scheme
req.URL.Host = targetUrl.Host
req.Header.Set("Authorization", "Bearer "+key.Key)
}
proxy.ModifyResponse = func(resp *http.Response) error {
if resp.StatusCode != http.StatusOK {
return nil
}
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
if err := store.Record(&chatlog); err != nil {
log.Println(err)
}
if err := store.SumDaily(chatlog.UserID); err != nil {
log.Println(err)
}
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func probe(fileReader io.Reader) (time.Duration, error) {
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
data, err := ffprobe.ProbeReader(ctx, fileReader)
if err != nil {
return 0, err
}
duration := data.Format.DurationSeconds
pduration, err := time.ParseDuration(fmt.Sprintf("%fs", duration))
if err != nil {
return 0, fmt.Errorf("Error parsing duration: %s", err)
}
return pduration, nil
}
func getAudioDuration(file *multipart.FileHeader) (time.Duration, error) {
var (
streamer beep.StreamSeekCloser
format beep.Format
err error
)
f, err := file.Open()
defer f.Close()
// Get the file extension to determine the audio file type
fileType := filepath.Ext(file.Filename)
switch fileType {
case ".mp3":
streamer, format, err = mp3.Decode(f)
case ".wav":
streamer, format, err = wav.Decode(f)
case ".m4a":
duration, err := probe(f)
if err != nil {
return 0, err
}
return duration, nil
default:
return 0, errors.New("unsupported audio file format")
}
if err != nil {
return 0, err
}
defer streamer.Close()
// Calculate the audio file's duration.
numSamples := streamer.Len()
sampleRate := format.SampleRate
duration := time.Duration(numSamples) * time.Second / time.Duration(sampleRate)
return duration, nil
}
func ParseWhisperRequestTokens(c *gin.Context, usage *store.Tokens, byteBody []byte) error {
file, _ := c.FormFile("file")
model, _ := c.GetPostForm("model")
usage.Model = model
if file != nil {
duration, err := getAudioDuration(file)
if err != nil {
return fmt.Errorf("Error getting audio duration:%s", err)
}
if duration > 5*time.Minute {
return fmt.Errorf("Audio duration exceeds 5 minutes")
}
// 计算时长,四舍五入到最接近的秒数
usage.PromptCount = int(duration.Round(time.Second).Seconds())
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody))
return nil
}

182
pkg/team/key.go Normal file
View File

@@ -0,0 +1,182 @@
package team
import (
"net/http"
"opencatd-open/pkg/azureopenai"
"opencatd-open/store"
"strings"
"github.com/Sakurasan/to"
"github.com/gin-gonic/gin"
)
type Key struct {
ID int `json:"id,omitempty"`
Key string `json:"key,omitempty"`
Name string `json:"name,omitempty"`
ApiType string `json:"api_type,omitempty"`
Endpoint string `json:"endpoint,omitempty"`
UpdatedAt string `json:"updatedAt,omitempty"`
CreatedAt string `json:"createdAt,omitempty"`
}
func HandleKeys(c *gin.Context) {
keys, err := store.GetAllKeys()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": err.Error(),
})
}
c.JSON(http.StatusOK, keys)
}
func HandleAddKey(c *gin.Context) {
var body Key
if err := c.BindJSON(&body); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
body.Name = strings.ToLower(strings.TrimSpace(body.Name))
body.Key = strings.TrimSpace(body.Key)
if strings.HasPrefix(body.Name, "azure.") {
keynames := strings.Split(body.Name, ".")
if len(keynames) < 2 {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
"message": "Invalid Key Name",
}})
return
}
k := &store.Key{
ApiType: "azure",
Name: body.Name,
Key: body.Key,
ResourceNmae: keynames[1],
EndPoint: body.Endpoint,
}
if err := store.CreateKey(k); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
} else if strings.HasPrefix(body.Name, "claude.") {
keynames := strings.Split(body.Name, ".")
if len(keynames) < 2 {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
"message": "Invalid Key Name",
}})
return
}
if body.Endpoint == "" {
body.Endpoint = "https://api.anthropic.com"
}
k := &store.Key{
// ApiType: "anthropic",
ApiType: "claude",
Name: body.Name,
Key: body.Key,
ResourceNmae: keynames[1],
EndPoint: body.Endpoint,
}
if err := store.CreateKey(k); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
} else if strings.HasPrefix(body.Name, "google.") {
keynames := strings.Split(body.Name, ".")
if len(keynames) < 2 {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
"message": "Invalid Key Name",
}})
return
}
k := &store.Key{
// ApiType: "anthropic",
ApiType: "google",
Name: body.Name,
Key: body.Key,
ResourceNmae: keynames[1],
EndPoint: body.Endpoint,
}
if err := store.CreateKey(k); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
} else if strings.HasPrefix(body.Name, "github.") {
keynames := strings.Split(body.Name, ".")
if len(keynames) < 2 {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
"message": "Invalid Key Name",
}})
return
}
k := &store.Key{
ApiType: "github",
Name: body.Name,
Key: body.Key,
ResourceNmae: keynames[1],
EndPoint: body.Endpoint,
}
if err := store.CreateKey(k); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
} else {
if body.ApiType == "" {
if err := store.AddKey("openai", body.Key, body.Name); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
} else {
k := &store.Key{
ApiType: body.ApiType,
Name: body.Name,
Key: body.Key,
ResourceNmae: azureopenai.GetResourceName(body.Endpoint),
EndPoint: body.Endpoint,
}
if err := store.CreateKey(k); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
}
}
k, err := store.GetKeyrByName(body.Name)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
c.JSON(http.StatusOK, k)
}
func HandleDelKey(c *gin.Context) {
id := to.Int(c.Param("id"))
if id < 1 {
c.JSON(http.StatusOK, gin.H{"error": "invalid key id"})
return
}
if err := store.DeleteKey(uint(id)); err != nil {
c.JSON(http.StatusOK, gin.H{"error": "invalid key id"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "ok"})
}

104
pkg/team/me.go Normal file
View File

@@ -0,0 +1,104 @@
package team
import (
"errors"
"net/http"
"opencatd-open/store"
"time"
"github.com/Sakurasan/to"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"gorm.io/gorm"
)
func Handleinit(c *gin.Context) {
user, err := store.GetUserByID(1)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
u := store.User{Name: "root", Token: uuid.NewString()}
u.ID = 1
if err := store.CreateUser(&u); err != nil {
c.JSON(http.StatusForbidden, gin.H{
"error": err.Error(),
})
return
} else {
rootToken = u.Token
resJSON := User{
false,
int(u.ID),
u.UpdatedAt.Format(time.RFC3339),
u.Name,
u.Token,
u.CreatedAt.Format(time.RFC3339),
}
c.JSON(http.StatusOK, resJSON)
return
}
}
c.JSON(http.StatusOK, gin.H{
"error": err.Error(),
})
return
}
if user.ID == uint(1) {
c.JSON(http.StatusForbidden, gin.H{
"error": "super user already exists, use cli to reset password",
})
}
}
func HandleMe(c *gin.Context) {
token := c.GetHeader("Authorization")
u, err := store.GetUserByToken(token[7:])
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": err.Error(),
})
}
resJSON := User{
false,
int(u.ID),
u.UpdatedAt.Format(time.RFC3339),
u.Name,
u.Token,
u.CreatedAt.Format(time.RFC3339),
}
c.JSON(http.StatusOK, resJSON)
}
func HandleMeUsage(c *gin.Context) {
token := c.GetHeader("Authorization")
fromStr := c.Query("from")
toStr := c.Query("to")
getMonthStartAndEnd := func() (start, end string) {
loc, _ := time.LoadLocation("Local")
now := time.Now().In(loc)
year, month, _ := now.Date()
startOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, loc)
endOfMonth := startOfMonth.AddDate(0, 1, 0)
start = startOfMonth.Format("2006-01-02")
end = endOfMonth.Format("2006-01-02")
return
}
if fromStr == "" || toStr == "" {
fromStr, toStr = getMonthStartAndEnd()
}
user, err := store.GetUserByToken(token)
if err != nil {
c.AbortWithError(http.StatusForbidden, err)
return
}
usage, err := store.QueryUserUsage(to.String(user.ID), fromStr, toStr)
if err != nil {
c.AbortWithError(http.StatusForbidden, err)
return
}
c.JSON(200, usage)
}

59
pkg/team/middleware.go Normal file
View File

@@ -0,0 +1,59 @@
package team
import (
"log"
"net/http"
"opencatd-open/store"
"strings"
"github.com/gin-gonic/gin"
)
var (
rootToken string
)
func AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if rootToken == "" {
u, err := store.GetUserByID(uint(1))
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort()
return
}
rootToken = u.Token
}
token := c.GetHeader("Authorization")
if token == "" || token[:7] != "Bearer " {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort()
return
}
if store.IsExistAuthCache(token[7:]) {
if strings.HasPrefix(c.Request.URL.Path, "/1/me") {
c.Next()
return
}
}
if token[7:] != rootToken {
u, err := store.GetUserByID(uint(1))
if err != nil {
log.Println(err)
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort()
return
}
if token[:7] != u.Token {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort()
return
}
rootToken = u.Token
store.LoadAuthCache()
}
// 可以在这里对 token 进行验证并检查权限
c.Next()
}
}

38
pkg/team/usage.go Normal file
View File

@@ -0,0 +1,38 @@
package team
import (
"net/http"
"opencatd-open/store"
"time"
"github.com/gin-gonic/gin"
)
func HandleUsage(c *gin.Context) {
fromStr := c.Query("from")
toStr := c.Query("to")
getMonthStartAndEnd := func() (start, end string) {
loc, _ := time.LoadLocation("Local")
now := time.Now().In(loc)
year, month, _ := now.Date()
startOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, loc)
endOfMonth := startOfMonth.AddDate(0, 1, 0)
start = startOfMonth.Format("2006-01-02")
end = endOfMonth.Format("2006-01-02")
return
}
if fromStr == "" || toStr == "" {
fromStr, toStr = getMonthStartAndEnd()
}
usage, err := store.QueryUsage(fromStr, toStr)
if err != nil {
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
c.JSON(200, usage)
}

89
pkg/team/user.go Normal file
View File

@@ -0,0 +1,89 @@
package team
import (
"net/http"
"opencatd-open/store"
"github.com/Sakurasan/to"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type User struct {
IsDelete bool `json:"IsDelete,omitempty"`
ID int `json:"id,omitempty"`
UpdatedAt string `json:"updatedAt,omitempty"`
Name string `json:"name,omitempty"`
Token string `json:"token,omitempty"`
CreatedAt string `json:"createdAt,omitempty"`
}
func HandleUsers(c *gin.Context) {
users, err := store.GetAllUsers()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"error": err.Error(),
})
}
c.JSON(http.StatusOK, users)
}
func HandleAddUser(c *gin.Context) {
var body User
if err := c.BindJSON(&body); err != nil {
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
return
}
if len(body.Name) == 0 {
c.JSON(http.StatusOK, gin.H{"error": "invalid user name"})
return
}
if err := store.AddUser(body.Name, uuid.NewString()); err != nil {
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
return
}
u, err := store.GetUserByName(body.Name)
if err != nil {
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, u)
}
func HandleDelUser(c *gin.Context) {
id := to.Int(c.Param("id"))
if id <= 1 {
c.JSON(http.StatusOK, gin.H{"error": "invalid user id"})
return
}
if err := store.DeleteUser(uint(id)); err != nil {
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "ok"})
}
func HandleResetUserToken(c *gin.Context) {
id := to.Int(c.Param("id"))
newtoken := c.Query("token")
if newtoken == "" {
newtoken = uuid.NewString()
}
if err := store.UpdateUser(uint(id), newtoken); err != nil {
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
u, err := store.GetUserByID(uint(id))
if err != nil {
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
if u.ID == uint(1) {
rootToken = u.Token
}
c.JSON(http.StatusOK, u)
}

View File

@@ -97,8 +97,14 @@ func Cost(model string, promptCount, completionCount int) float64 {
cost = 0.01*float64(prompt/1000) + 0.03*float64(completion/1000)
case "gpt-4o", "gpt-4o-2024-05-13":
cost = 0.005*float64(prompt/1000) + 0.015*float64(completion/1000)
case "gpt-4o-2024-08-06":
cost = 0.0025*float64(prompt/1000) + 0.010*float64(completion/1000)
case "gpt-4o-mini", "gpt-4o-mini-2024-07-18":
cost = 0.00015*float64(prompt/1000) + 0.0006*float64(completion/1000)
case "o1-preview", "o1-preview-2024-09-12":
cost = 0.015*float64(prompt/1000) + 0.06*float64(completion/1000)
case "o1-mini", "o1-mini-2024-09-12":
cost = 0.003*float64(prompt/1000) + 0.012*float64(completion/1000)
case "whisper-1":
// 0.006$/min
cost = 0.006 * float64(prompt+completion) / 60
@@ -149,7 +155,8 @@ func Cost(model string, promptCount, completionCount int) float64 {
cost = (0.003/1000)*float64(prompt) + (0.015/1000)*float64(completion)
case "claude-3-opus-20240229":
cost = (0.015/1000)*float64(prompt) + (0.075/1000)*float64(completion)
case "claude-3-5-sonnet", "claude-3-5-sonnet-20240620":
cost = (0.003/1000)*float64(prompt) + (0.015/1000)*float64(completion)
// google
// https://ai.google.dev/pricing?hl=zh-cn
case "gemini-pro":

167
pkg/vertexai/auth.go Normal file
View File

@@ -0,0 +1,167 @@
/*
https://docs.anthropic.com/zh-CN/api/claude-on-vertex-ai
MODEL_ID=claude-3-5-sonnet@20240620
REGION=us-east5
PROJECT_ID=MY_PROJECT_ID
curl \
-X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json" \
https://$LOCATION-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/${LOCATION}/publishers/anthropic/models/${MODEL_ID}:streamRawPredict \
-d '{
"anthropic_version": "vertex-2023-10-16",
"messages": [{
"role": "user",
"content": "介绍一下你自己"
}],
"stream": true,
"max_tokens": 4096
}'
*/
package vertexai
import (
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"net/http"
"net/url"
"time"
"github.com/golang-jwt/jwt"
)
// json文件存储在ApiKey.ApiSecret中
type VertexSecretKey struct {
Type string `json:"type"`
ProjectID string `json:"project_id"`
PrivateKeyID string `json:"private_key_id"`
PrivateKey string `json:"private_key"`
ClientEmail string `json:"client_email"`
ClientID string `json:"client_id"`
AuthURI string `json:"auth_uri"`
TokenURI string `json:"token_uri"`
AuthProviderX509CertURL string `json:"auth_provider_x509_cert_url"`
ClientX509CertURL string `json:"client_x509_cert_url"`
UniverseDomain string `json:"universe_domain"`
}
type VertexClaudeModel struct {
VertexName string
Region string
}
var VertexClaudeModelMap = map[string]VertexClaudeModel{
"claude-3-opus": {
VertexName: "claude-3-opus@20240229",
Region: "us-east5",
},
"claude-3-sonnet": {
VertexName: "claude-3-sonnet@20240229",
Region: "us-central1",
// Region: "asia-southeast1",
},
"claude-3-haiku": {
VertexName: "claude-3-haiku@20240307",
Region: "us-central1",
// Region: "europe-west4",
},
"claude-3-opus-20240229": {
VertexName: "claude-3-opus@20240229",
Region: "us-east5",
},
"claude-3-sonnet-20240229": {
VertexName: "claude-3-sonnet@20240229",
Region: "us-central1",
// Region: "asia-southeast1",
},
"claude-3-haiku-20240307": {
VertexName: "claude-3-haiku@20240307",
Region: "us-central1",
// Region: "europe-west4",
},
"claude-3-5-sonnet": {
VertexName: "claude-3-5-sonnet@20240620",
Region: "us-east5",
// Region: "europe-west1",
},
"claude-3-5-sonnet-20240620": {
VertexName: "claude-3-5-sonnet@20240620",
Region: "us-east5",
// Region: "europe-west1",
},
}
func createSignedJWT(email, privateKeyPEM string) (string, error) {
block, _ := pem.Decode([]byte(privateKeyPEM))
if block == nil {
return "", fmt.Errorf("failed to parse PEM block containing the private key")
}
privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return "", err
}
rsaKey, ok := privateKey.(*rsa.PrivateKey)
if !ok {
return "", fmt.Errorf("not an RSA private key")
}
now := time.Now()
claims := jwt.MapClaims{
"iss": email,
"aud": "https://www.googleapis.com/oauth2/v4/token",
"iat": now.Unix(),
"exp": now.Add(10 * time.Minute).Unix(),
"scope": "https://www.googleapis.com/auth/cloud-platform",
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
return token.SignedString(rsaKey)
}
func exchangeJwtForAccessToken(signedJWT string) (string, error) {
authURL := "https://www.googleapis.com/oauth2/v4/token"
data := url.Values{}
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
data.Set("assertion", signedJWT)
resp, err := http.PostForm(authURL, data)
if err != nil {
return "", err
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", err
}
accessToken, ok := result["access_token"].(string)
if !ok {
return "", fmt.Errorf("access token not found in response")
}
return accessToken, nil
}
// 获取gcloud auth token
func GcloudAuth(ClientEmail, PrivateKey string) (string, error) {
signedJWT, err := createSignedJWT(ClientEmail, PrivateKey)
if err != nil {
return "", err
}
token, err := exchangeJwtForAccessToken(signedJWT)
if err != nil {
return "", fmt.Errorf("Invalid jwt token: %v\n", err)
}
return token, nil
}

View File

@@ -18,33 +18,7 @@ func ChatHandler(c *gin.Context) {
return
}
// if chatreq.Messages[len(chatreq.Messages)-1].Role == "user" {
// result, err := search.BingSearch(search.SearchParams{Query: string(chatreq.Messages[len(chatreq.Messages)-1].Content)})
// if err == nil {
// var msgs []openai.ChatCompletionMessage
// for i, m := range chatreq.Messages {
// var buf bytes.Buffer
// buf.WriteString("根据我提问的语言回答我,我将提供一些从搜索引擎获取的信息(以websearch:开头)。你自行判断是否使用搜索引擎获取的内容。不要原封不动照抄,根据你自己的知识库提炼信息之后回答我\n\n")
// if m.Role == "system" {
// buf.Write(m.Content)
// msgs = append(msgs, openai.ChatCompletionMessage{Role: m.Role, Content: buf.Bytes()})
// } else {
// msgs = append(msgs, openai.ChatCompletionMessage{Role: m.Role, Content: buf.Bytes()})
// }
// if i == len(chatreq.Messages)-1 {
// m.Content = append(m.Content, json.RawMessage("\n\nwebsearch:")...)
// m.Content = append(m.Content, json.RawMessage(result.(string))...)
// msgs = append(msgs, openai.ChatCompletionMessage{Role: m.Role, Content: m.Content})
// } else {
// msgs = append(msgs, openai.ChatCompletionMessage{Role: m.Role, Content: m.Content})
// }
// }
// chatreq.Messages = msgs
// }
// }
if strings.HasPrefix(chatreq.Model, "gpt") {
if strings.HasPrefix(chatreq.Model, "gpt") || strings.HasPrefix(chatreq.Model, "o1-") {
openai.ChatProxy(c, &chatreq)
return
}

File diff suppressed because it is too large Load Diff

View File

@@ -53,6 +53,14 @@ func SelectKeyCache(apitype string) (Key, error) {
if item.Object.(Key).ApiType == "azure" {
keys = append(keys, item.Object.(Key))
}
if item.Object.(Key).ApiType == "github" {
keys = append(keys, item.Object.(Key))
}
}
if apitype == "claude" {
if item.Object.(Key).ApiType == "vertex" {
keys = append(keys, item.Object.(Key))
}
}
}
if len(keys) == 0 {

View File

@@ -2,9 +2,35 @@ package store
import (
"encoding/json"
"fmt"
"log"
"opencatd-open/pkg/vertexai"
"os"
"time"
)
func init() {
// check vertex
if os.Getenv("Vertex") != "" {
vertex_auth := os.Getenv("Vertex")
var Vertex vertexai.VertexSecretKey
if err := json.Unmarshal([]byte(vertex_auth), &Vertex); err != nil {
log.Fatalln(err)
return
}
key := Key{
ApiType: "vertex",
Name: Vertex.ProjectID,
Key: vertex_auth,
ApiSecret: vertex_auth,
}
if err := db.FirstOrCreate(&key).Error; err != nil {
log.Println(fmt.Errorf("create vertex key error: %v", err))
}
}
LoadKeysCache()
}
type Key struct {
ID uint `gorm:"primarykey" json:"id,omitempty"`
Key string `gorm:"unique;not null" json:"key,omitempty"`
@@ -14,6 +40,7 @@ type Key struct {
EndPoint string `gorm:"column:endpoint"`
ResourceNmae string `gorm:"column:resource_name"`
DeploymentName string `gorm:"column:deployment_name"`
ApiSecret string `gorm:"column:api_secret"`
CreatedAt time.Time `json:"createdAt,omitempty"`
UpdatedAt time.Time `json:"updatedAt,omitempty"`
}