Files
opencatd-open/llm/claude/handle_proxy.go
2025-04-16 18:01:27 +08:00

231 lines
6.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package claude
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"opencatd-open/llm/openai"
"opencatd-open/llm/vertexai"
"opencatd-open/pkg/error"
"opencatd-open/pkg/tokenizer"
"opencatd-open/store"
"strings"
"github.com/gin-gonic/gin"
)
func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
var (
req *http.Request
targetURL = ClaudeMessageEndpoint
)
apiKey, err := store.SelectKeyCache("claude")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
usagelog := store.Tokens{Model: chatReq.Model}
var claudReq ChatRequest
claudReq.Model = chatReq.Model
claudReq.Stream = chatReq.Stream
// claudReq.Temperature = chatReq.Temperature
claudReq.TopP = chatReq.TopP
claudReq.MaxTokens = 4096
if apiKey.ApiType == "vertex" {
claudReq.AnthropicVersion = "vertex-2023-10-16"
claudReq.Model = ""
}
var claudecontent []VisionContent
var prompt string
for _, msg := range chatReq.Messages {
switch ct := msg.Content.(type) {
case string:
prompt += "<" + msg.Role + ">: " + msg.Content.(string) + "\n"
if msg.Role == "system" {
claudReq.System = msg.Content.(string)
continue
}
claudecontent = append(claudecontent, VisionContent{Type: "text", Text: msg.Role + ":" + msg.Content.(string)})
case []any:
for _, item := range ct {
if m, ok := item.(map[string]interface{}); ok {
if m["type"] == "text" {
prompt += "<" + msg.Role + ">: " + m["text"].(string) + "\n"
claudecontent = append(claudecontent, VisionContent{Type: "text", Text: msg.Role + ":" + m["text"].(string)})
} else if m["type"] == "image_url" {
if url, ok := m["image_url"].(map[string]interface{}); ok {
fmt.Printf(" URL: %v\n", url["url"])
if strings.HasPrefix(url["url"].(string), "http") {
fmt.Println("网络图片:", url["url"].(string))
} else if strings.HasPrefix(url["url"].(string), "data:image") {
fmt.Println("base64:", url["url"].(string)[:20])
var mediaType string
if strings.HasPrefix(url["url"].(string), "data:image/jpeg") {
mediaType = "image/jpeg"
}
if strings.HasPrefix(url["url"].(string), "data:image/png") {
mediaType = "image/png"
}
claudecontent = append(claudecontent, VisionContent{Type: "image", Source: &VisionSource{Type: "base64", MediaType: mediaType, Data: strings.Split(url["url"].(string), ",")[1]}})
}
}
}
}
}
default:
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Invalid content type",
},
})
return
}
if len(chatReq.Tools) > 0 {
tooljson, _ := json.Marshal(chatReq.Tools)
prompt += "<tools>: " + string(tooljson) + "\n"
}
}
claudReq.Messages = []VisionMessages{{Role: "user", Content: claudecontent}}
usagelog.PromptCount = tokenizer.NumTokensFromStr(prompt, chatReq.Model)
if apiKey.ApiType == "vertex" {
var vertexSecret vertexai.VertexSecretKey
if err := json.Unmarshal([]byte(apiKey.ApiSecret), &vertexSecret); err != nil {
c.JSON(http.StatusInternalServerError, error.ErrorData(err.Error()))
return
}
vcmodel, ok := vertexai.VertexClaudeModelMap[chatReq.Model]
if !ok {
c.JSON(http.StatusInternalServerError, error.ErrorData("Model not found"))
return
}
// 获取gcloud token临时放置在apiKey.Key中
gcloudToken, err := vertexai.GcloudAuth(vertexSecret.ClientEmail, vertexSecret.PrivateKey)
if err != nil {
c.JSON(http.StatusInternalServerError, error.ErrorData(err.Error()))
return
}
// 拼接vertex的请求地址
targetURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", vcmodel.Region, vertexSecret.ProjectID, vcmodel.Region, vcmodel.VertexName)
req, _ = http.NewRequest("POST", targetURL, bytes.NewReader(claudReq.ByteJson()))
req.Header.Set("Authorization", "Bearer "+gcloudToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Accept-Encoding", "identity")
} else {
req, _ = http.NewRequest("POST", targetURL, bytes.NewReader(claudReq.ByteJson()))
req.Header.Set("x-api-key", apiKey.Key)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("Content-Type", "application/json")
}
client := http.DefaultClient
rsp, err := client.Do(req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer rsp.Body.Close()
if rsp.StatusCode != http.StatusOK {
io.Copy(c.Writer, rsp.Body)
return
}
var buffer bytes.Buffer
teeReader := io.TeeReader(rsp.Body, &buffer)
dataChan := make(chan string)
// stopChan := make(chan bool)
var result string
scanner := bufio.NewScanner(teeReader)
go func() {
for scanner.Scan() {
line := scanner.Bytes()
if len(line) > 0 && bytes.HasPrefix(line, []byte("data: ")) {
if bytes.HasPrefix(line, []byte("data: [DONE]")) {
dataChan <- string(line) + "\n"
break
}
var claudeResp ClaudeStreamResponse
line = bytes.Replace(line, []byte("data: "), []byte(""), -1)
line = bytes.TrimSpace(line)
if err := json.Unmarshal(line, &claudeResp); err != nil {
continue
}
if claudeResp.Type == "message_start" {
if claudeResp.Message.Role != "" {
result += "<" + claudeResp.Message.Role + ">"
}
} else if claudeResp.Type == "message_stop" {
break
}
if claudeResp.Delta.Text != "" {
result += claudeResp.Delta.Text
}
var choice openai.Choice
choice.Delta.Role = claudeResp.Message.Role
choice.Delta.Content = claudeResp.Delta.Text
choice.FinishReason = claudeResp.Delta.StopReason
chatResp := openai.ChatCompletionStreamResponse{
Model: chatReq.Model,
Choices: []openai.Choice{choice},
}
dataChan <- "data: " + string(chatResp.ByteJson()) + "\n"
if claudeResp.Delta.StopReason != "" {
dataChan <- "\ndata: [DONE]\n"
}
}
}
defer close(dataChan)
}()
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Stream(func(w io.Writer) bool {
if data, ok := <-dataChan; ok {
if strings.HasPrefix(data, "data: ") {
c.Writer.WriteString(data)
// c.Writer.WriteString("\n\n")
} else {
c.Writer.WriteHeader(http.StatusBadGateway)
c.Writer.WriteString(data)
}
c.Writer.Flush()
return true
}
go func() {
usagelog.CompletionCount = tokenizer.NumTokensFromStr(result, chatReq.Model)
usagelog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(usagelog.Model, usagelog.PromptCount, usagelog.CompletionCount))
if err := store.Record(&usagelog); err != nil {
log.Println(err)
}
if err := store.SumDaily(usagelog.UserID); err != nil {
log.Println(err)
}
}()
return false
})
}