refact & update new model

This commit is contained in:
Sakurasan
2024-09-13 21:32:10 +08:00
parent c11824f5aa
commit 7fd82b43f4
21 changed files with 1094 additions and 975 deletions

View File

@@ -10,8 +10,10 @@ import (
"io"
"log"
"net/http"
"opencatd-open/pkg/error"
"opencatd-open/pkg/openai"
"opencatd-open/pkg/tokenizer"
"opencatd-open/pkg/vertexai"
"opencatd-open/store"
"strings"
@@ -27,14 +29,15 @@ func ChatTextCompletions(c *gin.Context, chatReq *openai.ChatCompletionRequest)
}
type ChatRequest struct {
Model string `json:"model,omitempty"`
Messages any `json:"messages,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream,omitempty"`
System string `json:"system,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Model string `json:"model,omitempty"`
Messages any `json:"messages,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream,omitempty"`
System string `json:"system,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
AnthropicVersion string `json:"anthropic_version,omitempty"`
}
func (c *ChatRequest) ByteJson() []byte {
@@ -117,8 +120,12 @@ type ClaudeStreamResponse struct {
}
func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
var (
req *http.Request
targetURL = ClaudeMessageEndpoint
)
onekey, err := store.SelectKeyCache("claude")
apiKey, err := store.SelectKeyCache("claude")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -131,6 +138,10 @@ func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
// claudReq.Temperature = chatReq.Temperature
claudReq.TopP = chatReq.TopP
claudReq.MaxTokens = 4096
if apiKey.ApiType == "vertex" {
claudReq.AnthropicVersion = "vertex-2023-10-16"
claudReq.Model = ""
}
var prompt string
@@ -181,10 +192,40 @@ func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
usagelog.PromptCount = tokenizer.NumTokensFromStr(prompt, chatReq.Model)
req, _ := http.NewRequest("POST", MessageEndpoint, bytes.NewReader(claudReq.ByteJson()))
req.Header.Set("x-api-key", onekey.Key)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("Content-Type", "application/json")
if apiKey.ApiType == "vertex" {
var vertexSecret vertexai.VertexSecretKey
if err := json.Unmarshal([]byte(apiKey.ApiSecret), &vertexSecret); err != nil {
c.JSON(http.StatusInternalServerError, error.ErrorData(err.Error()))
return
}
vcmodel, ok := vertexai.VertexClaudeModelMap[chatReq.Model]
if !ok {
c.JSON(http.StatusInternalServerError, error.ErrorData("Model not found"))
return
}
// 获取gcloud token临时放置在apiKey.Key中
gcloudToken, err := vertexai.GcloudAuth(vertexSecret.ClientEmail, vertexSecret.PrivateKey)
if err != nil {
c.JSON(http.StatusInternalServerError, error.ErrorData(err.Error()))
return
}
// 拼接vertex的请求地址
targetURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", vcmodel.Region, vertexSecret.ProjectID, vcmodel.Region, vcmodel.VertexName)
req, _ = http.NewRequest("POST", targetURL, bytes.NewReader(claudReq.ByteJson()))
req.Header.Set("Authorization", "Bearer "+gcloudToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Accept-Encoding", "identity")
} else {
req, _ = http.NewRequest("POST", targetURL, bytes.NewReader(claudReq.ByteJson()))
req.Header.Set("x-api-key", apiKey.Key)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("Content-Type", "application/json")
}
client := http.DefaultClient
rsp, err := client.Do(req)

View File

@@ -68,8 +68,8 @@ import (
)
var (
ClaudeUrl = "https://api.anthropic.com/v1/complete"
MessageEndpoint = "https://api.anthropic.com/v1/messages"
ClaudeUrl = "https://api.anthropic.com/v1/complete"
ClaudeMessageEndpoint = "https://api.anthropic.com/v1/messages"
)
type MessageModule struct {