Files
opencatd-open/llm/google/v2/chat.go
2025-04-21 19:10:27 +08:00

233 lines
6.5 KiB
Go

// https://github.com/google-gemini/api-examples/
// https://ai.google.dev/gemini-api/docs/models?hl=zh-cn
package google
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"opencatd-open/internal/model"
"opencatd-open/llm"
"os"
"strings"
"github.com/sashabaranov/go-openai"
"google.golang.org/genai"
)
type Gemini struct {
Ctx context.Context
Client *genai.Client
ApiKey *model.ApiKey
tokenUsage *llm.TokenUsage
Done chan struct{}
}
func NewGemini(ctx context.Context, apiKey *model.ApiKey) (*Gemini, error) {
hc := http.DefaultClient
if os.Getenv("LOCAL_PROXY") != "" {
proxyUrl, err := url.Parse(os.Getenv("LOCAL_PROXY"))
if err == nil {
hc = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}}
}
}
client, err := genai.NewClient(ctx, &genai.ClientConfig{
APIKey: *apiKey.ApiKey,
Backend: genai.BackendGeminiAPI,
HTTPClient: hc,
})
if err != nil {
return nil, err
}
return &Gemini{
Ctx: context.Background(),
Client: client,
ApiKey: apiKey,
tokenUsage: &llm.TokenUsage{},
Done: make(chan struct{}),
}, nil
}
func (g *Gemini) Chat(ctx context.Context, chatReq llm.ChatRequest) (*llm.ChatResponse, error) {
var content []*genai.Content
if len(chatReq.Messages) > 0 {
for _, msg := range chatReq.Messages {
var role genai.Role
if msg.Role == "user" || msg.Role == "system" {
role = genai.RoleUser
} else {
role = genai.RoleModel
}
if len(msg.MultiContent) > 0 {
for _, c := range msg.MultiContent {
var parts []*genai.Part
if c.Type == "text" {
parts = append(parts, genai.NewPartFromText(c.Text))
}
if c.Type == "image_url" {
if strings.HasPrefix(c.ImageURL.URL, "http") {
continue
}
if strings.HasPrefix(c.ImageURL.URL, "data:image") {
var mediaType string
if strings.HasPrefix(c.ImageURL.URL, "data:image/jpeg") {
mediaType = "image/jpeg"
}
if strings.HasPrefix(c.ImageURL.URL, "data:image/png") {
mediaType = "image/png"
}
imageString := strings.Split(c.ImageURL.URL, ",")[1]
imageBytes, _ := base64.StdEncoding.DecodeString(imageString)
parts = append(parts, genai.NewPartFromBytes(imageBytes, mediaType))
}
}
content = append(content, genai.NewContentFromParts(parts, role))
}
} else {
content = append(content, genai.NewContentFromText(msg.Content, role))
}
}
}
tools := []*genai.Tool{{GoogleSearch: &genai.GoogleSearch{}}}
response, err := g.Client.Models.GenerateContent(g.Ctx,
chatReq.Model,
content,
&genai.GenerateContentConfig{Tools: tools})
if err != nil {
return nil, err
}
if g.tokenUsage.Model == "" && response.ModelVersion != "" {
g.tokenUsage.Model = response.ModelVersion
}
if response.UsageMetadata != nil {
g.tokenUsage.PromptTokens += int(response.UsageMetadata.PromptTokenCount)
g.tokenUsage.CompletionTokens += int(response.UsageMetadata.CandidatesTokenCount)
g.tokenUsage.ToolsTokens += int(response.UsageMetadata.ToolUsePromptTokenCount)
g.tokenUsage.TotalTokens += int(response.UsageMetadata.TotalTokenCount)
}
// var text string
// if response.Candidates != nil && response.Candidates[0].Content != nil {
// for _, part := range response.Candidates[0].Content.Parts {
// text += part.Text
// }
// }
return &llm.ChatResponse{
Model: response.ModelVersion,
Choices: []openai.ChatCompletionChoice{
{
Message: openai.ChatCompletionMessage{Content: response.Text(), Role: "assistant"},
FinishReason: openai.FinishReason(response.Candidates[0].FinishReason),
},
},
Usage: openai.Usage{PromptTokens: g.tokenUsage.PromptTokens + g.tokenUsage.ToolsTokens, CompletionTokens: g.tokenUsage.CompletionTokens, TotalTokens: g.tokenUsage.TotalTokens},
}, nil
}
func (g *Gemini) StreamChat(ctx context.Context, chatReq llm.ChatRequest) (chan *llm.StreamChatResponse, error) {
var contents []*genai.Content
if len(chatReq.Messages) > 0 {
for _, msg := range chatReq.Messages {
var role genai.Role
if msg.Role == "user" {
role = genai.RoleUser
} else {
role = genai.RoleModel
}
if len(msg.MultiContent) > 0 {
for _, c := range msg.MultiContent {
var parts []*genai.Part
if c.Type == "text" {
parts = append(parts, genai.NewPartFromText(c.Text))
}
if c.Type == "image_url" {
if strings.HasPrefix(c.ImageURL.URL, "http") {
continue
}
if strings.HasPrefix(c.ImageURL.URL, "data:image") {
var mediaType string
if strings.HasPrefix(c.ImageURL.URL, "data:image/jpeg") {
mediaType = "image/jpeg"
}
if strings.HasPrefix(c.ImageURL.URL, "data:image/png") {
mediaType = "image/png"
}
imageString := strings.Split(c.ImageURL.URL, ",")[1]
imageBytes, _ := base64.StdEncoding.DecodeString(imageString)
parts = append(parts, genai.NewPartFromBytes(imageBytes, mediaType))
}
}
contents = append(contents, genai.NewContentFromParts(parts, role))
}
} else {
contents = append(contents, genai.NewContentFromText(msg.Content, role))
}
}
}
datachan := make(chan *llm.StreamChatResponse)
var generr error
tools := []*genai.Tool{{GoogleSearch: &genai.GoogleSearch{}}}
go func() {
defer close(datachan)
for result, err := range g.Client.Models.GenerateContentStream(g.Ctx, chatReq.Model, contents, &genai.GenerateContentConfig{Tools: tools}) {
if err != nil {
fmt.Println(err)
generr = err
return
}
if result.UsageMetadata != nil {
g.tokenUsage.PromptTokens += int(result.UsageMetadata.PromptTokenCount)
g.tokenUsage.CompletionTokens += int(result.UsageMetadata.CandidatesTokenCount)
g.tokenUsage.ToolsTokens += int(result.UsageMetadata.ToolUsePromptTokenCount)
g.tokenUsage.TotalTokens += int(result.UsageMetadata.TotalTokenCount)
}
datachan <- &llm.StreamChatResponse{
Model: result.ModelVersion,
Choices: []openai.ChatCompletionStreamChoice{
{
Delta: openai.ChatCompletionStreamChoiceDelta{
Role: "assistant",
// Content: result.Candidates[0].Content.Parts[0].Text,
Content: result.Text(),
},
FinishReason: openai.FinishReason(result.Candidates[0].FinishReason),
},
},
Usage: &openai.Usage{PromptTokens: g.tokenUsage.PromptTokens + g.tokenUsage.ToolsTokens, CompletionTokens: g.tokenUsage.CompletionTokens, TotalTokens: g.tokenUsage.TotalTokens},
}
}
}()
return datachan, generr
}
func (g *Gemini) GetTokenUsage() *llm.TokenUsage {
return g.tokenUsage
}