focalboard/server/ws/websockets.go
2020-10-22 15:22:36 +02:00

192 lines
4.6 KiB
Go

package ws
import (
"encoding/json"
"log"
"net/http"
"sync"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
)
// RegisterRoutes registers routes.
func (ws *Server) RegisterRoutes(r *mux.Router) {
r.HandleFunc("/ws/onchange", ws.handleWebSocketOnChange)
}
// AddListener adds a listener for a block's change.
func (ws *Server) AddListener(client *websocket.Conn, blockIDs []string) {
ws.mu.Lock()
for _, blockID := range blockIDs {
if ws.listeners[blockID] == nil {
ws.listeners[blockID] = []*websocket.Conn{}
}
ws.listeners[blockID] = append(ws.listeners[blockID], client)
}
ws.mu.Unlock()
}
// RemoveListener removes a webSocket listener from all blocks.
func (ws *Server) RemoveListener(client *websocket.Conn) {
ws.mu.Lock()
for key, clients := range ws.listeners {
listeners := []*websocket.Conn{}
for _, existingClient := range clients {
if client != existingClient {
listeners = append(listeners, existingClient)
}
}
ws.listeners[key] = listeners
}
ws.mu.Unlock()
}
// RemoveListenerFromBlocks removes a webSocket listener from a set of block.
func (ws *Server) RemoveListenerFromBlocks(client *websocket.Conn, blockIDs []string) {
ws.mu.Lock()
for _, blockID := range blockIDs {
listeners := ws.listeners[blockID]
if listeners == nil {
return
}
// Remove the first instance of this client that's listening to this block
// Note: A client can listen multiple times to the same block
for index, listener := range listeners {
if client == listener {
newListeners := append(listeners[:index], listeners[index+1:]...)
ws.listeners[blockID] = newListeners
break
}
}
}
ws.mu.Unlock()
}
// GetListeners returns the listeners to a blockID's changes.
func (ws *Server) GetListeners(blockID string) []*websocket.Conn {
ws.mu.Lock()
listeners := ws.listeners[blockID]
ws.mu.Unlock()
return listeners
}
// Server is a WebSocket server.
type Server struct {
upgrader websocket.Upgrader
listeners map[string][]*websocket.Conn
mu sync.RWMutex
}
// NewServer creates a new Server.
func NewServer() *Server {
return &Server{
listeners: make(map[string][]*websocket.Conn),
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
},
}
}
// WebsocketMsg is sent on block changes.
type WebsocketMsg struct {
Action string `json:"action"`
BlockID string `json:"blockId"`
}
// WebsocketCommand is an incoming command from the client.
type WebsocketCommand struct {
Action string `json:"action"`
BlockIDs []string `json:"blockIds"`
}
func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request) {
// Upgrade initial GET request to a websocket
client, err := ws.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Fatal(err)
}
// TODO: Auth
log.Printf("CONNECT WebSocket onChange, client: %s", client.RemoteAddr())
// Make sure we close the connection when the function returns
defer func() {
log.Printf("DISCONNECT WebSocket onChange, client: %s", client.RemoteAddr())
// Remove client from listeners
ws.RemoveListener(client)
client.Close()
}()
// Simple message handling loop
for {
_, p, err := client.ReadMessage()
if err != nil {
log.Printf("ERROR WebSocket onChange, client: %s, err: %v", client.RemoteAddr(), err)
ws.RemoveListener(client)
break
}
var command WebsocketCommand
err = json.Unmarshal(p, &command)
if err != nil {
// handle this error
log.Printf(`ERROR webSocket parsing command JSON: %v`, string(p))
continue
}
switch command.Action {
case "ADD":
log.Printf(`Command: Add blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr())
ws.AddListener(client, command.BlockIDs)
case "REMOVE":
log.Printf(`Command: Remove blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr())
ws.RemoveListenerFromBlocks(client, command.BlockIDs)
default:
log.Printf(`ERROR webSocket command, invalid action: %v`, command.Action)
}
}
}
// BroadcastBlockChangeToWebsocketClients broadcasts change to clients.
func (ws *Server) BroadcastBlockChangeToWebsocketClients(blockIDs []string) {
for _, blockID := range blockIDs {
listeners := ws.GetListeners(blockID)
log.Printf("%d listener(s) for blockID: %s", len(listeners), blockID)
if listeners != nil {
message := WebsocketMsg{
Action: "UPDATE_BLOCK",
BlockID: blockID,
}
for _, listener := range listeners {
log.Printf("Broadcast change, blockID: %s, remoteAddr: %s", blockID, listener.RemoteAddr())
err := listener.WriteJSON(message)
if err != nil {
log.Printf("broadcast error: %v", err)
listener.Close()
}
}
}
}
}