diff --git a/pkg/azureopenai/azureopenai.go b/pkg/azureopenai/azureopenai.go index d9d7a26..db5b775 100644 --- a/pkg/azureopenai/azureopenai.go +++ b/pkg/azureopenai/azureopenai.go @@ -20,6 +20,7 @@ package azureopenai import ( "encoding/json" "net/http" + "regexp" "strings" ) @@ -46,7 +47,7 @@ type ModelsList struct { } func Models(endpoint, apikey string) (*ModelsList, error) { - endpoint = removeTrailingSlash(endpoint) + endpoint = RemoveTrailingSlash(endpoint) var modelsl ModelsList req, _ := http.NewRequest(http.MethodGet, endpoint+"/openai/deployments?api-version=2022-12-01", nil) req.Header.Set("api-key", apikey) @@ -63,10 +64,19 @@ func Models(endpoint, apikey string) (*ModelsList, error) { } -func removeTrailingSlash(s string) string { +func RemoveTrailingSlash(s string) string { const prefix = "openai.azure.com/" - if strings.HasPrefix(s, prefix) && strings.HasSuffix(s, "/") { + if strings.HasSuffix(strings.TrimSpace(s), prefix) && strings.HasSuffix(s, "/") { return s[:len(s)-1] } return s } + +func GetResourceName(url string) string { + re := regexp.MustCompile(`https?://(.+)\.openai\.azure\.com/?`) + match := re.FindStringSubmatch(url) + if len(match) > 1 { + return match[1] + } + return "" +} diff --git a/router/router.go b/router/router.go index 3862aa0..f2d76fd 100644 --- a/router/router.go +++ b/router/router.go @@ -12,6 +12,7 @@ import ( "net" "net/http" "net/http/httputil" + "opencatd-open/pkg/azureopenai" "opencatd-open/store" "os" "strings" @@ -46,8 +47,10 @@ type User struct { type Key struct { ID int `json:"id,omitempty"` Key string `json:"key,omitempty"` - UpdatedAt string `json:"updatedAt,omitempty"` Name string `json:"name,omitempty"` + ApiType string `json:"api_type,omitempty"` + Endpoint string `json:"endpoint,omitempty"` + UpdatedAt string `json:"updatedAt,omitempty"` CreatedAt string `json:"createdAt,omitempty"` } @@ -223,7 +226,9 @@ func HandleAddKey(c *gin.Context) { }}) return } - if strings.HasPrefix(strings.ToLower(body.Name), "azure") { + body.Name = strings.ToLower(strings.TrimSpace(body.Name)) + body.Key = strings.TrimSpace(body.Key) + if strings.HasPrefix(body.Name, "azure.") { keynames := strings.Split(body.Name, ".") if len(keynames) < 2 { c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ @@ -236,6 +241,7 @@ func HandleAddKey(c *gin.Context) { Name: body.Name, Key: body.Key, ResourceNmae: keynames[1], + EndPoint: body.Endpoint, } if err := store.CreateKey(k); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ @@ -244,12 +250,29 @@ func HandleAddKey(c *gin.Context) { return } } else { - if err := store.AddKey("openai", body.Key, body.Name); err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ - "message": err.Error(), - }}) - return + if body.ApiType == "" { + if err := store.AddKey("openai", body.Key, body.Name); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } + } else { + k := &store.Key{ + ApiType: body.ApiType, + Name: body.Name, + Key: body.Key, + ResourceNmae: azureopenai.GetResourceName(body.Endpoint), + EndPoint: body.Endpoint, + } + if err := store.CreateKey(k); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } } + } k, err := store.GetKeyrByName(body.Name) @@ -281,10 +304,10 @@ func HandleAddUser(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"error": err.Error()}) return } - // if len(body.Name) == 0 { - // c.JSON(http.StatusOK, gin.H{"error": "invalid user name"}) - // return - // } + if len(body.Name) == 0 { + c.JSON(http.StatusOK, gin.H{"error": "invalid user name"}) + return + } if err := store.AddUser(body.Name, uuid.NewString()); err != nil { c.JSON(http.StatusOK, gin.H{"error": err.Error()}) @@ -396,7 +419,14 @@ func HandleProy(c *gin.Context) { // 创建 API 请求 switch onekey.ApiType { case "azure_openai": - req, err = http.NewRequest(c.Request.Method, fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=2023-05-15", onekey.ResourceNmae, modelmap(chatreq.Model)), &body) + var buildurl string + var apiVersion = "2023-05-15" + if onekey.EndPoint != "" { + buildurl = fmt.Sprintf("https://%s/openai/deployments/%s/chat/completions?api-version=%s", onekey.EndPoint, modelmap(chatreq.Model), apiVersion) + } else { + buildurl = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s", onekey.ResourceNmae, modelmap(chatreq.Model), apiVersion) + } + req, err = http.NewRequest(c.Request.Method, buildurl, &body) req.Header = c.Request.Header req.Header.Set("api-key", onekey.Key) case "openai":