reface to openteam
This commit is contained in:
221
llm/openai_compatible/chat.go
Normal file
221
llm/openai_compatible/chat.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package openai_compatible
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"opencatd-open/internal/model"
|
||||
"opencatd-open/internal/utils"
|
||||
"opencatd-open/llm"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation#latest-preview-api-releases
|
||||
const AzureApiVersion = "2024-10-21"
|
||||
const defaultOpenAICompatibleEndpoint = "https://api.openai.com/v1/chat/completions"
|
||||
const Github_Marketplace = "https://models.inference.ai.azure.com/chat/completions"
|
||||
|
||||
type OpenAICompatible struct {
|
||||
Client *http.Client
|
||||
ApiKey *model.ApiKey
|
||||
tokenUsage *llm.TokenUsage
|
||||
Params map[string]interface{}
|
||||
Done chan struct{}
|
||||
}
|
||||
|
||||
func NewOpenAICompatible(apikey *model.ApiKey) (*OpenAICompatible, error) {
|
||||
hc := http.DefaultClient
|
||||
if os.Getenv("LOCAL_PROXY") != "" {
|
||||
proxyUrl, err := url.Parse(os.Getenv("LOCAL_PROXY"))
|
||||
if err == nil {
|
||||
tr := http.Transport{
|
||||
Proxy: http.ProxyURL(proxyUrl),
|
||||
}
|
||||
hc.Transport = &tr
|
||||
}
|
||||
}
|
||||
|
||||
oc := OpenAICompatible{
|
||||
ApiKey: apikey,
|
||||
Client: hc,
|
||||
tokenUsage: &llm.TokenUsage{},
|
||||
Done: make(chan struct{}),
|
||||
}
|
||||
|
||||
if apikey.Parameters != nil {
|
||||
var params map[string]interface{}
|
||||
err := json.Unmarshal([]byte(*apikey.Parameters), ¶ms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
oc.Params = params
|
||||
}
|
||||
|
||||
return &oc, nil
|
||||
}
|
||||
|
||||
func (o *OpenAICompatible) Chat(ctx context.Context, chatReq llm.ChatRequest) (*llm.ChatResponse, error) {
|
||||
chatReq.Stream = false
|
||||
dst, err := utils.StructToMap(chatReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(o.Params) > 0 {
|
||||
dst = utils.MergeJSONObjects(dst, o.Params)
|
||||
}
|
||||
|
||||
var reqBody bytes.Buffer
|
||||
if err := json.NewEncoder(&reqBody).Encode(dst); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
switch *o.ApiKey.ApiType {
|
||||
case "azure":
|
||||
formatModel := func(in string) string {
|
||||
if strings.Contains(in, ".") {
|
||||
return strings.ReplaceAll(in, ".", "")
|
||||
}
|
||||
return in
|
||||
}
|
||||
var buildurl string
|
||||
if *o.ApiKey.Endpoint != "" {
|
||||
buildurl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.Endpoint, formatModel(chatReq.Model), AzureApiVersion)
|
||||
} else {
|
||||
buildurl = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.ResourceNmae, formatModel(chatReq.Model), AzureApiVersion)
|
||||
}
|
||||
req, _ = http.NewRequest(http.MethodPost, buildurl, &reqBody)
|
||||
req.Header.Set("api-key", *o.ApiKey.ApiKey)
|
||||
case "github":
|
||||
req, _ = http.NewRequest(http.MethodPost, Github_Marketplace, &reqBody)
|
||||
default:
|
||||
if o.ApiKey.Endpoint == nil || *o.ApiKey.Endpoint == "" {
|
||||
req, _ = http.NewRequest(http.MethodPost, defaultOpenAICompatibleEndpoint, &reqBody)
|
||||
} else {
|
||||
req, _ = http.NewRequest(http.MethodPost, *o.ApiKey.Endpoint, &reqBody)
|
||||
}
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+*o.ApiKey.ApiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept-Encoding", "identity")
|
||||
|
||||
resp, err := o.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var chatResp llm.ChatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
o.tokenUsage.PromptTokens = chatResp.Usage.PromptTokens
|
||||
o.tokenUsage.CompletionTokens = chatResp.Usage.CompletionTokens
|
||||
o.tokenUsage.TotalTokens = chatResp.Usage.TotalTokens
|
||||
return &chatResp, nil
|
||||
}
|
||||
|
||||
func (o *OpenAICompatible) StreamChat(ctx context.Context, chatReq llm.ChatRequest) (chan *llm.StreamChatResponse, error) {
|
||||
chatReq.Stream = true
|
||||
dst, err := utils.StructToMap(chatReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(o.Params) > 0 {
|
||||
dst = utils.MergeJSONObjects(dst, o.Params)
|
||||
}
|
||||
|
||||
var reqBody bytes.Buffer
|
||||
if err := json.NewEncoder(&reqBody).Encode(dst); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
switch *o.ApiKey.ApiType {
|
||||
case "azure":
|
||||
formatModel := func(in string) string {
|
||||
if strings.Contains(in, ".") {
|
||||
return strings.ReplaceAll(in, ".", "")
|
||||
}
|
||||
return in
|
||||
}
|
||||
var buildurl string
|
||||
if *o.ApiKey.Endpoint != "" {
|
||||
buildurl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.Endpoint, formatModel(chatReq.Model), AzureApiVersion)
|
||||
} else {
|
||||
buildurl = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.ResourceNmae, formatModel(chatReq.Model), AzureApiVersion)
|
||||
}
|
||||
req, _ = http.NewRequest(http.MethodPost, buildurl, &reqBody)
|
||||
req.Header.Set("api-key", *o.ApiKey.ApiKey)
|
||||
case "github":
|
||||
req, _ = http.NewRequest(http.MethodPost, Github_Marketplace, &reqBody)
|
||||
default:
|
||||
if o.ApiKey.Endpoint == nil || *o.ApiKey.Endpoint == "" {
|
||||
req, _ = http.NewRequest(http.MethodPost, defaultOpenAICompatibleEndpoint, &reqBody)
|
||||
} else {
|
||||
req, _ = http.NewRequest(http.MethodPost, *o.ApiKey.Endpoint, &reqBody)
|
||||
}
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+*o.ApiKey.ApiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept-Encoding", "identity")
|
||||
|
||||
resp, err := o.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
output := make(chan *llm.StreamChatResponse)
|
||||
|
||||
b := new(bytes.Buffer)
|
||||
teeReader := io.TeeReader(resp.Body, b)
|
||||
// 流式响应
|
||||
scanner := bufio.NewScanner(teeReader)
|
||||
|
||||
go func() {
|
||||
defer resp.Body.Close()
|
||||
defer close(output)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
var streamResp llm.StreamChatResponse
|
||||
if len(line) > 0 {
|
||||
// fmt.Println(string(line))
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
if bytes.HasPrefix(line, []byte("data: [DONE]")) {
|
||||
break
|
||||
}
|
||||
line = bytes.Replace(line, []byte("data: "), []byte(""), -1)
|
||||
line = bytes.TrimSpace(line)
|
||||
if err := json.Unmarshal(line, &streamResp); err != nil {
|
||||
continue
|
||||
}
|
||||
if streamResp.Usage != nil {
|
||||
o.tokenUsage.PromptTokens += streamResp.Usage.PromptTokens
|
||||
o.tokenUsage.CompletionTokens += streamResp.Usage.CompletionTokens
|
||||
o.tokenUsage.TotalTokens += streamResp.Usage.TotalTokens
|
||||
}
|
||||
output <- &streamResp
|
||||
}
|
||||
}
|
||||
|
||||
// select {
|
||||
// case <-ctx.Done():
|
||||
// return
|
||||
// case output <- &streamResp:
|
||||
// }
|
||||
}
|
||||
}()
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func (o *OpenAICompatible) GetTokenUsage() *llm.TokenUsage {
|
||||
return o.tokenUsage
|
||||
}
|
||||
Reference in New Issue
Block a user