From 1dbcc03e46e992bde0e7fb5247413f077174cd96 Mon Sep 17 00:00:00 2001 From: R2m1liA <15258427350@163.com> Date: Wed, 17 Dec 2025 14:49:59 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=9F=BA=E6=9C=AC=E5=B9=BF=E6=92=AD?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 由Hub接收/push/{topic}的请求并解析信息体广播到对应的Client --- cmd/client/main.go | 69 ++++++++++++++ cmd/test-client/main.go | 43 --------- interval/protocol/message.go | 39 ++++++++ .../model => protocol}/subscription.go | 4 +- interval/protocol/types.go | 17 ++++ interval/server/api/dto/doc.go | 2 - interval/server/api/dto/message.go | 17 ---- interval/server/api/dto/publish.go | 5 - interval/server/api/dto/subscription.go | 17 ---- interval/server/api/handler/push.go | 23 +++-- interval/server/model/doc.go | 2 - interval/server/model/message.go | 7 -- interval/server/model/topic.go | 7 -- interval/server/ws/client.go | 93 ++++++++++++++++--- interval/server/ws/handler.go | 49 +--------- interval/server/ws/hub.go | 42 +++++---- 16 files changed, 245 insertions(+), 191 deletions(-) create mode 100644 cmd/client/main.go delete mode 100644 cmd/test-client/main.go create mode 100644 interval/protocol/message.go rename interval/{server/model => protocol}/subscription.go (78%) create mode 100644 interval/protocol/types.go delete mode 100644 interval/server/api/dto/doc.go delete mode 100644 interval/server/api/dto/message.go delete mode 100644 interval/server/api/dto/publish.go delete mode 100644 interval/server/api/dto/subscription.go delete mode 100644 interval/server/model/doc.go delete mode 100644 interval/server/model/message.go delete mode 100644 interval/server/model/topic.go diff --git a/cmd/client/main.go b/cmd/client/main.go new file mode 100644 index 0000000..b31602a --- /dev/null +++ b/cmd/client/main.go @@ -0,0 +1,69 @@ +package main + +import ( + "context" + "log" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + + "git.jinshen.cn/remilia/push-server/interval/protocol" +) + +func main() { + log.Println("This is a test client.") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c, _, err := websocket.Dial(ctx, "ws://localhost:8080/ws", nil) + if err != nil { + log.Fatal("dial error:", err) + return + } + defer func() { _ = c.CloseNow() }() + + initMsg := protocol.ControlMessage{ + Type: protocol.MsgInit, + Topics: []protocol.Topic{"news", "sports"}, + } + + log.Println("Sending init message:", initMsg) + + err = wsjson.Write(ctx, c, initMsg) + if err != nil { + if websocket.CloseStatus(err) != websocket.StatusNormalClosure { + log.Printf("init write failed: %v", err) + } + return + } + + // typ, msg, err := c.Read(ctx) + // if err != nil { + // log.Println("read error:", err) + // return + // } + // + // switch typ { + // case websocket.MessageText: + // log.Printf("Received text message: %s", string(msg)) + // case websocket.MessageBinary: + // log.Printf("Received binary message: %v", msg) + // } + go ReadBroadCastLoop(ctx, c) + <-ctx.Done() + + _ = c.Close(websocket.StatusNormalClosure, "test client finished") +} + +func ReadBroadCastLoop(ctx context.Context, c *websocket.Conn) { + for { + // var msg protocol.BroadcastMessage + var msg []byte + if err := wsjson.Read(ctx, c, &msg); err != nil { + log.Println("read broadcast error:", err) + return + } + + log.Println("Received broadcast message:", string(msg)) + } +} diff --git a/cmd/test-client/main.go b/cmd/test-client/main.go deleted file mode 100644 index 44fa85b..0000000 --- a/cmd/test-client/main.go +++ /dev/null @@ -1,43 +0,0 @@ -package main - -import ( - "context" - "log" - - "github.com/coder/websocket" - "github.com/coder/websocket/wsjson" -) - -func main() { - log.Println("This is a test client.") - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - c, _, err := websocket.Dial(ctx, "ws://localhost:8080/ws", nil) - if err != nil { - log.Fatal("dial error:", err) - return - } - defer c.CloseNow() - - err = wsjson.Write(ctx, c, "Hello, WebSocket server!") - if err != nil { - log.Fatal("write error:", err) - return - } - - typ, msg, err := c.Read(ctx) - if err != nil { - log.Println("read error:", err) - return - } - - switch typ { - case websocket.MessageText: - log.Printf("Received text message: %s", string(msg)) - case websocket.MessageBinary: - log.Printf("Received binary message: %v", msg) - } - - c.Close(websocket.StatusNormalClosure, "test client finished") -} diff --git a/interval/protocol/message.go b/interval/protocol/message.go new file mode 100644 index 0000000..61d0b6f --- /dev/null +++ b/interval/protocol/message.go @@ -0,0 +1,39 @@ +package protocol + +import ( + "encoding/json" + "errors" +) + +type ControlMessage struct { + Type MessageType `json:"type"` + Topic Topic `json:"topic,omitempty"` + Topics []Topic `json:"topics,omitempty"` +} + +type BroadcastMessage struct { + Type MessageType `json:"type"` + Topic Topic `json:"topic"` + Payload json.RawMessage `json:"payload"` +} + +func (m ControlMessage) Validate() error { + switch m.Type { + case MsgInit: + if len(m.Topics) == 0 { + return errors.New("init requires topics") + } + case MsgSubscribe: + if m.Topic == "" { + return errors.New("subscribe requires topic") + } + default: + return errors.New("unknown message type") + } + return nil +} + +var ( + ErrInvalidMessage = errors.New("invalid message") + ErrPolicyViolation = errors.New("policy violation") +) diff --git a/interval/server/model/subscription.go b/interval/protocol/subscription.go similarity index 78% rename from interval/server/model/subscription.go rename to interval/protocol/subscription.go index 41ad00b..620cf81 100644 --- a/interval/server/model/subscription.go +++ b/interval/protocol/subscription.go @@ -1,6 +1,6 @@ -package model +package protocol type Subscription struct { - Topic Topic ClientID string + Topic Topic } diff --git a/interval/protocol/types.go b/interval/protocol/types.go new file mode 100644 index 0000000..993a6c9 --- /dev/null +++ b/interval/protocol/types.go @@ -0,0 +1,17 @@ +package protocol + +type MessageType string + +const ( + MsgInit MessageType = "init" + MsgSubscribe MessageType = "subscribe" + MsgUnsubscribe MessageType = "unsubscribe" + MsgBroadcast MessageType = "broadcast" + MsgError MessageType = "error" +) + +type Topic string + +func (t Topic) Valid() bool { + return t != "" +} diff --git a/interval/server/api/dto/doc.go b/interval/server/api/dto/doc.go deleted file mode 100644 index d9ac152..0000000 --- a/interval/server/api/dto/doc.go +++ /dev/null @@ -1,2 +0,0 @@ -// Package dto contains data transfer objects used in the interval API. -package dto diff --git a/interval/server/api/dto/message.go b/interval/server/api/dto/message.go deleted file mode 100644 index dfdd9a7..0000000 --- a/interval/server/api/dto/message.go +++ /dev/null @@ -1,17 +0,0 @@ -package dto - -import ( - "git.jinshen.cn/remilia/push-server/interval/server/model" -) - -type Message struct { - Topic string `json:"topic"` - Content string `json:"content"` -} - -func MessageFromModel(m model.Message) Message { - return Message{ - Topic: string(m.Topic), - Content: string(m.Content), - } -} diff --git a/interval/server/api/dto/publish.go b/interval/server/api/dto/publish.go deleted file mode 100644 index 54124cc..0000000 --- a/interval/server/api/dto/publish.go +++ /dev/null @@ -1,5 +0,0 @@ -package dto - -type PublishRequest struct { - Content string `json:"content"` -} diff --git a/interval/server/api/dto/subscription.go b/interval/server/api/dto/subscription.go deleted file mode 100644 index 4b314dd..0000000 --- a/interval/server/api/dto/subscription.go +++ /dev/null @@ -1,17 +0,0 @@ -package dto - -import ( - "git.jinshen.cn/remilia/push-server/interval/server/model" -) - -type Subscription struct { - Topic string `json:"topic"` - ClientID string `json:"client_id"` -} - -func SubscriptionFromModel(s model.Subscription) Subscription { - return Subscription{ - Topic: string(s.Topic), - ClientID: string(s.ClientID), - } -} diff --git a/interval/server/api/handler/push.go b/interval/server/api/handler/push.go index 3dbcfd7..63f11a0 100644 --- a/interval/server/api/handler/push.go +++ b/interval/server/api/handler/push.go @@ -2,25 +2,28 @@ package handler import ( "encoding/json" + "log" "net/http" - "time" - "git.jinshen.cn/remilia/push-server/interval/server/api/dto" - "git.jinshen.cn/remilia/push-server/interval/server/model" + "git.jinshen.cn/remilia/push-server/interval/protocol" "git.jinshen.cn/remilia/push-server/interval/server/ws" "github.com/go-chi/chi/v5" ) +type PublishRequest struct { + Content string `json:"content"` +} + func PushHandler(hub *ws.Hub) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { topicStr := chi.URLParam(r, "topic") - topic := model.Topic(topicStr) + topic := protocol.Topic(topicStr) if !topic.Valid() { http.Error(w, "invalid topic", http.StatusBadRequest) return } - var req dto.PublishRequest + var req PublishRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "invalid request body", http.StatusBadRequest) return @@ -30,12 +33,14 @@ func PushHandler(hub *ws.Hub) http.HandlerFunc { return } - msg := model.Message{ - Topic: topic, - Content: []byte(req.Content), - Timestamp: time.Now().Unix(), + msg := protocol.BroadcastMessage{ + Type: protocol.MsgBroadcast, + Topic: topic, + Payload: json.RawMessage(req.Content), } + log.Printf("Received push request for topic %s: %s", topic, req.Content) + if err := hub.BroadcastMessage(r.Context(), msg); err != nil { http.Error(w, "request cancelled", http.StatusRequestTimeout) return diff --git a/interval/server/model/doc.go b/interval/server/model/doc.go deleted file mode 100644 index 0535f75..0000000 --- a/interval/server/model/doc.go +++ /dev/null @@ -1,2 +0,0 @@ -// Package model defines core domain models of the push service. -package model diff --git a/interval/server/model/message.go b/interval/server/model/message.go deleted file mode 100644 index b9ab29b..0000000 --- a/interval/server/model/message.go +++ /dev/null @@ -1,7 +0,0 @@ -package model - -type Message struct { - Topic Topic - Content []byte - Timestamp int64 -} diff --git a/interval/server/model/topic.go b/interval/server/model/topic.go deleted file mode 100644 index ab8e05e..0000000 --- a/interval/server/model/topic.go +++ /dev/null @@ -1,7 +0,0 @@ -package model - -type Topic string - -func (t Topic) Valid() bool { - return t != "" -} diff --git a/interval/server/ws/client.go b/interval/server/ws/client.go index 08ad94f..8280557 100644 --- a/interval/server/ws/client.go +++ b/interval/server/ws/client.go @@ -2,10 +2,12 @@ package ws import ( "context" + "errors" "log" - "time" + "git.jinshen.cn/remilia/push-server/interval/protocol" "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" ) // Client represents a connected client in the hub. @@ -13,28 +15,58 @@ type Client struct { ID string Conn *websocket.Conn SendChan chan []byte + Hub *Hub Ctx context.Context Cancel context.CancelFunc + + inited bool } -func NewClient(id string, conn *websocket.Conn, parentCtx context.Context) *Client { +func NewClient(id string, conn *websocket.Conn, hub *Hub, parentCtx context.Context) *Client { ctx, cancel := context.WithCancel(parentCtx) return &Client{ ID: id, Conn: conn, SendChan: make(chan []byte, 32), - Ctx: ctx, - Cancel: cancel, + Hub: hub, + + Ctx: ctx, + Cancel: cancel, + + inited: false, + } +} +func (c *Client) ReadLoop() { + defer c.Close() + + for { + var msg protocol.ControlMessage + + err := wsjson.Read(c.Ctx, c.Conn, &msg) + if err != nil { + if websocket.CloseStatus(err) == websocket.StatusNormalClosure { + log.Println("WebSocket closed normally:", err) + } else { + log.Println("WebSocket read error:", err) + } + return + } + + if err := msg.Validate(); err != nil { + _ = c.Conn.Close(websocket.StatusPolicyViolation, "invalid message") + return + } + + if err := c.handleControlMessage(msg); err != nil { + return + } } } -// Write message to websocket connection. -func (c *Client) WriteMessage() { - defer func() { - _ = c.Conn.Close(websocket.StatusNormalClosure, "client writer closed") - }() +func (c *Client) WriteLoop() { + defer c.Close() for { select { @@ -44,11 +76,7 @@ func (c *Client) WriteMessage() { if !ok { return } - - writeCtx, cancel := context.WithTimeout(c.Ctx, 5*time.Second) - err := c.Conn.Write(writeCtx, websocket.MessageText, msg) - cancel() - + err := wsjson.Write(c.Ctx, c.Conn, msg) if err != nil { log.Println("WebSocket write error:", err) return @@ -56,3 +84,40 @@ func (c *Client) WriteMessage() { } } } + +func (c *Client) handleControlMessage(msg protocol.ControlMessage) error { + switch msg.Type { + case protocol.MsgInit: + if c.inited { + return errors.New("already initialized") + } + + log.Printf("Client %s initializing with topics: %v", c.ID, msg.Topics) + c.Hub.RegisterClient(c) + + for _, t := range msg.Topics { + c.Hub.Subscribe(protocol.Subscription{ + ClientID: c.ID, + Topic: protocol.Topic(t), + }) + } + + c.inited = true + + return nil + default: + if !c.inited { + return errors.New("client not initialized") + } + log.Println("Unknown control message type:", msg.Type) + return nil + } +} + +func (c *Client) Close() { + c.Cancel() + c.Hub.UnregisterClient(c) + if c.Conn != nil { + _ = c.Conn.Close(websocket.StatusNormalClosure, "client closed") + } +} diff --git a/interval/server/ws/handler.go b/interval/server/ws/handler.go index b82ed23..bdf8c3d 100644 --- a/interval/server/ws/handler.go +++ b/interval/server/ws/handler.go @@ -2,7 +2,6 @@ package ws import ( "context" - "io" "log" "net/http" "time" @@ -23,53 +22,11 @@ func Handler(ctx context.Context, h *Hub) http.HandlerFunc { return } - c := NewClient(r.RemoteAddr, conn, ctx) + c := NewClient(r.RemoteAddr, conn, h, ctx) log.Println("Client", r.RemoteAddr, "connected.") - h.RegisterClient(c) - go echo(c, h) + go c.ReadLoop() + go c.WriteLoop() go heartbeat(c) } } - -func echo(c *Client, h *Hub) { - defer func() { - if c.Conn != nil { - log.Println("Closing WebSocket connection") - h.UnregisterClient(c) - c.Cancel() - _ = c.Conn.Close(websocket.StatusNormalClosure, "echo finished") - } - }() - - for { - typ, r, err := c.Conn.Reader(c.Ctx) - - if err != nil { - if websocket.CloseStatus(err) == websocket.StatusNormalClosure { - log.Println("WebSocket connection closed normally") - } else { - log.Println("WebSocket reader error:", err) - } - - return - } - - w, err := c.Conn.Writer(c.Ctx, typ) - if err != nil { - log.Println("WebSocket writer error:", err) - return - } - - _, err = io.Copy(w, r) - if err != nil { - log.Println("WebSocket copy error:", err) - return - } - - if err = w.Close(); err != nil { - log.Println("WebSocket writer close error:", err) - return - } - } -} diff --git a/interval/server/ws/hub.go b/interval/server/ws/hub.go index d91033e..257678d 100644 --- a/interval/server/ws/hub.go +++ b/interval/server/ws/hub.go @@ -4,7 +4,7 @@ import ( "context" "log" - "git.jinshen.cn/remilia/push-server/interval/server/model" + "git.jinshen.cn/remilia/push-server/interval/protocol" "github.com/coder/websocket" ) @@ -17,28 +17,28 @@ type Hub struct { // Invariant: // - clientsByTopic contains only topics with at least one active subscriber. // - A topic key must not exist if it has zero clients. - clientsByTopic map[model.Topic]map[string]*Client + clientsByTopic map[protocol.Topic]map[string]*Client // clientID -> topic -> struct{} - topicsByClients map[string]map[model.Topic]struct{} + topicsByClients map[string]map[protocol.Topic]struct{} register chan *Client unregister chan *Client - subscribe chan model.Subscription - unsubscribe chan model.Subscription - broadcast chan model.Message + subscribe chan protocol.Subscription + unsubscribe chan protocol.Subscription + broadcast chan protocol.BroadcastMessage } func NewHub() *Hub { return &Hub{ clientsByID: make(map[string]*Client), - clientsByTopic: make(map[model.Topic]map[string]*Client), - topicsByClients: make(map[string]map[model.Topic]struct{}), + clientsByTopic: make(map[protocol.Topic]map[string]*Client), + topicsByClients: make(map[string]map[protocol.Topic]struct{}), register: make(chan *Client), unregister: make(chan *Client), - subscribe: make(chan model.Subscription, 8), - unsubscribe: make(chan model.Subscription, 8), - broadcast: make(chan model.Message, 64), + subscribe: make(chan protocol.Subscription, 8), + unsubscribe: make(chan protocol.Subscription, 8), + broadcast: make(chan protocol.BroadcastMessage, 64), } } @@ -50,15 +50,15 @@ func (h *Hub) UnregisterClient(client *Client) { h.unregister <- client } -func (h *Hub) Subscribe(sub model.Subscription) { +func (h *Hub) Subscribe(sub protocol.Subscription) { h.subscribe <- sub } -func (h *Hub) Unsubscribe(sub model.Subscription) { +func (h *Hub) Unsubscribe(sub protocol.Subscription) { h.unsubscribe <- sub } -func (h *Hub) BroadcastMessage(ctx context.Context, msg model.Message) error { +func (h *Hub) BroadcastMessage(ctx context.Context, msg protocol.BroadcastMessage) error { select { case h.broadcast <- msg: return nil @@ -103,7 +103,7 @@ func (h *Hub) getClient(id string) (*Client, bool) { // Create a new entry for the client in topicsByClients map when it registers. func (h *Hub) onRegister(c *Client) { h.clientsByID[c.ID] = c - h.topicsByClients[c.ID] = make(map[model.Topic]struct{}) + h.topicsByClients[c.ID] = make(map[protocol.Topic]struct{}) log.Printf("Current clientList: %v\n", h.clientsByID) } @@ -122,7 +122,7 @@ func (h *Hub) onUnregister(c *Client) { delete(h.clientsByID, c.ID) } -func (h *Hub) onSubscribe(s model.Subscription) { +func (h *Hub) onSubscribe(s protocol.Subscription) { c, ok := h.getClient(s.ClientID) if !ok { // If the client does not exist, log an error and return. @@ -135,9 +135,10 @@ func (h *Hub) onSubscribe(s model.Subscription) { h.clientsByTopic[s.Topic][s.ClientID] = c h.topicsByClients[s.ClientID][s.Topic] = struct{}{} + log.Printf("Client %s subscribed to topic %s", s.ClientID, s.Topic) } -func (h *Hub) onUnsubscribe(s model.Subscription) { +func (h *Hub) onUnsubscribe(s protocol.Subscription) { if clients, ok := h.clientsByTopic[s.Topic]; ok { delete(clients, s.ClientID) if len(clients) == 0 { @@ -150,15 +151,16 @@ func (h *Hub) onUnsubscribe(s model.Subscription) { } } -func (h *Hub) onBroadcast(msg model.Message) { +func (h *Hub) onBroadcast(msg protocol.BroadcastMessage) { if !msg.Topic.Valid() { log.Printf("Broadcast failed: invalid topic") return } - log.Printf("Receiving message for topic %s: %s", msg.Topic, string(msg.Content)) + log.Printf("Receiving message for topic %s: %s", msg.Topic, string(msg.Payload)) for _, c := range h.clientsByTopic[msg.Topic] { select { - case c.SendChan <- msg.Content: + case c.SendChan <- msg.Payload: + log.Printf("Sending message to client %s: %s", c.ID, string(msg.Payload)) default: h.UnregisterClient(c) if c.Conn != nil {