refact & update new model
This commit is contained in:
@@ -17,8 +17,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
AzureApiVersion = "2024-02-01"
|
||||
OpenAI_Endpoint = "https://api.openai.com/v1/chat/completions"
|
||||
AzureApiVersion = "2024-02-01"
|
||||
BaseHost = "api.openai.com"
|
||||
OpenAI_Endpoint = "https://api.openai.com/v1/chat/completions"
|
||||
Github_Marketplace = "https://models.inference.ai.azure.com/chat/completions"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -204,6 +206,10 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) {
|
||||
var req *http.Request
|
||||
|
||||
switch onekey.ApiType {
|
||||
case "github":
|
||||
req, err = http.NewRequest(c.Request.Method, Github_Marketplace, bytes.NewReader(chatReq.ToByteJson()))
|
||||
req.Header = c.Request.Header
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key))
|
||||
case "azure":
|
||||
var buildurl string
|
||||
if onekey.EndPoint != "" {
|
||||
@@ -215,15 +221,18 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) {
|
||||
req.Header = c.Request.Header
|
||||
req.Header.Set("api-key", onekey.Key)
|
||||
default:
|
||||
req, err = http.NewRequest(c.Request.Method, OpenAI_Endpoint, bytes.NewReader(chatReq.ToByteJson()))
|
||||
if onekey.EndPoint != "" { // 优先key的endpoint
|
||||
req, err = http.NewRequest(c.Request.Method, onekey.EndPoint+c.Request.RequestURI, bytes.NewReader(chatReq.ToByteJson()))
|
||||
} else {
|
||||
if BaseURL != "" { // 其次BaseURL
|
||||
req, err = http.NewRequest(c.Request.Method, BaseURL+c.Request.RequestURI, bytes.NewReader(chatReq.ToByteJson()))
|
||||
} else { // 最后是gateway的endpoint
|
||||
req, err = http.NewRequest(c.Request.Method, AIGateWay_Endpoint, bytes.NewReader(chatReq.ToByteJson()))
|
||||
}
|
||||
}
|
||||
if AIGateWay_Endpoint != "" { // cloudflare gateway的endpoint
|
||||
req, err = http.NewRequest(c.Request.Method, AIGateWay_Endpoint, bytes.NewReader(chatReq.ToByteJson()))
|
||||
}
|
||||
customEndpoint := os.Getenv("CUSTOM_ENDPOINT") // 最后是用户自定义的endpoint CUSTOM_ENDPOINT=true OpenAI_Endpoint
|
||||
if customEndpoint == "true" && OpenAI_Endpoint != "" {
|
||||
req, err = http.NewRequest(c.Request.Method, BaseURL, bytes.NewReader(chatReq.ToByteJson()))
|
||||
}
|
||||
|
||||
req.Header = c.Request.Header
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key))
|
||||
}
|
||||
|
||||
177
pkg/openai/whisper.go
Normal file
177
pkg/openai/whisper.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"opencatd-open/pkg/tokenizer"
|
||||
"opencatd-open/store"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/faiface/beep"
|
||||
"github.com/faiface/beep/mp3"
|
||||
"github.com/faiface/beep/wav"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gopkg.in/vansante/go-ffprobe.v2"
|
||||
)
|
||||
|
||||
func WhisperProxy(c *gin.Context) {
|
||||
var chatlog store.Tokens
|
||||
|
||||
byteBody, _ := io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody))
|
||||
|
||||
model, _ := c.GetPostForm("model")
|
||||
|
||||
key, err := store.SelectKeyCache("openai")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chatlog.Model = model
|
||||
|
||||
token, _ := c.Get("localuser")
|
||||
|
||||
lu, err := store.GetUserByToken(token.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
chatlog.UserID = int(lu.ID)
|
||||
|
||||
if err := ParseWhisperRequestTokens(c, &chatlog, byteBody); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
if key.EndPoint == "" {
|
||||
key.EndPoint = "https://api.openai.com"
|
||||
}
|
||||
targetUrl, _ := url.ParseRequestURI(key.EndPoint + c.Request.URL.String())
|
||||
log.Println(targetUrl)
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetUrl)
|
||||
proxy.Director = func(req *http.Request) {
|
||||
req.Host = targetUrl.Host
|
||||
req.URL.Scheme = targetUrl.Scheme
|
||||
req.URL.Host = targetUrl.Host
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+key.Key)
|
||||
}
|
||||
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
|
||||
chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
|
||||
if err := store.Record(&chatlog); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
if err := store.SumDaily(chatlog.UserID); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
proxy.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
|
||||
func probe(fileReader io.Reader) (time.Duration, error) {
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancelFn()
|
||||
|
||||
data, err := ffprobe.ProbeReader(ctx, fileReader)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
duration := data.Format.DurationSeconds
|
||||
pduration, err := time.ParseDuration(fmt.Sprintf("%fs", duration))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("Error parsing duration: %s", err)
|
||||
}
|
||||
return pduration, nil
|
||||
}
|
||||
|
||||
func getAudioDuration(file *multipart.FileHeader) (time.Duration, error) {
|
||||
var (
|
||||
streamer beep.StreamSeekCloser
|
||||
format beep.Format
|
||||
err error
|
||||
)
|
||||
|
||||
f, err := file.Open()
|
||||
defer f.Close()
|
||||
|
||||
// Get the file extension to determine the audio file type
|
||||
fileType := filepath.Ext(file.Filename)
|
||||
|
||||
switch fileType {
|
||||
case ".mp3":
|
||||
streamer, format, err = mp3.Decode(f)
|
||||
case ".wav":
|
||||
streamer, format, err = wav.Decode(f)
|
||||
case ".m4a":
|
||||
duration, err := probe(f)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return duration, nil
|
||||
default:
|
||||
return 0, errors.New("unsupported audio file format")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer streamer.Close()
|
||||
|
||||
// Calculate the audio file's duration.
|
||||
numSamples := streamer.Len()
|
||||
sampleRate := format.SampleRate
|
||||
duration := time.Duration(numSamples) * time.Second / time.Duration(sampleRate)
|
||||
|
||||
return duration, nil
|
||||
}
|
||||
|
||||
func ParseWhisperRequestTokens(c *gin.Context, usage *store.Tokens, byteBody []byte) error {
|
||||
file, _ := c.FormFile("file")
|
||||
model, _ := c.GetPostForm("model")
|
||||
usage.Model = model
|
||||
|
||||
if file != nil {
|
||||
duration, err := getAudioDuration(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error getting audio duration:%s", err)
|
||||
}
|
||||
|
||||
if duration > 5*time.Minute {
|
||||
return fmt.Errorf("Audio duration exceeds 5 minutes")
|
||||
}
|
||||
// 计算时长,四舍五入到最接近的秒数
|
||||
usage.PromptCount = int(duration.Round(time.Second).Seconds())
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody))
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user