diff --git a/router/router.go b/router/router.go index 13077b3..9484583 100644 --- a/router/router.go +++ b/router/router.go @@ -2,7 +2,9 @@ package router import ( "bufio" + "bytes" "crypto/tls" + "encoding/json" "errors" "fmt" "io" @@ -18,6 +20,7 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" "gorm.io/gorm" + // "github.com/pkoukk/tiktoken-go" ) var ( @@ -292,6 +295,12 @@ func GenerateToken() string { return token.String() } +type _ struct { + model string + promptCount int + completionCount int +} + func HandleProy(c *gin.Context) { var localuser bool var isStream bool @@ -376,6 +385,7 @@ func HandleProy(c *gin.Context) { // return // } reader := bufio.NewReader(resp.Body) + // var resbuf = bytes.NewBuffer(nil) if resp.StatusCode == 200 && isStream { //todo @@ -437,6 +447,18 @@ func HandleReverseProxy(c *gin.Context) { proxy.ServeHTTP(c.Writer, req) } +func Cost(model string, promptCount, completionCount int) float64 { + var cost float64 + switch model { + case "gpt-3.5": + cost = 0.002 * float64((promptCount+completionCount)/1000) + case "gpt-4-32k": + cost = 0.06*float64(promptCount/1000) + 0.12*float64(completionCount/1000) + case "gpt-4": + cost = 0.03*float64(promptCount/1000) + 0.06*float64(completionCount/1000) + } + return cost +} func HandleUsage(c *gin.Context) { fromStr := c.Query("from") @@ -463,39 +485,46 @@ func HandleUsage(c *gin.Context) { c.JSON(200, usage) } -// todo -func streamEvent(reader *bufio.Reader) error { - - lineChan := make(chan string, 1) - - timeout := time.AfterFunc(60*time.Second, func() { - lineChan <- "" - }) - +func fetchResponseContent(buf *bytes.Buffer, responseBody *bufio.Reader) <-chan string { + contentCh := make(chan string) go func() { - line, err := reader.ReadString('\n') - if err == nil { - lineChan <- line + defer close(contentCh) + for { + line, err := responseBody.ReadString('\n') + if err == nil { + buf.WriteString(line) + if line == "\n" { + continue + } + if strings.HasPrefix(line, "data:") { + line = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if strings.HasSuffix(line, "[DONE]") { + break + } + line = strings.TrimSpace(line) + } + + dec := json.NewDecoder(strings.NewReader(line)) + var data map[string]interface{} + if err := dec.Decode(&data); err == io.EOF { + log.Println("EOF:", err) + break + } else if err != nil { + fmt.Println("Error decoding response:", err) + return + } + if choices, ok := data["choices"].([]interface{}); ok { + for _, choice := range choices { + choiceMap := choice.(map[string]interface{}) + if content, ok := choiceMap["delta"].(map[string]interface{})["content"]; ok { + contentCh <- content.(string) + } + } + } + } else { + break + } } }() - - line := <-lineChan - - timeout.Stop() - if line == "" { - - } - - if strings.HasPrefix(line, "data:") { - line = strings.TrimSpace(strings.TrimPrefix(line, "data:")) - //log.Println("Received data:", line) - - if line == "[DONE]" { - - } - - line = strings.TrimSpace(line) - - } - return nil + return contentCh }