diff --git a/internal/share/webdav/webdav.go b/internal/share/webdav/webdav.go index 1b76cb3bc..4b7af6b8a 100644 --- a/internal/share/webdav/webdav.go +++ b/internal/share/webdav/webdav.go @@ -32,9 +32,17 @@ func Connect(url, user, pass string) Client { return result } +func (c Client) readDir(path string) ([]os.FileInfo, error) { + if path == "" { + path = "/" + } + + return c.client.ReadDir(path) +} + // Files returns all files in path as string slice. func (c Client) Files(path string) (result []string, err error) { - files, err := c.client.ReadDir(path) + files, err := c.readDir(path) if err != nil { return result, err @@ -49,8 +57,8 @@ func (c Client) Files(path string) (result []string, err error) { } // Directories returns all sub directories in path as string slice. -func (c Client) Directories(path string) (result []string, err error) { - files, err := c.client.ReadDir(path) +func (c Client) Directories(path string, recursive bool) (result []string, err error) { + files, err := c.readDir(path) if err != nil { return result, err @@ -58,7 +66,19 @@ func (c Client) Directories(path string) (result []string, err error) { for _, file := range files { if !file.Mode().IsDir() { continue } - result = append(result, fmt.Sprintf("%s/%s", path, file.Name())) + + dir := fmt.Sprintf("%s/%s", path, file.Name()) + result = append(result, dir) + + if recursive { + subDirs, err := c.Directories(dir, true) + + if err != nil { + return result, err + } + + result = append(result, subDirs...) + } } return result, nil @@ -118,7 +138,7 @@ func (c Client) DownloadDir(from, to string, recursive bool) (errs []error) { return errs } - dirs, err := c.Directories(from) + dirs, err := c.Directories(from, false) for _, dir := range dirs { errs = append(errs, c.DownloadDir(dir, to, true)...) diff --git a/internal/share/webdav/webdav_test.go b/internal/share/webdav/webdav_test.go index ac141619b..cb420be72 100644 --- a/internal/share/webdav/webdav_test.go +++ b/internal/share/webdav/webdav_test.go @@ -37,6 +37,38 @@ func TestClient_Files(t *testing.T) { } } +func TestClient_Directories(t *testing.T) { + c := Connect(testUrl, testUser, testPass) + + assert.IsType(t, Client{}, c) + + t.Run("non-recursive", func(t *testing.T) { + dirs, err := c.Directories("", false) + + if err != nil { + t.Fatal(err) + } + + if len(dirs) == 0 { + t.Fatal("no directories found") + } + + assert.Equal(t, "/Photos", dirs[0]) + }) + + t.Run("recursive", func(t *testing.T) { + dirs, err := c.Directories("", true) + + if err != nil { + t.Fatal(err) + } + + if len(dirs) < 2 { + t.Fatal("at least 2 directories expected") + } + }) +} + func TestClient_Download(t *testing.T) { c := Connect(testUrl, testUser, testPass)