Compare commits

..

3 Commits

Author SHA1 Message Date
36296b6af4 fix: 解决json无法正常解析的问题 2025-12-17 15:54:26 +08:00
1dbcc03e46 feat: 基本广播服务
- 由Hub接收/push/{topic}的请求并解析信息体广播到对应的Client
2025-12-17 14:49:59 +08:00
53555a31c0 refactor: 重构项目结构
- 将server端相关依赖单独防止在server中
2025-12-17 12:34:03 +08:00
29 changed files with 330 additions and 275 deletions

68
cmd/client/main.go Normal file
View File

@ -0,0 +1,68 @@
package main
import (
"context"
"log"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"git.jinshen.cn/remilia/push-server/interval/protocol"
)
func main() {
log.Println("This is a test client.")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c, _, err := websocket.Dial(ctx, "ws://localhost:8080/ws", nil)
if err != nil {
log.Fatal("dial error:", err)
return
}
defer func() { _ = c.CloseNow() }()
initMsg := protocol.ControlMessage{
Type: protocol.MsgInit,
Topics: []protocol.Topic{"news", "sports"},
}
log.Println("Sending init message:", initMsg)
err = wsjson.Write(ctx, c, initMsg)
if err != nil {
if websocket.CloseStatus(err) != websocket.StatusNormalClosure {
log.Printf("init write failed: %v", err)
}
return
}
// typ, msg, err := c.Read(ctx)
// if err != nil {
// log.Println("read error:", err)
// return
// }
//
// switch typ {
// case websocket.MessageText:
// log.Printf("Received text message: %s", string(msg))
// case websocket.MessageBinary:
// log.Printf("Received binary message: %v", msg)
// }
go ReadBroadCastLoop(ctx, c)
<-ctx.Done()
_ = c.Close(websocket.StatusNormalClosure, "test client finished")
}
func ReadBroadCastLoop(ctx context.Context, c *websocket.Conn) {
for {
var msg protocol.BroadcastMessage
if err := wsjson.Read(ctx, c, &msg); err != nil {
log.Println("Read broadcast error:", err)
return
}
log.Printf("Received broadcast message: [%s] %s", msg.Topic, msg.Payload)
}
}

View File

@ -9,9 +9,9 @@ import (
"syscall" "syscall"
"time" "time"
"git.jinshen.cn/remilia/push-server/interval/api" "git.jinshen.cn/remilia/push-server/interval/server/api"
"git.jinshen.cn/remilia/push-server/interval/server" "git.jinshen.cn/remilia/push-server/interval/server/httpserver"
"git.jinshen.cn/remilia/push-server/interval/ws" "git.jinshen.cn/remilia/push-server/interval/server/ws"
) )
func main() { func main() {
@ -27,7 +27,7 @@ func main() {
h := ws.NewHub() h := ws.NewHub()
go h.Run(serverCtx) go h.Run(serverCtx)
httpServer := server.NewHTTPServer(":8080", api.NewRouter(h, serverCtx)) httpServer := httpserver.NewHTTPServer(":8080", api.NewRouter(h, serverCtx))
go func() { go func() {
log.Println("Starting HTTP server on :8080") log.Println("Starting HTTP server on :8080")

View File

@ -1,43 +0,0 @@
package main
import (
"context"
"log"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
func main() {
log.Println("This is a test client.")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
c, _, err := websocket.Dial(ctx, "ws://localhost:8080/ws", nil)
if err != nil {
log.Fatal("dial error:", err)
return
}
defer c.CloseNow()
err = wsjson.Write(ctx, c, "Hello, WebSocket server!")
if err != nil {
log.Fatal("write error:", err)
return
}
typ, msg, err := c.Read(ctx)
if err != nil {
log.Println("read error:", err)
return
}
switch typ {
case websocket.MessageText:
log.Printf("Received text message: %s", string(msg))
case websocket.MessageBinary:
log.Printf("Received binary message: %v", msg)
}
c.Close(websocket.StatusNormalClosure, "test client finished")
}

View File

@ -1,2 +0,0 @@
// Package dto contains data transfer objects used in the interval API.
package dto

View File

@ -1,17 +0,0 @@
package dto
import (
"git.jinshen.cn/remilia/push-server/interval/model"
)
type Message struct {
Topic string `json:"topic"`
Content string `json:"content"`
}
func MessageFromModel(m model.Message) Message {
return Message{
Topic: string(m.Topic),
Content: string(m.Content),
}
}

View File

@ -1,5 +0,0 @@
package dto
type PublishRequest struct {
Content string `json:"content"`
}

View File

@ -1,17 +0,0 @@
package dto
import (
"git.jinshen.cn/remilia/push-server/interval/model"
)
type Subscription struct {
Topic string `json:"topic"`
ClientID string `json:"client_id"`
}
func SubscriptionFromModel(s model.Subscription) Subscription {
return Subscription{
Topic: string(s.Topic),
ClientID: string(s.ClientID),
}
}

View File

@ -1,2 +0,0 @@
// Package hub implements the message distribution core of the push service.
package hub

View File

@ -1,2 +0,0 @@
// Package model defines core domain models of the push service.
package model

View File

@ -1,7 +0,0 @@
package model
type Message struct {
Topic Topic
Content []byte
Timestamp int64
}

View File

@ -1,7 +0,0 @@
package model
type Topic string
func (t Topic) Valid() bool {
return t != ""
}

View File

@ -0,0 +1,40 @@
package protocol
import (
"encoding/json"
"errors"
)
type ControlMessage struct {
Type MessageType `json:"type"`
Topic Topic `json:"topic,omitempty"`
Topics []Topic `json:"topics,omitempty"`
}
type BroadcastMessage struct {
Type MessageType `json:"type"`
Topic Topic `json:"topic"`
Payload json.RawMessage `json:"payload"`
Timestamp int64 `json:"timestamp"`
}
func (m ControlMessage) Validate() error {
switch m.Type {
case MsgInit:
if len(m.Topics) == 0 {
return errors.New("init requires topics")
}
case MsgSubscribe:
if m.Topic == "" {
return errors.New("subscribe requires topic")
}
default:
return errors.New("unknown message type")
}
return nil
}
var (
ErrInvalidMessage = errors.New("invalid message")
ErrPolicyViolation = errors.New("policy violation")
)

View File

@ -1,6 +1,6 @@
package model package protocol
type Subscription struct { type Subscription struct {
Topic Topic
ClientID string ClientID string
Topic Topic
} }

View File

@ -0,0 +1,17 @@
package protocol
type MessageType string
const (
MsgInit MessageType = "init"
MsgSubscribe MessageType = "subscribe"
MsgUnsubscribe MessageType = "unsubscribe"
MsgBroadcast MessageType = "broadcast"
MsgError MessageType = "error"
)
type Topic string
func (t Topic) Valid() bool {
return t != ""
}

View File

@ -5,34 +5,41 @@ import (
"net/http" "net/http"
"time" "time"
"git.jinshen.cn/remilia/push-server/interval/api/dto" "git.jinshen.cn/remilia/push-server/interval/protocol"
"git.jinshen.cn/remilia/push-server/interval/model" "git.jinshen.cn/remilia/push-server/interval/server/ws"
"git.jinshen.cn/remilia/push-server/interval/ws"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
) )
type PublishRequest struct {
Payload json.RawMessage `json:"payload"`
}
func PushHandler(hub *ws.Hub) http.HandlerFunc { func PushHandler(hub *ws.Hub) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
topicStr := chi.URLParam(r, "topic") topicStr := chi.URLParam(r, "topic")
topic := model.Topic(topicStr) topic := protocol.Topic(topicStr)
if !topic.Valid() { if !topic.Valid() {
http.Error(w, "invalid topic", http.StatusBadRequest) http.Error(w, "invalid topic", http.StatusBadRequest)
return return
} }
var req dto.PublishRequest var req PublishRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
if err := dec.Decode(&req); err != nil {
http.Error(w, "invalid request body", http.StatusBadRequest) http.Error(w, "invalid request body", http.StatusBadRequest)
return return
} }
if req.Content == "" {
if len(req.Payload) == 0 {
http.Error(w, "content cannot be empty", http.StatusBadRequest) http.Error(w, "content cannot be empty", http.StatusBadRequest)
return return
} }
msg := model.Message{ msg := protocol.BroadcastMessage{
Type: protocol.MsgBroadcast,
Topic: topic, Topic: topic,
Content: []byte(req.Content), Payload: req.Payload,
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
} }

View File

@ -4,8 +4,8 @@ import (
"context" "context"
"net/http" "net/http"
"git.jinshen.cn/remilia/push-server/interval/api/handler" "git.jinshen.cn/remilia/push-server/interval/server/api/handler"
"git.jinshen.cn/remilia/push-server/interval/ws" "git.jinshen.cn/remilia/push-server/interval/server/ws"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
) )

View File

@ -1,2 +0,0 @@
// Package server provides HTTP server abstractions.
package server

View File

@ -0,0 +1,2 @@
// Package httpserver provides HTTP server abstractions.
package httpserver

View File

@ -1,4 +1,4 @@
package server package httpserver
import ( import (
"context" "context"

View File

@ -0,0 +1,124 @@
package ws
import (
"context"
"errors"
"log"
"git.jinshen.cn/remilia/push-server/interval/protocol"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
// Client represents a connected client in the hub.
type Client struct {
ID string
Conn *websocket.Conn
SendChan chan protocol.BroadcastMessage
Hub *Hub
Ctx context.Context
Cancel context.CancelFunc
inited bool
}
func NewClient(id string, conn *websocket.Conn, hub *Hub, parentCtx context.Context) *Client {
ctx, cancel := context.WithCancel(parentCtx)
return &Client{
ID: id,
Conn: conn,
SendChan: make(chan protocol.BroadcastMessage, 32),
Hub: hub,
Ctx: ctx,
Cancel: cancel,
inited: false,
}
}
func (c *Client) ReadLoop() {
defer c.Close()
for {
var msg protocol.ControlMessage
err := wsjson.Read(c.Ctx, c.Conn, &msg)
if err != nil {
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
log.Println("WebSocket closed normally:", err)
} else {
log.Println("[Server Client.ReadLoop] WebSocket read error:", err)
}
return
}
if err := msg.Validate(); err != nil {
_ = c.Conn.Close(websocket.StatusPolicyViolation, "invalid message")
return
}
if err := c.handleControlMessage(msg); err != nil {
return
}
}
}
func (c *Client) WriteLoop() {
defer c.Close()
for {
select {
case <-c.Ctx.Done():
return
case msg, ok := <-c.SendChan:
if !ok {
return
}
log.Printf("Sending message to client %s: %+v", c.ID, msg)
err := wsjson.Write(c.Ctx, c.Conn, msg)
if err != nil {
log.Println("[Server Client.WriteLoop] WebSocket write error:", err)
return
}
}
}
}
func (c *Client) handleControlMessage(msg protocol.ControlMessage) error {
switch msg.Type {
case protocol.MsgInit:
if c.inited {
return errors.New("already initialized")
}
log.Printf("Client %s initializing with topics: %v", c.ID, msg.Topics)
c.Hub.RegisterClient(c)
for _, t := range msg.Topics {
c.Hub.Subscribe(protocol.Subscription{
ClientID: c.ID,
Topic: protocol.Topic(t),
})
}
c.inited = true
return nil
default:
if !c.inited {
return errors.New("client not initialized")
}
log.Println("Unknown control message type:", msg.Type)
return nil
}
}
func (c *Client) Close() {
c.Cancel()
c.Hub.UnregisterClient(c)
if c.Conn != nil {
_ = c.Conn.Close(websocket.StatusNormalClosure, "client closed")
}
}

View File

@ -0,0 +1,32 @@
package ws
import (
"context"
"log"
"net/http"
"time"
"github.com/coder/websocket"
)
func Handler(ctx context.Context, h *Hub) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
})
log.Println("New WebSocket connection from", r.RemoteAddr, "at", time.Now().Format(time.RFC3339))
if err != nil {
log.Println("WebSocket accept error:", err)
return
}
c := NewClient(r.RemoteAddr, conn, h, ctx)
log.Println("Client", r.RemoteAddr, "connected.")
go c.ReadLoop()
go c.WriteLoop()
go heartbeat(c)
}
}

View File

@ -4,7 +4,7 @@ import (
"context" "context"
"log" "log"
"git.jinshen.cn/remilia/push-server/interval/model" "git.jinshen.cn/remilia/push-server/interval/protocol"
"github.com/coder/websocket" "github.com/coder/websocket"
) )
@ -17,28 +17,28 @@ type Hub struct {
// Invariant: // Invariant:
// - clientsByTopic contains only topics with at least one active subscriber. // - clientsByTopic contains only topics with at least one active subscriber.
// - A topic key must not exist if it has zero clients. // - A topic key must not exist if it has zero clients.
clientsByTopic map[model.Topic]map[string]*Client clientsByTopic map[protocol.Topic]map[string]*Client
// clientID -> topic -> struct{} // clientID -> topic -> struct{}
topicsByClients map[string]map[model.Topic]struct{} topicsByClients map[string]map[protocol.Topic]struct{}
register chan *Client register chan *Client
unregister chan *Client unregister chan *Client
subscribe chan model.Subscription subscribe chan protocol.Subscription
unsubscribe chan model.Subscription unsubscribe chan protocol.Subscription
broadcast chan model.Message broadcast chan protocol.BroadcastMessage
} }
func NewHub() *Hub { func NewHub() *Hub {
return &Hub{ return &Hub{
clientsByID: make(map[string]*Client), clientsByID: make(map[string]*Client),
clientsByTopic: make(map[model.Topic]map[string]*Client), clientsByTopic: make(map[protocol.Topic]map[string]*Client),
topicsByClients: make(map[string]map[model.Topic]struct{}), topicsByClients: make(map[string]map[protocol.Topic]struct{}),
register: make(chan *Client), register: make(chan *Client),
unregister: make(chan *Client), unregister: make(chan *Client),
subscribe: make(chan model.Subscription, 8), subscribe: make(chan protocol.Subscription, 8),
unsubscribe: make(chan model.Subscription, 8), unsubscribe: make(chan protocol.Subscription, 8),
broadcast: make(chan model.Message, 64), broadcast: make(chan protocol.BroadcastMessage, 64),
} }
} }
@ -50,15 +50,15 @@ func (h *Hub) UnregisterClient(client *Client) {
h.unregister <- client h.unregister <- client
} }
func (h *Hub) Subscribe(sub model.Subscription) { func (h *Hub) Subscribe(sub protocol.Subscription) {
h.subscribe <- sub h.subscribe <- sub
} }
func (h *Hub) Unsubscribe(sub model.Subscription) { func (h *Hub) Unsubscribe(sub protocol.Subscription) {
h.unsubscribe <- sub h.unsubscribe <- sub
} }
func (h *Hub) BroadcastMessage(ctx context.Context, msg model.Message) error { func (h *Hub) BroadcastMessage(ctx context.Context, msg protocol.BroadcastMessage) error {
select { select {
case h.broadcast <- msg: case h.broadcast <- msg:
return nil return nil
@ -103,7 +103,7 @@ func (h *Hub) getClient(id string) (*Client, bool) {
// Create a new entry for the client in topicsByClients map when it registers. // Create a new entry for the client in topicsByClients map when it registers.
func (h *Hub) onRegister(c *Client) { func (h *Hub) onRegister(c *Client) {
h.clientsByID[c.ID] = c h.clientsByID[c.ID] = c
h.topicsByClients[c.ID] = make(map[model.Topic]struct{}) h.topicsByClients[c.ID] = make(map[protocol.Topic]struct{})
log.Printf("Current clientList: %v\n", h.clientsByID) log.Printf("Current clientList: %v\n", h.clientsByID)
} }
@ -122,7 +122,7 @@ func (h *Hub) onUnregister(c *Client) {
delete(h.clientsByID, c.ID) delete(h.clientsByID, c.ID)
} }
func (h *Hub) onSubscribe(s model.Subscription) { func (h *Hub) onSubscribe(s protocol.Subscription) {
c, ok := h.getClient(s.ClientID) c, ok := h.getClient(s.ClientID)
if !ok { if !ok {
// If the client does not exist, log an error and return. // If the client does not exist, log an error and return.
@ -135,9 +135,10 @@ func (h *Hub) onSubscribe(s model.Subscription) {
h.clientsByTopic[s.Topic][s.ClientID] = c h.clientsByTopic[s.Topic][s.ClientID] = c
h.topicsByClients[s.ClientID][s.Topic] = struct{}{} h.topicsByClients[s.ClientID][s.Topic] = struct{}{}
log.Printf("Client %s subscribed to topic %s", s.ClientID, s.Topic)
} }
func (h *Hub) onUnsubscribe(s model.Subscription) { func (h *Hub) onUnsubscribe(s protocol.Subscription) {
if clients, ok := h.clientsByTopic[s.Topic]; ok { if clients, ok := h.clientsByTopic[s.Topic]; ok {
delete(clients, s.ClientID) delete(clients, s.ClientID)
if len(clients) == 0 { if len(clients) == 0 {
@ -150,15 +151,16 @@ func (h *Hub) onUnsubscribe(s model.Subscription) {
} }
} }
func (h *Hub) onBroadcast(msg model.Message) { func (h *Hub) onBroadcast(msg protocol.BroadcastMessage) {
if !msg.Topic.Valid() { if !msg.Topic.Valid() {
log.Printf("Broadcast failed: invalid topic") log.Printf("Broadcast failed: invalid topic")
return return
} }
log.Printf("Receiving message for topic %s: %s", msg.Topic, string(msg.Content)) log.Printf("Receiving message for topic %s: %s", msg.Topic, string(msg.Payload))
for _, c := range h.clientsByTopic[msg.Topic] { for _, c := range h.clientsByTopic[msg.Topic] {
select { select {
case c.SendChan <- msg.Content: case c.SendChan <- msg:
log.Printf("[%d] Sending message to client %s: [%s] %v", msg.Timestamp, c.ID, msg.Type, string(msg.Payload))
default: default:
h.UnregisterClient(c) h.UnregisterClient(c)
if c.Conn != nil { if c.Conn != nil {

View File

@ -1,58 +0,0 @@
package ws
import (
"context"
"log"
"time"
"github.com/coder/websocket"
)
// Client represents a connected client in the hub.
type Client struct {
ID string
Conn *websocket.Conn
SendChan chan []byte
Ctx context.Context
Cancel context.CancelFunc
}
func NewClient(id string, conn *websocket.Conn, parentCtx context.Context) *Client {
ctx, cancel := context.WithCancel(parentCtx)
return &Client{
ID: id,
Conn: conn,
SendChan: make(chan []byte, 32),
Ctx: ctx,
Cancel: cancel,
}
}
// Write message to websocket connection.
func (c *Client) WriteMessage() {
defer func() {
_ = c.Conn.Close(websocket.StatusNormalClosure, "client writer closed")
}()
for {
select {
case <-c.Ctx.Done():
return
case msg, ok := <-c.SendChan:
if !ok {
return
}
writeCtx, cancel := context.WithTimeout(c.Ctx, 5*time.Second)
err := c.Conn.Write(writeCtx, websocket.MessageText, msg)
cancel()
if err != nil {
log.Println("WebSocket write error:", err)
return
}
}
}
}

View File

@ -1,75 +0,0 @@
package ws
import (
"context"
"io"
"log"
"net/http"
"time"
"github.com/coder/websocket"
)
func Handler(ctx context.Context, h *Hub) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
})
log.Println("New WebSocket connection from", r.RemoteAddr, "at", time.Now().Format(time.RFC3339))
if err != nil {
log.Println("WebSocket accept error:", err)
return
}
c := NewClient(r.RemoteAddr, conn, ctx)
log.Println("Client", r.RemoteAddr, "connected.")
h.RegisterClient(c)
go echo(c, h)
go heartbeat(c)
}
}
func echo(c *Client, h *Hub) {
defer func() {
if c.Conn != nil {
log.Println("Closing WebSocket connection")
h.UnregisterClient(c)
c.Cancel()
_ = c.Conn.Close(websocket.StatusNormalClosure, "echo finished")
}
}()
for {
typ, r, err := c.Conn.Reader(c.Ctx)
if err != nil {
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
log.Println("WebSocket connection closed normally")
} else {
log.Println("WebSocket reader error:", err)
}
return
}
w, err := c.Conn.Writer(c.Ctx, typ)
if err != nil {
log.Println("WebSocket writer error:", err)
return
}
_, err = io.Copy(w, r)
if err != nil {
log.Println("WebSocket copy error:", err)
return
}
if err = w.Close(); err != nil {
log.Println("WebSocket writer close error:", err)
return
}
}
}