package router import ( "bufio" "bytes" "context" "crypto/tls" "encoding/json" "errors" "fmt" "io" "log" "mime/multipart" "net" "net/http" "net/http/httputil" "net/url" "opencatd-open/pkg/azureopenai" "opencatd-open/store" "os" "path/filepath" "strings" "time" "github.com/Sakurasan/to" "github.com/duke-git/lancet/v2/cryptor" "github.com/faiface/beep" "github.com/faiface/beep/mp3" "github.com/faiface/beep/wav" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/pkoukk/tiktoken-go" "github.com/sashabaranov/go-openai" "gopkg.in/vansante/go-ffprobe.v2" "gorm.io/gorm" ) var ( rootToken string baseUrl = "https://api.openai.com" GPT3Dot5Turbo = "gpt-3.5-turbo" GPT4 = "gpt-4" client = getHttpClient() ) type User struct { IsDelete bool `json:"IsDelete,omitempty"` ID int `json:"id,omitempty"` UpdatedAt string `json:"updatedAt,omitempty"` Name string `json:"name,omitempty"` Token string `json:"token,omitempty"` CreatedAt string `json:"createdAt,omitempty"` } type Key struct { ID int `json:"id,omitempty"` Key string `json:"key,omitempty"` Name string `json:"name,omitempty"` ApiType string `json:"api_type,omitempty"` Endpoint string `json:"endpoint,omitempty"` UpdatedAt string `json:"updatedAt,omitempty"` CreatedAt string `json:"createdAt,omitempty"` } type ChatCompletionMessage struct { Role string `json:"role"` Content string `json:"content"` Name string `json:"name,omitempty"` } type ChatCompletionRequest struct { Model string `json:"model"` Messages []ChatCompletionMessage `json:"messages"` MaxTokens int `json:"max_tokens,omitempty"` Temperature float32 `json:"temperature,omitempty"` TopP float32 `json:"top_p,omitempty"` N int `json:"n,omitempty"` Stream bool `json:"stream,omitempty"` Stop []string `json:"stop,omitempty"` PresencePenalty float32 `json:"presence_penalty,omitempty"` FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` LogitBias map[string]int `json:"logit_bias,omitempty"` User string `json:"user,omitempty"` } type ChatCompletionChoice struct { Index int `json:"index"` Message ChatCompletionMessage `json:"message"` FinishReason string `json:"finish_reason"` } type ChatCompletionResponse struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` Choices []ChatCompletionChoice `json:"choices"` Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } `json:"usage"` } func init() { if openai_endpoint := os.Getenv("openai_endpoint"); openai_endpoint != "" { log.Println(fmt.Sprintf("replace %s to %s", baseUrl, openai_endpoint)) baseUrl = openai_endpoint } } func AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { if rootToken == "" { u, err := store.GetUserByID(uint(1)) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) c.Abort() return } rootToken = u.Token } token := c.GetHeader("Authorization") if token == "" || token[:7] != "Bearer " { c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) c.Abort() return } if store.IsExistAuthCache(token[7:]) { if strings.HasPrefix(c.Request.URL.Path, "/1/me") { c.Next() return } } if token[7:] != rootToken { u, err := store.GetUserByID(uint(1)) if err != nil { log.Println(err) c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) c.Abort() return } if token[:7] != u.Token { c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) c.Abort() return } rootToken = u.Token store.LoadAuthCache() } // 可以在这里对 token 进行验证并检查权限 c.Next() } } func Handleinit(c *gin.Context) { user, err := store.GetUserByID(1) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { u := store.User{Name: "root", Token: uuid.NewString()} u.ID = 1 if err := store.CreateUser(&u); err != nil { c.JSON(http.StatusForbidden, gin.H{ "error": err.Error(), }) return } else { rootToken = u.Token resJSON := User{ false, int(u.ID), u.UpdatedAt.Format(time.RFC3339), u.Name, u.Token, u.CreatedAt.Format(time.RFC3339), } c.JSON(http.StatusOK, resJSON) return } } c.JSON(http.StatusOK, gin.H{ "error": err.Error(), }) return } if user.ID == uint(1) { c.JSON(http.StatusForbidden, gin.H{ "error": "super user already exists, use cli to reset password", }) } } func HandleMe(c *gin.Context) { token := c.GetHeader("Authorization") u, err := store.GetUserByToken(token[7:]) if err != nil { c.JSON(http.StatusOK, gin.H{ "error": err.Error(), }) } resJSON := User{ false, int(u.ID), u.UpdatedAt.Format(time.RFC3339), u.Name, u.Token, u.CreatedAt.Format(time.RFC3339), } c.JSON(http.StatusOK, resJSON) } func HandleMeUsage(c *gin.Context) { token := c.GetHeader("Authorization") fromStr := c.Query("from") toStr := c.Query("to") getMonthStartAndEnd := func() (start, end string) { loc, _ := time.LoadLocation("Local") now := time.Now().In(loc) year, month, _ := now.Date() startOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, loc) endOfMonth := startOfMonth.AddDate(0, 1, 0) start = startOfMonth.Format("2006-01-02") end = endOfMonth.Format("2006-01-02") return } if fromStr == "" || toStr == "" { fromStr, toStr = getMonthStartAndEnd() } user, err := store.GetUserByToken(token) if err != nil { c.AbortWithError(http.StatusForbidden, err) return } usage, err := store.QueryUserUsage(to.String(user.ID), fromStr, toStr) if err != nil { c.AbortWithError(http.StatusForbidden, err) return } c.JSON(200, usage) } func HandleKeys(c *gin.Context) { keys, err := store.GetAllKeys() if err != nil { c.JSON(http.StatusOK, gin.H{ "error": err.Error(), }) } c.JSON(http.StatusOK, keys) } func HandleUsers(c *gin.Context) { users, err := store.GetAllUsers() if err != nil { c.JSON(http.StatusOK, gin.H{ "error": err.Error(), }) } c.JSON(http.StatusOK, users) } func HandleAddKey(c *gin.Context) { var body Key if err := c.BindJSON(&body); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ "message": err.Error(), }}) return } body.Name = strings.ToLower(strings.TrimSpace(body.Name)) body.Key = strings.TrimSpace(body.Key) if strings.HasPrefix(body.Name, "azure.") { keynames := strings.Split(body.Name, ".") if len(keynames) < 2 { c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{ "message": "Invalid Key Name", }}) return } k := &store.Key{ ApiType: "azure_openai", Name: body.Name, Key: body.Key, ResourceNmae: keynames[1], EndPoint: body.Endpoint, } if err := store.CreateKey(k); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ "message": err.Error(), }}) return } } else { if body.ApiType == "" { if err := store.AddKey("openai", body.Key, body.Name); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ "message": err.Error(), }}) return } } else { k := &store.Key{ ApiType: body.ApiType, Name: body.Name, Key: body.Key, ResourceNmae: azureopenai.GetResourceName(body.Endpoint), EndPoint: body.Endpoint, } if err := store.CreateKey(k); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ "message": err.Error(), }}) return } } } k, err := store.GetKeyrByName(body.Name) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{ "message": err.Error(), }}) return } c.JSON(http.StatusOK, k) } func HandleDelKey(c *gin.Context) { id := to.Int(c.Param("id")) if id < 1 { c.JSON(http.StatusOK, gin.H{"error": "invalid key id"}) return } if err := store.DeleteKey(uint(id)); err != nil { c.JSON(http.StatusOK, gin.H{"error": "invalid key id"}) return } c.JSON(http.StatusOK, gin.H{"message": "ok"}) } func HandleAddUser(c *gin.Context) { var body User if err := c.BindJSON(&body); err != nil { c.JSON(http.StatusOK, gin.H{"error": err.Error()}) return } if len(body.Name) == 0 { c.JSON(http.StatusOK, gin.H{"error": "invalid user name"}) return } if err := store.AddUser(body.Name, uuid.NewString()); err != nil { c.JSON(http.StatusOK, gin.H{"error": err.Error()}) return } u, err := store.GetUserByName(body.Name) if err != nil { c.JSON(http.StatusOK, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, u) } func HandleDelUser(c *gin.Context) { id := to.Int(c.Param("id")) if id <= 1 { c.JSON(http.StatusOK, gin.H{"error": "invalid user id"}) return } if err := store.DeleteUser(uint(id)); err != nil { c.JSON(http.StatusOK, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{"message": "ok"}) } func HandleResetUserToken(c *gin.Context) { id := to.Int(c.Param("id")) newtoken := c.Query("token") if newtoken == "" { newtoken = uuid.NewString() } if err := store.UpdateUser(uint(id), newtoken); err != nil { c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) return } u, err := store.GetUserByID(uint(id)) if err != nil { c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) return } if u.ID == uint(1) { rootToken = u.Token } c.JSON(http.StatusOK, u) } func GenerateToken() string { token := uuid.New() return token.String() } func getHttpClient() *http.Client { tr := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, }).DialContext, ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } return &http.Client{Transport: tr} } func HandleProy(c *gin.Context) { var ( localuser bool isStream bool chatreq = openai.ChatCompletionRequest{} chatres = openai.ChatCompletionResponse{} chatlog store.Tokens pre_prompt string req *http.Request err error // wg sync.WaitGroup ) auth := c.Request.Header.Get("Authorization") if len(auth) > 7 && auth[:7] == "Bearer " { localuser = store.IsExistAuthCache(auth[7:]) c.Set("localuser", localuser) } if c.Request.URL.Path == "/v1/audio/transcriptions" { WhisperProxy(c) return } if c.Request.URL.Path == "/v1/chat/completions" && localuser { if store.KeysCache.ItemCount() == 0 { c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{ "message": "No Api-Key Available", }}) return } onekey := store.FromKeyCacheRandomItemKey() if err := c.BindJSON(&chatreq); err != nil { c.AbortWithError(http.StatusBadRequest, err) return } chatlog.Model = chatreq.Model for _, m := range chatreq.Messages { pre_prompt += m.Content + "\n" } chatlog.PromptHash = cryptor.Md5String(pre_prompt) chatlog.PromptCount = NumTokensFromMessages(chatreq.Messages, chatreq.Model) isStream = chatreq.Stream chatlog.UserID, _ = store.GetUserID(auth[7:]) var body bytes.Buffer json.NewEncoder(&body).Encode(chatreq) // 创建 API 请求 switch onekey.ApiType { case "azure_openai": var buildurl string var apiVersion = "2023-05-15" if onekey.EndPoint != "" { buildurl = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", onekey.EndPoint, modelmap(chatreq.Model), apiVersion) } else { buildurl = fmt.Sprintf("https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s", onekey.ResourceNmae, modelmap(chatreq.Model), apiVersion) } req, err = http.NewRequest(c.Request.Method, buildurl, &body) req.Header = c.Request.Header req.Header.Set("api-key", onekey.Key) case "openai": fallthrough default: if onekey.EndPoint != "" { req, err = http.NewRequest(c.Request.Method, onekey.EndPoint+c.Request.RequestURI, &body) } else { req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, &body) } req.Header = c.Request.Header req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key)) } if err != nil { log.Println(err) c.JSON(http.StatusOK, gin.H{"error": err.Error()}) return } } else { req, err = http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, c.Request.Body) if err != nil { log.Println(err) c.JSON(http.StatusOK, gin.H{"error": err.Error()}) return } req.Header = c.Request.Header } resp, err := client.Do(req) if err != nil { log.Println(err) c.JSON(http.StatusOK, gin.H{"error": err.Error()}) return } defer resp.Body.Close() // 复制 API 响应头部 for name, values := range resp.Header { for _, value := range values { c.Writer.Header().Add(name, value) } } head := map[string]string{ "Cache-Control": "no-store", "access-control-allow-origin": "*", "access-control-allow-credentials": "true", } for k, v := range head { if _, ok := resp.Header[k]; !ok { c.Writer.Header().Set(k, v) } } resp.Header.Del("content-security-policy") resp.Header.Del("content-security-policy-report-only") resp.Header.Del("clear-site-data") c.Writer.WriteHeader(resp.StatusCode) writer := bufio.NewWriter(c.Writer) defer writer.Flush() reader := bufio.NewReader(resp.Body) if resp.StatusCode == 200 && localuser { if isStream { contentCh := fetchResponseContent(c, reader) var buffer bytes.Buffer for content := range contentCh { buffer.WriteString(content) } chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model) chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) if err := store.Record(&chatlog); err != nil { log.Println(err) } if err := store.SumDaily(chatlog.UserID); err != nil { log.Println(err) } return } res, err := io.ReadAll(reader) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{ "message": err.Error(), }}) return } reader = bufio.NewReader(bytes.NewBuffer(res)) json.NewDecoder(bytes.NewBuffer(res)).Decode(&chatres) chatlog.PromptCount = chatres.Usage.PromptTokens chatlog.CompletionCount = chatres.Usage.CompletionTokens chatlog.TotalTokens = chatres.Usage.TotalTokens chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) if err := store.Record(&chatlog); err != nil { log.Println(err) } if err := store.SumDaily(chatlog.UserID); err != nil { log.Println(err) } } // 返回 API 响应主体 if _, err := io.Copy(writer, reader); err != nil { log.Println(err) c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{ "message": err.Error(), }}) return } } func HandleReverseProxy(c *gin.Context) { proxy := &httputil.ReverseProxy{ Director: func(req *http.Request) { req.URL.Scheme = "https" req.URL.Host = "api.openai.com" // req.Header.Set("Authorization", "Bearer YOUR_API_KEY_HERE") }, Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, }).DialContext, ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, }, } var localuser bool auth := c.Request.Header.Get("Authorization") if len(auth) > 7 && auth[:7] == "Bearer " { log.Println(store.IsExistAuthCache(auth[7:])) localuser = store.IsExistAuthCache(auth[7:]) } req, err := http.NewRequest(c.Request.Method, c.Request.URL.Path, c.Request.Body) if err != nil { c.JSON(http.StatusOK, gin.H{"error": err.Error()}) return } req.Header = c.Request.Header if localuser { if store.KeysCache.ItemCount() == 0 { c.JSON(http.StatusOK, gin.H{"error": "No Api-Key Available"}) return } onekey := store.FromKeyCacheRandomItemKey() req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", onekey.Key)) } proxy.ServeHTTP(c.Writer, req) } // https://openai.com/pricing func Cost(model string, promptCount, completionCount int) float64 { var cost, prompt, completion float64 prompt = float64(promptCount) completion = float64(completionCount) switch model { case "gpt-3.5-turbo-0301": cost = 0.002 * float64((prompt+completion)/1000) case "gpt-3.5-turbo", "gpt-3.5-turbo-0613": cost = 0.0015*float64((prompt)/1000) + 0.002*float64(completion/1000) case "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613": cost = 0.003*float64((prompt)/1000) + 0.004*float64(completion/1000) case "gpt-4", "gpt-4-0613", "gpt-4-0314": cost = 0.03*float64(prompt/1000) + 0.06*float64(completion/1000) case "gpt-4-32k", "gpt-4-32k-0613": cost = 0.06*float64(prompt/1000) + 0.12*float64(completion/1000) case "whisper-1": // 0.006$/min cost = 0.006 * float64(prompt+completion) / 60 default: if strings.Contains(model, "gpt-3.5-turbo") { cost = 0.003 * float64((prompt+completion)/1000) } else if strings.Contains(model, "gpt-4") { cost = 0.06 * float64((prompt+completion)/1000) } else { cost = 0.002 * float64((prompt+completion)/1000) } } return cost } func HandleUsage(c *gin.Context) { fromStr := c.Query("from") toStr := c.Query("to") getMonthStartAndEnd := func() (start, end string) { loc, _ := time.LoadLocation("Local") now := time.Now().In(loc) year, month, _ := now.Date() startOfMonth := time.Date(year, month, 1, 0, 0, 0, 0, loc) endOfMonth := startOfMonth.AddDate(0, 1, 0) start = startOfMonth.Format("2006-01-02") end = endOfMonth.Format("2006-01-02") return } if fromStr == "" || toStr == "" { fromStr, toStr = getMonthStartAndEnd() } usage, err := store.QueryUsage(fromStr, toStr) if err != nil { c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) return } c.JSON(200, usage) } func fetchResponseContent(ctx *gin.Context, responseBody *bufio.Reader) <-chan string { contentCh := make(chan string) go func() { defer close(contentCh) for { line, err := responseBody.ReadString('\n') if err == nil { lines := strings.Split(line, "") for _, word := range lines { ctx.Writer.WriteString(word) ctx.Writer.Flush() } if line == "\n" { continue } if strings.HasPrefix(line, "data:") { line = strings.TrimSpace(strings.TrimPrefix(line, "data:")) if strings.HasSuffix(line, "[DONE]") { break } line = strings.TrimSpace(line) } dec := json.NewDecoder(strings.NewReader(line)) var data map[string]interface{} if err := dec.Decode(&data); err == io.EOF { log.Println("EOF:", err) break } else if err != nil { fmt.Println("Error decoding response:", err) return } if choices, ok := data["choices"].([]interface{}); ok { for _, choice := range choices { choiceMap := choice.(map[string]interface{}) if content, ok := choiceMap["delta"].(map[string]interface{})["content"]; ok { contentCh <- content.(string) } } } } else { break } } }() return contentCh } func NumTokensFromMessages(messages []openai.ChatCompletionMessage, model string) (numTokens int) { tkm, err := tiktoken.EncodingForModel(model) if err != nil { err = fmt.Errorf("EncodingForModel: %v", err) log.Println(err) return } var tokensPerMessage, tokensPerName int switch model { case "gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613": tokensPerMessage = 3 tokensPerName = 1 case "gpt-3.5-turbo-0301": tokensPerMessage = 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n tokensPerName = -1 // if there's a name, the role is omitted default: if strings.Contains(model, "gpt-3.5-turbo") { log.Println("warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") return NumTokensFromMessages(messages, "gpt-3.5-turbo-0613") } else if strings.Contains(model, "gpt-4") { log.Println("warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") return NumTokensFromMessages(messages, "gpt-4-0613") } else { err = fmt.Errorf("warning: unknown model [%s]. Use default calculation method converted tokens.", model) log.Println(err) return NumTokensFromMessages(messages, "gpt-3.5-turbo-0613") } } for _, message := range messages { numTokens += tokensPerMessage numTokens += len(tkm.Encode(message.Content, nil, nil)) numTokens += len(tkm.Encode(message.Role, nil, nil)) numTokens += len(tkm.Encode(message.Name, nil, nil)) if message.Name != "" { numTokens += tokensPerName } } numTokens += 3 return numTokens } func NumTokensFromStr(messages string, model string) (num_tokens int) { tkm, err := tiktoken.EncodingForModel(model) if err != nil { err = fmt.Errorf("EncodingForModel: %v", err) fmt.Println(err) return } num_tokens += len(tkm.Encode(messages, nil, nil)) return num_tokens } func modelmap(in string) string { // gpt-3.5-turbo -> gpt-35-turbo if strings.Contains(in, ".") { return strings.ReplaceAll(in, ".", "") } return in } func WhisperProxy(c *gin.Context) { var chatlog store.Tokens byteBody, _ := io.ReadAll(c.Request.Body) c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody)) model, _ := c.GetPostForm("model") key, err := store.SelectKeyCache("openai") if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": err.Error(), }, }) return } chatlog.Model = model token, _ := c.Get("localuser") lu, err := store.GetUserByToken(token.(string)) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": err.Error(), }, }) return } chatlog.UserID = int(lu.ID) if err := ParseWhisperRequestTokens(c, &chatlog, byteBody); err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": err.Error(), }, }) return } targetUrl, _ := url.ParseRequestURI(key.EndPoint) proxy := httputil.NewSingleHostReverseProxy(targetUrl) proxy.Director = func(req *http.Request) { req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+key.Key) } proxy.ModifyResponse = func(resp *http.Response) error { if resp.StatusCode != http.StatusOK { return nil } chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) if err := store.Record(&chatlog); err != nil { log.Println(err) } if err := store.SumDaily(chatlog.UserID); err != nil { log.Println(err) } return nil } proxy.ServeHTTP(c.Writer, c.Request) } func probe(fileReader io.Reader) (time.Duration, error) { ctx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second) defer cancelFn() data, err := ffprobe.ProbeReader(ctx, fileReader) if err != nil { return 0, err } duration := data.Format.DurationSeconds pduration, err := time.ParseDuration(fmt.Sprintf("%fs", duration)) if err != nil { return 0, fmt.Errorf("Error parsing duration: %s", err) } return pduration, nil } func getAudioDuration(file *multipart.FileHeader) (time.Duration, error) { var ( streamer beep.StreamSeekCloser format beep.Format err error ) f, err := file.Open() defer f.Close() // Get the file extension to determine the audio file type fileType := filepath.Ext(file.Filename) switch fileType { case ".mp3": streamer, format, err = mp3.Decode(f) case ".wav": streamer, format, err = wav.Decode(f) case ".m4a": duration, err := probe(f) if err != nil { return 0, err } return duration, nil default: return 0, errors.New("unsupported audio file format") } if err != nil { return 0, err } defer streamer.Close() // Calculate the audio file's duration. numSamples := streamer.Len() sampleRate := format.SampleRate duration := time.Duration(numSamples) * time.Second / time.Duration(sampleRate) return duration, nil } func ParseWhisperRequestTokens(c *gin.Context, usage *store.Tokens, byteBody []byte) error { file, _ := c.FormFile("file") model, _ := c.GetPostForm("model") usage.Model = model if file != nil { duration, err := getAudioDuration(file) if err != nil { return fmt.Errorf("Error getting audio duration:%s", err) } if duration > 5*time.Minute { return fmt.Errorf("Audio duration exceeds 5 minutes") } // 计算时长,四舍五入到最接近的秒数 usage.PromptCount = int(duration.Round(time.Second).Seconds()) } c.Request.Body = io.NopCloser(bytes.NewBuffer(byteBody)) return nil }