mirror of
https://github.com/zhengkai/orca.git
synced 2026-03-01 00:35:36 +08:00
feat: vertexai
This commit is contained in:
@@ -0,0 +1 @@
|
|||||||
|
*.json
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
chat:
|
||||||
|
./chat.sh
|
||||||
Executable
+15
@@ -0,0 +1,15 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
URL="http://localhost:22035/va/chat"
|
||||||
|
|
||||||
|
curl -s "$URL" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "VA-TOKEN: ${ORCA_VA_TOKEN}" \
|
||||||
|
-d '{
|
||||||
|
"system":"翻译下列语言为中文:",
|
||||||
|
"user":"Hello, world!",
|
||||||
|
"debug":true
|
||||||
|
}' | tee tmp-chat.json
|
||||||
|
echo
|
||||||
|
|
||||||
|
jq . tmp-chat.json
|
||||||
+29
-1
@@ -4,9 +4,21 @@ package pb;
|
|||||||
|
|
||||||
message VaChatReq {
|
message VaChatReq {
|
||||||
string system = 1;
|
string system = 1;
|
||||||
repeated string user = 2;
|
string user = 2;
|
||||||
bool noCache = 3;
|
bool noCache = 3;
|
||||||
VaParam param = 4;
|
VaParam param = 4;
|
||||||
|
bool debug = 5;
|
||||||
|
}
|
||||||
|
message VaChatRsp {
|
||||||
|
string content = 1;
|
||||||
|
bool blocked = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message VaChatWebRsp {
|
||||||
|
bool ok = 1;
|
||||||
|
VaChatRsp data = 2;
|
||||||
|
VaDebug debug = 3;
|
||||||
|
string error = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
message VaParam {
|
message VaParam {
|
||||||
@@ -15,3 +27,19 @@ message VaParam {
|
|||||||
float topP = 3;
|
float topP = 3;
|
||||||
uint32 topK = 4;
|
uint32 topK = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message VaDebug {
|
||||||
|
uint32 costMs = 1;
|
||||||
|
string cahceFile = 2;
|
||||||
|
uint32 inputChar = 3;
|
||||||
|
uint32 inputToken = 4;
|
||||||
|
uint32 outputChar = 5;
|
||||||
|
uint32 outputToken = 6;
|
||||||
|
repeated VaSafety safety = 7;
|
||||||
|
uint32 totalMs = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
message VaSafety {
|
||||||
|
string category = 1;
|
||||||
|
float score = 2;
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,3 +7,4 @@ export ORCA_WEB=":22035"
|
|||||||
export ORCA_ES_ADDR="https://10.0.84.49:9200/"
|
export ORCA_ES_ADDR="https://10.0.84.49:9200/"
|
||||||
export ORCA_ES_USER=""
|
export ORCA_ES_USER=""
|
||||||
export ORCA_ES_PASS=""
|
export ORCA_ES_PASS=""
|
||||||
|
export ORCA_VA_TOKEN=""
|
||||||
|
|||||||
@@ -17,4 +17,6 @@ var (
|
|||||||
ESAddr = ``
|
ESAddr = ``
|
||||||
ESUser = ``
|
ESUser = ``
|
||||||
ESPass = ``
|
ESPass = ``
|
||||||
|
|
||||||
|
VAToken = ``
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ func init() {
|
|||||||
`ORCA_ES_ADDR`: &ESAddr,
|
`ORCA_ES_ADDR`: &ESAddr,
|
||||||
`ORCA_ES_USER`: &ESUser,
|
`ORCA_ES_USER`: &ESUser,
|
||||||
`ORCA_ES_PASS`: &ESPass,
|
`ORCA_ES_PASS`: &ESPass,
|
||||||
|
`ORCA_VA_TOKEN`: &VAToken,
|
||||||
}
|
}
|
||||||
for k, v := range list {
|
for k, v := range list {
|
||||||
s := os.Getenv(k)
|
s := os.Getenv(k)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package core
|
package core
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"project/pb"
|
"project/pb"
|
||||||
"project/util"
|
"project/util"
|
||||||
)
|
)
|
||||||
@@ -20,3 +21,8 @@ func tryCache(p *pb.Req) ([]byte, bool) {
|
|||||||
func rspCacheFile(r *pb.Req) string {
|
func rspCacheFile(r *pb.Req) string {
|
||||||
return util.CacheName(r.Hash()) + `-rsp.json`
|
return util.CacheName(r.Hash()) + `-rsp.json`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cacheFile(hash [16]byte) string {
|
||||||
|
s := fmt.Sprintf(`cache/%02x/%02x/%02x/%x`, hash[0], hash[1], hash[2], hash[3:])
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|||||||
@@ -78,6 +78,6 @@ func (pr *row) fetchRemote() (ab []byte, err error) {
|
|||||||
func writeFailLog(hash [16]byte, ab []byte) {
|
func writeFailLog(hash [16]byte, ab []byte) {
|
||||||
date := time.Now().Format(`0102/150405`)
|
date := time.Now().Format(`0102/150405`)
|
||||||
file := fmt.Sprintf(`fail/%s-%x.txt`, date, hash)
|
file := fmt.Sprintf(`fail/%s-%x.txt`, date, hash)
|
||||||
os.MkdirAll(path.Dir(util.StaticFile(file)), 0755)
|
os.MkdirAll(path.Dir(util.Static(file)), 0755)
|
||||||
util.WriteFile(file, ab)
|
util.WriteFile(file, ab)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ func (c *Core) getAB(p *pb.Req, r *http.Request) (ab []byte, cached bool, pr *ro
|
|||||||
go func() {
|
go func() {
|
||||||
reqFile := util.CacheName(p.Hash()) + `-req.json`
|
reqFile := util.CacheName(p.Hash()) + `-req.json`
|
||||||
if !util.FileExists(reqFile) {
|
if !util.FileExists(reqFile) {
|
||||||
|
util.Mkdir(reqFile)
|
||||||
util.WriteFile(reqFile, p.Body)
|
util.WriteFile(reqFile, p.Body)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ require (
|
|||||||
github.com/prometheus/client_model v0.4.0 // indirect
|
github.com/prometheus/client_model v0.4.0 // indirect
|
||||||
github.com/prometheus/common v0.44.0 // indirect
|
github.com/prometheus/common v0.44.0 // indirect
|
||||||
github.com/prometheus/procfs v0.11.1 // indirect
|
github.com/prometheus/procfs v0.11.1 // indirect
|
||||||
|
github.com/zhengkai/coral/v2 v2.0.3 // indirect
|
||||||
go.opencensus.io v0.24.0 // indirect
|
go.opencensus.io v0.24.0 // indirect
|
||||||
golang.org/x/crypto v0.9.0 // indirect
|
golang.org/x/crypto v0.9.0 // indirect
|
||||||
golang.org/x/net v0.10.0 // indirect
|
golang.org/x/net v0.10.0 // indirect
|
||||||
|
|||||||
@@ -102,6 +102,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
|||||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||||
|
github.com/zhengkai/coral/v2 v2.0.3 h1:SB/uWDpPsOgsexmH8qZdgpDK7bwErEpb5XQ+7m/jhps=
|
||||||
|
github.com/zhengkai/coral/v2 v2.0.3/go.mod h1:3gGfB8tumy+OZPFFOeu/5ykeTgDc01sAdvOPPRxwIzw=
|
||||||
github.com/zhengkai/life-go v1.0.3 h1:rzm+Hb8H4He5trWx3lthFEQPf3sHpns0bDZ7vubT6sI=
|
github.com/zhengkai/life-go v1.0.3 h1:rzm+Hb8H4He5trWx3lthFEQPf3sHpns0bDZ7vubT6sI=
|
||||||
github.com/zhengkai/life-go v1.0.3/go.mod h1:e2RGLfk+uRzjhRrMQash9X4iY3jAuGj99r0qj5JS7m4=
|
github.com/zhengkai/life-go v1.0.3/go.mod h1:e2RGLfk+uRzjhRrMQash9X4iY3jAuGj99r0qj5JS7m4=
|
||||||
github.com/zhengkai/zog v1.0.3 h1:dkJdXJKRjbqqlseFycA1d80AUU6HAZrPe4WplpmwTo4=
|
github.com/zhengkai/zog v1.0.3 h1:dkJdXJKRjbqqlseFycA1d80AUU6HAZrPe4WplpmwTo4=
|
||||||
|
|||||||
+18
-7
@@ -19,10 +19,14 @@ type DownloadFunc func(url string) (ab []byte, err error)
|
|||||||
// CacheName ...
|
// CacheName ...
|
||||||
func CacheName(hash [16]byte) string {
|
func CacheName(hash [16]byte) string {
|
||||||
s := fmt.Sprintf(`cache/%02x/%02x/%02x/%x`, hash[0], hash[1], hash[2], hash[3:])
|
s := fmt.Sprintf(`cache/%02x/%02x/%02x/%x`, hash[0], hash[1], hash[2], hash[3:])
|
||||||
os.MkdirAll(StaticFile(filepath.Dir(s)), 0755)
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mkdir ...
|
||||||
|
func Mkdir(filename string) {
|
||||||
|
os.MkdirAll(Static(filepath.Dir(filename)), 0755)
|
||||||
|
}
|
||||||
|
|
||||||
// FileExists ...
|
// FileExists ...
|
||||||
func FileExists(filename string) bool {
|
func FileExists(filename string) bool {
|
||||||
filename = fmt.Sprintf(`%s/%s`, config.StaticDir, filename)
|
filename = fmt.Sprintf(`%s/%s`, config.StaticDir, filename)
|
||||||
@@ -36,12 +40,12 @@ func IsURL(s string) bool {
|
|||||||
|
|
||||||
// ReadFile ...
|
// ReadFile ...
|
||||||
func ReadFile(file string) (ab []byte, err error) {
|
func ReadFile(file string) (ab []byte, err error) {
|
||||||
ab, err = os.ReadFile(StaticFile(file))
|
ab, err = os.ReadFile(Static(file))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// StaticFile ...
|
// Static ...
|
||||||
func StaticFile(file string) string {
|
func Static(file string) string {
|
||||||
file = strings.TrimPrefix(file, config.StaticDir+`/`)
|
file = strings.TrimPrefix(file, config.StaticDir+`/`)
|
||||||
return fmt.Sprintf(`%s/%s`, config.StaticDir, file)
|
return fmt.Sprintf(`%s/%s`, config.StaticDir, file)
|
||||||
}
|
}
|
||||||
@@ -67,8 +71,6 @@ func SaveData(name string, p proto.Message) (err error) {
|
|||||||
// WriteFile ...
|
// WriteFile ...
|
||||||
func WriteFile(file string, ab []byte) (err error) {
|
func WriteFile(file string, ab []byte) (err error) {
|
||||||
|
|
||||||
file = StaticFile(file)
|
|
||||||
|
|
||||||
defer zj.Watch(&err)
|
defer zj.Watch(&err)
|
||||||
|
|
||||||
f, err := os.CreateTemp(config.StaticDir+`/tmp`, `wr-*`)
|
f, err := os.CreateTemp(config.StaticDir+`/tmp`, `wr-*`)
|
||||||
@@ -86,10 +88,19 @@ func WriteFile(file string, ab []byte) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = os.Rename(tmpName, file)
|
err = os.Rename(tmpName, Static(file))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriteJSON ...
|
||||||
|
func WriteJSON(file string, d any) error {
|
||||||
|
ab, err := json.Marshal(d)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return WriteFile(file, ab)
|
||||||
|
}
|
||||||
|
|||||||
+180
-15
@@ -1,11 +1,18 @@
|
|||||||
package vertexai
|
package vertexai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/md5"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"project/pb"
|
"project/pb"
|
||||||
|
"project/util"
|
||||||
|
"time"
|
||||||
|
|
||||||
aiplatform "cloud.google.com/go/aiplatform/apiv1"
|
aiplatform "cloud.google.com/go/aiplatform/apiv1"
|
||||||
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
|
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
|
||||||
|
"github.com/zhengkai/coral/v2"
|
||||||
|
"github.com/zhengkai/life-go"
|
||||||
"google.golang.org/protobuf/types/known/structpb"
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -14,24 +21,69 @@ var chatClient *aiplatform.PredictionClient
|
|||||||
var errEmptyAnswer = errors.New(`empty answer`)
|
var errEmptyAnswer = errors.New(`empty answer`)
|
||||||
var errBlocked = errors.New(`blocked by google`)
|
var errBlocked = errors.New(`blocked by google`)
|
||||||
|
|
||||||
// Chat ...
|
var chatCache = coral.NewLRU(loadChatForCoral, 1000, 100)
|
||||||
func Chat(req *pb.VaChatReq) {
|
|
||||||
|
type chatKey struct {
|
||||||
|
System string `json:"system"`
|
||||||
|
User string `json:"user"`
|
||||||
|
Temperature float32 `json:"temperature"`
|
||||||
|
MaxOutputTokens uint32 `json:"maxOutputTokens"`
|
||||||
|
TopP float32 `json:"topP"`
|
||||||
|
TopK uint32 `json:"topK"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildChatReq(system string, user []string, param *pb.VaParam) (*aiplatformpb.PredictRequest, error) {
|
// ChatRsp ...
|
||||||
|
type ChatRsp struct {
|
||||||
|
Answer *pb.VaChatRsp `json:"answer,omitempty"`
|
||||||
|
Raw *aiplatformpb.PredictResponse `json:"raw"`
|
||||||
|
CostMs uint32 `json:"costMs"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chat ...
|
||||||
|
func Chat(req *pb.VaChatReq) (*ChatRsp, error) {
|
||||||
|
|
||||||
|
p := req.Param
|
||||||
|
if p == nil {
|
||||||
|
p = defaultParam
|
||||||
|
}
|
||||||
|
|
||||||
|
k := chatKey{
|
||||||
|
System: req.System,
|
||||||
|
User: req.User,
|
||||||
|
Temperature: p.Temperature,
|
||||||
|
MaxOutputTokens: p.MaxOutputTokens,
|
||||||
|
TopP: p.TopP,
|
||||||
|
TopK: p.TopK,
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.NoCache {
|
||||||
|
chatCache.Delete(k)
|
||||||
|
} else {
|
||||||
|
ab, err := util.ReadFile(chatCacheFile(k) + `.json`)
|
||||||
|
if err == nil && len(ab) > 2 {
|
||||||
|
rsp := &ChatRsp{}
|
||||||
|
err = json.Unmarshal(ab, rsp)
|
||||||
|
if err == nil {
|
||||||
|
return rsp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return chatCache.Get(k)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildChatReq(k chatKey) (*aiplatformpb.PredictRequest, error) {
|
||||||
|
|
||||||
m := map[string]any{
|
m := map[string]any{
|
||||||
`context`: system,
|
`context`: k.System,
|
||||||
}
|
}
|
||||||
if len(user) > 0 {
|
if k.User != `` {
|
||||||
var li []any
|
m[`messages`] = []any{
|
||||||
for _, v := range user {
|
map[string]any{
|
||||||
li = append(li, map[string]any{
|
|
||||||
`author`: `user`,
|
`author`: `user`,
|
||||||
`content`: v,
|
`content`: k.User,
|
||||||
})
|
},
|
||||||
}
|
}
|
||||||
m[`messages`] = li
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inst, err := structpb.NewStruct(m)
|
inst, err := structpb.NewStruct(m)
|
||||||
@@ -40,10 +92,10 @@ func buildChatReq(system string, user []string, param *pb.VaParam) (*aiplatformp
|
|||||||
}
|
}
|
||||||
|
|
||||||
p, err := structpb.NewStruct(map[string]any{
|
p, err := structpb.NewStruct(map[string]any{
|
||||||
`temperature`: param.Temperature,
|
`temperature`: k.Temperature,
|
||||||
`maxOutputTokens`: param.MaxOutputTokens,
|
`maxOutputTokens`: k.MaxOutputTokens,
|
||||||
`topP`: param.TopP,
|
`topP`: k.TopP,
|
||||||
`topK`: param.TopK,
|
`topK`: k.TopK,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -95,3 +147,116 @@ func isBlocked(o *structpb.Value) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func loadChat(k chatKey) (*ChatRsp, error) {
|
||||||
|
|
||||||
|
req, err := buildChatReq(k)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := life.CTXTimeout(10 * time.Second)
|
||||||
|
t := time.Now()
|
||||||
|
rsp, err := chatClient.Predict(ctx, req)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
r := &ChatRsp{
|
||||||
|
Raw: rsp,
|
||||||
|
CostMs: uint32(time.Since(t) / time.Millisecond),
|
||||||
|
}
|
||||||
|
|
||||||
|
answer := &pb.VaChatRsp{}
|
||||||
|
answer.Content, err = getChatVal(rsp)
|
||||||
|
if err == errBlocked {
|
||||||
|
err = nil
|
||||||
|
answer.Blocked = true
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Answer = answer
|
||||||
|
|
||||||
|
go chatSaveCache(k, r)
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadChatForCoral(k chatKey) (*ChatRsp, *time.Time, error) {
|
||||||
|
r, err := loadChat(k)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return r, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func chatCacheFile(k chatKey) string {
|
||||||
|
ab, _ := json.Marshal(k)
|
||||||
|
h := md5.Sum(ab)
|
||||||
|
file := fmt.Sprintf(`vertexai/chat/%02x/%02x/%02x/%x`, h[0], h[1], h[2], h[3:])
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
|
||||||
|
func chatSaveCache(k chatKey, rsp *ChatRsp) {
|
||||||
|
file := chatCacheFile(k)
|
||||||
|
util.Mkdir(file)
|
||||||
|
util.WriteJSON(file+`.json`, rsp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug ...
|
||||||
|
func (rsp *ChatRsp) Debug() *pb.VaDebug {
|
||||||
|
d := &pb.VaDebug{
|
||||||
|
CostMs: rsp.CostMs,
|
||||||
|
}
|
||||||
|
|
||||||
|
getToken(d, rsp.Raw)
|
||||||
|
getSafety(d, rsp.Raw)
|
||||||
|
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func getToken(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) {
|
||||||
|
tm := SpbMap(raw.GetMetadata(), `tokenMetadata`)
|
||||||
|
if tm == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
input := SpbMap(tm, `inputTokenCount`)
|
||||||
|
if input != nil {
|
||||||
|
d.InputChar = uint32(SpbMap(input, `totalBillableCharacters`).GetNumberValue())
|
||||||
|
d.InputToken = uint32(SpbMap(input, `totalTokens`).GetNumberValue())
|
||||||
|
}
|
||||||
|
output := SpbMap(tm, `outputTokenCount`)
|
||||||
|
if output != nil {
|
||||||
|
d.OutputChar = uint32(SpbMap(output, `totalBillableCharacters`).GetNumberValue())
|
||||||
|
d.OutputToken = uint32(SpbMap(output, `totalTokens`).GetNumberValue())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func getSafety(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) {
|
||||||
|
p := raw.GetPredictions()
|
||||||
|
if len(p) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sa := SpbMap(p[0], `safetyAttributes`).GetListValue().GetValues()
|
||||||
|
|
||||||
|
for _, v := range sa {
|
||||||
|
c := SpbMap(v, `categories`).GetListValue().GetValues()
|
||||||
|
if len(c) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s := SpbMap(v, `scores`).GetListValue().GetValues()
|
||||||
|
if len(s) != len(c) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, cv := range c {
|
||||||
|
row := &pb.VaSafety{
|
||||||
|
Category: cv.GetStringValue(),
|
||||||
|
Score: float32(s[i].GetNumberValue()),
|
||||||
|
}
|
||||||
|
d.Safety = append(d.Safety, row)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,23 @@
|
|||||||
|
package vertexai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"project/util"
|
||||||
|
"project/zj"
|
||||||
|
|
||||||
|
aiplatform "cloud.google.com/go/aiplatform/apiv1"
|
||||||
|
"github.com/zhengkai/life-go"
|
||||||
|
"google.golang.org/api/option"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
|
||||||
|
var err error
|
||||||
|
chatClient, err = aiplatform.NewPredictionClient(
|
||||||
|
life.CTX,
|
||||||
|
option.WithEndpoint(`us-central1-aiplatform.googleapis.com:443`),
|
||||||
|
option.WithCredentialsFile(util.Static(`aigc-llm-730bb179e13c.json`)),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
zj.W(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,17 @@
|
|||||||
package vertexai
|
package vertexai
|
||||||
|
|
||||||
import "google.golang.org/protobuf/types/known/structpb"
|
import (
|
||||||
|
"project/pb"
|
||||||
|
|
||||||
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
)
|
||||||
|
|
||||||
|
var defaultParam = &pb.VaParam{
|
||||||
|
Temperature: 0.2,
|
||||||
|
MaxOutputTokens: 0,
|
||||||
|
TopP: 1,
|
||||||
|
TopK: 40,
|
||||||
|
}
|
||||||
|
|
||||||
// SpbMap ...
|
// SpbMap ...
|
||||||
func SpbMap(o *structpb.Value, key string) *structpb.Value {
|
func SpbMap(o *structpb.Value, key string) *structpb.Value {
|
||||||
|
|||||||
@@ -0,0 +1,97 @@
|
|||||||
|
package vertexai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"project/config"
|
||||||
|
"project/pb"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChatHandle ...
|
||||||
|
func ChatHandle(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
|
t := time.Now()
|
||||||
|
|
||||||
|
data, debug, err := chatHandle(w, r)
|
||||||
|
o := &pb.VaChatWebRsp{
|
||||||
|
Data: data,
|
||||||
|
Debug: debug,
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
o.Ok = true
|
||||||
|
} else {
|
||||||
|
o.Error = err.Error()
|
||||||
|
}
|
||||||
|
if debug != nil {
|
||||||
|
i := uint32(time.Since(t) / time.Millisecond)
|
||||||
|
if i < 1 {
|
||||||
|
i = 1
|
||||||
|
}
|
||||||
|
debug.TotalMs = i
|
||||||
|
}
|
||||||
|
ab, _ := json.Marshal(o)
|
||||||
|
w.Write(ab)
|
||||||
|
}
|
||||||
|
|
||||||
|
func chatHandleInput(w http.ResponseWriter, r *http.Request) (*pb.VaChatReq, error) {
|
||||||
|
|
||||||
|
if r.Method != `POST` {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return nil, errors.New(`method not allow`)
|
||||||
|
}
|
||||||
|
|
||||||
|
token := r.Header.Get(`VA-TOKEN`)
|
||||||
|
|
||||||
|
if token == `` {
|
||||||
|
w.WriteHeader(http.StatusNonAuthoritativeInfo)
|
||||||
|
return nil, errors.New(`no token`)
|
||||||
|
}
|
||||||
|
if config.VAToken == `` {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
err := errors.New(`no token in server`)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if config.VAToken != token {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return nil, errors.New(`token not match`)
|
||||||
|
}
|
||||||
|
|
||||||
|
o := &pb.VaChatReq{}
|
||||||
|
|
||||||
|
ab, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return nil, errors.New(`read body fail`)
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(ab, o)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return o, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func chatHandle(w http.ResponseWriter, r *http.Request) (*pb.VaChatRsp, *pb.VaDebug, error) {
|
||||||
|
|
||||||
|
req, err := chatHandleInput(w, r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rsp, err := Chat(req)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var debug *pb.VaDebug
|
||||||
|
if req.Debug {
|
||||||
|
debug = rsp.Debug()
|
||||||
|
}
|
||||||
|
|
||||||
|
return rsp.Answer, debug, nil
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"project/config"
|
"project/config"
|
||||||
"project/core"
|
"project/core"
|
||||||
|
"project/vertexai"
|
||||||
"project/zj"
|
"project/zj"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -19,6 +20,7 @@ func Server() {
|
|||||||
mux.Handle(`/v1/moderations`, core.NewCore())
|
mux.Handle(`/v1/moderations`, core.NewCore())
|
||||||
mux.Handle(`/v1/completions`, core.NewCore())
|
mux.Handle(`/v1/completions`, core.NewCore())
|
||||||
mux.Handle(`/v1/chat/completions`, core.NewCore())
|
mux.Handle(`/v1/chat/completions`, core.NewCore())
|
||||||
|
mux.HandleFunc(`/va/chat`, vertexai.ChatHandle)
|
||||||
|
|
||||||
s := &http.Server{
|
s := &http.Server{
|
||||||
Addr: config.WebAddr,
|
Addr: config.WebAddr,
|
||||||
|
|||||||
Reference in New Issue
Block a user