diff --git a/router/router.go b/router/router.go index f261ece..300e6c2 100644 --- a/router/router.go +++ b/router/router.go @@ -287,7 +287,7 @@ func HandleAddKey(c *gin.Context) { return } k := &store.Key{ - ApiType: "azure_openai", + ApiType: "azure", Name: body.Name, Key: body.Key, ResourceNmae: keynames[1], @@ -299,7 +299,7 @@ func HandleAddKey(c *gin.Context) { }}) return } - } else if strings.HasPrefix(body.Name, "anthropic.") { + } else if strings.HasPrefix(body.Name, "claude.") { keynames := strings.Split(body.Name, ".") if len(keynames) < 2 { c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ @@ -311,7 +311,8 @@ func HandleAddKey(c *gin.Context) { body.Endpoint = "https://api.anthropic.com" } k := &store.Key{ - ApiType: "anthropic", + // ApiType: "anthropic", + ApiType: "claude", Name: body.Name, Key: body.Key, ResourceNmae: keynames[1], @@ -460,6 +461,7 @@ func HandleProy(c *gin.Context) { chatreq = openai.ChatCompletionRequest{} chatres = openai.ChatCompletionResponse{} chatlog store.Tokens + onekey store.Key pre_prompt string req *http.Request err error @@ -491,10 +493,6 @@ func HandleProy(c *gin.Context) { c.AbortWithError(http.StatusBadRequest, err) return } - if strings.HasPrefix(chatreq.Model, "claude-") { - claude.Translate(c, &chatreq) - return - } chatlog.Model = chatreq.Model for _, m := range chatreq.Messages { @@ -508,7 +506,15 @@ func HandleProy(c *gin.Context) { var body bytes.Buffer json.NewEncoder(&body).Encode(chatreq) - onekey := store.FromKeyCacheRandomItemKey() + if strings.HasPrefix(chatreq.Model, "claude-") { + onekey, err = store.SelectKeyCache("claude") + if err != nil { + c.AbortWithError(http.StatusForbidden, err) + } + } else { + onekey = store.FromKeyCacheRandomItemKey() + } + // 创建 API 请求 switch onekey.ApiType { case "claude": @@ -554,7 +560,7 @@ func HandleProy(c *gin.Context) { req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, c.Request.Body) if err != nil { log.Println(err) - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } req.Header = c.Request.Header @@ -563,7 +569,7 @@ func HandleProy(c *gin.Context) { resp, err := client.Do(req) if err != nil { log.Println(err) - c.JSON(http.StatusOK, gin.H{"error": err.Error()}) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } defer resp.Body.Close() @@ -595,14 +601,42 @@ func HandleProy(c *gin.Context) { reader := bufio.NewReader(resp.Body) if resp.StatusCode == 200 && localuser { - if isStream { - contentCh := fetchResponseContent(c, reader) - var buffer bytes.Buffer - for content := range contentCh { - buffer.WriteString(content) + switch onekey.ApiType { + case "claude": + claude.TransRsp(c, isStream, reader) + return + case "openai", "azure", "azure_openai": + fallthrough + default: + if isStream { + contentCh := fetchResponseContent(c, reader) + var buffer bytes.Buffer + for content := range contentCh { + buffer.WriteString(content) + } + chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model) + chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount + chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) + if err := store.Record(&chatlog); err != nil { + log.Println(err) + } + if err := store.SumDaily(chatlog.UserID); err != nil { + log.Println(err) + } + return } - chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model) - chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount + res, err := io.ReadAll(reader) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{ + "message": err.Error(), + }}) + return + } + reader = bufio.NewReader(bytes.NewBuffer(res)) + json.NewDecoder(bytes.NewBuffer(res)).Decode(&chatres) + chatlog.PromptCount = chatres.Usage.PromptTokens + chatlog.CompletionCount = chatres.Usage.CompletionTokens + chatlog.TotalTokens = chatres.Usage.TotalTokens chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) if err := store.Record(&chatlog); err != nil { log.Println(err) @@ -610,26 +644,6 @@ func HandleProy(c *gin.Context) { if err := store.SumDaily(chatlog.UserID); err != nil { log.Println(err) } - return - } - res, err := io.ReadAll(reader) - if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{ - "message": err.Error(), - }}) - return - } - reader = bufio.NewReader(bytes.NewBuffer(res)) - json.NewDecoder(bytes.NewBuffer(res)).Decode(&chatres) - chatlog.PromptCount = chatres.Usage.PromptTokens - chatlog.CompletionCount = chatres.Usage.CompletionTokens - chatlog.TotalTokens = chatres.Usage.TotalTokens - chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) - if err := store.Record(&chatlog); err != nil { - log.Println(err) - } - if err := store.SumDaily(chatlog.UserID); err != nil { - log.Println(err) } }