diff --git a/Makefile b/Makefile index 413eee9ee..ad97e081a 100644 --- a/Makefile +++ b/Makefile @@ -25,11 +25,14 @@ server-lint: echo "golangci-lint is not installed. Please see https://github.com/golangci/golangci-lint#install for installation instructions."; \ exit 1; \ fi; \ - cd server; golangci-lint run ./... + cd server; golangci-lint run -p format -p unused -p complexity -p bugs -p performance -E asciicheck -E depguard -E dogsled -E dupl -E funlen -E gochecknoglobals -E gochecknoinits -E goconst -E gocritic -E godot -E godox -E goerr113 -E goheader -E golint -E gomnd -E gomodguard -E goprintffuncname -E gosimple -E interfacer -E lll -E misspell -E nlreturn -E nolintlint -E stylecheck -E unconvert -E whitespace -E wsl --skip-dirs services/store/sqlstore/migrations/ ./... server-test: cd server; go test ./... +server-doc: + cd server; go doc ./... + watch-server: cd server; modd diff --git a/server/api/api.go b/server/api/api.go index 17536fa7d..168cae071 100644 --- a/server/api/api.go +++ b/server/api/api.go @@ -83,7 +83,7 @@ func (a *API) handlePostBlocks(w http.ResponseWriter, r *http.Request) { }() var blocks []model.Block - err = json.Unmarshal([]byte(requestBody), &blocks) + err = json.Unmarshal(requestBody, &blocks) if err != nil { errorResponse(w, http.StatusInternalServerError, ``) return @@ -95,15 +95,16 @@ func (a *API) handlePostBlocks(w http.ResponseWriter, r *http.Request) { errorResponse(w, http.StatusInternalServerError, fmt.Sprintf(`{"description": "missing type", "id": "%s"}`, block.ID)) return } + if block.CreateAt < 1 { errorResponse(w, http.StatusInternalServerError, fmt.Sprintf(`{"description": "invalid createAt", "id": "%s"}`, block.ID)) return } + if block.UpdateAt < 1 { errorResponse(w, http.StatusInternalServerError, fmt.Sprintf(`{"description": "invalid updateAt", "id": "%s"}`, block.ID)) return } - } err = a.app().InsertBlocks(blocks) @@ -190,7 +191,7 @@ func (a *API) handleImport(w http.ResponseWriter, r *http.Request) { }() var blocks []model.Block - err = json.Unmarshal([]byte(requestBody), &blocks) + err = json.Unmarshal(requestBody, &blocks) if err != nil { errorResponse(w, http.StatusInternalServerError, ``) return @@ -229,6 +230,7 @@ func (a *API) handleServeFile(w http.ResponseWriter, r *http.Request) { func (a *API) handleUploadFile(w http.ResponseWriter, r *http.Request) { fmt.Println(`handleUploadFile`) + file, handle, err := r.FormFile("file") if err != nil { fmt.Fprintf(w, "%v", err) @@ -243,6 +245,7 @@ func (a *API) handleUploadFile(w http.ResponseWriter, r *http.Request) { jsonStringResponse(w, http.StatusInternalServerError, `{}`) return } + log.Printf(`saveFile, url: %s`, url) json := fmt.Sprintf(`{ "url": "%s" }`, url) jsonStringResponse(w, http.StatusOK, json) diff --git a/server/app/app.go b/server/app/app.go index ddb687da8..9c7f9c943 100644 --- a/server/app/app.go +++ b/server/app/app.go @@ -10,10 +10,10 @@ import ( type App struct { config *config.Configuration store store.Store - wsServer *ws.WSServer + wsServer *ws.Server filesBackend filesstore.FileBackend } -func New(config *config.Configuration, store store.Store, wsServer *ws.WSServer, filesBackend filesstore.FileBackend) *App { +func New(config *config.Configuration, store store.Store, wsServer *ws.Server, filesBackend filesstore.FileBackend) *App { return &App{config: config, store: store, wsServer: wsServer, filesBackend: filesBackend} } diff --git a/server/app/blocks.go b/server/app/blocks.go index 14d384185..84faab2ad 100644 --- a/server/app/blocks.go +++ b/server/app/blocks.go @@ -24,12 +24,14 @@ func (a *App) InsertBlock(block model.Block) error { func (a *App) InsertBlocks(blocks []model.Block) error { var blockIDsToNotify = []string{} + uniqueBlockIDs := make(map[string]bool) for _, block := range blocks { if !uniqueBlockIDs[block.ID] { blockIDsToNotify = append(blockIDsToNotify, block.ID) } + if len(block.ParentID) > 0 && !uniqueBlockIDs[block.ParentID] { blockIDsToNotify = append(blockIDsToNotify, block.ParentID) } diff --git a/server/app/blocks_test.go b/server/app/blocks_test.go index 3fa5a7600..ef448d9ff 100644 --- a/server/app/blocks_test.go +++ b/server/app/blocks_test.go @@ -16,14 +16,16 @@ func TestGetParentID(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() store := mockstore.NewMockStore(ctrl) - wsserver := ws.NewWSServer() + wsserver := ws.NewServer() app := New(&config.Configuration{}, store, wsserver, &mocks.FileBackend{}) + t.Run("success query", func(t *testing.T) { store.EXPECT().GetParentID(gomock.Eq("test-id")).Return("test-parent-id", nil) result, err := app.GetParentID("test-id") require.NoError(t, err) require.Equal(t, "test-parent-id", result) }) + t.Run("fail query", func(t *testing.T) { store.EXPECT().GetParentID(gomock.Eq("test-id")).Return("", errors.New("block-not-found")) _, err := app.GetParentID("test-id") diff --git a/server/main/main.go b/server/main/main.go index 46de4c820..bb9ab5d8e 100644 --- a/server/main/main.go +++ b/server/main/main.go @@ -15,6 +15,10 @@ import ( // ---------------------------------------------------------------------------------------------------- // WebSocket OnChange listener +const ( + timeBetweenPidMonitoringChecks = 2 * time.Second +) + func isProcessRunning(pid int) bool { process, err := os.FindProcess(pid) if err != nil { @@ -27,13 +31,15 @@ func isProcessRunning(pid int) bool { func monitorPid(pid int) { log.Printf("Monitoring PID: %d", pid) + go func() { for { if !isProcessRunning(pid) { log.Printf("Monitored process not found, exiting.") os.Exit(1) } - time.Sleep(2 * time.Second) + + time.Sleep(timeBetweenPidMonitoringChecks) } }() } diff --git a/server/server/server.go b/server/server/server.go index e21c47d49..368dbfeb6 100644 --- a/server/server/server.go +++ b/server/server/server.go @@ -26,11 +26,11 @@ const CurrentVersion = "0.0.1" type Server struct { config *config.Configuration - wsServer *ws.WSServer + wsServer *ws.Server webServer *web.WebServer store store.Store filesBackend filesstore.FileBackend - telemetry *telemetry.TelemetryService + telemetry *telemetry.Service logger *zap.Logger } @@ -46,7 +46,7 @@ func New(cfg *config.Configuration) (*Server, error) { return nil, err } - wsServer := ws.NewWSServer() + wsServer := ws.NewServer() filesBackendSettings := model.FileSettings{} filesBackendSettings.SetDefaults(false) @@ -67,11 +67,13 @@ func New(cfg *config.Configuration) (*Server, error) { // Ctrl+C handling handler := make(chan os.Signal, 1) signal.Notify(handler, os.Interrupt) + go func() { for sig := range handler { // sig is a ^C, handle it if sig == os.Interrupt { os.Exit(1) + break } } diff --git a/server/services/scheduler/scheduler.go b/server/services/scheduler/scheduler.go index 8d9439926..52731fb17 100644 --- a/server/services/scheduler/scheduler.go +++ b/server/services/scheduler/scheduler.go @@ -41,9 +41,7 @@ func createTask(name string, function TaskFunc, interval time.Duration, recurrin defer close(task.cancelled) ticker := time.NewTicker(interval) - defer func() { - ticker.Stop() - }() + defer ticker.Stop() for { select { diff --git a/server/services/scheduler/scheduler_test.go b/server/services/scheduler/scheduler_test.go index 314d1f8bd..33b922278 100644 --- a/server/services/scheduler/scheduler_test.go +++ b/server/services/scheduler/scheduler_test.go @@ -12,67 +12,67 @@ import ( ) func TestCreateTask(t *testing.T) { - TASK_NAME := "Test Task" - TASK_TIME := time.Millisecond * 200 - TASK_WAIT := time.Millisecond * 100 + taskName := "Test Task" + taskTime := time.Millisecond * 200 + taskWait := time.Millisecond * 100 executionCount := new(int32) testFunc := func() { atomic.AddInt32(executionCount, 1) } - task := CreateTask(TASK_NAME, testFunc, TASK_TIME) + task := CreateTask(taskName, testFunc, taskTime) assert.EqualValues(t, 0, atomic.LoadInt32(executionCount)) - time.Sleep(TASK_TIME + TASK_WAIT) + time.Sleep(taskTime + taskWait) assert.EqualValues(t, 1, atomic.LoadInt32(executionCount)) - assert.Equal(t, TASK_NAME, task.Name) - assert.Equal(t, TASK_TIME, task.Interval) + assert.Equal(t, taskName, task.Name) + assert.Equal(t, taskTime, task.Interval) assert.False(t, task.Recurring) } func TestCreateRecurringTask(t *testing.T) { - TASK_NAME := "Test Recurring Task" - TASK_TIME := time.Millisecond * 200 - TASK_WAIT := time.Millisecond * 100 + taskName := "Test Recurring Task" + taskTime := time.Millisecond * 200 + taskWait := time.Millisecond * 100 executionCount := new(int32) testFunc := func() { atomic.AddInt32(executionCount, 1) } - task := CreateRecurringTask(TASK_NAME, testFunc, TASK_TIME) + task := CreateRecurringTask(taskName, testFunc, taskTime) assert.EqualValues(t, 0, atomic.LoadInt32(executionCount)) - time.Sleep(TASK_TIME + TASK_WAIT) + time.Sleep(taskTime + taskWait) assert.EqualValues(t, 1, atomic.LoadInt32(executionCount)) - time.Sleep(TASK_TIME) + time.Sleep(taskTime) assert.EqualValues(t, 2, atomic.LoadInt32(executionCount)) - assert.Equal(t, TASK_NAME, task.Name) - assert.Equal(t, TASK_TIME, task.Interval) + assert.Equal(t, taskName, task.Name) + assert.Equal(t, taskTime, task.Interval) assert.True(t, task.Recurring) task.Cancel() } func TestCancelTask(t *testing.T) { - TASK_NAME := "Test Task" - TASK_TIME := time.Millisecond * 100 - TASK_WAIT := time.Millisecond * 100 + taskName := "Test Task" + taskTime := time.Millisecond * 100 + taskWait := time.Millisecond * 100 executionCount := new(int32) testFunc := func() { atomic.AddInt32(executionCount, 1) } - task := CreateTask(TASK_NAME, testFunc, TASK_TIME) + task := CreateTask(taskName, testFunc, taskTime) assert.EqualValues(t, 0, atomic.LoadInt32(executionCount)) task.Cancel() - time.Sleep(TASK_TIME + TASK_WAIT) + time.Sleep(taskTime + taskWait) assert.EqualValues(t, 0, atomic.LoadInt32(executionCount)) } diff --git a/server/services/store/sqlstore/blocks.go b/server/services/store/sqlstore/blocks.go index 19962adb1..2868276d4 100644 --- a/server/services/store/sqlstore/blocks.go +++ b/server/services/store/sqlstore/blocks.go @@ -17,7 +17,8 @@ func (s *SQLStore) latestsBlocksSubquery() sq.SelectBuilder { } func (s *SQLStore) GetBlocksWithParentAndType(parentID string, blockType string) ([]model.Block, error) { - query := s.getQueryBuilder().Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at"). + query := s.getQueryBuilder(). + Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at"). FromSelect(s.latestsBlocksSubquery(), "latest"). Where(sq.Eq{"delete_at": 0}). Where(sq.Eq{"parent_id": parentID}). @@ -25,6 +26,7 @@ func (s *SQLStore) GetBlocksWithParentAndType(parentID string, blockType string) rows, err := query.Query() if err != nil { log.Printf(`getBlocksWithParentAndType ERROR: %v`, err) + return nil, err } @@ -32,7 +34,8 @@ func (s *SQLStore) GetBlocksWithParentAndType(parentID string, blockType string) } func (s *SQLStore) GetBlocksWithParent(parentID string) ([]model.Block, error) { - query := s.getQueryBuilder().Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at"). + query := s.getQueryBuilder(). + Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at"). FromSelect(s.latestsBlocksSubquery(), "latest"). Where(sq.Eq{"delete_at": 0}). Where(sq.Eq{"parent_id": parentID}) @@ -47,7 +50,8 @@ func (s *SQLStore) GetBlocksWithParent(parentID string) ([]model.Block, error) { } func (s *SQLStore) GetBlocksWithType(blockType string) ([]model.Block, error) { - query := s.getQueryBuilder().Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at"). + query := s.getQueryBuilder(). + Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at"). FromSelect(s.latestsBlocksSubquery(), "latest"). Where(sq.Eq{"delete_at": 0}). Where(sq.Eq{"type": blockType}) @@ -61,7 +65,8 @@ func (s *SQLStore) GetBlocksWithType(blockType string) ([]model.Block, error) { } func (s *SQLStore) GetSubTree(blockID string) ([]model.Block, error) { - query := s.getQueryBuilder().Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at"). + query := s.getQueryBuilder(). + Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at"). FromSelect(s.latestsBlocksSubquery(), "latest"). Where(sq.Eq{"delete_at": 0}). Where(sq.Or{sq.Eq{"id": blockID}, sq.Eq{"parent_id": blockID}}) @@ -76,7 +81,8 @@ func (s *SQLStore) GetSubTree(blockID string) ([]model.Block, error) { } func (s *SQLStore) GetAllBlocks() ([]model.Block, error) { - query := s.getQueryBuilder().Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at"). + query := s.getQueryBuilder(). + Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at"). FromSelect(s.latestsBlocksSubquery(), "latest"). Where(sq.Eq{"delete_at": 0}) @@ -97,6 +103,7 @@ func blocksFromRows(rows *sql.Rows) ([]model.Block, error) { for rows.Next() { var block model.Block var fieldsJSON string + err := rows.Scan( &block.ID, &block.ParentID, diff --git a/server/services/store/sqlstore/migrate.go b/server/services/store/sqlstore/migrate.go index 7dbec692b..60d662cf2 100644 --- a/server/services/store/sqlstore/migrate.go +++ b/server/services/store/sqlstore/migrate.go @@ -1,6 +1,8 @@ package sqlstore import ( + "errors" + "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" "github.com/golang-migrate/migrate/v4/database/postgres" @@ -13,28 +15,24 @@ import ( ) func (s *SQLStore) Migrate() error { + var bresource *bindata.AssetSource var driver database.Driver var err error - var bresource *bindata.AssetSource + if s.dbType == "sqlite3" { driver, err = sqlite3.WithInstance(s.db, &sqlite3.Config{}) if err != nil { return err } - bresource = bindata.Resource(sqlite.AssetNames(), - func(name string) ([]byte, error) { - return sqlite.Asset(name) - }) + bresource = bindata.Resource(sqlite.AssetNames(), sqlite.Asset) } + if s.dbType == "postgres" { driver, err = postgres.WithInstance(s.db, &postgres.Config{}) if err != nil { return err } - bresource = bindata.Resource(pgmigrations.AssetNames(), - func(name string) ([]byte, error) { - return pgmigrations.Asset(name) - }) + bresource = bindata.Resource(pgmigrations.AssetNames(), pgmigrations.Asset) } d, err := bindata.WithInstance(bresource) @@ -48,8 +46,9 @@ func (s *SQLStore) Migrate() error { } err = m.Up() - if err != nil && err != migrate.ErrNoChange { + if err != nil && errors.Is(err, migrate.ErrNoChange) { return err } + return nil } diff --git a/server/services/telemetry/telemetry.go b/server/services/telemetry/telemetry.go index 3b3c9b87e..a74a2fb50 100644 --- a/server/services/telemetry/telemetry.go +++ b/server/services/telemetry/telemetry.go @@ -14,19 +14,15 @@ import ( ) const ( - DAY_MILLISECONDS = 24 * 60 * 60 * 1000 - MONTH_MILLISECONDS = 31 * DAY_MILLISECONDS - - RUDDER_KEY = "placeholder_rudder_key" - RUDDER_DATAPLANE_URL = "placeholder_rudder_dataplane_url" - - TRACK_CONFIG = "config" + rudderKey = "placeholder_rudder_key" + rudderDataplaneURL = "placeholder_rudder_dataplane_url" + timeBetweenTelemetryChecks = 10 * time.Minute ) -type telemetryTracker func() map[string]interface{} +type Tracker func() map[string]interface{} -type TelemetryService struct { - trackers map[string]telemetryTracker +type Service struct { + trackers map[string]Tracker log *log.Logger rudderClient rudder.Client telemetryID string @@ -35,25 +31,25 @@ type TelemetryService struct { type RudderConfig struct { RudderKey string - DataplaneUrl string + DataplaneURL string } -func New(telemetryID string, log *log.Logger) *TelemetryService { - service := &TelemetryService{ +func New(telemetryID string, log *log.Logger) *Service { + service := &Service{ log: log, telemetryID: telemetryID, - trackers: map[string]telemetryTracker{}, + trackers: map[string]Tracker{}, } return service } -func (ts *TelemetryService) RegisterTracker(name string, tracker telemetryTracker) { +func (ts *Service) RegisterTracker(name string, tracker Tracker) { ts.trackers[name] = tracker } -func (ts *TelemetryService) getRudderConfig() RudderConfig { - if !strings.Contains(RUDDER_KEY, "placeholder") && !strings.Contains(RUDDER_DATAPLANE_URL, "placeholder") { - return RudderConfig{RUDDER_KEY, RUDDER_DATAPLANE_URL} +func (ts *Service) getRudderConfig() RudderConfig { + if !strings.Contains(rudderKey, "placeholder") && !strings.Contains(rudderDataplaneURL, "placeholder") { + return RudderConfig{rudderKey, rudderDataplaneURL} } else if os.Getenv("RUDDER_KEY") != "" && os.Getenv("RUDDER_DATAPLANE_URL") != "" { return RudderConfig{os.Getenv("RUDDER_KEY"), os.Getenv("RUDDER_DATAPLANE_URL")} } else { @@ -61,17 +57,18 @@ func (ts *TelemetryService) getRudderConfig() RudderConfig { } } -func (ts *TelemetryService) sendDailyTelemetry(override bool) { +func (ts *Service) sendDailyTelemetry(override bool) { config := ts.getRudderConfig() - if (config.DataplaneUrl != "" && config.RudderKey != "") || override { - ts.initRudder(config.DataplaneUrl, config.RudderKey) + if (config.DataplaneURL != "" && config.RudderKey != "") || override { + ts.initRudder(config.DataplaneURL, config.RudderKey) + for name, tracker := range ts.trackers { ts.sendTelemetry(name, tracker()) } } } -func (ts *TelemetryService) sendTelemetry(event string, properties map[string]interface{}) { +func (ts *Service) sendTelemetry(event string, properties map[string]interface{}) { if ts.rudderClient != nil { var context *rudder.Context ts.rudderClient.Enqueue(rudder.Track{ @@ -83,13 +80,13 @@ func (ts *TelemetryService) sendTelemetry(event string, properties map[string]in } } -func (ts *TelemetryService) initRudder(endpoint string, rudderKey string) { +func (ts *Service) initRudder(endpoint string, rudderKey string) { if ts.rudderClient == nil { config := rudder.Config{} config.Logger = rudder.StdLogger(ts.log) config.Endpoint = endpoint // For testing - if endpoint != RUDDER_DATAPLANE_URL { + if endpoint != rudderDataplaneURL { config.Verbose = true config.BatchSize = 1 } @@ -106,7 +103,7 @@ func (ts *TelemetryService) initRudder(endpoint string, rudderKey string) { } } -func (ts *TelemetryService) doTelemetryIfNeeded(firstRun time.Time) { +func (ts *Service) doTelemetryIfNeeded(firstRun time.Time) { hoursSinceFirstServerRun := time.Since(firstRun).Hours() // Send once every 10 minutes for the first hour // Send once every hour thereafter for the first 12 hours @@ -120,21 +117,21 @@ func (ts *TelemetryService) doTelemetryIfNeeded(firstRun time.Time) { } } -func (ts *TelemetryService) RunTelemetryJob(firstRun int64) { +func (ts *Service) RunTelemetryJob(firstRun int64) { // Send on boot ts.doTelemetry() scheduler.CreateRecurringTask("Telemetry", func() { ts.doTelemetryIfNeeded(time.Unix(0, firstRun*int64(time.Millisecond))) - }, time.Minute*10) + }, timeBetweenTelemetryChecks) } -func (ts *TelemetryService) doTelemetry() { +func (ts *Service) doTelemetry() { ts.timestampLastTelemetrySent = time.Now() ts.sendDailyTelemetry(false) } // Shutdown closes the telemetry client. -func (ts *TelemetryService) Shutdown() error { +func (ts *Service) Shutdown() error { if ts.rudderClient != nil { return ts.rudderClient.Close() } diff --git a/server/web/webserver.go b/server/web/webserver.go index 5f641ef37..4483ad85f 100644 --- a/server/web/webserver.go +++ b/server/web/webserver.go @@ -53,19 +53,23 @@ func (ws *WebServer) Start() error { urlPort := fmt.Sprintf(`:%d`, ws.port) var isSSL = ws.ssl && fileExists("./cert/cert.pem") && fileExists("./cert/key.pem") + if isSSL { log.Println("https server started on ", urlPort) err := http.ListenAndServeTLS(urlPort, "./cert/cert.pem", "./cert/key.pem", nil) if err != nil { return err } + return nil } + log.Println("http server started on ", urlPort) err := http.ListenAndServe(urlPort, nil) if err != nil { return err } + return nil } @@ -75,5 +79,6 @@ func fileExists(path string) bool { if os.IsNotExist(err) { return false } + return err == nil } diff --git a/server/ws/websockets.go b/server/ws/websockets.go index ae16608d1..64962318b 100644 --- a/server/ws/websockets.go +++ b/server/ws/websockets.go @@ -11,27 +11,29 @@ import ( ) // RegisterRoutes registeres routes -func (ws *WSServer) RegisterRoutes(r *mux.Router) { +func (ws *Server) RegisterRoutes(r *mux.Router) { r.HandleFunc("/ws/onchange", ws.handleWebSocketOnChange) } // AddListener adds a listener for a block's change -func (ws *WSServer) AddListener(client *websocket.Conn, blockIDs []string) { +func (ws *Server) AddListener(client *websocket.Conn, blockIDs []string) { ws.mu.Lock() for _, blockID := range blockIDs { if ws.listeners[blockID] == nil { ws.listeners[blockID] = []*websocket.Conn{} } + ws.listeners[blockID] = append(ws.listeners[blockID], client) } ws.mu.Unlock() } // RemoveListener removes a webSocket listener from all blocks -func (ws *WSServer) RemoveListener(client *websocket.Conn) { +func (ws *Server) RemoveListener(client *websocket.Conn) { ws.mu.Lock() for key, clients := range ws.listeners { var listeners = []*websocket.Conn{} + for _, existingClient := range clients { if client != existingClient { listeners = append(listeners, existingClient) @@ -43,7 +45,7 @@ func (ws *WSServer) RemoveListener(client *websocket.Conn) { } // RemoveListenerFromBlocks removes a webSocket listener from a set of block -func (ws *WSServer) RemoveListenerFromBlocks(client *websocket.Conn, blockIDs []string) { +func (ws *Server) RemoveListenerFromBlocks(client *websocket.Conn, blockIDs []string) { ws.mu.Lock() for _, blockID := range blockIDs { @@ -58,6 +60,7 @@ func (ws *WSServer) RemoveListenerFromBlocks(client *websocket.Conn, blockIDs [] if client == listener { newListeners := append(listeners[:index], listeners[index+1:]...) ws.listeners[blockID] = newListeners + break } } @@ -67,7 +70,7 @@ func (ws *WSServer) RemoveListenerFromBlocks(client *websocket.Conn, blockIDs [] } // GetListeners returns the listeners to a blockID's changes -func (ws *WSServer) GetListeners(blockID string) []*websocket.Conn { +func (ws *Server) GetListeners(blockID string) []*websocket.Conn { ws.mu.Lock() listeners := ws.listeners[blockID] ws.mu.Unlock() @@ -75,16 +78,16 @@ func (ws *WSServer) GetListeners(blockID string) []*websocket.Conn { return listeners } -// WSServer is a WebSocket server -type WSServer struct { +// Server is a WebSocket server +type Server struct { upgrader websocket.Upgrader listeners map[string][]*websocket.Conn mu sync.RWMutex } -// NewWSServer creates a new WSServer -func NewWSServer() *WSServer { - return &WSServer{ +// NewServer creates a new Server +func NewServer() *Server { + return &Server{ listeners: make(map[string][]*websocket.Conn), upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { @@ -106,7 +109,7 @@ type WebsocketCommand struct { BlockIDs []string `json:"blockIds"` } -func (ws *WSServer) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request) { +func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request) { // Upgrade initial GET request to a websocket client, err := ws.upgrader.Upgrade(w, r, nil) if err != nil { @@ -133,6 +136,7 @@ func (ws *WSServer) handleWebSocketOnChange(w http.ResponseWriter, r *http.Reque if err != nil { log.Printf("ERROR WebSocket onChange, client: %s, err: %v", client.RemoteAddr(), err) ws.RemoveListener(client) + break } @@ -141,6 +145,7 @@ func (ws *WSServer) handleWebSocketOnChange(w http.ResponseWriter, r *http.Reque if err != nil { // handle this error log.Printf(`ERROR webSocket parsing command JSON: %v`, string(p)) + continue } @@ -148,9 +153,11 @@ func (ws *WSServer) handleWebSocketOnChange(w http.ResponseWriter, r *http.Reque case "ADD": log.Printf(`Command: Add blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr()) ws.AddListener(client, command.BlockIDs) + case "REMOVE": log.Printf(`Command: Remove blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr()) ws.RemoveListenerFromBlocks(client, command.BlockIDs) + default: log.Printf(`ERROR webSocket command, invalid action: %v`, command.Action) } @@ -158,7 +165,7 @@ func (ws *WSServer) handleWebSocketOnChange(w http.ResponseWriter, r *http.Reque } // BroadcastBlockChangeToWebsocketClients broadcasts change to clients -func (ws *WSServer) BroadcastBlockChangeToWebsocketClients(blockIDs []string) { +func (ws *Server) BroadcastBlockChangeToWebsocketClients(blockIDs []string) { for _, blockID := range blockIDs { listeners := ws.GetListeners(blockID) log.Printf("%d listener(s) for blockID: %s", len(listeners), blockID) @@ -168,6 +175,7 @@ func (ws *WSServer) BroadcastBlockChangeToWebsocketClients(blockIDs []string) { Action: "UPDATE_BLOCK", BlockID: blockID, } + for _, listener := range listeners { log.Printf("Broadcast change, blockID: %s, remoteAddr: %s", blockID, listener.RemoteAddr()) err := listener.WriteJSON(message)