Files
ChatRoom/internal/ws/connection.go
2026-02-04 13:06:11 +08:00

272 lines
6.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package ws
import (
"ChatRoom/internal/models"
"ChatRoom/internal/rabbitmq"
"ChatRoom/internal/redis"
"encoding/json"
"log"
"net/http"
"time"
"github.com/gorilla/websocket"
)
// 常量
const (
writeWait = 10 * time.Second //写入等待时间
pongWait = 60 * time.Second //pong等待时间
pingPeriod = (pongWait * 9) / 10 //ping周期
maxMessageSize = 512 //最大消息大小
)
// 升级器
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024, //读取缓冲区大小
WriteBufferSize: 1024, //写入缓冲区大小
CheckOrigin: func(r *http.Request) bool {
return true // 允许跨域(开发阶段)允许所有来源
},
}
// 连接
type Connection struct {
wsConn *websocket.Conn //websocket连接
rmqClient *rabbitmq.Client //rabbitmq客户端
redisClient *redis.Client //redis客户端
queueName string //队列名称
userID string //用户ID
send chan []byte //发送通道
}
func NewConnection(w http.ResponseWriter, r *http.Request, rmq *rabbitmq.Client, redisClient *redis.Client) {
// 1. 升级 HTTP 到 WebSocket
wsConn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("WebSocket 升级失败: %v", err)
return
}
// 2. 从 URL 获取用户ID
userID := r.URL.Query().Get("user")
if userID == "" {
wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(4001, "user required"))
wsConn.Close()
return
}
//redis记录用户
if err := redisClient.AddUserToZSet(redis.UserZSet, userID, time.Now().Unix()); err != nil {
log.Printf("Redis 用户记录失败: %v", err)
wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(5000, "redis error"))
wsConn.Close()
return
}
// 3. 为用户创建 RabbitMQ 队列
queueName, err := rmq.DeclareQueue()
if err != nil {
log.Printf("RabbitMQ 队列创建失败: %v", err)
wsConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(5000, "queue error"))
wsConn.Close()
return
}
// 4. 绑定基础路由(全局/私聊/事件)
bindings := []string{
"chat.global", // 全体广播
"chat.user." + userID, // 私聊
"chat.event.*", // 上下线事件
"chat.system", // 系统通知
}
for _, rk := range bindings {
if err := rmq.BindQueue(queueName, rk); err != nil {
log.Printf("绑定失败 [%s]: %v", rk, err)
// 可选择继续或关闭连接
}
}
// 5. 创建 Connection 对象
conn := &Connection{
wsConn: wsConn,
rmqClient: rmq,
redisClient: redisClient,
queueName: queueName,
userID: userID,
send: make(chan []byte, 256),
}
// 发送登录成功消息
systemMsg := &models.Message{
Type: models.MsgTypeSystem,
User: "system",
Content: "登录成功",
Time: time.Now().UTC().Format(time.RFC3339), // 服务端覆盖 Time
}
if data, _ := json.Marshal(systemMsg); len(data) > 0 {
conn.send <- data
}
// 广播在线人数
conn.broadcastUserCount()
go conn.writePump()
go conn.readPump()
go conn.consumeFromRabbitMQ()
}
func (c *Connection) readPump() {
defer func() {
c.redisClient.RemoveUserFromZSet(redis.UserZSet, c.userID)
c.broadcastUserCount()
c.wsConn.Close()
}()
c.wsConn.SetReadLimit(maxMessageSize)
c.wsConn.SetReadDeadline(time.Now().Add(pongWait))
c.wsConn.SetPongHandler(func(string) error {
c.wsConn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for {
_, rawMsg, err := c.wsConn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived) {
log.Printf("WebSocket 读取错误: %v", err)
}
break
}
var msg models.Message
if err := json.Unmarshal(rawMsg, &msg); err != nil {
c.sendError("无效消息格式")
continue
}
// 校验必填字段content 必填user 由服务端覆盖(防伪造)
if msg.Content == "" {
c.sendError("缺少 content 字段")
continue
}
msg.User = c.userID
msg.Time = time.Now().UTC().Format(time.RFC3339) // 关键!服务端统一时间
// 生成路由键
routingKey := c.getRoutingKey(&msg)
if routingKey == "" {
c.sendError("不支持的消息类型或缺少必要字段")
continue
}
// 序列化并发布
body, _ := json.Marshal(msg)
if err := c.rmqClient.Publish(c.rmqClient.ExchangeName, routingKey, body); err != nil {
log.Printf("RabbitMQ 发布失败 [%s]: %v", routingKey, err)
c.sendError("消息发送失败")
continue
}
}
}
func (c *Connection) getRoutingKey(msg *models.Message) string {
switch msg.Type {
case models.MsgTypeBroadcast:
return "chat.global"
case models.MsgTypeRoom:
if msg.Room == "" {
return ""
}
return "chat.room." + msg.Room
case models.MsgTypePrivate:
if msg.To == "" {
return ""
}
return "chat.user." + msg.To
default:
return ""
}
}
func (c *Connection) sendError(content string) {
errMsg := &models.Message{
Type: models.MsgTypeError,
User: "system",
Content: content,
Time: time.Now().UTC().Format(time.RFC3339), // 服务端覆盖 Time
}
if data, _ := json.Marshal(errMsg); len(data) > 0 {
select {
case c.send <- data:
default:
}
}
}
func (c *Connection) writePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.wsConn.Close()
}()
for {
select {
case message, ok := <-c.send:
c.wsConn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
c.wsConn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
w, err := c.wsConn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
w.Write(message)
if err := w.Close(); err != nil {
return
}
case <-ticker.C:
c.wsConn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.wsConn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
func (c *Connection) consumeFromRabbitMQ() {
deliveries, err := c.rmqClient.Consume(c.queueName)
if err != nil {
log.Printf("RabbitMQ 消费启动失败: %v", err)
return
}
for delivery := range deliveries {
select {
case c.send <- delivery.Body: // 消息体已是 JSON含服务端设置的 Time
delivery.Ack(false)
default:
delivery.Nack(false, true)
}
}
}
// 广播在线人数
func (c *Connection) broadcastUserCount() {
count, err := c.redisClient.CountUsers(redis.UserZSet)
if err != nil {
log.Printf("获取在线人数失败: %v", err)
return
}
msg := &models.Message{
Type: models.MsgTypeUserCount,
User: "system",
Count: count,
Time: time.Now().UTC().Format(time.RFC3339),
}
body, _ := json.Marshal(msg)
if err := c.rmqClient.Publish(c.rmqClient.ExchangeName, "chat.global", body); err != nil {
log.Printf("广播在线人数失败: %v", err)
}
}