This commit is contained in:
c菌
2023-09-16 18:14:37 +08:00
parent de31741d06
commit d9ecd1ea74

View File

@@ -287,7 +287,7 @@ func HandleAddKey(c *gin.Context) {
return
}
k := &store.Key{
ApiType: "azure_openai",
ApiType: "azure",
Name: body.Name,
Key: body.Key,
ResourceNmae: keynames[1],
@@ -299,7 +299,7 @@ func HandleAddKey(c *gin.Context) {
}})
return
}
} else if strings.HasPrefix(body.Name, "anthropic.") {
} else if strings.HasPrefix(body.Name, "claude.") {
keynames := strings.Split(body.Name, ".")
if len(keynames) < 2 {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
@@ -311,7 +311,8 @@ func HandleAddKey(c *gin.Context) {
body.Endpoint = "https://api.anthropic.com"
}
k := &store.Key{
ApiType: "anthropic",
// ApiType: "anthropic",
ApiType: "claude",
Name: body.Name,
Key: body.Key,
ResourceNmae: keynames[1],
@@ -460,6 +461,7 @@ func HandleProy(c *gin.Context) {
chatreq = openai.ChatCompletionRequest{}
chatres = openai.ChatCompletionResponse{}
chatlog store.Tokens
onekey store.Key
pre_prompt string
req *http.Request
err error
@@ -491,10 +493,6 @@ func HandleProy(c *gin.Context) {
c.AbortWithError(http.StatusBadRequest, err)
return
}
if strings.HasPrefix(chatreq.Model, "claude-") {
claude.Translate(c, &chatreq)
return
}
chatlog.Model = chatreq.Model
for _, m := range chatreq.Messages {
@@ -508,7 +506,15 @@ func HandleProy(c *gin.Context) {
var body bytes.Buffer
json.NewEncoder(&body).Encode(chatreq)
onekey := store.FromKeyCacheRandomItemKey()
if strings.HasPrefix(chatreq.Model, "claude-") {
onekey, err = store.SelectKeyCache("claude")
if err != nil {
c.AbortWithError(http.StatusForbidden, err)
}
} else {
onekey = store.FromKeyCacheRandomItemKey()
}
// 创建 API 请求
switch onekey.ApiType {
case "claude":
@@ -554,7 +560,7 @@ func HandleProy(c *gin.Context) {
req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, c.Request.Body)
if err != nil {
log.Println(err)
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.Header = c.Request.Header
@@ -563,7 +569,7 @@ func HandleProy(c *gin.Context) {
resp, err := client.Do(req)
if err != nil {
log.Println(err)
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
defer resp.Body.Close()
@@ -595,14 +601,42 @@ func HandleProy(c *gin.Context) {
reader := bufio.NewReader(resp.Body)
if resp.StatusCode == 200 && localuser {
if isStream {
contentCh := fetchResponseContent(c, reader)
var buffer bytes.Buffer
for content := range contentCh {
buffer.WriteString(content)
switch onekey.ApiType {
case "claude":
claude.TransRsp(c, isStream, reader)
return
case "openai", "azure", "azure_openai":
fallthrough
default:
if isStream {
contentCh := fetchResponseContent(c, reader)
var buffer bytes.Buffer
for content := range contentCh {
buffer.WriteString(content)
}
chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model)
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
if err := store.Record(&chatlog); err != nil {
log.Println(err)
}
if err := store.SumDaily(chatlog.UserID); err != nil {
log.Println(err)
}
return
}
chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model)
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
res, err := io.ReadAll(reader)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
reader = bufio.NewReader(bytes.NewBuffer(res))
json.NewDecoder(bytes.NewBuffer(res)).Decode(&chatres)
chatlog.PromptCount = chatres.Usage.PromptTokens
chatlog.CompletionCount = chatres.Usage.CompletionTokens
chatlog.TotalTokens = chatres.Usage.TotalTokens
chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
if err := store.Record(&chatlog); err != nil {
log.Println(err)
@@ -610,26 +644,6 @@ func HandleProy(c *gin.Context) {
if err := store.SumDaily(chatlog.UserID); err != nil {
log.Println(err)
}
return
}
res, err := io.ReadAll(reader)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
reader = bufio.NewReader(bytes.NewBuffer(res))
json.NewDecoder(bytes.NewBuffer(res)).Decode(&chatres)
chatlog.PromptCount = chatres.Usage.PromptTokens
chatlog.CompletionCount = chatres.Usage.CompletionTokens
chatlog.TotalTokens = chatres.Usage.TotalTokens
chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
if err := store.Record(&chatlog); err != nil {
log.Println(err)
}
if err := store.SumDaily(chatlog.UserID); err != nil {
log.Println(err)
}
}