package openai_compatible import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "net/http" "net/url" "opencatd-open/internal/model" "opencatd-open/internal/utils" "opencatd-open/llm" "os" "strings" ) // https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation#latest-preview-api-releases const AzureApiVersion = "2024-10-21" const defaultOpenAICompatibleEndpoint = "https://api.openai.com/v1/chat/completions" const Github_Marketplace = "https://models.inference.ai.azure.com/chat/completions" type OpenAICompatible struct { Client *http.Client ApiKey *model.ApiKey tokenUsage *llm.TokenUsage Params map[string]interface{} Done chan struct{} } func NewOpenAICompatible(apikey *model.ApiKey) (*OpenAICompatible, error) { hc := http.DefaultClient if os.Getenv("LOCAL_PROXY") != "" { proxyUrl, err := url.Parse(os.Getenv("LOCAL_PROXY")) if err == nil { tr := http.Transport{ Proxy: http.ProxyURL(proxyUrl), } hc.Transport = &tr } } oc := OpenAICompatible{ ApiKey: apikey, Client: hc, tokenUsage: &llm.TokenUsage{}, Done: make(chan struct{}), } if apikey.Parameters != nil { var params map[string]interface{} err := json.Unmarshal([]byte(*apikey.Parameters), ¶ms) if err != nil { return nil, err } oc.Params = params } return &oc, nil } func (o *OpenAICompatible) Chat(ctx context.Context, chatReq llm.ChatRequest) (*llm.ChatResponse, error) { chatReq.Stream = false dst, err := utils.StructToMap(chatReq) if err != nil { return nil, err } if len(o.Params) > 0 { dst = utils.MergeJSONObjects(dst, o.Params) } var reqBody bytes.Buffer if err := json.NewEncoder(&reqBody).Encode(dst); err != nil { return nil, err } var req *http.Request switch *o.ApiKey.ApiType { case "azure": formatModel := func(in string) string { if strings.Contains(in, ".") { return strings.ReplaceAll(in, ".", "") } return in } var buildurl string if *o.ApiKey.Endpoint != "" { buildurl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.Endpoint, formatModel(chatReq.Model), AzureApiVersion) } else { buildurl = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.ResourceNmae, formatModel(chatReq.Model), AzureApiVersion) } req, _ = http.NewRequest(http.MethodPost, buildurl, &reqBody) req.Header.Set("api-key", *o.ApiKey.ApiKey) case "github": req, _ = http.NewRequest(http.MethodPost, Github_Marketplace, &reqBody) default: if o.ApiKey.Endpoint == nil || *o.ApiKey.Endpoint == "" { req, _ = http.NewRequest(http.MethodPost, defaultOpenAICompatibleEndpoint, &reqBody) } else { req, _ = http.NewRequest(http.MethodPost, *o.ApiKey.Endpoint, &reqBody) } } req.Header.Set("Authorization", "Bearer "+*o.ApiKey.ApiKey) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept-Encoding", "identity") resp, err := o.Client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() var chatResp llm.ChatResponse if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { return nil, err } o.tokenUsage.PromptTokens = chatResp.Usage.PromptTokens o.tokenUsage.CompletionTokens = chatResp.Usage.CompletionTokens o.tokenUsage.TotalTokens = chatResp.Usage.TotalTokens return &chatResp, nil } func (o *OpenAICompatible) StreamChat(ctx context.Context, chatReq llm.ChatRequest) (chan *llm.StreamChatResponse, error) { chatReq.Stream = true dst, err := utils.StructToMap(chatReq) if err != nil { return nil, err } if len(o.Params) > 0 { dst = utils.MergeJSONObjects(dst, o.Params) } var reqBody bytes.Buffer if err := json.NewEncoder(&reqBody).Encode(dst); err != nil { return nil, err } var req *http.Request switch *o.ApiKey.ApiType { case "azure": formatModel := func(in string) string { if strings.Contains(in, ".") { return strings.ReplaceAll(in, ".", "") } return in } var buildurl string if *o.ApiKey.Endpoint != "" { buildurl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.Endpoint, formatModel(chatReq.Model), AzureApiVersion) } else { buildurl = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s", *o.ApiKey.ResourceNmae, formatModel(chatReq.Model), AzureApiVersion) } req, _ = http.NewRequest(http.MethodPost, buildurl, &reqBody) req.Header.Set("api-key", *o.ApiKey.ApiKey) case "github": req, _ = http.NewRequest(http.MethodPost, Github_Marketplace, &reqBody) default: if o.ApiKey.Endpoint == nil || *o.ApiKey.Endpoint == "" { req, _ = http.NewRequest(http.MethodPost, defaultOpenAICompatibleEndpoint, &reqBody) } else { req, _ = http.NewRequest(http.MethodPost, *o.ApiKey.Endpoint, &reqBody) } } req.Header.Set("Authorization", "Bearer "+*o.ApiKey.ApiKey) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept-Encoding", "identity") resp, err := o.Client.Do(req) if err != nil { return nil, err } output := make(chan *llm.StreamChatResponse) b := new(bytes.Buffer) teeReader := io.TeeReader(resp.Body, b) // 流式响应 scanner := bufio.NewScanner(teeReader) go func() { defer resp.Body.Close() defer close(output) for scanner.Scan() { line := scanner.Bytes() var streamResp llm.StreamChatResponse if len(line) > 0 { // fmt.Println(string(line)) if bytes.HasPrefix(line, []byte("data: ")) { if bytes.HasPrefix(line, []byte("data: [DONE]")) { break } line = bytes.Replace(line, []byte("data: "), []byte(""), -1) line = bytes.TrimSpace(line) if err := json.Unmarshal(line, &streamResp); err != nil { continue } if streamResp.Usage != nil { o.tokenUsage.PromptTokens += streamResp.Usage.PromptTokens o.tokenUsage.CompletionTokens += streamResp.Usage.CompletionTokens o.tokenUsage.TotalTokens += streamResp.Usage.TotalTokens } output <- &streamResp } } // select { // case <-ctx.Done(): // return // case output <- &streamResp: // } } }() return output, nil } func (o *OpenAICompatible) GetTokenUsage() *llm.TokenUsage { return o.tokenUsage }