This commit is contained in:
Sakurasan
2023-03-24 22:51:53 +08:00
parent 4668571cbc
commit 2cf484cdf1
7 changed files with 541 additions and 2 deletions

View File

@@ -0,0 +1,122 @@
package main
import (
"bytes"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"time"
)
var (
masterUrl string
OPENAI_API_KEY string
baseUrl = "https://api.openai.com"
)
func main() {
router := http.NewServeMux()
// 路由转发
router.HandleFunc("/", HandleProxy)
// 启动代理服务器
fmt.Println("API proxy server is listening on port 80")
if err := http.ListenAndServe(":80", router); err != nil {
panic(err)
}
}
func HandleProxy(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth[:7] == "Bearer " {
if len(auth[7:]) < 1 {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
req, _ := http.NewRequest(http.MethodGet, masterUrl, nil)
req.Header.Set("Authorization", auth[7:])
resp, err := http.DefaultClient.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
} else {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
client := http.DefaultClient
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client.Transport = tr
// 创建 API 请求
req, err := http.NewRequest(r.Method, baseUrl+r.URL.Path, r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
req.Header = r.Header
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", OPENAI_API_KEY))
resp, err := client.Do(req)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer resp.Body.Close()
// 复制 API 响应头部
for name, values := range resp.Header {
for _, value := range values {
w.Header().Add(name, value)
}
}
head := map[string]string{
"Cache-Control": "no-store",
"access-control-allow-origin": "*",
"access-control-allow-credentials": "true",
}
for k, v := range head {
if _, ok := resp.Header[k]; !ok {
w.Header().Set(k, v)
}
}
resp.Header.Del("content-security-policy")
resp.Header.Del("content-security-policy-report-only")
resp.Header.Del("clear-site-data")
bodyRes, err := io.ReadAll(resp.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if resp.StatusCode == 200 {
// todo
}
resbody := io.NopCloser(bytes.NewReader(bodyRes))
// 返回 API 响应主体
w.WriteHeader(resp.StatusCode)
if _, err := io.Copy(w, resbody); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}

215
getway/main.go Normal file
View File

@@ -0,0 +1,215 @@
package main
import (
"crypto/rand"
"database/sql"
"encoding/base64"
"log"
"net/http"
"os"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/go-sql-driver/mysql"
"github.com/golang-jwt/jwt"
"github.com/google/go-github/v32/github"
"golang.org/x/oauth2"
"golang.org/x/oauth2/github"
)
var dbConn *sql.DB
var jwtSecret = []byte(os.Getenv("JWT_SECRET"))
func main() {
initDB()
router := gin.Default()
router.Use(cors.New(cors.Config{
AllowOrigins: []string{"http://localhost:8080"},
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD"},
AllowHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
AllowCredentials: true,
}))
router.GET("/auth/github", githubLoginHandler)
router.GET("/auth/github/callback", githubCallbackHandler)
router.Run(":8000")
}
func githubLoginHandler(c *gin.Context) {
state := generateState()
code := generateCode()
err := storeStateToDB(state, code)
if err != nil {
log.Println("Error storing state to DB:", err)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"message": "Internal server error"})
return
}
oauthConfig := &oauth2.Config{
ClientID: os.Getenv("GITHUB_CLIENT_ID"),
ClientSecret: os.Getenv("GITHUB_CLIENT_SECRET"),
RedirectURL: "http://localhost:8000/auth/github/callback",
Scopes: []string{"user:email"},
Endpoint: github.Endpoint,
}
url := oauthConfig.AuthCodeURL(state)
c.Redirect(http.StatusFound, url)
}
func githubCallbackHandler(c *gin.Context) {
state := c.Query("state")
code := c.Query("code")
if !verifyState(state, code) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"message": "Invalid state"})
return
}
oauthConfig := &oauth2.Config{
ClientID: os.Getenv("GITHUB_CLIENT_ID"),
ClientSecret: os.Getenv("GITHUB_CLIENT_SECRET"),
RedirectURL: "http://localhost:8000/auth/github/callback",
Scopes: []string{"user:email"},
Endpoint: github.Endpoint,
}
token, err := oauthConfig.Exchange(c.Request.Context(), code)
if err != nil {
log.Println("Error exchanging token:", err)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"message": "Internal server error"})
return
}
client := github.NewClient(oauthConfig.Client(c.Request.Context(), token))
user, _, err := client.Users.Get(c.Request.Context(), "")
if err != nil {
log.Println("Error getting user:", err)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"message": "Internal server error"})
return
}
err = storeUserToDB(user)
if err != nil {
log.Println("Error storing user to DB:", err)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"message": "Internal server error"})
return
}
jwtToken, err := generateJWTToken(user.ID)
if err != nil {
log.Println("Error generating JWT token:", err)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"message": "Internal server error"})
return
}
c.SetCookie("token", jwtToken, 60*60*24, "/", "localhost", false, true)
c.Redirect(http.StatusFound, "http://localhost:8080/#/")
}
func initDB() {
cfg, err := mysql.ParseDSN(os.Getenv("MYSQL_DSN"))
if err != nil {
log.Fatal("Error parsing MySQL DSN:", err)
}
dbConn, err = sql.Open("mysql", cfg.FormatDSN())
if err != nil {
log.Fatal("Error opening database:", err)
}
err = dbConn.Ping()
if err != nil {
log.Fatal("Error connecting to database:", err)
}
log.Println("Database connection established")
}
func generateState() string {
return generateRandomString(20)
}
func generateCode() string {
return generateRandomString(20)
}
func generateRandomString(length int) string {
byteArr := make([]byte, length)
_, err := rand.Read(byteArr)
if err != nil {
log.Fatal("Error generating random string:", err)
}
return base64.URLEncoding.EncodeToString(byteArr)
}
func storeStateToDB(state, code string) error {
query := "INSERT INTO oauth_state (state, code) VALUES (?, ?)"
stmt, err := dbConn.Prepare(query)
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(state, code)
if err != nil {
return err
}
return nil
}
func verifyState(state, code string) bool {
query := "SELECT COUNT(*) FROM oauth_state WHERE state = ? AND code = ?"
row := dbConn.QueryRow(query, state, code)
var count int
err := row.Scan(&count)
if err != nil {
log.Println("Error verifying state:", err)
return false
}
if count == 0 {
return false
}
return true
}
func storeUserToDB(user *github.User) error {
query := "INSERT INTO users (id, login, email) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE login = VALUES(login), email = VALUES(email)"
stmt, err := dbConn.Prepare(query)
if err != nil {
return err
}
defer stmt.Close()
_, err = stmt.Exec(user.GetID(), user.GetLogin(), user.GetEmail())
if err != nil {
return err
}
return nil
}
func generateJWTToken(userID int64) (string, error) {
token := jwt.New(jwt.SigningMethodHS256)
claims := token.Claims.(jwt.MapClaims)
claims["userID"] = userID
claims["exp"] = time.Now().Add(time.Hour * 24).Unix()
jwtToken, err := token.SignedString(jwtSecret)
if err != nil {
return "", err
}
return jwtToken, nil
}

1
go.mod
View File

@@ -4,6 +4,7 @@ go 1.19
require (
github.com/google/go-github v17.0.0+incompatible
github.com/gorilla/websocket v1.5.0
golang.org/x/oauth2 v0.6.0
)

2
go.sum
View File

@@ -9,6 +9,8 @@ github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4r
github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=

View File

@@ -10,8 +10,8 @@ import (
)
var (
clientID = "YOUR_CLIENT_ID"
clientSecret = "YOUR_CLIENT_SECRET"
clientID = "9f75836d51c1cb447fa5"
clientSecret = "3cbedeb77ffa7593b3cc60985aa4212b2cfc0686"
redirectURL = "http://localhost:8080/callback"
)

33
online/index.html Normal file
View File

@@ -0,0 +1,33 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>在线人数统计</title>
</head>
<body>
<div id="app">
<p>当前在线人数:{{ onlineCount }}</p>
</div>
<script src="https://cdn.jsdelivr.net/npm/vue"></script>
<script>
let ws = new WebSocket("ws://localhost:8080/ws");
let app = new Vue({
el: '#app',
data: {
onlineCount: 1,
},
mounted() {
let self = this;
ws.addEventListener('message', function(event) {
self.onlineCount = event.data;
});
},
});
</script>
</body>
</html>

166
online/main.go Normal file
View File

@@ -0,0 +1,166 @@
package main
import (
"flag"
"fmt"
"log"
"net/http"
"sync"
"github.com/gorilla/websocket"
)
type Client struct {
conn *websocket.Conn // websocket 连接
mutex sync.RWMutex // 读写锁,用于多线程安全
id int // 客户端 ID
done chan struct{} // 用于关闭客户端连接
send chan []byte // 发送消息的 channel
hub *ClientHub // 客户端连接管理器
}
func (c *Client) Write(msg []byte) {
c.mutex.Lock()
defer c.mutex.Unlock()
select {
case c.send <- msg:
default:
// 如果 send channel 已满,则关闭连接
close(c.done)
}
}
func (c *Client) Read() {
defer func() {
c.hub.unregister <- c
c.conn.Close()
}()
for {
_, msg, err := c.conn.ReadMessage()
if err != nil {
break
}
log.Println(msg)
c.send <- []byte("pong")
}
}
func (c *Client) WriteLoop() {
// for {
// select {
// case msg := <-c.send:
// err := c.conn.WriteMessage(msg.messageType, msg.data)
// if err != nil {
// return err
// }
// case <-c.stop:
// return
// }
// }
}
type ClientHub struct {
clients map[*Client]bool // 连接的客户端
broadcast chan []byte // 广播通道
register chan *Client // 新连接通道
unregister chan *Client // 断开连接通道
onlineCount int // 在线人数
}
func (h *ClientHub) run() {
for {
select {
case client := <-h.register: // 新连接
h.add(client)
case client := <-h.unregister: // 断开连接
h.remove(client)
case message := <-h.broadcast: // 广播消息
for client := range h.clients {
select {
case client.send <- message:
default:
// 如果 send channel 已满,则关闭连接
close(client.done)
delete(h.clients, client)
}
}
}
}
}
func (h *ClientHub) add(client *Client) {
// 添加 client 到连接管理器
h.clients[client] = true
h.onlineCount++
msg := fmt.Sprintf("欢迎 %d 来到聊天室,当前在线人数 %d 人", client.id, h.onlineCount)
h.broadcastMessage([]byte(msg))
}
func (h *ClientHub) remove(client *Client) {
if _, ok := h.clients[client]; ok {
msg := fmt.Sprintf("%d 离开了聊天室,当前在线人数 %d 人", client.id, h.onlineCount)
h.broadcastMessage([]byte(msg))
delete(h.clients, client)
close(client.send)
h.onlineCount--
}
}
func (h *ClientHub) broadcastMessage(msg []byte) {
for client := range h.clients {
select {
case client.send <- msg:
default:
close(client.done)
delete(h.clients, client)
}
}
}
func main() {
var addr = flag.String("addr", ":8080", "http service address")
flag.Parse()
log.SetFlags(0)
hub := &ClientHub{
clients: make(map[*Client]bool),
broadcast: make(chan []byte),
register: make(chan *Client),
unregister: make(chan *Client),
}
go hub.run()
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, "./index.html")
})
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // 允许所有的请求
},
}
http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
return
}
client := &Client{
conn: conn,
mutex: sync.RWMutex{},
id: len(hub.clients) + 1,
done: make(chan struct{}),
send: make(chan []byte, 256),
hub: hub,
}
hub.register <- client
// go client.WriteLoop()
go client.Read()
})
log.Printf("listening on %s", *addr)
err := http.ListenAndServe(*addr, nil)
if err != nil {
log.Fatal("ListenAndServe: ", err.Error())
}
}