feat: 基本的websocket echo服务
This commit is contained in:
@ -10,20 +10,24 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.jinshen.cn/remilia/push-server/interval/api"
|
"git.jinshen.cn/remilia/push-server/interval/api"
|
||||||
"git.jinshen.cn/remilia/push-server/interval/hub"
|
|
||||||
"git.jinshen.cn/remilia/push-server/interval/server"
|
"git.jinshen.cn/remilia/push-server/interval/server"
|
||||||
|
"git.jinshen.cn/remilia/push-server/interval/ws"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
serverCtx, serverCancel := context.WithCancel(context.Background())
|
serverCtx, stop := signal.NotifyContext(
|
||||||
|
context.Background(),
|
||||||
|
os.Interrupt,
|
||||||
|
syscall.SIGTERM,
|
||||||
|
)
|
||||||
defer func() {
|
defer func() {
|
||||||
serverCancel()
|
stop()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
h := hub.NewHub()
|
h := ws.NewHub()
|
||||||
go h.Run(serverCtx)
|
go h.Run(serverCtx)
|
||||||
|
|
||||||
httpServer := server.NewHTTPServer(":8080", api.NewRouter(h))
|
httpServer := server.NewHTTPServer(":8080", api.NewRouter(h, serverCtx))
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
log.Println("Starting HTTP server on :8080")
|
log.Println("Starting HTTP server on :8080")
|
||||||
@ -32,16 +36,14 @@ func main() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
sig := make(chan os.Signal, 1)
|
<-serverCtx.Done()
|
||||||
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
|
||||||
<-sig
|
|
||||||
|
|
||||||
log.Println("Shutting down server...")
|
log.Println("Shutting down server...")
|
||||||
|
|
||||||
serverCancel()
|
|
||||||
|
|
||||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), time.Second*10)
|
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
defer shutdownCancel()
|
defer shutdownCancel()
|
||||||
|
|
||||||
httpServer.Shutdown(shutdownCtx)
|
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||||
|
log.Printf("HTTP server shutdown error: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
43
cmd/test-client/main.go
Normal file
43
cmd/test-client/main.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/coder/websocket"
|
||||||
|
"github.com/coder/websocket/wsjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
log.Println("This is a test client.")
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
c, _, err := websocket.Dial(ctx, "ws://localhost:8080/ws", nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("dial error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer c.CloseNow()
|
||||||
|
|
||||||
|
err = wsjson.Write(ctx, c, "Hello, WebSocket server!")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("write error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
typ, msg, err := c.Read(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("read error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch typ {
|
||||||
|
case websocket.MessageText:
|
||||||
|
log.Printf("Received text message: %s", string(msg))
|
||||||
|
case websocket.MessageBinary:
|
||||||
|
log.Printf("Received binary message: %v", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Close(websocket.StatusNormalClosure, "test client finished")
|
||||||
|
}
|
||||||
2
go.mod
2
go.mod
@ -4,4 +4,4 @@ go 1.25.5
|
|||||||
|
|
||||||
require github.com/coder/websocket v1.8.14
|
require github.com/coder/websocket v1.8.14
|
||||||
|
|
||||||
require github.com/go-chi/chi/v5 v5.2.3 // indirect
|
require github.com/go-chi/chi/v5 v5.2.3
|
||||||
|
|||||||
@ -4,5 +4,7 @@ import "net/http"
|
|||||||
|
|
||||||
func Health(w http.ResponseWriter, _ *http.Request) {
|
func Health(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte("OK"))
|
if _, err := w.Write([]byte("OK")); err != nil {
|
||||||
|
http.Error(w, "failed to write response", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,12 +6,12 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.jinshen.cn/remilia/push-server/interval/api/dto"
|
"git.jinshen.cn/remilia/push-server/interval/api/dto"
|
||||||
"git.jinshen.cn/remilia/push-server/interval/hub"
|
|
||||||
"git.jinshen.cn/remilia/push-server/interval/model"
|
"git.jinshen.cn/remilia/push-server/interval/model"
|
||||||
|
"git.jinshen.cn/remilia/push-server/interval/ws"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
func PushHandler(hub *hub.Hub) http.HandlerFunc {
|
func PushHandler(hub *ws.Hub) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
topicStr := chi.URLParam(r, "topic")
|
topicStr := chi.URLParam(r, "topic")
|
||||||
topic := model.Topic(topicStr)
|
topic := model.Topic(topicStr)
|
||||||
|
|||||||
@ -1,16 +1,19 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"git.jinshen.cn/remilia/push-server/interval/api/handler"
|
"git.jinshen.cn/remilia/push-server/interval/api/handler"
|
||||||
"git.jinshen.cn/remilia/push-server/interval/hub"
|
"git.jinshen.cn/remilia/push-server/interval/ws"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewRouter(h *hub.Hub) http.Handler {
|
func NewRouter(h *ws.Hub, ctx context.Context) http.Handler {
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
|
|
||||||
|
r.Get("/ws", ws.Handler(ctx, h))
|
||||||
|
|
||||||
r.Post("/health", handler.Health)
|
r.Post("/health", handler.Health)
|
||||||
r.Post("/push/{topic}", handler.PushHandler(h))
|
r.Post("/push/{topic}", handler.PushHandler(h))
|
||||||
|
|
||||||
|
|||||||
58
interval/ws/client.go
Normal file
58
interval/ws/client.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coder/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client represents a connected client in the hub.
|
||||||
|
type Client struct {
|
||||||
|
ID string
|
||||||
|
Conn *websocket.Conn
|
||||||
|
SendChan chan []byte
|
||||||
|
|
||||||
|
Ctx context.Context
|
||||||
|
Cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(id string, conn *websocket.Conn, parentCtx context.Context) *Client {
|
||||||
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
ID: id,
|
||||||
|
Conn: conn,
|
||||||
|
SendChan: make(chan []byte, 32),
|
||||||
|
Ctx: ctx,
|
||||||
|
Cancel: cancel,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write message to websocket connection.
|
||||||
|
func (c *Client) WriteMessage() {
|
||||||
|
defer func() {
|
||||||
|
_ = c.Conn.Close(websocket.StatusNormalClosure, "client writer closed")
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Ctx.Done():
|
||||||
|
return
|
||||||
|
case msg, ok := <-c.SendChan:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeCtx, cancel := context.WithTimeout(c.Ctx, 5*time.Second)
|
||||||
|
err := c.Conn.Write(writeCtx, websocket.MessageText, msg)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Println("WebSocket write error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
2
interval/ws/doc.go
Normal file
2
interval/ws/doc.go
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
// Package ws implements the Websocket handler for the push service.
|
||||||
|
package ws
|
||||||
75
interval/ws/handler.go
Normal file
75
interval/ws/handler.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coder/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Handler(ctx context.Context, h *Hub) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||||
|
OriginPatterns: []string{"*"},
|
||||||
|
})
|
||||||
|
|
||||||
|
log.Println("New WebSocket connection from", r.RemoteAddr, "at", time.Now().Format(time.RFC3339))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Println("WebSocket accept error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c := NewClient(r.RemoteAddr, conn, ctx)
|
||||||
|
log.Println("Client", r.RemoteAddr, "connected.")
|
||||||
|
h.RegisterClient(c)
|
||||||
|
|
||||||
|
go echo(c, h)
|
||||||
|
go heartbeat(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func echo(c *Client, h *Hub) {
|
||||||
|
defer func() {
|
||||||
|
if c.Conn != nil {
|
||||||
|
log.Println("Closing WebSocket connection")
|
||||||
|
h.UnregisterClient(c)
|
||||||
|
c.Cancel()
|
||||||
|
_ = c.Conn.Close(websocket.StatusNormalClosure, "echo finished")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
typ, r, err := c.Conn.Reader(c.Ctx)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
|
||||||
|
log.Println("WebSocket connection closed normally")
|
||||||
|
} else {
|
||||||
|
log.Println("WebSocket reader error:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w, err := c.Conn.Writer(c.Ctx, typ)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("WebSocket writer error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = io.Copy(w, r)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("WebSocket copy error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = w.Close(); err != nil {
|
||||||
|
log.Println("WebSocket writer close error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
35
interval/ws/heartbeat.go
Normal file
35
interval/ws/heartbeat.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/coder/websocket"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func heartbeat(c *Client) {
|
||||||
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
c.Cancel()
|
||||||
|
_ = c.Conn.Close(websocket.StatusNormalClosure, "heartbeat stopped")
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
pingCtx, pingCancel := context.WithTimeout(c.Ctx, 5*time.Second)
|
||||||
|
err := c.Conn.Ping(pingCtx)
|
||||||
|
pingCancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Ping filed: ", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,4 +1,4 @@
|
|||||||
package hub
|
package ws
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -8,13 +8,6 @@ import (
|
|||||||
"github.com/coder/websocket"
|
"github.com/coder/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client represents a connected client in the hub.
|
|
||||||
type Client struct {
|
|
||||||
ID string
|
|
||||||
Conn *websocket.Conn
|
|
||||||
SendChan chan []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hub is the central message distribution hub.
|
// Hub is the central message distribution hub.
|
||||||
type Hub struct {
|
type Hub struct {
|
||||||
// ClientID -> *Client
|
// ClientID -> *Client
|
||||||
@ -43,9 +36,9 @@ func NewHub() *Hub {
|
|||||||
|
|
||||||
register: make(chan *Client),
|
register: make(chan *Client),
|
||||||
unregister: make(chan *Client),
|
unregister: make(chan *Client),
|
||||||
subscribe: make(chan model.Subscription),
|
subscribe: make(chan model.Subscription, 8),
|
||||||
unsubscribe: make(chan model.Subscription),
|
unsubscribe: make(chan model.Subscription, 8),
|
||||||
broadcast: make(chan model.Message),
|
broadcast: make(chan model.Message, 64),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,6 +104,7 @@ func (h *Hub) getClient(id string) (*Client, bool) {
|
|||||||
func (h *Hub) onRegister(c *Client) {
|
func (h *Hub) onRegister(c *Client) {
|
||||||
h.clientsByID[c.ID] = c
|
h.clientsByID[c.ID] = c
|
||||||
h.topicsByClients[c.ID] = make(map[model.Topic]struct{})
|
h.topicsByClients[c.ID] = make(map[model.Topic]struct{})
|
||||||
|
log.Printf("Current clientList: %v\n", h.clientsByID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete all topic subscriptions for the client when it unregisters.
|
// Delete all topic subscriptions for the client when it unregisters.
|
||||||
@ -123,7 +117,6 @@ func (h *Hub) onUnregister(c *Client) {
|
|||||||
delete(h.clientsByTopic, t)
|
delete(h.clientsByTopic, t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
delete(h.topicsByClients, c.ID)
|
delete(h.topicsByClients, c.ID)
|
||||||
delete(h.clientsByID, c.ID)
|
delete(h.clientsByID, c.ID)
|
||||||
Reference in New Issue
Block a user