From 1020c03924d1fb17679d642ade13c48c8315d76f Mon Sep 17 00:00:00 2001 From: Harshil Sharma Date: Thu, 1 Jul 2021 11:41:29 +0530 Subject: [PATCH] Prevented concurrent writes to websocket (#658) * retained individual connection objects * unlocking lock in defer * Completely abstracted internal connection object * Completely removed direct use of WS connection --- server/ws/websockets.go | 66 ++++++++++++++++++++++++----------------- 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/server/ws/websockets.go b/server/ws/websockets.go index 57c71e7fe..e117fe204 100644 --- a/server/ws/websockets.go +++ b/server/ws/websockets.go @@ -24,10 +24,22 @@ type Hub interface { SetReceiveWSMessage(func(data []byte)) } +type wsClient struct { + *websocket.Conn + lock *sync.RWMutex +} + +func (c *wsClient) WriteJSON(v interface{}) error { + c.lock.Lock() + defer c.lock.Unlock() + err := c.Conn.WriteJSON(v) + return err +} + // Server is a WebSocket server. type Server struct { upgrader websocket.Upgrader - listeners map[string][]*websocket.Conn + listeners map[string][]*wsClient mu sync.RWMutex auth *auth.Auth hub Hub @@ -64,7 +76,7 @@ type WebsocketCommand struct { } type websocketSession struct { - client *websocket.Conn + client *wsClient isAuthenticated bool workspaceID string } @@ -72,7 +84,7 @@ type websocketSession struct { // NewServer creates a new Server. func NewServer(auth *auth.Auth, singleUserToken string, isMattermostAuth bool, logger *mlog.Logger) *Server { return &Server{ - listeners: make(map[string][]*websocket.Conn), + listeners: make(map[string][]*wsClient), upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true @@ -98,36 +110,34 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request return } - // Make sure we close the connection when the function returns - defer func() { - ws.logger.Debug("DISCONNECT WebSocket onChange", mlog.Stringer("client", client.RemoteAddr())) - - // Remove client from listeners - ws.removeListener(client) - - client.Close() - }() - userID := "" if ws.isMattermostAuth { userID = r.Header.Get("Mattermost-User-Id") } wsSession := websocketSession{ - client: client, + client: &wsClient{client, &sync.RWMutex{}}, isAuthenticated: userID != "", } + // Make sure we close the connection when the function returns + defer func() { + ws.logger.Debug("DISCONNECT WebSocket onChange", mlog.Stringer("client", wsSession.client.RemoteAddr())) + + // Remove client from listeners + ws.removeListener(wsSession.client) + wsSession.client.Close() + }() + // Simple message handling loop for { - _, p, err := client.ReadMessage() + _, p, err := wsSession.client.ReadMessage() if err != nil { ws.logger.Error("ERROR WebSocket onChange", - mlog.Stringer("client", client.RemoteAddr()), + mlog.Stringer("client", wsSession.client.RemoteAddr()), mlog.Err(err), ) - ws.removeListener(client) - + ws.removeListener(wsSession.client) break } @@ -152,20 +162,20 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request switch command.Action { case "AUTH": - ws.logger.Debug(`Command: AUTH`, mlog.Stringer("client", client.RemoteAddr())) + ws.logger.Debug(`Command: AUTH`, mlog.Stringer("client", wsSession.client.RemoteAddr())) ws.authenticateListener(&wsSession, command.WorkspaceID, command.Token) case "ADD": ws.logger.Debug(`Command: ADD`, mlog.String("workspaceID", wsSession.workspaceID), mlog.Array("blockIDs", command.BlockIDs), - mlog.Stringer("client", client.RemoteAddr()), + mlog.Stringer("client", wsSession.client.RemoteAddr()), ) ws.addListener(&wsSession, &command) case "REMOVE": ws.logger.Debug(`Command: REMOVE`, mlog.String("workspaceID", wsSession.workspaceID), mlog.Array("blockIDs", command.BlockIDs), - mlog.Stringer("client", client.RemoteAddr()), + mlog.Stringer("client", wsSession.client.RemoteAddr()), ) ws.removeListenerFromBlocks(&wsSession, &command) @@ -258,7 +268,7 @@ func (ws *Server) addListener(wsSession *websocketSession, command *WebsocketCom for _, blockID := range command.BlockIDs { itemID := makeItemID(workspaceID, blockID) if ws.listeners[itemID] == nil { - ws.listeners[itemID] = []*websocket.Conn{} + ws.listeners[itemID] = []*wsClient{} } ws.listeners[itemID] = append(ws.listeners[itemID], wsSession.client) @@ -267,10 +277,10 @@ func (ws *Server) addListener(wsSession *websocketSession, command *WebsocketCom } // removeListener removes a webSocket listener from all blocks. -func (ws *Server) removeListener(client *websocket.Conn) { +func (ws *Server) removeListener(client *wsClient) { ws.mu.Lock() for key, clients := range ws.listeners { - listeners := []*websocket.Conn{} + listeners := []*wsClient{} for _, existingClient := range clients { if client != existingClient { @@ -315,15 +325,15 @@ func (ws *Server) removeListenerFromBlocks(wsSession *websocketSession, command ws.mu.Unlock() } -func (ws *Server) sendError(conn *websocket.Conn, message string) { +func (ws *Server) sendError(wsClient *wsClient, message string) { errorMsg := ErrorMsg{ Error: message, } - err := conn.WriteJSON(errorMsg) + err := wsClient.WriteJSON(errorMsg) if err != nil { ws.logger.Error("sendError error", mlog.Err(err)) - conn.Close() + wsClient.Close() } } @@ -358,7 +368,7 @@ func (ws *Server) SetHub(hub Hub) { } // getListeners returns the listeners to a blockID's changes. -func (ws *Server) getListeners(workspaceID string, blockID string) []*websocket.Conn { +func (ws *Server) getListeners(workspaceID string, blockID string) []*wsClient { ws.mu.Lock() itemID := makeItemID(workspaceID, blockID) listeners := ws.listeners[itemID]