support dall-e
This commit is contained in:
@@ -10,6 +10,8 @@ OpenCat for Team的开源实现
|
||||
|
||||
~~基本~~实现了opencatd的全部功能
|
||||
|
||||
(openai附属能力:whisper,tts,dall-e(text to image)...)
|
||||
|
||||
## Extra Support:
|
||||
|
||||
| 任务 | 完成情况 |
|
||||
|
||||
4
go.mod
4
go.mod
@@ -29,6 +29,7 @@ require (
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.14.1 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/go-cmp v0.5.9 // indirect
|
||||
github.com/hajimehoshi/go-mp3 v0.3.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
@@ -45,7 +46,8 @@ require (
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
golang.org/x/arch v0.4.0 // indirect
|
||||
golang.org/x/crypto v0.11.0 // indirect
|
||||
golang.org/x/net v0.12.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20221208152030-732eee02a75a // indirect
|
||||
golang.org/x/net v0.13.0 // indirect
|
||||
golang.org/x/sys v0.10.0 // indirect
|
||||
golang.org/x/text v0.11.0 // indirect
|
||||
google.golang.org/protobuf v1.31.0 // indirect
|
||||
|
||||
10
go.sum
10
go.sum
@@ -46,8 +46,9 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
@@ -118,12 +119,14 @@ golang.org/x/arch v0.4.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA=
|
||||
golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20221208152030-732eee02a75a h1:4iLhBPcpqFmylhnkbY3W0ONLUYYkDAW9xMFLfxgsvCw=
|
||||
golang.org/x/exp v0.0.0-20221208152030-732eee02a75a/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
|
||||
golang.org/x/image v0.0.0-20190220214146-31aff87c08e9/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/mobile v0.0.0-20190415191353-3e0bab5405d6/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
|
||||
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50=
|
||||
golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
|
||||
golang.org/x/net v0.13.0 h1:Nvo8UFsZ8X3BhAC9699Z1j7XQ3rsZnUUm7jfBEk1ueY=
|
||||
golang.org/x/net v0.13.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190429190828-d89cdac9e872/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190626150813-e07cf5db2756/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -135,7 +138,6 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
|
||||
golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
|
||||
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
|
||||
149
pkg/openai/dall-e.go
Normal file
149
pkg/openai/dall-e.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"opencatd-open/pkg/tokenizer"
|
||||
"opencatd-open/store"
|
||||
"strconv"
|
||||
|
||||
"github.com/duke-git/lancet/v2/slice"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
DalleEndpoint = "https://api.openai.com/v1/images/generations"
|
||||
DalleEditEndpoint = "https://api.openai.com/v1/images/edits"
|
||||
DalleVariationEndpoint = "https://api.openai.com/v1/images/variations"
|
||||
)
|
||||
|
||||
type DallERequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `form:"n" json:"n,omitempty"`
|
||||
Size string `form:"size" json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"` // standard,hd
|
||||
Style string `json:"style,omitempty"` // vivid,natural
|
||||
ResponseFormat string `json:"response_format,omitempty"` // url or b64_json
|
||||
}
|
||||
|
||||
func DalleHandler(c *gin.Context) {
|
||||
|
||||
var dalleRequest DallERequest
|
||||
if err := c.ShouldBind(&dalleRequest); err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if dalleRequest.N == 0 {
|
||||
dalleRequest.N = 1
|
||||
}
|
||||
|
||||
if dalleRequest.Size == "" {
|
||||
dalleRequest.Size = "512x512"
|
||||
}
|
||||
|
||||
model := dalleRequest.Model
|
||||
|
||||
var chatlog store.Tokens
|
||||
chatlog.Model = model
|
||||
chatlog.CompletionCount = dalleRequest.N
|
||||
|
||||
if model == "dall-e" {
|
||||
model = "dall-e-2"
|
||||
}
|
||||
model = model + "." + dalleRequest.Size
|
||||
|
||||
if model == "dall-e-2" {
|
||||
if !slice.Contain([]string{"256x256", "512x512", "1024x1024"}, dalleRequest.Size) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Invalid size: %s for %s", dalleRequest.Size, dalleRequest.Model),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
} else if model == "dall-e-3" {
|
||||
if !slice.Contain([]string{"256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"}, dalleRequest.Size) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Invalid size: %s for %s", dalleRequest.Size, dalleRequest.Model),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
if dalleRequest.Quality == "HD" {
|
||||
model = model + ".HD"
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Invalid model: %s", dalleRequest.Model),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
token, _ := c.Get("localuser")
|
||||
|
||||
lu, err := store.GetUserByToken(token.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
chatlog.UserID = int(lu.ID)
|
||||
|
||||
key, err := store.SelectKeyCache("openai")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
targetURL, _ := url.Parse(DalleEndpoint)
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||
proxy.Director = func(req *http.Request) {
|
||||
req.Header.Set("Authorization", "Bearer "+key.Key)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
req.Host = targetURL.Host
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
req.URL.Host = targetURL.Host
|
||||
req.URL.Path = targetURL.Path
|
||||
req.URL.RawPath = targetURL.RawPath
|
||||
req.URL.RawQuery = targetURL.RawQuery
|
||||
|
||||
bytebody, _ := json.Marshal(dalleRequest)
|
||||
req.Body = io.NopCloser(bytes.NewBuffer(bytebody))
|
||||
req.ContentLength = int64(len(bytebody))
|
||||
req.Header.Set("Content-Length", strconv.Itoa(len(bytebody)))
|
||||
}
|
||||
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
|
||||
chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
|
||||
if err := store.Record(&chatlog); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
if err := store.SumDaily(chatlog.UserID); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
proxy.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
@@ -102,6 +102,29 @@ func Cost(model string, promptCount, completionCount int) float64 {
|
||||
cost = 0.015 * float64(prompt+completion)
|
||||
case "tts-1-hd":
|
||||
cost = 0.03 * float64(prompt+completion)
|
||||
case "dall-e-2.256x256":
|
||||
cost = float64(0.016 * completion)
|
||||
case "dall-e-2.512x512":
|
||||
cost = float64(0.018 * completion)
|
||||
case "dall-e-2.1024x1024":
|
||||
cost = float64(0.02 * completion)
|
||||
case "dall-e-3.256x256":
|
||||
cost = float64(0.04 * completion)
|
||||
case "dall-e-3.512x512":
|
||||
cost = float64(0.04 * completion)
|
||||
case "dall-e-3.1024x1024":
|
||||
cost = float64(0.04 * completion)
|
||||
case "dall-e-3.1024x1792", "dall-e-3.1792x1024":
|
||||
cost = float64(0.08 * completion)
|
||||
case "dall-e-3.256x256.HD":
|
||||
cost = float64(0.08 * completion)
|
||||
case "dall-e-3.512x512.HD":
|
||||
cost = float64(0.08 * completion)
|
||||
case "dall-e-3.1024x1024.HD":
|
||||
cost = float64(0.08 * completion)
|
||||
case "dall-e-3.1024x1792.HD", "dall-e-3.1792x1024.HD":
|
||||
cost = float64(0.12 * completion)
|
||||
|
||||
// claude /million tokens
|
||||
case "claude-v1", "claude-v1-100k":
|
||||
cost = 11.02/1000000*float64(prompt) + (32.68/1000000)*float64(completion)
|
||||
|
||||
@@ -486,6 +486,11 @@ func HandleProy(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if c.Request.URL.Path == "/v1/images/generations" {
|
||||
oai.DalleHandler(c)
|
||||
return
|
||||
}
|
||||
|
||||
if c.Request.URL.Path == "/v1/chat/completions" && localuser {
|
||||
if store.KeysCache.ItemCount() == 0 {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": gin.H{
|
||||
|
||||
Reference in New Issue
Block a user