This commit is contained in:
Sakurasan
2023-04-24 22:44:03 +08:00
parent 365bc37487
commit 5ce5466723
5 changed files with 183 additions and 35 deletions
+95 -25
View File
@@ -17,10 +17,12 @@ import (
"time"
"github.com/Sakurasan/to"
"github.com/duke-git/lancet/v2/cryptor"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/pkoukk/tiktoken-go"
"github.com/sashabaranov/go-openai"
"gorm.io/gorm"
// "github.com/pkoukk/tiktoken-go"
)
var (
@@ -295,15 +297,24 @@ func GenerateToken() string {
return token.String()
}
type _ struct {
model string
promptCount int
completionCount int
}
// type Tokens struct {
// UserID int
// PromptCount int
// CompletionCount int
// TotalTokens int
// Model string
// PromptHash string
// }
func HandleProy(c *gin.Context) {
var localuser bool
var isStream bool
var (
localuser bool
isStream bool
chatreq = openai.ChatCompletionRequest{}
chatres = openai.ChatCompletionResponse{}
chatlog store.Tokens
pre_prompt string
)
auth := c.Request.Header.Get("Authorization")
if len(auth) > 7 && auth[:7] == "Bearer " {
localuser = store.IsExistAuthCache(auth[7:])
@@ -324,14 +335,20 @@ func HandleProy(c *gin.Context) {
}
client.Transport = tr
if c.Request.URL.Path == "/v1/chat/completions" {
var chatreq = ChatCompletionRequest{}
if c.Request.URL.Path == "/v1/chat/completions" && localuser {
if err := c.BindJSON(&chatreq); err != nil {
return
// c.AbortWithError(http.StatusBadRequest,)
// c.AbortWithError(http.StatusBadRequest, err)
}
chatlog.Model = chatreq.Model
for _, m := range chatreq.Messages {
pre_prompt += m.Content + "\n"
}
chatlog.PromptHash = cryptor.Md5String(pre_prompt)
chatlog.PromptCount = NumTokensFromMessages(chatreq.Messages, chatreq.Model)
isStream = chatreq.Stream
chatlog.UserID, _ = store.GetUserID(auth[7:])
}
// 创建 API 请求
@@ -378,17 +395,24 @@ func HandleProy(c *gin.Context) {
resp.Header.Del("content-security-policy-report-only")
resp.Header.Del("clear-site-data")
// bodyRes, err := io.ReadAll(resp.Body)
// if err != nil {
// log.Println(err)
// c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
// return
// }
reader := bufio.NewReader(resp.Body)
// var resbuf = bytes.NewBuffer(nil)
var resbuf = bytes.NewBuffer(nil)
if resp.StatusCode == 200 && localuser {
if isStream {
chatlog.CompletionCount = NumTokensFromStr(<-fetchResponseContent(resbuf, reader), chatreq.Model)
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
} else {
reader.WriteTo(resbuf)
json.NewDecoder(resbuf).Decode(&chatres)
chatlog.PromptCount = chatres.Usage.PromptTokens
chatlog.CompletionCount = chatres.Usage.CompletionTokens
chatlog.TotalTokens = chatres.Usage.TotalTokens
}
chatlog.Cost = Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)
store.Record(&chatlog)
// todo insert usage && calc daily_usage
if resp.StatusCode == 200 && isStream {
//todo
}
resbody := io.NopCloser(reader)
// 返回 API 响应主体
@@ -450,12 +474,12 @@ func HandleReverseProxy(c *gin.Context) {
func Cost(model string, promptCount, completionCount int) float64 {
var cost float64
switch model {
case "gpt-3.5":
case "gpt-3.5-turbo", "gpt-3.5-turbo-0301":
cost = 0.002 * float64((promptCount+completionCount)/1000)
case "gpt-4-32k":
cost = 0.06*float64(promptCount/1000) + 0.12*float64(completionCount/1000)
case "gpt-4":
case "gpt-4", "gpt-4-0314":
cost = 0.03*float64(promptCount/1000) + 0.06*float64(completionCount/1000)
case "gpt-4-32k", "gpt-4-32k-0314":
cost = 0.06*float64(promptCount/1000) + 0.12*float64(completionCount/1000)
}
return cost
}
@@ -528,3 +552,49 @@ func fetchResponseContent(buf *bytes.Buffer, responseBody *bufio.Reader) <-chan
}()
return contentCh
}
func NumTokensFromMessages(messages []openai.ChatCompletionMessage, model string) (num_tokens int) {
tkm, err := tiktoken.EncodingForModel(model)
if err != nil {
err = fmt.Errorf("EncodingForModel: %v", err)
fmt.Println(err)
return
}
var tokens_per_message int
var tokens_per_name int
if model == "gpt-3.5-turbo-0301" || model == "gpt-3.5-turbo" {
tokens_per_message = 4
tokens_per_name = -1
} else if model == "gpt-4-0314" || model == "gpt-4" {
tokens_per_message = 3
tokens_per_name = 1
} else {
fmt.Println("Warning: model not found. Using cl100k_base encoding.")
tokens_per_message = 3
tokens_per_name = 1
}
for _, message := range messages {
num_tokens += tokens_per_message
num_tokens += len(tkm.Encode(message.Content, nil, nil))
// num_tokens += len(tkm.Encode(message.Role, nil, nil))
if message.Name != "" {
num_tokens += tokens_per_name
}
}
num_tokens += 3
return num_tokens
}
func NumTokensFromStr(messages string, model string) (num_tokens int) {
tkm, err := tiktoken.EncodingForModel(model)
if err != nil {
err = fmt.Errorf("EncodingForModel: %v", err)
fmt.Println(err)
return
}
num_tokens += len(tkm.Encode(messages, nil, nil))
return num_tokens
}