add realtime proxy

This commit is contained in:
Sakurasan
2024-10-06 00:44:48 +08:00
parent eef24913e0
commit 236dffa256
9 changed files with 329 additions and 278 deletions

View File

@@ -17,7 +17,8 @@ import (
)
const (
AzureApiVersion = "2024-02-01"
// https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation#latest-preview-api-releases
AzureApiVersion = "2024-06-01"
BaseHost = "api.openai.com"
OpenAI_Endpoint = "https://api.openai.com/v1/chat/completions"
Github_Marketplace = "https://models.inference.ai.azure.com/chat/completions"
@@ -65,6 +66,10 @@ type Tool struct {
Function *FunctionDefinition `json:"function,omitempty"`
}
type StreamOption struct {
IncludeUsage bool `json:"include_Usage,omitempty"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
@@ -80,8 +85,10 @@ type ChatCompletionRequest struct {
User string `json:"user,omitempty"`
// Functions []FunctionDefinition `json:"functions,omitempty"`
// FunctionCall any `json:"function_call,omitempty"`
Tools []Tool `json:"tools,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
// ToolChoice any `json:"tool_choice,omitempty"`
StreamOptions StreamOption `json:"stream_options,omitempty"`
}
func (c ChatCompletionRequest) ToByteJson() []byte {
@@ -194,10 +201,18 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) {
prompt += "<tools>: " + string(tooljson) + "\n"
}
}
switch chatReq.Model {
case "gpt-4o", "gpt-4o-mini", "chatgpt-4o-latest":
chatReq.MaxTokens = 16384
}
if chatReq.Stream == true {
chatReq.StreamOptions.IncludeUsage = true
}
usagelog.PromptCount = tokenizer.NumTokensFromStr(prompt, chatReq.Model)
onekey, err := store.SelectKeyCache("openai")
// onekey, err := store.SelectKeyCache("openai")
onekey, err := store.SelectKeyCacheByModel(chatReq.Model)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return

View File

@@ -33,7 +33,7 @@ type DallERequest struct {
ResponseFormat string `json:"response_format,omitempty"` // url or b64_json
}
func DalleHandler(c *gin.Context) {
func DallEProxy(c *gin.Context) {
var dalleRequest DallERequest
if err := c.ShouldBind(&dalleRequest); err != nil {

109
pkg/openai/realtime.go Normal file
View File

@@ -0,0 +1,109 @@
/*
https://platform.openai.com/docs/guides/realtime
wss://my-eastus2-openai-resource.openai.azure.com/openai/realtime?api-version=2024-10-01-preview&deployment=gpt-4o-realtime-preview-1001
*/
package openai
import (
"context"
"log"
"net/http"
"net/url"
"os"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"golang.org/x/sync/errgroup"
)
// "wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"
const realtimeURL = "wss://api.openai.com/v1/realtime"
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
type Message struct {
Type string `json:"type"`
Response Response `json:"response"`
}
type Response struct {
Modalities []string `json:"modalities"`
Instructions string `json:"instructions"`
}
func RealTimeProxy(c *gin.Context) {
log.Println(c.Request.URL.String())
var model string = c.Query("model")
value := url.Values{}
value.Add("model", model)
realtimeURL := realtimeURL + "?" + value.Encode()
// 升级 HTTP 连接为 WebSocket
clientConn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Println("Upgrade error:", err)
return
}
defer clientConn.Close()
// 连接到 OpenAI WebSocket
headers := http.Header{
"Authorization": []string{"Bearer " + os.Getenv("OPENAI_API_KEY")},
"OpenAI-Beta": []string{"realtime=v1"},
}
conn := websocket.Dialer{
Proxy: http.ProxyURL(&url.URL{Scheme: "http", Host: "127.0.0.1:7890"}),
HandshakeTimeout: 45 * time.Second,
}
openAIConn, _, err := conn.Dial(realtimeURL, headers)
if err != nil {
log.Println("OpenAI dial error:", err)
return
}
defer openAIConn.Close()
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
return forwardMessages(ctx, clientConn, openAIConn)
})
g.Go(func() error {
return forwardMessages(ctx, openAIConn, clientConn)
})
if err := g.Wait(); err != nil {
log.Println("Error in message forwarding:", err)
return
}
}
func forwardMessages(ctx context.Context, src, dst *websocket.Conn) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
_, message, err := src.ReadMessage()
if err != nil {
return err
}
log.Println("Received message:", string(message))
err = dst.WriteMessage(websocket.TextMessage, message)
if err != nil {
return err
}
}
}
}

View File

@@ -25,7 +25,7 @@ type SpeechRequest struct {
Voice string `json:"voice"`
}
func SpeechHandler(c *gin.Context) {
func SpeechProxy(c *gin.Context) {
var chatreq SpeechRequest
if err := c.ShouldBindJSON(&chatreq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})

View File

@@ -95,16 +95,28 @@ func Cost(model string, promptCount, completionCount int) float64 {
cost = 0.01*float64(prompt/1000) + 0.03*float64(completion/1000)
case "gpt-4-turbo", "gpt-4-turbo-2024-04-09":
cost = 0.01*float64(prompt/1000) + 0.03*float64(completion/1000)
case "gpt-4o", "gpt-4o-2024-05-13":
// omni
case "gpt-4o", "gpt-4o-2024-08-06":
cost = 0.0025*float64(prompt/1000) + 0.01*float64(completion/1000)
case "gpt-4o-2024-05-13":
cost = 0.005*float64(prompt/1000) + 0.015*float64(completion/1000)
case "gpt-4o-2024-08-06":
cost = 0.0025*float64(prompt/1000) + 0.010*float64(completion/1000)
case "gpt-4o-mini", "gpt-4o-mini-2024-07-18":
cost = 0.00015*float64(prompt/1000) + 0.0006*float64(completion/1000)
case "chatgpt-4o-latest":
cost = 0.005*float64(prompt/1000) + 0.015*float64(completion/1000)
// o1
case "o1-preview", "o1-preview-2024-09-12":
cost = 0.015*float64(prompt/1000) + 0.06*float64(completion/1000)
case "o1-mini", "o1-mini-2024-09-12":
cost = 0.003*float64(prompt/1000) + 0.012*float64(completion/1000)
// Realtime API
// Audio*
// $0.1 / 1K input tokens
// $0.2 / 1K output tokens
case "gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01":
cost = 0.005*float64(prompt/1000) + 0.020*float64(completion/1000)
case "gpt-4o-realtime-preview.audio", "gpt-4o-realtime-preview-2024-10-01.audio":
cost = 0.1*float64(prompt/1000) + 0.2*float64(completion/1000)
case "whisper-1":
// 0.006$/min
cost = 0.006 * float64(prompt+completion) / 60