Merge pull request #46 from mattermost/websocket_auth

Websocket auth
This commit is contained in:
Chen-I Lim 2021-02-02 13:53:24 -08:00 committed by GitHub
commit 7256fb4b5a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 257 additions and 117 deletions

View file

@ -1,6 +1,7 @@
package app package app
import ( import (
"github.com/mattermost/focalboard/server/auth"
"github.com/mattermost/focalboard/server/services/config" "github.com/mattermost/focalboard/server/services/config"
"github.com/mattermost/focalboard/server/services/store" "github.com/mattermost/focalboard/server/services/store"
"github.com/mattermost/focalboard/server/services/webhook" "github.com/mattermost/focalboard/server/services/webhook"
@ -11,11 +12,26 @@ import (
type App struct { type App struct {
config *config.Configuration config *config.Configuration
store store.Store store store.Store
auth *auth.Auth
wsServer *ws.Server wsServer *ws.Server
filesBackend filesstore.FileBackend filesBackend filesstore.FileBackend
webhook *webhook.Client webhook *webhook.Client
} }
func New(config *config.Configuration, store store.Store, wsServer *ws.Server, filesBackend filesstore.FileBackend, webhook *webhook.Client) *App { func New(
return &App{config: config, store: store, wsServer: wsServer, filesBackend: filesBackend, webhook: webhook} config *config.Configuration,
store store.Store,
auth *auth.Auth,
wsServer *ws.Server,
filesBackend filesstore.FileBackend,
webhook *webhook.Client,
) *App {
return &App{
config: config,
store: store,
auth: auth,
wsServer: wsServer,
filesBackend: filesBackend,
webhook: webhook,
}
} }

View file

@ -2,7 +2,6 @@ package app
import ( import (
"log" "log"
"time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mattermost/focalboard/server/model" "github.com/mattermost/focalboard/server/model"
@ -13,18 +12,7 @@ import (
// GetSession Get a user active session and refresh the session if is needed // GetSession Get a user active session and refresh the session if is needed
func (a *App) GetSession(token string) (*model.Session, error) { func (a *App) GetSession(token string) (*model.Session, error) {
if len(token) < 1 { return a.auth.GetSession(token)
return nil, errors.New("no session token")
}
session, err := a.store.GetSession(token, a.config.SessionExpireTime)
if err != nil {
return nil, errors.Wrap(err, "unable to get the session for the token")
}
if session.UpdateAt < (time.Now().Unix() - a.config.SessionRefreshTime) {
a.store.RefreshSession(session)
}
return session, nil
} }
// GetRegisteredUserCount returns the number of registered users // GetRegisteredUserCount returns the number of registered users

View file

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/mattermost/focalboard/server/auth"
"github.com/mattermost/focalboard/server/services/config" "github.com/mattermost/focalboard/server/services/config"
"github.com/mattermost/focalboard/server/services/store/mockstore" "github.com/mattermost/focalboard/server/services/store/mockstore"
"github.com/mattermost/focalboard/server/services/webhook" "github.com/mattermost/focalboard/server/services/webhook"
@ -18,9 +19,10 @@ func TestGetParentID(t *testing.T) {
defer ctrl.Finish() defer ctrl.Finish()
cfg := config.Configuration{} cfg := config.Configuration{}
store := mockstore.NewMockStore(ctrl) store := mockstore.NewMockStore(ctrl)
wsserver := ws.NewServer() auth := auth.New(&cfg, store)
wsserver := ws.NewServer(auth, true)
webhook := webhook.NewClient(&cfg) webhook := webhook.NewClient(&cfg)
app := New(&cfg, store, wsserver, &mocks.FileBackend{}, webhook) app := New(&cfg, store, auth, wsserver, &mocks.FileBackend{}, webhook)
t.Run("success query", func(t *testing.T) { t.Run("success query", func(t *testing.T) {
store.EXPECT().GetParentID(gomock.Eq("test-id")).Return("test-parent-id", nil) store.EXPECT().GetParentID(gomock.Eq("test-id")).Return("test-parent-id", nil)

37
server/auth/auth.go Normal file
View file

@ -0,0 +1,37 @@
package auth
import (
"time"
"github.com/mattermost/focalboard/server/model"
"github.com/mattermost/focalboard/server/services/config"
"github.com/mattermost/focalboard/server/services/store"
"github.com/pkg/errors"
)
// Auth authenticates sessions
type Auth struct {
config *config.Configuration
store store.Store
}
// New returns a new Auth
func New(config *config.Configuration, store store.Store) *Auth {
return &Auth{config: config, store: store}
}
// GetSession Get a user active session and refresh the session if is needed
func (a *Auth) GetSession(token string) (*model.Session, error) {
if len(token) < 1 {
return nil, errors.New("no session token")
}
session, err := a.store.GetSession(token, a.config.SessionExpireTime)
if err != nil {
return nil, errors.Wrap(err, "unable to get the session for the token")
}
if session.UpdateAt < (time.Now().Unix() - a.config.SessionRefreshTime) {
a.store.RefreshSession(session)
}
return session, nil
}

View file

@ -17,6 +17,7 @@ import (
"github.com/mattermost/focalboard/server/api" "github.com/mattermost/focalboard/server/api"
"github.com/mattermost/focalboard/server/app" "github.com/mattermost/focalboard/server/app"
"github.com/mattermost/focalboard/server/auth"
"github.com/mattermost/focalboard/server/context" "github.com/mattermost/focalboard/server/context"
appModel "github.com/mattermost/focalboard/server/model" appModel "github.com/mattermost/focalboard/server/model"
"github.com/mattermost/focalboard/server/services/config" "github.com/mattermost/focalboard/server/services/config"
@ -60,7 +61,9 @@ func New(cfg *config.Configuration, singleUser bool) (*Server, error) {
return nil, err return nil, err
} }
wsServer := ws.NewServer() auth := auth.New(cfg, store)
wsServer := ws.NewServer(auth, singleUser)
filesBackendSettings := model.FileSettings{} filesBackendSettings := model.FileSettings{}
filesBackendSettings.SetDefaults(false) filesBackendSettings.SetDefaults(false)
@ -74,7 +77,7 @@ func New(cfg *config.Configuration, singleUser bool) (*Server, error) {
webhookClient := webhook.NewClient(cfg) webhookClient := webhook.NewClient(cfg)
appBuilder := func() *app.App { return app.New(cfg, store, wsServer, filesBackend, webhookClient) } appBuilder := func() *app.App { return app.New(cfg, store, auth, wsServer, filesBackend, webhookClient) }
api := api.NewAPI(appBuilder, singleUser) api := api.NewAPI(appBuilder, singleUser)
// Local router for admin APIs // Local router for admin APIs

View file

@ -9,95 +9,20 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/mattermost/focalboard/server/auth"
"github.com/mattermost/focalboard/server/model" "github.com/mattermost/focalboard/server/model"
) )
// RegisterRoutes registers routes. // IsValidSessionToken authenticates session tokens
func (ws *Server) RegisterRoutes(r *mux.Router) { type IsValidSessionToken func(token string) bool
r.HandleFunc("/ws/onchange", ws.handleWebSocketOnChange)
}
// AddListener adds a listener for a block's change.
func (ws *Server) AddListener(client *websocket.Conn, blockIDs []string) {
ws.mu.Lock()
for _, blockID := range blockIDs {
if ws.listeners[blockID] == nil {
ws.listeners[blockID] = []*websocket.Conn{}
}
ws.listeners[blockID] = append(ws.listeners[blockID], client)
}
ws.mu.Unlock()
}
// RemoveListener removes a webSocket listener from all blocks.
func (ws *Server) RemoveListener(client *websocket.Conn) {
ws.mu.Lock()
for key, clients := range ws.listeners {
listeners := []*websocket.Conn{}
for _, existingClient := range clients {
if client != existingClient {
listeners = append(listeners, existingClient)
}
}
ws.listeners[key] = listeners
}
ws.mu.Unlock()
}
// RemoveListenerFromBlocks removes a webSocket listener from a set of block.
func (ws *Server) RemoveListenerFromBlocks(client *websocket.Conn, blockIDs []string) {
ws.mu.Lock()
for _, blockID := range blockIDs {
listeners := ws.listeners[blockID]
if listeners == nil {
return
}
// Remove the first instance of this client that's listening to this block
// Note: A client can listen multiple times to the same block
for index, listener := range listeners {
if client == listener {
newListeners := append(listeners[:index], listeners[index+1:]...)
ws.listeners[blockID] = newListeners
break
}
}
}
ws.mu.Unlock()
}
// GetListeners returns the listeners to a blockID's changes.
func (ws *Server) GetListeners(blockID string) []*websocket.Conn {
ws.mu.Lock()
listeners := ws.listeners[blockID]
ws.mu.Unlock()
return listeners
}
// Server is a WebSocket server. // Server is a WebSocket server.
type Server struct { type Server struct {
upgrader websocket.Upgrader upgrader websocket.Upgrader
listeners map[string][]*websocket.Conn listeners map[string][]*websocket.Conn
mu sync.RWMutex mu sync.RWMutex
} auth *auth.Auth
singleUser bool
// NewServer creates a new Server.
func NewServer() *Server {
return &Server{
listeners: make(map[string][]*websocket.Conn),
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
},
}
} }
// UpdateMsg is sent on block updates // UpdateMsg is sent on block updates
@ -106,12 +31,42 @@ type UpdateMsg struct {
Block model.Block `json:"block"` Block model.Block `json:"block"`
} }
// ErrorMsg is sent on errors
type ErrorMsg struct {
Error string `json:"error"`
}
// WebsocketCommand is an incoming command from the client. // WebsocketCommand is an incoming command from the client.
type WebsocketCommand struct { type WebsocketCommand struct {
Action string `json:"action"` Action string `json:"action"`
Token string `json:"token"`
BlockIDs []string `json:"blockIds"` BlockIDs []string `json:"blockIds"`
} }
type websocketSession struct {
client *websocket.Conn
isAuthenticated bool
}
// NewServer creates a new Server.
func NewServer(auth *auth.Auth, singleUser bool) *Server {
return &Server{
listeners: make(map[string][]*websocket.Conn),
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
},
auth: auth,
singleUser: singleUser,
}
}
// RegisterRoutes registers routes.
func (ws *Server) RegisterRoutes(r *mux.Router) {
r.HandleFunc("/ws/onchange", ws.handleWebSocketOnChange)
}
func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request) { func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request) {
// Upgrade initial GET request to a websocket // Upgrade initial GET request to a websocket
client, err := ws.upgrader.Upgrade(w, r, nil) client, err := ws.upgrader.Upgrade(w, r, nil)
@ -128,17 +83,22 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request
log.Printf("DISCONNECT WebSocket onChange, client: %s", client.RemoteAddr()) log.Printf("DISCONNECT WebSocket onChange, client: %s", client.RemoteAddr())
// Remove client from listeners // Remove client from listeners
ws.RemoveListener(client) ws.removeListener(client)
client.Close() client.Close()
}() }()
wsSession := websocketSession{
client: client,
isAuthenticated: ws.singleUser,
}
// Simple message handling loop // Simple message handling loop
for { for {
_, p, err := client.ReadMessage() _, p, err := client.ReadMessage()
if err != nil { if err != nil {
log.Printf("ERROR WebSocket onChange, client: %s, err: %v", client.RemoteAddr(), err) log.Printf("ERROR WebSocket onChange, client: %s, err: %v", client.RemoteAddr(), err)
ws.RemoveListener(client) ws.removeListener(client)
break break
} }
@ -154,13 +114,17 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request
} }
switch command.Action { switch command.Action {
case "AUTH":
log.Printf(`Command: AUTH, client: %s`, client.RemoteAddr())
ws.authenticateListener(&wsSession, command.Token)
case "ADD": case "ADD":
log.Printf(`Command: Add blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr()) log.Printf(`Command: Add blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr())
ws.AddListener(client, command.BlockIDs) ws.addListener(&wsSession, command.BlockIDs)
case "REMOVE": case "REMOVE":
log.Printf(`Command: Remove blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr()) log.Printf(`Command: Remove blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr())
ws.RemoveListenerFromBlocks(client, command.BlockIDs) ws.removeListenerFromBlocks(&wsSession, command.BlockIDs)
default: default:
log.Printf(`ERROR webSocket command, invalid action: %v`, command.Action) log.Printf(`ERROR webSocket command, invalid action: %v`, command.Action)
@ -168,6 +132,119 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request
} }
} }
func (ws *Server) isValidSessionToken(token string) bool {
if ws.singleUser {
return true
}
session, err := ws.auth.GetSession(token)
if session == nil || err != nil {
return false
}
return true
}
func (ws *Server) authenticateListener(wsSession *websocketSession, token string) {
isValidSession := ws.isValidSessionToken(token)
if !isValidSession {
wsSession.client.Close()
return
}
// Authenticated
wsSession.isAuthenticated = true
log.Printf("authenticateListener: Authenticated")
}
// addListener adds a listener for a block's change.
func (ws *Server) addListener(wsSession *websocketSession, blockIDs []string) {
if !wsSession.isAuthenticated {
log.Printf("addListener: NOT AUTHENTICATED")
sendError(wsSession.client, "not authenticated")
return
}
ws.mu.Lock()
for _, blockID := range blockIDs {
if ws.listeners[blockID] == nil {
ws.listeners[blockID] = []*websocket.Conn{}
}
ws.listeners[blockID] = append(ws.listeners[blockID], wsSession.client)
}
ws.mu.Unlock()
}
// removeListener removes a webSocket listener from all blocks.
func (ws *Server) removeListener(client *websocket.Conn) {
ws.mu.Lock()
for key, clients := range ws.listeners {
listeners := []*websocket.Conn{}
for _, existingClient := range clients {
if client != existingClient {
listeners = append(listeners, existingClient)
}
}
ws.listeners[key] = listeners
}
ws.mu.Unlock()
}
// removeListenerFromBlocks removes a webSocket listener from a set of block.
func (ws *Server) removeListenerFromBlocks(wsSession *websocketSession, blockIDs []string) {
if !wsSession.isAuthenticated {
log.Printf("removeListenerFromBlocks: NOT AUTHENTICATED")
sendError(wsSession.client, "not authenticated")
return
}
ws.mu.Lock()
for _, blockID := range blockIDs {
listeners := ws.listeners[blockID]
if listeners == nil {
return
}
// Remove the first instance of this client that's listening to this block
// Note: A client can listen multiple times to the same block
for index, listener := range listeners {
if wsSession.client == listener {
newListeners := append(listeners[:index], listeners[index+1:]...)
ws.listeners[blockID] = newListeners
break
}
}
}
ws.mu.Unlock()
}
func sendError(conn *websocket.Conn, message string) {
errorMsg := ErrorMsg{
Error: message,
}
err := conn.WriteJSON(errorMsg)
if err != nil {
log.Printf("sendError error: %v", err)
conn.Close()
}
}
// getListeners returns the listeners to a blockID's changes.
func (ws *Server) getListeners(blockID string) []*websocket.Conn {
ws.mu.Lock()
listeners := ws.listeners[blockID]
ws.mu.Unlock()
return listeners
}
// BroadcastBlockDelete broadcasts delete messages to clients // BroadcastBlockDelete broadcasts delete messages to clients
func (ws *Server) BroadcastBlockDelete(blockID string, parentID string) { func (ws *Server) BroadcastBlockDelete(blockID string, parentID string) {
now := time.Now().Unix() now := time.Now().Unix()
@ -185,7 +262,7 @@ func (ws *Server) BroadcastBlockChange(block model.Block) {
blockIDsToNotify := []string{block.ID, block.ParentID} blockIDsToNotify := []string{block.ID, block.ParentID}
for _, blockID := range blockIDsToNotify { for _, blockID := range blockIDsToNotify {
listeners := ws.GetListeners(blockID) listeners := ws.getListeners(blockID)
log.Printf("%d listener(s) for blockID: %s", len(listeners), blockID) log.Printf("%d listener(s) for blockID: %s", len(listeners), blockID)
if listeners != nil { if listeners != nil {

View file

@ -11,9 +11,9 @@ type WSCommand = {
// These are messages from the server // These are messages from the server
type WSMessage = { type WSMessage = {
action: string action?: string
blockId: string block?: IBlock
block: IBlock error?: string
} }
type OnChangeHandler = (blocks: IBlock[]) => void type OnChangeHandler = (blocks: IBlock[]) => void
@ -27,6 +27,7 @@ class OctoListener {
} }
readonly serverUrl: string readonly serverUrl: string
private token: string
private ws?: WebSocket private ws?: WebSocket
private blockIds: string[] = [] private blockIds: string[] = []
private isInitialized = false private isInitialized = false
@ -38,14 +39,13 @@ class OctoListener {
notificationDelay = 100 notificationDelay = 100
reopenDelay = 3000 reopenDelay = 3000
constructor(serverUrl?: string) { constructor(serverUrl?: string, token?: string) {
this.serverUrl = serverUrl || window.location.origin this.serverUrl = serverUrl || window.location.origin
this.token = token || localStorage.getItem('sessionId') || ''
Utils.log(`OctoListener serverUrl: ${this.serverUrl}`) Utils.log(`OctoListener serverUrl: ${this.serverUrl}`)
} }
open(blockIds: string[], onChange: OnChangeHandler, onReconnect: () => void): void { open(blockIds: string[], onChange: OnChangeHandler, onReconnect: () => void): void {
let timeoutId: NodeJS.Timeout
if (this.ws) { if (this.ws) {
this.close() this.close()
} }
@ -61,6 +61,7 @@ class OctoListener {
ws.onopen = () => { ws.onopen = () => {
Utils.log('OctoListener webSocket opened.') Utils.log('OctoListener webSocket opened.')
this.authenticate()
this.addBlocks(blockIds) this.addBlocks(blockIds)
this.isInitialized = true this.isInitialized = true
} }
@ -91,13 +92,15 @@ class OctoListener {
try { try {
const message = JSON.parse(e.data) as WSMessage const message = JSON.parse(e.data) as WSMessage
if (message.error) {
Utils.logError(`Listener websocket error: ${message.error}`)
return
}
switch (message.action) { switch (message.action) {
case 'UPDATE_BLOCK': case 'UPDATE_BLOCK':
if (timeoutId) {
clearTimeout(timeoutId)
}
Utils.log(`OctoListener update block: ${message.block?.id}`) Utils.log(`OctoListener update block: ${message.block?.id}`)
this.queueUpdateNotification(message.block) this.queueUpdateNotification(message.block!)
break break
default: default:
Utils.logError(`Unexpected action: ${message.action}`) Utils.logError(`Unexpected action: ${message.action}`)
@ -124,6 +127,20 @@ class OctoListener {
ws.close() ws.close()
} }
authenticate(): void {
if (!this.ws) {
Utils.assertFailure('OctoListener.addBlocks: ws is not open')
return
}
const command = {
action: 'AUTH',
token: this.token,
}
this.ws.send(JSON.stringify(command))
}
addBlocks(blockIds: string[]): void { addBlocks(blockIds: string[]): void {
if (!this.ws) { if (!this.ws) {
Utils.assertFailure('OctoListener.addBlocks: ws is not open') Utils.assertFailure('OctoListener.addBlocks: ws is not open')