From 9c04122680f9d13becd5f169a6443798186148b8 Mon Sep 17 00:00:00 2001 From: Sakurasan <1173092237@qq.com> Date: Sat, 27 May 2023 22:47:32 +0800 Subject: [PATCH] azure openai --- README.md | 9 ++++ pkg/azureopenai/azureopenai.go | 4 ++ router/router.go | 86 ++++++++++++++++++++++++++-------- store/cache.go | 8 ++-- store/keydb.go | 25 ++++++++-- 5 files changed, 105 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index ed13f83..62ff7c3 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,15 @@ OpenCat for Team的开源实现 ~~基本~~实现了opencatd的全部功能 +## Extra Support: + +| 任务 | 完成情况 | +| --- | --- | +|Azure OpenAI | ✅| +| ... | ... | + + + ## 快速上手 ``` docker run -d --name opencatd -p 80:80 -v /etc/opencatd:/app/db mirrors2/opencatd-open diff --git a/pkg/azureopenai/azureopenai.go b/pkg/azureopenai/azureopenai.go index 38a17db..d9d7a26 100644 --- a/pkg/azureopenai/azureopenai.go +++ b/pkg/azureopenai/azureopenai.go @@ -9,6 +9,10 @@ curl $AZURE_OPENAI_ENDPOINT/openai/deployments/gpt-35-turbo/chat/completions?api "messages": [{"role": "user", "content": "你好"}] }' + curl $AZURE_OPENAI_ENDPOINT/openai/deployments?api-version=2022-12-01 \ + -H "Content-Type: application/json" \ + -H "api-key: $AZURE_OPENAI_KEY" \ + */ package azureopenai diff --git a/router/router.go b/router/router.go index 43643a5..c207fda 100644 --- a/router/router.go +++ b/router/router.go @@ -210,16 +210,45 @@ func HandleUsers(c *gin.Context) { func HandleAddKey(c *gin.Context) { var body Key if err := c.BindJSON(&body); err != nil { - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) + c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ + "message": err.Error(), + }}) return } - if err := store.AddKey(body.Key, body.Name); err != nil { - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) - return + if strings.HasPrefix(strings.ToLower(body.Name), "azure") { + keynames := strings.Split(body.Name, ".") + if len(keynames) < 2 { + c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ + "message": "Invalid Key Name", + }}) + return + } + k := &store.Key{ + ApiType: "azure_openai", + Name: body.Name, + Key: body.Key, + EndPoint: keynames[1], + } + if err := store.CreateKey(k); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } + } else { + if err := store.AddKey("openai", body.Key, body.Name); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } } + k, err := store.GetKeyrByName(body.Name) if err != nil { - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ + "message": err.Error(), + }}) return } c.JSON(http.StatusOK, k) @@ -333,6 +362,13 @@ func HandleProy(c *gin.Context) { } if c.Request.URL.Path == "/v1/chat/completions" && localuser { + if store.KeysCache.ItemCount() == 0 { + c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{ + "message": "No Api-Key Available", + }}) + return + } + onekey := store.FromKeyCacheRandomItemKey() if err := c.BindJSON(&chatreq); err != nil { c.AbortWithError(http.StatusBadRequest, err) @@ -350,12 +386,24 @@ func HandleProy(c *gin.Context) { var body bytes.Buffer json.NewEncoder(&body).Encode(chatreq) // 创建 API 请求 - req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body) + switch onekey.ApiType { + case "azure_openai": + req, err = http.NewRequest(c.Request.Method, fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", onekey.EndPoint, modelmap(chatreq.Model)), &body) + req.Header = c.Request.Header + req.Header.Set("api-key", onekey.Key) + case "openai": + fallthrough + default: + req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body) + req.Header = c.Request.Header + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key)) + } if err != nil { log.Println(err) c.JSON(http.StatusOK, gin.H{"error": err.Error()}) return } + } else { req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, c.Request.Body) if err != nil { @@ -363,17 +411,7 @@ func HandleProy(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"error": err.Error()}) return } - } - - req.Header = c.Request.Header - if localuser { - if store.KeysCache.ItemCount() == 0 { - c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{ - "message": "No Api-Key Available", - }}) - return - } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", store.FromKeyCacheRandomItem())) + req.Header = c.Request.Header } resp, err := client.Do(req) @@ -500,8 +538,8 @@ func HandleReverseProxy(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"error": "No Api-Key Available"}) return } - // c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", store.FromKeyCacheRandomItem())) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", store.FromKeyCacheRandomItem())) + onekey := store.FromKeyCacheRandomItemKey() + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key)) } proxy.ServeHTTP(c.Writer, req) @@ -629,3 +667,13 @@ func NumTokensFromStr(messages string, model string) (num_tokens int) { num_tokens += len(tkm.Encode(messages, nil, nil)) return num_tokens } + +func modelmap(in string) string { + switch in { + case "gpt-3.5-turbo": + return "gpt-35-turbo" + case "gpt-4": + return "gpt-4" + } + return in +} diff --git a/store/cache.go b/store/cache.go index cd7abef..211e409 100644 --- a/store/cache.go +++ b/store/cache.go @@ -26,18 +26,18 @@ func LoadKeysCache() { return } for idx, key := range keys { - KeysCache.Set(to.String(idx), key.Key, cache.NoExpiration) + KeysCache.Set(to.String(idx), key, cache.NoExpiration) } } -func FromKeyCacheRandomItem() string { +func FromKeyCacheRandomItemKey() Key { items := KeysCache.Items() if len(items) == 1 { - return items[to.String(0)].Object.(string) + return items[to.String(0)].Object.(Key) } idx := rand.Intn(len(items)) item := items[to.String(idx)] - return item.Object.(string) + return item.Object.(Key) } func LoadAuthCache() { diff --git a/store/keydb.go b/store/keydb.go index 64d3510..fcab5a8 100644 --- a/store/keydb.go +++ b/store/keydb.go @@ -1,6 +1,9 @@ package store -import "time" +import ( + "encoding/json" + "time" +) type Key struct { ID uint `gorm:"primarykey" json:"id,omitempty"` @@ -14,6 +17,11 @@ type Key struct { UpdatedAt time.Time `json:"updatedAt,omitempty"` } +func (k Key) ToString() string { + bdate, _ := json.Marshal(k) + return string(bdate) +} + func GetKeyrByName(name string) (*Key, error) { var key Key result := db.First(&key, "name = ?", name) @@ -32,10 +40,11 @@ func GetAllKeys() ([]Key, error) { } // 添加记录 -func AddKey(apikey, name string) error { +func AddKey(apitype, apikey, name string) error { key := Key{ - Key: apikey, - Name: name, + ApiType: apitype, + Key: apikey, + Name: name, } if err := db.Create(&key).Error; err != nil { return err @@ -44,6 +53,14 @@ func AddKey(apikey, name string) error { return nil } +func CreateKey(k *Key) error { + if err := db.Create(&k).Error; err != nil { + return err + } + LoadKeysCache() + return nil +} + // 删除记录 func DeleteKey(id uint) error { if err := db.Delete(&Key{}, id).Error; err != nil {