From c838e7be166810ef546d24a076ddddd69cc08bea Mon Sep 17 00:00:00 2001 From: Sakurasan <1173092237@qq.com> Date: Thu, 23 Nov 2023 18:03:15 +0800 Subject: [PATCH] support dall-e --- README.md | 2 + go.mod | 4 +- go.sum | 10 ++- pkg/openai/dall-e.go | 149 +++++++++++++++++++++++++++++++++++++ pkg/tokenizer/tokenizer.go | 23 ++++++ router/router.go | 5 ++ 6 files changed, 188 insertions(+), 5 deletions(-) create mode 100644 pkg/openai/dall-e.go diff --git a/README.md b/README.md index f5d2746..63579c1 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,8 @@ OpenCat for Team的开源实现 ~~基本~~实现了opencatd的全部功能 +(openai附属能力:whisper,tts,dall-e(text to image)...) + ## Extra Support: | 任务 | 完成情况 | diff --git a/go.mod b/go.mod index eb10981..0661027 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.14.1 // indirect github.com/goccy/go-json v0.10.2 // indirect + github.com/google/go-cmp v0.5.9 // indirect github.com/hajimehoshi/go-mp3 v0.3.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect @@ -45,7 +46,8 @@ require ( github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.4.0 // indirect golang.org/x/crypto v0.11.0 // indirect - golang.org/x/net v0.12.0 // indirect + golang.org/x/exp v0.0.0-20221208152030-732eee02a75a // indirect + golang.org/x/net v0.13.0 // indirect golang.org/x/sys v0.10.0 // indirect golang.org/x/text v0.11.0 // indirect google.golang.org/protobuf v1.31.0 // indirect diff --git a/go.sum b/go.sum index 5cbf28e..ff24e7f 100644 --- a/go.sum +++ b/go.sum @@ -46,8 +46,9 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= @@ -118,12 +119,14 @@ golang.org/x/arch v0.4.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20221208152030-732eee02a75a h1:4iLhBPcpqFmylhnkbY3W0ONLUYYkDAW9xMFLfxgsvCw= +golang.org/x/exp v0.0.0-20221208152030-732eee02a75a/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/image v0.0.0-20190220214146-31aff87c08e9/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/mobile v0.0.0-20190415191353-3e0bab5405d6/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= -golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= +golang.org/x/net v0.13.0 h1:Nvo8UFsZ8X3BhAC9699Z1j7XQ3rsZnUUm7jfBEk1ueY= +golang.org/x/net v0.13.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190429190828-d89cdac9e872/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190626150813-e07cf5db2756/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -135,7 +138,6 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= diff --git a/pkg/openai/dall-e.go b/pkg/openai/dall-e.go new file mode 100644 index 0000000..280ab87 --- /dev/null +++ b/pkg/openai/dall-e.go @@ -0,0 +1,149 @@ +package openai + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/http/httputil" + "net/url" + "opencatd-open/pkg/tokenizer" + "opencatd-open/store" + "strconv" + + "github.com/duke-git/lancet/v2/slice" + "github.com/gin-gonic/gin" +) + +const ( + DalleEndpoint = "https://api.openai.com/v1/images/generations" + DalleEditEndpoint = "https://api.openai.com/v1/images/edits" + DalleVariationEndpoint = "https://api.openai.com/v1/images/variations" +) + +type DallERequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + N int `form:"n" json:"n,omitempty"` + Size string `form:"size" json:"size,omitempty"` + Quality string `json:"quality,omitempty"` // standard,hd + Style string `json:"style,omitempty"` // vivid,natural + ResponseFormat string `json:"response_format,omitempty"` // url or b64_json +} + +func DalleHandler(c *gin.Context) { + + var dalleRequest DallERequest + if err := c.ShouldBind(&dalleRequest); err != nil { + c.JSON(400, gin.H{"error": err.Error()}) + return + } + if dalleRequest.N == 0 { + dalleRequest.N = 1 + } + + if dalleRequest.Size == "" { + dalleRequest.Size = "512x512" + } + + model := dalleRequest.Model + + var chatlog store.Tokens + chatlog.Model = model + chatlog.CompletionCount = dalleRequest.N + + if model == "dall-e" { + model = "dall-e-2" + } + model = model + "." + dalleRequest.Size + + if model == "dall-e-2" { + if !slice.Contain([]string{"256x256", "512x512", "1024x1024"}, dalleRequest.Size) { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": fmt.Sprintf("Invalid size: %s for %s", dalleRequest.Size, dalleRequest.Model), + }, + }) + return + } + } else if model == "dall-e-3" { + if !slice.Contain([]string{"256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"}, dalleRequest.Size) { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": fmt.Sprintf("Invalid size: %s for %s", dalleRequest.Size, dalleRequest.Model), + }, + }) + return + } + if dalleRequest.Quality == "HD" { + model = model + ".HD" + } + } else { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": fmt.Sprintf("Invalid model: %s", dalleRequest.Model), + }, + }) + return + } + + token, _ := c.Get("localuser") + + lu, err := store.GetUserByToken(token.(string)) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": err.Error(), + }, + }) + return + } + chatlog.UserID = int(lu.ID) + + key, err := store.SelectKeyCache("openai") + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": err.Error(), + }, + }) + return + } + + targetURL, _ := url.Parse(DalleEndpoint) + proxy := httputil.NewSingleHostReverseProxy(targetURL) + proxy.Director = func(req *http.Request) { + req.Header.Set("Authorization", "Bearer "+key.Key) + req.Header.Set("Content-Type", "application/json") + + req.Host = targetURL.Host + req.URL.Scheme = targetURL.Scheme + req.URL.Host = targetURL.Host + req.URL.Path = targetURL.Path + req.URL.RawPath = targetURL.RawPath + req.URL.RawQuery = targetURL.RawQuery + + bytebody, _ := json.Marshal(dalleRequest) + req.Body = io.NopCloser(bytes.NewBuffer(bytebody)) + req.ContentLength = int64(len(bytebody)) + req.Header.Set("Content-Length", strconv.Itoa(len(bytebody))) + } + + proxy.ModifyResponse = func(resp *http.Response) error { + if resp.StatusCode == http.StatusOK { + chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount + chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount)) + if err := store.Record(&chatlog); err != nil { + log.Println(err) + } + if err := store.SumDaily(chatlog.UserID); err != nil { + log.Println(err) + } + } + return nil + } + + proxy.ServeHTTP(c.Writer, c.Request) +} diff --git a/pkg/tokenizer/tokenizer.go b/pkg/tokenizer/tokenizer.go index aaa9c98..c7240e0 100644 --- a/pkg/tokenizer/tokenizer.go +++ b/pkg/tokenizer/tokenizer.go @@ -102,6 +102,29 @@ func Cost(model string, promptCount, completionCount int) float64 { cost = 0.015 * float64(prompt+completion) case "tts-1-hd": cost = 0.03 * float64(prompt+completion) + case "dall-e-2.256x256": + cost = float64(0.016 * completion) + case "dall-e-2.512x512": + cost = float64(0.018 * completion) + case "dall-e-2.1024x1024": + cost = float64(0.02 * completion) + case "dall-e-3.256x256": + cost = float64(0.04 * completion) + case "dall-e-3.512x512": + cost = float64(0.04 * completion) + case "dall-e-3.1024x1024": + cost = float64(0.04 * completion) + case "dall-e-3.1024x1792", "dall-e-3.1792x1024": + cost = float64(0.08 * completion) + case "dall-e-3.256x256.HD": + cost = float64(0.08 * completion) + case "dall-e-3.512x512.HD": + cost = float64(0.08 * completion) + case "dall-e-3.1024x1024.HD": + cost = float64(0.08 * completion) + case "dall-e-3.1024x1792.HD", "dall-e-3.1792x1024.HD": + cost = float64(0.12 * completion) + // claude /million tokens case "claude-v1", "claude-v1-100k": cost = 11.02/1000000*float64(prompt) + (32.68/1000000)*float64(completion) diff --git a/router/router.go b/router/router.go index b619d4c..dd3b5fb 100644 --- a/router/router.go +++ b/router/router.go @@ -486,6 +486,11 @@ func HandleProy(c *gin.Context) { return } + if c.Request.URL.Path == "/v1/images/generations" { + oai.DalleHandler(c) + return + } + if c.Request.URL.Path == "/v1/chat/completions" && localuser { if store.KeysCache.ItemCount() == 0 { c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{