feat: vertexai

This commit is contained in:
Zheng Kai
2023-09-01 17:33:51 +08:00
parent e620651059
commit bdf34ab55b
18 changed files with 394 additions and 25 deletions

1
misc/test/va/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
*.json

2
misc/test/va/Makefile Normal file
View File

@@ -0,0 +1,2 @@
chat:
./chat.sh

15
misc/test/va/chat.sh Executable file
View File

@@ -0,0 +1,15 @@
#!/bin/bash
URL="http://localhost:22035/va/chat"
curl -s "$URL" \
-H "Content-Type: application/json" \
-H "VA-TOKEN: ${ORCA_VA_TOKEN}" \
-d '{
"system":"翻译下列语言为中文:",
"user":"Hello, world!",
"debug":true
}' | tee tmp-chat.json
echo
jq . tmp-chat.json

View File

@@ -4,9 +4,21 @@ package pb;
message VaChatReq {
string system = 1;
repeated string user = 2;
string user = 2;
bool noCache = 3;
VaParam param = 4;
bool debug = 5;
}
message VaChatRsp {
string content = 1;
bool blocked = 2;
}
message VaChatWebRsp {
bool ok = 1;
VaChatRsp data = 2;
VaDebug debug = 3;
string error = 4;
}
message VaParam {
@@ -15,3 +27,19 @@ message VaParam {
float topP = 3;
uint32 topK = 4;
}
message VaDebug {
uint32 costMs = 1;
string cahceFile = 2;
uint32 inputChar = 3;
uint32 inputToken = 4;
uint32 outputChar = 5;
uint32 outputToken = 6;
repeated VaSafety safety = 7;
uint32 totalMs = 8;
}
message VaSafety {
string category = 1;
float score = 2;
}

View File

@@ -7,3 +7,4 @@ export ORCA_WEB=":22035"
export ORCA_ES_ADDR="https://10.0.84.49:9200/"
export ORCA_ES_USER=""
export ORCA_ES_PASS=""
export ORCA_VA_TOKEN=""

View File

@@ -17,4 +17,6 @@ var (
ESAddr = ``
ESUser = ``
ESPass = ``
VAToken = ``
)

View File

@@ -20,6 +20,7 @@ func init() {
`ORCA_ES_ADDR`: &ESAddr,
`ORCA_ES_USER`: &ESUser,
`ORCA_ES_PASS`: &ESPass,
`ORCA_VA_TOKEN`: &VAToken,
}
for k, v := range list {
s := os.Getenv(k)

View File

@@ -1,6 +1,7 @@
package core
import (
"fmt"
"project/pb"
"project/util"
)
@@ -20,3 +21,8 @@ func tryCache(p *pb.Req) ([]byte, bool) {
func rspCacheFile(r *pb.Req) string {
return util.CacheName(r.Hash()) + `-rsp.json`
}
func cacheFile(hash [16]byte) string {
s := fmt.Sprintf(`cache/%02x/%02x/%02x/%x`, hash[0], hash[1], hash[2], hash[3:])
return s
}

View File

@@ -78,6 +78,6 @@ func (pr *row) fetchRemote() (ab []byte, err error) {
func writeFailLog(hash [16]byte, ab []byte) {
date := time.Now().Format(`0102/150405`)
file := fmt.Sprintf(`fail/%s-%x.txt`, date, hash)
os.MkdirAll(path.Dir(util.StaticFile(file)), 0755)
os.MkdirAll(path.Dir(util.Static(file)), 0755)
util.WriteFile(file, ab)
}

View File

@@ -32,6 +32,7 @@ func (c *Core) getAB(p *pb.Req, r *http.Request) (ab []byte, cached bool, pr *ro
go func() {
reqFile := util.CacheName(p.Hash()) + `-req.json`
if !util.FileExists(reqFile) {
util.Mkdir(reqFile)
util.WriteFile(reqFile, p.Body)
}
}()

View File

@@ -31,6 +31,7 @@ require (
github.com/prometheus/client_model v0.4.0 // indirect
github.com/prometheus/common v0.44.0 // indirect
github.com/prometheus/procfs v0.11.1 // indirect
github.com/zhengkai/coral/v2 v2.0.3 // indirect
go.opencensus.io v0.24.0 // indirect
golang.org/x/crypto v0.9.0 // indirect
golang.org/x/net v0.10.0 // indirect

View File

@@ -102,6 +102,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/zhengkai/coral/v2 v2.0.3 h1:SB/uWDpPsOgsexmH8qZdgpDK7bwErEpb5XQ+7m/jhps=
github.com/zhengkai/coral/v2 v2.0.3/go.mod h1:3gGfB8tumy+OZPFFOeu/5ykeTgDc01sAdvOPPRxwIzw=
github.com/zhengkai/life-go v1.0.3 h1:rzm+Hb8H4He5trWx3lthFEQPf3sHpns0bDZ7vubT6sI=
github.com/zhengkai/life-go v1.0.3/go.mod h1:e2RGLfk+uRzjhRrMQash9X4iY3jAuGj99r0qj5JS7m4=
github.com/zhengkai/zog v1.0.3 h1:dkJdXJKRjbqqlseFycA1d80AUU6HAZrPe4WplpmwTo4=

View File

@@ -19,10 +19,14 @@ type DownloadFunc func(url string) (ab []byte, err error)
// CacheName ...
func CacheName(hash [16]byte) string {
s := fmt.Sprintf(`cache/%02x/%02x/%02x/%x`, hash[0], hash[1], hash[2], hash[3:])
os.MkdirAll(StaticFile(filepath.Dir(s)), 0755)
return s
}
// Mkdir ...
func Mkdir(filename string) {
os.MkdirAll(Static(filepath.Dir(filename)), 0755)
}
// FileExists ...
func FileExists(filename string) bool {
filename = fmt.Sprintf(`%s/%s`, config.StaticDir, filename)
@@ -36,12 +40,12 @@ func IsURL(s string) bool {
// ReadFile ...
func ReadFile(file string) (ab []byte, err error) {
ab, err = os.ReadFile(StaticFile(file))
ab, err = os.ReadFile(Static(file))
return
}
// StaticFile ...
func StaticFile(file string) string {
// Static ...
func Static(file string) string {
file = strings.TrimPrefix(file, config.StaticDir+`/`)
return fmt.Sprintf(`%s/%s`, config.StaticDir, file)
}
@@ -67,8 +71,6 @@ func SaveData(name string, p proto.Message) (err error) {
// WriteFile ...
func WriteFile(file string, ab []byte) (err error) {
file = StaticFile(file)
defer zj.Watch(&err)
f, err := os.CreateTemp(config.StaticDir+`/tmp`, `wr-*`)
@@ -86,10 +88,19 @@ func WriteFile(file string, ab []byte) (err error) {
return
}
err = os.Rename(tmpName, file)
err = os.Rename(tmpName, Static(file))
if err != nil {
return
}
return
}
// WriteJSON ...
func WriteJSON(file string, d any) error {
ab, err := json.Marshal(d)
if err != nil {
return err
}
return WriteFile(file, ab)
}

View File

@@ -1,11 +1,18 @@
package vertexai
import (
"crypto/md5"
"encoding/json"
"errors"
"fmt"
"project/pb"
"project/util"
"time"
aiplatform "cloud.google.com/go/aiplatform/apiv1"
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
"github.com/zhengkai/coral/v2"
"github.com/zhengkai/life-go"
"google.golang.org/protobuf/types/known/structpb"
)
@@ -14,24 +21,69 @@ var chatClient *aiplatform.PredictionClient
var errEmptyAnswer = errors.New(`empty answer`)
var errBlocked = errors.New(`blocked by google`)
// Chat ...
func Chat(req *pb.VaChatReq) {
var chatCache = coral.NewLRU(loadChatForCoral, 1000, 100)
type chatKey struct {
System string `json:"system"`
User string `json:"user"`
Temperature float32 `json:"temperature"`
MaxOutputTokens uint32 `json:"maxOutputTokens"`
TopP float32 `json:"topP"`
TopK uint32 `json:"topK"`
}
func buildChatReq(system string, user []string, param *pb.VaParam) (*aiplatformpb.PredictRequest, error) {
// ChatRsp ...
type ChatRsp struct {
Answer *pb.VaChatRsp `json:"answer,omitempty"`
Raw *aiplatformpb.PredictResponse `json:"raw"`
CostMs uint32 `json:"costMs"`
}
// Chat ...
func Chat(req *pb.VaChatReq) (*ChatRsp, error) {
p := req.Param
if p == nil {
p = defaultParam
}
k := chatKey{
System: req.System,
User: req.User,
Temperature: p.Temperature,
MaxOutputTokens: p.MaxOutputTokens,
TopP: p.TopP,
TopK: p.TopK,
}
if req.NoCache {
chatCache.Delete(k)
} else {
ab, err := util.ReadFile(chatCacheFile(k) + `.json`)
if err == nil && len(ab) > 2 {
rsp := &ChatRsp{}
err = json.Unmarshal(ab, rsp)
if err == nil {
return rsp, nil
}
}
}
return chatCache.Get(k)
}
func buildChatReq(k chatKey) (*aiplatformpb.PredictRequest, error) {
m := map[string]any{
`context`: system,
`context`: k.System,
}
if len(user) > 0 {
var li []any
for _, v := range user {
li = append(li, map[string]any{
if k.User != `` {
m[`messages`] = []any{
map[string]any{
`author`: `user`,
`content`: v,
})
`content`: k.User,
},
}
m[`messages`] = li
}
inst, err := structpb.NewStruct(m)
@@ -40,10 +92,10 @@ func buildChatReq(system string, user []string, param *pb.VaParam) (*aiplatformp
}
p, err := structpb.NewStruct(map[string]any{
`temperature`: param.Temperature,
`maxOutputTokens`: param.MaxOutputTokens,
`topP`: param.TopP,
`topK`: param.TopK,
`temperature`: k.Temperature,
`maxOutputTokens`: k.MaxOutputTokens,
`topP`: k.TopP,
`topK`: k.TopK,
})
if err != nil {
return nil, err
@@ -95,3 +147,116 @@ func isBlocked(o *structpb.Value) bool {
}
return false
}
func loadChat(k chatKey) (*ChatRsp, error) {
req, err := buildChatReq(k)
if err != nil {
return nil, err
}
ctx, cancel := life.CTXTimeout(10 * time.Second)
t := time.Now()
rsp, err := chatClient.Predict(ctx, req)
cancel()
if err != nil {
return nil, err
}
r := &ChatRsp{
Raw: rsp,
CostMs: uint32(time.Since(t) / time.Millisecond),
}
answer := &pb.VaChatRsp{}
answer.Content, err = getChatVal(rsp)
if err == errBlocked {
err = nil
answer.Blocked = true
}
if err != nil {
return nil, err
}
r.Answer = answer
go chatSaveCache(k, r)
return r, nil
}
func loadChatForCoral(k chatKey) (*ChatRsp, *time.Time, error) {
r, err := loadChat(k)
if err != nil {
return nil, nil, err
}
return r, nil, nil
}
func chatCacheFile(k chatKey) string {
ab, _ := json.Marshal(k)
h := md5.Sum(ab)
file := fmt.Sprintf(`vertexai/chat/%02x/%02x/%02x/%x`, h[0], h[1], h[2], h[3:])
return file
}
func chatSaveCache(k chatKey, rsp *ChatRsp) {
file := chatCacheFile(k)
util.Mkdir(file)
util.WriteJSON(file+`.json`, rsp)
}
// Debug ...
func (rsp *ChatRsp) Debug() *pb.VaDebug {
d := &pb.VaDebug{
CostMs: rsp.CostMs,
}
getToken(d, rsp.Raw)
getSafety(d, rsp.Raw)
return d
}
func getToken(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) {
tm := SpbMap(raw.GetMetadata(), `tokenMetadata`)
if tm == nil {
return
}
input := SpbMap(tm, `inputTokenCount`)
if input != nil {
d.InputChar = uint32(SpbMap(input, `totalBillableCharacters`).GetNumberValue())
d.InputToken = uint32(SpbMap(input, `totalTokens`).GetNumberValue())
}
output := SpbMap(tm, `outputTokenCount`)
if output != nil {
d.OutputChar = uint32(SpbMap(output, `totalBillableCharacters`).GetNumberValue())
d.OutputToken = uint32(SpbMap(output, `totalTokens`).GetNumberValue())
}
}
func getSafety(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) {
p := raw.GetPredictions()
if len(p) == 0 {
return
}
sa := SpbMap(p[0], `safetyAttributes`).GetListValue().GetValues()
for _, v := range sa {
c := SpbMap(v, `categories`).GetListValue().GetValues()
if len(c) == 0 {
continue
}
s := SpbMap(v, `scores`).GetListValue().GetValues()
if len(s) != len(c) {
continue
}
for i, cv := range c {
row := &pb.VaSafety{
Category: cv.GetStringValue(),
Score: float32(s[i].GetNumberValue()),
}
d.Safety = append(d.Safety, row)
}
}
}

View File

@@ -0,0 +1,23 @@
package vertexai
import (
"project/util"
"project/zj"
aiplatform "cloud.google.com/go/aiplatform/apiv1"
"github.com/zhengkai/life-go"
"google.golang.org/api/option"
)
func init() {
var err error
chatClient, err = aiplatform.NewPredictionClient(
life.CTX,
option.WithEndpoint(`us-central1-aiplatform.googleapis.com:443`),
option.WithCredentialsFile(util.Static(`aigc-llm-730bb179e13c.json`)),
)
if err != nil {
zj.W(err)
}
}

View File

@@ -1,6 +1,17 @@
package vertexai
import "google.golang.org/protobuf/types/known/structpb"
import (
"project/pb"
"google.golang.org/protobuf/types/known/structpb"
)
var defaultParam = &pb.VaParam{
Temperature: 0.2,
MaxOutputTokens: 0,
TopP: 1,
TopK: 40,
}
// SpbMap ...
func SpbMap(o *structpb.Value, key string) *structpb.Value {

View File

@@ -0,0 +1,97 @@
package vertexai
import (
"encoding/json"
"errors"
"io"
"net/http"
"project/config"
"project/pb"
"time"
)
// ChatHandle ...
func ChatHandle(w http.ResponseWriter, r *http.Request) {
t := time.Now()
data, debug, err := chatHandle(w, r)
o := &pb.VaChatWebRsp{
Data: data,
Debug: debug,
}
if err == nil {
o.Ok = true
} else {
o.Error = err.Error()
}
if debug != nil {
i := uint32(time.Since(t) / time.Millisecond)
if i < 1 {
i = 1
}
debug.TotalMs = i
}
ab, _ := json.Marshal(o)
w.Write(ab)
}
func chatHandleInput(w http.ResponseWriter, r *http.Request) (*pb.VaChatReq, error) {
if r.Method != `POST` {
w.WriteHeader(http.StatusBadRequest)
return nil, errors.New(`method not allow`)
}
token := r.Header.Get(`VA-TOKEN`)
if token == `` {
w.WriteHeader(http.StatusNonAuthoritativeInfo)
return nil, errors.New(`no token`)
}
if config.VAToken == `` {
w.WriteHeader(http.StatusInternalServerError)
err := errors.New(`no token in server`)
return nil, err
}
if config.VAToken != token {
w.WriteHeader(http.StatusUnauthorized)
return nil, errors.New(`token not match`)
}
o := &pb.VaChatReq{}
ab, err := io.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return nil, errors.New(`read body fail`)
}
err = json.Unmarshal(ab, o)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return nil, err
}
return o, nil
}
func chatHandle(w http.ResponseWriter, r *http.Request) (*pb.VaChatRsp, *pb.VaDebug, error) {
req, err := chatHandleInput(w, r)
if err != nil {
return nil, nil, err
}
rsp, err := Chat(req)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return nil, nil, err
}
var debug *pb.VaDebug
if req.Debug {
debug = rsp.Debug()
}
return rsp.Answer, debug, nil
}

View File

@@ -4,6 +4,7 @@ import (
"net/http"
"project/config"
"project/core"
"project/vertexai"
"project/zj"
"time"
@@ -19,6 +20,7 @@ func Server() {
mux.Handle(`/v1/moderations`, core.NewCore())
mux.Handle(`/v1/completions`, core.NewCore())
mux.Handle(`/v1/chat/completions`, core.NewCore())
mux.HandleFunc(`/va/chat`, vertexai.ChatHandle)
s := &http.Server{
Addr: config.WebAddr,