Prevented concurrent writes to websocket (#658)
* retained individual connection objects * unlocking lock in defer * Completely abstracted internal connection object * Completely removed direct use of WS connection
This commit is contained in:
parent
ba69c8b083
commit
1020c03924
1 changed files with 38 additions and 28 deletions
|
@ -24,10 +24,22 @@ type Hub interface {
|
|||
SetReceiveWSMessage(func(data []byte))
|
||||
}
|
||||
|
||||
type wsClient struct {
|
||||
*websocket.Conn
|
||||
lock *sync.RWMutex
|
||||
}
|
||||
|
||||
func (c *wsClient) WriteJSON(v interface{}) error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
err := c.Conn.WriteJSON(v)
|
||||
return err
|
||||
}
|
||||
|
||||
// Server is a WebSocket server.
|
||||
type Server struct {
|
||||
upgrader websocket.Upgrader
|
||||
listeners map[string][]*websocket.Conn
|
||||
listeners map[string][]*wsClient
|
||||
mu sync.RWMutex
|
||||
auth *auth.Auth
|
||||
hub Hub
|
||||
|
@ -64,7 +76,7 @@ type WebsocketCommand struct {
|
|||
}
|
||||
|
||||
type websocketSession struct {
|
||||
client *websocket.Conn
|
||||
client *wsClient
|
||||
isAuthenticated bool
|
||||
workspaceID string
|
||||
}
|
||||
|
@ -72,7 +84,7 @@ type websocketSession struct {
|
|||
// NewServer creates a new Server.
|
||||
func NewServer(auth *auth.Auth, singleUserToken string, isMattermostAuth bool, logger *mlog.Logger) *Server {
|
||||
return &Server{
|
||||
listeners: make(map[string][]*websocket.Conn),
|
||||
listeners: make(map[string][]*wsClient),
|
||||
upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
|
@ -98,36 +110,34 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request
|
|||
return
|
||||
}
|
||||
|
||||
// Make sure we close the connection when the function returns
|
||||
defer func() {
|
||||
ws.logger.Debug("DISCONNECT WebSocket onChange", mlog.Stringer("client", client.RemoteAddr()))
|
||||
|
||||
// Remove client from listeners
|
||||
ws.removeListener(client)
|
||||
|
||||
client.Close()
|
||||
}()
|
||||
|
||||
userID := ""
|
||||
if ws.isMattermostAuth {
|
||||
userID = r.Header.Get("Mattermost-User-Id")
|
||||
}
|
||||
|
||||
wsSession := websocketSession{
|
||||
client: client,
|
||||
client: &wsClient{client, &sync.RWMutex{}},
|
||||
isAuthenticated: userID != "",
|
||||
}
|
||||
|
||||
// Make sure we close the connection when the function returns
|
||||
defer func() {
|
||||
ws.logger.Debug("DISCONNECT WebSocket onChange", mlog.Stringer("client", wsSession.client.RemoteAddr()))
|
||||
|
||||
// Remove client from listeners
|
||||
ws.removeListener(wsSession.client)
|
||||
wsSession.client.Close()
|
||||
}()
|
||||
|
||||
// Simple message handling loop
|
||||
for {
|
||||
_, p, err := client.ReadMessage()
|
||||
_, p, err := wsSession.client.ReadMessage()
|
||||
if err != nil {
|
||||
ws.logger.Error("ERROR WebSocket onChange",
|
||||
mlog.Stringer("client", client.RemoteAddr()),
|
||||
mlog.Stringer("client", wsSession.client.RemoteAddr()),
|
||||
mlog.Err(err),
|
||||
)
|
||||
ws.removeListener(client)
|
||||
|
||||
ws.removeListener(wsSession.client)
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -152,20 +162,20 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request
|
|||
|
||||
switch command.Action {
|
||||
case "AUTH":
|
||||
ws.logger.Debug(`Command: AUTH`, mlog.Stringer("client", client.RemoteAddr()))
|
||||
ws.logger.Debug(`Command: AUTH`, mlog.Stringer("client", wsSession.client.RemoteAddr()))
|
||||
ws.authenticateListener(&wsSession, command.WorkspaceID, command.Token)
|
||||
case "ADD":
|
||||
ws.logger.Debug(`Command: ADD`,
|
||||
mlog.String("workspaceID", wsSession.workspaceID),
|
||||
mlog.Array("blockIDs", command.BlockIDs),
|
||||
mlog.Stringer("client", client.RemoteAddr()),
|
||||
mlog.Stringer("client", wsSession.client.RemoteAddr()),
|
||||
)
|
||||
ws.addListener(&wsSession, &command)
|
||||
case "REMOVE":
|
||||
ws.logger.Debug(`Command: REMOVE`,
|
||||
mlog.String("workspaceID", wsSession.workspaceID),
|
||||
mlog.Array("blockIDs", command.BlockIDs),
|
||||
mlog.Stringer("client", client.RemoteAddr()),
|
||||
mlog.Stringer("client", wsSession.client.RemoteAddr()),
|
||||
)
|
||||
|
||||
ws.removeListenerFromBlocks(&wsSession, &command)
|
||||
|
@ -258,7 +268,7 @@ func (ws *Server) addListener(wsSession *websocketSession, command *WebsocketCom
|
|||
for _, blockID := range command.BlockIDs {
|
||||
itemID := makeItemID(workspaceID, blockID)
|
||||
if ws.listeners[itemID] == nil {
|
||||
ws.listeners[itemID] = []*websocket.Conn{}
|
||||
ws.listeners[itemID] = []*wsClient{}
|
||||
}
|
||||
|
||||
ws.listeners[itemID] = append(ws.listeners[itemID], wsSession.client)
|
||||
|
@ -267,10 +277,10 @@ func (ws *Server) addListener(wsSession *websocketSession, command *WebsocketCom
|
|||
}
|
||||
|
||||
// removeListener removes a webSocket listener from all blocks.
|
||||
func (ws *Server) removeListener(client *websocket.Conn) {
|
||||
func (ws *Server) removeListener(client *wsClient) {
|
||||
ws.mu.Lock()
|
||||
for key, clients := range ws.listeners {
|
||||
listeners := []*websocket.Conn{}
|
||||
listeners := []*wsClient{}
|
||||
|
||||
for _, existingClient := range clients {
|
||||
if client != existingClient {
|
||||
|
@ -315,15 +325,15 @@ func (ws *Server) removeListenerFromBlocks(wsSession *websocketSession, command
|
|||
ws.mu.Unlock()
|
||||
}
|
||||
|
||||
func (ws *Server) sendError(conn *websocket.Conn, message string) {
|
||||
func (ws *Server) sendError(wsClient *wsClient, message string) {
|
||||
errorMsg := ErrorMsg{
|
||||
Error: message,
|
||||
}
|
||||
|
||||
err := conn.WriteJSON(errorMsg)
|
||||
err := wsClient.WriteJSON(errorMsg)
|
||||
if err != nil {
|
||||
ws.logger.Error("sendError error", mlog.Err(err))
|
||||
conn.Close()
|
||||
wsClient.Close()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -358,7 +368,7 @@ func (ws *Server) SetHub(hub Hub) {
|
|||
}
|
||||
|
||||
// getListeners returns the listeners to a blockID's changes.
|
||||
func (ws *Server) getListeners(workspaceID string, blockID string) []*websocket.Conn {
|
||||
func (ws *Server) getListeners(workspaceID string, blockID string) []*wsClient {
|
||||
ws.mu.Lock()
|
||||
itemID := makeItemID(workspaceID, blockID)
|
||||
listeners := ws.listeners[itemID]
|
||||
|
|
Loading…
Reference in a new issue