233 lines
6.6 KiB
Go
233 lines
6.6 KiB
Go
package openai_compatible
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"opencatd-open/internal/model"
|
|
"opencatd-open/internal/utils"
|
|
"opencatd-open/llm"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
// https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation#latest-preview-api-releases
|
|
const AzureApiVersion = "2024-10-21"
|
|
const defaultOpenAICompatibleEndpoint = "https://api.openai.com/v1/chat/completions"
|
|
const Github_Marketplace = "https://models.inference.ai.azure.com/chat/completions"
|
|
|
|
type OpenAICompatible struct {
|
|
Client *http.Client
|
|
ApiKey *model.ApiKey
|
|
tokenUsage *llm.TokenUsage
|
|
Params map[string]interface{}
|
|
Done chan struct{}
|
|
}
|
|
|
|
func NewOpenAICompatible(apikey *model.ApiKey) (*OpenAICompatible, error) {
|
|
hc := http.DefaultClient
|
|
if os.Getenv("LOCAL_PROXY") != "" {
|
|
proxyUrl, err := url.Parse(os.Getenv("LOCAL_PROXY"))
|
|
if err == nil {
|
|
tr := http.Transport{
|
|
Proxy: http.ProxyURL(proxyUrl),
|
|
}
|
|
hc.Transport = &tr
|
|
}
|
|
}
|
|
|
|
oc := OpenAICompatible{
|
|
ApiKey: apikey,
|
|
Client: hc,
|
|
tokenUsage: &llm.TokenUsage{},
|
|
Done: make(chan struct{}),
|
|
}
|
|
|
|
if apikey.Parameters != nil {
|
|
var params map[string]interface{}
|
|
err := json.Unmarshal([]byte(*apikey.Parameters), ¶ms)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
oc.Params = params
|
|
}
|
|
|
|
return &oc, nil
|
|
}
|
|
|
|
func (o *OpenAICompatible) Chat(ctx context.Context, chatReq llm.ChatRequest) (*llm.ChatResponse, error) {
|
|
chatReq.Stream = false
|
|
dst, err := utils.StructToMap(chatReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(o.Params) > 0 {
|
|
dst = utils.MergeJSONObjects(dst, o.Params)
|
|
}
|
|
|
|
var reqBody bytes.Buffer
|
|
if err := json.NewEncoder(&reqBody).Encode(dst); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var req *http.Request
|
|
switch *o.ApiKey.ApiType {
|
|
case "azure":
|
|
formatModel := func(in string) string {
|
|
if strings.Contains(in, ".") {
|
|
return strings.ReplaceAll(in, ".", "")
|
|
}
|
|
return in
|
|
}
|
|
var buildurl string
|
|
if *o.ApiKey.Endpoint != "" {
|
|
if strings.HasSuffix(*o.ApiKey.Endpoint, "/") {
|
|
o.ApiKey.ApiKey = utils.ToPtr(strings.TrimSuffix(*o.ApiKey.Endpoint, "/"))
|
|
}
|
|
buildurl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.Endpoint, formatModel(chatReq.Model), AzureApiVersion)
|
|
} else {
|
|
buildurl = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.ResourceNmae, formatModel(chatReq.Model), AzureApiVersion)
|
|
}
|
|
req, _ = http.NewRequest(http.MethodPost, buildurl, &reqBody)
|
|
req.Header.Set("api-key", *o.ApiKey.ApiKey)
|
|
case "github":
|
|
req, _ = http.NewRequest(http.MethodPost, Github_Marketplace, &reqBody)
|
|
default:
|
|
if o.ApiKey.Endpoint == nil || *o.ApiKey.Endpoint == "" {
|
|
req, _ = http.NewRequest(http.MethodPost, defaultOpenAICompatibleEndpoint, &reqBody)
|
|
} else {
|
|
req, _ = http.NewRequest(http.MethodPost, *o.ApiKey.Endpoint, &reqBody)
|
|
}
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+*o.ApiKey.ApiKey)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept-Encoding", "identity")
|
|
|
|
resp, err := o.Client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var chatResp llm.ChatResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if o.tokenUsage.Model == "" && chatResp.Model != "" {
|
|
o.tokenUsage.Model = chatResp.Model
|
|
}
|
|
o.tokenUsage.PromptTokens = chatResp.Usage.PromptTokens
|
|
o.tokenUsage.CompletionTokens = chatResp.Usage.CompletionTokens
|
|
o.tokenUsage.TotalTokens = chatResp.Usage.TotalTokens
|
|
return &chatResp, nil
|
|
}
|
|
|
|
func (o *OpenAICompatible) StreamChat(ctx context.Context, chatReq llm.ChatRequest) (chan *llm.StreamChatResponse, error) {
|
|
chatReq.Stream = true
|
|
chatReq.StreamOptions = &openai.StreamOptions{IncludeUsage: true}
|
|
dst, err := utils.StructToMap(chatReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(o.Params) > 0 {
|
|
dst = utils.MergeJSONObjects(dst, o.Params)
|
|
}
|
|
|
|
var reqBody bytes.Buffer
|
|
if err := json.NewEncoder(&reqBody).Encode(dst); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var req *http.Request
|
|
switch *o.ApiKey.ApiType {
|
|
case "azure":
|
|
formatModel := func(in string) string {
|
|
if strings.Contains(in, ".") {
|
|
return strings.ReplaceAll(in, ".", "")
|
|
}
|
|
return in
|
|
}
|
|
var buildurl string
|
|
if *o.ApiKey.Endpoint != "" {
|
|
buildurl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.Endpoint, formatModel(chatReq.Model), AzureApiVersion)
|
|
} else {
|
|
buildurl = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.ResourceNmae, formatModel(chatReq.Model), AzureApiVersion)
|
|
}
|
|
req, _ = http.NewRequest(http.MethodPost, buildurl, &reqBody)
|
|
req.Header.Set("api-key", *o.ApiKey.ApiKey)
|
|
case "github":
|
|
req, _ = http.NewRequest(http.MethodPost, Github_Marketplace, &reqBody)
|
|
default:
|
|
if o.ApiKey.Endpoint == nil || *o.ApiKey.Endpoint == "" {
|
|
req, _ = http.NewRequest(http.MethodPost, defaultOpenAICompatibleEndpoint, &reqBody)
|
|
} else {
|
|
req, _ = http.NewRequest(http.MethodPost, *o.ApiKey.Endpoint, &reqBody)
|
|
}
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+*o.ApiKey.ApiKey)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept-Encoding", "identity")
|
|
|
|
resp, err := o.Client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
output := make(chan *llm.StreamChatResponse)
|
|
|
|
b := new(bytes.Buffer)
|
|
teeReader := io.TeeReader(resp.Body, b)
|
|
// 流式响应
|
|
scanner := bufio.NewScanner(teeReader)
|
|
|
|
go func() {
|
|
defer resp.Body.Close()
|
|
defer close(output)
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Bytes()
|
|
var streamResp llm.StreamChatResponse
|
|
if len(line) > 0 {
|
|
// fmt.Println(string(line))
|
|
if bytes.HasPrefix(line, []byte("data: ")) {
|
|
if bytes.HasPrefix(line, []byte("data: [DONE]")) {
|
|
break
|
|
}
|
|
line = bytes.Replace(line, []byte("data: "), []byte(""), -1)
|
|
line = bytes.TrimSpace(line)
|
|
if err := json.Unmarshal(line, &streamResp); err != nil {
|
|
continue
|
|
}
|
|
fmt.Printf("%#v\n", streamResp.Usage)
|
|
if streamResp.Usage != nil {
|
|
o.tokenUsage.PromptTokens += streamResp.Usage.PromptTokens
|
|
o.tokenUsage.CompletionTokens += streamResp.Usage.CompletionTokens
|
|
o.tokenUsage.TotalTokens += streamResp.Usage.TotalTokens
|
|
}
|
|
output <- &streamResp
|
|
}
|
|
}
|
|
|
|
// select {
|
|
// case <-ctx.Done():
|
|
// return
|
|
// case output <- &streamResp:
|
|
// }
|
|
}
|
|
fmt.Println("llm usage:", o.tokenUsage.Model, o.tokenUsage.PromptTokens, o.tokenUsage.CompletionTokens, o.tokenUsage.TotalTokens)
|
|
}()
|
|
return output, nil
|
|
}
|
|
|
|
func (o *OpenAICompatible) GetTokenUsage() *llm.TokenUsage {
|
|
return o.tokenUsage
|
|
}
|