package ws import ( "ChatRoom/internal/models" "ChatRoom/internal/rabbitmq" "encoding/json" "log" "net/http" "time" "github.com/gorilla/websocket" ) // 常量 const ( writeWait = 10 * time.Second //写入等待时间 pongWait = 60 * time.Second //pong等待时间 pingPeriod = (pongWait * 9) / 10 //ping周期 maxMessageSize = 512 //最大消息大小 ) // 升级器 var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, //读取缓冲区大小 WriteBufferSize: 1024, //写入缓冲区大小 CheckOrigin: func(r *http.Request) bool { return true // 允许跨域(开发阶段)允许所有来源 }, } // 连接 type Connection struct { wsConn *websocket.Conn //websocket连接 rmqClient *rabbitmq.Client //rabbitmq客户端 queueName string //队列名称 userID string //用户ID send chan []byte //发送通道 } func NewConnection(w http.ResponseWriter, r *http.Request, rmq *rabbitmq.Client) { // 1. 升级 HTTP 到 WebSocket wsConn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("WebSocket 升级失败: %v", err) return } // 2. 从 URL 获取用户ID userID := r.URL.Query().Get("user") if userID == "" { wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(4001, "user required")) wsConn.Close() return } // 3. 为用户创建 RabbitMQ 队列 queueName, err := rmq.DeclareQueue() if err != nil { log.Printf("RabbitMQ 队列创建失败: %v", err) wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(5000, "queue error")) wsConn.Close() return } // 4. 绑定基础路由(全局/私聊/事件) bindings := []string{ "chat.global", // 全体广播 "chat.user." + userID, // 私聊 "chat.event.*", // 上下线事件 "chat.system", // 系统通知 } for _, rk := range bindings { if err := rmq.BindQueue(queueName, rk); err != nil { log.Printf("绑定失败 [%s]: %v", rk, err) // 可选择继续或关闭连接 } } // 5. 创建 Connection 对象 conn := &Connection{ wsConn: wsConn, rmqClient: rmq, queueName: queueName, userID: userID, send: make(chan []byte, 256), } // 发送登录成功消息 systemMsg := &models.Message{ Type: models.MsgTypeSystem, User: "system", Content: "登录成功", Time: time.Now().UTC().Format(time.RFC3339), // 服务端覆盖 Time } if data, _ := json.Marshal(systemMsg); len(data) > 0 { conn.send <- data } go conn.writePump() go conn.readPump() go conn.consumeFromRabbitMQ() } func (c *Connection) readPump() { defer c.wsConn.Close() c.wsConn.SetReadLimit(maxMessageSize) c.wsConn.SetReadDeadline(time.Now().Add(pongWait)) c.wsConn.SetPongHandler(func(string) error { c.wsConn.SetReadDeadline(time.Now().Add(pongWait)) return nil }) for { _, rawMsg, err := c.wsConn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) { log.Printf("WebSocket 读取错误: %v", err) } break } var msg models.Message if err := json.Unmarshal(rawMsg, &msg); err != nil { c.sendError("无效消息格式") continue } // 校验必填字段:content 必填,user 由服务端覆盖(防伪造) if msg.Content == "" { c.sendError("缺少 content 字段") continue } msg.User = c.userID msg.Time = time.Now().UTC().Format(time.RFC3339) // 关键!服务端统一时间 // 生成路由键 routingKey := c.getRoutingKey(&msg) if routingKey == "" { c.sendError("不支持的消息类型或缺少必要字段") continue } // 序列化并发布 body, _ := json.Marshal(msg) if err := c.rmqClient.Publish(c.rmqClient.ExchangeName, routingKey, body); err != nil { log.Printf("RabbitMQ 发布失败 [%s]: %v", routingKey, err) c.sendError("消息发送失败") continue } } } func (c *Connection) getRoutingKey(msg *models.Message) string { switch msg.Type { case models.MsgTypeBroadcast: return "chat.global" case models.MsgTypeRoom: if msg.Room == "" { return "" } return "chat.room." + msg.Room case models.MsgTypePrivate: if msg.To == "" { return "" } return "chat.user." + msg.To default: return "" } } func (c *Connection) sendError(content string) { errMsg := &models.Message{ Type: models.MsgTypeError, User: "system", Content: content, Time: time.Now().UTC().Format(time.RFC3339), // 服务端覆盖 Time } if data, _ := json.Marshal(errMsg); len(data) > 0 { select { case c.send <- data: default: } } } func (c *Connection) writePump() { ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() c.wsConn.Close() }() for { select { case message, ok := <-c.send: c.wsConn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { c.wsConn.WriteMessage(websocket.CloseMessage, []byte{}) return } w, err := c.wsConn.NextWriter(websocket.TextMessage) if err != nil { return } w.Write(message) if err := w.Close(); err != nil { return } case <-ticker.C: c.wsConn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.wsConn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } } func (c *Connection) consumeFromRabbitMQ() { deliveries, err := c.rmqClient.Consume(c.queueName) if err != nil { log.Printf("RabbitMQ 消费启动失败: %v", err) return } for delivery := range deliveries { select { case c.send <- delivery.Body: // 消息体已是 JSON(含服务端设置的 Time) delivery.Ack(false) default: delivery.Nack(false, true) } } }