mirror of
https://github.com/zhengkai/orca.git
synced 2026-02-04 13:32:27 +08:00
feat: vertexai text
This commit is contained in:
14
misc/test/va/text.sh
Executable file
14
misc/test/va/text.sh
Executable file
@@ -0,0 +1,14 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
URL="http://localhost:22035/va/text"
|
||||||
|
|
||||||
|
curl -s "$URL" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "VA-TOKEN: ${ORCA_VA_TOKEN}" \
|
||||||
|
-d '{
|
||||||
|
"prompt":"翻译下列语言为中文:\n\n3 2 1 Hello, world!",
|
||||||
|
"debug":true
|
||||||
|
}' | tee tmp-text.json
|
||||||
|
echo
|
||||||
|
|
||||||
|
jq . tmp-text.json
|
||||||
@@ -9,14 +9,22 @@ message VaChatReq {
|
|||||||
VaParam param = 4;
|
VaParam param = 4;
|
||||||
bool debug = 5;
|
bool debug = 5;
|
||||||
}
|
}
|
||||||
message VaChatRsp {
|
|
||||||
|
message VaTextReq {
|
||||||
|
string prompt = 1;
|
||||||
|
bool noCache = 3;
|
||||||
|
VaParam param = 4;
|
||||||
|
bool debug = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message VaRsp {
|
||||||
string content = 1;
|
string content = 1;
|
||||||
bool blocked = 2;
|
bool blocked = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message VaChatWebRsp {
|
message VaWebRsp {
|
||||||
bool ok = 1;
|
bool ok = 1;
|
||||||
VaChatRsp data = 2;
|
VaRsp data = 2;
|
||||||
VaDebug debug = 3;
|
VaDebug debug = 3;
|
||||||
string error = 4;
|
string error = 4;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,68 +1,39 @@
|
|||||||
package vertexai
|
package vertexai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/md5"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"project/metrics"
|
|
||||||
"project/pb"
|
"project/pb"
|
||||||
"project/util"
|
"project/util"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
aiplatform "cloud.google.com/go/aiplatform/apiv1"
|
|
||||||
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
|
|
||||||
"github.com/zhengkai/coral/v2"
|
"github.com/zhengkai/coral/v2"
|
||||||
"github.com/zhengkai/life-go"
|
|
||||||
"google.golang.org/protobuf/types/known/structpb"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var chatClient *aiplatform.PredictionClient
|
|
||||||
|
|
||||||
var errEmptyAnswer = errors.New(`empty answer`)
|
|
||||||
var errBlocked = errors.New(`blocked by google`)
|
|
||||||
|
|
||||||
var chatCache = coral.NewLRU(loadChatForCoral, 1000, 100)
|
var chatCache = coral.NewLRU(loadChatForCoral, 1000, 100)
|
||||||
|
|
||||||
type chatKey struct {
|
const chatModel = `chat-bison@001`
|
||||||
System string `json:"system"`
|
|
||||||
User string `json:"user"`
|
|
||||||
Temperature float32 `json:"temperature"`
|
|
||||||
MaxOutputTokens uint32 `json:"maxOutputTokens"`
|
|
||||||
TopP float32 `json:"topP"`
|
|
||||||
TopK uint32 `json:"topK"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatRsp ...
|
type chatKey struct {
|
||||||
type ChatRsp struct {
|
paramKey
|
||||||
Answer *pb.VaChatRsp `json:"answer,omitempty"`
|
System string `json:"system"`
|
||||||
Raw *aiplatformpb.PredictResponse `json:"raw"`
|
User string `json:"user"`
|
||||||
CostMs uint32 `json:"costMs"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Chat ...
|
// Chat ...
|
||||||
func Chat(req *pb.VaChatReq) (*ChatRsp, error) {
|
func Chat(req *pb.VaChatReq) (*Rsp, error) {
|
||||||
|
|
||||||
p := req.Param
|
|
||||||
if p == nil {
|
|
||||||
p = defaultParam
|
|
||||||
}
|
|
||||||
|
|
||||||
k := chatKey{
|
k := chatKey{
|
||||||
System: req.System,
|
System: req.System,
|
||||||
User: req.User,
|
User: req.User,
|
||||||
Temperature: p.Temperature,
|
|
||||||
MaxOutputTokens: p.MaxOutputTokens,
|
|
||||||
TopP: p.TopP,
|
|
||||||
TopK: p.TopK,
|
|
||||||
}
|
}
|
||||||
|
k.load(req.Param)
|
||||||
|
|
||||||
if req.NoCache {
|
if req.NoCache {
|
||||||
chatCache.Delete(k)
|
chatCache.Delete(k)
|
||||||
} else {
|
} else {
|
||||||
ab, err := util.ReadFile(chatCacheFile(k) + `.json`)
|
ab, err := util.ReadFile(cacheFile(`chat`, k) + `.json`)
|
||||||
if err == nil && len(ab) > 2 {
|
if err == nil && len(ab) > 2 {
|
||||||
rsp := &ChatRsp{}
|
rsp := &Rsp{}
|
||||||
err = json.Unmarshal(ab, rsp)
|
err = json.Unmarshal(ab, rsp)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return rsp, nil
|
return rsp, nil
|
||||||
@@ -73,7 +44,7 @@ func Chat(req *pb.VaChatReq) (*ChatRsp, error) {
|
|||||||
return chatCache.Get(k)
|
return chatCache.Get(k)
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildChatReq(k chatKey) (*aiplatformpb.PredictRequest, error) {
|
func loadChat(k chatKey) (*Rsp, error) {
|
||||||
|
|
||||||
m := map[string]any{
|
m := map[string]any{
|
||||||
`context`: k.System,
|
`context`: k.System,
|
||||||
@@ -87,178 +58,26 @@ func buildChatReq(k chatKey) (*aiplatformpb.PredictRequest, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inst, err := structpb.NewStruct(m)
|
req, err := buildReq(k.paramKey, m, chatModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := structpb.NewStruct(map[string]any{
|
rsp, err := doRequest(req)
|
||||||
`temperature`: k.Temperature,
|
|
||||||
`maxOutputTokens`: k.MaxOutputTokens,
|
|
||||||
`topP`: k.TopP,
|
|
||||||
`topK`: k.TopK,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
req := &aiplatformpb.PredictRequest{
|
go func() {
|
||||||
Endpoint: `projects/aigc-llm/locations/us-central1/publishers/google/models/chat-bison@001`,
|
rsp.save(cacheFile(`chat`, k))
|
||||||
Instances: []*structpb.Value{
|
}()
|
||||||
structpb.NewStructValue(inst),
|
return rsp, nil
|
||||||
},
|
|
||||||
Parameters: structpb.NewStructValue(p),
|
|
||||||
}
|
|
||||||
|
|
||||||
return req, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getChatVal(rsp *aiplatformpb.PredictResponse) (string, error) {
|
func loadChatForCoral(k chatKey) (*Rsp, *time.Time, error) {
|
||||||
|
|
||||||
p := rsp.GetPredictions()
|
|
||||||
if len(p) == 0 {
|
|
||||||
return ``, errEmptyAnswer
|
|
||||||
}
|
|
||||||
|
|
||||||
p0 := p[0]
|
|
||||||
|
|
||||||
if isBlocked(p0) {
|
|
||||||
return ``, errBlocked
|
|
||||||
}
|
|
||||||
|
|
||||||
ca := SpbMap(p0, `candidates`).GetListValue().GetValues()
|
|
||||||
if len(ca) == 0 {
|
|
||||||
return ``, errEmptyAnswer
|
|
||||||
}
|
|
||||||
|
|
||||||
s := SpbMap(ca[0], `content`).GetStringValue()
|
|
||||||
if s == `` {
|
|
||||||
return ``, errEmptyAnswer
|
|
||||||
}
|
|
||||||
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isBlocked(o *structpb.Value) bool {
|
|
||||||
sa := SpbMap(o, `safetyAttributes`).GetListValue().GetValues()
|
|
||||||
for _, v := range sa {
|
|
||||||
if SpbMap(v, `blocked`).GetBoolValue() {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadChat(k chatKey) (*ChatRsp, error) {
|
|
||||||
|
|
||||||
req, err := buildChatReq(k)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := life.CTXTimeout(10 * time.Second)
|
|
||||||
t := time.Now()
|
|
||||||
rsp, err := chatClient.Predict(ctx, req)
|
|
||||||
|
|
||||||
cancel()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
r := &ChatRsp{
|
|
||||||
Raw: rsp,
|
|
||||||
CostMs: uint32(time.Since(t) / time.Millisecond),
|
|
||||||
}
|
|
||||||
metrics.VaTime(r.CostMs)
|
|
||||||
|
|
||||||
answer := &pb.VaChatRsp{}
|
|
||||||
answer.Content, err = getChatVal(rsp)
|
|
||||||
if err == errBlocked {
|
|
||||||
err = nil
|
|
||||||
answer.Blocked = true
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Answer = answer
|
|
||||||
|
|
||||||
go chatSaveCache(k, r)
|
|
||||||
return r, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadChatForCoral(k chatKey) (*ChatRsp, *time.Time, error) {
|
|
||||||
r, err := loadChat(k)
|
r, err := loadChat(k)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
return r, nil, nil
|
return r, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func chatCacheFile(k chatKey) string {
|
|
||||||
ab, _ := json.Marshal(k)
|
|
||||||
h := md5.Sum(ab)
|
|
||||||
file := fmt.Sprintf(`vertexai/chat/%02x/%02x/%02x/%x`, h[0], h[1], h[2], h[3:])
|
|
||||||
return file
|
|
||||||
}
|
|
||||||
|
|
||||||
func chatSaveCache(k chatKey, rsp *ChatRsp) {
|
|
||||||
file := chatCacheFile(k)
|
|
||||||
util.Mkdir(file)
|
|
||||||
util.WriteJSON(file+`.json`, rsp)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug ...
|
|
||||||
func (rsp *ChatRsp) Debug() *pb.VaDebug {
|
|
||||||
d := &pb.VaDebug{
|
|
||||||
CostMs: rsp.CostMs,
|
|
||||||
}
|
|
||||||
|
|
||||||
getToken(d, rsp.Raw)
|
|
||||||
getSafety(d, rsp.Raw)
|
|
||||||
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
func getToken(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) {
|
|
||||||
tm := SpbMap(raw.GetMetadata(), `tokenMetadata`)
|
|
||||||
if tm == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
input := SpbMap(tm, `inputTokenCount`)
|
|
||||||
if input != nil {
|
|
||||||
d.InputChar = uint32(SpbMap(input, `totalBillableCharacters`).GetNumberValue())
|
|
||||||
d.InputToken = uint32(SpbMap(input, `totalTokens`).GetNumberValue())
|
|
||||||
}
|
|
||||||
output := SpbMap(tm, `outputTokenCount`)
|
|
||||||
if output != nil {
|
|
||||||
d.OutputChar = uint32(SpbMap(output, `totalBillableCharacters`).GetNumberValue())
|
|
||||||
d.OutputToken = uint32(SpbMap(output, `totalTokens`).GetNumberValue())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func getSafety(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) {
|
|
||||||
p := raw.GetPredictions()
|
|
||||||
if len(p) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
sa := SpbMap(p[0], `safetyAttributes`).GetListValue().GetValues()
|
|
||||||
|
|
||||||
for _, v := range sa {
|
|
||||||
c := SpbMap(v, `categories`).GetListValue().GetValues()
|
|
||||||
if len(c) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
s := SpbMap(v, `scores`).GetListValue().GetValues()
|
|
||||||
if len(s) != len(c) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, cv := range c {
|
|
||||||
row := &pb.VaSafety{
|
|
||||||
Category: cv.GetStringValue(),
|
|
||||||
Score: float32(s[i].GetNumberValue()),
|
|
||||||
}
|
|
||||||
d.Safety = append(d.Safety, row)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
195
server/src/vertexai/common.go
Normal file
195
server/src/vertexai/common.go
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
package vertexai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/md5"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"project/metrics"
|
||||||
|
"project/pb"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
aiplatform "cloud.google.com/go/aiplatform/apiv1"
|
||||||
|
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
|
||||||
|
"github.com/zhengkai/life-go"
|
||||||
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
)
|
||||||
|
|
||||||
|
var theClient *aiplatform.PredictionClient
|
||||||
|
var errEmptyAnswer = errors.New(`empty answer`)
|
||||||
|
var errBlocked = errors.New(`blocked by google`)
|
||||||
|
|
||||||
|
type canDebug interface {
|
||||||
|
GetDebug() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// SpbMap ...
|
||||||
|
func SpbMap(o *structpb.Value, key string) *structpb.Value {
|
||||||
|
m := o.GetStructValue().GetFields()
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
v, ok := m[key]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func cacheFile[T chatKey | textKey](t string, k T) string {
|
||||||
|
ab, _ := json.Marshal(k)
|
||||||
|
h := md5.Sum(ab)
|
||||||
|
file := fmt.Sprintf(`vertexai/%s/%02x/%02x/%02x/%x`, t, h[0], h[1], h[2], h[3:])
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildReq(k paramKey, m map[string]any, model string) (*aiplatformpb.PredictRequest, error) {
|
||||||
|
|
||||||
|
inst, err := structpb.NewStruct(m)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := structpb.NewStruct(map[string]any{
|
||||||
|
`temperature`: k.Temperature,
|
||||||
|
`maxOutputTokens`: k.MaxOutputTokens,
|
||||||
|
`topP`: k.TopP,
|
||||||
|
`topK`: k.TopK,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := fmt.Sprintf(
|
||||||
|
`projects/aigc-llm/locations/us-central1/publishers/google/models/%s`,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
|
||||||
|
req := &aiplatformpb.PredictRequest{
|
||||||
|
Endpoint: endpoint,
|
||||||
|
Instances: []*structpb.Value{
|
||||||
|
structpb.NewStructValue(inst),
|
||||||
|
},
|
||||||
|
Parameters: structpb.NewStructValue(p),
|
||||||
|
}
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getToken(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) {
|
||||||
|
tm := SpbMap(raw.GetMetadata(), `tokenMetadata`)
|
||||||
|
if tm == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
input := SpbMap(tm, `inputTokenCount`)
|
||||||
|
if input != nil {
|
||||||
|
d.InputChar = uint32(SpbMap(input, `totalBillableCharacters`).GetNumberValue())
|
||||||
|
d.InputToken = uint32(SpbMap(input, `totalTokens`).GetNumberValue())
|
||||||
|
}
|
||||||
|
output := SpbMap(tm, `outputTokenCount`)
|
||||||
|
if output != nil {
|
||||||
|
d.OutputChar = uint32(SpbMap(output, `totalBillableCharacters`).GetNumberValue())
|
||||||
|
d.OutputToken = uint32(SpbMap(output, `totalTokens`).GetNumberValue())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func getSafety(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) {
|
||||||
|
p := raw.GetPredictions()
|
||||||
|
if len(p) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sa := SpbMap(p[0], `safetyAttributes`).GetListValue().GetValues()
|
||||||
|
|
||||||
|
for _, v := range sa {
|
||||||
|
c := SpbMap(v, `categories`).GetListValue().GetValues()
|
||||||
|
if len(c) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s := SpbMap(v, `scores`).GetListValue().GetValues()
|
||||||
|
if len(s) != len(c) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, cv := range c {
|
||||||
|
row := &pb.VaSafety{
|
||||||
|
Category: cv.GetStringValue(),
|
||||||
|
Score: float32(s[i].GetNumberValue()),
|
||||||
|
}
|
||||||
|
d.Safety = append(d.Safety, row)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBlocked(o *structpb.Value) bool {
|
||||||
|
sa := SpbMap(o, `safetyAttributes`).GetListValue().GetValues()
|
||||||
|
for _, v := range sa {
|
||||||
|
if SpbMap(v, `blocked`).GetBoolValue() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func getVal(rsp *aiplatformpb.PredictResponse) (string, error) {
|
||||||
|
|
||||||
|
p := rsp.GetPredictions()
|
||||||
|
if len(p) == 0 {
|
||||||
|
return ``, errEmptyAnswer
|
||||||
|
}
|
||||||
|
|
||||||
|
p0 := p[0]
|
||||||
|
|
||||||
|
if isBlocked(p0) {
|
||||||
|
return ``, errBlocked
|
||||||
|
}
|
||||||
|
|
||||||
|
// for text
|
||||||
|
c := SpbMap(p0, `content`).GetStringValue()
|
||||||
|
if c != `` {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// for chat
|
||||||
|
ca := SpbMap(p0, `candidates`).GetListValue().GetValues()
|
||||||
|
if len(ca) == 0 {
|
||||||
|
return ``, errEmptyAnswer
|
||||||
|
}
|
||||||
|
|
||||||
|
s := SpbMap(ca[0], `content`).GetStringValue()
|
||||||
|
if s == `` {
|
||||||
|
return ``, errEmptyAnswer
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func doRequest(req *aiplatformpb.PredictRequest) (*Rsp, error) {
|
||||||
|
|
||||||
|
ctx, cancel := life.CTXTimeout(10 * time.Second)
|
||||||
|
t := time.Now()
|
||||||
|
rsp, err := theClient.Predict(ctx, req)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
r := &Rsp{
|
||||||
|
Raw: rsp,
|
||||||
|
CostMs: uint32(time.Since(t) / time.Millisecond),
|
||||||
|
}
|
||||||
|
metrics.VaTime(r.CostMs)
|
||||||
|
|
||||||
|
answer := &pb.VaRsp{}
|
||||||
|
answer.Content, err = getVal(rsp)
|
||||||
|
if err == errBlocked {
|
||||||
|
err = nil
|
||||||
|
answer.Blocked = true
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Answer = answer
|
||||||
|
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
func init() {
|
func init() {
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
chatClient, err = aiplatform.NewPredictionClient(
|
theClient, err = aiplatform.NewPredictionClient(
|
||||||
life.CTX,
|
life.CTX,
|
||||||
option.WithEndpoint(`us-central1-aiplatform.googleapis.com:443`),
|
option.WithEndpoint(`us-central1-aiplatform.googleapis.com:443`),
|
||||||
option.WithCredentialsFile(util.Static(`aigc-llm-730bb179e13c.json`)),
|
option.WithCredentialsFile(util.Static(`aigc-llm-730bb179e13c.json`)),
|
||||||
|
|||||||
32
server/src/vertexai/param.go
Normal file
32
server/src/vertexai/param.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package vertexai
|
||||||
|
|
||||||
|
import "project/pb"
|
||||||
|
|
||||||
|
var defaultParam = &pb.VaParam{
|
||||||
|
Temperature: 0.2,
|
||||||
|
MaxOutputTokens: 0,
|
||||||
|
TopP: 1,
|
||||||
|
TopK: 40,
|
||||||
|
}
|
||||||
|
|
||||||
|
// paramKey ...
|
||||||
|
type paramKey struct {
|
||||||
|
Temperature float32 `json:"temperature"`
|
||||||
|
MaxOutputTokens uint32 `json:"maxOutputTokens"`
|
||||||
|
TopP float32 `json:"topP"`
|
||||||
|
TopK uint32 `json:"topK"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *paramKey) load(p *pb.VaParam) {
|
||||||
|
if p == nil {
|
||||||
|
p = defaultParam
|
||||||
|
}
|
||||||
|
k.Temperature = p.Temperature
|
||||||
|
k.MaxOutputTokens = p.MaxOutputTokens
|
||||||
|
k.TopP = p.TopP
|
||||||
|
k.TopK = p.TopK
|
||||||
|
|
||||||
|
if k.MaxOutputTokens == 0 {
|
||||||
|
k.MaxOutputTokens = 512
|
||||||
|
}
|
||||||
|
}
|
||||||
31
server/src/vertexai/rsp.go
Normal file
31
server/src/vertexai/rsp.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package vertexai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"project/pb"
|
||||||
|
"project/util"
|
||||||
|
|
||||||
|
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Rsp ...
|
||||||
|
type Rsp struct {
|
||||||
|
Answer *pb.VaRsp `json:"answer,omitempty"`
|
||||||
|
Raw *aiplatformpb.PredictResponse `json:"raw"`
|
||||||
|
CostMs uint32 `json:"costMs"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug ...
|
||||||
|
func (rsp *Rsp) Debug() *pb.VaDebug {
|
||||||
|
d := &pb.VaDebug{
|
||||||
|
CostMs: rsp.CostMs,
|
||||||
|
}
|
||||||
|
|
||||||
|
getToken(d, rsp.Raw)
|
||||||
|
getSafety(d, rsp.Raw)
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rsp *Rsp) save(file string) {
|
||||||
|
util.Mkdir(file)
|
||||||
|
util.WriteJSON(file+`.json`, rsp)
|
||||||
|
}
|
||||||
73
server/src/vertexai/text.go
Normal file
73
server/src/vertexai/text.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package vertexai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"project/pb"
|
||||||
|
"project/util"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zhengkai/coral/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var textCache = coral.NewLRU(loadTextForCoral, 1000, 100)
|
||||||
|
|
||||||
|
const textModel = `text-bison@001`
|
||||||
|
|
||||||
|
type textKey struct {
|
||||||
|
paramKey
|
||||||
|
Prompt string `json:"system"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Text ...
|
||||||
|
func Text(req *pb.VaTextReq) (*Rsp, error) {
|
||||||
|
|
||||||
|
k := textKey{
|
||||||
|
Prompt: req.Prompt,
|
||||||
|
}
|
||||||
|
k.load(req.Param)
|
||||||
|
|
||||||
|
if req.NoCache {
|
||||||
|
textCache.Delete(k)
|
||||||
|
} else {
|
||||||
|
ab, err := util.ReadFile(cacheFile(`text`, k) + `.json`)
|
||||||
|
if err == nil && len(ab) > 2 {
|
||||||
|
rsp := &Rsp{}
|
||||||
|
err = json.Unmarshal(ab, rsp)
|
||||||
|
if err == nil {
|
||||||
|
return rsp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return textCache.Get(k)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadText(k textKey) (*Rsp, error) {
|
||||||
|
|
||||||
|
m := map[string]any{
|
||||||
|
`prompt`: k.Prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := buildReq(k.paramKey, m, textModel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rsp, err := doRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
rsp.save(cacheFile(`text`, k))
|
||||||
|
}()
|
||||||
|
return rsp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadTextForCoral(k textKey) (*Rsp, *time.Time, error) {
|
||||||
|
r, err := loadText(k)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return r, nil, nil
|
||||||
|
}
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
package vertexai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"project/pb"
|
|
||||||
|
|
||||||
"google.golang.org/protobuf/types/known/structpb"
|
|
||||||
)
|
|
||||||
|
|
||||||
var defaultParam = &pb.VaParam{
|
|
||||||
Temperature: 0.2,
|
|
||||||
MaxOutputTokens: 0,
|
|
||||||
TopP: 1,
|
|
||||||
TopK: 40,
|
|
||||||
}
|
|
||||||
|
|
||||||
// SpbMap ...
|
|
||||||
func SpbMap(o *structpb.Value, key string) *structpb.Value {
|
|
||||||
m := o.GetStructValue().GetFields()
|
|
||||||
if m == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
v, ok := m[key]
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
@@ -10,13 +10,21 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ChatHandle ...
|
// web type
|
||||||
func ChatHandle(w http.ResponseWriter, r *http.Request) {
|
const (
|
||||||
|
WebText Web = iota + 1
|
||||||
|
WebChat
|
||||||
|
)
|
||||||
|
|
||||||
|
// Web ...
|
||||||
|
type Web int
|
||||||
|
|
||||||
|
func (web Web) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
t := time.Now()
|
t := time.Now()
|
||||||
|
|
||||||
data, debug, err := chatHandle(w, r)
|
data, debug, err := handle(web, w, r)
|
||||||
o := &pb.VaChatWebRsp{
|
o := &pb.VaWebRsp{
|
||||||
Data: data,
|
Data: data,
|
||||||
Debug: debug,
|
Debug: debug,
|
||||||
}
|
}
|
||||||
@@ -36,7 +44,7 @@ func ChatHandle(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.Write(ab)
|
w.Write(ab)
|
||||||
}
|
}
|
||||||
|
|
||||||
func chatHandleInput(w http.ResponseWriter, r *http.Request) (*pb.VaChatReq, error) {
|
func chatHandleInput(w http.ResponseWriter, r *http.Request) ([]byte, error) {
|
||||||
|
|
||||||
if r.Method != `POST` {
|
if r.Method != `POST` {
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
@@ -59,39 +67,66 @@ func chatHandleInput(w http.ResponseWriter, r *http.Request) (*pb.VaChatReq, err
|
|||||||
return nil, errors.New(`token not match`)
|
return nil, errors.New(`token not match`)
|
||||||
}
|
}
|
||||||
|
|
||||||
o := &pb.VaChatReq{}
|
|
||||||
|
|
||||||
ab, err := io.ReadAll(r.Body)
|
ab, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
return nil, errors.New(`read body fail`)
|
return nil, errors.New(`read body fail`)
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(ab, o)
|
|
||||||
if err != nil {
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return o, nil
|
return ab, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func chatHandle(w http.ResponseWriter, r *http.Request) (*pb.VaChatRsp, *pb.VaDebug, error) {
|
func handle(web Web, w http.ResponseWriter, r *http.Request) (*pb.VaRsp, *pb.VaDebug, error) {
|
||||||
|
|
||||||
req, err := chatHandleInput(w, r)
|
ab, err := chatHandleInput(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
rsp, err := Chat(req)
|
var rsp *Rsp
|
||||||
|
isDebug := false
|
||||||
|
|
||||||
|
if web == WebChat {
|
||||||
|
|
||||||
|
req := &pb.VaChatReq{}
|
||||||
|
isDebug, err = unmarshalWebReq(w, ab, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rsp, err = Chat(req)
|
||||||
|
|
||||||
|
} else {
|
||||||
|
|
||||||
|
req := &pb.VaTextReq{}
|
||||||
|
isDebug, err = unmarshalWebReq(w, ab, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rsp, err = Text(req)
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var debug *pb.VaDebug
|
var debug *pb.VaDebug
|
||||||
if req.Debug {
|
if isDebug {
|
||||||
debug = rsp.Debug()
|
debug = rsp.Debug()
|
||||||
}
|
}
|
||||||
|
|
||||||
return rsp.Answer, debug, nil
|
return rsp.Answer, debug, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func unmarshalWebReq(w http.ResponseWriter, ab []byte, req canDebug) (bool, error) {
|
||||||
|
|
||||||
|
err := json.Unmarshal(ab, req)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return req.GetDebug(), nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,7 +20,8 @@ func Server() {
|
|||||||
mux.Handle(`/v1/moderations`, core.NewCore())
|
mux.Handle(`/v1/moderations`, core.NewCore())
|
||||||
mux.Handle(`/v1/completions`, core.NewCore())
|
mux.Handle(`/v1/completions`, core.NewCore())
|
||||||
mux.Handle(`/v1/chat/completions`, core.NewCore())
|
mux.Handle(`/v1/chat/completions`, core.NewCore())
|
||||||
mux.HandleFunc(`/va/chat`, vertexai.ChatHandle)
|
mux.Handle(`/va/chat`, vertexai.WebChat)
|
||||||
|
mux.Handle(`/va/text`, vertexai.WebText)
|
||||||
|
|
||||||
s := &http.Server{
|
s := &http.Server{
|
||||||
Addr: config.WebAddr,
|
Addr: config.WebAddr,
|
||||||
|
|||||||
Reference in New Issue
Block a user