diff --git a/server/ws/websockets.go b/server/ws/websockets.go index a6d836eaf..6ed562057 100644 --- a/server/ws/websockets.go +++ b/server/ws/websockets.go @@ -2,6 +2,7 @@ package ws import ( "encoding/json" + "errors" "log" "net/http" "sync" @@ -182,22 +183,27 @@ func (ws *Server) authenticateListener(wsSession *websocketSession, workspaceID, if workspaceID == "0" { workspaceID = "" } + wsSession.workspaceID = workspaceID wsSession.isAuthenticated = true log.Printf("authenticateListener: Authenticated, workspaceID: %s", workspaceID) } -func (ws *Server) getAuthenticatedWorkspaceID(wsSession *websocketSession, command *WebsocketCommand) string { +func (ws *Server) getAuthenticatedWorkspaceID(wsSession *websocketSession, command *WebsocketCommand) (string, error) { if wsSession.isAuthenticated { - return wsSession.workspaceID + return wsSession.workspaceID, nil } // If not authenticated, try to authenticate the read token against the supplied workspaceID workspaceID := command.WorkspaceID if len(workspaceID) == 0 { log.Printf("getAuthenticatedWorkspaceID: No workspace") - sendError(wsSession.client, "No workspace") - return "" + return "", errors.New("No workspace") + } + + // Special case: Default workspace is blank + if workspaceID == "0" { + workspaceID = "" } container := store.Container{ @@ -209,13 +215,13 @@ func (ws *Server) getAuthenticatedWorkspaceID(wsSession *websocketSession, comma for _, blockID := range command.BlockIDs { isValid, _ := ws.auth.IsValidReadToken(container, blockID, command.ReadToken) if !isValid { - return "" + return "", errors.New("Invalid read token for workspace") } } - return workspaceID + return workspaceID, nil } - return "" + return "", errors.New("No read token") } // TODO: Refactor workspace hashing @@ -225,9 +231,9 @@ func makeItemID(workspaceID, blockID string) string { // addListener adds a listener for a block's change. func (ws *Server) addListener(wsSession *websocketSession, command *WebsocketCommand) { - workspaceID := ws.getAuthenticatedWorkspaceID(wsSession, command) - if len(workspaceID) == 0 { - log.Printf("addListener: NOT AUTHENTICATED") + workspaceID, err := ws.getAuthenticatedWorkspaceID(wsSession, command) + if err != nil { + log.Printf("addListener: NOT AUTHENTICATED, ERROR: %v", err) sendError(wsSession.client, "not authenticated") return } @@ -263,9 +269,9 @@ func (ws *Server) removeListener(client *websocket.Conn) { // removeListenerFromBlocks removes a webSocket listener from a set of block. func (ws *Server) removeListenerFromBlocks(wsSession *websocketSession, command *WebsocketCommand) { - workspaceID := ws.getAuthenticatedWorkspaceID(wsSession, command) - if len(workspaceID) == 0 { - log.Printf("addListener: NOT AUTHENTICATED") + workspaceID, err := ws.getAuthenticatedWorkspaceID(wsSession, command) + if err != nil { + log.Printf("addListener: NOT AUTHENTICATED, ERROR: %v", err) sendError(wsSession.client, "not authenticated") return }