From 7df0b2817c5c4a10ce2a129edd705a7d762f8166 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=E8=8F=8C?= Date: Thu, 14 Sep 2023 22:51:47 +0800 Subject: [PATCH] update --- pkg/claude/claude.go | 141 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 137 insertions(+), 4 deletions(-) diff --git a/pkg/claude/claude.go b/pkg/claude/claude.go index b23c90b..53cb6b7 100644 --- a/pkg/claude/claude.go +++ b/pkg/claude/claude.go @@ -44,6 +44,7 @@ import ( "net/url" "opencatd-open/store" "strings" + "time" "github.com/gin-gonic/gin" "github.com/sashabaranov/go-openai" @@ -81,8 +82,15 @@ type CompleteResponse struct { } func Create() { + complet := CompleteRequest{ + Model: "claude-2", + Prompt: "Human: Hello, world!\\n\\nAssistant:", + Stream: true, + } + var payload *bytes.Buffer + json.NewEncoder(payload).Encode(complet) - payload := strings.NewReader("{\"model\":\"claude-2\",\"prompt\":\"\\n\\nHuman: Hello, world!\\n\\nAssistant:\",\"max_tokens_to_sample\":256}") + // payload := strings.NewReader("{\"model\":\"claude-2\",\"prompt\":\"\\n\\nHuman: Hello, world!\\n\\nAssistant:\",\"max_tokens_to_sample\":256}") req, _ := http.NewRequest("POST", ClaudeUrl, payload) @@ -94,9 +102,24 @@ func Create() { res, _ := http.DefaultClient.Do(req) defer res.Body.Close() - body, _ := io.ReadAll(res.Body) + // body, _ := io.ReadAll(res.Body) - fmt.Println(string(body)) + // fmt.Println(string(body)) + reader := bufio.NewReader(res.Body) + for { + line, err := reader.ReadString('\n') + if err == nil { + if strings.HasPrefix(line, "data:") { + fmt.Println(line) + // var result CompleteResponse + // json.Unmarshal() + } else { + continue + } + } else { + break + } + } } func ClaudeProxy(c *gin.Context) { @@ -205,6 +228,116 @@ func ClaudeProxy(c *gin.Context) { proxy.ServeHTTP(c.Writer, c.Request) } +func TransReq(chatreq *openai.ChatCompletionRequest) (*bytes.Buffer, error) { + transReq := CompleteRequest{ + Model: chatreq.Model, + Temperature: int(chatreq.Temperature), + TopP: int(chatreq.TopP), + Stream: chatreq.Stream, + MaxTokensToSample: chatreq.MaxTokens, + } + var prompt string + for _, msg := range chatreq.Messages { + switch msg.Role { + case "system": + prompt += fmt.Sprintf("\n\nSystem:%s", msg.Content) + case "user": + prompt += fmt.Sprintf("\n\nUser:%s", msg.Content) + case "assistant": + prompt += fmt.Sprintf("\n\nAssistant:%s", msg.Content) + } + } + transReq.Prompt = prompt + "\n\nAssistant:" + var payload = bytes.NewBuffer(nil) + if err := json.NewEncoder(payload).Encode(transReq); err != nil { + return nil, err + } + return payload, nil +} + +func TransRsp(c *gin.Context, isStream bool, responseBody *bufio.Reader) { + if !isStream { + var completersp CompleteResponse + var chatrsp openai.ChatCompletionResponse + json.NewDecoder(responseBody).Decode(&completersp) + chatrsp.Model = completersp.Model + chatrsp.ID = completersp.LogID + chatrsp.Object = "chat.completion" + chatrsp.Created = time.Now().Unix() + choice := openai.ChatCompletionChoice{ + Index: 0, + FinishReason: "stop", + Message: openai.ChatCompletionMessage{ + Role: "assistant", + Content: completersp.Completion, + }, + } + chatrsp.Choices = append(chatrsp.Choices, choice) + var payload *bytes.Buffer + if err := json.NewEncoder(payload).Encode(chatrsp); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": err.Error(), + }, + }) + return + } + c.JSON(http.StatusOK, payload) + return + } else { + result := TranslateStream(c, responseBody) + for _, content := range <-result { + c.Writer.WriteString(string(content)) + c.Writer.Flush() + } + } +} + +func TranslateStream(ctx *gin.Context, reader *bufio.Reader,dataChan chan string,stopChan chan bool) { + // dataChan := make(chan string) + // stopChan := make(chan bool) + go func () { + for { + line, err := reader.ReadString('\n') + if err == nil { + if strings.HasPrefix(line, "data: ") { + + var result CompleteResponse + json.NewDecoder(strings.NewReader(line[6:])).Decode(&result) + if result.StopReason == "" { + if result.Completion != "" { + chatrsp := openai.ChatCompletionStreamResponse{ + ID: result.LogID, + Model: result.Model, + Object: "chat.completion", + Created: time.Now().Unix(), + } + choice := openai.ChatCompletionStreamChoice{ + Delta: openai.ChatCompletionStreamChoiceDelta{ + Role: "assistant", + Content: result.Completion, + }, + FinishReason: "", + } + chatrsp.Choices = append(chatrsp.Choices, choice) + bytedate, _ := json.Marshal(chatrsp) + dataChan <- string(bytedate) + } + } else { + log.Println("finish:", result.StopReason) + dataChan <- "data: [DONE]" + break + } + } else { + continue + } + } else { + log.Println("err:", err) + break + } + }() +} + func Translate(c *gin.Context, chatreq *openai.ChatCompletionRequest) { transReq := CompleteRequest{ Model: chatreq.Model, @@ -298,7 +431,7 @@ func Translate(c *gin.Context, chatreq *openai.ChatCompletionRequest) { writer.WriteString("[DONE]") writer.Flush() } else { - + return } } }