From bdf34ab55b228e880a4afd9a8344d920b4ff1b21 Mon Sep 17 00:00:00 2001 From: Zheng Kai Date: Fri, 1 Sep 2023 17:33:51 +0800 Subject: [PATCH] feat: vertexai --- misc/test/va/.gitignore | 1 + misc/test/va/Makefile | 2 + misc/test/va/chat.sh | 15 +++ proto/vertexai.proto | 30 +++++- server/build/sample-env.sh | 1 + server/src/config/config.go | 2 + server/src/config/init.go | 1 + server/src/core/cache.go | 6 ++ server/src/core/fetch.go | 2 +- server/src/core/req.go | 1 + server/src/go.mod | 1 + server/src/go.sum | 2 + server/src/util/file.go | 25 +++-- server/src/vertexai/chat.go | 195 +++++++++++++++++++++++++++++++++--- server/src/vertexai/init.go | 23 +++++ server/src/vertexai/util.go | 13 ++- server/src/vertexai/web.go | 97 ++++++++++++++++++ server/src/web/server.go | 2 + 18 files changed, 394 insertions(+), 25 deletions(-) create mode 100644 misc/test/va/.gitignore create mode 100644 misc/test/va/Makefile create mode 100755 misc/test/va/chat.sh create mode 100644 server/src/vertexai/init.go create mode 100644 server/src/vertexai/web.go diff --git a/misc/test/va/.gitignore b/misc/test/va/.gitignore new file mode 100644 index 0000000..a6c57f5 --- /dev/null +++ b/misc/test/va/.gitignore @@ -0,0 +1 @@ +*.json diff --git a/misc/test/va/Makefile b/misc/test/va/Makefile new file mode 100644 index 0000000..0df8075 --- /dev/null +++ b/misc/test/va/Makefile @@ -0,0 +1,2 @@ +chat: + ./chat.sh diff --git a/misc/test/va/chat.sh b/misc/test/va/chat.sh new file mode 100755 index 0000000..3fd7726 --- /dev/null +++ b/misc/test/va/chat.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +URL="http://localhost:22035/va/chat" + +curl -s "$URL" \ + -H "Content-Type: application/json" \ + -H "VA-TOKEN: ${ORCA_VA_TOKEN}" \ + -d '{ + "system":"翻译下列语言为中文:", + "user":"Hello, world!", + "debug":true +}' | tee tmp-chat.json +echo + +jq . tmp-chat.json diff --git a/proto/vertexai.proto b/proto/vertexai.proto index 08fd628..b50b126 100644 --- a/proto/vertexai.proto +++ b/proto/vertexai.proto @@ -4,9 +4,21 @@ package pb; message VaChatReq { string system = 1; - repeated string user = 2; + string user = 2; bool noCache = 3; VaParam param = 4; + bool debug = 5; +} +message VaChatRsp { + string content = 1; + bool blocked = 2; +} + +message VaChatWebRsp { + bool ok = 1; + VaChatRsp data = 2; + VaDebug debug = 3; + string error = 4; } message VaParam { @@ -15,3 +27,19 @@ message VaParam { float topP = 3; uint32 topK = 4; } + +message VaDebug { + uint32 costMs = 1; + string cahceFile = 2; + uint32 inputChar = 3; + uint32 inputToken = 4; + uint32 outputChar = 5; + uint32 outputToken = 6; + repeated VaSafety safety = 7; + uint32 totalMs = 8; +} + +message VaSafety { + string category = 1; + float score = 2; +} diff --git a/server/build/sample-env.sh b/server/build/sample-env.sh index 1e1682a..6d3fa63 100644 --- a/server/build/sample-env.sh +++ b/server/build/sample-env.sh @@ -7,3 +7,4 @@ export ORCA_WEB=":22035" export ORCA_ES_ADDR="https://10.0.84.49:9200/" export ORCA_ES_USER="" export ORCA_ES_PASS="" +export ORCA_VA_TOKEN="" diff --git a/server/src/config/config.go b/server/src/config/config.go index fd606f0..f6ecb03 100644 --- a/server/src/config/config.go +++ b/server/src/config/config.go @@ -17,4 +17,6 @@ var ( ESAddr = `` ESUser = `` ESPass = `` + + VAToken = `` ) diff --git a/server/src/config/init.go b/server/src/config/init.go index e724133..1ab1bdb 100644 --- a/server/src/config/init.go +++ b/server/src/config/init.go @@ -20,6 +20,7 @@ func init() { `ORCA_ES_ADDR`: &ESAddr, `ORCA_ES_USER`: &ESUser, `ORCA_ES_PASS`: &ESPass, + `ORCA_VA_TOKEN`: &VAToken, } for k, v := range list { s := os.Getenv(k) diff --git a/server/src/core/cache.go b/server/src/core/cache.go index 2fecef1..9c768f4 100644 --- a/server/src/core/cache.go +++ b/server/src/core/cache.go @@ -1,6 +1,7 @@ package core import ( + "fmt" "project/pb" "project/util" ) @@ -20,3 +21,8 @@ func tryCache(p *pb.Req) ([]byte, bool) { func rspCacheFile(r *pb.Req) string { return util.CacheName(r.Hash()) + `-rsp.json` } + +func cacheFile(hash [16]byte) string { + s := fmt.Sprintf(`cache/%02x/%02x/%02x/%x`, hash[0], hash[1], hash[2], hash[3:]) + return s +} diff --git a/server/src/core/fetch.go b/server/src/core/fetch.go index c12a0dc..1a5ae7f 100644 --- a/server/src/core/fetch.go +++ b/server/src/core/fetch.go @@ -78,6 +78,6 @@ func (pr *row) fetchRemote() (ab []byte, err error) { func writeFailLog(hash [16]byte, ab []byte) { date := time.Now().Format(`0102/150405`) file := fmt.Sprintf(`fail/%s-%x.txt`, date, hash) - os.MkdirAll(path.Dir(util.StaticFile(file)), 0755) + os.MkdirAll(path.Dir(util.Static(file)), 0755) util.WriteFile(file, ab) } diff --git a/server/src/core/req.go b/server/src/core/req.go index f73732a..a41aeb3 100644 --- a/server/src/core/req.go +++ b/server/src/core/req.go @@ -32,6 +32,7 @@ func (c *Core) getAB(p *pb.Req, r *http.Request) (ab []byte, cached bool, pr *ro go func() { reqFile := util.CacheName(p.Hash()) + `-req.json` if !util.FileExists(reqFile) { + util.Mkdir(reqFile) util.WriteFile(reqFile, p.Body) } }() diff --git a/server/src/go.mod b/server/src/go.mod index 6fd8fbe..1d0ce66 100644 --- a/server/src/go.mod +++ b/server/src/go.mod @@ -31,6 +31,7 @@ require ( github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/procfs v0.11.1 // indirect + github.com/zhengkai/coral/v2 v2.0.3 // indirect go.opencensus.io v0.24.0 // indirect golang.org/x/crypto v0.9.0 // indirect golang.org/x/net v0.10.0 // indirect diff --git a/server/src/go.sum b/server/src/go.sum index 7783138..473f827 100644 --- a/server/src/go.sum +++ b/server/src/go.sum @@ -102,6 +102,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/zhengkai/coral/v2 v2.0.3 h1:SB/uWDpPsOgsexmH8qZdgpDK7bwErEpb5XQ+7m/jhps= +github.com/zhengkai/coral/v2 v2.0.3/go.mod h1:3gGfB8tumy+OZPFFOeu/5ykeTgDc01sAdvOPPRxwIzw= github.com/zhengkai/life-go v1.0.3 h1:rzm+Hb8H4He5trWx3lthFEQPf3sHpns0bDZ7vubT6sI= github.com/zhengkai/life-go v1.0.3/go.mod h1:e2RGLfk+uRzjhRrMQash9X4iY3jAuGj99r0qj5JS7m4= github.com/zhengkai/zog v1.0.3 h1:dkJdXJKRjbqqlseFycA1d80AUU6HAZrPe4WplpmwTo4= diff --git a/server/src/util/file.go b/server/src/util/file.go index 89ec901..8d17b10 100644 --- a/server/src/util/file.go +++ b/server/src/util/file.go @@ -19,10 +19,14 @@ type DownloadFunc func(url string) (ab []byte, err error) // CacheName ... func CacheName(hash [16]byte) string { s := fmt.Sprintf(`cache/%02x/%02x/%02x/%x`, hash[0], hash[1], hash[2], hash[3:]) - os.MkdirAll(StaticFile(filepath.Dir(s)), 0755) return s } +// Mkdir ... +func Mkdir(filename string) { + os.MkdirAll(Static(filepath.Dir(filename)), 0755) +} + // FileExists ... func FileExists(filename string) bool { filename = fmt.Sprintf(`%s/%s`, config.StaticDir, filename) @@ -36,12 +40,12 @@ func IsURL(s string) bool { // ReadFile ... func ReadFile(file string) (ab []byte, err error) { - ab, err = os.ReadFile(StaticFile(file)) + ab, err = os.ReadFile(Static(file)) return } -// StaticFile ... -func StaticFile(file string) string { +// Static ... +func Static(file string) string { file = strings.TrimPrefix(file, config.StaticDir+`/`) return fmt.Sprintf(`%s/%s`, config.StaticDir, file) } @@ -67,8 +71,6 @@ func SaveData(name string, p proto.Message) (err error) { // WriteFile ... func WriteFile(file string, ab []byte) (err error) { - file = StaticFile(file) - defer zj.Watch(&err) f, err := os.CreateTemp(config.StaticDir+`/tmp`, `wr-*`) @@ -86,10 +88,19 @@ func WriteFile(file string, ab []byte) (err error) { return } - err = os.Rename(tmpName, file) + err = os.Rename(tmpName, Static(file)) if err != nil { return } return } + +// WriteJSON ... +func WriteJSON(file string, d any) error { + ab, err := json.Marshal(d) + if err != nil { + return err + } + return WriteFile(file, ab) +} diff --git a/server/src/vertexai/chat.go b/server/src/vertexai/chat.go index bb4822d..dd8124c 100644 --- a/server/src/vertexai/chat.go +++ b/server/src/vertexai/chat.go @@ -1,11 +1,18 @@ package vertexai import ( + "crypto/md5" + "encoding/json" "errors" + "fmt" "project/pb" + "project/util" + "time" aiplatform "cloud.google.com/go/aiplatform/apiv1" "cloud.google.com/go/aiplatform/apiv1/aiplatformpb" + "github.com/zhengkai/coral/v2" + "github.com/zhengkai/life-go" "google.golang.org/protobuf/types/known/structpb" ) @@ -14,24 +21,69 @@ var chatClient *aiplatform.PredictionClient var errEmptyAnswer = errors.New(`empty answer`) var errBlocked = errors.New(`blocked by google`) -// Chat ... -func Chat(req *pb.VaChatReq) { +var chatCache = coral.NewLRU(loadChatForCoral, 1000, 100) + +type chatKey struct { + System string `json:"system"` + User string `json:"user"` + Temperature float32 `json:"temperature"` + MaxOutputTokens uint32 `json:"maxOutputTokens"` + TopP float32 `json:"topP"` + TopK uint32 `json:"topK"` } -func buildChatReq(system string, user []string, param *pb.VaParam) (*aiplatformpb.PredictRequest, error) { +// ChatRsp ... +type ChatRsp struct { + Answer *pb.VaChatRsp `json:"answer,omitempty"` + Raw *aiplatformpb.PredictResponse `json:"raw"` + CostMs uint32 `json:"costMs"` +} + +// Chat ... +func Chat(req *pb.VaChatReq) (*ChatRsp, error) { + + p := req.Param + if p == nil { + p = defaultParam + } + + k := chatKey{ + System: req.System, + User: req.User, + Temperature: p.Temperature, + MaxOutputTokens: p.MaxOutputTokens, + TopP: p.TopP, + TopK: p.TopK, + } + + if req.NoCache { + chatCache.Delete(k) + } else { + ab, err := util.ReadFile(chatCacheFile(k) + `.json`) + if err == nil && len(ab) > 2 { + rsp := &ChatRsp{} + err = json.Unmarshal(ab, rsp) + if err == nil { + return rsp, nil + } + } + } + + return chatCache.Get(k) +} + +func buildChatReq(k chatKey) (*aiplatformpb.PredictRequest, error) { m := map[string]any{ - `context`: system, + `context`: k.System, } - if len(user) > 0 { - var li []any - for _, v := range user { - li = append(li, map[string]any{ + if k.User != `` { + m[`messages`] = []any{ + map[string]any{ `author`: `user`, - `content`: v, - }) + `content`: k.User, + }, } - m[`messages`] = li } inst, err := structpb.NewStruct(m) @@ -40,10 +92,10 @@ func buildChatReq(system string, user []string, param *pb.VaParam) (*aiplatformp } p, err := structpb.NewStruct(map[string]any{ - `temperature`: param.Temperature, - `maxOutputTokens`: param.MaxOutputTokens, - `topP`: param.TopP, - `topK`: param.TopK, + `temperature`: k.Temperature, + `maxOutputTokens`: k.MaxOutputTokens, + `topP`: k.TopP, + `topK`: k.TopK, }) if err != nil { return nil, err @@ -95,3 +147,116 @@ func isBlocked(o *structpb.Value) bool { } return false } + +func loadChat(k chatKey) (*ChatRsp, error) { + + req, err := buildChatReq(k) + if err != nil { + return nil, err + } + + ctx, cancel := life.CTXTimeout(10 * time.Second) + t := time.Now() + rsp, err := chatClient.Predict(ctx, req) + + cancel() + if err != nil { + return nil, err + } + + r := &ChatRsp{ + Raw: rsp, + CostMs: uint32(time.Since(t) / time.Millisecond), + } + + answer := &pb.VaChatRsp{} + answer.Content, err = getChatVal(rsp) + if err == errBlocked { + err = nil + answer.Blocked = true + } + if err != nil { + return nil, err + } + + r.Answer = answer + + go chatSaveCache(k, r) + return r, nil +} + +func loadChatForCoral(k chatKey) (*ChatRsp, *time.Time, error) { + r, err := loadChat(k) + if err != nil { + return nil, nil, err + } + return r, nil, nil +} + +func chatCacheFile(k chatKey) string { + ab, _ := json.Marshal(k) + h := md5.Sum(ab) + file := fmt.Sprintf(`vertexai/chat/%02x/%02x/%02x/%x`, h[0], h[1], h[2], h[3:]) + return file +} + +func chatSaveCache(k chatKey, rsp *ChatRsp) { + file := chatCacheFile(k) + util.Mkdir(file) + util.WriteJSON(file+`.json`, rsp) +} + +// Debug ... +func (rsp *ChatRsp) Debug() *pb.VaDebug { + d := &pb.VaDebug{ + CostMs: rsp.CostMs, + } + + getToken(d, rsp.Raw) + getSafety(d, rsp.Raw) + + return d +} + +func getToken(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) { + tm := SpbMap(raw.GetMetadata(), `tokenMetadata`) + if tm == nil { + return + } + input := SpbMap(tm, `inputTokenCount`) + if input != nil { + d.InputChar = uint32(SpbMap(input, `totalBillableCharacters`).GetNumberValue()) + d.InputToken = uint32(SpbMap(input, `totalTokens`).GetNumberValue()) + } + output := SpbMap(tm, `outputTokenCount`) + if output != nil { + d.OutputChar = uint32(SpbMap(output, `totalBillableCharacters`).GetNumberValue()) + d.OutputToken = uint32(SpbMap(output, `totalTokens`).GetNumberValue()) + } +} +func getSafety(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) { + p := raw.GetPredictions() + if len(p) == 0 { + return + } + sa := SpbMap(p[0], `safetyAttributes`).GetListValue().GetValues() + + for _, v := range sa { + c := SpbMap(v, `categories`).GetListValue().GetValues() + if len(c) == 0 { + continue + } + s := SpbMap(v, `scores`).GetListValue().GetValues() + if len(s) != len(c) { + continue + } + + for i, cv := range c { + row := &pb.VaSafety{ + Category: cv.GetStringValue(), + Score: float32(s[i].GetNumberValue()), + } + d.Safety = append(d.Safety, row) + } + } +} diff --git a/server/src/vertexai/init.go b/server/src/vertexai/init.go new file mode 100644 index 0000000..03da0fe --- /dev/null +++ b/server/src/vertexai/init.go @@ -0,0 +1,23 @@ +package vertexai + +import ( + "project/util" + "project/zj" + + aiplatform "cloud.google.com/go/aiplatform/apiv1" + "github.com/zhengkai/life-go" + "google.golang.org/api/option" +) + +func init() { + + var err error + chatClient, err = aiplatform.NewPredictionClient( + life.CTX, + option.WithEndpoint(`us-central1-aiplatform.googleapis.com:443`), + option.WithCredentialsFile(util.Static(`aigc-llm-730bb179e13c.json`)), + ) + if err != nil { + zj.W(err) + } +} diff --git a/server/src/vertexai/util.go b/server/src/vertexai/util.go index ea950a0..b423f68 100644 --- a/server/src/vertexai/util.go +++ b/server/src/vertexai/util.go @@ -1,6 +1,17 @@ package vertexai -import "google.golang.org/protobuf/types/known/structpb" +import ( + "project/pb" + + "google.golang.org/protobuf/types/known/structpb" +) + +var defaultParam = &pb.VaParam{ + Temperature: 0.2, + MaxOutputTokens: 0, + TopP: 1, + TopK: 40, +} // SpbMap ... func SpbMap(o *structpb.Value, key string) *structpb.Value { diff --git a/server/src/vertexai/web.go b/server/src/vertexai/web.go new file mode 100644 index 0000000..f5baf5c --- /dev/null +++ b/server/src/vertexai/web.go @@ -0,0 +1,97 @@ +package vertexai + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "project/config" + "project/pb" + "time" +) + +// ChatHandle ... +func ChatHandle(w http.ResponseWriter, r *http.Request) { + + t := time.Now() + + data, debug, err := chatHandle(w, r) + o := &pb.VaChatWebRsp{ + Data: data, + Debug: debug, + } + if err == nil { + o.Ok = true + } else { + o.Error = err.Error() + } + if debug != nil { + i := uint32(time.Since(t) / time.Millisecond) + if i < 1 { + i = 1 + } + debug.TotalMs = i + } + ab, _ := json.Marshal(o) + w.Write(ab) +} + +func chatHandleInput(w http.ResponseWriter, r *http.Request) (*pb.VaChatReq, error) { + + if r.Method != `POST` { + w.WriteHeader(http.StatusBadRequest) + return nil, errors.New(`method not allow`) + } + + token := r.Header.Get(`VA-TOKEN`) + + if token == `` { + w.WriteHeader(http.StatusNonAuthoritativeInfo) + return nil, errors.New(`no token`) + } + if config.VAToken == `` { + w.WriteHeader(http.StatusInternalServerError) + err := errors.New(`no token in server`) + return nil, err + } + if config.VAToken != token { + w.WriteHeader(http.StatusUnauthorized) + return nil, errors.New(`token not match`) + } + + o := &pb.VaChatReq{} + + ab, err := io.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return nil, errors.New(`read body fail`) + } + err = json.Unmarshal(ab, o) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return nil, err + } + + return o, nil +} + +func chatHandle(w http.ResponseWriter, r *http.Request) (*pb.VaChatRsp, *pb.VaDebug, error) { + + req, err := chatHandleInput(w, r) + if err != nil { + return nil, nil, err + } + + rsp, err := Chat(req) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return nil, nil, err + } + + var debug *pb.VaDebug + if req.Debug { + debug = rsp.Debug() + } + + return rsp.Answer, debug, nil +} diff --git a/server/src/web/server.go b/server/src/web/server.go index 1b5a89f..5b9d330 100644 --- a/server/src/web/server.go +++ b/server/src/web/server.go @@ -4,6 +4,7 @@ import ( "net/http" "project/config" "project/core" + "project/vertexai" "project/zj" "time" @@ -19,6 +20,7 @@ func Server() { mux.Handle(`/v1/moderations`, core.NewCore()) mux.Handle(`/v1/completions`, core.NewCore()) mux.Handle(`/v1/chat/completions`, core.NewCore()) + mux.HandleFunc(`/va/chat`, vertexai.ChatHandle) s := &http.Server{ Addr: config.WebAddr,