reface to openteam
This commit is contained in:
178
llm/openai/chat.go
Normal file
178
llm/openai/chat.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation#latest-preview-api-releases
|
||||
AzureApiVersion = "2024-10-21"
|
||||
BaseHost = "api.openai.com"
|
||||
OpenAI_Endpoint = "https://api.openai.com/v1/chat/completions"
|
||||
Github_Marketplace = "https://models.inference.ai.azure.com/chat/completions"
|
||||
)
|
||||
|
||||
var (
|
||||
Custom_Endpoint string
|
||||
AIGateWay_Endpoint string // "https://gateway.ai.cloudflare.com/v1/431ba10f11200d544922fbca177aaa7f/openai/openai/chat/completions"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if os.Getenv("OpenAI_Endpoint") != "" {
|
||||
Custom_Endpoint = os.Getenv("OpenAI_Endpoint")
|
||||
}
|
||||
if os.Getenv("AIGateWay_Endpoint") != "" {
|
||||
AIGateWay_Endpoint = os.Getenv("AIGateWay_Endpoint")
|
||||
}
|
||||
}
|
||||
|
||||
// Vision Content
|
||||
type VisionContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL *VisionImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
type VisionImageURL struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Name string `json:"name,omitempty"`
|
||||
// MultiContent []VisionContent
|
||||
}
|
||||
|
||||
type FunctionDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters any `json:"parameters"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function *FunctionDefinition `json:"function,omitempty"`
|
||||
}
|
||||
|
||||
type StreamOption struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatCompletionMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
LogitBias map[string]int `json:"logit_bias,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
// Functions []FunctionDefinition `json:"functions,omitempty"`
|
||||
// FunctionCall any `json:"function_call,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
// ToolChoice any `json:"tool_choice,omitempty"`
|
||||
StreamOptions *StreamOption `json:"stream_options,omitempty"`
|
||||
}
|
||||
|
||||
func (c ChatCompletionRequest) ToByteJson() []byte {
|
||||
bytejson, _ := json.Marshal(c)
|
||||
return bytejson
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
} `json:"function"`
|
||||
}
|
||||
|
||||
type ChatCompletionResponse struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Object string `json:"object,omitempty"`
|
||||
Created int `json:"created,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Choices []struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
Message struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
} `json:"message,omitempty"`
|
||||
Logprobs string `json:"logprobs,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
} `json:"choices,omitempty"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens,omitempty"`
|
||||
CompletionTokens int `json:"completion_tokens,omitempty"`
|
||||
TotalTokens int `json:"total_tokens,omitempty"`
|
||||
PromptTokensDetails struct {
|
||||
CachedTokens int `json:"cached_tokens,omitempty"`
|
||||
AudioTokens int `json:"audio_tokens,omitempty"`
|
||||
} `json:"prompt_tokens_details,omitempty"`
|
||||
CompletionTokensDetails struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
|
||||
AudioTokens int `json:"audio_tokens,omitempty"`
|
||||
AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"`
|
||||
RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"`
|
||||
} `json:"completion_tokens_details,omitempty"`
|
||||
} `json:"usage,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
}
|
||||
|
||||
type Choice struct {
|
||||
Index int `json:"index"`
|
||||
Delta struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls"`
|
||||
} `json:"delta"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
type ChatCompletionStreamResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []Choice `json:"choices"`
|
||||
}
|
||||
|
||||
func (c *ChatCompletionStreamResponse) ByteJson() []byte {
|
||||
bytejson, _ := json.Marshal(c)
|
||||
return bytejson
|
||||
}
|
||||
|
||||
func modelmap(in string) string {
|
||||
// gpt-3.5-turbo -> gpt-35-turbo
|
||||
if strings.Contains(in, ".") {
|
||||
return strings.ReplaceAll(in, ".", "")
|
||||
}
|
||||
return in
|
||||
}
|
||||
|
||||
type ErrResponse struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
Code string `json:"code"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func (e *ErrResponse) ByteJson() []byte {
|
||||
bytejson, _ := json.Marshal(e)
|
||||
return bytejson
|
||||
}
|
||||
149
llm/openai/dall-e.go
Normal file
149
llm/openai/dall-e.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"opencatd-open/pkg/tokenizer"
|
||||
"opencatd-open/store"
|
||||
"strconv"
|
||||
|
||||
"github.com/duke-git/lancet/v2/slice"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
DalleEndpoint = "https://api.openai.com/v1/images/generations"
|
||||
DalleEditEndpoint = "https://api.openai.com/v1/images/edits"
|
||||
DalleVariationEndpoint = "https://api.openai.com/v1/images/variations"
|
||||
)
|
||||
|
||||
type DallERequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `form:"n" json:"n,omitempty"`
|
||||
Size string `form:"size" json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"` // standard,hd
|
||||
Style string `json:"style,omitempty"` // vivid,natural
|
||||
ResponseFormat string `json:"response_format,omitempty"` // url or b64_json
|
||||
}
|
||||
|
||||
func DallEProxy(c *gin.Context) {
|
||||
|
||||
var dalleRequest DallERequest
|
||||
if err := c.ShouldBind(&dalleRequest); err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if dalleRequest.N == 0 {
|
||||
dalleRequest.N = 1
|
||||
}
|
||||
|
||||
if dalleRequest.Size == "" {
|
||||
dalleRequest.Size = "512x512"
|
||||
}
|
||||
|
||||
model := dalleRequest.Model
|
||||
|
||||
var chatlog store.Tokens
|
||||
chatlog.CompletionCount = dalleRequest.N
|
||||
|
||||
if model == "dall-e" {
|
||||
model = "dall-e-2"
|
||||
}
|
||||
model = model + "." + dalleRequest.Size
|
||||
|
||||
if dalleRequest.Model == "dall-e-2" || dalleRequest.Model == "dall-e" {
|
||||
if !slice.Contain([]string{"256x256", "512x512", "1024x1024"}, dalleRequest.Size) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Invalid size: %s for %s", dalleRequest.Size, dalleRequest.Model),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
} else if dalleRequest.Model == "dall-e-3" {
|
||||
if !slice.Contain([]string{"256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"}, dalleRequest.Size) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Invalid size: %s for %s", dalleRequest.Size, dalleRequest.Model),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
if dalleRequest.Quality == "hd" {
|
||||
model = model + ".hd"
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Invalid model: %s", dalleRequest.Model),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
chatlog.Model = model
|
||||
|
||||
token, _ := c.Get("localuser")
|
||||
|
||||
lu, err := store.GetUserByToken(token.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
chatlog.UserID = int(lu.ID)
|
||||
|
||||
key, err := store.SelectKeyCache("openai")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
targetURL, _ := url.Parse(DalleEndpoint)
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||
proxy.Director = func(req *http.Request) {
|
||||
req.Header.Set("Authorization", "Bearer "+key.Key)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
req.Host = targetURL.Host
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
req.URL.Host = targetURL.Host
|
||||
req.URL.Path = targetURL.Path
|
||||
req.URL.RawPath = targetURL.RawPath
|
||||
req.URL.RawQuery = targetURL.RawQuery
|
||||
|
||||
bytebody, _ := json.Marshal(dalleRequest)
|
||||
req.Body = io.NopCloser(bytes.NewBuffer(bytebody))
|
||||
req.ContentLength = int64(len(bytebody))
|
||||
req.Header.Set("Content-Length", strconv.Itoa(len(bytebody)))
|
||||
}
|
||||
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
|
||||
chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
|
||||
if err := store.Record(&chatlog); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
if err := store.SumDaily(chatlog.UserID); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
proxy.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
215
llm/openai/handle_proxy.go
Normal file
215
llm/openai/handle_proxy.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"opencatd-open/pkg/tokenizer"
|
||||
"opencatd-open/store"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) {
|
||||
usagelog := store.Tokens{Model: chatReq.Model}
|
||||
|
||||
token, _ := c.Get("localuser")
|
||||
|
||||
lu, err := store.GetUserByToken(token.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
usagelog.UserID = int(lu.ID)
|
||||
|
||||
var prompt string
|
||||
for _, msg := range chatReq.Messages {
|
||||
switch ct := msg.Content.(type) {
|
||||
case string:
|
||||
prompt += "<" + msg.Role + ">: " + msg.Content.(string) + "\n"
|
||||
case []any:
|
||||
for _, item := range ct {
|
||||
if m, ok := item.(map[string]interface{}); ok {
|
||||
if m["type"] == "text" {
|
||||
prompt += "<" + msg.Role + ">: " + m["text"].(string) + "\n"
|
||||
} else if m["type"] == "image_url" {
|
||||
if url, ok := m["image_url"].(map[string]interface{}); ok {
|
||||
fmt.Printf(" URL: %v\n", url["url"])
|
||||
if strings.HasPrefix(url["url"].(string), "http") {
|
||||
fmt.Println("网络图片:", url["url"].(string))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Invalid content type",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(chatReq.Tools) > 0 {
|
||||
tooljson, _ := json.Marshal(chatReq.Tools)
|
||||
prompt += "<tools>: " + string(tooljson) + "\n"
|
||||
}
|
||||
}
|
||||
switch chatReq.Model {
|
||||
case "gpt-4o", "gpt-4o-mini", "chatgpt-4o-latest":
|
||||
chatReq.MaxTokens = 16384
|
||||
}
|
||||
if chatReq.Stream {
|
||||
chatReq.StreamOptions = &StreamOption{IncludeUsage: true}
|
||||
}
|
||||
|
||||
usagelog.PromptCount = tokenizer.NumTokensFromStr(prompt, chatReq.Model)
|
||||
|
||||
// onekey, err := store.SelectKeyCache("openai")
|
||||
onekey, err := store.SelectKeyCacheByModel(chatReq.Model)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
|
||||
switch onekey.ApiType {
|
||||
case "github":
|
||||
req, err = http.NewRequest(c.Request.Method, Github_Marketplace, bytes.NewReader(chatReq.ToByteJson()))
|
||||
req.Header = c.Request.Header
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key))
|
||||
case "azure":
|
||||
var buildurl string
|
||||
if onekey.EndPoint != "" {
|
||||
buildurl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", onekey.EndPoint, modelmap(chatReq.Model), AzureApiVersion)
|
||||
} else {
|
||||
buildurl = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s", onekey.ResourceNmae, modelmap(chatReq.Model), AzureApiVersion)
|
||||
}
|
||||
req, err = http.NewRequest(c.Request.Method, buildurl, bytes.NewReader(chatReq.ToByteJson()))
|
||||
req.Header = c.Request.Header
|
||||
req.Header.Set("api-key", onekey.Key)
|
||||
default:
|
||||
req, err = http.NewRequest(c.Request.Method, OpenAI_Endpoint, bytes.NewReader(chatReq.ToByteJson())) // default endpoint
|
||||
|
||||
if AIGateWay_Endpoint != "" { // cloudflare gateway的endpoint
|
||||
req, err = http.NewRequest(c.Request.Method, AIGateWay_Endpoint, bytes.NewReader(chatReq.ToByteJson()))
|
||||
}
|
||||
if Custom_Endpoint != "" { // 自定义endpoint
|
||||
req, err = http.NewRequest(c.Request.Method, Custom_Endpoint, bytes.NewReader(chatReq.ToByteJson()))
|
||||
}
|
||||
if onekey.EndPoint != "" { // 优先key的endpoint
|
||||
req, err = http.NewRequest(c.Request.Method, onekey.EndPoint+c.Request.RequestURI, bytes.NewReader(chatReq.ToByteJson()))
|
||||
}
|
||||
|
||||
req.Header = c.Request.Header
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key))
|
||||
}
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result string
|
||||
if chatReq.Stream {
|
||||
for key, value := range resp.Header {
|
||||
for _, v := range value {
|
||||
c.Writer.Header().Add(key, v)
|
||||
}
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
teeReader := io.TeeReader(resp.Body, c.Writer)
|
||||
// 流式响应
|
||||
scanner := bufio.NewScanner(teeReader)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if len(line) > 0 && bytes.HasPrefix(line, []byte("data: ")) {
|
||||
if bytes.HasPrefix(line, []byte("data: [DONE]")) {
|
||||
break
|
||||
}
|
||||
var opiResp ChatCompletionStreamResponse
|
||||
line = bytes.Replace(line, []byte("data: "), []byte(""), -1)
|
||||
line = bytes.TrimSpace(line)
|
||||
if err := json.Unmarshal(line, &opiResp); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if opiResp.Choices != nil && len(opiResp.Choices) > 0 {
|
||||
if opiResp.Choices[0].Delta.Role != "" {
|
||||
result += "<" + opiResp.Choices[0].Delta.Role + "> "
|
||||
}
|
||||
result += opiResp.Choices[0].Delta.Content // 计算Content Token
|
||||
|
||||
if len(opiResp.Choices[0].Delta.ToolCalls) > 0 { // 计算ToolCalls token
|
||||
if opiResp.Choices[0].Delta.ToolCalls[0].Function.Name != "" {
|
||||
result += "name:" + opiResp.Choices[0].Delta.ToolCalls[0].Function.Name + " arguments:"
|
||||
}
|
||||
result += opiResp.Choices[0].Delta.ToolCalls[0].Function.Arguments
|
||||
}
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
} else {
|
||||
|
||||
// 处理非流式响应
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
fmt.Println("Error reading response body:", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
var opiResp ChatCompletionResponse
|
||||
if err := json.Unmarshal(body, &opiResp); err != nil {
|
||||
log.Println("Error parsing JSON:", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if opiResp.Choices != nil && len(opiResp.Choices) > 0 {
|
||||
if opiResp.Choices[0].Message.Role != "" {
|
||||
result += "<" + opiResp.Choices[0].Message.Role + "> "
|
||||
}
|
||||
result += opiResp.Choices[0].Message.Content
|
||||
|
||||
if len(opiResp.Choices[0].Message.ToolCalls) > 0 {
|
||||
if opiResp.Choices[0].Message.ToolCalls[0].Function.Name != "" {
|
||||
result += "name:" + opiResp.Choices[0].Message.ToolCalls[0].Function.Name + " arguments:"
|
||||
}
|
||||
result += opiResp.Choices[0].Message.ToolCalls[0].Function.Arguments
|
||||
}
|
||||
|
||||
}
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, opiResp)
|
||||
}
|
||||
usagelog.CompletionCount = tokenizer.NumTokensFromStr(result, chatReq.Model)
|
||||
usagelog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(usagelog.Model, usagelog.PromptCount, usagelog.CompletionCount))
|
||||
if err := store.Record(&usagelog); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
if err := store.SumDaily(usagelog.UserID); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
197
llm/openai/realtime.go
Normal file
197
llm/openai/realtime.go
Normal file
@@ -0,0 +1,197 @@
|
||||
/*
|
||||
https://platform.openai.com/docs/guides/realtime
|
||||
https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/audio-real-time
|
||||
|
||||
wss://my-eastus2-openai-resource.openai.azure.com/openai/realtime?api-version=2024-10-01-preview&deployment=gpt-4o-realtime-preview-1001
|
||||
*/
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"opencatd-open/pkg/tokenizer"
|
||||
"opencatd-open/store"
|
||||
"os"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// "wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"
|
||||
const realtimeURL = "wss://api.openai.com/v1/realtime"
|
||||
const azureRealtimeURL = "wss://%s.openai.azure.com/openai/realtime?api-version=2024-10-01-preview&deployment=gpt-4o-realtime-preview"
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Type string `json:"type"`
|
||||
Response Response `json:"response"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Modalities []string `json:"modalities"`
|
||||
Instructions string `json:"instructions"`
|
||||
}
|
||||
|
||||
type RealTimeResponse struct {
|
||||
Type string `json:"type"`
|
||||
EventID string `json:"event_id"`
|
||||
Response struct {
|
||||
Object string `json:"object"`
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
StatusDetails any `json:"status_details"`
|
||||
Output []struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
Role string `json:"role"`
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Transcript string `json:"transcript"`
|
||||
} `json:"content"`
|
||||
} `json:"output"`
|
||||
Usage Usage `json:"usage"`
|
||||
} `json:"response"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
} `json:"input_token_details"`
|
||||
OutputTokenDetails struct {
|
||||
TextTokens int `json:"text_tokens"`
|
||||
AudioTokens int `json:"audio_tokens"`
|
||||
} `json:"output_token_details"`
|
||||
}
|
||||
|
||||
func RealTimeProxy(c *gin.Context) {
|
||||
log.Println(c.Request.URL.String())
|
||||
var model string = c.Query("model")
|
||||
value := url.Values{}
|
||||
value.Add("model", model)
|
||||
realtimeURL := realtimeURL + "?" + value.Encode()
|
||||
|
||||
// 升级 HTTP 连接为 WebSocket
|
||||
clientConn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Println("Upgrade error:", err)
|
||||
return
|
||||
}
|
||||
defer clientConn.Close()
|
||||
|
||||
apikey, err := store.SelectKeyCacheByModel(model)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
// 连接到 OpenAI WebSocket
|
||||
headers := http.Header{"OpenAI-Beta": []string{"realtime=v1"}}
|
||||
|
||||
if apikey.ApiType == "azure" {
|
||||
headers.Set("api-key", apikey.Key)
|
||||
if apikey.EndPoint != "" {
|
||||
realtimeURL = fmt.Sprintf("%s/openai/realtime?api-version=2024-10-01-preview&deployment=gpt-4o-realtime-preview", apikey.EndPoint)
|
||||
} else {
|
||||
realtimeURL = fmt.Sprintf(azureRealtimeURL, apikey.ResourceNmae)
|
||||
}
|
||||
} else {
|
||||
headers.Set("Authorization", "Bearer "+apikey.Key)
|
||||
}
|
||||
|
||||
conn := websocket.DefaultDialer
|
||||
if os.Getenv("LOCAL_PROXY") != "" {
|
||||
proxyUrl, _ := url.Parse(os.Getenv("LOCAL_PROXY"))
|
||||
conn.Proxy = http.ProxyURL(proxyUrl)
|
||||
}
|
||||
|
||||
openAIConn, _, err := conn.Dial(realtimeURL, headers)
|
||||
if err != nil {
|
||||
log.Println("OpenAI dial error:", err)
|
||||
return
|
||||
}
|
||||
defer openAIConn.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
g.Go(func() error {
|
||||
return forwardMessages(ctx, c, clientConn, openAIConn)
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
return forwardMessages(ctx, c, openAIConn, clientConn)
|
||||
})
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
log.Println("Error in message forwarding:", err)
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func forwardMessages(ctx context.Context, c *gin.Context, src, dst *websocket.Conn) error {
|
||||
usagelog := store.Tokens{Model: "gpt-4o-realtime-preview"}
|
||||
|
||||
token, _ := c.Get("localuser")
|
||||
|
||||
lu, err := store.GetUserByToken(token.(string))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
usagelog.UserID = int(lu.ID)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
messageType, message, err := src.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||||
return nil // 正常关闭,不报错
|
||||
}
|
||||
return err
|
||||
}
|
||||
if messageType == websocket.TextMessage {
|
||||
var usage Usage
|
||||
err := json.Unmarshal(message, &usage)
|
||||
if err == nil {
|
||||
usagelog.PromptCount += usage.InputTokens
|
||||
usagelog.CompletionCount += usage.OutputTokens
|
||||
}
|
||||
|
||||
}
|
||||
err = dst.WriteMessage(messageType, message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
usagelog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(usagelog.Model, usagelog.PromptCount, usagelog.CompletionCount))
|
||||
if err := store.Record(&usagelog); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
if err := store.SumDaily(usagelog.UserID); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
93
llm/openai/tts.go
Normal file
93
llm/openai/tts.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"opencatd-open/pkg/tokenizer"
|
||||
"opencatd-open/store"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
SpeechEndpoint = "https://api.openai.com/v1/audio/speech"
|
||||
)
|
||||
|
||||
type SpeechRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input string `json:"input"`
|
||||
Voice string `json:"voice"`
|
||||
}
|
||||
|
||||
func SpeechProxy(c *gin.Context) {
|
||||
var chatreq SpeechRequest
|
||||
if err := c.ShouldBindJSON(&chatreq); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var chatlog store.Tokens
|
||||
chatlog.Model = chatreq.Model
|
||||
chatlog.CompletionCount = len(chatreq.Input)
|
||||
|
||||
token, _ := c.Get("localuser")
|
||||
|
||||
lu, err := store.GetUserByToken(token.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
chatlog.UserID = int(lu.ID)
|
||||
|
||||
key, err := store.SelectKeyCache("openai")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
targetURL, _ := url.Parse(SpeechEndpoint)
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||
|
||||
proxy.Director = func(req *http.Request) {
|
||||
req.Header = c.Request.Header
|
||||
req.Header["Authorization"] = []string{"Bearer " + key.Key}
|
||||
req.Host = targetURL.Host
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
req.URL.Host = targetURL.Host
|
||||
req.URL.Path = targetURL.Path
|
||||
req.URL.RawPath = targetURL.RawPath
|
||||
|
||||
reqBytes, _ := json.Marshal(chatreq)
|
||||
req.Body = io.NopCloser(bytes.NewReader(reqBytes))
|
||||
req.ContentLength = int64(len(reqBytes))
|
||||
|
||||
}
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
|
||||
chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
|
||||
if err := store.Record(&chatlog); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
if err := store.SumDaily(chatlog.UserID); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
proxy.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
177
llm/openai/whisper.go
Normal file
177
llm/openai/whisper.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"opencatd-open/pkg/tokenizer"
|
||||
"opencatd-open/store"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/faiface/beep"
|
||||
"github.com/faiface/beep/mp3"
|
||||
"github.com/faiface/beep/wav"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gopkg.in/vansante/go-ffprobe.v2"
|
||||
)
|
||||
|
||||
func WhisperProxy(c *gin.Context) {
|
||||
var chatlog store.Tokens
|
||||
|
||||
byteBody, _ := io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody))
|
||||
|
||||
model, _ := c.GetPostForm("model")
|
||||
|
||||
key, err := store.SelectKeyCache("openai")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chatlog.Model = model
|
||||
|
||||
token, _ := c.Get("localuser")
|
||||
|
||||
lu, err := store.GetUserByToken(token.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
chatlog.UserID = int(lu.ID)
|
||||
|
||||
if err := ParseWhisperRequestTokens(c, &chatlog, byteBody); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
if key.EndPoint == "" {
|
||||
key.EndPoint = "https://api.openai.com"
|
||||
}
|
||||
targetUrl, _ := url.ParseRequestURI(key.EndPoint + c.Request.URL.String())
|
||||
log.Println(targetUrl)
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetUrl)
|
||||
proxy.Director = func(req *http.Request) {
|
||||
req.Host = targetUrl.Host
|
||||
req.URL.Scheme = targetUrl.Scheme
|
||||
req.URL.Host = targetUrl.Host
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+key.Key)
|
||||
}
|
||||
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
|
||||
chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
|
||||
if err := store.Record(&chatlog); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
if err := store.SumDaily(chatlog.UserID); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
proxy.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
|
||||
func probe(fileReader io.Reader) (time.Duration, error) {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancelFn()
|
||||
|
||||
data, err := ffprobe.ProbeReader(ctx, fileReader)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
duration := data.Format.DurationSeconds
|
||||
pduration, err := time.ParseDuration(fmt.Sprintf("%fs", duration))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("Error parsing duration: %s", err)
|
||||
}
|
||||
return pduration, nil
|
||||
}
|
||||
|
||||
func getAudioDuration(file *multipart.FileHeader) (time.Duration, error) {
|
||||
var (
|
||||
streamer beep.StreamSeekCloser
|
||||
format beep.Format
|
||||
err error
|
||||
)
|
||||
|
||||
f, err := file.Open()
|
||||
defer f.Close()
|
||||
|
||||
// Get the file extension to determine the audio file type
|
||||
fileType := filepath.Ext(file.Filename)
|
||||
|
||||
switch fileType {
|
||||
case ".mp3":
|
||||
streamer, format, err = mp3.Decode(f)
|
||||
case ".wav":
|
||||
streamer, format, err = wav.Decode(f)
|
||||
case ".m4a":
|
||||
duration, err := probe(f)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return duration, nil
|
||||
default:
|
||||
return 0, errors.New("unsupported audio file format")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer streamer.Close()
|
||||
|
||||
// Calculate the audio file's duration.
|
||||
numSamples := streamer.Len()
|
||||
sampleRate := format.SampleRate
|
||||
duration := time.Duration(numSamples) * time.Second / time.Duration(sampleRate)
|
||||
|
||||
return duration, nil
|
||||
}
|
||||
|
||||
func ParseWhisperRequestTokens(c *gin.Context, usage *store.Tokens, byteBody []byte) error {
|
||||
file, _ := c.FormFile("file")
|
||||
model, _ := c.GetPostForm("model")
|
||||
usage.Model = model
|
||||
|
||||
if file != nil {
|
||||
duration, err := getAudioDuration(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error getting audio duration:%s", err)
|
||||
}
|
||||
|
||||
if duration > 5*time.Minute {
|
||||
return fmt.Errorf("Audio duration exceeds 5 minutes")
|
||||
}
|
||||
// 计算时长,四舍五入到最接近的秒数
|
||||
usage.PromptCount = int(duration.Round(time.Second).Seconds())
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody))
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user