feat: 基本广播服务

- 由Hub接收/push/{topic}的请求并解析信息体广播到对应的Client
This commit is contained in:
2025-12-17 14:49:59 +08:00
parent 53555a31c0
commit 1dbcc03e46
16 changed files with 245 additions and 191 deletions

View File

@ -2,10 +2,12 @@ package ws
import (
"context"
"errors"
"log"
"time"
"git.jinshen.cn/remilia/push-server/interval/protocol"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
// Client represents a connected client in the hub.
@ -13,28 +15,58 @@ type Client struct {
ID string
Conn *websocket.Conn
SendChan chan []byte
Hub *Hub
Ctx context.Context
Cancel context.CancelFunc
inited bool
}
func NewClient(id string, conn *websocket.Conn, parentCtx context.Context) *Client {
func NewClient(id string, conn *websocket.Conn, hub *Hub, parentCtx context.Context) *Client {
ctx, cancel := context.WithCancel(parentCtx)
return &Client{
ID: id,
Conn: conn,
SendChan: make(chan []byte, 32),
Ctx: ctx,
Cancel: cancel,
Hub: hub,
Ctx: ctx,
Cancel: cancel,
inited: false,
}
}
func (c *Client) ReadLoop() {
defer c.Close()
for {
var msg protocol.ControlMessage
err := wsjson.Read(c.Ctx, c.Conn, &msg)
if err != nil {
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
log.Println("WebSocket closed normally:", err)
} else {
log.Println("WebSocket read error:", err)
}
return
}
if err := msg.Validate(); err != nil {
_ = c.Conn.Close(websocket.StatusPolicyViolation, "invalid message")
return
}
if err := c.handleControlMessage(msg); err != nil {
return
}
}
}
// Write message to websocket connection.
func (c *Client) WriteMessage() {
defer func() {
_ = c.Conn.Close(websocket.StatusNormalClosure, "client writer closed")
}()
func (c *Client) WriteLoop() {
defer c.Close()
for {
select {
@ -44,11 +76,7 @@ func (c *Client) WriteMessage() {
if !ok {
return
}
writeCtx, cancel := context.WithTimeout(c.Ctx, 5*time.Second)
err := c.Conn.Write(writeCtx, websocket.MessageText, msg)
cancel()
err := wsjson.Write(c.Ctx, c.Conn, msg)
if err != nil {
log.Println("WebSocket write error:", err)
return
@ -56,3 +84,40 @@ func (c *Client) WriteMessage() {
}
}
}
func (c *Client) handleControlMessage(msg protocol.ControlMessage) error {
switch msg.Type {
case protocol.MsgInit:
if c.inited {
return errors.New("already initialized")
}
log.Printf("Client %s initializing with topics: %v", c.ID, msg.Topics)
c.Hub.RegisterClient(c)
for _, t := range msg.Topics {
c.Hub.Subscribe(protocol.Subscription{
ClientID: c.ID,
Topic: protocol.Topic(t),
})
}
c.inited = true
return nil
default:
if !c.inited {
return errors.New("client not initialized")
}
log.Println("Unknown control message type:", msg.Type)
return nil
}
}
func (c *Client) Close() {
c.Cancel()
c.Hub.UnregisterClient(c)
if c.Conn != nil {
_ = c.Conn.Close(websocket.StatusNormalClosure, "client closed")
}
}

View File

@ -2,7 +2,6 @@ package ws
import (
"context"
"io"
"log"
"net/http"
"time"
@ -23,53 +22,11 @@ func Handler(ctx context.Context, h *Hub) http.HandlerFunc {
return
}
c := NewClient(r.RemoteAddr, conn, ctx)
c := NewClient(r.RemoteAddr, conn, h, ctx)
log.Println("Client", r.RemoteAddr, "connected.")
h.RegisterClient(c)
go echo(c, h)
go c.ReadLoop()
go c.WriteLoop()
go heartbeat(c)
}
}
func echo(c *Client, h *Hub) {
defer func() {
if c.Conn != nil {
log.Println("Closing WebSocket connection")
h.UnregisterClient(c)
c.Cancel()
_ = c.Conn.Close(websocket.StatusNormalClosure, "echo finished")
}
}()
for {
typ, r, err := c.Conn.Reader(c.Ctx)
if err != nil {
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
log.Println("WebSocket connection closed normally")
} else {
log.Println("WebSocket reader error:", err)
}
return
}
w, err := c.Conn.Writer(c.Ctx, typ)
if err != nil {
log.Println("WebSocket writer error:", err)
return
}
_, err = io.Copy(w, r)
if err != nil {
log.Println("WebSocket copy error:", err)
return
}
if err = w.Close(); err != nil {
log.Println("WebSocket writer close error:", err)
return
}
}
}

View File

@ -4,7 +4,7 @@ import (
"context"
"log"
"git.jinshen.cn/remilia/push-server/interval/server/model"
"git.jinshen.cn/remilia/push-server/interval/protocol"
"github.com/coder/websocket"
)
@ -17,28 +17,28 @@ type Hub struct {
// Invariant:
// - clientsByTopic contains only topics with at least one active subscriber.
// - A topic key must not exist if it has zero clients.
clientsByTopic map[model.Topic]map[string]*Client
clientsByTopic map[protocol.Topic]map[string]*Client
// clientID -> topic -> struct{}
topicsByClients map[string]map[model.Topic]struct{}
topicsByClients map[string]map[protocol.Topic]struct{}
register chan *Client
unregister chan *Client
subscribe chan model.Subscription
unsubscribe chan model.Subscription
broadcast chan model.Message
subscribe chan protocol.Subscription
unsubscribe chan protocol.Subscription
broadcast chan protocol.BroadcastMessage
}
func NewHub() *Hub {
return &Hub{
clientsByID: make(map[string]*Client),
clientsByTopic: make(map[model.Topic]map[string]*Client),
topicsByClients: make(map[string]map[model.Topic]struct{}),
clientsByTopic: make(map[protocol.Topic]map[string]*Client),
topicsByClients: make(map[string]map[protocol.Topic]struct{}),
register: make(chan *Client),
unregister: make(chan *Client),
subscribe: make(chan model.Subscription, 8),
unsubscribe: make(chan model.Subscription, 8),
broadcast: make(chan model.Message, 64),
subscribe: make(chan protocol.Subscription, 8),
unsubscribe: make(chan protocol.Subscription, 8),
broadcast: make(chan protocol.BroadcastMessage, 64),
}
}
@ -50,15 +50,15 @@ func (h *Hub) UnregisterClient(client *Client) {
h.unregister <- client
}
func (h *Hub) Subscribe(sub model.Subscription) {
func (h *Hub) Subscribe(sub protocol.Subscription) {
h.subscribe <- sub
}
func (h *Hub) Unsubscribe(sub model.Subscription) {
func (h *Hub) Unsubscribe(sub protocol.Subscription) {
h.unsubscribe <- sub
}
func (h *Hub) BroadcastMessage(ctx context.Context, msg model.Message) error {
func (h *Hub) BroadcastMessage(ctx context.Context, msg protocol.BroadcastMessage) error {
select {
case h.broadcast <- msg:
return nil
@ -103,7 +103,7 @@ func (h *Hub) getClient(id string) (*Client, bool) {
// Create a new entry for the client in topicsByClients map when it registers.
func (h *Hub) onRegister(c *Client) {
h.clientsByID[c.ID] = c
h.topicsByClients[c.ID] = make(map[model.Topic]struct{})
h.topicsByClients[c.ID] = make(map[protocol.Topic]struct{})
log.Printf("Current clientList: %v\n", h.clientsByID)
}
@ -122,7 +122,7 @@ func (h *Hub) onUnregister(c *Client) {
delete(h.clientsByID, c.ID)
}
func (h *Hub) onSubscribe(s model.Subscription) {
func (h *Hub) onSubscribe(s protocol.Subscription) {
c, ok := h.getClient(s.ClientID)
if !ok {
// If the client does not exist, log an error and return.
@ -135,9 +135,10 @@ func (h *Hub) onSubscribe(s model.Subscription) {
h.clientsByTopic[s.Topic][s.ClientID] = c
h.topicsByClients[s.ClientID][s.Topic] = struct{}{}
log.Printf("Client %s subscribed to topic %s", s.ClientID, s.Topic)
}
func (h *Hub) onUnsubscribe(s model.Subscription) {
func (h *Hub) onUnsubscribe(s protocol.Subscription) {
if clients, ok := h.clientsByTopic[s.Topic]; ok {
delete(clients, s.ClientID)
if len(clients) == 0 {
@ -150,15 +151,16 @@ func (h *Hub) onUnsubscribe(s model.Subscription) {
}
}
func (h *Hub) onBroadcast(msg model.Message) {
func (h *Hub) onBroadcast(msg protocol.BroadcastMessage) {
if !msg.Topic.Valid() {
log.Printf("Broadcast failed: invalid topic")
return
}
log.Printf("Receiving message for topic %s: %s", msg.Topic, string(msg.Content))
log.Printf("Receiving message for topic %s: %s", msg.Topic, string(msg.Payload))
for _, c := range h.clientsByTopic[msg.Topic] {
select {
case c.SendChan <- msg.Content:
case c.SendChan <- msg.Payload:
log.Printf("Sending message to client %s: %s", c.ID, string(msg.Payload))
default:
h.UnregisterClient(c)
if c.Conn != nil {