This commit is contained in:
Sakurasan
2023-08-10 22:54:56 +08:00
parent adbc388920
commit a0155fc0b1
6 changed files with 225 additions and 29 deletions

View File

@@ -3,27 +3,35 @@ package router
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"mime/multipart"
"net"
"net/http"
"net/http/httputil"
"net/url"
"opencatd-open/pkg/azureopenai"
"opencatd-open/store"
"os"
"path/filepath"
"strings"
"time"
"github.com/Sakurasan/to"
"github.com/duke-git/lancet/v2/cryptor"
"github.com/faiface/beep"
"github.com/faiface/beep/mp3"
"github.com/faiface/beep/wav"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/pkoukk/tiktoken-go"
"github.com/sashabaranov/go-openai"
"gopkg.in/vansante/go-ffprobe.v2"
"gorm.io/gorm"
)
@@ -435,6 +443,11 @@ func HandleProy(c *gin.Context) {
auth := c.Request.Header.Get("Authorization")
if len(auth) > 7 && auth[:7] == "Bearer " {
localuser = store.IsExistAuthCache(auth[7:])
c.Set("localuser", localuser)
}
if c.Request.URL.Path == "/v1/audio/transcriptions" {
WhisperProxy(c)
return
}
if c.Request.URL.Path == "/v1/chat/completions" && localuser {
@@ -632,6 +645,8 @@ func HandleReverseProxy(c *gin.Context) {
proxy.ServeHTTP(c.Writer, req)
}
// https://openai.com/pricing
func Cost(model string, promptCount, completionCount int) float64 {
var cost, prompt, completion float64
prompt = float64(promptCount)
@@ -648,6 +663,9 @@ func Cost(model string, promptCount, completionCount int) float64 {
cost = 0.03*float64(prompt/1000) + 0.06*float64(completion/1000)
case "gpt-4-32k", "gpt-4-32k-0613":
cost = 0.06*float64(prompt/1000) + 0.12*float64(completion/1000)
case "whisper-1":
// 0.006$/min
cost = 0.006 * float64(prompt+completion) / 60
default:
if strings.Contains(model, "gpt-3.5-turbo") {
cost = 0.003 * float64((prompt+completion)/1000)
@@ -809,3 +827,145 @@ func modelmap(in string) string {
}
return in
}
func WhisperProxy(c *gin.Context) {
var chatlog store.Tokens
byteBody, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody))
model, _ := c.GetPostForm("model")
key, err := store.SelectKeyCache("openai")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
},
})
return
}
chatlog.Model = model
token, _ := c.Get("localuser")
lu, err := store.GetUserByToken(token.(string))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
},
})
return
}
chatlog.UserID = int(lu.ID)
ParseWhisperRequestTokens(c, &chatlog, byteBody)
targetUrl, _ := url.ParseRequestURI(key.EndPoint)
proxy := httputil.NewSingleHostReverseProxy(targetUrl)
proxy.Director = func(req *http.Request) {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+key.Key)
}
proxy.ModifyResponse = func(resp *http.Response) error {
if resp.StatusCode != http.StatusOK {
return nil
}
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
if err := store.Record(&chatlog); err != nil {
log.Println(err)
}
if err := store.SumDaily(chatlog.UserID); err != nil {
log.Println(err)
}
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
return
}
func probe(fileReader io.Reader) (time.Duration, error) {
ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFn()
data, err := ffprobe.ProbeReader(ctx, fileReader)
if err != nil {
return 0, err
}
duration := data.Format.DurationSeconds
pduration, err := time.ParseDuration(fmt.Sprintf("%fs", duration))
if err != nil {
return 0, fmt.Errorf("Error parsing duration: %s", err)
}
return pduration, nil
}
func getAudioDuration(file *multipart.FileHeader) (time.Duration, error) {
var (
streamer beep.StreamSeekCloser
format beep.Format
err error
)
f, err := file.Open()
defer f.Close()
// Get the file extension to determine the audio file type
fileType := filepath.Ext(file.Filename)
switch fileType {
case ".mp3":
streamer, format, err = mp3.Decode(f)
case ".wav":
streamer, format, err = wav.Decode(f)
case ".m4a":
duration, err := probe(f)
if err != nil {
return 0, err
}
return duration, nil
default:
return 0, errors.New("unsupported audio file format")
}
if err != nil {
return 0, err
}
defer streamer.Close()
// Calculate the audio file's duration.
numSamples := streamer.Len()
sampleRate := format.SampleRate
duration := time.Duration(numSamples) * time.Second / time.Duration(sampleRate)
return duration, nil
}
func ParseWhisperRequestTokens(c *gin.Context, usage *store.Tokens, byteBody []byte) error {
file, _ := c.FormFile("file")
model, _ := c.GetPostForm("model")
usage.Model = model
if file != nil {
duration, err := getAudioDuration(file)
if err != nil {
return fmt.Errorf("Error getting audio duration:%s", err)
}
if duration > 5*time.Minute {
return fmt.Errorf("Audio duration exceeds 5 minutes")
}
// 计算时长,四舍五入到最接近的秒数
usage.PromptCount = int(duration.Round(time.Second).Seconds())
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody))
return nil
}