refact & update new model
This commit is contained in:
@@ -10,8 +10,10 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"opencatd-open/pkg/error"
|
||||
"opencatd-open/pkg/openai"
|
||||
"opencatd-open/pkg/tokenizer"
|
||||
"opencatd-open/pkg/vertexai"
|
||||
"opencatd-open/store"
|
||||
"strings"
|
||||
|
||||
@@ -27,14 +29,15 @@ func ChatTextCompletions(c *gin.Context, chatReq *openai.ChatCompletionRequest)
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages any `json:"messages,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages any `json:"messages,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
AnthropicVersion string `json:"anthropic_version,omitempty"`
|
||||
}
|
||||
|
||||
func (c *ChatRequest) ByteJson() []byte {
|
||||
@@ -117,8 +120,12 @@ type ClaudeStreamResponse struct {
|
||||
}
|
||||
|
||||
func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
|
||||
var (
|
||||
req *http.Request
|
||||
targetURL = ClaudeMessageEndpoint
|
||||
)
|
||||
|
||||
onekey, err := store.SelectKeyCache("claude")
|
||||
apiKey, err := store.SelectKeyCache("claude")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -131,6 +138,10 @@ func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
|
||||
// claudReq.Temperature = chatReq.Temperature
|
||||
claudReq.TopP = chatReq.TopP
|
||||
claudReq.MaxTokens = 4096
|
||||
if apiKey.ApiType == "vertex" {
|
||||
claudReq.AnthropicVersion = "vertex-2023-10-16"
|
||||
claudReq.Model = ""
|
||||
}
|
||||
|
||||
var prompt string
|
||||
|
||||
@@ -181,10 +192,40 @@ func ChatMessages(c *gin.Context, chatReq *openai.ChatCompletionRequest) {
|
||||
|
||||
usagelog.PromptCount = tokenizer.NumTokensFromStr(prompt, chatReq.Model)
|
||||
|
||||
req, _ := http.NewRequest("POST", MessageEndpoint, bytes.NewReader(claudReq.ByteJson()))
|
||||
req.Header.Set("x-api-key", onekey.Key)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if apiKey.ApiType == "vertex" {
|
||||
var vertexSecret vertexai.VertexSecretKey
|
||||
if err := json.Unmarshal([]byte(apiKey.ApiSecret), &vertexSecret); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, error.ErrorData(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
vcmodel, ok := vertexai.VertexClaudeModelMap[chatReq.Model]
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, error.ErrorData("Model not found"))
|
||||
return
|
||||
}
|
||||
|
||||
// 获取gcloud token,临时放置在apiKey.Key中
|
||||
gcloudToken, err := vertexai.GcloudAuth(vertexSecret.ClientEmail, vertexSecret.PrivateKey)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, error.ErrorData(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// 拼接vertex的请求地址
|
||||
targetURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", vcmodel.Region, vertexSecret.ProjectID, vcmodel.Region, vcmodel.VertexName)
|
||||
|
||||
req, _ = http.NewRequest("POST", targetURL, bytes.NewReader(claudReq.ByteJson()))
|
||||
req.Header.Set("Authorization", "Bearer "+gcloudToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
req.Header.Set("Accept-Encoding", "identity")
|
||||
} else {
|
||||
req, _ = http.NewRequest("POST", targetURL, bytes.NewReader(claudReq.ByteJson()))
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
client := http.DefaultClient
|
||||
rsp, err := client.Do(req)
|
||||
|
||||
@@ -68,8 +68,8 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ClaudeUrl = "https://api.anthropic.com/v1/complete"
|
||||
MessageEndpoint = "https://api.anthropic.com/v1/messages"
|
||||
ClaudeUrl = "https://api.anthropic.com/v1/complete"
|
||||
ClaudeMessageEndpoint = "https://api.anthropic.com/v1/messages"
|
||||
)
|
||||
|
||||
type MessageModule struct {
|
||||
|
||||
Reference in New Issue
Block a user