update: openai struct

This commit is contained in:
Sakurasan
2024-12-20 03:25:56 +08:00
parent fb5b1a55ae
commit 24bac8e38d
2 changed files with 84 additions and 81 deletions

View File

@@ -143,51 +143,57 @@ func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
claudReq.Model = "" claudReq.Model = ""
} }
var prompt string
var claudecontent []VisionContent var claudecontent []VisionContent
var prompt string
for _, msg := range chatReq.Messages { for _, msg := range chatReq.Messages {
switch ct := msg.Content.(type) {
case string:
prompt += "<" + msg.Role + ">: " + msg.Content.(string) + "\n"
if msg.Role == "system" { if msg.Role == "system" {
claudReq.System = string(msg.Content) claudReq.System = msg.Content.(string)
continue continue
} }
claudecontent = append(claudecontent, VisionContent{Type: "text", Text: msg.Role + ":" + msg.Content.(string)})
var oaivisioncontent []openai.VisionContent case []any:
if err := json.Unmarshal(msg.Content, &oaivisioncontent); err != nil { for _, item := range ct {
prompt += "<" + msg.Role + ">: " + string(msg.Content) + "\n" if m, ok := item.(map[string]interface{}); ok {
if m["type"] == "text" {
claudecontent = append(claudecontent, VisionContent{Type: "text", Text: msg.Role + ":" + string(msg.Content)}) prompt += "<" + msg.Role + ">: " + m["text"].(string) + "\n"
} else { claudecontent = append(claudecontent, VisionContent{Type: "text", Text: msg.Role + ":" + m["text"].(string)})
if len(oaivisioncontent) > 0 { } else if m["type"] == "image_url" {
for _, content := range oaivisioncontent { if url, ok := m["image_url"].(map[string]interface{}); ok {
if content.Type == "text" { fmt.Printf(" URL: %v\n", url["url"])
prompt += "<" + msg.Role + ">: " + content.Text + "\n" if strings.HasPrefix(url["url"].(string), "http") {
claudecontent = append(claudecontent, VisionContent{Type: "text", Text: msg.Role + ":" + content.Text}) fmt.Println("网络图片:", url["url"].(string))
} else if content.Type == "image_url" { } else if strings.HasPrefix(url["url"].(string), "data:image") {
if strings.HasPrefix(content.ImageURL.URL, "http") { fmt.Println("base64:", url["url"].(string)[:20])
fmt.Println("链接:", content.ImageURL.URL)
} else if strings.HasPrefix(content.ImageURL.URL, "data:image") {
fmt.Println("base64:", content.ImageURL.URL[:20])
}
// todo image tokens
var mediaType string var mediaType string
if strings.HasPrefix(content.ImageURL.URL, "data:image/jpeg") { if strings.HasPrefix(url["url"].(string), "data:image/jpeg") {
mediaType = "image/jpeg" mediaType = "image/jpeg"
} }
if strings.HasPrefix(content.ImageURL.URL, "data:image/png") { if strings.HasPrefix(url["url"].(string), "data:image/png") {
mediaType = "image/png" mediaType = "image/png"
} }
claudecontent = append(claudecontent, VisionContent{Type: "image", Source: &VisionSource{Type: "base64", MediaType: mediaType, Data: strings.Split(content.ImageURL.URL, ",")[1]}}) claudecontent = append(claudecontent, VisionContent{Type: "image", Source: &VisionSource{Type: "base64", MediaType: mediaType, Data: strings.Split(url["url"].(string), ",")[1]}})
}
}
}
}
}
default:
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Invalid content type",
},
})
return
}
if len(chatReq.Tools) > 0 {
tooljson, _ := json.Marshal(chatReq.Tools)
prompt += "<tools>: " + string(tooljson) + "\n"
} }
} }
}
}
// if len(chatReq.Tools) > 0 {
// tooljson, _ := json.Marshal(chatReq.Tools)
// prompt += "<tools>: " + string(tooljson) + "\n"
// }
}
claudReq.Messages = []VisionMessages{{Role: "user", Content: claudecontent}} claudReq.Messages = []VisionMessages{{Role: "user", Content: claudecontent}}
usagelog.PromptCount = tokenizer.NumTokensFromStr(prompt, chatReq.Model) usagelog.PromptCount = tokenizer.NumTokensFromStr(prompt, chatReq.Model)

View File

@@ -97,36 +97,33 @@ func ChatProxy(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
var prompts []genai.Part var prompts []genai.Part
var prompt string var prompt string
for _, msg := range chatReq.Messages { for _, msg := range chatReq.Messages {
var visioncontent []openai.VisionContent switch ct := msg.Content.(type) {
if err := json.Unmarshal(msg.Content, &visioncontent); err != nil { case string:
prompt += "<" + msg.Role + ">: " + string(msg.Content) + "\n" prompt += "<" + msg.Role + ">: " + msg.Content.(string) + "\n"
prompts = append(prompts, genai.Text("<"+msg.Role+">: "+string(msg.Content))) prompts = append(prompts, genai.Text("<"+msg.Role+">: "+msg.Content.(string)))
} else { case []any:
if len(visioncontent) > 0 { for _, item := range ct {
for _, content := range visioncontent { if m, ok := item.(map[string]interface{}); ok {
if content.Type == "text" { if m["type"] == "text" {
prompt += "<" + msg.Role + ">: " + content.Text + "\n" prompt += "<" + msg.Role + ">: " + m["text"].(string) + "\n"
prompts = append(prompts, genai.Text("<"+msg.Role+">: "+content.Text)) prompts = append(prompts, genai.Text("<"+msg.Role+">: "+m["text"].(string)))
} else if content.Type == "image_url" { } else if m["type"] == "image_url" {
if strings.HasPrefix(content.ImageURL.URL, "http") { if url, ok := m["image_url"].(map[string]interface{}); ok {
fmt.Println("链接:", content.ImageURL.URL) if strings.HasPrefix(url["url"].(string), "http") {
} else if strings.HasPrefix(content.ImageURL.URL, "data:image") { fmt.Println("网络图片:", url["url"].(string))
fmt.Println("base64:", content.ImageURL.URL[:20]) } else if strings.HasPrefix(url["url"].(string), "data:image") {
if chatReq.Model != "gemini-pro-vision" { fmt.Println("base64:", url["url"].(string)[:20])
chatReq.Model = "gemini-pro-vision"
}
var mime string var mime string
// openai 会以 data:image 开头,则去掉 data:image/png;base64, 和 data:image/jpeg;base64, // openai 会以 data:image 开头,则去掉 data:image/png;base64, 和 data:image/jpeg;base64,
if strings.HasPrefix(content.ImageURL.URL, "data:image/png") { if strings.HasPrefix(url["url"].(string), "data:image/png") {
mime = "image/png" mime = "image/png"
} else if strings.HasPrefix(content.ImageURL.URL, "data:image/jpeg") { } else if strings.HasPrefix(url["url"].(string), "data:image/jpeg") {
mime = "image/jpeg" mime = "image/jpeg"
} else { } else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Unsupported image format"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Unsupported image format"})
return return
} }
imageString := strings.Split(content.ImageURL.URL, ",")[1] imageString := strings.Split(url["url"].(string), ",")[1]
imageBytes, err := base64.StdEncoding.DecodeString(imageString) imageBytes, err := base64.StdEncoding.DecodeString(imageString)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -134,22 +131,21 @@ func ChatProxy(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
} }
prompts = append(prompts, genai.Blob{MIMEType: mime, Data: imageBytes}) prompts = append(prompts, genai.Blob{MIMEType: mime, Data: imageBytes})
} }
// todo image tokens
} }
} }
} }
} }
default:
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Invalid content type",
},
})
return
}
if len(chatReq.Tools) > 0 { if len(chatReq.Tools) > 0 {
tooljson, _ := json.Marshal(chatReq.Tools) tooljson, _ := json.Marshal(chatReq.Tools)
prompt += "<tools>: " + string(tooljson) + "\n" prompt += "<tools>: " + string(tooljson) + "\n"
// for _, tool := range chatReq.Tools {
// }
} }
} }
@@ -171,6 +167,7 @@ func ChatProxy(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
defer client.Close() defer client.Close()
model := client.GenerativeModel(chatReq.Model) model := client.GenerativeModel(chatReq.Model)
model.Tools = []*genai.Tool{}
iter := model.GenerateContentStream(ctx, prompts...) iter := model.GenerateContentStream(ctx, prompts...)
datachan := make(chan string) datachan := make(chan string)