Files
opencatd-open/llm/openai_compatible/chat.go
2025-04-21 22:48:28 +08:00

233 lines
6.6 KiB
Go

package openai_compatible
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"opencatd-open/internal/model"
"opencatd-open/internal/utils"
"opencatd-open/llm"
"os"
"strings"
"github.com/sashabaranov/go-openai"
)
// https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation#latest-preview-api-releases
const AzureApiVersion = "2024-10-21"
const defaultOpenAICompatibleEndpoint = "https://api.openai.com/v1/chat/completions"
const Github_Marketplace = "https://models.inference.ai.azure.com/chat/completions"
type OpenAICompatible struct {
Client *http.Client
ApiKey *model.ApiKey
tokenUsage *llm.TokenUsage
Params map[string]interface{}
Done chan struct{}
}
func NewOpenAICompatible(apikey *model.ApiKey) (*OpenAICompatible, error) {
hc := http.DefaultClient
if os.Getenv("LOCAL_PROXY") != "" {
proxyUrl, err := url.Parse(os.Getenv("LOCAL_PROXY"))
if err == nil {
tr := http.Transport{
Proxy: http.ProxyURL(proxyUrl),
}
hc.Transport = &tr
}
}
oc := OpenAICompatible{
ApiKey: apikey,
Client: hc,
tokenUsage: &llm.TokenUsage{},
Done: make(chan struct{}),
}
if apikey.Parameters != nil {
var params map[string]interface{}
err := json.Unmarshal([]byte(*apikey.Parameters), &params)
if err != nil {
return nil, err
}
oc.Params = params
}
return &oc, nil
}
func (o *OpenAICompatible) Chat(ctx context.Context, chatReq llm.ChatRequest) (*llm.ChatResponse, error) {
chatReq.Stream = false
dst, err := utils.StructToMap(chatReq)
if err != nil {
return nil, err
}
if len(o.Params) > 0 {
dst = utils.MergeJSONObjects(dst, o.Params)
}
var reqBody bytes.Buffer
if err := json.NewEncoder(&reqBody).Encode(dst); err != nil {
return nil, err
}
var req *http.Request
switch *o.ApiKey.ApiType {
case "azure":
formatModel := func(in string) string {
if strings.Contains(in, ".") {
return strings.ReplaceAll(in, ".", "")
}
return in
}
var buildurl string
if *o.ApiKey.Endpoint != "" {
if strings.HasSuffix(*o.ApiKey.Endpoint, "/") {
o.ApiKey.ApiKey = utils.ToPtr(strings.TrimSuffix(*o.ApiKey.Endpoint, "/"))
}
buildurl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.Endpoint, formatModel(chatReq.Model), AzureApiVersion)
} else {
buildurl = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.ResourceNmae, formatModel(chatReq.Model), AzureApiVersion)
}
req, _ = http.NewRequest(http.MethodPost, buildurl, &reqBody)
req.Header.Set("api-key", *o.ApiKey.ApiKey)
case "github":
req, _ = http.NewRequest(http.MethodPost, Github_Marketplace, &reqBody)
default:
if o.ApiKey.Endpoint == nil || *o.ApiKey.Endpoint == "" {
req, _ = http.NewRequest(http.MethodPost, defaultOpenAICompatibleEndpoint, &reqBody)
} else {
req, _ = http.NewRequest(http.MethodPost, *o.ApiKey.Endpoint, &reqBody)
}
}
req.Header.Set("Authorization", "Bearer "+*o.ApiKey.ApiKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "identity")
resp, err := o.Client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var chatResp llm.ChatResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
return nil, err
}
if o.tokenUsage.Model == "" && chatResp.Model != "" {
o.tokenUsage.Model = chatResp.Model
}
o.tokenUsage.PromptTokens = chatResp.Usage.PromptTokens
o.tokenUsage.CompletionTokens = chatResp.Usage.CompletionTokens
o.tokenUsage.TotalTokens = chatResp.Usage.TotalTokens
return &chatResp, nil
}
func (o *OpenAICompatible) StreamChat(ctx context.Context, chatReq llm.ChatRequest) (chan *llm.StreamChatResponse, error) {
chatReq.Stream = true
chatReq.StreamOptions = &openai.StreamOptions{IncludeUsage: true}
dst, err := utils.StructToMap(chatReq)
if err != nil {
return nil, err
}
if len(o.Params) > 0 {
dst = utils.MergeJSONObjects(dst, o.Params)
}
var reqBody bytes.Buffer
if err := json.NewEncoder(&reqBody).Encode(dst); err != nil {
return nil, err
}
var req *http.Request
switch *o.ApiKey.ApiType {
case "azure":
formatModel := func(in string) string {
if strings.Contains(in, ".") {
return strings.ReplaceAll(in, ".", "")
}
return in
}
var buildurl string
if *o.ApiKey.Endpoint != "" {
buildurl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.Endpoint, formatModel(chatReq.Model), AzureApiVersion)
} else {
buildurl = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.ResourceNmae, formatModel(chatReq.Model), AzureApiVersion)
}
req, _ = http.NewRequest(http.MethodPost, buildurl, &reqBody)
req.Header.Set("api-key", *o.ApiKey.ApiKey)
case "github":
req, _ = http.NewRequest(http.MethodPost, Github_Marketplace, &reqBody)
default:
if o.ApiKey.Endpoint == nil || *o.ApiKey.Endpoint == "" {
req, _ = http.NewRequest(http.MethodPost, defaultOpenAICompatibleEndpoint, &reqBody)
} else {
req, _ = http.NewRequest(http.MethodPost, *o.ApiKey.Endpoint, &reqBody)
}
}
req.Header.Set("Authorization", "Bearer "+*o.ApiKey.ApiKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept-Encoding", "identity")
resp, err := o.Client.Do(req)
if err != nil {
return nil, err
}
output := make(chan *llm.StreamChatResponse)
b := new(bytes.Buffer)
teeReader := io.TeeReader(resp.Body, b)
// 流式响应
scanner := bufio.NewScanner(teeReader)
go func() {
defer resp.Body.Close()
defer close(output)
for scanner.Scan() {
line := scanner.Bytes()
var streamResp llm.StreamChatResponse
if len(line) > 0 {
// fmt.Println(string(line))
if bytes.HasPrefix(line, []byte("data: ")) {
if bytes.HasPrefix(line, []byte("data: [DONE]")) {
break
}
line = bytes.Replace(line, []byte("data: "), []byte(""), -1)
line = bytes.TrimSpace(line)
if err := json.Unmarshal(line, &streamResp); err != nil {
continue
}
fmt.Printf("%#v\n", streamResp.Usage)
if streamResp.Usage != nil {
o.tokenUsage.PromptTokens += streamResp.Usage.PromptTokens
o.tokenUsage.CompletionTokens += streamResp.Usage.CompletionTokens
o.tokenUsage.TotalTokens += streamResp.Usage.TotalTokens
}
output <- &streamResp
}
}
// select {
// case <-ctx.Done():
// return
// case output <- &streamResp:
// }
}
fmt.Println("llm usage:", o.tokenUsage.Model, o.tokenUsage.PromptTokens, o.tokenUsage.CompletionTokens, o.tokenUsage.TotalTokens)
}()
return output, nil
}
func (o *OpenAICompatible) GetTokenUsage() *llm.TokenUsage {
return o.tokenUsage
}