package main import ( "flag" "fmt" "log" "net/http" "sync" "github.com/gorilla/websocket" ) type Client struct { conn *websocket.Conn // websocket 连接 mutex sync.RWMutex // 读写锁,用于多线程安全 id int // 客户端 ID done chan struct{} // 用于关闭客户端连接 send chan []byte // 发送消息的 channel hub *ClientHub // 客户端连接管理器 } func (c *Client) Write(msg []byte) { c.mutex.Lock() defer c.mutex.Unlock() select { case c.send <- msg: default: // 如果 send channel 已满,则关闭连接 close(c.done) } } func (c *Client) Read() { defer func() { c.hub.unregister <- c c.conn.Close() }() for { _, msg, err := c.conn.ReadMessage() if err != nil { break } log.Println(msg) c.send <- []byte("pong") } } func (c *Client) WriteLoop() { // for { // select { // case msg := <-c.send: // err := c.conn.WriteMessage(msg.messageType, msg.data) // if err != nil { // return err // } // case <-c.stop: // return // } // } } type ClientHub struct { clients map[*Client]bool // 连接的客户端 broadcast chan []byte // 广播通道 register chan *Client // 新连接通道 unregister chan *Client // 断开连接通道 onlineCount int // 在线人数 } func (h *ClientHub) run() { for { select { case client := <-h.register: // 新连接 h.add(client) case client := <-h.unregister: // 断开连接 h.remove(client) case message := <-h.broadcast: // 广播消息 for client := range h.clients { select { case client.send <- message: default: // 如果 send channel 已满,则关闭连接 close(client.done) delete(h.clients, client) } } } } } func (h *ClientHub) add(client *Client) { // 添加 client 到连接管理器 h.clients[client] = true h.onlineCount++ msg := fmt.Sprintf("欢迎 %d 来到聊天室,当前在线人数 %d 人", client.id, h.onlineCount) h.broadcastMessage([]byte(msg)) } func (h *ClientHub) remove(client *Client) { if _, ok := h.clients[client]; ok { msg := fmt.Sprintf("%d 离开了聊天室,当前在线人数 %d 人", client.id, h.onlineCount) h.broadcastMessage([]byte(msg)) delete(h.clients, client) close(client.send) h.onlineCount-- } } func (h *ClientHub) broadcastMessage(msg []byte) { for client := range h.clients { select { case client.send <- msg: default: close(client.done) delete(h.clients, client) } } } func main() { var addr = flag.String("addr", ":8080", "http service address") flag.Parse() log.SetFlags(0) hub := &ClientHub{ clients: make(map[*Client]bool), broadcast: make(chan []byte), register: make(chan *Client), unregister: make(chan *Client), } go hub.run() http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, "./index.html") }) var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true // 允许所有的请求 }, } http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Println(err) return } client := &Client{ conn: conn, mutex: sync.RWMutex{}, id: len(hub.clients) + 1, done: make(chan struct{}), send: make(chan []byte, 256), hub: hub, } hub.register <- client // go client.WriteLoop() go client.Read() }) log.Printf("listening on %s", *addr) err := http.ListenAndServe(*addr, nil) if err != nil { log.Fatal("ListenAndServe: ", err.Error()) } }