From 24bac8e38de694b9c2d903c646e5320f193ee2c8 Mon Sep 17 00:00:00 2001 From: Sakurasan <26715255+Sakurasan@users.noreply.github.com> Date: Fri, 20 Dec 2024 03:25:56 +0800 Subject: [PATCH] update: openai struct --- pkg/claude/chat.go | 80 +++++++++++++++++++++++-------------------- pkg/google/chat.go | 85 ++++++++++++++++++++++------------------------ 2 files changed, 84 insertions(+), 81 deletions(-) diff --git a/pkg/claude/chat.go b/pkg/claude/chat.go index 39afe48..fdfda85 100644 --- a/pkg/claude/chat.go +++ b/pkg/claude/chat.go @@ -143,51 +143,57 @@ func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) { claudReq.Model = "" } - var prompt string - var claudecontent []VisionContent + var prompt string for _, msg := range chatReq.Messages { - if msg.Role == "system" { - claudReq.System = string(msg.Content) - continue - } - - var oaivisioncontent []openai.VisionContent - if err := json.Unmarshal(msg.Content, &oaivisioncontent); err != nil { - prompt += "<" + msg.Role + ">: " + string(msg.Content) + "\n" - - claudecontent = append(claudecontent, VisionContent{Type: "text", Text: msg.Role + ":" + string(msg.Content)}) - } else { - if len(oaivisioncontent) > 0 { - for _, content := range oaivisioncontent { - if content.Type == "text" { - prompt += "<" + msg.Role + ">: " + content.Text + "\n" - claudecontent = append(claudecontent, VisionContent{Type: "text", Text: msg.Role + ":" + content.Text}) - } else if content.Type == "image_url" { - if strings.HasPrefix(content.ImageURL.URL, "http") { - fmt.Println("链接:", content.ImageURL.URL) - } else if strings.HasPrefix(content.ImageURL.URL, "data:image") { - fmt.Println("base64:", content.ImageURL.URL[:20]) + switch ct := msg.Content.(type) { + case string: + prompt += "<" + msg.Role + ">: " + msg.Content.(string) + "\n" + if msg.Role == "system" { + claudReq.System = msg.Content.(string) + continue + } + claudecontent = append(claudecontent, VisionContent{Type: "text", Text: msg.Role + ":" + msg.Content.(string)}) + case []any: + for _, item := range ct { + if m, ok := item.(map[string]interface{}); ok { + if m["type"] == "text" { + prompt += "<" + msg.Role + ">: " + m["text"].(string) + "\n" + claudecontent = append(claudecontent, VisionContent{Type: "text", Text: msg.Role + ":" + m["text"].(string)}) + } else if m["type"] == "image_url" { + if url, ok := m["image_url"].(map[string]interface{}); ok { + fmt.Printf(" URL: %v\n", url["url"]) + if strings.HasPrefix(url["url"].(string), "http") { + fmt.Println("网络图片:", url["url"].(string)) + } else if strings.HasPrefix(url["url"].(string), "data:image") { + fmt.Println("base64:", url["url"].(string)[:20]) + var mediaType string + if strings.HasPrefix(url["url"].(string), "data:image/jpeg") { + mediaType = "image/jpeg" + } + if strings.HasPrefix(url["url"].(string), "data:image/png") { + mediaType = "image/png" + } + claudecontent = append(claudecontent, VisionContent{Type: "image", Source: &VisionSource{Type: "base64", MediaType: mediaType, Data: strings.Split(url["url"].(string), ",")[1]}}) + } } - // todo image tokens - var mediaType string - if strings.HasPrefix(content.ImageURL.URL, "data:image/jpeg") { - mediaType = "image/jpeg" - } - if strings.HasPrefix(content.ImageURL.URL, "data:image/png") { - mediaType = "image/png" - } - claudecontent = append(claudecontent, VisionContent{Type: "image", Source: &VisionSource{Type: "base64", MediaType: mediaType, Data: strings.Split(content.ImageURL.URL, ",")[1]}}) } } - } + default: + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Invalid content type", + }, + }) + return + } + if len(chatReq.Tools) > 0 { + tooljson, _ := json.Marshal(chatReq.Tools) + prompt += ": " + string(tooljson) + "\n" } - // if len(chatReq.Tools) > 0 { - // tooljson, _ := json.Marshal(chatReq.Tools) - // prompt += ": " + string(tooljson) + "\n" - // } } + claudReq.Messages = []VisionMessages{{Role: "user", Content: claudecontent}} usagelog.PromptCount = tokenizer.NumTokensFromStr(prompt, chatReq.Model) diff --git a/pkg/google/chat.go b/pkg/google/chat.go index ad18ce0..d0e78d7 100644 --- a/pkg/google/chat.go +++ b/pkg/google/chat.go @@ -97,59 +97,55 @@ func ChatProxy(c *gin.Context, chatReq *openai.ChatCompletionRequest) { var prompts []genai.Part var prompt string for _, msg := range chatReq.Messages { - var visioncontent []openai.VisionContent - if err := json.Unmarshal(msg.Content, &visioncontent); err != nil { - prompt += "<" + msg.Role + ">: " + string(msg.Content) + "\n" - prompts = append(prompts, genai.Text("<"+msg.Role+">: "+string(msg.Content))) - } else { - if len(visioncontent) > 0 { - for _, content := range visioncontent { - if content.Type == "text" { - prompt += "<" + msg.Role + ">: " + content.Text + "\n" - prompts = append(prompts, genai.Text("<"+msg.Role+">: "+content.Text)) - } else if content.Type == "image_url" { - if strings.HasPrefix(content.ImageURL.URL, "http") { - fmt.Println("链接:", content.ImageURL.URL) - } else if strings.HasPrefix(content.ImageURL.URL, "data:image") { - fmt.Println("base64:", content.ImageURL.URL[:20]) - if chatReq.Model != "gemini-pro-vision" { - chatReq.Model = "gemini-pro-vision" + switch ct := msg.Content.(type) { + case string: + prompt += "<" + msg.Role + ">: " + msg.Content.(string) + "\n" + prompts = append(prompts, genai.Text("<"+msg.Role+">: "+msg.Content.(string))) + case []any: + for _, item := range ct { + if m, ok := item.(map[string]interface{}); ok { + if m["type"] == "text" { + prompt += "<" + msg.Role + ">: " + m["text"].(string) + "\n" + prompts = append(prompts, genai.Text("<"+msg.Role+">: "+m["text"].(string))) + } else if m["type"] == "image_url" { + if url, ok := m["image_url"].(map[string]interface{}); ok { + if strings.HasPrefix(url["url"].(string), "http") { + fmt.Println("网络图片:", url["url"].(string)) + } else if strings.HasPrefix(url["url"].(string), "data:image") { + fmt.Println("base64:", url["url"].(string)[:20]) + var mime string + // openai 会以 data:image 开头,则去掉 data:image/png;base64, 和 data:image/jpeg;base64, + if strings.HasPrefix(url["url"].(string), "data:image/png") { + mime = "image/png" + } else if strings.HasPrefix(url["url"].(string), "data:image/jpeg") { + mime = "image/jpeg" + } else { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Unsupported image format"}) + return + } + imageString := strings.Split(url["url"].(string), ",")[1] + imageBytes, err := base64.StdEncoding.DecodeString(imageString) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + prompts = append(prompts, genai.Blob{MIMEType: mime, Data: imageBytes}) } - - var mime string - // openai 会以 data:image 开头,则去掉 data:image/png;base64, 和 data:image/jpeg;base64, - if strings.HasPrefix(content.ImageURL.URL, "data:image/png") { - mime = "image/png" - } else if strings.HasPrefix(content.ImageURL.URL, "data:image/jpeg") { - mime = "image/jpeg" - } else { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Unsupported image format"}) - return - } - imageString := strings.Split(content.ImageURL.URL, ",")[1] - imageBytes, err := base64.StdEncoding.DecodeString(imageString) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - prompts = append(prompts, genai.Blob{MIMEType: mime, Data: imageBytes}) } - - // todo image tokens } - } - } + default: + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": "Invalid content type", + }, + }) + return } if len(chatReq.Tools) > 0 { tooljson, _ := json.Marshal(chatReq.Tools) prompt += ": " + string(tooljson) + "\n" - - // for _, tool := range chatReq.Tools { - - // } - } } @@ -171,6 +167,7 @@ func ChatProxy(c *gin.Context, chatReq *openai.ChatCompletionRequest) { defer client.Close() model := client.GenerativeModel(chatReq.Model) + model.Tools = []*genai.Tool{} iter := model.GenerateContentStream(ctx, prompts...) datachan := make(chan string)