Auth readToken in websocket

This commit is contained in:
Chen-I Lim 2021-02-02 18:15:03 -08:00
parent 26f73135bc
commit 79b79b35bc
6 changed files with 95 additions and 27 deletions

View file

@ -234,19 +234,13 @@ func (a *API) handleGetSubTree(w http.ResponseWriter, r *http.Request) {
return
}
rootID, err := a.app().GetRootID(blockID)
isValid, err := a.app().IsValidReadToken(blockID, readToken)
if err != nil {
errorResponse(w, http.StatusInternalServerError, nil, err)
return
}
sharing, err := a.app().GetSharing(rootID)
if err != nil {
errorResponse(w, http.StatusInternalServerError, nil, err)
return
}
if sharing == nil || !(sharing.ID == rootID && sharing.Enabled && sharing.Token == readToken) {
if !isValid {
errorResponse(w, http.StatusUnauthorized, nil, nil)
return
}

View file

@ -15,6 +15,11 @@ func (a *App) GetSession(token string) (*model.Session, error) {
return a.auth.GetSession(token)
}
// IsValidReadToken validates the read token for a block
func (a *App) IsValidReadToken(blockID string, readToken string) (bool, error) {
return a.auth.IsValidReadToken(blockID, readToken)
}
// GetRegisteredUserCount returns the number of registered users
func (a *App) GetRegisteredUserCount() (int, error) {
return a.store.GetRegisteredUserCount()

View file

@ -1,6 +1,7 @@
package auth
import (
"database/sql"
"time"
"github.com/mattermost/focalboard/server/model"
@ -35,3 +36,25 @@ func (a *Auth) GetSession(token string) (*model.Session, error) {
}
return session, nil
}
// IsValidReadToken validates the read token for a block
func (a *Auth) IsValidReadToken(blockID string, readToken string) (bool, error) {
rootID, err := a.store.GetRootID(blockID)
if err != nil {
return false, err
}
sharing, err := a.store.GetSharing(rootID)
if err == sql.ErrNoRows {
return false, nil
}
if err != nil {
return false, err
}
if sharing != nil && (sharing.ID == rootID && sharing.Enabled && sharing.Token == readToken) {
return true, nil
}
return false, nil
}

View file

@ -38,9 +38,10 @@ type ErrorMsg struct {
// WebsocketCommand is an incoming command from the client.
type WebsocketCommand struct {
Action string `json:"action"`
Token string `json:"token"`
BlockIDs []string `json:"blockIds"`
Action string `json:"action"`
Token string `json:"token"`
ReadToken string `json:"readToken"`
BlockIDs []string `json:"blockIds"`
}
type websocketSession struct {
@ -116,15 +117,15 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request
switch command.Action {
case "AUTH":
log.Printf(`Command: AUTH, client: %s`, client.RemoteAddr())
ws.authenticateListener(&wsSession, command.Token)
ws.authenticateListener(&wsSession, command.Token, command.ReadToken)
case "ADD":
log.Printf(`Command: Add blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr())
ws.addListener(&wsSession, command.BlockIDs)
ws.addListener(&wsSession, &command)
case "REMOVE":
log.Printf(`Command: Remove blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr())
ws.removeListenerFromBlocks(&wsSession, command.BlockIDs)
ws.removeListenerFromBlocks(&wsSession, &command)
default:
log.Printf(`ERROR webSocket command, invalid action: %v`, command.Action)
@ -138,14 +139,15 @@ func (ws *Server) isValidSessionToken(token string) bool {
}
session, err := ws.auth.GetSession(token)
if session == nil || err != nil {
return false
if session != nil && err == nil {
return true
}
return true
return false
}
func (ws *Server) authenticateListener(wsSession *websocketSession, token string) {
func (ws *Server) authenticateListener(wsSession *websocketSession, token string, readToken string) {
// Authenticate session
isValidSession := ws.isValidSessionToken(token)
if !isValidSession {
wsSession.client.Close()
@ -157,16 +159,39 @@ func (ws *Server) authenticateListener(wsSession *websocketSession, token string
log.Printf("authenticateListener: Authenticated")
}
func (ws *Server) checkAuthentication(wsSession *websocketSession, command *WebsocketCommand) bool {
if ws.singleUser {
return true
}
if wsSession.isAuthenticated {
return true
}
if len(command.ReadToken) > 0 {
// Read token must be valid for all block IDs
for _, blockID := range command.BlockIDs {
isValid, _ := ws.auth.IsValidReadToken(blockID, command.ReadToken)
if !isValid {
return false
}
}
return true
}
return false
}
// addListener adds a listener for a block's change.
func (ws *Server) addListener(wsSession *websocketSession, blockIDs []string) {
if !wsSession.isAuthenticated {
func (ws *Server) addListener(wsSession *websocketSession, command *WebsocketCommand) {
if !ws.checkAuthentication(wsSession, command) {
log.Printf("addListener: NOT AUTHENTICATED")
sendError(wsSession.client, "not authenticated")
return
}
ws.mu.Lock()
for _, blockID := range blockIDs {
for _, blockID := range command.BlockIDs {
if ws.listeners[blockID] == nil {
ws.listeners[blockID] = []*websocket.Conn{}
}
@ -194,8 +219,8 @@ func (ws *Server) removeListener(client *websocket.Conn) {
}
// removeListenerFromBlocks removes a webSocket listener from a set of block.
func (ws *Server) removeListenerFromBlocks(wsSession *websocketSession, blockIDs []string) {
if !wsSession.isAuthenticated {
func (ws *Server) removeListenerFromBlocks(wsSession *websocketSession, command *WebsocketCommand) {
if !ws.checkAuthentication(wsSession, command) {
log.Printf("removeListenerFromBlocks: NOT AUTHENTICATED")
sendError(wsSession.client, "not authenticated")
return
@ -203,7 +228,7 @@ func (ws *Server) removeListenerFromBlocks(wsSession *websocketSession, blockIDs
ws.mu.Lock()
for _, blockID := range blockIDs {
for _, blockID := range command.BlockIDs {
listeners := ws.listeners[blockID]
if listeners == nil {
return

View file

@ -6,6 +6,7 @@ import {Utils} from './utils'
// These are outgoing commands to the server
type WSCommand = {
action: string
readToken?: string
blockIds: string[]
}
@ -28,6 +29,7 @@ class OctoListener {
readonly serverUrl: string
private token: string
private readToken: string
private ws?: WebSocket
private blockIds: string[] = []
private isInitialized = false
@ -39,12 +41,19 @@ class OctoListener {
notificationDelay = 100
reopenDelay = 3000
constructor(serverUrl?: string, token?: string) {
constructor(serverUrl?: string, token?: string, readToken?: string) {
this.serverUrl = serverUrl || window.location.origin
this.token = token || localStorage.getItem('sessionId') || ''
this.readToken = readToken || OctoListener.getReadToken()
Utils.log(`OctoListener serverUrl: ${this.serverUrl}`)
}
static getReadToken(): string {
const queryString = new URLSearchParams(window.location.search)
const readToken = queryString.get('r') || ''
return readToken
}
open(blockIds: string[], onChange: OnChangeHandler, onReconnect: () => void): void {
if (this.ws) {
this.close()
@ -133,11 +142,13 @@ class OctoListener {
return
}
if (!this.token) {
return
}
const command = {
action: 'AUTH',
token: this.token,
}
this.ws.send(JSON.stringify(command))
}
@ -150,6 +161,7 @@ class OctoListener {
const command: WSCommand = {
action: 'ADD',
blockIds,
readToken: this.readToken,
}
this.ws.send(JSON.stringify(command))
@ -165,6 +177,7 @@ class OctoListener {
const command: WSCommand = {
action: 'REMOVE',
blockIds,
readToken: this.readToken,
}
this.ws.send(JSON.stringify(command))

View file

@ -185,9 +185,17 @@ class BoardPage extends React.Component<Props, State> {
const boardIds = [...workspaceTree.boards.map((o) => o.id), ...workspaceTree.boardTemplates.map((o) => o.id)]
this.setState({workspaceTree})
let boardIdsToListen: string[]
if (boardIds.length > 0) {
boardIdsToListen = ['', ...boardIds]
} else {
// Read-only view
boardIdsToListen = [this.state.boardId]
}
// Listen to boards plus all blocks at root (Empty string for parentId)
this.workspaceListener.open(
['', ...boardIds],
boardIdsToListen,
async (blocks) => {
Utils.log(`workspaceListener.onChanged: ${blocks.length}`)
this.incrementalUpdate(blocks)