feat: vertexai text

This commit is contained in:
Zheng Kai
2023-09-07 18:10:02 +08:00
parent 87ab452a6a
commit 8259b8b164
11 changed files with 430 additions and 249 deletions

14
misc/test/va/text.sh Executable file
View File

@@ -0,0 +1,14 @@
#!/bin/bash
URL="http://localhost:22035/va/text"
curl -s "$URL" \
-H "Content-Type: application/json" \
-H "VA-TOKEN: ${ORCA_VA_TOKEN}" \
-d '{
"prompt":"翻译下列语言为中文:\n\n3 2 1 Hello, world!",
"debug":true
}' | tee tmp-text.json
echo
jq . tmp-text.json

View File

@@ -9,14 +9,22 @@ message VaChatReq {
VaParam param = 4;
bool debug = 5;
}
message VaChatRsp {
message VaTextReq {
string prompt = 1;
bool noCache = 3;
VaParam param = 4;
bool debug = 5;
}
message VaRsp {
string content = 1;
bool blocked = 2;
}
message VaChatWebRsp {
message VaWebRsp {
bool ok = 1;
VaChatRsp data = 2;
VaRsp data = 2;
VaDebug debug = 3;
string error = 4;
}

View File

@@ -1,68 +1,39 @@
package vertexai
import (
"crypto/md5"
"encoding/json"
"errors"
"fmt"
"project/metrics"
"project/pb"
"project/util"
"time"
aiplatform "cloud.google.com/go/aiplatform/apiv1"
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
"github.com/zhengkai/coral/v2"
"github.com/zhengkai/life-go"
"google.golang.org/protobuf/types/known/structpb"
)
var chatClient *aiplatform.PredictionClient
var errEmptyAnswer = errors.New(`empty answer`)
var errBlocked = errors.New(`blocked by google`)
var chatCache = coral.NewLRU(loadChatForCoral, 1000, 100)
type chatKey struct {
System string `json:"system"`
User string `json:"user"`
Temperature float32 `json:"temperature"`
MaxOutputTokens uint32 `json:"maxOutputTokens"`
TopP float32 `json:"topP"`
TopK uint32 `json:"topK"`
}
const chatModel = `chat-bison@001`
// ChatRsp ...
type ChatRsp struct {
Answer *pb.VaChatRsp `json:"answer,omitempty"`
Raw *aiplatformpb.PredictResponse `json:"raw"`
CostMs uint32 `json:"costMs"`
type chatKey struct {
paramKey
System string `json:"system"`
User string `json:"user"`
}
// Chat ...
func Chat(req *pb.VaChatReq) (*ChatRsp, error) {
p := req.Param
if p == nil {
p = defaultParam
}
func Chat(req *pb.VaChatReq) (*Rsp, error) {
k := chatKey{
System: req.System,
User: req.User,
Temperature: p.Temperature,
MaxOutputTokens: p.MaxOutputTokens,
TopP: p.TopP,
TopK: p.TopK,
System: req.System,
User: req.User,
}
k.load(req.Param)
if req.NoCache {
chatCache.Delete(k)
} else {
ab, err := util.ReadFile(chatCacheFile(k) + `.json`)
ab, err := util.ReadFile(cacheFile(`chat`, k) + `.json`)
if err == nil && len(ab) > 2 {
rsp := &ChatRsp{}
rsp := &Rsp{}
err = json.Unmarshal(ab, rsp)
if err == nil {
return rsp, nil
@@ -73,7 +44,7 @@ func Chat(req *pb.VaChatReq) (*ChatRsp, error) {
return chatCache.Get(k)
}
func buildChatReq(k chatKey) (*aiplatformpb.PredictRequest, error) {
func loadChat(k chatKey) (*Rsp, error) {
m := map[string]any{
`context`: k.System,
@@ -87,178 +58,26 @@ func buildChatReq(k chatKey) (*aiplatformpb.PredictRequest, error) {
}
}
inst, err := structpb.NewStruct(m)
req, err := buildReq(k.paramKey, m, chatModel)
if err != nil {
return nil, err
}
p, err := structpb.NewStruct(map[string]any{
`temperature`: k.Temperature,
`maxOutputTokens`: k.MaxOutputTokens,
`topP`: k.TopP,
`topK`: k.TopK,
})
rsp, err := doRequest(req)
if err != nil {
return nil, err
}
req := &aiplatformpb.PredictRequest{
Endpoint: `projects/aigc-llm/locations/us-central1/publishers/google/models/chat-bison@001`,
Instances: []*structpb.Value{
structpb.NewStructValue(inst),
},
Parameters: structpb.NewStructValue(p),
}
return req, nil
go func() {
rsp.save(cacheFile(`chat`, k))
}()
return rsp, nil
}
func getChatVal(rsp *aiplatformpb.PredictResponse) (string, error) {
p := rsp.GetPredictions()
if len(p) == 0 {
return ``, errEmptyAnswer
}
p0 := p[0]
if isBlocked(p0) {
return ``, errBlocked
}
ca := SpbMap(p0, `candidates`).GetListValue().GetValues()
if len(ca) == 0 {
return ``, errEmptyAnswer
}
s := SpbMap(ca[0], `content`).GetStringValue()
if s == `` {
return ``, errEmptyAnswer
}
return s, nil
}
func isBlocked(o *structpb.Value) bool {
sa := SpbMap(o, `safetyAttributes`).GetListValue().GetValues()
for _, v := range sa {
if SpbMap(v, `blocked`).GetBoolValue() {
return true
}
}
return false
}
func loadChat(k chatKey) (*ChatRsp, error) {
req, err := buildChatReq(k)
if err != nil {
return nil, err
}
ctx, cancel := life.CTXTimeout(10 * time.Second)
t := time.Now()
rsp, err := chatClient.Predict(ctx, req)
cancel()
if err != nil {
return nil, err
}
r := &ChatRsp{
Raw: rsp,
CostMs: uint32(time.Since(t) / time.Millisecond),
}
metrics.VaTime(r.CostMs)
answer := &pb.VaChatRsp{}
answer.Content, err = getChatVal(rsp)
if err == errBlocked {
err = nil
answer.Blocked = true
}
if err != nil {
return nil, err
}
r.Answer = answer
go chatSaveCache(k, r)
return r, nil
}
func loadChatForCoral(k chatKey) (*ChatRsp, *time.Time, error) {
func loadChatForCoral(k chatKey) (*Rsp, *time.Time, error) {
r, err := loadChat(k)
if err != nil {
return nil, nil, err
}
return r, nil, nil
}
func chatCacheFile(k chatKey) string {
ab, _ := json.Marshal(k)
h := md5.Sum(ab)
file := fmt.Sprintf(`vertexai/chat/%02x/%02x/%02x/%x`, h[0], h[1], h[2], h[3:])
return file
}
func chatSaveCache(k chatKey, rsp *ChatRsp) {
file := chatCacheFile(k)
util.Mkdir(file)
util.WriteJSON(file+`.json`, rsp)
}
// Debug ...
func (rsp *ChatRsp) Debug() *pb.VaDebug {
d := &pb.VaDebug{
CostMs: rsp.CostMs,
}
getToken(d, rsp.Raw)
getSafety(d, rsp.Raw)
return d
}
func getToken(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) {
tm := SpbMap(raw.GetMetadata(), `tokenMetadata`)
if tm == nil {
return
}
input := SpbMap(tm, `inputTokenCount`)
if input != nil {
d.InputChar = uint32(SpbMap(input, `totalBillableCharacters`).GetNumberValue())
d.InputToken = uint32(SpbMap(input, `totalTokens`).GetNumberValue())
}
output := SpbMap(tm, `outputTokenCount`)
if output != nil {
d.OutputChar = uint32(SpbMap(output, `totalBillableCharacters`).GetNumberValue())
d.OutputToken = uint32(SpbMap(output, `totalTokens`).GetNumberValue())
}
}
func getSafety(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) {
p := raw.GetPredictions()
if len(p) == 0 {
return
}
sa := SpbMap(p[0], `safetyAttributes`).GetListValue().GetValues()
for _, v := range sa {
c := SpbMap(v, `categories`).GetListValue().GetValues()
if len(c) == 0 {
continue
}
s := SpbMap(v, `scores`).GetListValue().GetValues()
if len(s) != len(c) {
continue
}
for i, cv := range c {
row := &pb.VaSafety{
Category: cv.GetStringValue(),
Score: float32(s[i].GetNumberValue()),
}
d.Safety = append(d.Safety, row)
}
}
}

View File

@@ -0,0 +1,195 @@
package vertexai
import (
"crypto/md5"
"encoding/json"
"errors"
"fmt"
"project/metrics"
"project/pb"
"time"
aiplatform "cloud.google.com/go/aiplatform/apiv1"
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
"github.com/zhengkai/life-go"
"google.golang.org/protobuf/types/known/structpb"
)
var theClient *aiplatform.PredictionClient
var errEmptyAnswer = errors.New(`empty answer`)
var errBlocked = errors.New(`blocked by google`)
type canDebug interface {
GetDebug() bool
}
// SpbMap ...
func SpbMap(o *structpb.Value, key string) *structpb.Value {
m := o.GetStructValue().GetFields()
if m == nil {
return nil
}
v, ok := m[key]
if !ok {
return nil
}
return v
}
func cacheFile[T chatKey | textKey](t string, k T) string {
ab, _ := json.Marshal(k)
h := md5.Sum(ab)
file := fmt.Sprintf(`vertexai/%s/%02x/%02x/%02x/%x`, t, h[0], h[1], h[2], h[3:])
return file
}
func buildReq(k paramKey, m map[string]any, model string) (*aiplatformpb.PredictRequest, error) {
inst, err := structpb.NewStruct(m)
if err != nil {
return nil, err
}
p, err := structpb.NewStruct(map[string]any{
`temperature`: k.Temperature,
`maxOutputTokens`: k.MaxOutputTokens,
`topP`: k.TopP,
`topK`: k.TopK,
})
if err != nil {
return nil, err
}
endpoint := fmt.Sprintf(
`projects/aigc-llm/locations/us-central1/publishers/google/models/%s`,
model,
)
req := &aiplatformpb.PredictRequest{
Endpoint: endpoint,
Instances: []*structpb.Value{
structpb.NewStructValue(inst),
},
Parameters: structpb.NewStructValue(p),
}
return req, nil
}
func getToken(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) {
tm := SpbMap(raw.GetMetadata(), `tokenMetadata`)
if tm == nil {
return
}
input := SpbMap(tm, `inputTokenCount`)
if input != nil {
d.InputChar = uint32(SpbMap(input, `totalBillableCharacters`).GetNumberValue())
d.InputToken = uint32(SpbMap(input, `totalTokens`).GetNumberValue())
}
output := SpbMap(tm, `outputTokenCount`)
if output != nil {
d.OutputChar = uint32(SpbMap(output, `totalBillableCharacters`).GetNumberValue())
d.OutputToken = uint32(SpbMap(output, `totalTokens`).GetNumberValue())
}
}
func getSafety(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) {
p := raw.GetPredictions()
if len(p) == 0 {
return
}
sa := SpbMap(p[0], `safetyAttributes`).GetListValue().GetValues()
for _, v := range sa {
c := SpbMap(v, `categories`).GetListValue().GetValues()
if len(c) == 0 {
continue
}
s := SpbMap(v, `scores`).GetListValue().GetValues()
if len(s) != len(c) {
continue
}
for i, cv := range c {
row := &pb.VaSafety{
Category: cv.GetStringValue(),
Score: float32(s[i].GetNumberValue()),
}
d.Safety = append(d.Safety, row)
}
}
}
func isBlocked(o *structpb.Value) bool {
sa := SpbMap(o, `safetyAttributes`).GetListValue().GetValues()
for _, v := range sa {
if SpbMap(v, `blocked`).GetBoolValue() {
return true
}
}
return false
}
func getVal(rsp *aiplatformpb.PredictResponse) (string, error) {
p := rsp.GetPredictions()
if len(p) == 0 {
return ``, errEmptyAnswer
}
p0 := p[0]
if isBlocked(p0) {
return ``, errBlocked
}
// for text
c := SpbMap(p0, `content`).GetStringValue()
if c != `` {
return c, nil
}
// for chat
ca := SpbMap(p0, `candidates`).GetListValue().GetValues()
if len(ca) == 0 {
return ``, errEmptyAnswer
}
s := SpbMap(ca[0], `content`).GetStringValue()
if s == `` {
return ``, errEmptyAnswer
}
return s, nil
}
func doRequest(req *aiplatformpb.PredictRequest) (*Rsp, error) {
ctx, cancel := life.CTXTimeout(10 * time.Second)
t := time.Now()
rsp, err := theClient.Predict(ctx, req)
cancel()
if err != nil {
return nil, err
}
r := &Rsp{
Raw: rsp,
CostMs: uint32(time.Since(t) / time.Millisecond),
}
metrics.VaTime(r.CostMs)
answer := &pb.VaRsp{}
answer.Content, err = getVal(rsp)
if err == errBlocked {
err = nil
answer.Blocked = true
}
if err != nil {
return nil, err
}
r.Answer = answer
return r, nil
}

View File

@@ -12,7 +12,7 @@ import (
func init() {
var err error
chatClient, err = aiplatform.NewPredictionClient(
theClient, err = aiplatform.NewPredictionClient(
life.CTX,
option.WithEndpoint(`us-central1-aiplatform.googleapis.com:443`),
option.WithCredentialsFile(util.Static(`aigc-llm-730bb179e13c.json`)),

View File

@@ -0,0 +1,32 @@
package vertexai
import "project/pb"
var defaultParam = &pb.VaParam{
Temperature: 0.2,
MaxOutputTokens: 0,
TopP: 1,
TopK: 40,
}
// paramKey ...
type paramKey struct {
Temperature float32 `json:"temperature"`
MaxOutputTokens uint32 `json:"maxOutputTokens"`
TopP float32 `json:"topP"`
TopK uint32 `json:"topK"`
}
func (k *paramKey) load(p *pb.VaParam) {
if p == nil {
p = defaultParam
}
k.Temperature = p.Temperature
k.MaxOutputTokens = p.MaxOutputTokens
k.TopP = p.TopP
k.TopK = p.TopK
if k.MaxOutputTokens == 0 {
k.MaxOutputTokens = 512
}
}

View File

@@ -0,0 +1,31 @@
package vertexai
import (
"project/pb"
"project/util"
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
)
// Rsp ...
type Rsp struct {
Answer *pb.VaRsp `json:"answer,omitempty"`
Raw *aiplatformpb.PredictResponse `json:"raw"`
CostMs uint32 `json:"costMs"`
}
// Debug ...
func (rsp *Rsp) Debug() *pb.VaDebug {
d := &pb.VaDebug{
CostMs: rsp.CostMs,
}
getToken(d, rsp.Raw)
getSafety(d, rsp.Raw)
return d
}
func (rsp *Rsp) save(file string) {
util.Mkdir(file)
util.WriteJSON(file+`.json`, rsp)
}

View File

@@ -0,0 +1,73 @@
package vertexai
import (
"encoding/json"
"project/pb"
"project/util"
"time"
"github.com/zhengkai/coral/v2"
)
var textCache = coral.NewLRU(loadTextForCoral, 1000, 100)
const textModel = `text-bison@001`
type textKey struct {
paramKey
Prompt string `json:"system"`
}
// Text ...
func Text(req *pb.VaTextReq) (*Rsp, error) {
k := textKey{
Prompt: req.Prompt,
}
k.load(req.Param)
if req.NoCache {
textCache.Delete(k)
} else {
ab, err := util.ReadFile(cacheFile(`text`, k) + `.json`)
if err == nil && len(ab) > 2 {
rsp := &Rsp{}
err = json.Unmarshal(ab, rsp)
if err == nil {
return rsp, nil
}
}
}
return textCache.Get(k)
}
func loadText(k textKey) (*Rsp, error) {
m := map[string]any{
`prompt`: k.Prompt,
}
req, err := buildReq(k.paramKey, m, textModel)
if err != nil {
return nil, err
}
rsp, err := doRequest(req)
if err != nil {
return nil, err
}
go func() {
rsp.save(cacheFile(`text`, k))
}()
return rsp, nil
}
func loadTextForCoral(k textKey) (*Rsp, *time.Time, error) {
r, err := loadText(k)
if err != nil {
return nil, nil, err
}
return r, nil, nil
}

View File

@@ -1,27 +0,0 @@
package vertexai
import (
"project/pb"
"google.golang.org/protobuf/types/known/structpb"
)
var defaultParam = &pb.VaParam{
Temperature: 0.2,
MaxOutputTokens: 0,
TopP: 1,
TopK: 40,
}
// SpbMap ...
func SpbMap(o *structpb.Value, key string) *structpb.Value {
m := o.GetStructValue().GetFields()
if m == nil {
return nil
}
v, ok := m[key]
if !ok {
return nil
}
return v
}

View File

@@ -10,13 +10,21 @@ import (
"time"
)
// ChatHandle ...
func ChatHandle(w http.ResponseWriter, r *http.Request) {
// web type
const (
WebText Web = iota + 1
WebChat
)
// Web ...
type Web int
func (web Web) ServeHTTP(w http.ResponseWriter, r *http.Request) {
t := time.Now()
data, debug, err := chatHandle(w, r)
o := &pb.VaChatWebRsp{
data, debug, err := handle(web, w, r)
o := &pb.VaWebRsp{
Data: data,
Debug: debug,
}
@@ -36,7 +44,7 @@ func ChatHandle(w http.ResponseWriter, r *http.Request) {
w.Write(ab)
}
func chatHandleInput(w http.ResponseWriter, r *http.Request) (*pb.VaChatReq, error) {
func chatHandleInput(w http.ResponseWriter, r *http.Request) ([]byte, error) {
if r.Method != `POST` {
w.WriteHeader(http.StatusBadRequest)
@@ -59,39 +67,66 @@ func chatHandleInput(w http.ResponseWriter, r *http.Request) (*pb.VaChatReq, err
return nil, errors.New(`token not match`)
}
o := &pb.VaChatReq{}
ab, err := io.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return nil, errors.New(`read body fail`)
}
err = json.Unmarshal(ab, o)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return nil, err
}
return o, nil
return ab, nil
}
func chatHandle(w http.ResponseWriter, r *http.Request) (*pb.VaChatRsp, *pb.VaDebug, error) {
func handle(web Web, w http.ResponseWriter, r *http.Request) (*pb.VaRsp, *pb.VaDebug, error) {
req, err := chatHandleInput(w, r)
ab, err := chatHandleInput(w, r)
if err != nil {
return nil, nil, err
}
rsp, err := Chat(req)
var rsp *Rsp
isDebug := false
if web == WebChat {
req := &pb.VaChatReq{}
isDebug, err = unmarshalWebReq(w, ab, req)
if err != nil {
return nil, nil, err
}
rsp, err = Chat(req)
} else {
req := &pb.VaTextReq{}
isDebug, err = unmarshalWebReq(w, ab, req)
if err != nil {
return nil, nil, err
}
rsp, err = Text(req)
}
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return nil, nil, err
}
var debug *pb.VaDebug
if req.Debug {
if isDebug {
debug = rsp.Debug()
}
return rsp.Answer, debug, nil
}
func unmarshalWebReq(w http.ResponseWriter, ab []byte, req canDebug) (bool, error) {
err := json.Unmarshal(ab, req)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return false, err
}
return req.GetDebug(), nil
}

View File

@@ -20,7 +20,8 @@ func Server() {
mux.Handle(`/v1/moderations`, core.NewCore())
mux.Handle(`/v1/completions`, core.NewCore())
mux.Handle(`/v1/chat/completions`, core.NewCore())
mux.HandleFunc(`/va/chat`, vertexai.ChatHandle)
mux.Handle(`/va/chat`, vertexai.WebChat)
mux.Handle(`/va/text`, vertexai.WebText)
s := &http.Server{
Addr: config.WebAddr,