diff --git a/README.md b/README.md index ed13f83..dd89b79 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,15 @@ OpenCat for Team的开源实现 ~~基本~~实现了opencatd的全部功能 +## Extra Support: + +| 任务 | 完成情况 | +| --- | --- | +|[Azure OpenAI](./doc/azure.md) | ✅| +| ... | ... | + + + ## 快速上手 ``` docker run -d --name opencatd -p 80:80 -v /etc/opencatd:/app/db mirrors2/opencatd-open diff --git a/doc/azure.md b/doc/azure.md index c190ddb..5999a41 100644 --- a/doc/azure.md +++ b/doc/azure.md @@ -1,4 +1,17 @@ -# Azure OpenAI +# Azure OpenAI for team -需要获取 api-key和endpoint +1.需要获取 api-key和endpoint [https://[resource name].openai.azure.com/) ![](./azure_key%26endpoint.png) + +> 2.Pleause use model name as deployment name + +| model name | deployment name | +| --- | --- | +|gpt-35-turbo | gpt-35-turbo | +| gpt-4 | gpt-4 | + +## How to use +- opencat 使用方式 + - key name以 azure.[resource name]的方式添加 + - 密钥任取一个 + - azure_openai_for_team diff --git a/doc/azure_openai_for_team.png b/doc/azure_openai_for_team.png new file mode 100644 index 0000000..4b7df14 Binary files /dev/null and b/doc/azure_openai_for_team.png differ diff --git a/pkg/azureopenai/azureopenai.go b/pkg/azureopenai/azureopenai.go index f61d017..d9d7a26 100644 --- a/pkg/azureopenai/azureopenai.go +++ b/pkg/azureopenai/azureopenai.go @@ -9,6 +9,10 @@ curl $AZURE_OPENAI_ENDPOINT/openai/deployments/gpt-35-turbo/chat/completions?api "messages": [{"role": "user", "content": "你好"}] }' + curl $AZURE_OPENAI_ENDPOINT/openai/deployments?api-version=2022-12-01 \ + -H "Content-Type: application/json" \ + -H "api-key: $AZURE_OPENAI_KEY" \ + */ package azureopenai @@ -16,6 +20,7 @@ package azureopenai import ( "encoding/json" "net/http" + "strings" ) var ( @@ -41,6 +46,7 @@ type ModelsList struct { } func Models(endpoint, apikey string) (*ModelsList, error) { + endpoint = removeTrailingSlash(endpoint) var modelsl ModelsList req, _ := http.NewRequest(http.MethodGet, endpoint+"/openai/deployments?api-version=2022-12-01", nil) req.Header.Set("api-key", apikey) @@ -56,3 +62,11 @@ func Models(endpoint, apikey string) (*ModelsList, error) { return &modelsl, nil } + +func removeTrailingSlash(s string) string { + const prefix = "openai.azure.com/" + if strings.HasPrefix(s, prefix) && strings.HasSuffix(s, "/") { + return s[:len(s)-1] + } + return s +} diff --git a/pkg/azureopenai/azureopenai_test.go b/pkg/azureopenai/azureopenai_test.go new file mode 100644 index 0000000..1fe8e7d --- /dev/null +++ b/pkg/azureopenai/azureopenai_test.go @@ -0,0 +1,54 @@ +/* +https://learn.microsoft.com/zh-cn/azure/cognitive-services/openai/chatgpt-quickstart + +curl $AZURE_OPENAI_ENDPOINT/openai/deployments/gpt-35-turbo/chat/completions?api-version=2023-03-15-preview \ + -H "Content-Type: application/json" \ + -H "api-key: $AZURE_OPENAI_KEY" \ + -d '{ + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "你好"}] + }' + +*/ + +package azureopenai + +import ( + "fmt" + "testing" +) + +func TestModels(t *testing.T) { + type args struct { + endpoint string + apikey string + } + tests := []struct { + name string + args args + }{ + { + name: "test", + args: args{ + endpoint: "https://mirrors2.openai.azure.com", + apikey: "696a7729234c438cb38f24da22ee602d", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Models(tt.args.endpoint, tt.args.apikey) + if err != nil { + t.Errorf("Models() error = %v", err) + return + } + for _, data := range got.Data { + fmt.Println(data.Model, data.ID) + } + }) + } +} + +// curl https://mirrors2.openai.azure.com/openai/deployments?api-version=2023-03-15-preview \ +// -H "Content-Type: application/json" \ +// -H "api-key: 696a7729234c438cb38f24da22ee602d" diff --git a/router/router.go b/router/router.go index 43643a5..c0a1972 100644 --- a/router/router.go +++ b/router/router.go @@ -210,16 +210,45 @@ func HandleUsers(c *gin.Context) { func HandleAddKey(c *gin.Context) { var body Key if err := c.BindJSON(&body); err != nil { - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) + c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ + "message": err.Error(), + }}) return } - if err := store.AddKey(body.Key, body.Name); err != nil { - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) - return + if strings.HasPrefix(strings.ToLower(body.Name), "azure") { + keynames := strings.Split(body.Name, ".") + if len(keynames) < 2 { + c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ + "message": "Invalid Key Name", + }}) + return + } + k := &store.Key{ + ApiType: "azure_openai", + Name: body.Name, + Key: body.Key, + ResourceNmae: keynames[1], + } + if err := store.CreateKey(k); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } + } else { + if err := store.AddKey("openai", body.Key, body.Name); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } } + k, err := store.GetKeyrByName(body.Name) if err != nil { - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ + "message": err.Error(), + }}) return } c.JSON(http.StatusOK, k) @@ -333,6 +362,13 @@ func HandleProy(c *gin.Context) { } if c.Request.URL.Path == "/v1/chat/completions" && localuser { + if store.KeysCache.ItemCount() == 0 { + c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{ + "message": "No Api-Key Available", + }}) + return + } + onekey := store.FromKeyCacheRandomItemKey() if err := c.BindJSON(&chatreq); err != nil { c.AbortWithError(http.StatusBadRequest, err) @@ -350,12 +386,24 @@ func HandleProy(c *gin.Context) { var body bytes.Buffer json.NewEncoder(&body).Encode(chatreq) // 创建 API 请求 - req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body) + switch onekey.ApiType { + case "azure_openai": + req, err = http.NewRequest(c.Request.Method, fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", onekey.ResourceNmae, modelmap(chatreq.Model)), &body) + req.Header = c.Request.Header + req.Header.Set("api-key", onekey.Key) + case "openai": + fallthrough + default: + req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body) + req.Header = c.Request.Header + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key)) + } if err != nil { log.Println(err) c.JSON(http.StatusOK, gin.H{"error": err.Error()}) return } + } else { req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, c.Request.Body) if err != nil { @@ -363,17 +411,7 @@ func HandleProy(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"error": err.Error()}) return } - } - - req.Header = c.Request.Header - if localuser { - if store.KeysCache.ItemCount() == 0 { - c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{ - "message": "No Api-Key Available", - }}) - return - } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", store.FromKeyCacheRandomItem())) + req.Header = c.Request.Header } resp, err := client.Do(req) @@ -411,7 +449,6 @@ func HandleProy(c *gin.Context) { reader := bufio.NewReader(resp.Body) if resp.StatusCode == 200 && localuser { - if isStream { contentCh := fetchResponseContent(c, reader) var buffer bytes.Buffer @@ -500,8 +537,8 @@ func HandleReverseProxy(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"error": "No Api-Key Available"}) return } - // c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", store.FromKeyCacheRandomItem())) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", store.FromKeyCacheRandomItem())) + onekey := store.FromKeyCacheRandomItemKey() + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key)) } proxy.ServeHTTP(c.Writer, req) @@ -629,3 +666,13 @@ func NumTokensFromStr(messages string, model string) (num_tokens int) { num_tokens += len(tkm.Encode(messages, nil, nil)) return num_tokens } + +func modelmap(in string) string { + switch in { + case "gpt-3.5-turbo": + return "gpt-35-turbo" + case "gpt-4": + return "gpt-4" + } + return in +} diff --git a/store/cache.go b/store/cache.go index cd7abef..211e409 100644 --- a/store/cache.go +++ b/store/cache.go @@ -26,18 +26,18 @@ func LoadKeysCache() { return } for idx, key := range keys { - KeysCache.Set(to.String(idx), key.Key, cache.NoExpiration) + KeysCache.Set(to.String(idx), key, cache.NoExpiration) } } -func FromKeyCacheRandomItem() string { +func FromKeyCacheRandomItemKey() Key { items := KeysCache.Items() if len(items) == 1 { - return items[to.String(0)].Object.(string) + return items[to.String(0)].Object.(Key) } idx := rand.Intn(len(items)) item := items[to.String(idx)] - return item.Object.(string) + return item.Object.(Key) } func LoadAuthCache() { diff --git a/store/keydb.go b/store/keydb.go index 05f6d9c..c5d106b 100644 --- a/store/keydb.go +++ b/store/keydb.go @@ -1,19 +1,28 @@ package store -import "time" +import ( + "encoding/json" + "time" +) type Key struct { - ID uint `gorm:"primarykey" json:"id,omitempty"` - Key string `gorm:"unique;not null" json:"key,omitempty"` - Name string `gorm:"unique;not null" json:"name,omitempty"` - UserId string `json:"-,omitempty"` - KeyType string - EndPoint string - DeploymentName string + ID uint `gorm:"primarykey" json:"id,omitempty"` + Key string `gorm:"unique;not null" json:"key,omitempty"` + Name string `gorm:"unique;not null" json:"name,omitempty"` + UserId string `json:"-,omitempty"` + ApiType string `gorm:"column:api_type"` + EndPoint string `gorm:"column:endpoint"` + ResourceNmae string `gorm:"column:resource_name"` + DeploymentName string `gorm:"column:deployment_name"` CreatedAt time.Time `json:"createdAt,omitempty"` UpdatedAt time.Time `json:"updatedAt,omitempty"` } +func (k Key) ToString() string { + bdate, _ := json.Marshal(k) + return string(bdate) +} + func GetKeyrByName(name string) (*Key, error) { var key Key result := db.First(&key, "name = ?", name) @@ -32,10 +41,11 @@ func GetAllKeys() ([]Key, error) { } // 添加记录 -func AddKey(apikey, name string) error { +func AddKey(apitype, apikey, name string) error { key := Key{ - Key: apikey, - Name: name, + ApiType: apitype, + Key: apikey, + Name: name, } if err := db.Create(&key).Error; err != nil { return err @@ -44,6 +54,14 @@ func AddKey(apikey, name string) error { return nil } +func CreateKey(k *Key) error { + if err := db.Create(&k).Error; err != nil { + return err + } + LoadKeysCache() + return nil +} + // 删除记录 func DeleteKey(id uint) error { if err := db.Delete(&Key{}, id).Error; err != nil {