reface to openteam

This commit is contained in:
Sakurasan
2025-04-16 18:01:27 +08:00
parent bc223d6530
commit e7ffc9e8b9
92 changed files with 5345 additions and 1273 deletions

178
llm/openai/chat.go Normal file
View File

@@ -0,0 +1,178 @@
package openai
import (
"encoding/json"
"os"
"strings"
)
const (
// https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation#latest-preview-api-releases
AzureApiVersion = "2024-10-21"
BaseHost = "api.openai.com"
OpenAI_Endpoint = "https://api.openai.com/v1/chat/completions"
Github_Marketplace = "https://models.inference.ai.azure.com/chat/completions"
)
var (
Custom_Endpoint string
AIGateWay_Endpoint string // "https://gateway.ai.cloudflare.com/v1/431ba10f11200d544922fbca177aaa7f/openai/openai/chat/completions"
)
func init() {
if os.Getenv("OpenAI_Endpoint") != "" {
Custom_Endpoint = os.Getenv("OpenAI_Endpoint")
}
if os.Getenv("AIGateWay_Endpoint") != "" {
AIGateWay_Endpoint = os.Getenv("AIGateWay_Endpoint")
}
}
// Vision Content
type VisionContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
ImageURL *VisionImageURL `json:"image_url,omitempty"`
}
type VisionImageURL struct {
URL string `json:"url,omitempty"`
Detail string `json:"detail,omitempty"`
}
type ChatCompletionMessage struct {
Role string `json:"role"`
Content any `json:"content"`
Name string `json:"name,omitempty"`
// MultiContent []VisionContent
}
type FunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Parameters any `json:"parameters"`
}
type Tool struct {
Type string `json:"type"`
Function *FunctionDefinition `json:"function,omitempty"`
}
type StreamOption struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Stream bool `json:"stream"`
Stop []string `json:"stop,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
LogitBias map[string]int `json:"logit_bias,omitempty"`
User string `json:"user,omitempty"`
// Functions []FunctionDefinition `json:"functions,omitempty"`
// FunctionCall any `json:"function_call,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
// ToolChoice any `json:"tool_choice,omitempty"`
StreamOptions *StreamOption `json:"stream_options,omitempty"`
}
func (c ChatCompletionRequest) ToByteJson() []byte {
bytejson, _ := json.Marshal(c)
return bytejson
}
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
}
type ChatCompletionResponse struct {
ID string `json:"id,omitempty"`
Object string `json:"object,omitempty"`
Created int `json:"created,omitempty"`
Model string `json:"model,omitempty"`
Choices []struct {
Index int `json:"index,omitempty"`
Message struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
} `json:"message,omitempty"`
Logprobs string `json:"logprobs,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
} `json:"choices,omitempty"`
Usage struct {
PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
PromptTokensDetails struct {
CachedTokens int `json:"cached_tokens,omitempty"`
AudioTokens int `json:"audio_tokens,omitempty"`
} `json:"prompt_tokens_details,omitempty"`
CompletionTokensDetails struct {
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
AudioTokens int `json:"audio_tokens,omitempty"`
AcceptedPredictionTokens int `json:"accepted_prediction_tokens,omitempty"`
RejectedPredictionTokens int `json:"rejected_prediction_tokens,omitempty"`
} `json:"completion_tokens_details,omitempty"`
} `json:"usage,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
}
type Choice struct {
Index int `json:"index"`
Delta struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
type ChatCompletionStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
}
func (c *ChatCompletionStreamResponse) ByteJson() []byte {
bytejson, _ := json.Marshal(c)
return bytejson
}
func modelmap(in string) string {
// gpt-3.5-turbo -> gpt-35-turbo
if strings.Contains(in, ".") {
return strings.ReplaceAll(in, ".", "")
}
return in
}
type ErrResponse struct {
Error struct {
Message string `json:"message"`
Code string `json:"code"`
} `json:"error"`
}
func (e *ErrResponse) ByteJson() []byte {
bytejson, _ := json.Marshal(e)
return bytejson
}

149
llm/openai/dall-e.go Normal file
View File

@@ -0,0 +1,149 @@
package openai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/http/httputil"
"net/url"
"opencatd-open/pkg/tokenizer"
"opencatd-open/store"
"strconv"
"github.com/duke-git/lancet/v2/slice"
"github.com/gin-gonic/gin"
)
const (
DalleEndpoint = "https://api.openai.com/v1/images/generations"
DalleEditEndpoint = "https://api.openai.com/v1/images/edits"
DalleVariationEndpoint = "https://api.openai.com/v1/images/variations"
)
type DallERequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `form:"n" json:"n,omitempty"`
Size string `form:"size" json:"size,omitempty"`
Quality string `json:"quality,omitempty"` // standard,hd
Style string `json:"style,omitempty"` // vivid,natural
ResponseFormat string `json:"response_format,omitempty"` // url or b64_json
}
func DallEProxy(c *gin.Context) {
var dalleRequest DallERequest
if err := c.ShouldBind(&dalleRequest); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
if dalleRequest.N == 0 {
dalleRequest.N = 1
}
if dalleRequest.Size == "" {
dalleRequest.Size = "512x512"
}
model := dalleRequest.Model
var chatlog store.Tokens
chatlog.CompletionCount = dalleRequest.N
if model == "dall-e" {
model = "dall-e-2"
}
model = model + "." + dalleRequest.Size
if dalleRequest.Model == "dall-e-2" || dalleRequest.Model == "dall-e" {
if !slice.Contain([]string{"256x256", "512x512", "1024x1024"}, dalleRequest.Size) {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Invalid size: %s for %s", dalleRequest.Size, dalleRequest.Model),
},
})
return
}
} else if dalleRequest.Model == "dall-e-3" {
if !slice.Contain([]string{"256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"}, dalleRequest.Size) {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Invalid size: %s for %s", dalleRequest.Size, dalleRequest.Model),
},
})
return
}
if dalleRequest.Quality == "hd" {
model = model + ".hd"
}
} else {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Invalid model: %s", dalleRequest.Model),
},
})
return
}
chatlog.Model = model
token, _ := c.Get("localuser")
lu, err := store.GetUserByToken(token.(string))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
},
})
return
}
chatlog.UserID = int(lu.ID)
key, err := store.SelectKeyCache("openai")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
},
})
return
}
targetURL, _ := url.Parse(DalleEndpoint)
proxy := httputil.NewSingleHostReverseProxy(targetURL)
proxy.Director = func(req *http.Request) {
req.Header.Set("Authorization", "Bearer "+key.Key)
req.Header.Set("Content-Type", "application/json")
req.Host = targetURL.Host
req.URL.Scheme = targetURL.Scheme
req.URL.Host = targetURL.Host
req.URL.Path = targetURL.Path
req.URL.RawPath = targetURL.RawPath
req.URL.RawQuery = targetURL.RawQuery
bytebody, _ := json.Marshal(dalleRequest)
req.Body = io.NopCloser(bytes.NewBuffer(bytebody))
req.ContentLength = int64(len(bytebody))
req.Header.Set("Content-Length", strconv.Itoa(len(bytebody)))
}
proxy.ModifyResponse = func(resp *http.Response) error {
if resp.StatusCode == http.StatusOK {
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
if err := store.Record(&chatlog); err != nil {
log.Println(err)
}
if err := store.SumDaily(chatlog.UserID); err != nil {
log.Println(err)
}
}
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}

215
llm/openai/handle_proxy.go Normal file
View File

@@ -0,0 +1,215 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"opencatd-open/pkg/tokenizer"
"opencatd-open/store"
"strings"
"github.com/gin-gonic/gin"
)
func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) {
usagelog := store.Tokens{Model: chatReq.Model}
token, _ := c.Get("localuser")
lu, err := store.GetUserByToken(token.(string))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
},
})
return
}
usagelog.UserID = int(lu.ID)
var prompt string
for _, msg := range chatReq.Messages {
switch ct := msg.Content.(type) {
case string:
prompt += "<" + msg.Role + ">: " + msg.Content.(string) + "\n"
case []any:
for _, item := range ct {
if m, ok := item.(map[string]interface{}); ok {
if m["type"] == "text" {
prompt += "<" + msg.Role + ">: " + m["text"].(string) + "\n"
} else if m["type"] == "image_url" {
if url, ok := m["image_url"].(map[string]interface{}); ok {
fmt.Printf(" URL: %v\n", url["url"])
if strings.HasPrefix(url["url"].(string), "http") {
fmt.Println("网络图片:", url["url"].(string))
}
}
}
}
}
default:
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "Invalid content type",
},
})
return
}
if len(chatReq.Tools) > 0 {
tooljson, _ := json.Marshal(chatReq.Tools)
prompt += "<tools>: " + string(tooljson) + "\n"
}
}
switch chatReq.Model {
case "gpt-4o", "gpt-4o-mini", "chatgpt-4o-latest":
chatReq.MaxTokens = 16384
}
if chatReq.Stream {
chatReq.StreamOptions = &StreamOption{IncludeUsage: true}
}
usagelog.PromptCount = tokenizer.NumTokensFromStr(prompt, chatReq.Model)
// onekey, err := store.SelectKeyCache("openai")
onekey, err := store.SelectKeyCacheByModel(chatReq.Model)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var req *http.Request
switch onekey.ApiType {
case "github":
req, err = http.NewRequest(c.Request.Method, Github_Marketplace, bytes.NewReader(chatReq.ToByteJson()))
req.Header = c.Request.Header
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key))
case "azure":
var buildurl string
if onekey.EndPoint != "" {
buildurl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", onekey.EndPoint, modelmap(chatReq.Model), AzureApiVersion)
} else {
buildurl = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s", onekey.ResourceNmae, modelmap(chatReq.Model), AzureApiVersion)
}
req, err = http.NewRequest(c.Request.Method, buildurl, bytes.NewReader(chatReq.ToByteJson()))
req.Header = c.Request.Header
req.Header.Set("api-key", onekey.Key)
default:
req, err = http.NewRequest(c.Request.Method, OpenAI_Endpoint, bytes.NewReader(chatReq.ToByteJson())) // default endpoint
if AIGateWay_Endpoint != "" { // cloudflare gateway的endpoint
req, err = http.NewRequest(c.Request.Method, AIGateWay_Endpoint, bytes.NewReader(chatReq.ToByteJson()))
}
if Custom_Endpoint != "" { // 自定义endpoint
req, err = http.NewRequest(c.Request.Method, Custom_Endpoint, bytes.NewReader(chatReq.ToByteJson()))
}
if onekey.EndPoint != "" { // 优先key的endpoint
req, err = http.NewRequest(c.Request.Method, onekey.EndPoint+c.Request.RequestURI, bytes.NewReader(chatReq.ToByteJson()))
}
req.Header = c.Request.Header
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key))
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer resp.Body.Close()
var result string
if chatReq.Stream {
for key, value := range resp.Header {
for _, v := range value {
c.Writer.Header().Add(key, v)
}
}
c.Writer.WriteHeader(resp.StatusCode)
teeReader := io.TeeReader(resp.Body, c.Writer)
// 流式响应
scanner := bufio.NewScanner(teeReader)
for scanner.Scan() {
line := scanner.Bytes()
if len(line) > 0 && bytes.HasPrefix(line, []byte("data: ")) {
if bytes.HasPrefix(line, []byte("data: [DONE]")) {
break
}
var opiResp ChatCompletionStreamResponse
line = bytes.Replace(line, []byte("data: "), []byte(""), -1)
line = bytes.TrimSpace(line)
if err := json.Unmarshal(line, &opiResp); err != nil {
continue
}
if opiResp.Choices != nil && len(opiResp.Choices) > 0 {
if opiResp.Choices[0].Delta.Role != "" {
result += "<" + opiResp.Choices[0].Delta.Role + "> "
}
result += opiResp.Choices[0].Delta.Content // 计算Content Token
if len(opiResp.Choices[0].Delta.ToolCalls) > 0 { // 计算ToolCalls token
if opiResp.Choices[0].Delta.ToolCalls[0].Function.Name != "" {
result += "name:" + opiResp.Choices[0].Delta.ToolCalls[0].Function.Name + " arguments:"
}
result += opiResp.Choices[0].Delta.ToolCalls[0].Function.Arguments
}
} else {
continue
}
}
}
} else {
// 处理非流式响应
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Println("Error reading response body:", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var opiResp ChatCompletionResponse
if err := json.Unmarshal(body, &opiResp); err != nil {
log.Println("Error parsing JSON:", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if opiResp.Choices != nil && len(opiResp.Choices) > 0 {
if opiResp.Choices[0].Message.Role != "" {
result += "<" + opiResp.Choices[0].Message.Role + "> "
}
result += opiResp.Choices[0].Message.Content
if len(opiResp.Choices[0].Message.ToolCalls) > 0 {
if opiResp.Choices[0].Message.ToolCalls[0].Function.Name != "" {
result += "name:" + opiResp.Choices[0].Message.ToolCalls[0].Function.Name + " arguments:"
}
result += opiResp.Choices[0].Message.ToolCalls[0].Function.Arguments
}
}
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.JSON(http.StatusOK, opiResp)
}
usagelog.CompletionCount = tokenizer.NumTokensFromStr(result, chatReq.Model)
usagelog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(usagelog.Model, usagelog.PromptCount, usagelog.CompletionCount))
if err := store.Record(&usagelog); err != nil {
log.Println(err)
}
if err := store.SumDaily(usagelog.UserID); err != nil {
log.Println(err)
}
}

197
llm/openai/realtime.go Normal file
View File

@@ -0,0 +1,197 @@
/*
https://platform.openai.com/docs/guides/realtime
https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/audio-real-time
wss://my-eastus2-openai-resource.openai.azure.com/openai/realtime?api-version=2024-10-01-preview&deployment=gpt-4o-realtime-preview-1001
*/
package openai
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"opencatd-open/pkg/tokenizer"
"opencatd-open/store"
"os"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"golang.org/x/sync/errgroup"
)
// "wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"
const realtimeURL = "wss://api.openai.com/v1/realtime"
const azureRealtimeURL = "wss://%s.openai.azure.com/openai/realtime?api-version=2024-10-01-preview&deployment=gpt-4o-realtime-preview"
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
type Message struct {
Type string `json:"type"`
Response Response `json:"response"`
}
type Response struct {
Modalities []string `json:"modalities"`
Instructions string `json:"instructions"`
}
type RealTimeResponse struct {
Type string `json:"type"`
EventID string `json:"event_id"`
Response struct {
Object string `json:"object"`
ID string `json:"id"`
Status string `json:"status"`
StatusDetails any `json:"status_details"`
Output []struct {
ID string `json:"id"`
Object string `json:"object"`
Type string `json:"type"`
Status string `json:"status"`
Role string `json:"role"`
Content []struct {
Type string `json:"type"`
Transcript string `json:"transcript"`
} `json:"content"`
} `json:"output"`
Usage Usage `json:"usage"`
} `json:"response"`
}
type Usage struct {
TotalTokens int `json:"total_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
} `json:"input_token_details"`
OutputTokenDetails struct {
TextTokens int `json:"text_tokens"`
AudioTokens int `json:"audio_tokens"`
} `json:"output_token_details"`
}
func RealTimeProxy(c *gin.Context) {
log.Println(c.Request.URL.String())
var model string = c.Query("model")
value := url.Values{}
value.Add("model", model)
realtimeURL := realtimeURL + "?" + value.Encode()
// 升级 HTTP 连接为 WebSocket
clientConn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Println("Upgrade error:", err)
return
}
defer clientConn.Close()
apikey, err := store.SelectKeyCacheByModel(model)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 连接到 OpenAI WebSocket
headers := http.Header{"OpenAI-Beta": []string{"realtime=v1"}}
if apikey.ApiType == "azure" {
headers.Set("api-key", apikey.Key)
if apikey.EndPoint != "" {
realtimeURL = fmt.Sprintf("%s/openai/realtime?api-version=2024-10-01-preview&deployment=gpt-4o-realtime-preview", apikey.EndPoint)
} else {
realtimeURL = fmt.Sprintf(azureRealtimeURL, apikey.ResourceNmae)
}
} else {
headers.Set("Authorization", "Bearer "+apikey.Key)
}
conn := websocket.DefaultDialer
if os.Getenv("LOCAL_PROXY") != "" {
proxyUrl, _ := url.Parse(os.Getenv("LOCAL_PROXY"))
conn.Proxy = http.ProxyURL(proxyUrl)
}
openAIConn, _, err := conn.Dial(realtimeURL, headers)
if err != nil {
log.Println("OpenAI dial error:", err)
return
}
defer openAIConn.Close()
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
return forwardMessages(ctx, c, clientConn, openAIConn)
})
g.Go(func() error {
return forwardMessages(ctx, c, openAIConn, clientConn)
})
if err := g.Wait(); err != nil {
log.Println("Error in message forwarding:", err)
return
}
}
func forwardMessages(ctx context.Context, c *gin.Context, src, dst *websocket.Conn) error {
usagelog := store.Tokens{Model: "gpt-4o-realtime-preview"}
token, _ := c.Get("localuser")
lu, err := store.GetUserByToken(token.(string))
if err != nil {
return err
}
usagelog.UserID = int(lu.ID)
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
messageType, message, err := src.ReadMessage()
if err != nil {
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
return nil // 正常关闭,不报错
}
return err
}
if messageType == websocket.TextMessage {
var usage Usage
err := json.Unmarshal(message, &usage)
if err == nil {
usagelog.PromptCount += usage.InputTokens
usagelog.CompletionCount += usage.OutputTokens
}
}
err = dst.WriteMessage(messageType, message)
if err != nil {
return err
}
}
}
defer func() {
usagelog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(usagelog.Model, usagelog.PromptCount, usagelog.CompletionCount))
if err := store.Record(&usagelog); err != nil {
log.Println(err)
}
if err := store.SumDaily(usagelog.UserID); err != nil {
log.Println(err)
}
}()
return nil
}

93
llm/openai/tts.go Normal file
View File

@@ -0,0 +1,93 @@
package openai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/http/httputil"
"net/url"
"opencatd-open/pkg/tokenizer"
"opencatd-open/store"
"github.com/gin-gonic/gin"
)
const (
SpeechEndpoint = "https://api.openai.com/v1/audio/speech"
)
type SpeechRequest struct {
Model string `json:"model"`
Input string `json:"input"`
Voice string `json:"voice"`
}
func SpeechProxy(c *gin.Context) {
var chatreq SpeechRequest
if err := c.ShouldBindJSON(&chatreq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
var chatlog store.Tokens
chatlog.Model = chatreq.Model
chatlog.CompletionCount = len(chatreq.Input)
token, _ := c.Get("localuser")
lu, err := store.GetUserByToken(token.(string))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
},
})
return
}
chatlog.UserID = int(lu.ID)
key, err := store.SelectKeyCache("openai")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
},
})
return
}
targetURL, _ := url.Parse(SpeechEndpoint)
proxy := httputil.NewSingleHostReverseProxy(targetURL)
proxy.Director = func(req *http.Request) {
req.Header = c.Request.Header
req.Header["Authorization"] = []string{"Bearer " + key.Key}
req.Host = targetURL.Host
req.URL.Scheme = targetURL.Scheme
req.URL.Host = targetURL.Host
req.URL.Path = targetURL.Path
req.URL.RawPath = targetURL.RawPath
reqBytes, _ := json.Marshal(chatreq)
req.Body = io.NopCloser(bytes.NewReader(reqBytes))
req.ContentLength = int64(len(reqBytes))
}
proxy.ModifyResponse = func(resp *http.Response) error {
if resp.StatusCode == http.StatusOK {
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
if err := store.Record(&chatlog); err != nil {
log.Println(err)
}
if err := store.SumDaily(chatlog.UserID); err != nil {
log.Println(err)
}
}
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}

177
llm/openai/whisper.go Normal file
View File

@@ -0,0 +1,177 @@
package openai
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"log"
"mime/multipart"
"net/http"
"net/http/httputil"
"net/url"
"opencatd-open/pkg/tokenizer"
"opencatd-open/store"
"path/filepath"
"time"
"github.com/faiface/beep"
"github.com/faiface/beep/mp3"
"github.com/faiface/beep/wav"
"github.com/gin-gonic/gin"
"gopkg.in/vansante/go-ffprobe.v2"
)
func WhisperProxy(c *gin.Context) {
var chatlog store.Tokens
byteBody, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody))
model, _ := c.GetPostForm("model")
key, err := store.SelectKeyCache("openai")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
},
})
return
}
chatlog.Model = model
token, _ := c.Get("localuser")
lu, err := store.GetUserByToken(token.(string))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
},
})
return
}
chatlog.UserID = int(lu.ID)
if err := ParseWhisperRequestTokens(c, &chatlog, byteBody); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
},
})
return
}
if key.EndPoint == "" {
key.EndPoint = "https://api.openai.com"
}
targetUrl, _ := url.ParseRequestURI(key.EndPoint + c.Request.URL.String())
log.Println(targetUrl)
proxy := httputil.NewSingleHostReverseProxy(targetUrl)
proxy.Director = func(req *http.Request) {
req.Host = targetUrl.Host
req.URL.Scheme = targetUrl.Scheme
req.URL.Host = targetUrl.Host
req.Header.Set("Authorization", "Bearer "+key.Key)
}
proxy.ModifyResponse = func(resp *http.Response) error {
if resp.StatusCode != http.StatusOK {
return nil
}
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
if err := store.Record(&chatlog); err != nil {
log.Println(err)
}
if err := store.SumDaily(chatlog.UserID); err != nil {
log.Println(err)
}
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}
func probe(fileReader io.Reader) (time.Duration, error) {
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
data, err := ffprobe.ProbeReader(ctx, fileReader)
if err != nil {
return 0, err
}
duration := data.Format.DurationSeconds
pduration, err := time.ParseDuration(fmt.Sprintf("%fs", duration))
if err != nil {
return 0, fmt.Errorf("Error parsing duration: %s", err)
}
return pduration, nil
}
func getAudioDuration(file *multipart.FileHeader) (time.Duration, error) {
var (
streamer beep.StreamSeekCloser
format beep.Format
err error
)
f, err := file.Open()
defer f.Close()
// Get the file extension to determine the audio file type
fileType := filepath.Ext(file.Filename)
switch fileType {
case ".mp3":
streamer, format, err = mp3.Decode(f)
case ".wav":
streamer, format, err = wav.Decode(f)
case ".m4a":
duration, err := probe(f)
if err != nil {
return 0, err
}
return duration, nil
default:
return 0, errors.New("unsupported audio file format")
}
if err != nil {
return 0, err
}
defer streamer.Close()
// Calculate the audio file's duration.
numSamples := streamer.Len()
sampleRate := format.SampleRate
duration := time.Duration(numSamples) * time.Second / time.Duration(sampleRate)
return duration, nil
}
func ParseWhisperRequestTokens(c *gin.Context, usage *store.Tokens, byteBody []byte) error {
file, _ := c.FormFile("file")
model, _ := c.GetPostForm("model")
usage.Model = model
if file != nil {
duration, err := getAudioDuration(file)
if err != nil {
return fmt.Errorf("Error getting audio duration:%s", err)
}
if duration > 5*time.Minute {
return fmt.Errorf("Audio duration exceeds 5 minutes")
}
// 计算时长,四舍五入到最接近的秒数
usage.PromptCount = int(duration.Round(time.Second).Seconds())
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody))
return nil
}