mirror of
https://github.com/zhengkai/orca.git
synced 2025-12-19 18:32:23 +08:00
feat: vertexai text
This commit is contained in:
14
misc/test/va/text.sh
Executable file
14
misc/test/va/text.sh
Executable file
@@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
|
||||
URL="http://localhost:22035/va/text"
|
||||
|
||||
curl -s "$URL" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "VA-TOKEN: ${ORCA_VA_TOKEN}" \
|
||||
-d '{
|
||||
"prompt":"翻译下列语言为中文:\n\n3 2 1 Hello, world!",
|
||||
"debug":true
|
||||
}' | tee tmp-text.json
|
||||
echo
|
||||
|
||||
jq . tmp-text.json
|
||||
@@ -9,14 +9,22 @@ message VaChatReq {
|
||||
VaParam param = 4;
|
||||
bool debug = 5;
|
||||
}
|
||||
message VaChatRsp {
|
||||
|
||||
message VaTextReq {
|
||||
string prompt = 1;
|
||||
bool noCache = 3;
|
||||
VaParam param = 4;
|
||||
bool debug = 5;
|
||||
}
|
||||
|
||||
message VaRsp {
|
||||
string content = 1;
|
||||
bool blocked = 2;
|
||||
}
|
||||
|
||||
message VaChatWebRsp {
|
||||
message VaWebRsp {
|
||||
bool ok = 1;
|
||||
VaChatRsp data = 2;
|
||||
VaRsp data = 2;
|
||||
VaDebug debug = 3;
|
||||
string error = 4;
|
||||
}
|
||||
|
||||
@@ -1,68 +1,39 @@
|
||||
package vertexai
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"project/metrics"
|
||||
"project/pb"
|
||||
"project/util"
|
||||
"time"
|
||||
|
||||
aiplatform "cloud.google.com/go/aiplatform/apiv1"
|
||||
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
|
||||
"github.com/zhengkai/coral/v2"
|
||||
"github.com/zhengkai/life-go"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
)
|
||||
|
||||
var chatClient *aiplatform.PredictionClient
|
||||
|
||||
var errEmptyAnswer = errors.New(`empty answer`)
|
||||
var errBlocked = errors.New(`blocked by google`)
|
||||
|
||||
var chatCache = coral.NewLRU(loadChatForCoral, 1000, 100)
|
||||
|
||||
type chatKey struct {
|
||||
System string `json:"system"`
|
||||
User string `json:"user"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
MaxOutputTokens uint32 `json:"maxOutputTokens"`
|
||||
TopP float32 `json:"topP"`
|
||||
TopK uint32 `json:"topK"`
|
||||
}
|
||||
const chatModel = `chat-bison@001`
|
||||
|
||||
// ChatRsp ...
|
||||
type ChatRsp struct {
|
||||
Answer *pb.VaChatRsp `json:"answer,omitempty"`
|
||||
Raw *aiplatformpb.PredictResponse `json:"raw"`
|
||||
CostMs uint32 `json:"costMs"`
|
||||
type chatKey struct {
|
||||
paramKey
|
||||
System string `json:"system"`
|
||||
User string `json:"user"`
|
||||
}
|
||||
|
||||
// Chat ...
|
||||
func Chat(req *pb.VaChatReq) (*ChatRsp, error) {
|
||||
|
||||
p := req.Param
|
||||
if p == nil {
|
||||
p = defaultParam
|
||||
}
|
||||
func Chat(req *pb.VaChatReq) (*Rsp, error) {
|
||||
|
||||
k := chatKey{
|
||||
System: req.System,
|
||||
User: req.User,
|
||||
Temperature: p.Temperature,
|
||||
MaxOutputTokens: p.MaxOutputTokens,
|
||||
TopP: p.TopP,
|
||||
TopK: p.TopK,
|
||||
System: req.System,
|
||||
User: req.User,
|
||||
}
|
||||
k.load(req.Param)
|
||||
|
||||
if req.NoCache {
|
||||
chatCache.Delete(k)
|
||||
} else {
|
||||
ab, err := util.ReadFile(chatCacheFile(k) + `.json`)
|
||||
ab, err := util.ReadFile(cacheFile(`chat`, k) + `.json`)
|
||||
if err == nil && len(ab) > 2 {
|
||||
rsp := &ChatRsp{}
|
||||
rsp := &Rsp{}
|
||||
err = json.Unmarshal(ab, rsp)
|
||||
if err == nil {
|
||||
return rsp, nil
|
||||
@@ -73,7 +44,7 @@ func Chat(req *pb.VaChatReq) (*ChatRsp, error) {
|
||||
return chatCache.Get(k)
|
||||
}
|
||||
|
||||
func buildChatReq(k chatKey) (*aiplatformpb.PredictRequest, error) {
|
||||
func loadChat(k chatKey) (*Rsp, error) {
|
||||
|
||||
m := map[string]any{
|
||||
`context`: k.System,
|
||||
@@ -87,178 +58,26 @@ func buildChatReq(k chatKey) (*aiplatformpb.PredictRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
inst, err := structpb.NewStruct(m)
|
||||
req, err := buildReq(k.paramKey, m, chatModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p, err := structpb.NewStruct(map[string]any{
|
||||
`temperature`: k.Temperature,
|
||||
`maxOutputTokens`: k.MaxOutputTokens,
|
||||
`topP`: k.TopP,
|
||||
`topK`: k.TopK,
|
||||
})
|
||||
rsp, err := doRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := &aiplatformpb.PredictRequest{
|
||||
Endpoint: `projects/aigc-llm/locations/us-central1/publishers/google/models/chat-bison@001`,
|
||||
Instances: []*structpb.Value{
|
||||
structpb.NewStructValue(inst),
|
||||
},
|
||||
Parameters: structpb.NewStructValue(p),
|
||||
}
|
||||
|
||||
return req, nil
|
||||
go func() {
|
||||
rsp.save(cacheFile(`chat`, k))
|
||||
}()
|
||||
return rsp, nil
|
||||
}
|
||||
|
||||
func getChatVal(rsp *aiplatformpb.PredictResponse) (string, error) {
|
||||
|
||||
p := rsp.GetPredictions()
|
||||
if len(p) == 0 {
|
||||
return ``, errEmptyAnswer
|
||||
}
|
||||
|
||||
p0 := p[0]
|
||||
|
||||
if isBlocked(p0) {
|
||||
return ``, errBlocked
|
||||
}
|
||||
|
||||
ca := SpbMap(p0, `candidates`).GetListValue().GetValues()
|
||||
if len(ca) == 0 {
|
||||
return ``, errEmptyAnswer
|
||||
}
|
||||
|
||||
s := SpbMap(ca[0], `content`).GetStringValue()
|
||||
if s == `` {
|
||||
return ``, errEmptyAnswer
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func isBlocked(o *structpb.Value) bool {
|
||||
sa := SpbMap(o, `safetyAttributes`).GetListValue().GetValues()
|
||||
for _, v := range sa {
|
||||
if SpbMap(v, `blocked`).GetBoolValue() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func loadChat(k chatKey) (*ChatRsp, error) {
|
||||
|
||||
req, err := buildChatReq(k)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := life.CTXTimeout(10 * time.Second)
|
||||
t := time.Now()
|
||||
rsp, err := chatClient.Predict(ctx, req)
|
||||
|
||||
cancel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := &ChatRsp{
|
||||
Raw: rsp,
|
||||
CostMs: uint32(time.Since(t) / time.Millisecond),
|
||||
}
|
||||
metrics.VaTime(r.CostMs)
|
||||
|
||||
answer := &pb.VaChatRsp{}
|
||||
answer.Content, err = getChatVal(rsp)
|
||||
if err == errBlocked {
|
||||
err = nil
|
||||
answer.Blocked = true
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.Answer = answer
|
||||
|
||||
go chatSaveCache(k, r)
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func loadChatForCoral(k chatKey) (*ChatRsp, *time.Time, error) {
|
||||
func loadChatForCoral(k chatKey) (*Rsp, *time.Time, error) {
|
||||
r, err := loadChat(k)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return r, nil, nil
|
||||
}
|
||||
|
||||
func chatCacheFile(k chatKey) string {
|
||||
ab, _ := json.Marshal(k)
|
||||
h := md5.Sum(ab)
|
||||
file := fmt.Sprintf(`vertexai/chat/%02x/%02x/%02x/%x`, h[0], h[1], h[2], h[3:])
|
||||
return file
|
||||
}
|
||||
|
||||
func chatSaveCache(k chatKey, rsp *ChatRsp) {
|
||||
file := chatCacheFile(k)
|
||||
util.Mkdir(file)
|
||||
util.WriteJSON(file+`.json`, rsp)
|
||||
}
|
||||
|
||||
// Debug ...
|
||||
func (rsp *ChatRsp) Debug() *pb.VaDebug {
|
||||
d := &pb.VaDebug{
|
||||
CostMs: rsp.CostMs,
|
||||
}
|
||||
|
||||
getToken(d, rsp.Raw)
|
||||
getSafety(d, rsp.Raw)
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
func getToken(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) {
|
||||
tm := SpbMap(raw.GetMetadata(), `tokenMetadata`)
|
||||
if tm == nil {
|
||||
return
|
||||
}
|
||||
input := SpbMap(tm, `inputTokenCount`)
|
||||
if input != nil {
|
||||
d.InputChar = uint32(SpbMap(input, `totalBillableCharacters`).GetNumberValue())
|
||||
d.InputToken = uint32(SpbMap(input, `totalTokens`).GetNumberValue())
|
||||
}
|
||||
output := SpbMap(tm, `outputTokenCount`)
|
||||
if output != nil {
|
||||
d.OutputChar = uint32(SpbMap(output, `totalBillableCharacters`).GetNumberValue())
|
||||
d.OutputToken = uint32(SpbMap(output, `totalTokens`).GetNumberValue())
|
||||
}
|
||||
}
|
||||
func getSafety(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) {
|
||||
p := raw.GetPredictions()
|
||||
if len(p) == 0 {
|
||||
return
|
||||
}
|
||||
sa := SpbMap(p[0], `safetyAttributes`).GetListValue().GetValues()
|
||||
|
||||
for _, v := range sa {
|
||||
c := SpbMap(v, `categories`).GetListValue().GetValues()
|
||||
if len(c) == 0 {
|
||||
continue
|
||||
}
|
||||
s := SpbMap(v, `scores`).GetListValue().GetValues()
|
||||
if len(s) != len(c) {
|
||||
continue
|
||||
}
|
||||
|
||||
for i, cv := range c {
|
||||
row := &pb.VaSafety{
|
||||
Category: cv.GetStringValue(),
|
||||
Score: float32(s[i].GetNumberValue()),
|
||||
}
|
||||
d.Safety = append(d.Safety, row)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
195
server/src/vertexai/common.go
Normal file
195
server/src/vertexai/common.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package vertexai
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"project/metrics"
|
||||
"project/pb"
|
||||
"time"
|
||||
|
||||
aiplatform "cloud.google.com/go/aiplatform/apiv1"
|
||||
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
|
||||
"github.com/zhengkai/life-go"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
)
|
||||
|
||||
var theClient *aiplatform.PredictionClient
|
||||
var errEmptyAnswer = errors.New(`empty answer`)
|
||||
var errBlocked = errors.New(`blocked by google`)
|
||||
|
||||
type canDebug interface {
|
||||
GetDebug() bool
|
||||
}
|
||||
|
||||
// SpbMap ...
|
||||
func SpbMap(o *structpb.Value, key string) *structpb.Value {
|
||||
m := o.GetStructValue().GetFields()
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
v, ok := m[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func cacheFile[T chatKey | textKey](t string, k T) string {
|
||||
ab, _ := json.Marshal(k)
|
||||
h := md5.Sum(ab)
|
||||
file := fmt.Sprintf(`vertexai/%s/%02x/%02x/%02x/%x`, t, h[0], h[1], h[2], h[3:])
|
||||
return file
|
||||
}
|
||||
|
||||
func buildReq(k paramKey, m map[string]any, model string) (*aiplatformpb.PredictRequest, error) {
|
||||
|
||||
inst, err := structpb.NewStruct(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p, err := structpb.NewStruct(map[string]any{
|
||||
`temperature`: k.Temperature,
|
||||
`maxOutputTokens`: k.MaxOutputTokens,
|
||||
`topP`: k.TopP,
|
||||
`topK`: k.TopK,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf(
|
||||
`projects/aigc-llm/locations/us-central1/publishers/google/models/%s`,
|
||||
model,
|
||||
)
|
||||
|
||||
req := &aiplatformpb.PredictRequest{
|
||||
Endpoint: endpoint,
|
||||
Instances: []*structpb.Value{
|
||||
structpb.NewStructValue(inst),
|
||||
},
|
||||
Parameters: structpb.NewStructValue(p),
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func getToken(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) {
|
||||
tm := SpbMap(raw.GetMetadata(), `tokenMetadata`)
|
||||
if tm == nil {
|
||||
return
|
||||
}
|
||||
input := SpbMap(tm, `inputTokenCount`)
|
||||
if input != nil {
|
||||
d.InputChar = uint32(SpbMap(input, `totalBillableCharacters`).GetNumberValue())
|
||||
d.InputToken = uint32(SpbMap(input, `totalTokens`).GetNumberValue())
|
||||
}
|
||||
output := SpbMap(tm, `outputTokenCount`)
|
||||
if output != nil {
|
||||
d.OutputChar = uint32(SpbMap(output, `totalBillableCharacters`).GetNumberValue())
|
||||
d.OutputToken = uint32(SpbMap(output, `totalTokens`).GetNumberValue())
|
||||
}
|
||||
}
|
||||
func getSafety(d *pb.VaDebug, raw *aiplatformpb.PredictResponse) {
|
||||
p := raw.GetPredictions()
|
||||
if len(p) == 0 {
|
||||
return
|
||||
}
|
||||
sa := SpbMap(p[0], `safetyAttributes`).GetListValue().GetValues()
|
||||
|
||||
for _, v := range sa {
|
||||
c := SpbMap(v, `categories`).GetListValue().GetValues()
|
||||
if len(c) == 0 {
|
||||
continue
|
||||
}
|
||||
s := SpbMap(v, `scores`).GetListValue().GetValues()
|
||||
if len(s) != len(c) {
|
||||
continue
|
||||
}
|
||||
|
||||
for i, cv := range c {
|
||||
row := &pb.VaSafety{
|
||||
Category: cv.GetStringValue(),
|
||||
Score: float32(s[i].GetNumberValue()),
|
||||
}
|
||||
d.Safety = append(d.Safety, row)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isBlocked(o *structpb.Value) bool {
|
||||
sa := SpbMap(o, `safetyAttributes`).GetListValue().GetValues()
|
||||
for _, v := range sa {
|
||||
if SpbMap(v, `blocked`).GetBoolValue() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getVal(rsp *aiplatformpb.PredictResponse) (string, error) {
|
||||
|
||||
p := rsp.GetPredictions()
|
||||
if len(p) == 0 {
|
||||
return ``, errEmptyAnswer
|
||||
}
|
||||
|
||||
p0 := p[0]
|
||||
|
||||
if isBlocked(p0) {
|
||||
return ``, errBlocked
|
||||
}
|
||||
|
||||
// for text
|
||||
c := SpbMap(p0, `content`).GetStringValue()
|
||||
if c != `` {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// for chat
|
||||
ca := SpbMap(p0, `candidates`).GetListValue().GetValues()
|
||||
if len(ca) == 0 {
|
||||
return ``, errEmptyAnswer
|
||||
}
|
||||
|
||||
s := SpbMap(ca[0], `content`).GetStringValue()
|
||||
if s == `` {
|
||||
return ``, errEmptyAnswer
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func doRequest(req *aiplatformpb.PredictRequest) (*Rsp, error) {
|
||||
|
||||
ctx, cancel := life.CTXTimeout(10 * time.Second)
|
||||
t := time.Now()
|
||||
rsp, err := theClient.Predict(ctx, req)
|
||||
|
||||
cancel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := &Rsp{
|
||||
Raw: rsp,
|
||||
CostMs: uint32(time.Since(t) / time.Millisecond),
|
||||
}
|
||||
metrics.VaTime(r.CostMs)
|
||||
|
||||
answer := &pb.VaRsp{}
|
||||
answer.Content, err = getVal(rsp)
|
||||
if err == errBlocked {
|
||||
err = nil
|
||||
answer.Blocked = true
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.Answer = answer
|
||||
|
||||
return r, nil
|
||||
}
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
func init() {
|
||||
|
||||
var err error
|
||||
chatClient, err = aiplatform.NewPredictionClient(
|
||||
theClient, err = aiplatform.NewPredictionClient(
|
||||
life.CTX,
|
||||
option.WithEndpoint(`us-central1-aiplatform.googleapis.com:443`),
|
||||
option.WithCredentialsFile(util.Static(`aigc-llm-730bb179e13c.json`)),
|
||||
|
||||
32
server/src/vertexai/param.go
Normal file
32
server/src/vertexai/param.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package vertexai
|
||||
|
||||
import "project/pb"
|
||||
|
||||
var defaultParam = &pb.VaParam{
|
||||
Temperature: 0.2,
|
||||
MaxOutputTokens: 0,
|
||||
TopP: 1,
|
||||
TopK: 40,
|
||||
}
|
||||
|
||||
// paramKey ...
|
||||
type paramKey struct {
|
||||
Temperature float32 `json:"temperature"`
|
||||
MaxOutputTokens uint32 `json:"maxOutputTokens"`
|
||||
TopP float32 `json:"topP"`
|
||||
TopK uint32 `json:"topK"`
|
||||
}
|
||||
|
||||
func (k *paramKey) load(p *pb.VaParam) {
|
||||
if p == nil {
|
||||
p = defaultParam
|
||||
}
|
||||
k.Temperature = p.Temperature
|
||||
k.MaxOutputTokens = p.MaxOutputTokens
|
||||
k.TopP = p.TopP
|
||||
k.TopK = p.TopK
|
||||
|
||||
if k.MaxOutputTokens == 0 {
|
||||
k.MaxOutputTokens = 512
|
||||
}
|
||||
}
|
||||
31
server/src/vertexai/rsp.go
Normal file
31
server/src/vertexai/rsp.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package vertexai
|
||||
|
||||
import (
|
||||
"project/pb"
|
||||
"project/util"
|
||||
|
||||
"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
|
||||
)
|
||||
|
||||
// Rsp ...
|
||||
type Rsp struct {
|
||||
Answer *pb.VaRsp `json:"answer,omitempty"`
|
||||
Raw *aiplatformpb.PredictResponse `json:"raw"`
|
||||
CostMs uint32 `json:"costMs"`
|
||||
}
|
||||
|
||||
// Debug ...
|
||||
func (rsp *Rsp) Debug() *pb.VaDebug {
|
||||
d := &pb.VaDebug{
|
||||
CostMs: rsp.CostMs,
|
||||
}
|
||||
|
||||
getToken(d, rsp.Raw)
|
||||
getSafety(d, rsp.Raw)
|
||||
return d
|
||||
}
|
||||
|
||||
func (rsp *Rsp) save(file string) {
|
||||
util.Mkdir(file)
|
||||
util.WriteJSON(file+`.json`, rsp)
|
||||
}
|
||||
73
server/src/vertexai/text.go
Normal file
73
server/src/vertexai/text.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package vertexai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"project/pb"
|
||||
"project/util"
|
||||
"time"
|
||||
|
||||
"github.com/zhengkai/coral/v2"
|
||||
)
|
||||
|
||||
var textCache = coral.NewLRU(loadTextForCoral, 1000, 100)
|
||||
|
||||
const textModel = `text-bison@001`
|
||||
|
||||
type textKey struct {
|
||||
paramKey
|
||||
Prompt string `json:"system"`
|
||||
}
|
||||
|
||||
// Text ...
|
||||
func Text(req *pb.VaTextReq) (*Rsp, error) {
|
||||
|
||||
k := textKey{
|
||||
Prompt: req.Prompt,
|
||||
}
|
||||
k.load(req.Param)
|
||||
|
||||
if req.NoCache {
|
||||
textCache.Delete(k)
|
||||
} else {
|
||||
ab, err := util.ReadFile(cacheFile(`text`, k) + `.json`)
|
||||
if err == nil && len(ab) > 2 {
|
||||
rsp := &Rsp{}
|
||||
err = json.Unmarshal(ab, rsp)
|
||||
if err == nil {
|
||||
return rsp, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return textCache.Get(k)
|
||||
}
|
||||
|
||||
func loadText(k textKey) (*Rsp, error) {
|
||||
|
||||
m := map[string]any{
|
||||
`prompt`: k.Prompt,
|
||||
}
|
||||
|
||||
req, err := buildReq(k.paramKey, m, textModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rsp, err := doRequest(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
rsp.save(cacheFile(`text`, k))
|
||||
}()
|
||||
return rsp, nil
|
||||
}
|
||||
|
||||
func loadTextForCoral(k textKey) (*Rsp, *time.Time, error) {
|
||||
r, err := loadText(k)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return r, nil, nil
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
package vertexai
|
||||
|
||||
import (
|
||||
"project/pb"
|
||||
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
)
|
||||
|
||||
var defaultParam = &pb.VaParam{
|
||||
Temperature: 0.2,
|
||||
MaxOutputTokens: 0,
|
||||
TopP: 1,
|
||||
TopK: 40,
|
||||
}
|
||||
|
||||
// SpbMap ...
|
||||
func SpbMap(o *structpb.Value, key string) *structpb.Value {
|
||||
m := o.GetStructValue().GetFields()
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
v, ok := m[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return v
|
||||
}
|
||||
@@ -10,13 +10,21 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ChatHandle ...
|
||||
func ChatHandle(w http.ResponseWriter, r *http.Request) {
|
||||
// web type
|
||||
const (
|
||||
WebText Web = iota + 1
|
||||
WebChat
|
||||
)
|
||||
|
||||
// Web ...
|
||||
type Web int
|
||||
|
||||
func (web Web) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
t := time.Now()
|
||||
|
||||
data, debug, err := chatHandle(w, r)
|
||||
o := &pb.VaChatWebRsp{
|
||||
data, debug, err := handle(web, w, r)
|
||||
o := &pb.VaWebRsp{
|
||||
Data: data,
|
||||
Debug: debug,
|
||||
}
|
||||
@@ -36,7 +44,7 @@ func ChatHandle(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(ab)
|
||||
}
|
||||
|
||||
func chatHandleInput(w http.ResponseWriter, r *http.Request) (*pb.VaChatReq, error) {
|
||||
func chatHandleInput(w http.ResponseWriter, r *http.Request) ([]byte, error) {
|
||||
|
||||
if r.Method != `POST` {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
@@ -59,39 +67,66 @@ func chatHandleInput(w http.ResponseWriter, r *http.Request) (*pb.VaChatReq, err
|
||||
return nil, errors.New(`token not match`)
|
||||
}
|
||||
|
||||
o := &pb.VaChatReq{}
|
||||
|
||||
ab, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return nil, errors.New(`read body fail`)
|
||||
}
|
||||
err = json.Unmarshal(ab, o)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return o, nil
|
||||
return ab, nil
|
||||
}
|
||||
|
||||
func chatHandle(w http.ResponseWriter, r *http.Request) (*pb.VaChatRsp, *pb.VaDebug, error) {
|
||||
func handle(web Web, w http.ResponseWriter, r *http.Request) (*pb.VaRsp, *pb.VaDebug, error) {
|
||||
|
||||
req, err := chatHandleInput(w, r)
|
||||
ab, err := chatHandleInput(w, r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
rsp, err := Chat(req)
|
||||
var rsp *Rsp
|
||||
isDebug := false
|
||||
|
||||
if web == WebChat {
|
||||
|
||||
req := &pb.VaChatReq{}
|
||||
isDebug, err = unmarshalWebReq(w, ab, req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
rsp, err = Chat(req)
|
||||
|
||||
} else {
|
||||
|
||||
req := &pb.VaTextReq{}
|
||||
isDebug, err = unmarshalWebReq(w, ab, req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
rsp, err = Text(req)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var debug *pb.VaDebug
|
||||
if req.Debug {
|
||||
if isDebug {
|
||||
debug = rsp.Debug()
|
||||
}
|
||||
|
||||
return rsp.Answer, debug, nil
|
||||
}
|
||||
|
||||
func unmarshalWebReq(w http.ResponseWriter, ab []byte, req canDebug) (bool, error) {
|
||||
|
||||
err := json.Unmarshal(ab, req)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return false, err
|
||||
}
|
||||
|
||||
return req.GetDebug(), nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,8 @@ func Server() {
|
||||
mux.Handle(`/v1/moderations`, core.NewCore())
|
||||
mux.Handle(`/v1/completions`, core.NewCore())
|
||||
mux.Handle(`/v1/chat/completions`, core.NewCore())
|
||||
mux.HandleFunc(`/va/chat`, vertexai.ChatHandle)
|
||||
mux.Handle(`/va/chat`, vertexai.WebChat)
|
||||
mux.Handle(`/va/text`, vertexai.WebText)
|
||||
|
||||
s := &http.Server{
|
||||
Addr: config.WebAddr,
|
||||
|
||||
Reference in New Issue
Block a user