reface to openteam
This commit is contained in:
228
llm/google/v2/chat.go
Normal file
228
llm/google/v2/chat.go
Normal file
@@ -0,0 +1,228 @@
|
||||
// https://github.com/google-gemini/api-examples/
|
||||
|
||||
package google
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/llm"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
|
||||
type Gemini struct {
|
||||
Ctx context.Context
|
||||
Client *genai.Client
|
||||
ApiKey *model.ApiKey
|
||||
tokenUsage *llm.TokenUsage
|
||||
|
||||
Done chan struct{}
|
||||
}
|
||||
|
||||
func NewGemini(ctx context.Context, apiKey *model.ApiKey) (*Gemini, error) {
|
||||
hc := http.DefaultClient
|
||||
if os.Getenv("LOCAL_PROXY") != "" {
|
||||
proxyUrl, err := url.Parse(os.Getenv("LOCAL_PROXY"))
|
||||
if err == nil {
|
||||
hc = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}}
|
||||
}
|
||||
}
|
||||
client, err := genai.NewClient(ctx, &genai.ClientConfig{
|
||||
APIKey: *apiKey.ApiKey,
|
||||
Backend: genai.BackendGeminiAPI,
|
||||
HTTPClient: hc,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Gemini{
|
||||
Ctx: context.Background(),
|
||||
Client: client,
|
||||
ApiKey: apiKey,
|
||||
tokenUsage: &llm.TokenUsage{},
|
||||
Done: make(chan struct{}),
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (g *Gemini) Chat(ctx context.Context, chatReq llm.ChatRequest) (*llm.ChatResponse, error) {
|
||||
var content []*genai.Content
|
||||
if len(chatReq.Messages) > 0 {
|
||||
for _, msg := range chatReq.Messages {
|
||||
var role genai.Role
|
||||
if msg.Role == "user" || msg.Role == "system" {
|
||||
role = genai.RoleUser
|
||||
} else {
|
||||
role = genai.RoleModel
|
||||
}
|
||||
|
||||
if len(msg.MultiContent) > 0 {
|
||||
for _, c := range msg.MultiContent {
|
||||
var parts []*genai.Part
|
||||
|
||||
if c.Type == "text" {
|
||||
parts = append(parts, genai.NewPartFromText(c.Text))
|
||||
}
|
||||
if c.Type == "image_url" {
|
||||
if strings.HasPrefix(c.ImageURL.URL, "http") {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(c.ImageURL.URL, "data:image") {
|
||||
var mediaType string
|
||||
if strings.HasPrefix(c.ImageURL.URL, "data:image/jpeg") {
|
||||
mediaType = "image/jpeg"
|
||||
}
|
||||
if strings.HasPrefix(c.ImageURL.URL, "data:image/png") {
|
||||
mediaType = "image/png"
|
||||
}
|
||||
imageString := strings.Split(c.ImageURL.URL, ",")[1]
|
||||
imageBytes, _ := base64.StdEncoding.DecodeString(imageString)
|
||||
|
||||
parts = append(parts, genai.NewPartFromBytes(imageBytes, mediaType))
|
||||
}
|
||||
|
||||
}
|
||||
content = append(content, genai.NewContentFromParts(parts, role))
|
||||
}
|
||||
} else {
|
||||
content = append(content, genai.NewContentFromText(msg.Content, role))
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
tools := []*genai.Tool{{GoogleSearch: &genai.GoogleSearch{}}}
|
||||
response, err := g.Client.Models.GenerateContent(g.Ctx,
|
||||
chatReq.Model,
|
||||
content,
|
||||
&genai.GenerateContentConfig{Tools: tools})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if response.UsageMetadata != nil {
|
||||
g.tokenUsage.PromptTokens += int(response.UsageMetadata.PromptTokenCount)
|
||||
g.tokenUsage.CompletionTokens += int(response.UsageMetadata.CandidatesTokenCount)
|
||||
g.tokenUsage.ToolsTokens += int(response.UsageMetadata.ToolUsePromptTokenCount)
|
||||
g.tokenUsage.TotalTokens += int(response.UsageMetadata.TotalTokenCount)
|
||||
}
|
||||
|
||||
// var text string
|
||||
// if response.Candidates != nil && response.Candidates[0].Content != nil {
|
||||
// for _, part := range response.Candidates[0].Content.Parts {
|
||||
// text += part.Text
|
||||
// }
|
||||
// }
|
||||
|
||||
return &llm.ChatResponse{
|
||||
Model: response.ModelVersion,
|
||||
Choices: []openai.ChatCompletionChoice{
|
||||
{
|
||||
Message: openai.ChatCompletionMessage{Content: response.Text(), Role: "assistant"},
|
||||
FinishReason: openai.FinishReason(response.Candidates[0].FinishReason),
|
||||
},
|
||||
},
|
||||
Usage: openai.Usage{PromptTokens: g.tokenUsage.PromptTokens + g.tokenUsage.ToolsTokens, CompletionTokens: g.tokenUsage.CompletionTokens, TotalTokens: g.tokenUsage.TotalTokens},
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (g *Gemini) StreamChat(ctx context.Context, chatReq llm.ChatRequest) (chan *llm.StreamChatResponse, error) {
|
||||
var contents []*genai.Content
|
||||
if len(chatReq.Messages) > 0 {
|
||||
for _, msg := range chatReq.Messages {
|
||||
var role genai.Role
|
||||
if msg.Role == "user" {
|
||||
role = genai.RoleUser
|
||||
} else {
|
||||
role = genai.RoleModel
|
||||
}
|
||||
|
||||
if len(msg.MultiContent) > 0 {
|
||||
for _, c := range msg.MultiContent {
|
||||
var parts []*genai.Part
|
||||
|
||||
if c.Type == "text" {
|
||||
parts = append(parts, genai.NewPartFromText(c.Text))
|
||||
}
|
||||
if c.Type == "image_url" {
|
||||
if strings.HasPrefix(c.ImageURL.URL, "http") {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(c.ImageURL.URL, "data:image") {
|
||||
var mediaType string
|
||||
if strings.HasPrefix(c.ImageURL.URL, "data:image/jpeg") {
|
||||
mediaType = "image/jpeg"
|
||||
}
|
||||
if strings.HasPrefix(c.ImageURL.URL, "data:image/png") {
|
||||
mediaType = "image/png"
|
||||
}
|
||||
imageString := strings.Split(c.ImageURL.URL, ",")[1]
|
||||
imageBytes, _ := base64.StdEncoding.DecodeString(imageString)
|
||||
|
||||
parts = append(parts, genai.NewPartFromBytes(imageBytes, mediaType))
|
||||
}
|
||||
|
||||
}
|
||||
contents = append(contents, genai.NewContentFromParts(parts, role))
|
||||
}
|
||||
} else {
|
||||
contents = append(contents, genai.NewContentFromText(msg.Content, role))
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
datachan := make(chan *llm.StreamChatResponse)
|
||||
var generr error
|
||||
|
||||
tools := []*genai.Tool{{GoogleSearch: &genai.GoogleSearch{}}}
|
||||
|
||||
go func() {
|
||||
defer close(datachan)
|
||||
for result, err := range g.Client.Models.GenerateContentStream(g.Ctx, chatReq.Model, contents, &genai.GenerateContentConfig{Tools: tools}) {
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
generr = err
|
||||
return
|
||||
}
|
||||
if result.UsageMetadata != nil {
|
||||
g.tokenUsage.PromptTokens += int(result.UsageMetadata.PromptTokenCount)
|
||||
g.tokenUsage.CompletionTokens += int(result.UsageMetadata.CandidatesTokenCount)
|
||||
g.tokenUsage.ToolsTokens += int(result.UsageMetadata.ToolUsePromptTokenCount)
|
||||
g.tokenUsage.TotalTokens += int(result.UsageMetadata.TotalTokenCount)
|
||||
}
|
||||
|
||||
datachan <- &llm.StreamChatResponse{
|
||||
Model: result.ModelVersion,
|
||||
Choices: []openai.ChatCompletionStreamChoice{
|
||||
{
|
||||
Delta: openai.ChatCompletionStreamChoiceDelta{
|
||||
Role: "assistant",
|
||||
// Content: result.Candidates[0].Content.Parts[0].Text,
|
||||
Content: result.Text(),
|
||||
},
|
||||
FinishReason: openai.FinishReason(result.Candidates[0].FinishReason),
|
||||
},
|
||||
},
|
||||
Usage: &openai.Usage{PromptTokens: g.tokenUsage.PromptTokens + g.tokenUsage.ToolsTokens, CompletionTokens: g.tokenUsage.CompletionTokens, TotalTokens: g.tokenUsage.TotalTokens},
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
|
||||
return datachan, generr
|
||||
}
|
||||
|
||||
func (g *Gemini) GetTokenUsage() *llm.TokenUsage {
|
||||
return g.tokenUsage
|
||||
}
|
||||
Reference in New Issue
Block a user