Compare commits
11 Commits
8a1d56e45f
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 36296b6af4 | |||
| 1dbcc03e46 | |||
| 53555a31c0 | |||
| 1bc9c6a924 | |||
| b824dc3792 | |||
| 18874711ea | |||
| a72a46838e | |||
| e7a769c1b7 | |||
| 9bac821750 | |||
| 736d4f550c | |||
| 6a386eb9a0 |
68
cmd/client/main.go
Normal file
68
cmd/client/main.go
Normal file
@ -0,0 +1,68 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/coder/websocket/wsjson"
|
||||
|
||||
"git.jinshen.cn/remilia/push-server/interval/protocol"
|
||||
)
|
||||
|
||||
func main() {
|
||||
log.Println("This is a test client.")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
c, _, err := websocket.Dial(ctx, "ws://localhost:8080/ws", nil)
|
||||
if err != nil {
|
||||
log.Fatal("dial error:", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = c.CloseNow() }()
|
||||
|
||||
initMsg := protocol.ControlMessage{
|
||||
Type: protocol.MsgInit,
|
||||
Topics: []protocol.Topic{"news", "sports"},
|
||||
}
|
||||
|
||||
log.Println("Sending init message:", initMsg)
|
||||
|
||||
err = wsjson.Write(ctx, c, initMsg)
|
||||
if err != nil {
|
||||
if websocket.CloseStatus(err) != websocket.StatusNormalClosure {
|
||||
log.Printf("init write failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// typ, msg, err := c.Read(ctx)
|
||||
// if err != nil {
|
||||
// log.Println("read error:", err)
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// switch typ {
|
||||
// case websocket.MessageText:
|
||||
// log.Printf("Received text message: %s", string(msg))
|
||||
// case websocket.MessageBinary:
|
||||
// log.Printf("Received binary message: %v", msg)
|
||||
// }
|
||||
go ReadBroadCastLoop(ctx, c)
|
||||
<-ctx.Done()
|
||||
|
||||
_ = c.Close(websocket.StatusNormalClosure, "test client finished")
|
||||
}
|
||||
|
||||
func ReadBroadCastLoop(ctx context.Context, c *websocket.Conn) {
|
||||
for {
|
||||
var msg protocol.BroadcastMessage
|
||||
if err := wsjson.Read(ctx, c, &msg); err != nil {
|
||||
log.Println("Read broadcast error:", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Received broadcast message: [%s] %s", msg.Topic, msg.Payload)
|
||||
}
|
||||
}
|
||||
49
cmd/server/main.go
Normal file
49
cmd/server/main.go
Normal file
@ -0,0 +1,49 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"git.jinshen.cn/remilia/push-server/interval/server/api"
|
||||
"git.jinshen.cn/remilia/push-server/interval/server/httpserver"
|
||||
"git.jinshen.cn/remilia/push-server/interval/server/ws"
|
||||
)
|
||||
|
||||
func main() {
|
||||
serverCtx, stop := signal.NotifyContext(
|
||||
context.Background(),
|
||||
os.Interrupt,
|
||||
syscall.SIGTERM,
|
||||
)
|
||||
defer func() {
|
||||
stop()
|
||||
}()
|
||||
|
||||
h := ws.NewHub()
|
||||
go h.Run(serverCtx)
|
||||
|
||||
httpServer := httpserver.NewHTTPServer(":8080", api.NewRouter(h, serverCtx))
|
||||
|
||||
go func() {
|
||||
log.Println("Starting HTTP server on :8080")
|
||||
if err := httpServer.Start(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("HTTP server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-serverCtx.Done()
|
||||
|
||||
log.Println("Shutting down server...")
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer shutdownCancel()
|
||||
|
||||
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
log.Printf("HTTP server shutdown error: %v", err)
|
||||
}
|
||||
}
|
||||
4
go.mod
4
go.mod
@ -1,3 +1,7 @@
|
||||
module git.jinshen.cn/remilia/push-server
|
||||
|
||||
go 1.25.5
|
||||
|
||||
require github.com/coder/websocket v1.8.14
|
||||
|
||||
require github.com/go-chi/chi/v5 v5.2.3
|
||||
|
||||
4
go.sum
Normal file
4
go.sum
Normal file
@ -0,0 +1,4 @@
|
||||
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
|
||||
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
|
||||
github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE=
|
||||
github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
||||
40
interval/protocol/message.go
Normal file
40
interval/protocol/message.go
Normal file
@ -0,0 +1,40 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type ControlMessage struct {
|
||||
Type MessageType `json:"type"`
|
||||
Topic Topic `json:"topic,omitempty"`
|
||||
Topics []Topic `json:"topics,omitempty"`
|
||||
}
|
||||
|
||||
type BroadcastMessage struct {
|
||||
Type MessageType `json:"type"`
|
||||
Topic Topic `json:"topic"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
func (m ControlMessage) Validate() error {
|
||||
switch m.Type {
|
||||
case MsgInit:
|
||||
if len(m.Topics) == 0 {
|
||||
return errors.New("init requires topics")
|
||||
}
|
||||
case MsgSubscribe:
|
||||
if m.Topic == "" {
|
||||
return errors.New("subscribe requires topic")
|
||||
}
|
||||
default:
|
||||
return errors.New("unknown message type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidMessage = errors.New("invalid message")
|
||||
ErrPolicyViolation = errors.New("policy violation")
|
||||
)
|
||||
6
interval/protocol/subscription.go
Normal file
6
interval/protocol/subscription.go
Normal file
@ -0,0 +1,6 @@
|
||||
package protocol
|
||||
|
||||
type Subscription struct {
|
||||
ClientID string
|
||||
Topic Topic
|
||||
}
|
||||
17
interval/protocol/types.go
Normal file
17
interval/protocol/types.go
Normal file
@ -0,0 +1,17 @@
|
||||
package protocol
|
||||
|
||||
type MessageType string
|
||||
|
||||
const (
|
||||
MsgInit MessageType = "init"
|
||||
MsgSubscribe MessageType = "subscribe"
|
||||
MsgUnsubscribe MessageType = "unsubscribe"
|
||||
MsgBroadcast MessageType = "broadcast"
|
||||
MsgError MessageType = "error"
|
||||
)
|
||||
|
||||
type Topic string
|
||||
|
||||
func (t Topic) Valid() bool {
|
||||
return t != ""
|
||||
}
|
||||
2
interval/server/api/doc.go
Normal file
2
interval/server/api/doc.go
Normal file
@ -0,0 +1,2 @@
|
||||
// Package api defines the HTTP control plane of the push service
|
||||
package api
|
||||
2
interval/server/api/handler/doc.go
Normal file
2
interval/server/api/handler/doc.go
Normal file
@ -0,0 +1,2 @@
|
||||
// Package handler contains HTTP request handlers for the REST API.
|
||||
package handler
|
||||
10
interval/server/api/handler/health.go
Normal file
10
interval/server/api/handler/health.go
Normal file
@ -0,0 +1,10 @@
|
||||
package handler
|
||||
|
||||
import "net/http"
|
||||
|
||||
func Health(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if _, err := w.Write([]byte("OK")); err != nil {
|
||||
http.Error(w, "failed to write response", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
54
interval/server/api/handler/push.go
Normal file
54
interval/server/api/handler/push.go
Normal file
@ -0,0 +1,54 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.jinshen.cn/remilia/push-server/interval/protocol"
|
||||
"git.jinshen.cn/remilia/push-server/interval/server/ws"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
type PublishRequest struct {
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
}
|
||||
|
||||
func PushHandler(hub *ws.Hub) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
topicStr := chi.URLParam(r, "topic")
|
||||
topic := protocol.Topic(topicStr)
|
||||
if !topic.Valid() {
|
||||
http.Error(w, "invalid topic", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req PublishRequest
|
||||
dec := json.NewDecoder(r.Body)
|
||||
dec.DisallowUnknownFields()
|
||||
if err := dec.Decode(&req); err != nil {
|
||||
http.Error(w, "invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Payload) == 0 {
|
||||
http.Error(w, "content cannot be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
msg := protocol.BroadcastMessage{
|
||||
Type: protocol.MsgBroadcast,
|
||||
Topic: topic,
|
||||
Payload: req.Payload,
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
if err := hub.BroadcastMessage(r.Context(), msg); err != nil {
|
||||
http.Error(w, "request cancelled", http.StatusRequestTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
}
|
||||
}
|
||||
21
interval/server/api/router.go
Normal file
21
interval/server/api/router.go
Normal file
@ -0,0 +1,21 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"git.jinshen.cn/remilia/push-server/interval/server/api/handler"
|
||||
"git.jinshen.cn/remilia/push-server/interval/server/ws"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func NewRouter(h *ws.Hub, ctx context.Context) http.Handler {
|
||||
r := chi.NewRouter()
|
||||
|
||||
r.Get("/ws", ws.Handler(ctx, h))
|
||||
|
||||
r.Post("/health", handler.Health)
|
||||
r.Post("/push/{topic}", handler.PushHandler(h))
|
||||
|
||||
return r
|
||||
}
|
||||
2
interval/server/httpserver/doc.go
Normal file
2
interval/server/httpserver/doc.go
Normal file
@ -0,0 +1,2 @@
|
||||
// Package httpserver provides HTTP server abstractions.
|
||||
package httpserver
|
||||
27
interval/server/httpserver/http.go
Normal file
27
interval/server/httpserver/http.go
Normal file
@ -0,0 +1,27 @@
|
||||
package httpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type HTTPServer struct {
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
func NewHTTPServer(addr string, handler http.Handler) *HTTPServer {
|
||||
return &HTTPServer{
|
||||
server: &http.Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HTTPServer) Start() error {
|
||||
return s.server.ListenAndServe()
|
||||
}
|
||||
|
||||
func (s *HTTPServer) Shutdown(ctx context.Context) error {
|
||||
return s.server.Shutdown(ctx)
|
||||
}
|
||||
124
interval/server/ws/client.go
Normal file
124
interval/server/ws/client.go
Normal file
@ -0,0 +1,124 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
"git.jinshen.cn/remilia/push-server/interval/protocol"
|
||||
"github.com/coder/websocket"
|
||||
"github.com/coder/websocket/wsjson"
|
||||
)
|
||||
|
||||
// Client represents a connected client in the hub.
|
||||
type Client struct {
|
||||
ID string
|
||||
Conn *websocket.Conn
|
||||
SendChan chan protocol.BroadcastMessage
|
||||
Hub *Hub
|
||||
|
||||
Ctx context.Context
|
||||
Cancel context.CancelFunc
|
||||
|
||||
inited bool
|
||||
}
|
||||
|
||||
func NewClient(id string, conn *websocket.Conn, hub *Hub, parentCtx context.Context) *Client {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
|
||||
return &Client{
|
||||
ID: id,
|
||||
Conn: conn,
|
||||
SendChan: make(chan protocol.BroadcastMessage, 32),
|
||||
Hub: hub,
|
||||
|
||||
Ctx: ctx,
|
||||
Cancel: cancel,
|
||||
|
||||
inited: false,
|
||||
}
|
||||
}
|
||||
func (c *Client) ReadLoop() {
|
||||
defer c.Close()
|
||||
|
||||
for {
|
||||
var msg protocol.ControlMessage
|
||||
|
||||
err := wsjson.Read(c.Ctx, c.Conn, &msg)
|
||||
if err != nil {
|
||||
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
|
||||
log.Println("WebSocket closed normally:", err)
|
||||
} else {
|
||||
log.Println("[Server Client.ReadLoop] WebSocket read error:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := msg.Validate(); err != nil {
|
||||
_ = c.Conn.Close(websocket.StatusPolicyViolation, "invalid message")
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.handleControlMessage(msg); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) WriteLoop() {
|
||||
defer c.Close()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Ctx.Done():
|
||||
return
|
||||
case msg, ok := <-c.SendChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Printf("Sending message to client %s: %+v", c.ID, msg)
|
||||
err := wsjson.Write(c.Ctx, c.Conn, msg)
|
||||
if err != nil {
|
||||
log.Println("[Server Client.WriteLoop] WebSocket write error:", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) handleControlMessage(msg protocol.ControlMessage) error {
|
||||
switch msg.Type {
|
||||
case protocol.MsgInit:
|
||||
if c.inited {
|
||||
return errors.New("already initialized")
|
||||
}
|
||||
|
||||
log.Printf("Client %s initializing with topics: %v", c.ID, msg.Topics)
|
||||
c.Hub.RegisterClient(c)
|
||||
|
||||
for _, t := range msg.Topics {
|
||||
c.Hub.Subscribe(protocol.Subscription{
|
||||
ClientID: c.ID,
|
||||
Topic: protocol.Topic(t),
|
||||
})
|
||||
}
|
||||
|
||||
c.inited = true
|
||||
|
||||
return nil
|
||||
default:
|
||||
if !c.inited {
|
||||
return errors.New("client not initialized")
|
||||
}
|
||||
log.Println("Unknown control message type:", msg.Type)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Close() {
|
||||
c.Cancel()
|
||||
c.Hub.UnregisterClient(c)
|
||||
if c.Conn != nil {
|
||||
_ = c.Conn.Close(websocket.StatusNormalClosure, "client closed")
|
||||
}
|
||||
}
|
||||
2
interval/server/ws/doc.go
Normal file
2
interval/server/ws/doc.go
Normal file
@ -0,0 +1,2 @@
|
||||
// Package ws implements the Websocket handler for the push service.
|
||||
package ws
|
||||
32
interval/server/ws/handler.go
Normal file
32
interval/server/ws/handler.go
Normal file
@ -0,0 +1,32 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
func Handler(ctx context.Context, h *Hub) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||
OriginPatterns: []string{"*"},
|
||||
})
|
||||
|
||||
log.Println("New WebSocket connection from", r.RemoteAddr, "at", time.Now().Format(time.RFC3339))
|
||||
|
||||
if err != nil {
|
||||
log.Println("WebSocket accept error:", err)
|
||||
return
|
||||
}
|
||||
|
||||
c := NewClient(r.RemoteAddr, conn, h, ctx)
|
||||
log.Println("Client", r.RemoteAddr, "connected.")
|
||||
|
||||
go c.ReadLoop()
|
||||
go c.WriteLoop()
|
||||
go heartbeat(c)
|
||||
}
|
||||
}
|
||||
35
interval/server/ws/heartbeat.go
Normal file
35
interval/server/ws/heartbeat.go
Normal file
@ -0,0 +1,35 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
func heartbeat(c *Client) {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
defer func() {
|
||||
c.Cancel()
|
||||
_ = c.Conn.Close(websocket.StatusNormalClosure, "heartbeat stopped")
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
pingCtx, pingCancel := context.WithTimeout(c.Ctx, 5*time.Second)
|
||||
err := c.Conn.Ping(pingCtx)
|
||||
pingCancel()
|
||||
|
||||
if err != nil {
|
||||
log.Println("Ping filed: ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
180
interval/server/ws/hub.go
Normal file
180
interval/server/ws/hub.go
Normal file
@ -0,0 +1,180 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"git.jinshen.cn/remilia/push-server/interval/protocol"
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
// Hub is the central message distribution hub.
|
||||
type Hub struct {
|
||||
// ClientID -> *Client
|
||||
clientsByID map[string]*Client
|
||||
// topic -> clientID -> *Client
|
||||
//
|
||||
// Invariant:
|
||||
// - clientsByTopic contains only topics with at least one active subscriber.
|
||||
// - A topic key must not exist if it has zero clients.
|
||||
clientsByTopic map[protocol.Topic]map[string]*Client
|
||||
// clientID -> topic -> struct{}
|
||||
topicsByClients map[string]map[protocol.Topic]struct{}
|
||||
|
||||
register chan *Client
|
||||
unregister chan *Client
|
||||
subscribe chan protocol.Subscription
|
||||
unsubscribe chan protocol.Subscription
|
||||
broadcast chan protocol.BroadcastMessage
|
||||
}
|
||||
|
||||
func NewHub() *Hub {
|
||||
return &Hub{
|
||||
clientsByID: make(map[string]*Client),
|
||||
clientsByTopic: make(map[protocol.Topic]map[string]*Client),
|
||||
topicsByClients: make(map[string]map[protocol.Topic]struct{}),
|
||||
|
||||
register: make(chan *Client),
|
||||
unregister: make(chan *Client),
|
||||
subscribe: make(chan protocol.Subscription, 8),
|
||||
unsubscribe: make(chan protocol.Subscription, 8),
|
||||
broadcast: make(chan protocol.BroadcastMessage, 64),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) RegisterClient(client *Client) {
|
||||
h.register <- client
|
||||
}
|
||||
|
||||
func (h *Hub) UnregisterClient(client *Client) {
|
||||
h.unregister <- client
|
||||
}
|
||||
|
||||
func (h *Hub) Subscribe(sub protocol.Subscription) {
|
||||
h.subscribe <- sub
|
||||
}
|
||||
|
||||
func (h *Hub) Unsubscribe(sub protocol.Subscription) {
|
||||
h.unsubscribe <- sub
|
||||
}
|
||||
|
||||
func (h *Hub) BroadcastMessage(ctx context.Context, msg protocol.BroadcastMessage) error {
|
||||
select {
|
||||
case h.broadcast <- msg:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (h *Hub) Run(ctx context.Context) {
|
||||
log.Println("Hub is running")
|
||||
for {
|
||||
select {
|
||||
case c := <-h.register:
|
||||
h.onRegister(c)
|
||||
|
||||
case c := <-h.unregister:
|
||||
h.onUnregister(c)
|
||||
|
||||
case c := <-h.subscribe:
|
||||
h.onSubscribe(c)
|
||||
|
||||
case s := <-h.unsubscribe:
|
||||
h.onUnsubscribe(s)
|
||||
|
||||
case msg := <-h.broadcast:
|
||||
h.onBroadcast(msg)
|
||||
|
||||
case <-ctx.Done():
|
||||
h.shutdown()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get a client by its ClientID
|
||||
func (h *Hub) getClient(id string) (*Client, bool) {
|
||||
c, ok := h.clientsByID[id]
|
||||
return c, ok
|
||||
}
|
||||
|
||||
// Create a new entry for the client in topicsByClients map when it registers.
|
||||
func (h *Hub) onRegister(c *Client) {
|
||||
h.clientsByID[c.ID] = c
|
||||
h.topicsByClients[c.ID] = make(map[protocol.Topic]struct{})
|
||||
log.Printf("Current clientList: %v\n", h.clientsByID)
|
||||
}
|
||||
|
||||
// Delete all topic subscriptions for the client when it unregisters.
|
||||
func (h *Hub) onUnregister(c *Client) {
|
||||
topics := h.topicsByClients[c.ID]
|
||||
for t := range topics {
|
||||
if clients, ok := h.clientsByTopic[t]; ok {
|
||||
delete(clients, c.ID)
|
||||
if len(clients) == 0 {
|
||||
delete(h.clientsByTopic, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
delete(h.topicsByClients, c.ID)
|
||||
delete(h.clientsByID, c.ID)
|
||||
}
|
||||
|
||||
func (h *Hub) onSubscribe(s protocol.Subscription) {
|
||||
c, ok := h.getClient(s.ClientID)
|
||||
if !ok {
|
||||
// If the client does not exist, log an error and return.
|
||||
log.Printf("Subscribe failed: client %s not found", s.ClientID)
|
||||
return
|
||||
}
|
||||
if h.clientsByTopic[s.Topic] == nil {
|
||||
h.clientsByTopic[s.Topic] = make(map[string]*Client)
|
||||
}
|
||||
|
||||
h.clientsByTopic[s.Topic][s.ClientID] = c
|
||||
h.topicsByClients[s.ClientID][s.Topic] = struct{}{}
|
||||
log.Printf("Client %s subscribed to topic %s", s.ClientID, s.Topic)
|
||||
}
|
||||
|
||||
func (h *Hub) onUnsubscribe(s protocol.Subscription) {
|
||||
if clients, ok := h.clientsByTopic[s.Topic]; ok {
|
||||
delete(clients, s.ClientID)
|
||||
if len(clients) == 0 {
|
||||
delete(h.clientsByTopic, s.Topic)
|
||||
}
|
||||
}
|
||||
|
||||
if topics, ok := h.topicsByClients[s.ClientID]; ok {
|
||||
delete(topics, s.Topic)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) onBroadcast(msg protocol.BroadcastMessage) {
|
||||
if !msg.Topic.Valid() {
|
||||
log.Printf("Broadcast failed: invalid topic")
|
||||
return
|
||||
}
|
||||
log.Printf("Receiving message for topic %s: %s", msg.Topic, string(msg.Payload))
|
||||
for _, c := range h.clientsByTopic[msg.Topic] {
|
||||
select {
|
||||
case c.SendChan <- msg:
|
||||
log.Printf("[%d] Sending message to client %s: [%s] %v", msg.Timestamp, c.ID, msg.Type, string(msg.Payload))
|
||||
default:
|
||||
h.UnregisterClient(c)
|
||||
if c.Conn != nil {
|
||||
_ = c.Conn.Close(websocket.StatusPolicyViolation, "Slow consumer")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) shutdown() {
|
||||
for _, c := range h.clientsByID {
|
||||
close(c.SendChan)
|
||||
if c.Conn != nil {
|
||||
_ = c.Conn.Close(websocket.StatusNormalClosure, "Server shutdown")
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user