167 lines
3.5 KiB
Go
167 lines
3.5 KiB
Go
package main
|
|
|
|
import (
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"sync"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
type Client struct {
|
|
conn *websocket.Conn // websocket 连接
|
|
mutex sync.RWMutex // 读写锁,用于多线程安全
|
|
id int // 客户端 ID
|
|
done chan struct{} // 用于关闭客户端连接
|
|
send chan []byte // 发送消息的 channel
|
|
hub *ClientHub // 客户端连接管理器
|
|
}
|
|
|
|
func (c *Client) Write(msg []byte) {
|
|
c.mutex.Lock()
|
|
defer c.mutex.Unlock()
|
|
|
|
select {
|
|
case c.send <- msg:
|
|
default:
|
|
// 如果 send channel 已满,则关闭连接
|
|
close(c.done)
|
|
}
|
|
}
|
|
func (c *Client) Read() {
|
|
defer func() {
|
|
c.hub.unregister <- c
|
|
c.conn.Close()
|
|
}()
|
|
|
|
for {
|
|
_, msg, err := c.conn.ReadMessage()
|
|
if err != nil {
|
|
break
|
|
}
|
|
log.Println(msg)
|
|
c.send <- []byte("pong")
|
|
}
|
|
}
|
|
|
|
func (c *Client) WriteLoop() {
|
|
// for {
|
|
// select {
|
|
// case msg := <-c.send:
|
|
// err := c.conn.WriteMessage(msg.messageType, msg.data)
|
|
// if err != nil {
|
|
// return err
|
|
// }
|
|
// case <-c.stop:
|
|
// return
|
|
// }
|
|
// }
|
|
}
|
|
|
|
type ClientHub struct {
|
|
clients map[*Client]bool // 连接的客户端
|
|
broadcast chan []byte // 广播通道
|
|
register chan *Client // 新连接通道
|
|
unregister chan *Client // 断开连接通道
|
|
onlineCount int // 在线人数
|
|
}
|
|
|
|
func (h *ClientHub) run() {
|
|
for {
|
|
select {
|
|
case client := <-h.register: // 新连接
|
|
h.add(client)
|
|
case client := <-h.unregister: // 断开连接
|
|
h.remove(client)
|
|
case message := <-h.broadcast: // 广播消息
|
|
for client := range h.clients {
|
|
select {
|
|
case client.send <- message:
|
|
default:
|
|
// 如果 send channel 已满,则关闭连接
|
|
close(client.done)
|
|
delete(h.clients, client)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *ClientHub) add(client *Client) {
|
|
// 添加 client 到连接管理器
|
|
h.clients[client] = true
|
|
h.onlineCount++
|
|
msg := fmt.Sprintf("欢迎 %d 来到聊天室,当前在线人数 %d 人", client.id, h.onlineCount)
|
|
h.broadcastMessage([]byte(msg))
|
|
}
|
|
|
|
func (h *ClientHub) remove(client *Client) {
|
|
if _, ok := h.clients[client]; ok {
|
|
msg := fmt.Sprintf("%d 离开了聊天室,当前在线人数 %d 人", client.id, h.onlineCount)
|
|
h.broadcastMessage([]byte(msg))
|
|
delete(h.clients, client)
|
|
close(client.send)
|
|
h.onlineCount--
|
|
}
|
|
}
|
|
func (h *ClientHub) broadcastMessage(msg []byte) {
|
|
for client := range h.clients {
|
|
select {
|
|
case client.send <- msg:
|
|
default:
|
|
close(client.done)
|
|
delete(h.clients, client)
|
|
}
|
|
}
|
|
}
|
|
|
|
func main() {
|
|
var addr = flag.String("addr", ":8080", "http service address")
|
|
flag.Parse()
|
|
log.SetFlags(0)
|
|
|
|
hub := &ClientHub{
|
|
clients: make(map[*Client]bool),
|
|
broadcast: make(chan []byte),
|
|
register: make(chan *Client),
|
|
unregister: make(chan *Client),
|
|
}
|
|
go hub.run()
|
|
|
|
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
|
http.ServeFile(w, r, "./index.html")
|
|
})
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
return true // 允许所有的请求
|
|
},
|
|
}
|
|
http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Println(err)
|
|
return
|
|
}
|
|
client := &Client{
|
|
conn: conn,
|
|
mutex: sync.RWMutex{},
|
|
id: len(hub.clients) + 1,
|
|
done: make(chan struct{}),
|
|
send: make(chan []byte, 256),
|
|
hub: hub,
|
|
}
|
|
hub.register <- client
|
|
// go client.WriteLoop()
|
|
go client.Read()
|
|
})
|
|
|
|
log.Printf("listening on %s", *addr)
|
|
err := http.ListenAndServe(*addr, nil)
|
|
if err != nil {
|
|
log.Fatal("ListenAndServe: ", err.Error())
|
|
}
|
|
}
|