azure openai

This commit is contained in:
Sakurasan
2023-05-27 22:47:32 +08:00
parent 2f7567d23e
commit 9c04122680
5 changed files with 105 additions and 27 deletions

View File

@@ -10,6 +10,15 @@ OpenCat for Team的开源实现
~~基本~~实现了opencatd的全部功能
## Extra Support:
| 任务 | 完成情况 |
| --- | --- |
|Azure OpenAI | ✅|
| ... | ... |
## 快速上手
```
docker run -d --name opencatd -p 80:80 -v /etc/opencatd:/app/db mirrors2/opencatd-open

View File

@@ -9,6 +9,10 @@ curl $AZURE_OPENAI_ENDPOINT/openai/deployments/gpt-35-turbo/chat/completions?api
"messages": [{"role": "user", "content": "你好"}]
}'
curl $AZURE_OPENAI_ENDPOINT/openai/deployments?api-version=2022-12-01 \
-H "Content-Type: application/json" \
-H "api-key: $AZURE_OPENAI_KEY" \
*/
package azureopenai

View File

@@ -210,16 +210,45 @@ func HandleUsers(c *gin.Context) {
func HandleAddKey(c *gin.Context) {
var body Key
if err := c.BindJSON(&body); err != nil {
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
if err := store.AddKey(body.Key, body.Name); err != nil {
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
return
if strings.HasPrefix(strings.ToLower(body.Name), "azure") {
keynames := strings.Split(body.Name, ".")
if len(keynames) < 2 {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
"message": "Invalid Key Name",
}})
return
}
k := &store.Key{
ApiType: "azure_openai",
Name: body.Name,
Key: body.Key,
EndPoint: keynames[1],
}
if err := store.CreateKey(k); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
} else {
if err := store.AddKey("openai", body.Key, body.Name); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
}
k, err := store.GetKeyrByName(body.Name)
if err != nil {
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
c.JSON(http.StatusOK, k)
@@ -333,6 +362,13 @@ func HandleProy(c *gin.Context) {
}
if c.Request.URL.Path == "/v1/chat/completions" && localuser {
if store.KeysCache.ItemCount() == 0 {
c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{
"message": "No Api-Key Available",
}})
return
}
onekey := store.FromKeyCacheRandomItemKey()
if err := c.BindJSON(&chatreq); err != nil {
c.AbortWithError(http.StatusBadRequest, err)
@@ -350,12 +386,24 @@ func HandleProy(c *gin.Context) {
var body bytes.Buffer
json.NewEncoder(&body).Encode(chatreq)
// 创建 API 请求
req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body)
switch onekey.ApiType {
case "azure_openai":
req, err = http.NewRequest(c.Request.Method, fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", onekey.EndPoint, modelmap(chatreq.Model)), &body)
req.Header = c.Request.Header
req.Header.Set("api-key", onekey.Key)
case "openai":
fallthrough
default:
req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body)
req.Header = c.Request.Header
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key))
}
if err != nil {
log.Println(err)
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
return
}
} else {
req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, c.Request.Body)
if err != nil {
@@ -363,17 +411,7 @@ func HandleProy(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
return
}
}
req.Header = c.Request.Header
if localuser {
if store.KeysCache.ItemCount() == 0 {
c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{
"message": "No Api-Key Available",
}})
return
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", store.FromKeyCacheRandomItem()))
req.Header = c.Request.Header
}
resp, err := client.Do(req)
@@ -500,8 +538,8 @@ func HandleReverseProxy(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"error": "No Api-Key Available"})
return
}
// c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", store.FromKeyCacheRandomItem()))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", store.FromKeyCacheRandomItem()))
onekey := store.FromKeyCacheRandomItemKey()
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key))
}
proxy.ServeHTTP(c.Writer, req)
@@ -629,3 +667,13 @@ func NumTokensFromStr(messages string, model string) (num_tokens int) {
num_tokens += len(tkm.Encode(messages, nil, nil))
return num_tokens
}
func modelmap(in string) string {
switch in {
case "gpt-3.5-turbo":
return "gpt-35-turbo"
case "gpt-4":
return "gpt-4"
}
return in
}

View File

@@ -26,18 +26,18 @@ func LoadKeysCache() {
return
}
for idx, key := range keys {
KeysCache.Set(to.String(idx), key.Key, cache.NoExpiration)
KeysCache.Set(to.String(idx), key, cache.NoExpiration)
}
}
func FromKeyCacheRandomItem() string {
func FromKeyCacheRandomItemKey() Key {
items := KeysCache.Items()
if len(items) == 1 {
return items[to.String(0)].Object.(string)
return items[to.String(0)].Object.(Key)
}
idx := rand.Intn(len(items))
item := items[to.String(idx)]
return item.Object.(string)
return item.Object.(Key)
}
func LoadAuthCache() {

View File

@@ -1,6 +1,9 @@
package store
import "time"
import (
"encoding/json"
"time"
)
type Key struct {
ID uint `gorm:"primarykey" json:"id,omitempty"`
@@ -14,6 +17,11 @@ type Key struct {
UpdatedAt time.Time `json:"updatedAt,omitempty"`
}
func (k Key) ToString() string {
bdate, _ := json.Marshal(k)
return string(bdate)
}
func GetKeyrByName(name string) (*Key, error) {
var key Key
result := db.First(&key, "name = ?", name)
@@ -32,10 +40,11 @@ func GetAllKeys() ([]Key, error) {
}
// 添加记录
func AddKey(apikey, name string) error {
func AddKey(apitype, apikey, name string) error {
key := Key{
Key: apikey,
Name: name,
ApiType: apitype,
Key: apikey,
Name: name,
}
if err := db.Create(&key).Error; err != nil {
return err
@@ -44,6 +53,14 @@ func AddKey(apikey, name string) error {
return nil
}
func CreateKey(k *Key) error {
if err := db.Create(&k).Error; err != nil {
return err
}
LoadKeysCache()
return nil
}
// 删除记录
func DeleteKey(id uint) error {
if err := db.Delete(&Key{}, id).Error; err != nil {