package ws import ( "context" "log" "git.jinshen.cn/remilia/push-server/interval/model" "github.com/coder/websocket" ) // Hub is the central message distribution hub. type Hub struct { // ClientID -> *Client clientsByID map[string]*Client // topic -> clientID -> *Client // // Invariant: // - clientsByTopic contains only topics with at least one active subscriber. // - A topic key must not exist if it has zero clients. clientsByTopic map[model.Topic]map[string]*Client // clientID -> topic -> struct{} topicsByClients map[string]map[model.Topic]struct{} register chan *Client unregister chan *Client subscribe chan model.Subscription unsubscribe chan model.Subscription broadcast chan model.Message } func NewHub() *Hub { return &Hub{ clientsByID: make(map[string]*Client), clientsByTopic: make(map[model.Topic]map[string]*Client), topicsByClients: make(map[string]map[model.Topic]struct{}), register: make(chan *Client), unregister: make(chan *Client), subscribe: make(chan model.Subscription, 8), unsubscribe: make(chan model.Subscription, 8), broadcast: make(chan model.Message, 64), } } func (h *Hub) RegisterClient(client *Client) { h.register <- client } func (h *Hub) UnregisterClient(client *Client) { h.unregister <- client } func (h *Hub) Subscribe(sub model.Subscription) { h.subscribe <- sub } func (h *Hub) Unsubscribe(sub model.Subscription) { h.unsubscribe <- sub } func (h *Hub) BroadcastMessage(ctx context.Context, msg model.Message) error { select { case h.broadcast <- msg: return nil case <-ctx.Done(): return ctx.Err() } } func (h *Hub) Run(ctx context.Context) { log.Println("Hub is running") for { select { case c := <-h.register: h.onRegister(c) case c := <-h.unregister: h.onUnregister(c) case c := <-h.subscribe: h.onSubscribe(c) case s := <-h.unsubscribe: h.onUnsubscribe(s) case msg := <-h.broadcast: h.onBroadcast(msg) case <-ctx.Done(): h.shutdown() return } } } // Get a client by its ClientID func (h *Hub) getClient(id string) (*Client, bool) { c, ok := h.clientsByID[id] return c, ok } // Create a new entry for the client in topicsByClients map when it registers. func (h *Hub) onRegister(c *Client) { h.clientsByID[c.ID] = c h.topicsByClients[c.ID] = make(map[model.Topic]struct{}) log.Printf("Current clientList: %v\n", h.clientsByID) } // Delete all topic subscriptions for the client when it unregisters. func (h *Hub) onUnregister(c *Client) { topics := h.topicsByClients[c.ID] for t := range topics { if clients, ok := h.clientsByTopic[t]; ok { delete(clients, c.ID) if len(clients) == 0 { delete(h.clientsByTopic, t) } } } delete(h.topicsByClients, c.ID) delete(h.clientsByID, c.ID) } func (h *Hub) onSubscribe(s model.Subscription) { c, ok := h.getClient(s.ClientID) if !ok { // If the client does not exist, log an error and return. log.Printf("Subscribe failed: client %s not found", s.ClientID) return } if h.clientsByTopic[s.Topic] == nil { h.clientsByTopic[s.Topic] = make(map[string]*Client) } h.clientsByTopic[s.Topic][s.ClientID] = c h.topicsByClients[s.ClientID][s.Topic] = struct{}{} } func (h *Hub) onUnsubscribe(s model.Subscription) { if clients, ok := h.clientsByTopic[s.Topic]; ok { delete(clients, s.ClientID) if len(clients) == 0 { delete(h.clientsByTopic, s.Topic) } } if topics, ok := h.topicsByClients[s.ClientID]; ok { delete(topics, s.Topic) } } func (h *Hub) onBroadcast(msg model.Message) { if !msg.Topic.Valid() { log.Printf("Broadcast failed: invalid topic") return } log.Printf("Receiving message for topic %s: %s", msg.Topic, string(msg.Content)) for _, c := range h.clientsByTopic[msg.Topic] { select { case c.SendChan <- msg.Content: default: h.UnregisterClient(c) if c.Conn != nil { _ = c.Conn.Close(websocket.StatusPolicyViolation, "Slow consumer") } } } } func (h *Hub) shutdown() { for _, c := range h.clientsByID { close(c.SendChan) if c.Conn != nil { _ = c.Conn.Close(websocket.StatusNormalClosure, "Server shutdown") } } }