support dall-e
This commit is contained in:
149
pkg/openai/dall-e.go
Normal file
149
pkg/openai/dall-e.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"opencatd-open/pkg/tokenizer"
|
||||
"opencatd-open/store"
|
||||
"strconv"
|
||||
|
||||
"github.com/duke-git/lancet/v2/slice"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
DalleEndpoint = "https://api.openai.com/v1/images/generations"
|
||||
DalleEditEndpoint = "https://api.openai.com/v1/images/edits"
|
||||
DalleVariationEndpoint = "https://api.openai.com/v1/images/variations"
|
||||
)
|
||||
|
||||
type DallERequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `form:"n" json:"n,omitempty"`
|
||||
Size string `form:"size" json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"` // standard,hd
|
||||
Style string `json:"style,omitempty"` // vivid,natural
|
||||
ResponseFormat string `json:"response_format,omitempty"` // url or b64_json
|
||||
}
|
||||
|
||||
func DalleHandler(c *gin.Context) {
|
||||
|
||||
var dalleRequest DallERequest
|
||||
if err := c.ShouldBind(&dalleRequest); err != nil {
|
||||
c.JSON(400, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if dalleRequest.N == 0 {
|
||||
dalleRequest.N = 1
|
||||
}
|
||||
|
||||
if dalleRequest.Size == "" {
|
||||
dalleRequest.Size = "512x512"
|
||||
}
|
||||
|
||||
model := dalleRequest.Model
|
||||
|
||||
var chatlog store.Tokens
|
||||
chatlog.Model = model
|
||||
chatlog.CompletionCount = dalleRequest.N
|
||||
|
||||
if model == "dall-e" {
|
||||
model = "dall-e-2"
|
||||
}
|
||||
model = model + "." + dalleRequest.Size
|
||||
|
||||
if model == "dall-e-2" {
|
||||
if !slice.Contain([]string{"256x256", "512x512", "1024x1024"}, dalleRequest.Size) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Invalid size: %s for %s", dalleRequest.Size, dalleRequest.Model),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
} else if model == "dall-e-3" {
|
||||
if !slice.Contain([]string{"256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"}, dalleRequest.Size) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Invalid size: %s for %s", dalleRequest.Size, dalleRequest.Model),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
if dalleRequest.Quality == "HD" {
|
||||
model = model + ".HD"
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Invalid model: %s", dalleRequest.Model),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
token, _ := c.Get("localuser")
|
||||
|
||||
lu, err := store.GetUserByToken(token.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
chatlog.UserID = int(lu.ID)
|
||||
|
||||
key, err := store.SelectKeyCache("openai")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
targetURL, _ := url.Parse(DalleEndpoint)
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||
proxy.Director = func(req *http.Request) {
|
||||
req.Header.Set("Authorization", "Bearer "+key.Key)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
req.Host = targetURL.Host
|
||||
req.URL.Scheme = targetURL.Scheme
|
||||
req.URL.Host = targetURL.Host
|
||||
req.URL.Path = targetURL.Path
|
||||
req.URL.RawPath = targetURL.RawPath
|
||||
req.URL.RawQuery = targetURL.RawQuery
|
||||
|
||||
bytebody, _ := json.Marshal(dalleRequest)
|
||||
req.Body = io.NopCloser(bytes.NewBuffer(bytebody))
|
||||
req.ContentLength = int64(len(bytebody))
|
||||
req.Header.Set("Content-Length", strconv.Itoa(len(bytebody)))
|
||||
}
|
||||
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
chatlog.TotalTokens = chatlog.PromptCount + chatlog.CompletionCount
|
||||
chatlog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(chatlog.Model, chatlog.PromptCount, chatlog.CompletionCount))
|
||||
if err := store.Record(&chatlog); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
if err := store.SumDaily(chatlog.UserID); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
proxy.ServeHTTP(c.Writer, c.Request)
|
||||
}
|
||||
@@ -102,6 +102,29 @@ func Cost(model string, promptCount, completionCount int) float64 {
|
||||
cost = 0.015 * float64(prompt+completion)
|
||||
case "tts-1-hd":
|
||||
cost = 0.03 * float64(prompt+completion)
|
||||
case "dall-e-2.256x256":
|
||||
cost = float64(0.016 * completion)
|
||||
case "dall-e-2.512x512":
|
||||
cost = float64(0.018 * completion)
|
||||
case "dall-e-2.1024x1024":
|
||||
cost = float64(0.02 * completion)
|
||||
case "dall-e-3.256x256":
|
||||
cost = float64(0.04 * completion)
|
||||
case "dall-e-3.512x512":
|
||||
cost = float64(0.04 * completion)
|
||||
case "dall-e-3.1024x1024":
|
||||
cost = float64(0.04 * completion)
|
||||
case "dall-e-3.1024x1792", "dall-e-3.1792x1024":
|
||||
cost = float64(0.08 * completion)
|
||||
case "dall-e-3.256x256.HD":
|
||||
cost = float64(0.08 * completion)
|
||||
case "dall-e-3.512x512.HD":
|
||||
cost = float64(0.08 * completion)
|
||||
case "dall-e-3.1024x1024.HD":
|
||||
cost = float64(0.08 * completion)
|
||||
case "dall-e-3.1024x1792.HD", "dall-e-3.1792x1024.HD":
|
||||
cost = float64(0.12 * completion)
|
||||
|
||||
// claude /million tokens
|
||||
case "claude-v1", "claude-v1-100k":
|
||||
cost = 11.02/1000000*float64(prompt) + (32.68/1000000)*float64(completion)
|
||||
|
||||
Reference in New Issue
Block a user