233 lines
6.5 KiB
Go
233 lines
6.5 KiB
Go
// https://github.com/google-gemini/api-examples/
|
|
// https://ai.google.dev/gemini-api/docs/models?hl=zh-cn
|
|
|
|
package google
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"opencatd-open/internal/model"
|
|
"opencatd-open/llm"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/sashabaranov/go-openai"
|
|
|
|
"google.golang.org/genai"
|
|
)
|
|
|
|
type Gemini struct {
|
|
Ctx context.Context
|
|
Client *genai.Client
|
|
ApiKey *model.ApiKey
|
|
tokenUsage *llm.TokenUsage
|
|
|
|
Done chan struct{}
|
|
}
|
|
|
|
func NewGemini(ctx context.Context, apiKey *model.ApiKey) (*Gemini, error) {
|
|
hc := http.DefaultClient
|
|
if os.Getenv("LOCAL_PROXY") != "" {
|
|
proxyUrl, err := url.Parse(os.Getenv("LOCAL_PROXY"))
|
|
if err == nil {
|
|
hc = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}}
|
|
}
|
|
}
|
|
client, err := genai.NewClient(ctx, &genai.ClientConfig{
|
|
APIKey: *apiKey.ApiKey,
|
|
Backend: genai.BackendGeminiAPI,
|
|
HTTPClient: hc,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Gemini{
|
|
Ctx: context.Background(),
|
|
Client: client,
|
|
ApiKey: apiKey,
|
|
tokenUsage: &llm.TokenUsage{},
|
|
Done: make(chan struct{}),
|
|
}, nil
|
|
|
|
}
|
|
|
|
func (g *Gemini) Chat(ctx context.Context, chatReq llm.ChatRequest) (*llm.ChatResponse, error) {
|
|
var content []*genai.Content
|
|
if len(chatReq.Messages) > 0 {
|
|
for _, msg := range chatReq.Messages {
|
|
var role genai.Role
|
|
if msg.Role == "user" || msg.Role == "system" {
|
|
role = genai.RoleUser
|
|
} else {
|
|
role = genai.RoleModel
|
|
}
|
|
|
|
if len(msg.MultiContent) > 0 {
|
|
for _, c := range msg.MultiContent {
|
|
var parts []*genai.Part
|
|
|
|
if c.Type == "text" {
|
|
parts = append(parts, genai.NewPartFromText(c.Text))
|
|
}
|
|
if c.Type == "image_url" {
|
|
if strings.HasPrefix(c.ImageURL.URL, "http") {
|
|
continue
|
|
}
|
|
if strings.HasPrefix(c.ImageURL.URL, "data:image") {
|
|
var mediaType string
|
|
if strings.HasPrefix(c.ImageURL.URL, "data:image/jpeg") {
|
|
mediaType = "image/jpeg"
|
|
}
|
|
if strings.HasPrefix(c.ImageURL.URL, "data:image/png") {
|
|
mediaType = "image/png"
|
|
}
|
|
imageString := strings.Split(c.ImageURL.URL, ",")[1]
|
|
imageBytes, _ := base64.StdEncoding.DecodeString(imageString)
|
|
|
|
parts = append(parts, genai.NewPartFromBytes(imageBytes, mediaType))
|
|
}
|
|
|
|
}
|
|
content = append(content, genai.NewContentFromParts(parts, role))
|
|
}
|
|
} else {
|
|
content = append(content, genai.NewContentFromText(msg.Content, role))
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
tools := []*genai.Tool{{GoogleSearch: &genai.GoogleSearch{}}}
|
|
response, err := g.Client.Models.GenerateContent(g.Ctx,
|
|
chatReq.Model,
|
|
content,
|
|
&genai.GenerateContentConfig{Tools: tools})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if g.tokenUsage.Model == "" && response.ModelVersion != "" {
|
|
g.tokenUsage.Model = response.ModelVersion
|
|
}
|
|
if response.UsageMetadata != nil {
|
|
g.tokenUsage.PromptTokens += int(response.UsageMetadata.PromptTokenCount)
|
|
g.tokenUsage.CompletionTokens += int(response.UsageMetadata.CandidatesTokenCount)
|
|
g.tokenUsage.ToolsTokens += int(response.UsageMetadata.ToolUsePromptTokenCount)
|
|
g.tokenUsage.TotalTokens += int(response.UsageMetadata.TotalTokenCount)
|
|
}
|
|
|
|
// var text string
|
|
// if response.Candidates != nil && response.Candidates[0].Content != nil {
|
|
// for _, part := range response.Candidates[0].Content.Parts {
|
|
// text += part.Text
|
|
// }
|
|
// }
|
|
|
|
return &llm.ChatResponse{
|
|
Model: response.ModelVersion,
|
|
Choices: []openai.ChatCompletionChoice{
|
|
{
|
|
Message: openai.ChatCompletionMessage{Content: response.Text(), Role: "assistant"},
|
|
FinishReason: openai.FinishReason(response.Candidates[0].FinishReason),
|
|
},
|
|
},
|
|
Usage: openai.Usage{PromptTokens: g.tokenUsage.PromptTokens + g.tokenUsage.ToolsTokens, CompletionTokens: g.tokenUsage.CompletionTokens, TotalTokens: g.tokenUsage.TotalTokens},
|
|
}, nil
|
|
|
|
}
|
|
|
|
func (g *Gemini) StreamChat(ctx context.Context, chatReq llm.ChatRequest) (chan *llm.StreamChatResponse, error) {
|
|
var contents []*genai.Content
|
|
if len(chatReq.Messages) > 0 {
|
|
for _, msg := range chatReq.Messages {
|
|
var role genai.Role
|
|
if msg.Role == "user" {
|
|
role = genai.RoleUser
|
|
} else {
|
|
role = genai.RoleModel
|
|
}
|
|
|
|
if len(msg.MultiContent) > 0 {
|
|
for _, c := range msg.MultiContent {
|
|
var parts []*genai.Part
|
|
|
|
if c.Type == "text" {
|
|
parts = append(parts, genai.NewPartFromText(c.Text))
|
|
}
|
|
if c.Type == "image_url" {
|
|
if strings.HasPrefix(c.ImageURL.URL, "http") {
|
|
continue
|
|
}
|
|
if strings.HasPrefix(c.ImageURL.URL, "data:image") {
|
|
var mediaType string
|
|
if strings.HasPrefix(c.ImageURL.URL, "data:image/jpeg") {
|
|
mediaType = "image/jpeg"
|
|
}
|
|
if strings.HasPrefix(c.ImageURL.URL, "data:image/png") {
|
|
mediaType = "image/png"
|
|
}
|
|
imageString := strings.Split(c.ImageURL.URL, ",")[1]
|
|
imageBytes, _ := base64.StdEncoding.DecodeString(imageString)
|
|
|
|
parts = append(parts, genai.NewPartFromBytes(imageBytes, mediaType))
|
|
}
|
|
|
|
}
|
|
contents = append(contents, genai.NewContentFromParts(parts, role))
|
|
}
|
|
} else {
|
|
contents = append(contents, genai.NewContentFromText(msg.Content, role))
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
datachan := make(chan *llm.StreamChatResponse)
|
|
var generr error
|
|
|
|
tools := []*genai.Tool{{GoogleSearch: &genai.GoogleSearch{}}}
|
|
|
|
go func() {
|
|
defer close(datachan)
|
|
for result, err := range g.Client.Models.GenerateContentStream(g.Ctx, chatReq.Model, contents, &genai.GenerateContentConfig{Tools: tools}) {
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
generr = err
|
|
return
|
|
}
|
|
if result.UsageMetadata != nil {
|
|
g.tokenUsage.PromptTokens += int(result.UsageMetadata.PromptTokenCount)
|
|
g.tokenUsage.CompletionTokens += int(result.UsageMetadata.CandidatesTokenCount)
|
|
g.tokenUsage.ToolsTokens += int(result.UsageMetadata.ToolUsePromptTokenCount)
|
|
g.tokenUsage.TotalTokens += int(result.UsageMetadata.TotalTokenCount)
|
|
}
|
|
|
|
datachan <- &llm.StreamChatResponse{
|
|
Model: result.ModelVersion,
|
|
Choices: []openai.ChatCompletionStreamChoice{
|
|
{
|
|
Delta: openai.ChatCompletionStreamChoiceDelta{
|
|
Role: "assistant",
|
|
// Content: result.Candidates[0].Content.Parts[0].Text,
|
|
Content: result.Text(),
|
|
},
|
|
FinishReason: openai.FinishReason(result.Candidates[0].FinishReason),
|
|
},
|
|
},
|
|
Usage: &openai.Usage{PromptTokens: g.tokenUsage.PromptTokens + g.tokenUsage.ToolsTokens, CompletionTokens: g.tokenUsage.CompletionTokens, TotalTokens: g.tokenUsage.TotalTokens},
|
|
}
|
|
|
|
}
|
|
}()
|
|
|
|
return datachan, generr
|
|
}
|
|
|
|
func (g *Gemini) GetTokenUsage() *llm.TokenUsage {
|
|
return g.tokenUsage
|
|
}
|