diff --git a/server/app/app.go b/server/app/app.go index bc7aaa7a4..14363935a 100644 --- a/server/app/app.go +++ b/server/app/app.go @@ -1,6 +1,7 @@ package app import ( + "github.com/mattermost/focalboard/server/auth" "github.com/mattermost/focalboard/server/services/config" "github.com/mattermost/focalboard/server/services/store" "github.com/mattermost/focalboard/server/services/webhook" @@ -11,11 +12,26 @@ import ( type App struct { config *config.Configuration store store.Store + auth *auth.Auth wsServer *ws.Server filesBackend filesstore.FileBackend webhook *webhook.Client } -func New(config *config.Configuration, store store.Store, wsServer *ws.Server, filesBackend filesstore.FileBackend, webhook *webhook.Client) *App { - return &App{config: config, store: store, wsServer: wsServer, filesBackend: filesBackend, webhook: webhook} +func New( + config *config.Configuration, + store store.Store, + auth *auth.Auth, + wsServer *ws.Server, + filesBackend filesstore.FileBackend, + webhook *webhook.Client, +) *App { + return &App{ + config: config, + store: store, + auth: auth, + wsServer: wsServer, + filesBackend: filesBackend, + webhook: webhook, + } } diff --git a/server/app/auth.go b/server/app/auth.go index 51f923ab2..6e6760c6a 100644 --- a/server/app/auth.go +++ b/server/app/auth.go @@ -2,7 +2,6 @@ package app import ( "log" - "time" "github.com/google/uuid" "github.com/mattermost/focalboard/server/model" @@ -13,18 +12,7 @@ import ( // GetSession Get a user active session and refresh the session if is needed func (a *App) GetSession(token string) (*model.Session, error) { - if len(token) < 1 { - return nil, errors.New("no session token") - } - - session, err := a.store.GetSession(token, a.config.SessionExpireTime) - if err != nil { - return nil, errors.Wrap(err, "unable to get the session for the token") - } - if session.UpdateAt < (time.Now().Unix() - a.config.SessionRefreshTime) { - a.store.RefreshSession(session) - } - return session, nil + return a.auth.GetSession(token) } // GetRegisteredUserCount returns the number of registered users diff --git a/server/app/blocks_test.go b/server/app/blocks_test.go index 382137873..105417e26 100644 --- a/server/app/blocks_test.go +++ b/server/app/blocks_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/golang/mock/gomock" + "github.com/mattermost/focalboard/server/auth" "github.com/mattermost/focalboard/server/services/config" "github.com/mattermost/focalboard/server/services/store/mockstore" "github.com/mattermost/focalboard/server/services/webhook" @@ -18,9 +19,10 @@ func TestGetParentID(t *testing.T) { defer ctrl.Finish() cfg := config.Configuration{} store := mockstore.NewMockStore(ctrl) - wsserver := ws.NewServer() + auth := auth.New(&cfg, store) + wsserver := ws.NewServer(auth, true) webhook := webhook.NewClient(&cfg) - app := New(&cfg, store, wsserver, &mocks.FileBackend{}, webhook) + app := New(&cfg, store, auth, wsserver, &mocks.FileBackend{}, webhook) t.Run("success query", func(t *testing.T) { store.EXPECT().GetParentID(gomock.Eq("test-id")).Return("test-parent-id", nil) diff --git a/server/auth/auth.go b/server/auth/auth.go new file mode 100644 index 000000000..7ed1fb9ba --- /dev/null +++ b/server/auth/auth.go @@ -0,0 +1,37 @@ +package auth + +import ( + "time" + + "github.com/mattermost/focalboard/server/model" + "github.com/mattermost/focalboard/server/services/config" + "github.com/mattermost/focalboard/server/services/store" + "github.com/pkg/errors" +) + +// Auth authenticates sessions +type Auth struct { + config *config.Configuration + store store.Store +} + +// New returns a new Auth +func New(config *config.Configuration, store store.Store) *Auth { + return &Auth{config: config, store: store} +} + +// GetSession Get a user active session and refresh the session if is needed +func (a *Auth) GetSession(token string) (*model.Session, error) { + if len(token) < 1 { + return nil, errors.New("no session token") + } + + session, err := a.store.GetSession(token, a.config.SessionExpireTime) + if err != nil { + return nil, errors.Wrap(err, "unable to get the session for the token") + } + if session.UpdateAt < (time.Now().Unix() - a.config.SessionRefreshTime) { + a.store.RefreshSession(session) + } + return session, nil +} diff --git a/server/server/server.go b/server/server/server.go index 25324e1ec..4a683a76b 100644 --- a/server/server/server.go +++ b/server/server/server.go @@ -17,6 +17,7 @@ import ( "github.com/mattermost/focalboard/server/api" "github.com/mattermost/focalboard/server/app" + "github.com/mattermost/focalboard/server/auth" "github.com/mattermost/focalboard/server/context" appModel "github.com/mattermost/focalboard/server/model" "github.com/mattermost/focalboard/server/services/config" @@ -60,7 +61,9 @@ func New(cfg *config.Configuration, singleUser bool) (*Server, error) { return nil, err } - wsServer := ws.NewServer() + auth := auth.New(cfg, store) + + wsServer := ws.NewServer(auth, singleUser) filesBackendSettings := model.FileSettings{} filesBackendSettings.SetDefaults(false) @@ -74,7 +77,7 @@ func New(cfg *config.Configuration, singleUser bool) (*Server, error) { webhookClient := webhook.NewClient(cfg) - appBuilder := func() *app.App { return app.New(cfg, store, wsServer, filesBackend, webhookClient) } + appBuilder := func() *app.App { return app.New(cfg, store, auth, wsServer, filesBackend, webhookClient) } api := api.NewAPI(appBuilder, singleUser) // Local router for admin APIs diff --git a/server/ws/websockets.go b/server/ws/websockets.go index 739fe9bb1..ac90073e9 100644 --- a/server/ws/websockets.go +++ b/server/ws/websockets.go @@ -9,95 +9,20 @@ import ( "github.com/gorilla/mux" "github.com/gorilla/websocket" + "github.com/mattermost/focalboard/server/auth" "github.com/mattermost/focalboard/server/model" ) -// RegisterRoutes registers routes. -func (ws *Server) RegisterRoutes(r *mux.Router) { - r.HandleFunc("/ws/onchange", ws.handleWebSocketOnChange) -} - -// AddListener adds a listener for a block's change. -func (ws *Server) AddListener(client *websocket.Conn, blockIDs []string) { - ws.mu.Lock() - for _, blockID := range blockIDs { - if ws.listeners[blockID] == nil { - ws.listeners[blockID] = []*websocket.Conn{} - } - - ws.listeners[blockID] = append(ws.listeners[blockID], client) - } - ws.mu.Unlock() -} - -// RemoveListener removes a webSocket listener from all blocks. -func (ws *Server) RemoveListener(client *websocket.Conn) { - ws.mu.Lock() - for key, clients := range ws.listeners { - listeners := []*websocket.Conn{} - - for _, existingClient := range clients { - if client != existingClient { - listeners = append(listeners, existingClient) - } - } - - ws.listeners[key] = listeners - } - ws.mu.Unlock() -} - -// RemoveListenerFromBlocks removes a webSocket listener from a set of block. -func (ws *Server) RemoveListenerFromBlocks(client *websocket.Conn, blockIDs []string) { - ws.mu.Lock() - - for _, blockID := range blockIDs { - listeners := ws.listeners[blockID] - if listeners == nil { - return - } - - // Remove the first instance of this client that's listening to this block - // Note: A client can listen multiple times to the same block - for index, listener := range listeners { - if client == listener { - newListeners := append(listeners[:index], listeners[index+1:]...) - ws.listeners[blockID] = newListeners - - break - } - } - } - - ws.mu.Unlock() -} - -// GetListeners returns the listeners to a blockID's changes. -func (ws *Server) GetListeners(blockID string) []*websocket.Conn { - ws.mu.Lock() - listeners := ws.listeners[blockID] - ws.mu.Unlock() - - return listeners -} +// IsValidSessionToken authenticates session tokens +type IsValidSessionToken func(token string) bool // Server is a WebSocket server. type Server struct { - upgrader websocket.Upgrader - listeners map[string][]*websocket.Conn - mu sync.RWMutex -} - -// NewServer creates a new Server. -func NewServer() *Server { - return &Server{ - listeners: make(map[string][]*websocket.Conn), - upgrader: websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - return true - }, - }, - } + upgrader websocket.Upgrader + listeners map[string][]*websocket.Conn + mu sync.RWMutex + auth *auth.Auth + singleUser bool } // UpdateMsg is sent on block updates @@ -106,12 +31,42 @@ type UpdateMsg struct { Block model.Block `json:"block"` } +// ErrorMsg is sent on errors +type ErrorMsg struct { + Error string `json:"error"` +} + // WebsocketCommand is an incoming command from the client. type WebsocketCommand struct { Action string `json:"action"` + Token string `json:"token"` BlockIDs []string `json:"blockIds"` } +type websocketSession struct { + client *websocket.Conn + isAuthenticated bool +} + +// NewServer creates a new Server. +func NewServer(auth *auth.Auth, singleUser bool) *Server { + return &Server{ + listeners: make(map[string][]*websocket.Conn), + upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + }, + auth: auth, + singleUser: singleUser, + } +} + +// RegisterRoutes registers routes. +func (ws *Server) RegisterRoutes(r *mux.Router) { + r.HandleFunc("/ws/onchange", ws.handleWebSocketOnChange) +} + func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request) { // Upgrade initial GET request to a websocket client, err := ws.upgrader.Upgrade(w, r, nil) @@ -128,17 +83,22 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request log.Printf("DISCONNECT WebSocket onChange, client: %s", client.RemoteAddr()) // Remove client from listeners - ws.RemoveListener(client) + ws.removeListener(client) client.Close() }() + wsSession := websocketSession{ + client: client, + isAuthenticated: ws.singleUser, + } + // Simple message handling loop for { _, p, err := client.ReadMessage() if err != nil { log.Printf("ERROR WebSocket onChange, client: %s, err: %v", client.RemoteAddr(), err) - ws.RemoveListener(client) + ws.removeListener(client) break } @@ -154,13 +114,17 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request } switch command.Action { + case "AUTH": + log.Printf(`Command: AUTH, client: %s`, client.RemoteAddr()) + ws.authenticateListener(&wsSession, command.Token) + case "ADD": log.Printf(`Command: Add blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr()) - ws.AddListener(client, command.BlockIDs) + ws.addListener(&wsSession, command.BlockIDs) case "REMOVE": log.Printf(`Command: Remove blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr()) - ws.RemoveListenerFromBlocks(client, command.BlockIDs) + ws.removeListenerFromBlocks(&wsSession, command.BlockIDs) default: log.Printf(`ERROR webSocket command, invalid action: %v`, command.Action) @@ -168,6 +132,119 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request } } +func (ws *Server) isValidSessionToken(token string) bool { + if ws.singleUser { + return true + } + + session, err := ws.auth.GetSession(token) + if session == nil || err != nil { + return false + } + + return true +} + +func (ws *Server) authenticateListener(wsSession *websocketSession, token string) { + isValidSession := ws.isValidSessionToken(token) + if !isValidSession { + wsSession.client.Close() + return + } + + // Authenticated + wsSession.isAuthenticated = true + log.Printf("authenticateListener: Authenticated") +} + +// addListener adds a listener for a block's change. +func (ws *Server) addListener(wsSession *websocketSession, blockIDs []string) { + if !wsSession.isAuthenticated { + log.Printf("addListener: NOT AUTHENTICATED") + sendError(wsSession.client, "not authenticated") + return + } + + ws.mu.Lock() + for _, blockID := range blockIDs { + if ws.listeners[blockID] == nil { + ws.listeners[blockID] = []*websocket.Conn{} + } + + ws.listeners[blockID] = append(ws.listeners[blockID], wsSession.client) + } + ws.mu.Unlock() +} + +// removeListener removes a webSocket listener from all blocks. +func (ws *Server) removeListener(client *websocket.Conn) { + ws.mu.Lock() + for key, clients := range ws.listeners { + listeners := []*websocket.Conn{} + + for _, existingClient := range clients { + if client != existingClient { + listeners = append(listeners, existingClient) + } + } + + ws.listeners[key] = listeners + } + ws.mu.Unlock() +} + +// removeListenerFromBlocks removes a webSocket listener from a set of block. +func (ws *Server) removeListenerFromBlocks(wsSession *websocketSession, blockIDs []string) { + if !wsSession.isAuthenticated { + log.Printf("removeListenerFromBlocks: NOT AUTHENTICATED") + sendError(wsSession.client, "not authenticated") + return + } + + ws.mu.Lock() + + for _, blockID := range blockIDs { + listeners := ws.listeners[blockID] + if listeners == nil { + return + } + + // Remove the first instance of this client that's listening to this block + // Note: A client can listen multiple times to the same block + for index, listener := range listeners { + if wsSession.client == listener { + newListeners := append(listeners[:index], listeners[index+1:]...) + ws.listeners[blockID] = newListeners + + break + } + } + } + + ws.mu.Unlock() +} + +func sendError(conn *websocket.Conn, message string) { + errorMsg := ErrorMsg{ + Error: message, + } + + err := conn.WriteJSON(errorMsg) + if err != nil { + log.Printf("sendError error: %v", err) + conn.Close() + } +} + +// getListeners returns the listeners to a blockID's changes. +func (ws *Server) getListeners(blockID string) []*websocket.Conn { + ws.mu.Lock() + listeners := ws.listeners[blockID] + ws.mu.Unlock() + + return listeners +} + // BroadcastBlockDelete broadcasts delete messages to clients func (ws *Server) BroadcastBlockDelete(blockID string, parentID string) { now := time.Now().Unix() @@ -185,7 +262,7 @@ func (ws *Server) BroadcastBlockChange(block model.Block) { blockIDsToNotify := []string{block.ID, block.ParentID} for _, blockID := range blockIDsToNotify { - listeners := ws.GetListeners(blockID) + listeners := ws.getListeners(blockID) log.Printf("%d listener(s) for blockID: %s", len(listeners), blockID) if listeners != nil { diff --git a/webapp/src/octoListener.ts b/webapp/src/octoListener.ts index 2e38be09f..335c922a2 100644 --- a/webapp/src/octoListener.ts +++ b/webapp/src/octoListener.ts @@ -11,9 +11,9 @@ type WSCommand = { // These are messages from the server type WSMessage = { - action: string - blockId: string - block: IBlock + action?: string + block?: IBlock + error?: string } type OnChangeHandler = (blocks: IBlock[]) => void @@ -27,6 +27,7 @@ class OctoListener { } readonly serverUrl: string + private token: string private ws?: WebSocket private blockIds: string[] = [] private isInitialized = false @@ -38,14 +39,13 @@ class OctoListener { notificationDelay = 100 reopenDelay = 3000 - constructor(serverUrl?: string) { + constructor(serverUrl?: string, token?: string) { this.serverUrl = serverUrl || window.location.origin + this.token = token || localStorage.getItem('sessionId') || '' Utils.log(`OctoListener serverUrl: ${this.serverUrl}`) } open(blockIds: string[], onChange: OnChangeHandler, onReconnect: () => void): void { - let timeoutId: NodeJS.Timeout - if (this.ws) { this.close() } @@ -61,6 +61,7 @@ class OctoListener { ws.onopen = () => { Utils.log('OctoListener webSocket opened.') + this.authenticate() this.addBlocks(blockIds) this.isInitialized = true } @@ -91,13 +92,15 @@ class OctoListener { try { const message = JSON.parse(e.data) as WSMessage + if (message.error) { + Utils.logError(`Listener websocket error: ${message.error}`) + return + } + switch (message.action) { case 'UPDATE_BLOCK': - if (timeoutId) { - clearTimeout(timeoutId) - } Utils.log(`OctoListener update block: ${message.block?.id}`) - this.queueUpdateNotification(message.block) + this.queueUpdateNotification(message.block!) break default: Utils.logError(`Unexpected action: ${message.action}`) @@ -124,6 +127,20 @@ class OctoListener { ws.close() } + authenticate(): void { + if (!this.ws) { + Utils.assertFailure('OctoListener.addBlocks: ws is not open') + return + } + + const command = { + action: 'AUTH', + token: this.token, + } + + this.ws.send(JSON.stringify(command)) + } + addBlocks(blockIds: string[]): void { if (!this.ws) { Utils.assertFailure('OctoListener.addBlocks: ws is not open')