Compare commits

...

2 Commits

Author SHA1 Message Date
Sakurasan
1cd06ea1df up 2023-04-19 22:50:39 +08:00
Sakurasan
499edbb0fd add Completion 2023-04-19 02:06:09 +08:00
2 changed files with 112 additions and 17 deletions

View File

@@ -59,9 +59,12 @@ func main() {
// 初始化用户
r.POST("/1/users/init", router.Handleinit)
r.POST("/v1/chat/completions", router.HandleProy)
r.GET("/v1/models", router.HandleProy)
r.GET("/v1/dashboard/billing/subscription", router.HandleProy)
r.Any("/v1/*proxypath", router.HandleProy)
// r.POST("/v1/chat/completions", router.HandleProy)
// r.GET("/v1/models", router.HandleProy)
// r.GET("/v1/dashboard/billing/subscription", router.HandleProy)
r.GET("/", func(c *gin.Context) {
c.Writer.WriteHeader(http.StatusOK)
c.Writer.WriteString(`<h1><a href="https://github.com/mirrors2/opencatd-open" >opencatd-open</a> available</h1>Api-Keys:<a href=https://platform.openai.com/account/api-keys >https://platform.openai.com/account/api-keys</a>`)

View File

@@ -1,7 +1,7 @@
package router
import (
"bytes"
"bufio"
"crypto/tls"
"errors"
"fmt"
@@ -11,6 +11,7 @@ import (
"net/http"
"net/http/httputil"
"opencatd-open/store"
"strings"
"time"
"github.com/Sakurasan/to"
@@ -20,8 +21,10 @@ import (
)
var (
rootToken string
baseUrl = "https://api.openai.com"
rootToken string
baseUrl = "https://api.openai.com"
GPT3Dot5Turbo = "gpt-3.5-turbo"
GPT4 = "gpt-4"
)
type User struct {
@@ -41,6 +44,46 @@ type Key struct {
CreatedAt string `json:"createdAt,omitempty"`
}
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Name string `json:"name,omitempty"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Stream bool `json:"stream,omitempty"`
Stop []string `json:"stop,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
LogitBias map[string]int `json:"logit_bias,omitempty"`
User string `json:"user,omitempty"`
}
type ChatCompletionChoice struct {
Index int `json:"index"`
Message ChatCompletionMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionChoice `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
func AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if rootToken == "" {
@@ -251,6 +294,7 @@ func GenerateToken() string {
func HandleProy(c *gin.Context) {
var localuser bool
var isStream bool
auth := c.Request.Header.Get("Authorization")
if len(auth) > 7 && auth[:7] == "Bearer " {
localuser = store.IsExistAuthCache(auth[7:])
@@ -271,8 +315,18 @@ func HandleProy(c *gin.Context) {
}
client.Transport = tr
if c.Request.URL.Path == "/v1/chat/completions" {
var chatreq = ChatCompletionRequest{}
if err := c.BindJSON(&chatreq); err != nil {
return
// c.AbortWithError(http.StatusBadRequest,)
}
isStream = chatreq.Stream
}
// 创建 API 请求
req, err := http.NewRequest(c.Request.Method, baseUrl+c.Request.URL.Path, c.Request.Body)
req, err := http.NewRequest(c.Request.Method, baseUrl+c.Request.RequestURI, c.Request.Body)
if err != nil {
log.Println(err)
c.JSON(http.StatusOK, gin.H{"error": err.Error()})
@@ -315,17 +369,18 @@ func HandleProy(c *gin.Context) {
resp.Header.Del("content-security-policy-report-only")
resp.Header.Del("clear-site-data")
bodyRes, err := io.ReadAll(resp.Body)
if err != nil {
log.Println(err)
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
// bodyRes, err := io.ReadAll(resp.Body)
// if err != nil {
// log.Println(err)
// c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
// return
// }
reader := bufio.NewReader(resp.Body)
if resp.StatusCode == 200 && isStream {
//todo
}
if resp.StatusCode == 200 {
// todo
log.Println(string(bodyRes))
}
resbody := io.NopCloser(bytes.NewReader(bodyRes))
resbody := io.NopCloser(reader)
// 返回 API 响应主体
c.Writer.WriteHeader(resp.StatusCode)
if _, err := io.Copy(c.Writer, resbody); err != nil {
@@ -407,3 +462,40 @@ func HandleUsage(c *gin.Context) {
c.JSON(200, usage)
}
// todo
func streamEvent(reader *bufio.Reader) error {
lineChan := make(chan string, 1)
timeout := time.AfterFunc(60*time.Second, func() {
lineChan <- ""
})
go func() {
line, err := reader.ReadString('\n')
if err == nil {
lineChan <- line
}
}()
line := <-lineChan
timeout.Stop()
if line == "" {
}
if strings.HasPrefix(line, "data:") {
line = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
//log.Println("Received data:", line)
if line == "[DONE]" {
}
line = strings.TrimSpace(line)
}
return nil
}