Squashed commit of feat/claude

This commit is contained in:
Sakurasan
2023-09-16 21:18:11 +08:00
parent 678928cafd
commit 545147abe0
3 changed files with 265 additions and 42 deletions

View File

@@ -16,6 +16,7 @@ import (
"net/http/httputil"
"net/url"
"opencatd-open/pkg/azureopenai"
"opencatd-open/pkg/claude"
"opencatd-open/store"
"os"
"path/filepath"
@@ -286,7 +287,7 @@ func HandleAddKey(c *gin.Context) {
return
}
k := &store.Key{
ApiType: "azure_openai",
ApiType: "azure",
Name: body.Name,
Key: body.Key,
ResourceNmae: keynames[1],
@@ -298,7 +299,7 @@ func HandleAddKey(c *gin.Context) {
}})
return
}
} else if strings.HasPrefix(body.Name, "anthropic.") {
} else if strings.HasPrefix(body.Name, "claude.") {
keynames := strings.Split(body.Name, ".")
if len(keynames) < 2 {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{
@@ -310,7 +311,8 @@ func HandleAddKey(c *gin.Context) {
body.Endpoint = "https://api.anthropic.com"
}
k := &store.Key{
ApiType: "anthropic",
// ApiType: "anthropic",
ApiType: "claude",
Name: body.Name,
Key: body.Key,
ResourceNmae: keynames[1],
@@ -459,6 +461,7 @@ func HandleProy(c *gin.Context) {
chatreq = openai.ChatCompletionRequest{}
chatres = openai.ChatCompletionResponse{}
chatlog store.Tokens
onekey store.Key
pre_prompt string
req *http.Request
err error
@@ -469,6 +472,10 @@ func HandleProy(c *gin.Context) {
localuser = store.IsExistAuthCache(auth[7:])
c.Set("localuser", auth[7:])
}
if c.Request.URL.Path == "/v1/complete" {
claude.ClaudeProxy(c)
return
}
if c.Request.URL.Path == "/v1/audio/transcriptions" {
WhisperProxy(c)
return
@@ -481,12 +488,12 @@ func HandleProy(c *gin.Context) {
}})
return
}
onekey := store.FromKeyCacheRandomItemKey()
if err := c.BindJSON(&chatreq); err != nil {
c.AbortWithError(http.StatusBadRequest, err)
return
}
chatlog.Model = chatreq.Model
for _, m := range chatreq.Messages {
pre_prompt += m.Content + "\n"
@@ -498,8 +505,28 @@ func HandleProy(c *gin.Context) {
var body bytes.Buffer
json.NewEncoder(&body).Encode(chatreq)
if strings.HasPrefix(chatreq.Model, "claude-") {
onekey, err = store.SelectKeyCache("claude")
if err != nil {
c.AbortWithError(http.StatusForbidden, err)
}
} else {
onekey = store.FromKeyCacheRandomItemKey()
}
// 创建 API 请求
switch onekey.ApiType {
case "claude":
payload, _ := claude.TransReq(&chatreq)
buildurl := "https://api.anthropic.com/v1/complete"
req, err = http.NewRequest("POST", buildurl, payload)
req.Header.Add("accept", "application/json")
req.Header.Add("anthropic-version", "2023-06-01")
req.Header.Add("x-api-key", onekey.Key)
req.Header.Add("content-type", "application/json")
case "azure":
fallthrough
case "azure_openai":
var buildurl string
var apiVersion = "2023-05-15"
@@ -533,7 +560,7 @@ func HandleProy(c *gin.Context) {
req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, c.Request.Body)
if err != nil {
log.Println(err)
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
req.Header = c.Request.Header
@@ -542,7 +569,7 @@ func HandleProy(c *gin.Context) {
resp, err := client.Do(req)
if err != nil {
log.Println(err)
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
defer resp.Body.Close()
@@ -574,14 +601,42 @@ func HandleProy(c *gin.Context) {
reader := bufio.NewReader(resp.Body)
if resp.StatusCode == 200 && localuser {
if isStream {
contentCh := fetchResponseContent(c, reader)
var buffer bytes.Buffer
for content := range contentCh {
buffer.WriteString(content)
switch onekey.ApiType {
case "claude":
claude.TransRsp(c, isStream, reader)
return
case "openai", "azure", "azure_openai":
fallthrough
default:
if isStream {
contentCh := fetchResponseContent(c, reader)
var buffer bytes.Buffer
for content := range contentCh {
buffer.WriteString(content)
}
chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model)
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
if err := store.Record(&chatlog); err != nil {
log.Println(err)
}
if err := store.SumDaily(chatlog.UserID); err != nil {
log.Println(err)
}
return
}
chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model)
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
res, err := io.ReadAll(reader)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
reader = bufio.NewReader(bytes.NewBuffer(res))
json.NewDecoder(bytes.NewBuffer(res)).Decode(&chatres)
chatlog.PromptCount = chatres.Usage.PromptTokens
chatlog.CompletionCount = chatres.Usage.CompletionTokens
chatlog.TotalTokens = chatres.Usage.TotalTokens
chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
if err := store.Record(&chatlog); err != nil {
log.Println(err)
@@ -589,26 +644,6 @@ func HandleProy(c *gin.Context) {
if err := store.SumDaily(chatlog.UserID); err != nil {
log.Println(err)
}
return
}
res, err := io.ReadAll(reader)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{
"message": err.Error(),
}})
return
}
reader = bufio.NewReader(bytes.NewBuffer(res))
json.NewDecoder(bytes.NewBuffer(res)).Decode(&chatres)
chatlog.PromptCount = chatres.Usage.PromptTokens
chatlog.CompletionCount = chatres.Usage.CompletionTokens
chatlog.TotalTokens = chatres.Usage.TotalTokens
chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
if err := store.Record(&chatlog); err != nil {
log.Println(err)
}
if err := store.SumDaily(chatlog.UserID); err != nil {
log.Println(err)
}
}