This commit is contained in:
c菌
2023-09-16 18:14:37 +08:00
parent de31741d06
commit d9ecd1ea74

View File

@@ -287,7 +287,7 @@ func HandleAddKey(c *gin.Context) {
return return
} }
k := &store.Key{ k := &store.Key{
ApiType: "azure_openai", ApiType: "azure",
Name: body.Name, Name: body.Name,
Key: body.Key, Key: body.Key,
ResourceNmae: keynames[1], ResourceNmae: keynames[1],
@@ -299,7 +299,7 @@ func HandleAddKey(c *gin.Context) {
}}) }})
return return
} }
} else if strings.HasPrefix(body.Name, "anthropic.") { } else if strings.HasPrefix(body.Name, "claude.") {
keynames := strings.Split(body.Name, ".") keynames := strings.Split(body.Name, ".")
if len(keynames) < 2 { if len(keynames) < 2 {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
@@ -311,7 +311,8 @@ func HandleAddKey(c *gin.Context) {
body.Endpoint = "https://api.anthropic.com" body.Endpoint = "https://api.anthropic.com"
} }
k := &store.Key{ k := &store.Key{
ApiType: "anthropic", // ApiType: "anthropic",
ApiType: "claude",
Name: body.Name, Name: body.Name,
Key: body.Key, Key: body.Key,
ResourceNmae: keynames[1], ResourceNmae: keynames[1],
@@ -460,6 +461,7 @@ func HandleProy(c *gin.Context) {
chatreq = openai.ChatCompletionRequest{} chatreq = openai.ChatCompletionRequest{}
chatres = openai.ChatCompletionResponse{} chatres = openai.ChatCompletionResponse{}
chatlog store.Tokens chatlog store.Tokens
onekey store.Key
pre_prompt string pre_prompt string
req *http.Request req *http.Request
err error err error
@@ -491,10 +493,6 @@ func HandleProy(c *gin.Context) {
c.AbortWithError(http.StatusBadRequest, err) c.AbortWithError(http.StatusBadRequest, err)
return return
} }
if strings.HasPrefix(chatreq.Model, "claude-") {
claude.Translate(c, &chatreq)
return
}
chatlog.Model = chatreq.Model chatlog.Model = chatreq.Model
for _, m := range chatreq.Messages { for _, m := range chatreq.Messages {
@@ -508,7 +506,15 @@ func HandleProy(c *gin.Context) {
var body bytes.Buffer var body bytes.Buffer
json.NewEncoder(&body).Encode(chatreq) json.NewEncoder(&body).Encode(chatreq)
onekey := store.FromKeyCacheRandomItemKey() if strings.HasPrefix(chatreq.Model, "claude-") {
onekey, err = store.SelectKeyCache("claude")
if err != nil {
c.AbortWithError(http.StatusForbidden, err)
}
} else {
onekey = store.FromKeyCacheRandomItemKey()
}
// 创建 API 请求 // 创建 API 请求
switch onekey.ApiType { switch onekey.ApiType {
case "claude": case "claude":
@@ -554,7 +560,7 @@ func HandleProy(c *gin.Context) {
req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, c.Request.Body) req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, c.Request.Body)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
c.JSON(http.StatusOK, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
req.Header = c.Request.Header req.Header = c.Request.Header
@@ -563,7 +569,7 @@ func HandleProy(c *gin.Context) {
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
c.JSON(http.StatusOK, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
@@ -595,14 +601,42 @@ func HandleProy(c *gin.Context) {
reader := bufio.NewReader(resp.Body) reader := bufio.NewReader(resp.Body)
if resp.StatusCode == 200 && localuser { if resp.StatusCode == 200 && localuser {
if isStream { switch onekey.ApiType {
contentCh := fetchResponseContent(c, reader) case "claude":
var buffer bytes.Buffer claude.TransRsp(c, isStream, reader)
for content := range contentCh { return
buffer.WriteString(content) case "openai", "azure", "azure_openai":
fallthrough
default:
if isStream {
contentCh := fetchResponseContent(c, reader)
var buffer bytes.Buffer
for content := range contentCh {
buffer.WriteString(content)
}
chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model)
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
if err := store.Record(&chatlog); err != nil {
log.Println(err)
}
if err := store.SumDaily(chatlog.UserID); err != nil {
log.Println(err)
}
return
} }
chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model) res, err := io.ReadAll(reader)
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
reader = bufio.NewReader(bytes.NewBuffer(res))
json.NewDecoder(bytes.NewBuffer(res)).Decode(&chatres)
chatlog.PromptCount = chatres.Usage.PromptTokens
chatlog.CompletionCount = chatres.Usage.CompletionTokens
chatlog.TotalTokens = chatres.Usage.TotalTokens
chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
if err := store.Record(&chatlog); err != nil { if err := store.Record(&chatlog); err != nil {
log.Println(err) log.Println(err)
@@ -610,26 +644,6 @@ func HandleProy(c *gin.Context) {
if err := store.SumDaily(chatlog.UserID); err != nil { if err := store.SumDaily(chatlog.UserID); err != nil {
log.Println(err) log.Println(err)
} }
return
}
res, err := io.ReadAll(reader)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
reader = bufio.NewReader(bytes.NewBuffer(res))
json.NewDecoder(bytes.NewBuffer(res)).Decode(&chatres)
chatlog.PromptCount = chatres.Usage.PromptTokens
chatlog.CompletionCount = chatres.Usage.CompletionTokens
chatlog.TotalTokens = chatres.Usage.TotalTokens
chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
if err := store.Record(&chatlog); err != nil {
log.Println(err)
}
if err := store.SumDaily(chatlog.UserID); err != nil {
log.Println(err)
} }
} }