From daf83691b13eb1da4854ec71abcd95f72ab7117f Mon Sep 17 00:00:00 2001 From: Chen-I Lim Date: Fri, 5 Feb 2021 10:28:52 -0800 Subject: [PATCH] Refactor API to use middleware --- server/api/api.go | 54 +++++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/server/api/api.go b/server/api/api.go index 320af0ef1..4953504b3 100644 --- a/server/api/api.go +++ b/server/api/api.go @@ -40,49 +40,53 @@ func (a *API) app() *app.App { } func (a *API) RegisterRoutes(r *mux.Router) { - a.addHandler(r, "/api/v1/blocks", "GET", a.sessionRequired(a.handleGetBlocks)) - a.addHandler(r, "/api/v1/blocks", "POST", a.sessionRequired(a.handlePostBlocks)) - a.addHandler(r, "/api/v1/blocks/{blockID}", "DELETE", a.sessionRequired(a.handleDeleteBlock)) - a.addHandler(r, "/api/v1/blocks/{blockID}/subtree", "GET", a.attachSession(a.handleGetSubTree, false)) + apiv1 := r.PathPrefix("/api/v1").Subrouter() + apiv1.Use(a.requireCSRFToken) - a.addHandler(r, "/api/v1/users/me", "GET", a.sessionRequired(a.handleGetMe)) - a.addHandler(r, "/api/v1/users/{userID}", "GET", a.sessionRequired(a.handleGetUser)) - a.addHandler(r, "/api/v1/users/{userID}/changepassword", "POST", a.sessionRequired(a.handleChangePassword)) + apiv1.HandleFunc("/blocks", a.sessionRequired(a.handleGetBlocks)).Methods("GET") + apiv1.HandleFunc("/blocks", a.sessionRequired(a.handlePostBlocks)).Methods("POST") + apiv1.HandleFunc("/blocks/{blockID}", a.sessionRequired(a.handleDeleteBlock)).Methods("DELETE") + apiv1.HandleFunc("/blocks/{blockID}/subtree", a.attachSession(a.handleGetSubTree, false)).Methods("GET") - a.addHandler(r, "/api/v1/login", "POST", a.sessionRequired(a.handleLogin)) - a.addHandler(r, "/api/v1/register", "POST", a.sessionRequired(a.handleRegister)) + apiv1.HandleFunc("/users/me", a.sessionRequired(a.handleGetMe)).Methods("GET") + apiv1.HandleFunc("/users/{userID}", a.sessionRequired(a.handleGetUser)).Methods("GET") + apiv1.HandleFunc("/users/{userID}/changepassword", a.sessionRequired(a.handleChangePassword)).Methods("POST") - a.addHandler(r, "api/v1/files", "POST", a.sessionRequired(a.handleUploadFile)) - a.addHandler(r, "/files/{filename}", "GET", a.sessionRequired(a.handleServeFile)) + apiv1.HandleFunc("/login", a.handleLogin).Methods("POST") + apiv1.HandleFunc("/register", a.handleRegister).Methods("POST") - a.addHandler(r, "/api/v1/blocks/export", "GET", a.sessionRequired(a.handleExport)) - a.addHandler(r, "/api/v1/blocks/import", "POST", a.sessionRequired(a.handleImport)) + apiv1.HandleFunc("/blocks/export", a.sessionRequired(a.handleExport)).Methods("GET") + apiv1.HandleFunc("/blocks/import", a.sessionRequired(a.handleImport)).Methods("POST") - a.addHandler(r, "/api/v1/sharing/{rootID}", "POST", a.sessionRequired(a.handlePostSharing)) - a.addHandler(r, "/api/v1/sharing/{rootID}", "GET", a.sessionRequired(a.handleGetSharing)) + apiv1.HandleFunc("/sharing/{rootID}", a.sessionRequired(a.handlePostSharing)).Methods("POST") + apiv1.HandleFunc("/sharing/{rootID}", a.sessionRequired(a.handleGetSharing)).Methods("GET") - a.addHandler(r, "/api/v1/workspace", "GET", a.sessionRequired(a.handleGetWorkspace)) - a.addHandler(r, "/api/v1/workspace/regenerate_signup_token", "POST", a.sessionRequired(a.handlePostWorkspaceRegenerateSignupToken)) + apiv1.HandleFunc("/workspace", a.sessionRequired(a.handleGetWorkspace)).Methods("GET") + apiv1.HandleFunc("/workspace/regenerate_signup_token", a.sessionRequired(a.handlePostWorkspaceRegenerateSignupToken)).Methods("POST") + + // Files API + + files := r.PathPrefix("/files/").Subrouter() + files.Use(a.requireCSRFToken) + + files.HandleFunc("/", a.sessionRequired(a.handleUploadFile)).Methods("POST") + files.HandleFunc("/{filename}", a.sessionRequired(a.handleServeFile)).Methods("GET") } func (a *API) RegisterAdminRoutes(r *mux.Router) { r.HandleFunc("/api/v1/admin/users/{username}/password", a.adminRequired(a.handleAdminSetPassword)).Methods("POST") } -func (a *API) addHandler(r *mux.Router, path string, method string, f func(http.ResponseWriter, *http.Request)) { - r.HandleFunc(path, a.preHandle(f)).Methods(method) -} - -func (a *API) preHandle(handler func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { +func (a *API) requireCSRFToken(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !a.checkCSRFToken(r) { log.Println("checkCSRFToken FAILED") errorResponse(w, http.StatusBadRequest, nil, nil) return } - handler(w, r) - } + next.ServeHTTP(w, r) + }) } func (a *API) checkCSRFToken(r *http.Request) bool {