From c47b540664f90c69bb3a7c61ac82c78ccc97fe74 Mon Sep 17 00:00:00 2001 From: Sakurasan <1173092237@qq.com> Date: Sat, 8 Jul 2023 20:56:54 +0800 Subject: [PATCH] update:num token --- router/router.go | 55 ++++++++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/router/router.go b/router/router.go index f4726af..507baa3 100644 --- a/router/router.go +++ b/router/router.go @@ -737,38 +737,53 @@ func fetchResponseContent(ctx *gin.Context, responseBody *bufio.Reader) <-chan s return contentCh } -func NumTokensFromMessages(messages []openai.ChatCompletionMessage, model string) (num_tokens int) { +func NumTokensFromMessages(messages []openai.ChatCompletionMessage, model string) (numTokens int) { tkm, err := tiktoken.EncodingForModel(model) if err != nil { err = fmt.Errorf("EncodingForModel: %v", err) - fmt.Println(err) + log.Println(err) return } - var tokens_per_message int - var tokens_per_name int - if model == "gpt-3.5-turbo-0301" || model == "gpt-3.5-turbo" { - tokens_per_message = 4 - tokens_per_name = -1 - } else if model == "gpt-4-0314" || model == "gpt-4" { - tokens_per_message = 3 - tokens_per_name = 1 - } else { - fmt.Println("Warning: model not found. Using cl100k_base encoding.") - tokens_per_message = 3 - tokens_per_name = 1 + var tokensPerMessage, tokensPerName int + + switch model { + case "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613": + tokensPerMessage = 3 + tokensPerName = 1 + case "gpt-3.5-turbo-0301": + tokensPerMessage = 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n + tokensPerName = -1 // if there's a name, the role is omitted + default: + if strings.Contains(model, "gpt-3.5-turbo") { + log.Println("warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") + return NumTokensFromMessages(messages, "gpt-3.5-turbo-0613") + } else if strings.Contains(model, "gpt-4") { + log.Println("warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") + return NumTokensFromMessages(messages, "gpt-4-0613") + } else { + err = fmt.Errorf("warning: unknown model [%s]. Use default calculation method converted tokens.", model) + log.Println(err) + return NumTokensFromMessages(messages, "gpt-3.5-turbo-0613") + } } for _, message := range messages { - num_tokens += tokens_per_message - num_tokens += len(tkm.Encode(message.Content, nil, nil)) - // num_tokens += len(tkm.Encode(message.Role, nil, nil)) + numTokens += tokensPerMessage + numTokens += len(tkm.Encode(message.Content, nil, nil)) + numTokens += len(tkm.Encode(message.Role, nil, nil)) + numTokens += len(tkm.Encode(message.Name, nil, nil)) if message.Name != "" { - num_tokens += tokens_per_name + numTokens += tokensPerName } } - num_tokens += 3 - return num_tokens + numTokens += 3 + return numTokens } func NumTokensFromStr(messages string, model string) (num_tokens int) {