diff --git a/pkg/claude/claude.go b/pkg/claude/claude.go index 4387d27..8b4ae07 100644 --- a/pkg/claude/claude.go +++ b/pkg/claude/claude.go @@ -42,6 +42,7 @@ import ( "net/http" "net/http/httputil" "net/url" + "opencatd-open/pkg/tokenizer" "opencatd-open/store" "strings" "sync" @@ -135,7 +136,7 @@ func ClaudeProxy(c *gin.Context) { return } - key, err := store.SelectKeyCache("anthropic") + key, err := store.SelectKeyCache("claude") //anthropic if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ @@ -160,7 +161,7 @@ func ClaudeProxy(c *gin.Context) { } chatlog.UserID = int(lu.ID) - // todo calc prompt token + chatlog.PromptCount = tokenizer.NumTokensFromStr(complete.Prompt, complete.Model) if key.EndPoint == "" { key.EndPoint = "https://api.anthropic.com" @@ -192,6 +193,7 @@ func ClaudeProxy(c *gin.Context) { log.Println(err) return nil } + chatlog.CompletionCount = tokenizer.NumTokensFromStr(complete_resp.Completion, chatlog.Model) } else { var completion string for { @@ -213,10 +215,12 @@ func ClaudeProxy(c *gin.Context) { } } log.Println("completion:", completion) + chatlog.CompletionCount = tokenizer.NumTokensFromStr(completion, chatlog.Model) } - chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount - // todo calc cost + // calc cost + chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount + chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) if err := store.Record(&chatlog); err != nil { log.Println(err) @@ -259,7 +263,7 @@ func TransReq(chatreq *openai.ChatCompletionRequest) (*bytes.Buffer, error) { return payload, nil } -func TransRsp(c *gin.Context, isStream bool, reader *bufio.Reader) { +func TransRsp(c *gin.Context, isStream bool, chatlog store.Tokens, reader *bufio.Reader) { if !isStream { var completersp CompleteResponse var chatrsp openai.ChatCompletionResponse @@ -286,13 +290,24 @@ func TransRsp(c *gin.Context, isStream bool, reader *bufio.Reader) { }) return } + chatlog.CompletionCount = tokenizer.NumTokensFromStr(completersp.Completion, chatlog.Model) + chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount + chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) + if err := store.Record(&chatlog); err != nil { + log.Println(err) + } + if err := store.SumDaily(chatlog.UserID); err != nil { + log.Println(err) + } + c.JSON(http.StatusOK, payload) return } else { var ( - wg sync.WaitGroup - dataChan = make(chan string) - stopChan = make(chan bool) + wg sync.WaitGroup + dataChan = make(chan string) + stopChan = make(chan bool) + complete_resp string ) wg.Add(2) go func() { @@ -305,6 +320,7 @@ func TransRsp(c *gin.Context, isStream bool, reader *bufio.Reader) { json.NewDecoder(strings.NewReader(line[6:])).Decode(&result) if result.StopReason == "" { if result.Completion != "" { + complete_resp += result.Completion chatrsp := openai.ChatCompletionStreamResponse{ ID: result.LogID, Model: result.Model, @@ -372,6 +388,15 @@ func TransRsp(c *gin.Context, isStream bool, reader *bufio.Reader) { } }() wg.Wait() + chatlog.CompletionCount = tokenizer.NumTokensFromStr(complete_resp, chatlog.Model) + chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount + chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) + if err := store.Record(&chatlog); err != nil { + log.Println(err) + } + if err := store.SumDaily(chatlog.UserID); err != nil { + log.Println(err) + } } } diff --git a/pkg/tokenizer/tokenizer.go b/pkg/tokenizer/tokenizer.go new file mode 100644 index 0000000..ec0e7fb --- /dev/null +++ b/pkg/tokenizer/tokenizer.go @@ -0,0 +1,114 @@ +package tokenizer + +import ( + "fmt" + "log" + "strings" + + "github.com/pkoukk/tiktoken-go" + "github.com/sashabaranov/go-openai" +) + +func NumTokensFromMessages(messages []openai.ChatCompletionMessage, model string) (numTokens int) { + tkm, err := tiktoken.EncodingForModel(model) + if err != nil { + err = fmt.Errorf("EncodingForModel: %v", err) + log.Println(err) + return + } + + var tokensPerMessage, tokensPerName int + + switch model { + case "gpt-3.5-turbo", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-16k-0613", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613": + tokensPerMessage = 3 + tokensPerName = 1 + case "gpt-3.5-turbo-0301": + tokensPerMessage = 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n + tokensPerName = -1 // if there's a name, the role is omitted + default: + if strings.Contains(model, "gpt-3.5-turbo") { + log.Println("warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") + return NumTokensFromMessages(messages, "gpt-3.5-turbo-0613") + } else if strings.Contains(model, "gpt-4") { + log.Println("warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") + return NumTokensFromMessages(messages, "gpt-4-0613") + } else { + err = fmt.Errorf("warning: unknown model [%s]. Use default calculation method converted tokens.", model) + log.Println(err) + return NumTokensFromMessages(messages, "gpt-3.5-turbo-0613") + } + } + + for _, message := range messages { + numTokens += tokensPerMessage + numTokens += len(tkm.Encode(message.Content, nil, nil)) + numTokens += len(tkm.Encode(message.Role, nil, nil)) + numTokens += len(tkm.Encode(message.Name, nil, nil)) + if message.Name != "" { + numTokens += tokensPerName + } + } + numTokens += 3 + return numTokens +} + +func NumTokensFromStr(messages string, model string) (num_tokens int) { + tkm, err := tiktoken.EncodingForModel(model) + if err != nil { + fmt.Println(err) + fmt.Println("Unsupport Model,use cl100k_base Encode") + tkm, _ = tiktoken.GetEncoding("cl100k_base") + } + + num_tokens += len(tkm.Encode(messages, nil, nil)) + return num_tokens +} + +// https://openai.com/pricing +func Cost(model string, promptCount, completionCount int) float64 { + var cost, prompt, completion float64 + prompt = float64(promptCount) + completion = float64(completionCount) + + switch model { + case "gpt-3.5-turbo-0301": + cost = 0.002 * float64((prompt+completion)/1000) + case "gpt-3.5-turbo", "gpt-3.5-turbo-0613": + cost = 0.0015*float64((prompt)/1000) + 0.002*float64(completion/1000) + case "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613": + cost = 0.003*float64((prompt)/1000) + 0.004*float64(completion/1000) + case "gpt-4", "gpt-4-0613", "gpt-4-0314": + cost = 0.03*float64(prompt/1000) + 0.06*float64(completion/1000) + case "gpt-4-32k", "gpt-4-32k-0613": + cost = 0.06*float64(prompt/1000) + 0.12*float64(completion/1000) + case "whisper-1": + // 0.006$/min + cost = 0.006 * float64(prompt+completion) / 60 + // claude /million tokens + case "claude-v1", "claude-v1-100k": + cost = 11.02/1000000*float64(prompt) + (32.68/1000000)*float64(completion) + case "claude-instant-v1", "claude-instant-v1-100k": + cost = (1.63/1000000)*float64(prompt) + (5.51/1000000)*float64(completion) + case "claude-2": + cost = (11.02/1000000)*float64(prompt) + (32.68/1000000)*float64(completion) + default: + if strings.Contains(model, "gpt-3.5-turbo") { + cost = 0.003 * float64((prompt+completion)/1000) + } else if strings.Contains(model, "gpt-4") { + cost = 0.06 * float64((prompt+completion)/1000) + } else { + cost = 0.002 * float64((prompt+completion)/1000) + } + } + return cost +} diff --git a/router/router.go b/router/router.go index a7bdc32..0ddb29f 100644 --- a/router/router.go +++ b/router/router.go @@ -17,6 +17,7 @@ import ( "net/url" "opencatd-open/pkg/azureopenai" "opencatd-open/pkg/claude" + "opencatd-open/pkg/tokenizer" "opencatd-open/store" "os" "path/filepath" @@ -30,7 +31,6 @@ import ( "github.com/faiface/beep/wav" "github.com/gin-gonic/gin" "github.com/google/uuid" - "github.com/pkoukk/tiktoken-go" "github.com/sashabaranov/go-openai" "gopkg.in/vansante/go-ffprobe.v2" "gorm.io/gorm" @@ -499,7 +499,7 @@ func HandleProy(c *gin.Context) { pre_prompt += m.Content + "\n" } chatlog.PromptHash = cryptor.Md5String(pre_prompt) - chatlog.PromptCount = NumTokensFromMessages(chatreq.Messages, chatreq.Model) + chatlog.PromptCount = tokenizer.NumTokensFromMessages(chatreq.Messages, chatreq.Model) isStream = chatreq.Stream chatlog.UserID, _ = store.GetUserID(auth[7:]) @@ -603,7 +603,7 @@ func HandleProy(c *gin.Context) { if resp.StatusCode == 200 && localuser { switch onekey.ApiType { case "claude": - claude.TransRsp(c, isStream, reader) + claude.TransRsp(c, isStream, chatlog, reader) return case "openai", "azure", "azure_openai": fallthrough @@ -614,9 +614,9 @@ func HandleProy(c *gin.Context) { for content := range contentCh { buffer.WriteString(content) } - chatlog.CompletionCount = NumTokensFromStr(buffer.String(), chatreq.Model) + chatlog.CompletionCount = tokenizer.NumTokensFromStr(buffer.String(), chatreq.Model) chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount - chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) + chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) if err := store.Record(&chatlog); err != nil { log.Println(err) } @@ -637,7 +637,7 @@ func HandleProy(c *gin.Context) { chatlog.PromptCount = chatres.Usage.PromptTokens chatlog.CompletionCount = chatres.Usage.CompletionTokens chatlog.TotalTokens = chatres.Usage.TotalTokens - chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) + chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) if err := store.Record(&chatlog); err != nil { log.Println(err) } @@ -705,45 +705,6 @@ func HandleReverseProxy(c *gin.Context) { } -// https://openai.com/pricing -func Cost(model string, promptCount, completionCount int) float64 { - var cost, prompt, completion float64 - prompt = float64(promptCount) - completion = float64(completionCount) - - switch model { - case "gpt-3.5-turbo-0301": - cost = 0.002 * float64((prompt+completion)/1000) - case "gpt-3.5-turbo", "gpt-3.5-turbo-0613": - cost = 0.0015*float64((prompt)/1000) + 0.002*float64(completion/1000) - case "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613": - cost = 0.003*float64((prompt)/1000) + 0.004*float64(completion/1000) - case "gpt-4", "gpt-4-0613", "gpt-4-0314": - cost = 0.03*float64(prompt/1000) + 0.06*float64(completion/1000) - case "gpt-4-32k", "gpt-4-32k-0613": - cost = 0.06*float64(prompt/1000) + 0.12*float64(completion/1000) - case "whisper-1": - // 0.006$/min - cost = 0.006 * float64(prompt+completion) / 60 - // claude /million tokens - case "claude-v1", "claude-v1-100k": - cost = 11.02/1000000*float64(prompt) + (32.68/1000000)*float64(completion) - case "claude-instant-v1", "claude-instant-v1-100k": - cost = (1.63/1000000)*float64(prompt) + (5.51/1000000)*float64(completion) - case "claude-2": - cost = (11.02/1000000)*float64(prompt) + (32.68/1000000)*float64(completion) - default: - if strings.Contains(model, "gpt-3.5-turbo") { - cost = 0.003 * float64((prompt+completion)/1000) - } else if strings.Contains(model, "gpt-4") { - cost = 0.06 * float64((prompt+completion)/1000) - } else { - cost = 0.002 * float64((prompt+completion)/1000) - } - } - return cost -} - func HandleUsage(c *gin.Context) { fromStr := c.Query("from") toStr := c.Query("to") @@ -821,71 +782,6 @@ func fetchResponseContent(ctx *gin.Context, responseBody *bufio.Reader) <-chan s return contentCh } -func NumTokensFromMessages(messages []openai.ChatCompletionMessage, model string) (numTokens int) { - tkm, err := tiktoken.EncodingForModel(model) - if err != nil { - err = fmt.Errorf("EncodingForModel: %v", err) - log.Println(err) - return - } - - var tokensPerMessage, tokensPerName int - - switch model { - case "gpt-3.5-turbo", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-16k-0613", - "gpt-4", - "gpt-4-0314", - "gpt-4-0613", - "gpt-4-32k", - "gpt-4-32k-0314", - "gpt-4-32k-0613": - tokensPerMessage = 3 - tokensPerName = 1 - case "gpt-3.5-turbo-0301": - tokensPerMessage = 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n - tokensPerName = -1 // if there's a name, the role is omitted - default: - if strings.Contains(model, "gpt-3.5-turbo") { - log.Println("warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") - return NumTokensFromMessages(messages, "gpt-3.5-turbo-0613") - } else if strings.Contains(model, "gpt-4") { - log.Println("warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") - return NumTokensFromMessages(messages, "gpt-4-0613") - } else { - err = fmt.Errorf("warning: unknown model [%s]. Use default calculation method converted tokens.", model) - log.Println(err) - return NumTokensFromMessages(messages, "gpt-3.5-turbo-0613") - } - } - - for _, message := range messages { - numTokens += tokensPerMessage - numTokens += len(tkm.Encode(message.Content, nil, nil)) - numTokens += len(tkm.Encode(message.Role, nil, nil)) - numTokens += len(tkm.Encode(message.Name, nil, nil)) - if message.Name != "" { - numTokens += tokensPerName - } - } - numTokens += 3 - return numTokens -} - -func NumTokensFromStr(messages string, model string) (num_tokens int) { - tkm, err := tiktoken.EncodingForModel(model) - if err != nil { - err = fmt.Errorf("EncodingForModel: %v", err) - fmt.Println(err) - return - } - - num_tokens += len(tkm.Encode(messages, nil, nil)) - return num_tokens -} - func modelmap(in string) string { // gpt-3.5-turbo -> gpt-35-turbo if strings.Contains(in, ".") { @@ -954,7 +850,7 @@ func WhisperProxy(c *gin.Context) { return nil } chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount - chatlog.Cost = fmt.Sprintf("%.6f", Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) + chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) if err := store.Record(&chatlog); err != nil { log.Println(err) }