Prevented concurrent writes to websocket (#658)

* retained individual connection objects

* unlocking lock in defer

* Completely abstracted internal connection object

* Completely removed direct use of WS connection
This commit is contained in:
Harshil Sharma 2021-07-01 11:41:29 +05:30 committed by GitHub
parent ba69c8b083
commit 1020c03924
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -24,10 +24,22 @@ type Hub interface {
SetReceiveWSMessage(func(data []byte))
}
type wsClient struct {
*websocket.Conn
lock *sync.RWMutex
}
func (c *wsClient) WriteJSON(v interface{}) error {
c.lock.Lock()
defer c.lock.Unlock()
err := c.Conn.WriteJSON(v)
return err
}
// Server is a WebSocket server.
type Server struct {
upgrader websocket.Upgrader
listeners map[string][]*websocket.Conn
listeners map[string][]*wsClient
mu sync.RWMutex
auth *auth.Auth
hub Hub
@ -64,7 +76,7 @@ type WebsocketCommand struct {
}
type websocketSession struct {
client *websocket.Conn
client *wsClient
isAuthenticated bool
workspaceID string
}
@ -72,7 +84,7 @@ type websocketSession struct {
// NewServer creates a new Server.
func NewServer(auth *auth.Auth, singleUserToken string, isMattermostAuth bool, logger *mlog.Logger) *Server {
return &Server{
listeners: make(map[string][]*websocket.Conn),
listeners: make(map[string][]*wsClient),
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
@ -98,36 +110,34 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request
return
}
// Make sure we close the connection when the function returns
defer func() {
ws.logger.Debug("DISCONNECT WebSocket onChange", mlog.Stringer("client", client.RemoteAddr()))
// Remove client from listeners
ws.removeListener(client)
client.Close()
}()
userID := ""
if ws.isMattermostAuth {
userID = r.Header.Get("Mattermost-User-Id")
}
wsSession := websocketSession{
client: client,
client: &wsClient{client, &sync.RWMutex{}},
isAuthenticated: userID != "",
}
// Make sure we close the connection when the function returns
defer func() {
ws.logger.Debug("DISCONNECT WebSocket onChange", mlog.Stringer("client", wsSession.client.RemoteAddr()))
// Remove client from listeners
ws.removeListener(wsSession.client)
wsSession.client.Close()
}()
// Simple message handling loop
for {
_, p, err := client.ReadMessage()
_, p, err := wsSession.client.ReadMessage()
if err != nil {
ws.logger.Error("ERROR WebSocket onChange",
mlog.Stringer("client", client.RemoteAddr()),
mlog.Stringer("client", wsSession.client.RemoteAddr()),
mlog.Err(err),
)
ws.removeListener(client)
ws.removeListener(wsSession.client)
break
}
@ -152,20 +162,20 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request
switch command.Action {
case "AUTH":
ws.logger.Debug(`Command: AUTH`, mlog.Stringer("client", client.RemoteAddr()))
ws.logger.Debug(`Command: AUTH`, mlog.Stringer("client", wsSession.client.RemoteAddr()))
ws.authenticateListener(&wsSession, command.WorkspaceID, command.Token)
case "ADD":
ws.logger.Debug(`Command: ADD`,
mlog.String("workspaceID", wsSession.workspaceID),
mlog.Array("blockIDs", command.BlockIDs),
mlog.Stringer("client", client.RemoteAddr()),
mlog.Stringer("client", wsSession.client.RemoteAddr()),
)
ws.addListener(&wsSession, &command)
case "REMOVE":
ws.logger.Debug(`Command: REMOVE`,
mlog.String("workspaceID", wsSession.workspaceID),
mlog.Array("blockIDs", command.BlockIDs),
mlog.Stringer("client", client.RemoteAddr()),
mlog.Stringer("client", wsSession.client.RemoteAddr()),
)
ws.removeListenerFromBlocks(&wsSession, &command)
@ -258,7 +268,7 @@ func (ws *Server) addListener(wsSession *websocketSession, command *WebsocketCom
for _, blockID := range command.BlockIDs {
itemID := makeItemID(workspaceID, blockID)
if ws.listeners[itemID] == nil {
ws.listeners[itemID] = []*websocket.Conn{}
ws.listeners[itemID] = []*wsClient{}
}
ws.listeners[itemID] = append(ws.listeners[itemID], wsSession.client)
@ -267,10 +277,10 @@ func (ws *Server) addListener(wsSession *websocketSession, command *WebsocketCom
}
// removeListener removes a webSocket listener from all blocks.
func (ws *Server) removeListener(client *websocket.Conn) {
func (ws *Server) removeListener(client *wsClient) {
ws.mu.Lock()
for key, clients := range ws.listeners {
listeners := []*websocket.Conn{}
listeners := []*wsClient{}
for _, existingClient := range clients {
if client != existingClient {
@ -315,15 +325,15 @@ func (ws *Server) removeListenerFromBlocks(wsSession *websocketSession, command
ws.mu.Unlock()
}
func (ws *Server) sendError(conn *websocket.Conn, message string) {
func (ws *Server) sendError(wsClient *wsClient, message string) {
errorMsg := ErrorMsg{
Error: message,
}
err := conn.WriteJSON(errorMsg)
err := wsClient.WriteJSON(errorMsg)
if err != nil {
ws.logger.Error("sendError error", mlog.Err(err))
conn.Close()
wsClient.Close()
}
}
@ -358,7 +368,7 @@ func (ws *Server) SetHub(hub Hub) {
}
// getListeners returns the listeners to a blockID's changes.
func (ws *Server) getListeners(workspaceID string, blockID string) []*websocket.Conn {
func (ws *Server) getListeners(workspaceID string, blockID string) []*wsClient {
ws.mu.Lock()
itemID := makeItemID(workspaceID, blockID)
listeners := ws.listeners[itemID]