diff --git a/cmd/server/main.go b/cmd/server/main.go index 3682192..bb88cae 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -10,20 +10,24 @@ import ( "time" "git.jinshen.cn/remilia/push-server/interval/api" - "git.jinshen.cn/remilia/push-server/interval/hub" "git.jinshen.cn/remilia/push-server/interval/server" + "git.jinshen.cn/remilia/push-server/interval/ws" ) func main() { - serverCtx, serverCancel := context.WithCancel(context.Background()) + serverCtx, stop := signal.NotifyContext( + context.Background(), + os.Interrupt, + syscall.SIGTERM, + ) defer func() { - serverCancel() + stop() }() - h := hub.NewHub() + h := ws.NewHub() go h.Run(serverCtx) - httpServer := server.NewHTTPServer(":8080", api.NewRouter(h)) + httpServer := server.NewHTTPServer(":8080", api.NewRouter(h, serverCtx)) go func() { log.Println("Starting HTTP server on :8080") @@ -32,16 +36,14 @@ func main() { } }() - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) - <-sig + <-serverCtx.Done() log.Println("Shutting down server...") - serverCancel() - shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), time.Second*10) defer shutdownCancel() - httpServer.Shutdown(shutdownCtx) + if err := httpServer.Shutdown(shutdownCtx); err != nil { + log.Printf("HTTP server shutdown error: %v", err) + } } diff --git a/cmd/test-client/main.go b/cmd/test-client/main.go new file mode 100644 index 0000000..44fa85b --- /dev/null +++ b/cmd/test-client/main.go @@ -0,0 +1,43 @@ +package main + +import ( + "context" + "log" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" +) + +func main() { + log.Println("This is a test client.") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c, _, err := websocket.Dial(ctx, "ws://localhost:8080/ws", nil) + if err != nil { + log.Fatal("dial error:", err) + return + } + defer c.CloseNow() + + err = wsjson.Write(ctx, c, "Hello, WebSocket server!") + if err != nil { + log.Fatal("write error:", err) + return + } + + typ, msg, err := c.Read(ctx) + if err != nil { + log.Println("read error:", err) + return + } + + switch typ { + case websocket.MessageText: + log.Printf("Received text message: %s", string(msg)) + case websocket.MessageBinary: + log.Printf("Received binary message: %v", msg) + } + + c.Close(websocket.StatusNormalClosure, "test client finished") +} diff --git a/go.mod b/go.mod index 1eb274b..6bcbb74 100644 --- a/go.mod +++ b/go.mod @@ -4,4 +4,4 @@ go 1.25.5 require github.com/coder/websocket v1.8.14 -require github.com/go-chi/chi/v5 v5.2.3 // indirect +require github.com/go-chi/chi/v5 v5.2.3 diff --git a/interval/api/handler/health.go b/interval/api/handler/health.go index 1460723..4ab39d8 100644 --- a/interval/api/handler/health.go +++ b/interval/api/handler/health.go @@ -4,5 +4,7 @@ import "net/http" func Health(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) - w.Write([]byte("OK")) + if _, err := w.Write([]byte("OK")); err != nil { + http.Error(w, "failed to write response", http.StatusInternalServerError) + } } diff --git a/interval/api/handler/push.go b/interval/api/handler/push.go index 5140847..2289b9a 100644 --- a/interval/api/handler/push.go +++ b/interval/api/handler/push.go @@ -6,12 +6,12 @@ import ( "time" "git.jinshen.cn/remilia/push-server/interval/api/dto" - "git.jinshen.cn/remilia/push-server/interval/hub" "git.jinshen.cn/remilia/push-server/interval/model" + "git.jinshen.cn/remilia/push-server/interval/ws" "github.com/go-chi/chi/v5" ) -func PushHandler(hub *hub.Hub) http.HandlerFunc { +func PushHandler(hub *ws.Hub) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { topicStr := chi.URLParam(r, "topic") topic := model.Topic(topicStr) diff --git a/interval/api/router.go b/interval/api/router.go index 5f200ea..1047a5d 100644 --- a/interval/api/router.go +++ b/interval/api/router.go @@ -1,16 +1,19 @@ package api import ( + "context" "net/http" "git.jinshen.cn/remilia/push-server/interval/api/handler" - "git.jinshen.cn/remilia/push-server/interval/hub" + "git.jinshen.cn/remilia/push-server/interval/ws" "github.com/go-chi/chi/v5" ) -func NewRouter(h *hub.Hub) http.Handler { +func NewRouter(h *ws.Hub, ctx context.Context) http.Handler { r := chi.NewRouter() + r.Get("/ws", ws.Handler(ctx, h)) + r.Post("/health", handler.Health) r.Post("/push/{topic}", handler.PushHandler(h)) diff --git a/interval/ws/client.go b/interval/ws/client.go new file mode 100644 index 0000000..08ad94f --- /dev/null +++ b/interval/ws/client.go @@ -0,0 +1,58 @@ +package ws + +import ( + "context" + "log" + "time" + + "github.com/coder/websocket" +) + +// Client represents a connected client in the hub. +type Client struct { + ID string + Conn *websocket.Conn + SendChan chan []byte + + Ctx context.Context + Cancel context.CancelFunc +} + +func NewClient(id string, conn *websocket.Conn, parentCtx context.Context) *Client { + ctx, cancel := context.WithCancel(parentCtx) + + return &Client{ + ID: id, + Conn: conn, + SendChan: make(chan []byte, 32), + Ctx: ctx, + Cancel: cancel, + } +} + +// Write message to websocket connection. +func (c *Client) WriteMessage() { + defer func() { + _ = c.Conn.Close(websocket.StatusNormalClosure, "client writer closed") + }() + + for { + select { + case <-c.Ctx.Done(): + return + case msg, ok := <-c.SendChan: + if !ok { + return + } + + writeCtx, cancel := context.WithTimeout(c.Ctx, 5*time.Second) + err := c.Conn.Write(writeCtx, websocket.MessageText, msg) + cancel() + + if err != nil { + log.Println("WebSocket write error:", err) + return + } + } + } +} diff --git a/interval/ws/doc.go b/interval/ws/doc.go new file mode 100644 index 0000000..423a21e --- /dev/null +++ b/interval/ws/doc.go @@ -0,0 +1,2 @@ +// Package ws implements the Websocket handler for the push service. +package ws diff --git a/interval/ws/handler.go b/interval/ws/handler.go new file mode 100644 index 0000000..b82ed23 --- /dev/null +++ b/interval/ws/handler.go @@ -0,0 +1,75 @@ +package ws + +import ( + "context" + "io" + "log" + "net/http" + "time" + + "github.com/coder/websocket" +) + +func Handler(ctx context.Context, h *Hub) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + }) + + log.Println("New WebSocket connection from", r.RemoteAddr, "at", time.Now().Format(time.RFC3339)) + + if err != nil { + log.Println("WebSocket accept error:", err) + return + } + + c := NewClient(r.RemoteAddr, conn, ctx) + log.Println("Client", r.RemoteAddr, "connected.") + h.RegisterClient(c) + + go echo(c, h) + go heartbeat(c) + } +} + +func echo(c *Client, h *Hub) { + defer func() { + if c.Conn != nil { + log.Println("Closing WebSocket connection") + h.UnregisterClient(c) + c.Cancel() + _ = c.Conn.Close(websocket.StatusNormalClosure, "echo finished") + } + }() + + for { + typ, r, err := c.Conn.Reader(c.Ctx) + + if err != nil { + if websocket.CloseStatus(err) == websocket.StatusNormalClosure { + log.Println("WebSocket connection closed normally") + } else { + log.Println("WebSocket reader error:", err) + } + + return + } + + w, err := c.Conn.Writer(c.Ctx, typ) + if err != nil { + log.Println("WebSocket writer error:", err) + return + } + + _, err = io.Copy(w, r) + if err != nil { + log.Println("WebSocket copy error:", err) + return + } + + if err = w.Close(); err != nil { + log.Println("WebSocket writer close error:", err) + return + } + } +} diff --git a/interval/ws/heartbeat.go b/interval/ws/heartbeat.go new file mode 100644 index 0000000..d8c3616 --- /dev/null +++ b/interval/ws/heartbeat.go @@ -0,0 +1,35 @@ +package ws + +import ( + "context" + + "github.com/coder/websocket" + "log" + "time" +) + +func heartbeat(c *Client) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + defer func() { + c.Cancel() + _ = c.Conn.Close(websocket.StatusNormalClosure, "heartbeat stopped") + }() + + for { + select { + case <-c.Ctx.Done(): + return + case <-ticker.C: + pingCtx, pingCancel := context.WithTimeout(c.Ctx, 5*time.Second) + err := c.Conn.Ping(pingCtx) + pingCancel() + + if err != nil { + log.Println("Ping filed: ", err) + return + } + } + } +} diff --git a/interval/hub/hub.go b/interval/ws/hub.go similarity index 93% rename from interval/hub/hub.go rename to interval/ws/hub.go index 6b77443..7df38bc 100644 --- a/interval/hub/hub.go +++ b/interval/ws/hub.go @@ -1,4 +1,4 @@ -package hub +package ws import ( "context" @@ -8,13 +8,6 @@ import ( "github.com/coder/websocket" ) -// Client represents a connected client in the hub. -type Client struct { - ID string - Conn *websocket.Conn - SendChan chan []byte -} - // Hub is the central message distribution hub. type Hub struct { // ClientID -> *Client @@ -43,9 +36,9 @@ func NewHub() *Hub { register: make(chan *Client), unregister: make(chan *Client), - subscribe: make(chan model.Subscription), - unsubscribe: make(chan model.Subscription), - broadcast: make(chan model.Message), + subscribe: make(chan model.Subscription, 8), + unsubscribe: make(chan model.Subscription, 8), + broadcast: make(chan model.Message, 64), } } @@ -111,6 +104,7 @@ func (h *Hub) getClient(id string) (*Client, bool) { func (h *Hub) onRegister(c *Client) { h.clientsByID[c.ID] = c h.topicsByClients[c.ID] = make(map[model.Topic]struct{}) + log.Printf("Current clientList: %v\n", h.clientsByID) } // Delete all topic subscriptions for the client when it unregisters. @@ -123,7 +117,6 @@ func (h *Hub) onUnregister(c *Client) { delete(h.clientsByTopic, t) } } - } delete(h.topicsByClients, c.ID) delete(h.clientsByID, c.ID)