diff --git a/pkg/claude/chat.go b/pkg/claude/chat.go new file mode 100644 index 0000000..7532fb2 --- /dev/null +++ b/pkg/claude/chat.go @@ -0,0 +1,290 @@ +// https://docs.anthropic.com/claude/reference/messages_post + +package claude + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "opencatd-open/pkg/openai" + "opencatd-open/pkg/tokenizer" + "opencatd-open/store" + "strings" + + "github.com/gin-gonic/gin" +) + +func ChatProxy(c *gin.Context, chatReq *openai.ChatCompletionRequest) { + ChatMessages(c, chatReq) +} + +func ChatTextCompletions(c *gin.Context, chatReq *openai.ChatCompletionRequest) { + +} + +type ClaudeRequest struct { + Model string `json:"model,omitempty"` + Messages any `json:"messages,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` + System string `json:"system,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Temperature float64 `json:"temperature,omitempty"` +} + +func (c *ClaudeRequest) ByteJson() []byte { + bytejson, _ := json.Marshal(c) + return bytejson +} + +type ClaudeMessages struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` +} + +type VisionMessages struct { + Role string `json:"role,omitempty"` + Content []ClaudeContent `json:"content,omitempty"` +} + +type ClaudeContent struct { + Type string `json:"type,omitempty"` + Source *ClaudeSource `json:"source,omitempty"` + Text string `json:"text,omitempty"` +} + +type ClaudeSource struct { + Type string `json:"type,omitempty"` + MediaType string `json:"media_type,omitempty"` + Data string `json:"data,omitempty"` +} + +type ClaudeResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Model string `json:"model"` + StopSequence any `json:"stop_sequence"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + StopReason string `json:"stop_reason"` +} + +type ClaudeStreamResponse struct { + Type string `json:"type"` + Index int `json:"index"` + ContentBlock struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content_block"` + Delta struct { + Type string `json:"type"` + Text string `json:"text"` + StopReason string `json:"stop_reason"` + StopSequence any `json:"stop_sequence"` + } `json:"delta"` + Message struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []any `json:"content"` + Model string `json:"model"` + StopReason string `json:"stop_reason"` + StopSequence any `json:"stop_sequence"` + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + } `json:"usage"` + } `json:"message"` + Error struct { + Type string `json:"type"` + Message string `json:"message"` + } `json:"error"` + Usage struct { + OutputTokens int `json:"output_tokens"` + } `json:"usage"` +} + +func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) { + // var haveImages bool + + usagelog := store.Tokens{Model: chatReq.Model} + var claudReq ClaudeRequest + claudReq.Model = chatReq.Model + claudReq.Stream = chatReq.Stream + claudReq.Temperature = chatReq.Temperature + claudReq.TopP = chatReq.TopP + claudReq.MaxTokens = 4096 + + var msgs []any + var prompt string + for _, msg := range chatReq.Messages { + if msg.Role == "system" { + claudReq.System = string(msg.Content) + continue + } + + var visioncontent []openai.VisionContent + if err := json.Unmarshal(msg.Content, &visioncontent); err != nil { + prompt += "<" + msg.Role + ">: " + string(msg.Content) + "\n" + + var claudemsgs ClaudeMessages + claudemsgs.Role = msg.Role + claudemsgs.Content = string(msg.Content) + msgs = append(msgs, claudemsgs) + + } else { + if len(visioncontent) > 0 { + var visionMessage VisionMessages + visionMessage.Role = msg.Role + + for _, content := range visioncontent { + var claudecontent []ClaudeContent + if content.Type == "text" { + prompt += "<" + msg.Role + ">: " + content.Text + "\n" + claudecontent = append(claudecontent, ClaudeContent{Type: "text", Text: content.Text}) + } else if content.Type == "image_url" { + if strings.HasPrefix(content.ImageURL.URL, "http") { + fmt.Println("链接:", content.ImageURL.URL) + } else if strings.HasPrefix(content.ImageURL.URL, "data:image") { + fmt.Println("base64:", content.ImageURL.URL[:20]) + } + // todo image tokens + var mediaType string + if strings.HasPrefix(content.ImageURL.URL, "data:image/jpeg") { + mediaType = "image/jpeg" + } + if strings.HasPrefix(content.ImageURL.URL, "data:image/png") { + mediaType = "image/png" + } + claudecontent = append(claudecontent, ClaudeContent{Type: "image", Source: &ClaudeSource{Type: "base64", MediaType: mediaType, Data: strings.Split(content.ImageURL.URL, ",")[1]}}) + // haveImages = true + + } + visionMessage.Content = claudecontent + } + msgs = append(msgs, visionMessage) + } + } + claudReq.Messages = msgs + + // if len(chatReq.Tools) > 0 { + // tooljson, _ := json.Marshal(chatReq.Tools) + // prompt += ": " + string(tooljson) + "\n" + // } + } + + usagelog.PromptCount = tokenizer.NumTokensFromStr(prompt, chatReq.Model) + + req, _ := http.NewRequest("POST", MessageEndpoint, strings.NewReader(fmt.Sprintf("%v", bytes.NewReader(claudReq.ByteJson())))) + client := http.DefaultClient + rsp, err := client.Do(req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + io.Copy(c.Writer, rsp.Body) + return + } + + teeReader := io.TeeReader(rsp.Body, c.Writer) + + dataChan := make(chan string, 1) + // stopChan := make(chan bool) + + var result string + + scanner := bufio.NewScanner(teeReader) + + go func() { + for scanner.Scan() { + line := scanner.Bytes() + if len(line) > 0 && bytes.HasPrefix(line, []byte("data: ")) { + if bytes.HasPrefix(line, []byte("data: [DONE]")) { + dataChan <- string(line) + "\n" + break + } + var claudeResp ClaudeStreamResponse + line = bytes.Replace(line, []byte("data: "), []byte(""), -1) + line = bytes.TrimSpace(line) + if err := json.Unmarshal(line, &claudeResp); err != nil { + continue + } + + if claudeResp.Type == "message_start" { + if claudeResp.Message.Role != "" { + result += "<" + claudeResp.Message.Role + ">" + } + } else if claudeResp.Type == "message_stop" { + break + } + + if claudeResp.Delta.Text != "" { + result += claudeResp.Delta.Text + } + var choice openai.Choice + choice.Delta.Role = claudeResp.Message.Role + choice.Delta.Content = claudeResp.Delta.Text + choice.FinishReason = claudeResp.Delta.StopReason + + chatResp := openai.ChatCompletionStreamResponse{ + Model: chatReq.Model, + Choices: []openai.Choice{choice}, + } + dataChan <- "data: " + string(chatResp.ByteJson()) + "\n" + if claudeResp.Delta.StopReason != "" { + dataChan <- "\ndata: [DONE]\n" + } + } else { + if !bytes.HasPrefix(line, []byte("event:")) { + dataChan <- string(line) + "\n" + } + } + } + defer close(dataChan) + }() + + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") + + c.Stream(func(w io.Writer) bool { + if data, ok := <-dataChan; ok { + if strings.HasPrefix(data, "data: ") { + c.Writer.WriteString(data) + // c.Writer.WriteString("\n\n") + } else { + c.Writer.WriteHeader(http.StatusBadGateway) + c.Writer.WriteString(data) + } + c.Writer.Flush() + return true + } + go func() { + usagelog.CompletionCount = tokenizer.NumTokensFromStr(result, chatReq.Model) + usagelog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(usagelog.Model, usagelog.PromptCount, usagelog.CompletionCount)) + if err := store.Record(&usagelog); err != nil { + log.Println(err) + } + if err := store.SumDaily(usagelog.UserID); err != nil { + log.Println(err) + } + }() + return false + }) +} diff --git a/pkg/claude/claude.go b/pkg/claude/claude.go index 8b4ae07..0964f46 100644 --- a/pkg/claude/claude.go +++ b/pkg/claude/claude.go @@ -22,11 +22,26 @@ data: {"completion":"","stop_reason":"stop_sequence","model":"claude-2.0","stop" # Model Pricing -Claude Instant |100,000 tokens |Prompt $1.63/million tokens |Completion $5.51/million tokens +Claude Instant |100,000 tokens |Prompt $1.63/million tokens |Completion $5.51/million tokens -Claude 2 |100,000 tokens |Prompt $11.02/million tokens |Completion $32.68/million tokens +Claude 2 |100,000 tokens |Prompt $11.02/million tokens |Completion $32.68/million tokens *Claude 1 is still accessible and offered at the same price as Claude 2. +# AWS +https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-service.html +https://aws.amazon.com/cn/bedrock/pricing/ +Anthropic models Price for 1000 input tokens Price for 1000 output tokens +Claude Instant $0.00163 $0.00551 + +Claude $0.01102 $0.03268 + +https://docs.aws.amazon.com/bedrock/latest/userguide/endpointsTable.html +地区名称 地区 端点 协议 +美国东部(弗吉尼亚北部) 美国东部1 bedrock-runtime.us-east-1.amazonaws.com HTTPS + bedrock-runtime-fips.us-east-1.amazonaws.com HTTPS +美国西部(俄勒冈州) 美国西2号 bedrock-runtime.us-west-2.amazonaws.com HTTPS + bedrock-runtime-fips.us-west-2.amazonaws.com HTTPS +亚太地区(新加坡) ap-东南-1 bedrock-runtime.ap-southeast-1.amazonaws.com HTTPS */ // package anthropic @@ -53,7 +68,8 @@ import ( ) var ( - ClaudeUrl = "https://api.anthropic.com/v1/complete" + ClaudeUrl = "https://api.anthropic.com/v1/complete" + MessageEndpoint = "https://api.anthropic.com/v1/messages" ) type MessageModule struct { diff --git a/pkg/openai/chat.go b/pkg/openai/chat.go new file mode 100644 index 0000000..9b9f61e --- /dev/null +++ b/pkg/openai/chat.go @@ -0,0 +1,309 @@ +package openai + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "opencatd-open/pkg/tokenizer" + "opencatd-open/store" + "os" + "strings" + + "github.com/gin-gonic/gin" +) + +const ( + OpenAI_Endpoint = "https://api.openai.com/v1/chat/completions" +) + +var ( + BaseURL string // "https://api.openai.com" + AIGateWay_Endpoint = "https://gateway.ai.cloudflare.com/v1/431ba10f11200d544922fbca177aaa7f/openai/openai/chat/completions" +) + +func init() { + if os.Getenv("OpenAI_Endpoint") != "" { + BaseURL = os.Getenv("OpenAI_Endpoint") + } + if os.Getenv("AIGateWay_Endpoint") != "" { + AIGateWay_Endpoint = os.Getenv("AIGateWay_Endpoint") + } +} + +// Vision Content +type VisionContent struct { + Type string `json:"type,omitempty"` + Text string `json:"text,omitempty"` + ImageURL *VisionImageURL `json:"image_url,omitempty"` +} +type VisionImageURL struct { + URL string `json:"url,omitempty"` + Detail string `json:"detail,omitempty"` +} + +type ChatCompletionMessage struct { + Role string `json:"role"` + Content json.RawMessage `json:"content"` + Name string `json:"name,omitempty"` +} + +type FunctionDefinition struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters any `json:"parameters"` +} + +type Tool struct { + Type string `json:"type"` + Function *FunctionDefinition `json:"function,omitempty"` +} + +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + User string `json:"user,omitempty"` + // Functions []FunctionDefinition `json:"functions,omitempty"` + // FunctionCall any `json:"function_call,omitempty"` + Tools []Tool `json:"tools,omitempty"` + // ToolChoice any `json:"tool_choice,omitempty"` +} + +func (c ChatCompletionRequest) ToByteJson() []byte { + bytejson, _ := json.Marshal(c) + return bytejson +} + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` +} + +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + Choices []struct { + Index int `json:"index"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls"` + } `json:"message"` + Logprobs string `json:"logprobs"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + SystemFingerprint string `json:"system_fingerprint"` +} + +type Choice struct { + Index int `json:"index"` + Delta struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls"` + } `json:"delta"` + FinishReason string `json:"finish_reason"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + +type ChatCompletionStreamResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + Choices []Choice `json:"choices"` +} + +func (c *ChatCompletionStreamResponse) ByteJson() []byte { + bytejson, _ := json.Marshal(c) + return bytejson +} + +func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) { + usagelog := store.Tokens{Model: chatReq.Model} + + token, _ := c.Get("localuser") + + lu, err := store.GetUserByToken(token.(string)) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": err.Error(), + }, + }) + return + } + usagelog.UserID = int(lu.ID) + + var prompt string + for _, msg := range chatReq.Messages { + // prompt += "<" + msg.Role + ">: " + msg.Content + "\n" + var visioncontent []VisionContent + if err := json.Unmarshal(msg.Content, &visioncontent); err != nil { + prompt += "<" + msg.Role + ">: " + string(msg.Content) + "\n" + } else { + if len(visioncontent) > 0 { + for _, content := range visioncontent { + if content.Type == "text" { + prompt += "<" + msg.Role + ">: " + content.Text + "\n" + } else if content.Type == "image_url" { + if strings.HasPrefix(content.ImageURL.URL, "http") { + fmt.Println("链接:", content.ImageURL.URL) + } else if strings.HasPrefix(content.ImageURL.URL, "data:image") { + fmt.Println("base64:", content.ImageURL.URL[:20]) + } + // todo image tokens + } + + } + + } + } + if len(chatReq.Tools) > 0 { + tooljson, _ := json.Marshal(chatReq.Tools) + prompt += ": " + string(tooljson) + "\n" + } + } + + usagelog.PromptCount = tokenizer.NumTokensFromStr(prompt, chatReq.Model) + + onekey, err := store.SelectKeyCache("openai") + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + var req *http.Request + + if onekey.EndPoint != "" { // 优先key的endpoint + req, err = http.NewRequest(c.Request.Method, onekey.EndPoint+c.Request.RequestURI, bytes.NewReader(chatReq.ToByteJson())) + } else { + if BaseURL != "" { // 其次BaseURL + req, err = http.NewRequest(c.Request.Method, BaseURL+c.Request.RequestURI, bytes.NewReader(chatReq.ToByteJson())) + } else { // 最后是gateway的endpoint + req, err = http.NewRequest(c.Request.Method, AIGateWay_Endpoint, bytes.NewReader(chatReq.ToByteJson())) + } + } + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + req.Header = c.Request.Header + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key)) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + defer resp.Body.Close() + teeReader := io.TeeReader(resp.Body, c.Writer) + + var result string + if chatReq.Stream { + // 流式响应 + scanner := bufio.NewScanner(teeReader) + + for scanner.Scan() { + line := scanner.Bytes() + if len(line) > 0 && bytes.HasPrefix(line, []byte("data: ")) { + if bytes.HasPrefix(line, []byte("data: [DONE]")) { + break + } + var opiResp ChatCompletionStreamResponse + line = bytes.Replace(line, []byte("data: "), []byte(""), -1) + line = bytes.TrimSpace(line) + if err := json.Unmarshal(line, &opiResp); err != nil { + continue + } + + if opiResp.Choices != nil && len(opiResp.Choices) > 0 { + if opiResp.Choices[0].Delta.Role != "" { + result += "<" + opiResp.Choices[0].Delta.Role + "> " + } + result += opiResp.Choices[0].Delta.Content // 计算Content Token + + if len(opiResp.Choices[0].Delta.ToolCalls) > 0 { // 计算ToolCalls token + if opiResp.Choices[0].Delta.ToolCalls[0].Function.Name != "" { + result += "name:" + opiResp.Choices[0].Delta.ToolCalls[0].Function.Name + " arguments:" + } + result += opiResp.Choices[0].Delta.ToolCalls[0].Function.Arguments + } + } else { + continue + } + } + + } + } else { + // 处理非流式响应 + body, err := io.ReadAll(teeReader) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + var opiResp ChatCompletionResponse + if err := json.Unmarshal(body, &opiResp); err != nil { + log.Println("Error parsing JSON:", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Error parsing JSON," + err.Error()}) + return + } + if opiResp.Choices != nil && len(opiResp.Choices) > 0 { + if opiResp.Choices[0].Message.Role != "" { + result += "<" + opiResp.Choices[0].Message.Role + "> " + } + result += opiResp.Choices[0].Message.Content + + if len(opiResp.Choices[0].Message.ToolCalls) > 0 { + if opiResp.Choices[0].Message.ToolCalls[0].Function.Name != "" { + result += "name:" + opiResp.Choices[0].Message.ToolCalls[0].Function.Name + " arguments:" + } + result += opiResp.Choices[0].Message.ToolCalls[0].Function.Arguments + } + + } + } + usagelog.CompletionCount = tokenizer.NumTokensFromStr(result, chatReq.Model) + usagelog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(usagelog.Model, usagelog.PromptCount, usagelog.CompletionCount)) + if err := store.Record(&usagelog); err != nil { + log.Println(err) + } + if err := store.SumDaily(usagelog.UserID); err != nil { + log.Println(err) + } +} + +type ErrResponse struct { + Error struct { + Message string `json:"message"` + Code string `json:"code"` + } `json:"error"` +} diff --git a/router/chat.go b/router/chat.go new file mode 100644 index 0000000..f132428 --- /dev/null +++ b/router/chat.go @@ -0,0 +1,29 @@ +package router + +import ( + "net/http" + "strings" + + "opencatd-open/pkg/claude" + "opencatd-open/pkg/openai" + + "github.com/gin-gonic/gin" +) + +func ChatHandler(c *gin.Context) { + var chatreq openai.ChatCompletionRequest + if err := c.ShouldBindJSON(&chatreq); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if strings.HasPrefix(chatreq.Model, "gpt") { + openai.ChatProxy(c, &chatreq) + return + } + + if strings.HasPrefix(chatreq.Model, "claude") { + claude.ChatProxy(c, &chatreq) + return + } +} diff --git a/router/router.go b/router/router.go index dd3b5fb..af2b2c1 100644 --- a/router/router.go +++ b/router/router.go @@ -499,6 +499,9 @@ func HandleProy(c *gin.Context) { return } + ChatHandler(c) + return + if err := c.BindJSON(&chatreq); err != nil { c.AbortWithError(http.StatusBadRequest, err) return