diff --git a/server/main/main.go b/server/main/main.go index 04af07421..b4b0b05b7 100644 --- a/server/main/main.go +++ b/server/main/main.go @@ -21,6 +21,7 @@ import ( ) var config Configuration +var store *SQLStore // WebsocketMsg is send on block changes type WebsocketMsg struct { @@ -64,9 +65,9 @@ func handleGetBlocks(w http.ResponseWriter, r *http.Request) { var blocks []string if len(blockType) > 0 { - blocks = getBlocksWithParentAndType(parentID, blockType) + blocks = store.getBlocksWithParentAndType(parentID, blockType) } else { - blocks = getBlocksWithParent(parentID) + blocks = store.getBlocksWithParent(parentID) } log.Printf("GetBlocks parentID: %s, %d result(s)", parentID, len(blocks)) response := `[` + strings.Join(blocks[:], ",") + `]` @@ -127,7 +128,7 @@ func handlePostBlocks(w http.ResponseWriter, r *http.Request) { return } - insertBlock(block, string(jsonBytes)) + store.insertBlock(block, string(jsonBytes)) } broadcastBlockChangeToWebsocketClients(blockIDsToNotify) @@ -142,13 +143,13 @@ func handleDeleteBlock(w http.ResponseWriter, r *http.Request) { var blockIDsToNotify = []string{blockID} - parentID := getParentID(blockID) + parentID := store.getParentID(blockID) if len(parentID) > 0 { blockIDsToNotify = append(blockIDsToNotify, parentID) } - deleteBlock(blockID) + store.deleteBlock(blockID) broadcastBlockChangeToWebsocketClients(blockIDsToNotify) @@ -160,7 +161,7 @@ func handleGetSubTree(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) blockID := vars["blockID"] - blocks := getSubTree(blockID) + blocks := store.getSubTree(blockID) log.Printf("GetSubTree blockID: %s, %d result(s)", blockID, len(blocks)) response := `[` + strings.Join(blocks[:], ",") + `]` @@ -168,7 +169,7 @@ func handleGetSubTree(w http.ResponseWriter, r *http.Request) { } func handleExport(w http.ResponseWriter, r *http.Request) { - blocks := getAllBlocks() + blocks := store.getAllBlocks() log.Printf("EXPORT Blocks, %d result(s)", len(blocks)) response := `[` + strings.Join(blocks[:], ",") + `]` @@ -191,25 +192,24 @@ func handleImport(w http.ResponseWriter, r *http.Request) { } }() - var blockMaps []map[string]interface{} - err = json.Unmarshal([]byte(requestBody), &blockMaps) + var blocks []Block + err = json.Unmarshal([]byte(requestBody), &blocks) if err != nil { errorResponse(w, http.StatusInternalServerError, ``) return } - for _, blockMap := range blockMaps { - jsonBytes, err := json.Marshal(blockMap) + for _, block := range blocks { + jsonBytes, err := json.Marshal(block) if err != nil { errorResponse(w, http.StatusInternalServerError, `{}`) return } - block := blockFromMap(blockMap) - insertBlock(block, string(jsonBytes)) + store.insertBlock(block, string(jsonBytes)) } - log.Printf("IMPORT Blocks %d block(s)", len(blockMaps)) + log.Printf("IMPORT Blocks %d block(s)", len(blocks)) jsonResponse(w, http.StatusOK, "{}") } @@ -434,7 +434,12 @@ func main() { http.Handle("/", r) - connectDatabase(config.DBType, config.DBConfigString) + var err error + store, err = NewSQLStore(config.DBType, config.DBConfigString) + if err != nil { + log.Fatal("Unable to start the database", err) + panic(err) + } // Ctrl+C handling handler := make(chan os.Signal, 1) diff --git a/server/main/octoDatabase.go b/server/main/octoDatabase.go index 4a4568fa1..fbed7a58b 100644 --- a/server/main/octoDatabase.go +++ b/server/main/octoDatabase.go @@ -10,25 +10,38 @@ import ( _ "github.com/mattn/go-sqlite3" ) -var db *sql.DB +type SQLStore struct { + db *sql.DB + dbType string +} -func connectDatabase(dbType string, connectionString string) { +func NewSQLStore(dbType, connectionString string) (*SQLStore, error) { log.Println("connectDatabase") var err error - db, err = sql.Open(dbType, connectionString) + db, err := sql.Open(dbType, connectionString) if err != nil { log.Fatal("connectDatabase: ", err) - panic(err) + return nil, err } err = db.Ping() if err != nil { log.Println(`Database Ping failed`) - panic(err) + return nil, err } - createTablesIfNotExists(dbType) + store := &SQLStore{ + db: db, + dbType: dbType, + } + + err = store.createTablesIfNotExists() + if err != nil { + log.Println(`Table creation failed`) + return nil, err + } + return store, nil } // Block is the basic data unit @@ -41,11 +54,11 @@ type Block struct { DeleteAt int64 `json:"deleteAt"` } -func createTablesIfNotExists(dbType string) { +func (s *SQLStore) createTablesIfNotExists() error { // TODO: Add update_by with the user's ID // TODO: Consolidate insert_at and update_at, decide if the server of DB should set it var query string - if dbType == "sqlite3" { + if s.dbType == "sqlite3" { query = `CREATE TABLE IF NOT EXISTS blocks ( id VARCHAR(36), insert_at DATETIME NOT NULL DEFAULT current_timestamp, @@ -71,12 +84,13 @@ func createTablesIfNotExists(dbType string) { );` } - _, err := db.Exec(query) + _, err := s.db.Exec(query) if err != nil { log.Fatal("createTablesIfNotExists: ", err) - panic(err) + return err } - log.Printf("createTablesIfNotExists(%s)", dbType) + log.Printf("createTablesIfNotExists(%s)", s.dbType) + return nil } func blockFromMap(m map[string]interface{}) Block { @@ -103,7 +117,7 @@ func blockFromMap(m map[string]interface{}) Block { return b } -func getBlocksWithParentAndType(parentID string, blockType string) []string { +func (s *SQLStore) getBlocksWithParentAndType(parentID string, blockType string) []string { query := `WITH latest AS ( SELECT * FROM @@ -120,7 +134,7 @@ func getBlocksWithParentAndType(parentID string, blockType string) []string { FROM latest WHERE delete_at = 0 and parent_id = $1 and type = $2` - rows, err := db.Query(query, parentID, blockType) + rows, err := s.db.Query(query, parentID, blockType) if err != nil { log.Printf(`getBlocksWithParentAndType ERROR: %v`, err) panic(err) @@ -129,7 +143,7 @@ func getBlocksWithParentAndType(parentID string, blockType string) []string { return blocksFromRows(rows) } -func getBlocksWithParent(parentID string) []string { +func (s *SQLStore) getBlocksWithParent(parentID string) []string { query := `WITH latest AS ( SELECT * FROM @@ -146,7 +160,7 @@ func getBlocksWithParent(parentID string) []string { FROM latest WHERE delete_at = 0 and parent_id = $1` - rows, err := db.Query(query, parentID) + rows, err := s.db.Query(query, parentID) if err != nil { log.Printf(`getBlocksWithParent ERROR: %v`, err) panic(err) @@ -155,7 +169,7 @@ func getBlocksWithParent(parentID string) []string { return blocksFromRows(rows) } -func getSubTree(blockID string) []string { +func (s *SQLStore) getSubTree(blockID string) []string { query := `WITH latest AS ( SELECT * FROM @@ -174,7 +188,7 @@ func getSubTree(blockID string) []string { AND (id = $1 OR parent_id = $1)` - rows, err := db.Query(query, blockID) + rows, err := s.db.Query(query, blockID) if err != nil { log.Printf(`getSubTree ERROR: %v`, err) panic(err) @@ -183,7 +197,7 @@ func getSubTree(blockID string) []string { return blocksFromRows(rows) } -func getAllBlocks() []string { +func (s *SQLStore) getAllBlocks() []string { query := `WITH latest AS ( SELECT * FROM @@ -200,7 +214,7 @@ func getAllBlocks() []string { FROM latest WHERE delete_at = 0` - rows, err := db.Query(query) + rows, err := s.db.Query(query) if err != nil { log.Printf(`getAllBlocks ERROR: %v`, err) panic(err) @@ -229,7 +243,7 @@ func blocksFromRows(rows *sql.Rows) []string { return results } -func getParentID(blockID string) string { +func (s *SQLStore) getParentID(blockID string) string { statement := `WITH latest AS ( @@ -248,7 +262,7 @@ func getParentID(blockID string) string { WHERE delete_at = 0 AND id = $1` - row := db.QueryRow(statement, blockID) + row := s.db.QueryRow(statement, blockID) var parentID string err := row.Scan(&parentID) @@ -259,19 +273,19 @@ func getParentID(blockID string) string { return parentID } -func insertBlock(block Block, json string) { +func (s *SQLStore) insertBlock(block Block, json string) { statement := `INSERT INTO blocks(id, parent_id, type, json, create_at, update_at, delete_at) VALUES($1, $2, $3, $4, $5, $6, $7)` - _, err := db.Exec(statement, block.ID, block.ParentID, block.Type, json, block.CreateAt, block.UpdateAt, block.DeleteAt) + _, err := s.db.Exec(statement, block.ID, block.ParentID, block.Type, json, block.CreateAt, block.UpdateAt, block.DeleteAt) if err != nil { panic(err) } } -func deleteBlock(blockID string) { +func (s *SQLStore) deleteBlock(blockID string) { now := time.Now().Unix() json := fmt.Sprintf(`{"id":"%s","updateAt":%d,"deleteAt":%d}`, blockID, now, now) statement := `INSERT INTO blocks(id, json, update_at, delete_at) VALUES($1, $2, $3, $4)` - _, err := db.Exec(statement, blockID, json, now, now) + _, err := s.db.Exec(statement, blockID, json, now, now) if err != nil { panic(err) }