Files
api-proxy/online/main.go
Sakurasan 2cf484cdf1 all
2023-03-24 22:51:53 +08:00

167 lines
3.5 KiB
Go

package main
import (
"flag"
"fmt"
"log"
"net/http"
"sync"
"github.com/gorilla/websocket"
)
type Client struct {
conn *websocket.Conn // websocket 连接
mutex sync.RWMutex // 读写锁,用于多线程安全
id int // 客户端 ID
done chan struct{} // 用于关闭客户端连接
send chan []byte // 发送消息的 channel
hub *ClientHub // 客户端连接管理器
}
func (c *Client) Write(msg []byte) {
c.mutex.Lock()
defer c.mutex.Unlock()
select {
case c.send <- msg:
default:
// 如果 send channel 已满,则关闭连接
close(c.done)
}
}
func (c *Client) Read() {
defer func() {
c.hub.unregister <- c
c.conn.Close()
}()
for {
_, msg, err := c.conn.ReadMessage()
if err != nil {
break
}
log.Println(msg)
c.send <- []byte("pong")
}
}
func (c *Client) WriteLoop() {
// for {
// select {
// case msg := <-c.send:
// err := c.conn.WriteMessage(msg.messageType, msg.data)
// if err != nil {
// return err
// }
// case <-c.stop:
// return
// }
// }
}
type ClientHub struct {
clients map[*Client]bool // 连接的客户端
broadcast chan []byte // 广播通道
register chan *Client // 新连接通道
unregister chan *Client // 断开连接通道
onlineCount int // 在线人数
}
func (h *ClientHub) run() {
for {
select {
case client := <-h.register: // 新连接
h.add(client)
case client := <-h.unregister: // 断开连接
h.remove(client)
case message := <-h.broadcast: // 广播消息
for client := range h.clients {
select {
case client.send <- message:
default:
// 如果 send channel 已满,则关闭连接
close(client.done)
delete(h.clients, client)
}
}
}
}
}
func (h *ClientHub) add(client *Client) {
// 添加 client 到连接管理器
h.clients[client] = true
h.onlineCount++
msg := fmt.Sprintf("欢迎 %d 来到聊天室,当前在线人数 %d 人", client.id, h.onlineCount)
h.broadcastMessage([]byte(msg))
}
func (h *ClientHub) remove(client *Client) {
if _, ok := h.clients[client]; ok {
msg := fmt.Sprintf("%d 离开了聊天室,当前在线人数 %d 人", client.id, h.onlineCount)
h.broadcastMessage([]byte(msg))
delete(h.clients, client)
close(client.send)
h.onlineCount--
}
}
func (h *ClientHub) broadcastMessage(msg []byte) {
for client := range h.clients {
select {
case client.send <- msg:
default:
close(client.done)
delete(h.clients, client)
}
}
}
func main() {
var addr = flag.String("addr", ":8080", "http service address")
flag.Parse()
log.SetFlags(0)
hub := &ClientHub{
clients: make(map[*Client]bool),
broadcast: make(chan []byte),
register: make(chan *Client),
unregister: make(chan *Client),
}
go hub.run()
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, "./index.html")
})
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // 允许所有的请求
},
}
http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
return
}
client := &Client{
conn: conn,
mutex: sync.RWMutex{},
id: len(hub.clients) + 1,
done: make(chan struct{}),
send: make(chan []byte, 256),
hub: hub,
}
hub.register <- client
// go client.WriteLoop()
go client.Read()
})
log.Printf("listening on %s", *addr)
err := http.ListenAndServe(*addr, nil)
if err != nil {
log.Fatal("ListenAndServe: ", err.Error())
}
}