Files
api-proxy/agent/main.go
Sakurasan 2cf484cdf1 all
2023-03-24 22:51:53 +08:00

123 lines
2.9 KiB
Go

package main
import (
"bytes"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"time"
)
var (
masterUrl string
OPENAI_API_KEY string
baseUrl = "https://api.openai.com"
)
func main() {
router := http.NewServeMux()
// 路由转发
router.HandleFunc("/", HandleProxy)
// 启动代理服务器
fmt.Println("API proxy server is listening on port 80")
if err := http.ListenAndServe(":80", router); err != nil {
panic(err)
}
}
func HandleProxy(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth[:7] == "Bearer " {
if len(auth[7:]) < 1 {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
req, _ := http.NewRequest(http.MethodGet, masterUrl, nil)
req.Header.Set("Authorization", auth[7:])
resp, err := http.DefaultClient.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
} else {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
client := http.DefaultClient
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client.Transport = tr
// 创建 API 请求
req, err := http.NewRequest(r.Method, baseUrl+r.URL.Path, r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
req.Header = r.Header
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", OPENAI_API_KEY))
resp, err := client.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer resp.Body.Close()
// 复制 API 响应头部
for name, values := range resp.Header {
for _, value := range values {
w.Header().Add(name, value)
}
}
head := map[string]string{
"Cache-Control": "no-store",
"access-control-allow-origin": "*",
"access-control-allow-credentials": "true",
}
for k, v := range head {
if _, ok := resp.Header[k]; !ok {
w.Header().Set(k, v)
}
}
resp.Header.Del("content-security-policy")
resp.Header.Del("content-security-policy-report-only")
resp.Header.Del("clear-site-data")
bodyRes, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if resp.StatusCode == 200 {
// todo
}
resbody := io.NopCloser(bytes.NewReader(bodyRes))
// 返回 API 响应主体
w.WriteHeader(resp.StatusCode)
if _, err := io.Copy(w, resbody); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}