Merge branch 'main' into requested-with-header

This commit is contained in:
Chen-I Lim 2021-02-02 18:24:38 -08:00
commit 2237169fe0
6 changed files with 95 additions and 27 deletions

View file

@ -265,19 +265,13 @@ func (a *API) handleGetSubTree(w http.ResponseWriter, r *http.Request) {
return return
} }
rootID, err := a.app().GetRootID(blockID) isValid, err := a.app().IsValidReadToken(blockID, readToken)
if err != nil { if err != nil {
errorResponse(w, http.StatusInternalServerError, nil, err) errorResponse(w, http.StatusInternalServerError, nil, err)
return return
} }
sharing, err := a.app().GetSharing(rootID) if !isValid {
if err != nil {
errorResponse(w, http.StatusInternalServerError, nil, err)
return
}
if sharing == nil || !(sharing.ID == rootID && sharing.Enabled && sharing.Token == readToken) {
errorResponse(w, http.StatusUnauthorized, nil, nil) errorResponse(w, http.StatusUnauthorized, nil, nil)
return return
} }

View file

@ -15,6 +15,11 @@ func (a *App) GetSession(token string) (*model.Session, error) {
return a.auth.GetSession(token) return a.auth.GetSession(token)
} }
// IsValidReadToken validates the read token for a block
func (a *App) IsValidReadToken(blockID string, readToken string) (bool, error) {
return a.auth.IsValidReadToken(blockID, readToken)
}
// GetRegisteredUserCount returns the number of registered users // GetRegisteredUserCount returns the number of registered users
func (a *App) GetRegisteredUserCount() (int, error) { func (a *App) GetRegisteredUserCount() (int, error) {
return a.store.GetRegisteredUserCount() return a.store.GetRegisteredUserCount()

View file

@ -1,6 +1,7 @@
package auth package auth
import ( import (
"database/sql"
"time" "time"
"github.com/mattermost/focalboard/server/model" "github.com/mattermost/focalboard/server/model"
@ -35,3 +36,25 @@ func (a *Auth) GetSession(token string) (*model.Session, error) {
} }
return session, nil return session, nil
} }
// IsValidReadToken validates the read token for a block
func (a *Auth) IsValidReadToken(blockID string, readToken string) (bool, error) {
rootID, err := a.store.GetRootID(blockID)
if err != nil {
return false, err
}
sharing, err := a.store.GetSharing(rootID)
if err == sql.ErrNoRows {
return false, nil
}
if err != nil {
return false, err
}
if sharing != nil && (sharing.ID == rootID && sharing.Enabled && sharing.Token == readToken) {
return true, nil
}
return false, nil
}

View file

@ -38,9 +38,10 @@ type ErrorMsg struct {
// WebsocketCommand is an incoming command from the client. // WebsocketCommand is an incoming command from the client.
type WebsocketCommand struct { type WebsocketCommand struct {
Action string `json:"action"` Action string `json:"action"`
Token string `json:"token"` Token string `json:"token"`
BlockIDs []string `json:"blockIds"` ReadToken string `json:"readToken"`
BlockIDs []string `json:"blockIds"`
} }
type websocketSession struct { type websocketSession struct {
@ -116,15 +117,15 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request
switch command.Action { switch command.Action {
case "AUTH": case "AUTH":
log.Printf(`Command: AUTH, client: %s`, client.RemoteAddr()) log.Printf(`Command: AUTH, client: %s`, client.RemoteAddr())
ws.authenticateListener(&wsSession, command.Token) ws.authenticateListener(&wsSession, command.Token, command.ReadToken)
case "ADD": case "ADD":
log.Printf(`Command: Add blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr()) log.Printf(`Command: Add blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr())
ws.addListener(&wsSession, command.BlockIDs) ws.addListener(&wsSession, &command)
case "REMOVE": case "REMOVE":
log.Printf(`Command: Remove blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr()) log.Printf(`Command: Remove blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr())
ws.removeListenerFromBlocks(&wsSession, command.BlockIDs) ws.removeListenerFromBlocks(&wsSession, &command)
default: default:
log.Printf(`ERROR webSocket command, invalid action: %v`, command.Action) log.Printf(`ERROR webSocket command, invalid action: %v`, command.Action)
@ -138,14 +139,15 @@ func (ws *Server) isValidSessionToken(token string) bool {
} }
session, err := ws.auth.GetSession(token) session, err := ws.auth.GetSession(token)
if session == nil || err != nil { if session != nil && err == nil {
return false return true
} }
return true return false
} }
func (ws *Server) authenticateListener(wsSession *websocketSession, token string) { func (ws *Server) authenticateListener(wsSession *websocketSession, token string, readToken string) {
// Authenticate session
isValidSession := ws.isValidSessionToken(token) isValidSession := ws.isValidSessionToken(token)
if !isValidSession { if !isValidSession {
wsSession.client.Close() wsSession.client.Close()
@ -157,16 +159,39 @@ func (ws *Server) authenticateListener(wsSession *websocketSession, token string
log.Printf("authenticateListener: Authenticated") log.Printf("authenticateListener: Authenticated")
} }
func (ws *Server) checkAuthentication(wsSession *websocketSession, command *WebsocketCommand) bool {
if ws.singleUser {
return true
}
if wsSession.isAuthenticated {
return true
}
if len(command.ReadToken) > 0 {
// Read token must be valid for all block IDs
for _, blockID := range command.BlockIDs {
isValid, _ := ws.auth.IsValidReadToken(blockID, command.ReadToken)
if !isValid {
return false
}
}
return true
}
return false
}
// addListener adds a listener for a block's change. // addListener adds a listener for a block's change.
func (ws *Server) addListener(wsSession *websocketSession, blockIDs []string) { func (ws *Server) addListener(wsSession *websocketSession, command *WebsocketCommand) {
if !wsSession.isAuthenticated { if !ws.checkAuthentication(wsSession, command) {
log.Printf("addListener: NOT AUTHENTICATED") log.Printf("addListener: NOT AUTHENTICATED")
sendError(wsSession.client, "not authenticated") sendError(wsSession.client, "not authenticated")
return return
} }
ws.mu.Lock() ws.mu.Lock()
for _, blockID := range blockIDs { for _, blockID := range command.BlockIDs {
if ws.listeners[blockID] == nil { if ws.listeners[blockID] == nil {
ws.listeners[blockID] = []*websocket.Conn{} ws.listeners[blockID] = []*websocket.Conn{}
} }
@ -194,8 +219,8 @@ func (ws *Server) removeListener(client *websocket.Conn) {
} }
// removeListenerFromBlocks removes a webSocket listener from a set of block. // removeListenerFromBlocks removes a webSocket listener from a set of block.
func (ws *Server) removeListenerFromBlocks(wsSession *websocketSession, blockIDs []string) { func (ws *Server) removeListenerFromBlocks(wsSession *websocketSession, command *WebsocketCommand) {
if !wsSession.isAuthenticated { if !ws.checkAuthentication(wsSession, command) {
log.Printf("removeListenerFromBlocks: NOT AUTHENTICATED") log.Printf("removeListenerFromBlocks: NOT AUTHENTICATED")
sendError(wsSession.client, "not authenticated") sendError(wsSession.client, "not authenticated")
return return
@ -203,7 +228,7 @@ func (ws *Server) removeListenerFromBlocks(wsSession *websocketSession, blockIDs
ws.mu.Lock() ws.mu.Lock()
for _, blockID := range blockIDs { for _, blockID := range command.BlockIDs {
listeners := ws.listeners[blockID] listeners := ws.listeners[blockID]
if listeners == nil { if listeners == nil {
return return

View file

@ -6,6 +6,7 @@ import {Utils} from './utils'
// These are outgoing commands to the server // These are outgoing commands to the server
type WSCommand = { type WSCommand = {
action: string action: string
readToken?: string
blockIds: string[] blockIds: string[]
} }
@ -28,6 +29,7 @@ class OctoListener {
readonly serverUrl: string readonly serverUrl: string
private token: string private token: string
private readToken: string
private ws?: WebSocket private ws?: WebSocket
private blockIds: string[] = [] private blockIds: string[] = []
private isInitialized = false private isInitialized = false
@ -39,12 +41,19 @@ class OctoListener {
notificationDelay = 100 notificationDelay = 100
reopenDelay = 3000 reopenDelay = 3000
constructor(serverUrl?: string, token?: string) { constructor(serverUrl?: string, token?: string, readToken?: string) {
this.serverUrl = serverUrl || window.location.origin this.serverUrl = serverUrl || window.location.origin
this.token = token || localStorage.getItem('sessionId') || '' this.token = token || localStorage.getItem('sessionId') || ''
this.readToken = readToken || OctoListener.getReadToken()
Utils.log(`OctoListener serverUrl: ${this.serverUrl}`) Utils.log(`OctoListener serverUrl: ${this.serverUrl}`)
} }
static getReadToken(): string {
const queryString = new URLSearchParams(window.location.search)
const readToken = queryString.get('r') || ''
return readToken
}
open(blockIds: string[], onChange: OnChangeHandler, onReconnect: () => void): void { open(blockIds: string[], onChange: OnChangeHandler, onReconnect: () => void): void {
if (this.ws) { if (this.ws) {
this.close() this.close()
@ -133,11 +142,13 @@ class OctoListener {
return return
} }
if (!this.token) {
return
}
const command = { const command = {
action: 'AUTH', action: 'AUTH',
token: this.token, token: this.token,
} }
this.ws.send(JSON.stringify(command)) this.ws.send(JSON.stringify(command))
} }
@ -150,6 +161,7 @@ class OctoListener {
const command: WSCommand = { const command: WSCommand = {
action: 'ADD', action: 'ADD',
blockIds, blockIds,
readToken: this.readToken,
} }
this.ws.send(JSON.stringify(command)) this.ws.send(JSON.stringify(command))
@ -165,6 +177,7 @@ class OctoListener {
const command: WSCommand = { const command: WSCommand = {
action: 'REMOVE', action: 'REMOVE',
blockIds, blockIds,
readToken: this.readToken,
} }
this.ws.send(JSON.stringify(command)) this.ws.send(JSON.stringify(command))

View file

@ -185,9 +185,17 @@ class BoardPage extends React.Component<Props, State> {
const boardIds = [...workspaceTree.boards.map((o) => o.id), ...workspaceTree.boardTemplates.map((o) => o.id)] const boardIds = [...workspaceTree.boards.map((o) => o.id), ...workspaceTree.boardTemplates.map((o) => o.id)]
this.setState({workspaceTree}) this.setState({workspaceTree})
let boardIdsToListen: string[]
if (boardIds.length > 0) {
boardIdsToListen = ['', ...boardIds]
} else {
// Read-only view
boardIdsToListen = [this.state.boardId]
}
// Listen to boards plus all blocks at root (Empty string for parentId) // Listen to boards plus all blocks at root (Empty string for parentId)
this.workspaceListener.open( this.workspaceListener.open(
['', ...boardIds], boardIdsToListen,
async (blocks) => { async (blocks) => {
Utils.log(`workspaceListener.onChanged: ${blocks.length}`) Utils.log(`workspaceListener.onChanged: ${blocks.length}`)
this.incrementalUpdate(blocks) this.incrementalUpdate(blocks)