package claude import ( "bufio" "bytes" "encoding/json" "fmt" "io" "log" "net/http" "opencatd-open/llm/openai" "opencatd-open/llm/vertexai" "opencatd-open/pkg/error" "opencatd-open/pkg/tokenizer" "opencatd-open/store" "strings" "github.com/gin-gonic/gin" ) func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) { var ( req *http.Request targetURL = ClaudeMessageEndpoint ) apiKey, err := store.SelectKeyCache("claude") if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } usagelog := store.Tokens{Model: chatReq.Model} var claudReq ChatRequest claudReq.Model = chatReq.Model claudReq.Stream = chatReq.Stream // claudReq.Temperature = chatReq.Temperature claudReq.TopP = chatReq.TopP claudReq.MaxTokens = 4096 if apiKey.ApiType == "vertex" { claudReq.AnthropicVersion = "vertex-2023-10-16" claudReq.Model = "" } var claudecontent []VisionContent var prompt string for _, msg := range chatReq.Messages { switch ct := msg.Content.(type) { case string: prompt += "<" + msg.Role + ">: " + msg.Content.(string) + "\n" if msg.Role == "system" { claudReq.System = msg.Content.(string) continue } claudecontent = append(claudecontent, VisionContent{Type: "text", Text: msg.Role + ":" + msg.Content.(string)}) case []any: for _, item := range ct { if m, ok := item.(map[string]interface{}); ok { if m["type"] == "text" { prompt += "<" + msg.Role + ">: " + m["text"].(string) + "\n" claudecontent = append(claudecontent, VisionContent{Type: "text", Text: msg.Role + ":" + m["text"].(string)}) } else if m["type"] == "image_url" { if url, ok := m["image_url"].(map[string]interface{}); ok { fmt.Printf(" URL: %v\n", url["url"]) if strings.HasPrefix(url["url"].(string), "http") { fmt.Println("网络图片:", url["url"].(string)) } else if strings.HasPrefix(url["url"].(string), "data:image") { fmt.Println("base64:", url["url"].(string)[:20]) var mediaType string if strings.HasPrefix(url["url"].(string), "data:image/jpeg") { mediaType = "image/jpeg" } if strings.HasPrefix(url["url"].(string), "data:image/png") { mediaType = "image/png" } claudecontent = append(claudecontent, VisionContent{Type: "image", Source: &VisionSource{Type: "base64", MediaType: mediaType, Data: strings.Split(url["url"].(string), ",")[1]}}) } } } } } default: c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": "Invalid content type", }, }) return } if len(chatReq.Tools) > 0 { tooljson, _ := json.Marshal(chatReq.Tools) prompt += ": " + string(tooljson) + "\n" } } claudReq.Messages = []VisionMessages{{Role: "user", Content: claudecontent}} usagelog.PromptCount = tokenizer.NumTokensFromStr(prompt, chatReq.Model) if apiKey.ApiType == "vertex" { var vertexSecret vertexai.VertexSecretKey if err := json.Unmarshal([]byte(apiKey.ApiSecret), &vertexSecret); err != nil { c.JSON(http.StatusInternalServerError, error.ErrorData(err.Error())) return } vcmodel, ok := vertexai.VertexClaudeModelMap[chatReq.Model] if !ok { c.JSON(http.StatusInternalServerError, error.ErrorData("Model not found")) return } // 获取gcloud token,临时放置在apiKey.Key中 gcloudToken, err := vertexai.GcloudAuth(vertexSecret.ClientEmail, vertexSecret.PrivateKey) if err != nil { c.JSON(http.StatusInternalServerError, error.ErrorData(err.Error())) return } // 拼接vertex的请求地址 targetURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", vcmodel.Region, vertexSecret.ProjectID, vcmodel.Region, vcmodel.VertexName) req, _ = http.NewRequest("POST", targetURL, bytes.NewReader(claudReq.ByteJson())) req.Header.Set("Authorization", "Bearer "+gcloudToken) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "text/event-stream") req.Header.Set("Accept-Encoding", "identity") } else { req, _ = http.NewRequest("POST", targetURL, bytes.NewReader(claudReq.ByteJson())) req.Header.Set("x-api-key", apiKey.Key) req.Header.Set("anthropic-version", "2023-06-01") req.Header.Set("Content-Type", "application/json") } client := http.DefaultClient rsp, err := client.Do(req) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } defer rsp.Body.Close() if rsp.StatusCode != http.StatusOK { io.Copy(c.Writer, rsp.Body) return } var buffer bytes.Buffer teeReader := io.TeeReader(rsp.Body, &buffer) dataChan := make(chan string) // stopChan := make(chan bool) var result string scanner := bufio.NewScanner(teeReader) go func() { for scanner.Scan() { line := scanner.Bytes() if len(line) > 0 && bytes.HasPrefix(line, []byte("data: ")) { if bytes.HasPrefix(line, []byte("data: [DONE]")) { dataChan <- string(line) + "\n" break } var claudeResp ClaudeStreamResponse line = bytes.Replace(line, []byte("data: "), []byte(""), -1) line = bytes.TrimSpace(line) if err := json.Unmarshal(line, &claudeResp); err != nil { continue } if claudeResp.Type == "message_start" { if claudeResp.Message.Role != "" { result += "<" + claudeResp.Message.Role + ">" } } else if claudeResp.Type == "message_stop" { break } if claudeResp.Delta.Text != "" { result += claudeResp.Delta.Text } var choice openai.Choice choice.Delta.Role = claudeResp.Message.Role choice.Delta.Content = claudeResp.Delta.Text choice.FinishReason = claudeResp.Delta.StopReason chatResp := openai.ChatCompletionStreamResponse{ Model: chatReq.Model, Choices: []openai.Choice{choice}, } dataChan <- "data: " + string(chatResp.ByteJson()) + "\n" if claudeResp.Delta.StopReason != "" { dataChan <- "\ndata: [DONE]\n" } } } defer close(dataChan) }() c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Connection", "keep-alive") c.Writer.Header().Set("Transfer-Encoding", "chunked") c.Writer.Header().Set("X-Accel-Buffering", "no") c.Stream(func(w io.Writer) bool { if data, ok := <-dataChan; ok { if strings.HasPrefix(data, "data: ") { c.Writer.WriteString(data) // c.Writer.WriteString("\n\n") } else { c.Writer.WriteHeader(http.StatusBadGateway) c.Writer.WriteString(data) } c.Writer.Flush() return true } go func() { usagelog.CompletionCount = tokenizer.NumTokensFromStr(result, chatReq.Model) usagelog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(usagelog.Model, usagelog.PromptCount, usagelog.CompletionCount)) if err := store.Record(&usagelog); err != nil { log.Println(err) } if err := store.SumDaily(usagelog.UserID); err != nil { log.Println(err) } }() return false }) }