186 lines
4.1 KiB
Go
186 lines
4.1 KiB
Go
package hub
|
|
|
|
import (
|
|
"context"
|
|
"log"
|
|
|
|
"git.jinshen.cn/remilia/push-server/interval/model"
|
|
"github.com/coder/websocket"
|
|
)
|
|
|
|
// Client represents a connected client in the hub.
|
|
type Client struct {
|
|
ID string
|
|
Conn *websocket.Conn
|
|
SendChan chan []byte
|
|
}
|
|
|
|
// Hub is the central message distribution hub.
|
|
type Hub struct {
|
|
// ClientID -> *Client
|
|
clientsByID map[string]*Client
|
|
// topic -> clientID -> *Client
|
|
//
|
|
// Invariant:
|
|
// - clientsByTopic contains only topics with at least one active subscriber.
|
|
// - A topic key must not exist if it has zero clients.
|
|
clientsByTopic map[model.Topic]map[string]*Client
|
|
// clientID -> topic -> struct{}
|
|
topicsByClients map[string]map[model.Topic]struct{}
|
|
|
|
register chan *Client
|
|
unregister chan *Client
|
|
subscribe chan model.Subscription
|
|
unsubscribe chan model.Subscription
|
|
broadcast chan model.Message
|
|
}
|
|
|
|
func NewHub() *Hub {
|
|
return &Hub{
|
|
clientsByID: make(map[string]*Client),
|
|
clientsByTopic: make(map[model.Topic]map[string]*Client),
|
|
topicsByClients: make(map[string]map[model.Topic]struct{}),
|
|
|
|
register: make(chan *Client),
|
|
unregister: make(chan *Client),
|
|
subscribe: make(chan model.Subscription),
|
|
unsubscribe: make(chan model.Subscription),
|
|
broadcast: make(chan model.Message),
|
|
}
|
|
}
|
|
|
|
func (h *Hub) RegisterClient(client *Client) {
|
|
h.register <- client
|
|
}
|
|
|
|
func (h *Hub) UnregisterClient(client *Client) {
|
|
h.unregister <- client
|
|
}
|
|
|
|
func (h *Hub) Subscribe(sub model.Subscription) {
|
|
h.subscribe <- sub
|
|
}
|
|
|
|
func (h *Hub) Unsubscribe(sub model.Subscription) {
|
|
h.unsubscribe <- sub
|
|
}
|
|
|
|
func (h *Hub) BroadcastMessage(ctx context.Context, msg model.Message) error {
|
|
select {
|
|
case h.broadcast <- msg:
|
|
return nil
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
|
|
}
|
|
|
|
func (h *Hub) Run(ctx context.Context) {
|
|
log.Println("Hub is running")
|
|
for {
|
|
select {
|
|
case c := <-h.register:
|
|
h.onRegister(c)
|
|
|
|
case c := <-h.unregister:
|
|
h.onUnregister(c)
|
|
|
|
case c := <-h.subscribe:
|
|
h.onSubscribe(c)
|
|
|
|
case s := <-h.unsubscribe:
|
|
h.onUnsubscribe(s)
|
|
|
|
case msg := <-h.broadcast:
|
|
h.onBroadcast(msg)
|
|
|
|
case <-ctx.Done():
|
|
h.shutdown()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Get a client by its ClientID
|
|
func (h *Hub) getClient(id string) (*Client, bool) {
|
|
c, ok := h.clientsByID[id]
|
|
return c, ok
|
|
}
|
|
|
|
// Create a new entry for the client in topicsByClients map when it registers.
|
|
func (h *Hub) onRegister(c *Client) {
|
|
h.clientsByID[c.ID] = c
|
|
h.topicsByClients[c.ID] = make(map[model.Topic]struct{})
|
|
}
|
|
|
|
// Delete all topic subscriptions for the client when it unregisters.
|
|
func (h *Hub) onUnregister(c *Client) {
|
|
topics := h.topicsByClients[c.ID]
|
|
for t := range topics {
|
|
if clients, ok := h.clientsByTopic[t]; ok {
|
|
delete(clients, c.ID)
|
|
if len(clients) == 0 {
|
|
delete(h.clientsByTopic, t)
|
|
}
|
|
}
|
|
|
|
}
|
|
delete(h.topicsByClients, c.ID)
|
|
delete(h.clientsByID, c.ID)
|
|
}
|
|
|
|
func (h *Hub) onSubscribe(s model.Subscription) {
|
|
c, ok := h.getClient(s.ClientID)
|
|
if !ok {
|
|
// If the client does not exist, log an error and return.
|
|
log.Printf("Subscribe failed: client %s not found", s.ClientID)
|
|
return
|
|
}
|
|
if h.clientsByTopic[s.Topic] == nil {
|
|
h.clientsByTopic[s.Topic] = make(map[string]*Client)
|
|
}
|
|
|
|
h.clientsByTopic[s.Topic][s.ClientID] = c
|
|
h.topicsByClients[s.ClientID][s.Topic] = struct{}{}
|
|
}
|
|
|
|
func (h *Hub) onUnsubscribe(s model.Subscription) {
|
|
if clients, ok := h.clientsByTopic[s.Topic]; ok {
|
|
delete(clients, s.ClientID)
|
|
if len(clients) == 0 {
|
|
delete(h.clientsByTopic, s.Topic)
|
|
}
|
|
}
|
|
|
|
if topics, ok := h.topicsByClients[s.ClientID]; ok {
|
|
delete(topics, s.Topic)
|
|
}
|
|
}
|
|
|
|
func (h *Hub) onBroadcast(msg model.Message) {
|
|
if !msg.Topic.Valid() {
|
|
log.Printf("Broadcast failed: invalid topic")
|
|
return
|
|
}
|
|
log.Printf("Receiving message for topic %s: %s", msg.Topic, string(msg.Content))
|
|
for _, c := range h.clientsByTopic[msg.Topic] {
|
|
select {
|
|
case c.SendChan <- msg.Content:
|
|
default:
|
|
h.UnregisterClient(c)
|
|
if c.Conn != nil {
|
|
_ = c.Conn.Close(websocket.StatusPolicyViolation, "Slow consumer")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *Hub) shutdown() {
|
|
for _, c := range h.clientsByID {
|
|
close(c.SendChan)
|
|
if c.Conn != nil {
|
|
_ = c.Conn.Close(websocket.StatusNormalClosure, "Server shutdown")
|
|
}
|
|
}
|
|
}
|