From 0bd6233facf0864803d2c88d4893e57dbdf8d0af Mon Sep 17 00:00:00 2001 From: Miguel de la Cruz Date: Mon, 19 Jul 2021 11:34:17 +0200 Subject: [PATCH] Refactor the websockets connection messages and lifecycle (#749) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Refactor the websockets connection messages and lifecycle * Modify frontend to subscribe to a workspace instead of a set of blocks * Fixing linter errors Co-authored-by: Jesús Espino --- server/ws/websockets.go | 548 +++++++++++++++++++++-------------- server/ws/websockets_test.go | 231 +++++++++++++++ webapp/src/octoListener.ts | 26 +- 3 files changed, 585 insertions(+), 220 deletions(-) create mode 100644 server/ws/websockets_test.go diff --git a/server/ws/websockets.go b/server/ws/websockets.go index 249a8a873..75813f860 100644 --- a/server/ws/websockets.go +++ b/server/ws/websockets.go @@ -15,8 +15,13 @@ import ( "github.com/mattermost/focalboard/server/services/store" ) -// IsValidSessionToken authenticates session tokens. -type IsValidSessionToken func(token string) bool +const ( + singleUserID = "single-user-id" + websocketActionAuth = "AUTH" + websocketActionSubscribeWorkspace = "SUBSCRIBE_WORKSPACE" + websocketActionUnsubscribeWorkspace = "UNSUBSCRIBE_WORKSPACE" + websocketActionSubscribeBlocks = "SUBSCRIBE_BLOCKS" +) type Hub interface { SendWSMessage(data []byte) @@ -25,7 +30,9 @@ type Hub interface { type wsClient struct { *websocket.Conn - lock *sync.RWMutex + lock *sync.Mutex + workspaces []string + blocks []string } func (c *wsClient) WriteJSON(v interface{}) error { @@ -35,16 +42,38 @@ func (c *wsClient) WriteJSON(v interface{}) error { return err } +func (c *wsClient) isSubscribedToWorkspace(workspaceID string) bool { + for _, id := range c.workspaces { + if id == workspaceID { + return true + } + } + + return false +} + +func (c *wsClient) isSubscribedToBlock(blockID string) bool { + for _, id := range c.blocks { + if id == blockID { + return true + } + } + + return false +} + // Server is a WebSocket server. type Server struct { - upgrader websocket.Upgrader - listeners map[string][]*wsClient - mu sync.RWMutex - auth *auth.Auth - hub Hub - singleUserToken string - isMattermostAuth bool - logger *mlog.Logger + upgrader websocket.Upgrader + listeners map[*wsClient]bool + listenersByWorkspace map[string][]*wsClient + listenersByBlock map[string][]*wsClient + mu sync.RWMutex + auth *auth.Auth + hub Hub + singleUserToken string + isMattermostAuth bool + logger *mlog.Logger } // UpdateMsg is sent on block updates. @@ -60,11 +89,6 @@ type clusterUpdateMsg struct { WorkspaceID string `json:"workspace_id"` } -// ErrorMsg is sent on errors. -type ErrorMsg struct { - Error string `json:"error"` -} - // WebsocketCommand is an incoming command from the client. type WebsocketCommand struct { Action string `json:"action"` @@ -75,15 +99,20 @@ type WebsocketCommand struct { } type websocketSession struct { - client *wsClient - isAuthenticated bool - workspaceID string + client *wsClient + userID string +} + +func (wss *websocketSession) isAuthenticated() bool { + return wss.userID != "" } // NewServer creates a new Server. func NewServer(auth *auth.Auth, singleUserToken string, isMattermostAuth bool, logger *mlog.Logger) *Server { return &Server{ - listeners: make(map[string][]*wsClient), + listeners: make(map[*wsClient]bool), + listenersByWorkspace: make(map[string][]*wsClient), + listenersByBlock: make(map[string][]*wsClient), upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true @@ -98,10 +127,10 @@ func NewServer(auth *auth.Auth, singleUserToken string, isMattermostAuth bool, l // RegisterRoutes registers routes. func (ws *Server) RegisterRoutes(r *mux.Router) { - r.HandleFunc("/ws/onchange", ws.handleWebSocketOnChange) + r.HandleFunc("/ws", ws.handleWebSocket) } -func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request) { +func (ws *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { // Upgrade initial GET request to a websocket client, err := ws.upgrader.Upgrade(w, r, nil) if err != nil { @@ -109,19 +138,21 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request return } - userID := "" - if ws.isMattermostAuth { - userID = r.Header.Get("Mattermost-User-Id") + // create an empty session with websocket client + wsSession := websocketSession{ + client: &wsClient{client, &sync.Mutex{}, []string{}, []string{}}, + userID: "", } - wsSession := websocketSession{ - client: &wsClient{client, &sync.RWMutex{}}, - isAuthenticated: userID != "", + if ws.isMattermostAuth { + wsSession.userID = r.Header.Get("Mattermost-User-Id") } + ws.addListener(wsSession.client) + // Make sure we close the connection when the function returns defer func() { - ws.logger.Debug("DISCONNECT WebSocket onChange", mlog.Stringer("client", wsSession.client.RemoteAddr())) + ws.logger.Debug("DISCONNECT WebSocket", mlog.Stringer("client", wsSession.client.RemoteAddr())) // Remove client from listeners ws.removeListener(wsSession.client) @@ -132,7 +163,7 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request for { _, p, err := wsSession.client.ReadMessage() if err != nil { - ws.logger.Error("ERROR WebSocket onChange", + ws.logger.Error("ERROR WebSocket", mlog.Stringer("client", wsSession.client.RemoteAddr()), mlog.Err(err), ) @@ -150,199 +181,277 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request continue } - if userID != "" { - if ws.auth.DoesUserHaveWorkspaceAccess(userID, command.WorkspaceID) { - wsSession.workspaceID = command.WorkspaceID - } else { - ws.logger.Error(`ERROR User doesn't have permissions to read the workspace`, mlog.String("workspaceID", command.WorkspaceID)) + if command.Action == websocketActionAuth { + ws.logger.Debug(`Command: AUTH`, mlog.Stringer("client", wsSession.client.RemoteAddr())) + ws.authenticateListener(&wsSession, command.Token) + + continue + } + + // if the client wants to subscribe to a set of blocks and it + // is sending a read token, we don't need to check for + // authentication + if command.Action == websocketActionSubscribeBlocks { + ws.logger.Debug(`Command: SUBSCRIBE_BLOCKS`, + mlog.String("workspaceID", command.WorkspaceID), + mlog.Stringer("client", wsSession.client.RemoteAddr()), + ) + + if !ws.isCommandReadTokenValid(command) { + ws.logger.Error(`Rejected invalid read token`, + mlog.Stringer("client", wsSession.client.RemoteAddr()), + mlog.String("action", command.Action), + mlog.String("readToken", command.ReadToken), + ) + continue } + + ws.subscribeListenerToBlocks(wsSession.client, command.BlockIDs) + continue + } + + // if the command is not authenticated at this point, it will + // not be processed + if !wsSession.isAuthenticated() { + ws.logger.Error(`Rejected unauthenticated message`, + mlog.Stringer("client", wsSession.client.RemoteAddr()), + mlog.String("action", command.Action), + ) + + continue } switch command.Action { - case "AUTH": - ws.logger.Debug(`Command: AUTH`, mlog.Stringer("client", wsSession.client.RemoteAddr())) - ws.authenticateListener(&wsSession, command.WorkspaceID, command.Token) - case "ADD": - ws.logger.Debug(`Command: ADD`, - mlog.String("workspaceID", wsSession.workspaceID), - mlog.Array("blockIDs", command.BlockIDs), - mlog.Stringer("client", wsSession.client.RemoteAddr()), - ) - ws.addListener(&wsSession, &command) - case "REMOVE": - ws.logger.Debug(`Command: REMOVE`, - mlog.String("workspaceID", wsSession.workspaceID), - mlog.Array("blockIDs", command.BlockIDs), + case websocketActionSubscribeWorkspace: + ws.logger.Debug(`Command: SUBSCRIBE_WORKSPACE`, + mlog.String("workspaceID", command.WorkspaceID), mlog.Stringer("client", wsSession.client.RemoteAddr()), ) - ws.removeListenerFromBlocks(&wsSession, &command) + // if single user mode, check that the userID is valid and + // assume that the user has permission if so + if len(ws.singleUserToken) != 0 { + if wsSession.userID != singleUserID { + continue + } + + // if not in single user mode validate that the session + // has permissions to the workspace + } else { + if !ws.auth.DoesUserHaveWorkspaceAccess(wsSession.userID, command.WorkspaceID) { + continue + } + } + + ws.subscribeListenerToWorkspace(wsSession.client, command.WorkspaceID) + case websocketActionUnsubscribeWorkspace: + ws.logger.Debug(`Command: UNSUBSCRIBE_WORKSPACE`, + mlog.String("workspaceID", command.WorkspaceID), + mlog.Stringer("client", wsSession.client.RemoteAddr()), + ) + + ws.unsubscribeListenerFromWorkspace(wsSession.client, command.WorkspaceID) default: ws.logger.Error(`ERROR webSocket command, invalid action`, mlog.String("action", command.Action)) } } } -func (ws *Server) isValidSessionToken(token, workspaceID string) bool { +// isCommandReadTokenValid ensures that a command contains a read +// token and a set of block ids that said token is valid for. +func (ws *Server) isCommandReadTokenValid(command WebsocketCommand) bool { + if len(command.WorkspaceID) == 0 { + return false + } + + container := store.Container{WorkspaceID: command.WorkspaceID} + + if len(command.ReadToken) != 0 && len(command.BlockIDs) != 0 { + // Read token must be valid for all block IDs + for _, blockID := range command.BlockIDs { + isValid, _ := ws.auth.IsValidReadToken(container, blockID, command.ReadToken) + if !isValid { + return false + } + } + return true + } + + return false +} + +// addListener adds a listener to the websocket server. The listener +// should not receive any update from the server until it subscribes +// itself to some entity changes. Adding a listener to the server +// doesn't mean that it's authenticated in any way. +func (ws *Server) addListener(client *wsClient) { + ws.mu.Lock() + defer ws.mu.Unlock() + ws.listeners[client] = true +} + +// removeListener removes a listener and all its subscriptions, if +// any, from the websockets server. +func (ws *Server) removeListener(client *wsClient) { + ws.mu.Lock() + defer ws.mu.Unlock() + + // remove the listener from its subscriptions, if any + + // workspace subscriptions + for _, workspace := range client.workspaces { + ws.removeListenerFromWorkspace(client, workspace) + } + + // block subscriptions + for _, block := range client.blocks { + ws.removeListenerFromBlock(client, block) + } + + delete(ws.listeners, client) +} + +// subscribeListenerToWorkspace safely modifies the listener and the +// server to subscribe the listener to a given workspace updates. +func (ws *Server) subscribeListenerToWorkspace(client *wsClient, workspaceID string) { + if client.isSubscribedToWorkspace(workspaceID) { + return + } + + ws.mu.Lock() + defer ws.mu.Unlock() + + ws.listenersByWorkspace[workspaceID] = append(ws.listenersByWorkspace[workspaceID], client) + client.workspaces = append(client.workspaces, workspaceID) +} + +// unsubscribeListenerFromWorkspace safely modifies the listener and +// the server data structures to remove the link between the listener +// and a given workspace ID. +func (ws *Server) unsubscribeListenerFromWorkspace(client *wsClient, workspaceID string) { + if !client.isSubscribedToWorkspace(workspaceID) { + return + } + + ws.mu.Lock() + defer ws.mu.Unlock() + + ws.removeListenerFromWorkspace(client, workspaceID) +} + +// subscribeListenerToBlocks safely modifies the listener and the +// server to subscribe the listener to a given set of block updates. +func (ws *Server) subscribeListenerToBlocks(client *wsClient, blockIDs []string) { + ws.mu.Lock() + defer ws.mu.Unlock() + + for _, blockID := range blockIDs { + if client.isSubscribedToBlock(blockID) { + continue + } + + ws.listenersByBlock[blockID] = append(ws.listenersByBlock[blockID], client) + client.blocks = append(client.blocks, blockID) + } +} + +// unsubscribeListenerFromBlocks safely modifies the listener and the +// server data structures to remove the link between the listener and +// a given set of block IDs. +func (ws *Server) unsubscribeListenerFromBlocks(client *wsClient, blockIDs []string) { + ws.mu.Lock() + defer ws.mu.Unlock() + + for _, blockID := range blockIDs { + if client.isSubscribedToBlock(blockID) { + ws.removeListenerFromBlock(client, blockID) + } + } +} + +// removeListenerFromWorkspace removes the listener from both its own +// block subscribed list and the server listeners by workspace map. +func (ws *Server) removeListenerFromWorkspace(client *wsClient, workspaceID string) { + // we remove the listener from the workspace index + newWorkspaceListeners := []*wsClient{} + for _, listener := range ws.listenersByWorkspace[workspaceID] { + if listener != client { + newWorkspaceListeners = append(newWorkspaceListeners, listener) + } + } + ws.listenersByWorkspace[workspaceID] = newWorkspaceListeners + + // we remove the workspace from the listener subscription list + newClientWorkspaces := []string{} + for _, id := range client.workspaces { + if id != workspaceID { + newClientWorkspaces = append(newClientWorkspaces, id) + } + } + client.workspaces = newClientWorkspaces +} + +// removeListenerFromBlock removes the listener from both its own +// block subscribed list and the server listeners by block map. +func (ws *Server) removeListenerFromBlock(client *wsClient, blockID string) { + // we remove the listener from the block index + newBlockListeners := []*wsClient{} + for _, listener := range ws.listenersByBlock[blockID] { + if listener != client { + newBlockListeners = append(newBlockListeners, listener) + } + } + ws.listenersByBlock[blockID] = newBlockListeners + + // we remove the block from the listener subscription list + newClientBlocks := []string{} + for _, id := range client.blocks { + if id != blockID { + newClientBlocks = append(newClientBlocks, id) + } + } + client.blocks = newClientBlocks +} + +func (ws *Server) getUserIDForToken(token string) string { if len(ws.singleUserToken) > 0 { - return token == ws.singleUserToken + if token == ws.singleUserToken { + return singleUserID + } else { + return "" + } } session, err := ws.auth.GetSession(token) if session == nil || err != nil { - return false + return "" } - // Check workspace permission - return ws.auth.DoesUserHaveWorkspaceAccess(session.UserID, workspaceID) + return session.UserID } -func (ws *Server) authenticateListener(wsSession *websocketSession, workspaceID, token string) { - if wsSession.isAuthenticated { +func (ws *Server) authenticateListener(wsSession *websocketSession, token string) { + if wsSession.isAuthenticated() { // Do not allow multiple auth calls (for security) - ws.logger.Debug("authenticateListener: Ignoring already authenticated session", mlog.String("workspaceID", workspaceID)) + ws.logger.Debug( + "authenticateListener: Ignoring already authenticated session", + mlog.String("userID", wsSession.userID), + mlog.Stringer("client", wsSession.client.RemoteAddr()), + ) return } // Authenticate session - isValidSession := ws.isValidSessionToken(token, workspaceID) - if !isValidSession { + userID := ws.getUserIDForToken(token) + if userID == "" { wsSession.client.Close() return } // Authenticated - - wsSession.workspaceID = workspaceID - wsSession.isAuthenticated = true - ws.logger.Debug("authenticateListener: Authenticated", mlog.String("workspaceID", workspaceID)) -} - -type AuthWorkspaceError struct { - msg string -} - -func (awe AuthWorkspaceError) Error() string { - return awe.msg -} - -func (ws *Server) getAuthenticatedWorkspaceID(wsSession *websocketSession, command *WebsocketCommand) (string, error) { - if wsSession.isAuthenticated { - return wsSession.workspaceID, nil - } - - // If not authenticated, try to authenticate the read token against the supplied workspaceID - workspaceID := command.WorkspaceID - if len(workspaceID) == 0 { - ws.logger.Error("getAuthenticatedWorkspaceID: No workspace") - return "", AuthWorkspaceError{"no workspace"} - } - - container := store.Container{ - WorkspaceID: workspaceID, - } - - if len(command.ReadToken) > 0 { - // Read token must be valid for all block IDs - for _, blockID := range command.BlockIDs { - isValid, _ := ws.auth.IsValidReadToken(container, blockID, command.ReadToken) - if !isValid { - return "", AuthWorkspaceError{"invalid read token for workspace"} - } - } - return workspaceID, nil - } - - return "", AuthWorkspaceError{"no read token"} -} - -// TODO: Refactor workspace hashing. -func makeItemID(workspaceID, blockID string) string { - return workspaceID + "-" + blockID -} - -// addListener adds a listener for a block's change. -func (ws *Server) addListener(wsSession *websocketSession, command *WebsocketCommand) { - workspaceID, err := ws.getAuthenticatedWorkspaceID(wsSession, command) - if err != nil { - ws.logger.Error("addListener: NOT AUTHENTICATED", mlog.Err(err)) - ws.sendError(wsSession.client, "not authenticated") - return - } - - ws.mu.Lock() - for _, blockID := range command.BlockIDs { - itemID := makeItemID(workspaceID, blockID) - if ws.listeners[itemID] == nil { - ws.listeners[itemID] = []*wsClient{} - } - - ws.listeners[itemID] = append(ws.listeners[itemID], wsSession.client) - } - ws.mu.Unlock() -} - -// removeListener removes a webSocket listener from all blocks. -func (ws *Server) removeListener(client *wsClient) { - ws.mu.Lock() - for key, clients := range ws.listeners { - listeners := []*wsClient{} - - for _, existingClient := range clients { - if client != existingClient { - listeners = append(listeners, existingClient) - } - } - - ws.listeners[key] = listeners - } - ws.mu.Unlock() -} - -// removeListenerFromBlocks removes a webSocket listener from a set of blocks. -func (ws *Server) removeListenerFromBlocks(wsSession *websocketSession, command *WebsocketCommand) { - workspaceID, err := ws.getAuthenticatedWorkspaceID(wsSession, command) - if err != nil { - ws.logger.Error("addListener: NOT AUTHENTICATED", mlog.Err(err)) - ws.sendError(wsSession.client, "not authenticated") - return - } - - ws.mu.Lock() - for _, blockID := range command.BlockIDs { - itemID := makeItemID(workspaceID, blockID) - listeners := ws.listeners[itemID] - if listeners == nil { - return - } - - // Remove the first instance of this client that's listening to this block - // Note: A client can listen multiple times to the same block - for index, listener := range listeners { - if wsSession.client == listener { - newListeners := listeners[:index] - newListeners = append(newListeners, listeners[index+1:]...) - ws.listeners[itemID] = newListeners - - break - } - } - } - - ws.mu.Unlock() -} - -func (ws *Server) sendError(wsClient *wsClient, message string) { - errorMsg := ErrorMsg{ - Error: message, - } - - err := wsClient.WriteJSON(errorMsg) - if err != nil { - ws.logger.Error("sendError error", mlog.Err(err)) - wsClient.Close() - } + wsSession.userID = userID + ws.logger.Debug("authenticateListener: Authenticated", mlog.String("userID", userID), mlog.Stringer("client", wsSession.client.RemoteAddr())) } func (ws *Server) SetHub(hub Hub) { @@ -355,8 +464,10 @@ func (ws *Server) SetHub(hub Hub) { return } - listeners := ws.getListeners(msg.WorkspaceID, msg.BlockID) + listeners := ws.getListenersForBlock(msg.BlockID) log.Printf("%d listener(s) for blockID: %s", len(listeners), msg.BlockID) + listeners = append(listeners, ws.getListenersForWorkspace(msg.WorkspaceID)...) + log.Printf("%d listener(s) for workspaceID: %s", len(listeners), msg.WorkspaceID) message := UpdateMsg{ Action: msg.Action, @@ -375,14 +486,16 @@ func (ws *Server) SetHub(hub Hub) { }) } -// getListeners returns the listeners to a blockID's changes. -func (ws *Server) getListeners(workspaceID string, blockID string) []*wsClient { - ws.mu.Lock() - itemID := makeItemID(workspaceID, blockID) - listeners := ws.listeners[itemID] - ws.mu.Unlock() +// getListenersForBlock returns the listeners subscribed to a +// block changes. +func (ws *Server) getListenersForBlock(blockID string) []*wsClient { + return ws.listenersByBlock[blockID] +} - return listeners +// getListenersForWorkspace returns the listeners subscribed to a +// workspace changes. +func (ws *Server) getListenersForWorkspace(workspaceID string) []*wsClient { + return ws.listenersByWorkspace[workspaceID] } // BroadcastBlockDelete broadcasts delete messages to clients. @@ -401,17 +514,24 @@ func (ws *Server) BroadcastBlockDelete(workspaceID, blockID, parentID string) { func (ws *Server) BroadcastBlockChange(workspaceID string, block model.Block) { blockIDsToNotify := []string{block.ID, block.ParentID} + message := UpdateMsg{ + Action: "UPDATE_BLOCK", + Block: block, + } + + listeners := ws.getListenersForWorkspace(workspaceID) + ws.logger.Debug("listener(s) for workspaceID", + mlog.Int("listener_count", len(listeners)), + mlog.String("workspaceID", workspaceID), + ) + for _, blockID := range blockIDsToNotify { - listeners := ws.getListeners(workspaceID, blockID) + listeners = append(listeners, ws.getListenersForBlock(blockID)...) ws.logger.Debug("listener(s) for blockID", mlog.Int("listener_count", len(listeners)), mlog.String("blockID", blockID), ) - message := UpdateMsg{ - Action: "UPDATE_BLOCK", - Block: block, - } if ws.hub != nil { data, err := json.Marshal(clusterUpdateMsg{UpdateMsg: message, WorkspaceID: workspaceID, BlockID: blockID}) if err != nil { @@ -419,19 +539,19 @@ func (ws *Server) BroadcastBlockChange(workspaceID string, block model.Block) { } ws.hub.SendWSMessage(data) } + } - for _, listener := range listeners { - ws.logger.Debug("Broadcast change", - mlog.String("workspaceID", workspaceID), - mlog.String("blockID", blockID), - mlog.Stringer("remoteAddr", listener.RemoteAddr()), - ) + for _, listener := range listeners { + ws.logger.Debug("Broadcast change", + mlog.String("workspaceID", workspaceID), + mlog.String("blockID", block.ID), + mlog.Stringer("remoteAddr", listener.RemoteAddr()), + ) - err := listener.WriteJSON(message) - if err != nil { - ws.logger.Error("broadcast error", mlog.Err(err)) - listener.Close() - } + err := listener.WriteJSON(message) + if err != nil { + ws.logger.Error("broadcast error", mlog.Err(err)) + listener.Close() } } } diff --git a/server/ws/websockets_test.go b/server/ws/websockets_test.go new file mode 100644 index 000000000..b2e329f6e --- /dev/null +++ b/server/ws/websockets_test.go @@ -0,0 +1,231 @@ +package ws + +import ( + "sync" + "testing" + + "github.com/mattermost/focalboard/server/auth" + "github.com/mattermost/focalboard/server/services/mlog" + + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" +) + +func TestWorkspaceSubscription(t *testing.T) { + server := NewServer(&auth.Auth{}, "token", false, &mlog.Logger{}) + client := &wsClient{&websocket.Conn{}, &sync.Mutex{}, []string{}, []string{}} + session := &websocketSession{client: client} + workspaceID := "fake-workspace-id" + + t.Run("Should correctly add a session", func(t *testing.T) { + server.addListener(session.client) + require.Len(t, server.listeners, 1) + require.Empty(t, server.listenersByWorkspace) + require.Empty(t, client.workspaces) + }) + + t.Run("Should correctly subscribe to a workspace", func(t *testing.T) { + require.False(t, client.isSubscribedToWorkspace(workspaceID)) + + server.subscribeListenerToWorkspace(client, workspaceID) + + require.Len(t, server.listenersByWorkspace[workspaceID], 1) + require.Contains(t, server.listenersByWorkspace[workspaceID], client) + require.Len(t, client.workspaces, 1) + require.Contains(t, client.workspaces, workspaceID) + + require.True(t, client.isSubscribedToWorkspace(workspaceID)) + }) + + t.Run("Subscribing again to a subscribed workspace would have no effect", func(t *testing.T) { + require.True(t, client.isSubscribedToWorkspace(workspaceID)) + + server.subscribeListenerToWorkspace(client, workspaceID) + + require.Len(t, server.listenersByWorkspace[workspaceID], 1) + require.Contains(t, server.listenersByWorkspace[workspaceID], client) + require.Len(t, client.workspaces, 1) + require.Contains(t, client.workspaces, workspaceID) + + require.True(t, client.isSubscribedToWorkspace(workspaceID)) + }) + + t.Run("Should correctly unsubscribe to a workspace", func(t *testing.T) { + require.True(t, client.isSubscribedToWorkspace(workspaceID)) + + server.unsubscribeListenerFromWorkspace(client, workspaceID) + + require.Empty(t, server.listenersByWorkspace[workspaceID]) + require.Empty(t, client.workspaces) + + require.False(t, client.isSubscribedToWorkspace(workspaceID)) + }) + + t.Run("Unsubscribing again to an unsubscribed workspace would have no effect", func(t *testing.T) { + require.False(t, client.isSubscribedToWorkspace(workspaceID)) + + server.unsubscribeListenerFromWorkspace(client, workspaceID) + + require.Empty(t, server.listenersByWorkspace[workspaceID]) + require.Empty(t, client.workspaces) + + require.False(t, client.isSubscribedToWorkspace(workspaceID)) + }) + + t.Run("Should correctly be removed from the server", func(t *testing.T) { + server.removeListener(client) + + require.Empty(t, server.listeners) + }) + + t.Run("If subscribed to workspaces and removed, should be removed from the workspaces subscription list", func(t *testing.T) { + workspaceID2 := "other-fake-workspace-id" + + server.addListener(session.client) + server.subscribeListenerToWorkspace(client, workspaceID) + server.subscribeListenerToWorkspace(client, workspaceID2) + + require.Len(t, server.listeners, 1) + require.Contains(t, server.listenersByWorkspace[workspaceID], client) + require.Contains(t, server.listenersByWorkspace[workspaceID2], client) + + server.removeListener(client) + + require.Empty(t, server.listeners) + require.Empty(t, server.listenersByWorkspace[workspaceID]) + require.Empty(t, server.listenersByWorkspace[workspaceID2]) + }) +} + +func TestBlocksSubscription(t *testing.T) { + server := NewServer(&auth.Auth{}, "token", false, &mlog.Logger{}) + client := &wsClient{&websocket.Conn{}, &sync.Mutex{}, []string{}, []string{}} + session := &websocketSession{client: client} + blockID1 := "block1" + blockID2 := "block2" + blockID3 := "block3" + blockIDs := []string{blockID1, blockID2, blockID3} + + t.Run("Should correctly add a session", func(t *testing.T) { + server.addListener(session.client) + require.Len(t, server.listeners, 1) + require.Empty(t, server.listenersByWorkspace) + require.Empty(t, client.workspaces) + }) + + t.Run("Should correctly subscribe to a set of blocks", func(t *testing.T) { + require.False(t, client.isSubscribedToBlock(blockID1)) + require.False(t, client.isSubscribedToBlock(blockID2)) + require.False(t, client.isSubscribedToBlock(blockID3)) + + server.subscribeListenerToBlocks(client, blockIDs) + + require.Len(t, server.listenersByBlock[blockID1], 1) + require.Contains(t, server.listenersByBlock[blockID1], client) + require.Len(t, server.listenersByBlock[blockID2], 1) + require.Contains(t, server.listenersByBlock[blockID2], client) + require.Len(t, server.listenersByBlock[blockID3], 1) + require.Contains(t, server.listenersByBlock[blockID3], client) + require.Len(t, client.blocks, 3) + require.ElementsMatch(t, blockIDs, client.blocks) + + require.True(t, client.isSubscribedToBlock(blockID1)) + require.True(t, client.isSubscribedToBlock(blockID2)) + require.True(t, client.isSubscribedToBlock(blockID3)) + + t.Run("Subscribing again to a subscribed block would have no effect", func(t *testing.T) { + require.True(t, client.isSubscribedToBlock(blockID1)) + require.True(t, client.isSubscribedToBlock(blockID2)) + require.True(t, client.isSubscribedToBlock(blockID3)) + + server.subscribeListenerToBlocks(client, blockIDs) + + require.Len(t, server.listenersByBlock[blockID1], 1) + require.Contains(t, server.listenersByBlock[blockID1], client) + require.Len(t, server.listenersByBlock[blockID2], 1) + require.Contains(t, server.listenersByBlock[blockID2], client) + require.Len(t, server.listenersByBlock[blockID3], 1) + require.Contains(t, server.listenersByBlock[blockID3], client) + require.Len(t, client.blocks, 3) + require.ElementsMatch(t, blockIDs, client.blocks) + + require.True(t, client.isSubscribedToBlock(blockID1)) + require.True(t, client.isSubscribedToBlock(blockID2)) + require.True(t, client.isSubscribedToBlock(blockID3)) + }) + }) + + t.Run("Should correctly unsubscribe to a set of blocks", func(t *testing.T) { + require.True(t, client.isSubscribedToBlock(blockID1)) + require.True(t, client.isSubscribedToBlock(blockID2)) + require.True(t, client.isSubscribedToBlock(blockID3)) + + server.unsubscribeListenerFromBlocks(client, blockIDs) + + require.Empty(t, server.listenersByBlock[blockID1]) + require.Empty(t, server.listenersByBlock[blockID2]) + require.Empty(t, server.listenersByBlock[blockID3]) + require.Empty(t, client.blocks) + + require.False(t, client.isSubscribedToBlock(blockID1)) + require.False(t, client.isSubscribedToBlock(blockID2)) + require.False(t, client.isSubscribedToBlock(blockID3)) + }) + + t.Run("Unsubscribing again to an unsubscribed block would have no effect", func(t *testing.T) { + require.False(t, client.isSubscribedToBlock(blockID1)) + + server.unsubscribeListenerFromBlocks(client, []string{blockID1}) + + require.Empty(t, server.listenersByBlock[blockID1]) + require.Empty(t, client.blocks) + + require.False(t, client.isSubscribedToBlock(blockID1)) + }) + + t.Run("Should correctly be removed from the server", func(t *testing.T) { + server.removeListener(client) + + require.Empty(t, server.listeners) + }) + + t.Run("If subscribed to blocks and removed, should be removed from the blocks subscription list", func(t *testing.T) { + server.addListener(session.client) + server.subscribeListenerToBlocks(client, blockIDs) + + require.Len(t, server.listeners, 1) + require.Len(t, server.listenersByBlock[blockID1], 1) + require.Contains(t, server.listenersByBlock[blockID1], client) + require.Len(t, server.listenersByBlock[blockID2], 1) + require.Contains(t, server.listenersByBlock[blockID2], client) + require.Len(t, server.listenersByBlock[blockID3], 1) + require.Contains(t, server.listenersByBlock[blockID3], client) + require.Len(t, client.blocks, 3) + require.ElementsMatch(t, blockIDs, client.blocks) + + server.removeListener(client) + + require.Empty(t, server.listeners) + require.Empty(t, server.listenersByBlock[blockID1]) + require.Empty(t, server.listenersByBlock[blockID2]) + require.Empty(t, server.listenersByBlock[blockID3]) + }) +} + +func TestGetUserIDForTokenInSingleUserMode(t *testing.T) { + singleUserToken := "single-user-token" + server := NewServer(&auth.Auth{}, "token", false, &mlog.Logger{}) + server.singleUserToken = singleUserToken + + t.Run("Should return nothing if the token is empty", func(t *testing.T) { + require.Empty(t, server.getUserIDForToken("")) + }) + + t.Run("Should return nothing if the token is invalid", func(t *testing.T) { + require.Empty(t, server.getUserIDForToken("invalid-token")) + }) + + t.Run("Should return the single user ID if the token is correct", func(t *testing.T) { + require.Equal(t, singleUserID, server.getUserIDForToken(singleUserToken)) + }) +} diff --git a/webapp/src/octoListener.ts b/webapp/src/octoListener.ts index 52984fef5..741cad633 100644 --- a/webapp/src/octoListener.ts +++ b/webapp/src/octoListener.ts @@ -8,7 +8,7 @@ type WSCommand = { action: string workspaceId?: string readToken?: string - blockIds: string[] + blockIds?: string[] } // These are messages from the server @@ -67,7 +67,7 @@ class OctoListener { const url = new URL(this.serverUrl) const protocol = (url.protocol === 'https:') ? 'wss:' : 'ws:' - const wsServerUrl = `${protocol}//${url.host}${url.pathname.replace(/\/$/, '')}/ws/onchange` + const wsServerUrl = `${protocol}//${url.host}${url.pathname.replace(/\/$/, '')}/ws` Utils.log(`OctoListener open: ${wsServerUrl}`) const ws = new WebSocket(wsServerUrl) this.ws = ws @@ -75,7 +75,7 @@ class OctoListener { ws.onopen = () => { Utils.log('OctoListener webSocket opened.') this.authenticate(workspaceId) - this.addBlocks(blockIds) + this.subscribeToWorkspace(workspaceId) this.isInitialized = true onStateChange?.('open') } @@ -166,14 +166,14 @@ class OctoListener { this.ws.send(JSON.stringify(command)) } - private addBlocks(blockIds: string[]): void { + private subscribeToBlocks(blockIds: string[]): void { if (!this.ws) { - Utils.assertFailure('OctoListener.addBlocks: ws is not open') + Utils.assertFailure('OctoListener.subscribeToBlocks: ws is not open') return } const command: WSCommand = { - action: 'ADD', + action: 'SUBSCRIBE_BLOCKS', blockIds, workspaceId: this.workspaceId, readToken: this.readToken, @@ -183,6 +183,20 @@ class OctoListener { this.blockIds.push(...blockIds) } + private subscribeToWorkspace(workspaceId: string): void { + if (!this.ws) { + Utils.assertFailure('OctoListener.subscribeToWorkspace: ws is not open') + return + } + + const command: WSCommand = { + action: 'SUBSCRIBE_WORKSPACE', + workspaceId, + } + + this.ws.send(JSON.stringify(command)) + } + private removeBlocks(blockIds: string[]): void { if (!this.ws) { Utils.assertFailure('OctoListener.removeBlocks: ws is not open')