Compare commits
3 Commits
1bc9c6a924
...
36296b6af4
| Author | SHA1 | Date | |
|---|---|---|---|
| 36296b6af4 | |||
| 1dbcc03e46 | |||
| 53555a31c0 |
68
cmd/client/main.go
Normal file
68
cmd/client/main.go
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/coder/websocket"
|
||||||
|
"github.com/coder/websocket/wsjson"
|
||||||
|
|
||||||
|
"git.jinshen.cn/remilia/push-server/interval/protocol"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
log.Println("This is a test client.")
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
c, _, err := websocket.Dial(ctx, "ws://localhost:8080/ws", nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("dial error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = c.CloseNow() }()
|
||||||
|
|
||||||
|
initMsg := protocol.ControlMessage{
|
||||||
|
Type: protocol.MsgInit,
|
||||||
|
Topics: []protocol.Topic{"news", "sports"},
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("Sending init message:", initMsg)
|
||||||
|
|
||||||
|
err = wsjson.Write(ctx, c, initMsg)
|
||||||
|
if err != nil {
|
||||||
|
if websocket.CloseStatus(err) != websocket.StatusNormalClosure {
|
||||||
|
log.Printf("init write failed: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// typ, msg, err := c.Read(ctx)
|
||||||
|
// if err != nil {
|
||||||
|
// log.Println("read error:", err)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// switch typ {
|
||||||
|
// case websocket.MessageText:
|
||||||
|
// log.Printf("Received text message: %s", string(msg))
|
||||||
|
// case websocket.MessageBinary:
|
||||||
|
// log.Printf("Received binary message: %v", msg)
|
||||||
|
// }
|
||||||
|
go ReadBroadCastLoop(ctx, c)
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
_ = c.Close(websocket.StatusNormalClosure, "test client finished")
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadBroadCastLoop(ctx context.Context, c *websocket.Conn) {
|
||||||
|
for {
|
||||||
|
var msg protocol.BroadcastMessage
|
||||||
|
if err := wsjson.Read(ctx, c, &msg); err != nil {
|
||||||
|
log.Println("Read broadcast error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Received broadcast message: [%s] %s", msg.Topic, msg.Payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -9,9 +9,9 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.jinshen.cn/remilia/push-server/interval/api"
|
"git.jinshen.cn/remilia/push-server/interval/server/api"
|
||||||
"git.jinshen.cn/remilia/push-server/interval/server"
|
"git.jinshen.cn/remilia/push-server/interval/server/httpserver"
|
||||||
"git.jinshen.cn/remilia/push-server/interval/ws"
|
"git.jinshen.cn/remilia/push-server/interval/server/ws"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@ -27,7 +27,7 @@ func main() {
|
|||||||
h := ws.NewHub()
|
h := ws.NewHub()
|
||||||
go h.Run(serverCtx)
|
go h.Run(serverCtx)
|
||||||
|
|
||||||
httpServer := server.NewHTTPServer(":8080", api.NewRouter(h, serverCtx))
|
httpServer := httpserver.NewHTTPServer(":8080", api.NewRouter(h, serverCtx))
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
log.Println("Starting HTTP server on :8080")
|
log.Println("Starting HTTP server on :8080")
|
||||||
|
|||||||
@ -1,43 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/coder/websocket"
|
|
||||||
"github.com/coder/websocket/wsjson"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
log.Println("This is a test client.")
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
c, _, err := websocket.Dial(ctx, "ws://localhost:8080/ws", nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("dial error:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer c.CloseNow()
|
|
||||||
|
|
||||||
err = wsjson.Write(ctx, c, "Hello, WebSocket server!")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("write error:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
typ, msg, err := c.Read(ctx)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("read error:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch typ {
|
|
||||||
case websocket.MessageText:
|
|
||||||
log.Printf("Received text message: %s", string(msg))
|
|
||||||
case websocket.MessageBinary:
|
|
||||||
log.Printf("Received binary message: %v", msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Close(websocket.StatusNormalClosure, "test client finished")
|
|
||||||
}
|
|
||||||
@ -1,2 +0,0 @@
|
|||||||
// Package dto contains data transfer objects used in the interval API.
|
|
||||||
package dto
|
|
||||||
@ -1,17 +0,0 @@
|
|||||||
package dto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"git.jinshen.cn/remilia/push-server/interval/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
Topic string `json:"topic"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func MessageFromModel(m model.Message) Message {
|
|
||||||
return Message{
|
|
||||||
Topic: string(m.Topic),
|
|
||||||
Content: string(m.Content),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,5 +0,0 @@
|
|||||||
package dto
|
|
||||||
|
|
||||||
type PublishRequest struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
@ -1,17 +0,0 @@
|
|||||||
package dto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"git.jinshen.cn/remilia/push-server/interval/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Subscription struct {
|
|
||||||
Topic string `json:"topic"`
|
|
||||||
ClientID string `json:"client_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func SubscriptionFromModel(s model.Subscription) Subscription {
|
|
||||||
return Subscription{
|
|
||||||
Topic: string(s.Topic),
|
|
||||||
ClientID: string(s.ClientID),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,2 +0,0 @@
|
|||||||
// Package hub implements the message distribution core of the push service.
|
|
||||||
package hub
|
|
||||||
@ -1,2 +0,0 @@
|
|||||||
// Package model defines core domain models of the push service.
|
|
||||||
package model
|
|
||||||
@ -1,7 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
Topic Topic
|
|
||||||
Content []byte
|
|
||||||
Timestamp int64
|
|
||||||
}
|
|
||||||
@ -1,7 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
type Topic string
|
|
||||||
|
|
||||||
func (t Topic) Valid() bool {
|
|
||||||
return t != ""
|
|
||||||
}
|
|
||||||
40
interval/protocol/message.go
Normal file
40
interval/protocol/message.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package protocol
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ControlMessage struct {
|
||||||
|
Type MessageType `json:"type"`
|
||||||
|
Topic Topic `json:"topic,omitempty"`
|
||||||
|
Topics []Topic `json:"topics,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BroadcastMessage struct {
|
||||||
|
Type MessageType `json:"type"`
|
||||||
|
Topic Topic `json:"topic"`
|
||||||
|
Payload json.RawMessage `json:"payload"`
|
||||||
|
Timestamp int64 `json:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m ControlMessage) Validate() error {
|
||||||
|
switch m.Type {
|
||||||
|
case MsgInit:
|
||||||
|
if len(m.Topics) == 0 {
|
||||||
|
return errors.New("init requires topics")
|
||||||
|
}
|
||||||
|
case MsgSubscribe:
|
||||||
|
if m.Topic == "" {
|
||||||
|
return errors.New("subscribe requires topic")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return errors.New("unknown message type")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidMessage = errors.New("invalid message")
|
||||||
|
ErrPolicyViolation = errors.New("policy violation")
|
||||||
|
)
|
||||||
@ -1,6 +1,6 @@
|
|||||||
package model
|
package protocol
|
||||||
|
|
||||||
type Subscription struct {
|
type Subscription struct {
|
||||||
Topic Topic
|
|
||||||
ClientID string
|
ClientID string
|
||||||
|
Topic Topic
|
||||||
}
|
}
|
||||||
17
interval/protocol/types.go
Normal file
17
interval/protocol/types.go
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
package protocol
|
||||||
|
|
||||||
|
type MessageType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MsgInit MessageType = "init"
|
||||||
|
MsgSubscribe MessageType = "subscribe"
|
||||||
|
MsgUnsubscribe MessageType = "unsubscribe"
|
||||||
|
MsgBroadcast MessageType = "broadcast"
|
||||||
|
MsgError MessageType = "error"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Topic string
|
||||||
|
|
||||||
|
func (t Topic) Valid() bool {
|
||||||
|
return t != ""
|
||||||
|
}
|
||||||
@ -5,34 +5,41 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.jinshen.cn/remilia/push-server/interval/api/dto"
|
"git.jinshen.cn/remilia/push-server/interval/protocol"
|
||||||
"git.jinshen.cn/remilia/push-server/interval/model"
|
"git.jinshen.cn/remilia/push-server/interval/server/ws"
|
||||||
"git.jinshen.cn/remilia/push-server/interval/ws"
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type PublishRequest struct {
|
||||||
|
Payload json.RawMessage `json:"payload"`
|
||||||
|
}
|
||||||
|
|
||||||
func PushHandler(hub *ws.Hub) http.HandlerFunc {
|
func PushHandler(hub *ws.Hub) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
topicStr := chi.URLParam(r, "topic")
|
topicStr := chi.URLParam(r, "topic")
|
||||||
topic := model.Topic(topicStr)
|
topic := protocol.Topic(topicStr)
|
||||||
if !topic.Valid() {
|
if !topic.Valid() {
|
||||||
http.Error(w, "invalid topic", http.StatusBadRequest)
|
http.Error(w, "invalid topic", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req dto.PublishRequest
|
var req PublishRequest
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
dec := json.NewDecoder(r.Body)
|
||||||
|
dec.DisallowUnknownFields()
|
||||||
|
if err := dec.Decode(&req); err != nil {
|
||||||
http.Error(w, "invalid request body", http.StatusBadRequest)
|
http.Error(w, "invalid request body", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if req.Content == "" {
|
|
||||||
|
if len(req.Payload) == 0 {
|
||||||
http.Error(w, "content cannot be empty", http.StatusBadRequest)
|
http.Error(w, "content cannot be empty", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := model.Message{
|
msg := protocol.BroadcastMessage{
|
||||||
|
Type: protocol.MsgBroadcast,
|
||||||
Topic: topic,
|
Topic: topic,
|
||||||
Content: []byte(req.Content),
|
Payload: req.Payload,
|
||||||
Timestamp: time.Now().Unix(),
|
Timestamp: time.Now().Unix(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4,8 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"git.jinshen.cn/remilia/push-server/interval/api/handler"
|
"git.jinshen.cn/remilia/push-server/interval/server/api/handler"
|
||||||
"git.jinshen.cn/remilia/push-server/interval/ws"
|
"git.jinshen.cn/remilia/push-server/interval/server/ws"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1,2 +0,0 @@
|
|||||||
// Package server provides HTTP server abstractions.
|
|
||||||
package server
|
|
||||||
2
interval/server/httpserver/doc.go
Normal file
2
interval/server/httpserver/doc.go
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
// Package httpserver provides HTTP server abstractions.
|
||||||
|
package httpserver
|
||||||
@ -1,4 +1,4 @@
|
|||||||
package server
|
package httpserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
124
interval/server/ws/client.go
Normal file
124
interval/server/ws/client.go
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"git.jinshen.cn/remilia/push-server/interval/protocol"
|
||||||
|
"github.com/coder/websocket"
|
||||||
|
"github.com/coder/websocket/wsjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client represents a connected client in the hub.
|
||||||
|
type Client struct {
|
||||||
|
ID string
|
||||||
|
Conn *websocket.Conn
|
||||||
|
SendChan chan protocol.BroadcastMessage
|
||||||
|
Hub *Hub
|
||||||
|
|
||||||
|
Ctx context.Context
|
||||||
|
Cancel context.CancelFunc
|
||||||
|
|
||||||
|
inited bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(id string, conn *websocket.Conn, hub *Hub, parentCtx context.Context) *Client {
|
||||||
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
ID: id,
|
||||||
|
Conn: conn,
|
||||||
|
SendChan: make(chan protocol.BroadcastMessage, 32),
|
||||||
|
Hub: hub,
|
||||||
|
|
||||||
|
Ctx: ctx,
|
||||||
|
Cancel: cancel,
|
||||||
|
|
||||||
|
inited: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (c *Client) ReadLoop() {
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
var msg protocol.ControlMessage
|
||||||
|
|
||||||
|
err := wsjson.Read(c.Ctx, c.Conn, &msg)
|
||||||
|
if err != nil {
|
||||||
|
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
|
||||||
|
log.Println("WebSocket closed normally:", err)
|
||||||
|
} else {
|
||||||
|
log.Println("[Server Client.ReadLoop] WebSocket read error:", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := msg.Validate(); err != nil {
|
||||||
|
_ = c.Conn.Close(websocket.StatusPolicyViolation, "invalid message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.handleControlMessage(msg); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) WriteLoop() {
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Ctx.Done():
|
||||||
|
return
|
||||||
|
case msg, ok := <-c.SendChan:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("Sending message to client %s: %+v", c.ID, msg)
|
||||||
|
err := wsjson.Write(c.Ctx, c.Conn, msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("[Server Client.WriteLoop] WebSocket write error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) handleControlMessage(msg protocol.ControlMessage) error {
|
||||||
|
switch msg.Type {
|
||||||
|
case protocol.MsgInit:
|
||||||
|
if c.inited {
|
||||||
|
return errors.New("already initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Client %s initializing with topics: %v", c.ID, msg.Topics)
|
||||||
|
c.Hub.RegisterClient(c)
|
||||||
|
|
||||||
|
for _, t := range msg.Topics {
|
||||||
|
c.Hub.Subscribe(protocol.Subscription{
|
||||||
|
ClientID: c.ID,
|
||||||
|
Topic: protocol.Topic(t),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
c.inited = true
|
||||||
|
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
if !c.inited {
|
||||||
|
return errors.New("client not initialized")
|
||||||
|
}
|
||||||
|
log.Println("Unknown control message type:", msg.Type)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) Close() {
|
||||||
|
c.Cancel()
|
||||||
|
c.Hub.UnregisterClient(c)
|
||||||
|
if c.Conn != nil {
|
||||||
|
_ = c.Conn.Close(websocket.StatusNormalClosure, "client closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
32
interval/server/ws/handler.go
Normal file
32
interval/server/ws/handler.go
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
package ws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coder/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Handler(ctx context.Context, h *Hub) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||||
|
OriginPatterns: []string{"*"},
|
||||||
|
})
|
||||||
|
|
||||||
|
log.Println("New WebSocket connection from", r.RemoteAddr, "at", time.Now().Format(time.RFC3339))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Println("WebSocket accept error:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c := NewClient(r.RemoteAddr, conn, h, ctx)
|
||||||
|
log.Println("Client", r.RemoteAddr, "connected.")
|
||||||
|
|
||||||
|
go c.ReadLoop()
|
||||||
|
go c.WriteLoop()
|
||||||
|
go heartbeat(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -4,7 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"git.jinshen.cn/remilia/push-server/interval/model"
|
"git.jinshen.cn/remilia/push-server/interval/protocol"
|
||||||
"github.com/coder/websocket"
|
"github.com/coder/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -17,28 +17,28 @@ type Hub struct {
|
|||||||
// Invariant:
|
// Invariant:
|
||||||
// - clientsByTopic contains only topics with at least one active subscriber.
|
// - clientsByTopic contains only topics with at least one active subscriber.
|
||||||
// - A topic key must not exist if it has zero clients.
|
// - A topic key must not exist if it has zero clients.
|
||||||
clientsByTopic map[model.Topic]map[string]*Client
|
clientsByTopic map[protocol.Topic]map[string]*Client
|
||||||
// clientID -> topic -> struct{}
|
// clientID -> topic -> struct{}
|
||||||
topicsByClients map[string]map[model.Topic]struct{}
|
topicsByClients map[string]map[protocol.Topic]struct{}
|
||||||
|
|
||||||
register chan *Client
|
register chan *Client
|
||||||
unregister chan *Client
|
unregister chan *Client
|
||||||
subscribe chan model.Subscription
|
subscribe chan protocol.Subscription
|
||||||
unsubscribe chan model.Subscription
|
unsubscribe chan protocol.Subscription
|
||||||
broadcast chan model.Message
|
broadcast chan protocol.BroadcastMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHub() *Hub {
|
func NewHub() *Hub {
|
||||||
return &Hub{
|
return &Hub{
|
||||||
clientsByID: make(map[string]*Client),
|
clientsByID: make(map[string]*Client),
|
||||||
clientsByTopic: make(map[model.Topic]map[string]*Client),
|
clientsByTopic: make(map[protocol.Topic]map[string]*Client),
|
||||||
topicsByClients: make(map[string]map[model.Topic]struct{}),
|
topicsByClients: make(map[string]map[protocol.Topic]struct{}),
|
||||||
|
|
||||||
register: make(chan *Client),
|
register: make(chan *Client),
|
||||||
unregister: make(chan *Client),
|
unregister: make(chan *Client),
|
||||||
subscribe: make(chan model.Subscription, 8),
|
subscribe: make(chan protocol.Subscription, 8),
|
||||||
unsubscribe: make(chan model.Subscription, 8),
|
unsubscribe: make(chan protocol.Subscription, 8),
|
||||||
broadcast: make(chan model.Message, 64),
|
broadcast: make(chan protocol.BroadcastMessage, 64),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,15 +50,15 @@ func (h *Hub) UnregisterClient(client *Client) {
|
|||||||
h.unregister <- client
|
h.unregister <- client
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hub) Subscribe(sub model.Subscription) {
|
func (h *Hub) Subscribe(sub protocol.Subscription) {
|
||||||
h.subscribe <- sub
|
h.subscribe <- sub
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hub) Unsubscribe(sub model.Subscription) {
|
func (h *Hub) Unsubscribe(sub protocol.Subscription) {
|
||||||
h.unsubscribe <- sub
|
h.unsubscribe <- sub
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hub) BroadcastMessage(ctx context.Context, msg model.Message) error {
|
func (h *Hub) BroadcastMessage(ctx context.Context, msg protocol.BroadcastMessage) error {
|
||||||
select {
|
select {
|
||||||
case h.broadcast <- msg:
|
case h.broadcast <- msg:
|
||||||
return nil
|
return nil
|
||||||
@ -103,7 +103,7 @@ func (h *Hub) getClient(id string) (*Client, bool) {
|
|||||||
// Create a new entry for the client in topicsByClients map when it registers.
|
// Create a new entry for the client in topicsByClients map when it registers.
|
||||||
func (h *Hub) onRegister(c *Client) {
|
func (h *Hub) onRegister(c *Client) {
|
||||||
h.clientsByID[c.ID] = c
|
h.clientsByID[c.ID] = c
|
||||||
h.topicsByClients[c.ID] = make(map[model.Topic]struct{})
|
h.topicsByClients[c.ID] = make(map[protocol.Topic]struct{})
|
||||||
log.Printf("Current clientList: %v\n", h.clientsByID)
|
log.Printf("Current clientList: %v\n", h.clientsByID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ func (h *Hub) onUnregister(c *Client) {
|
|||||||
delete(h.clientsByID, c.ID)
|
delete(h.clientsByID, c.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hub) onSubscribe(s model.Subscription) {
|
func (h *Hub) onSubscribe(s protocol.Subscription) {
|
||||||
c, ok := h.getClient(s.ClientID)
|
c, ok := h.getClient(s.ClientID)
|
||||||
if !ok {
|
if !ok {
|
||||||
// If the client does not exist, log an error and return.
|
// If the client does not exist, log an error and return.
|
||||||
@ -135,9 +135,10 @@ func (h *Hub) onSubscribe(s model.Subscription) {
|
|||||||
|
|
||||||
h.clientsByTopic[s.Topic][s.ClientID] = c
|
h.clientsByTopic[s.Topic][s.ClientID] = c
|
||||||
h.topicsByClients[s.ClientID][s.Topic] = struct{}{}
|
h.topicsByClients[s.ClientID][s.Topic] = struct{}{}
|
||||||
|
log.Printf("Client %s subscribed to topic %s", s.ClientID, s.Topic)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hub) onUnsubscribe(s model.Subscription) {
|
func (h *Hub) onUnsubscribe(s protocol.Subscription) {
|
||||||
if clients, ok := h.clientsByTopic[s.Topic]; ok {
|
if clients, ok := h.clientsByTopic[s.Topic]; ok {
|
||||||
delete(clients, s.ClientID)
|
delete(clients, s.ClientID)
|
||||||
if len(clients) == 0 {
|
if len(clients) == 0 {
|
||||||
@ -150,15 +151,16 @@ func (h *Hub) onUnsubscribe(s model.Subscription) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hub) onBroadcast(msg model.Message) {
|
func (h *Hub) onBroadcast(msg protocol.BroadcastMessage) {
|
||||||
if !msg.Topic.Valid() {
|
if !msg.Topic.Valid() {
|
||||||
log.Printf("Broadcast failed: invalid topic")
|
log.Printf("Broadcast failed: invalid topic")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Printf("Receiving message for topic %s: %s", msg.Topic, string(msg.Content))
|
log.Printf("Receiving message for topic %s: %s", msg.Topic, string(msg.Payload))
|
||||||
for _, c := range h.clientsByTopic[msg.Topic] {
|
for _, c := range h.clientsByTopic[msg.Topic] {
|
||||||
select {
|
select {
|
||||||
case c.SendChan <- msg.Content:
|
case c.SendChan <- msg:
|
||||||
|
log.Printf("[%d] Sending message to client %s: [%s] %v", msg.Timestamp, c.ID, msg.Type, string(msg.Payload))
|
||||||
default:
|
default:
|
||||||
h.UnregisterClient(c)
|
h.UnregisterClient(c)
|
||||||
if c.Conn != nil {
|
if c.Conn != nil {
|
||||||
@ -1,58 +0,0 @@
|
|||||||
package ws
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"log"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/coder/websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Client represents a connected client in the hub.
|
|
||||||
type Client struct {
|
|
||||||
ID string
|
|
||||||
Conn *websocket.Conn
|
|
||||||
SendChan chan []byte
|
|
||||||
|
|
||||||
Ctx context.Context
|
|
||||||
Cancel context.CancelFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClient(id string, conn *websocket.Conn, parentCtx context.Context) *Client {
|
|
||||||
ctx, cancel := context.WithCancel(parentCtx)
|
|
||||||
|
|
||||||
return &Client{
|
|
||||||
ID: id,
|
|
||||||
Conn: conn,
|
|
||||||
SendChan: make(chan []byte, 32),
|
|
||||||
Ctx: ctx,
|
|
||||||
Cancel: cancel,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write message to websocket connection.
|
|
||||||
func (c *Client) WriteMessage() {
|
|
||||||
defer func() {
|
|
||||||
_ = c.Conn.Close(websocket.StatusNormalClosure, "client writer closed")
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-c.Ctx.Done():
|
|
||||||
return
|
|
||||||
case msg, ok := <-c.SendChan:
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
writeCtx, cancel := context.WithTimeout(c.Ctx, 5*time.Second)
|
|
||||||
err := c.Conn.Write(writeCtx, websocket.MessageText, msg)
|
|
||||||
cancel()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Println("WebSocket write error:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,75 +0,0 @@
|
|||||||
package ws
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/coder/websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Handler(ctx context.Context, h *Hub) http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
|
||||||
OriginPatterns: []string{"*"},
|
|
||||||
})
|
|
||||||
|
|
||||||
log.Println("New WebSocket connection from", r.RemoteAddr, "at", time.Now().Format(time.RFC3339))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Println("WebSocket accept error:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c := NewClient(r.RemoteAddr, conn, ctx)
|
|
||||||
log.Println("Client", r.RemoteAddr, "connected.")
|
|
||||||
h.RegisterClient(c)
|
|
||||||
|
|
||||||
go echo(c, h)
|
|
||||||
go heartbeat(c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func echo(c *Client, h *Hub) {
|
|
||||||
defer func() {
|
|
||||||
if c.Conn != nil {
|
|
||||||
log.Println("Closing WebSocket connection")
|
|
||||||
h.UnregisterClient(c)
|
|
||||||
c.Cancel()
|
|
||||||
_ = c.Conn.Close(websocket.StatusNormalClosure, "echo finished")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
typ, r, err := c.Conn.Reader(c.Ctx)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
|
|
||||||
log.Println("WebSocket connection closed normally")
|
|
||||||
} else {
|
|
||||||
log.Println("WebSocket reader error:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w, err := c.Conn.Writer(c.Ctx, typ)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("WebSocket writer error:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = io.Copy(w, r)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("WebSocket copy error:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = w.Close(); err != nil {
|
|
||||||
log.Println("WebSocket writer close error:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user