support dall-e

This commit is contained in:
Sakurasan
2023-11-23 18:03:15 +08:00
parent e8ed9aaa78
commit c838e7be16
6 changed files with 188 additions and 5 deletions

View File

@@ -10,6 +10,8 @@ OpenCat for Team的开源实现
~~基本~~实现了opencatd的全部功能 ~~基本~~实现了opencatd的全部功能
(openai附属能力:whisper,tts,dall-e(text to image)...)
## Extra Support: ## Extra Support:
| 任务 | 完成情况 | | 任务 | 完成情况 |

4
go.mod
View File

@@ -29,6 +29,7 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.1 // indirect github.com/go-playground/validator/v10 v10.14.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/hajimehoshi/go-mp3 v0.3.0 // indirect github.com/hajimehoshi/go-mp3 v0.3.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
@@ -45,7 +46,8 @@ require (
github.com/ugorji/go/codec v1.2.11 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.4.0 // indirect golang.org/x/arch v0.4.0 // indirect
golang.org/x/crypto v0.11.0 // indirect golang.org/x/crypto v0.11.0 // indirect
golang.org/x/net v0.12.0 // indirect golang.org/x/exp v0.0.0-20221208152030-732eee02a75a // indirect
golang.org/x/net v0.13.0 // indirect
golang.org/x/sys v0.10.0 // indirect golang.org/x/sys v0.10.0 // indirect
golang.org/x/text v0.11.0 // indirect golang.org/x/text v0.11.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect google.golang.org/protobuf v1.31.0 // indirect

10
go.sum
View File

@@ -46,8 +46,9 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
@@ -118,12 +119,14 @@ golang.org/x/arch v0.4.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA=
golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20221208152030-732eee02a75a h1:4iLhBPcpqFmylhnkbY3W0ONLUYYkDAW9xMFLfxgsvCw=
golang.org/x/exp v0.0.0-20221208152030-732eee02a75a/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
golang.org/x/image v0.0.0-20190220214146-31aff87c08e9/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190220214146-31aff87c08e9/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/mobile v0.0.0-20190415191353-3e0bab5405d6/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mobile v0.0.0-20190415191353-3e0bab5405d6/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= golang.org/x/net v0.13.0 h1:Nvo8UFsZ8X3BhAC9699Z1j7XQ3rsZnUUm7jfBEk1ueY=
golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/net v0.13.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190429190828-d89cdac9e872/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190429190828-d89cdac9e872/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190626150813-e07cf5db2756/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190626150813-e07cf5db2756/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -135,7 +138,6 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=

149
pkg/openai/dall-e.go Normal file
View File

@@ -0,0 +1,149 @@
package openai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/http/httputil"
"net/url"
"opencatd-open/pkg/tokenizer"
"opencatd-open/store"
"strconv"
"github.com/duke-git/lancet/v2/slice"
"github.com/gin-gonic/gin"
)
const (
DalleEndpoint = "https://api.openai.com/v1/images/generations"
DalleEditEndpoint = "https://api.openai.com/v1/images/edits"
DalleVariationEndpoint = "https://api.openai.com/v1/images/variations"
)
type DallERequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `form:"n" json:"n,omitempty"`
Size string `form:"size" json:"size,omitempty"`
Quality string `json:"quality,omitempty"` // standard,hd
Style string `json:"style,omitempty"` // vivid,natural
ResponseFormat string `json:"response_format,omitempty"` // url or b64_json
}
func DalleHandler(c *gin.Context) {
var dalleRequest DallERequest
if err := c.ShouldBind(&dalleRequest); err != nil {
c.JSON(400, gin.H{"error": err.Error()})
return
}
if dalleRequest.N == 0 {
dalleRequest.N = 1
}
if dalleRequest.Size == "" {
dalleRequest.Size = "512x512"
}
model := dalleRequest.Model
var chatlog store.Tokens
chatlog.Model = model
chatlog.CompletionCount = dalleRequest.N
if model == "dall-e" {
model = "dall-e-2"
}
model = model + "." + dalleRequest.Size
if model == "dall-e-2" {
if !slice.Contain([]string{"256x256", "512x512", "1024x1024"}, dalleRequest.Size) {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Invalid size: %s for %s", dalleRequest.Size, dalleRequest.Model),
},
})
return
}
} else if model == "dall-e-3" {
if !slice.Contain([]string{"256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"}, dalleRequest.Size) {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Invalid size: %s for %s", dalleRequest.Size, dalleRequest.Model),
},
})
return
}
if dalleRequest.Quality == "HD" {
model = model + ".HD"
}
} else {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Invalid model: %s", dalleRequest.Model),
},
})
return
}
token, _ := c.Get("localuser")
lu, err := store.GetUserByToken(token.(string))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
},
})
return
}
chatlog.UserID = int(lu.ID)
key, err := store.SelectKeyCache("openai")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
},
})
return
}
targetURL, _ := url.Parse(DalleEndpoint)
proxy := httputil.NewSingleHostReverseProxy(targetURL)
proxy.Director = func(req *http.Request) {
req.Header.Set("Authorization", "Bearer "+key.Key)
req.Header.Set("Content-Type", "application/json")
req.Host = targetURL.Host
req.URL.Scheme = targetURL.Scheme
req.URL.Host = targetURL.Host
req.URL.Path = targetURL.Path
req.URL.RawPath = targetURL.RawPath
req.URL.RawQuery = targetURL.RawQuery
bytebody, _ := json.Marshal(dalleRequest)
req.Body = io.NopCloser(bytes.NewBuffer(bytebody))
req.ContentLength = int64(len(bytebody))
req.Header.Set("Content-Length", strconv.Itoa(len(bytebody)))
}
proxy.ModifyResponse = func(resp *http.Response) error {
if resp.StatusCode == http.StatusOK {
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
if err := store.Record(&chatlog); err != nil {
log.Println(err)
}
if err := store.SumDaily(chatlog.UserID); err != nil {
log.Println(err)
}
}
return nil
}
proxy.ServeHTTP(c.Writer, c.Request)
}

View File

@@ -102,6 +102,29 @@ func Cost(model string, promptCount, completionCount int) float64 {
cost = 0.015 * float64(prompt+completion) cost = 0.015 * float64(prompt+completion)
case "tts-1-hd": case "tts-1-hd":
cost = 0.03 * float64(prompt+completion) cost = 0.03 * float64(prompt+completion)
case "dall-e-2.256x256":
cost = float64(0.016 * completion)
case "dall-e-2.512x512":
cost = float64(0.018 * completion)
case "dall-e-2.1024x1024":
cost = float64(0.02 * completion)
case "dall-e-3.256x256":
cost = float64(0.04 * completion)
case "dall-e-3.512x512":
cost = float64(0.04 * completion)
case "dall-e-3.1024x1024":
cost = float64(0.04 * completion)
case "dall-e-3.1024x1792", "dall-e-3.1792x1024":
cost = float64(0.08 * completion)
case "dall-e-3.256x256.HD":
cost = float64(0.08 * completion)
case "dall-e-3.512x512.HD":
cost = float64(0.08 * completion)
case "dall-e-3.1024x1024.HD":
cost = float64(0.08 * completion)
case "dall-e-3.1024x1792.HD", "dall-e-3.1792x1024.HD":
cost = float64(0.12 * completion)
// claude /million tokens // claude /million tokens
case "claude-v1", "claude-v1-100k": case "claude-v1", "claude-v1-100k":
cost = 11.02/1000000*float64(prompt) + (32.68/1000000)*float64(completion) cost = 11.02/1000000*float64(prompt) + (32.68/1000000)*float64(completion)

View File

@@ -486,6 +486,11 @@ func HandleProy(c *gin.Context) {
return return
} }
if c.Request.URL.Path == "/v1/images/generations" {
oai.DalleHandler(c)
return
}
if c.Request.URL.Path == "/v1/chat/completions" && localuser { if c.Request.URL.Path == "/v1/chat/completions" && localuser {
if store.KeysCache.ItemCount() == 0 { if store.KeysCache.ItemCount() == 0 {
c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{ c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{