github.com/grailbio/base@v0.0.11/file/gfilefs/write_test.go (about)

     1  // Copyright 2022 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package gfilefs_test
     6  
     7  import (
     8  	"context"
     9  	"flag"
    10  	gofs "io/fs"
    11  	"io/ioutil"
    12  	"log"
    13  	"math/rand"
    14  	"os"
    15  	"path/filepath"
    16  	"strings"
    17  	"sync"
    18  	"testing"
    19  
    20  	"github.com/grailbio/base/errors"
    21  	"github.com/grailbio/base/file"
    22  	"github.com/grailbio/base/file/fsnodefuse"
    23  	"github.com/grailbio/base/file/gfilefs"
    24  	"github.com/grailbio/base/file/s3file"
    25  	"github.com/grailbio/testutil"
    26  	"github.com/hanwen/go-fuse/v2/fs"
    27  	"github.com/hanwen/go-fuse/v2/fuse"
    28  	"github.com/stretchr/testify/assert"
    29  	"github.com/stretchr/testify/require"
    30  )
    31  
    32  func init() {
    33  	file.RegisterImplementation("s3", func() file.Implementation {
    34  		return s3file.NewImplementation(
    35  			s3file.NewDefaultProvider(), s3file.Options{},
    36  		)
    37  	})
    38  }
    39  
    40  // s3RootFlag sets an S3 root directory to use for test files.  When set to a
    41  // non-empty S3 path, e.g. "s3://some-bucket/some-writable/prefix", tests will
    42  // run with a mount point with this root.  These tests will run in addition to
    43  // the normal local root testing.
    44  var s3RootFlag = flag.String(
    45  	"s3-root",
    46  	"",
    47  	"optional S3 root directory to use for testing, in addition to the local root",
    48  )
    49  
    50  // TestCreateEmpty verifies that we can create an new empty file using various
    51  // flag parameters when opening, e.g. O_TRUNC.
    52  func TestCreateEmpty(t *testing.T) {
    53  	flagElements := [][]int{
    54  		{os.O_RDONLY, os.O_RDWR, os.O_WRONLY},
    55  		{0, os.O_TRUNC},
    56  		{0, os.O_EXCL},
    57  	}
    58  	// combos produces the flag parameters to test (less O_CREATE, which is
    59  	// applied below).
    60  	var combos func(elems [][]int) []int
    61  	combos = func(elems [][]int) []int {
    62  		if len(elems) == 1 {
    63  			return elems[0]
    64  		}
    65  		var result []int
    66  		for _, elem := range elems[0] {
    67  			for _, flag := range combos(elems[1:]) {
    68  				flag |= elem
    69  				result = append(result, flag)
    70  			}
    71  		}
    72  		return result
    73  	}
    74  	// name generates a nice name for a subtest for a given flag.
    75  	name := func(flags int) string {
    76  		var (
    77  			parts  []string
    78  			access string
    79  		)
    80  		switch {
    81  		case flags&os.O_RDWR == os.O_RDWR:
    82  			access = "RDWR"
    83  		case flags&os.O_WRONLY == os.O_WRONLY:
    84  			access = "WRONLY"
    85  		default:
    86  			access = "RDONLY"
    87  		}
    88  		parts = append(parts, access)
    89  		if flags&os.O_TRUNC == os.O_TRUNC {
    90  			parts = append(parts, "TRUNC")
    91  		}
    92  		if flags&os.O_EXCL == os.O_EXCL {
    93  			parts = append(parts, "EXCL")
    94  		}
    95  		return strings.Join(parts, "_")
    96  	}
    97  	for _, flag := range combos(flagElements) {
    98  		withTestMounts(t, func(m testMount) {
    99  			t.Run(name(flag), func(t *testing.T) {
   100  				path := filepath.Join(m.mountPoint, "test")
   101  				flag |= os.O_CREATE
   102  				f, err := os.OpenFile(path, flag, 0666)
   103  				require.NoError(t, err, "creating file")
   104  				require.NoError(t, f.Close(), "closing file")
   105  
   106  				info, err := os.Stat(path)
   107  				require.NoError(t, err, "stat of file")
   108  				assert.Equal(t, int64(0), info.Size(), "file should have zero size")
   109  
   110  				bs, err := ioutil.ReadFile(path)
   111  				require.NoError(t, err, "reading file")
   112  				assert.Empty(t, bs, "file should be empty")
   113  			})
   114  		})
   115  	}
   116  }
   117  
   118  // TestCreate verifies that we can create a new file, write content to it, and
   119  // read the same content back.
   120  func TestCreate(t *testing.T) {
   121  	withTestMounts(t, func(m testMount) {
   122  		var (
   123  			r        = rand.New(rand.NewSource(0))
   124  			path     = filepath.Join(m.mountPoint, "test")
   125  			rootPath = file.Join(m.root, "test")
   126  		)
   127  		assertRoundTrip(t, path, rootPath, r, 10*(1<<20))
   128  		assertRoundTrip(t, path, rootPath, r, 10*(1<<16))
   129  	})
   130  }
   131  
   132  // TestOverwrite verifies that we can overwrite the same file repeatedly, and
   133  // that the updated content is correct.
   134  func TestOverwrite(t *testing.T) {
   135  	withTestMounts(t, func(m testMount) {
   136  		const NumOverwrites = 20
   137  		var (
   138  			r        = rand.New(rand.NewSource(0))
   139  			path     = filepath.Join(m.mountPoint, "test")
   140  			rootPath = file.Join(m.root, "test")
   141  		)
   142  		for i := 0; i < NumOverwrites+1; i++ {
   143  			// Each iteration uses a random size between 5 and 10 MiB.
   144  			n := 5 + r.Intn(10)
   145  			n *= 1 << 20
   146  			assertRoundTrip(t, path, rootPath, r, n)
   147  		}
   148  	})
   149  }
   150  
   151  // TestTruncFlag verifies that opening with O_TRUNC truncates the file.
   152  func TestTruncFlag(t *testing.T) {
   153  	t.Run("WRONLY", func(t *testing.T) {
   154  		testTruncFlag(t, os.O_WRONLY)
   155  	})
   156  	t.Run("RDWR", func(t *testing.T) {
   157  		testTruncFlag(t, os.O_RDWR)
   158  	})
   159  }
   160  
   161  func testTruncFlag(t *testing.T, flag int) {
   162  	withTestMounts(t, func(m testMount) {
   163  		path := filepath.Join(m.mountPoint, "test")
   164  		// Write the file we will truncate to test.
   165  		err := ioutil.WriteFile(path, []byte{0, 1, 2}, 0644)
   166  		require.NoError(t, err, "writing file")
   167  
   168  		f, err := os.OpenFile(path, flag|os.O_TRUNC, 0666)
   169  		require.NoError(t, err, "opening for truncation")
   170  		func() {
   171  			defer func() {
   172  				require.NoError(t, f.Close())
   173  			}()
   174  			var info gofs.FileInfo
   175  			info, err = f.Stat()
   176  			require.NoError(t, err, "getting file stats")
   177  			assert.Equal(t, int64(0), info.Size(), "truncated file should be zero bytes")
   178  		}()
   179  
   180  		// Verify that reading the truncated file yields zero bytes.
   181  		bsRead, err := ioutil.ReadFile(path)
   182  		require.NoError(t, err, "reading truncated file")
   183  		assert.Empty(t, bsRead, "reading truncated file should yield no data")
   184  	})
   185  }
   186  
   187  // TestTruncateZero verifies that truncation to zero works.
   188  func TestTruncateZero(t *testing.T) {
   189  	t.Run("WRONLY", func(t *testing.T) {
   190  		testTruncateZero(t, os.O_WRONLY)
   191  	})
   192  	t.Run("RDWR", func(t *testing.T) {
   193  		testTruncateZero(t, os.O_RDWR)
   194  	})
   195  }
   196  
   197  func testTruncateZero(t *testing.T, flag int) {
   198  	withTestMounts(t, func(m testMount) {
   199  		path := filepath.Join(m.mountPoint, "test")
   200  		// Write the file we will truncate to test.
   201  		err := ioutil.WriteFile(path, []byte{0, 1, 2}, 0644)
   202  		require.NoError(t, err, "writing file")
   203  
   204  		f, err := os.OpenFile(path, os.O_WRONLY, 0666)
   205  		require.NoError(t, err, "opening for truncation")
   206  
   207  		func() {
   208  			defer func() {
   209  				require.NoError(t, f.Close(), "closing")
   210  			}()
   211  			// Sanity check that the initial file handle is the correct size.
   212  			var info gofs.FileInfo
   213  			info, err = f.Stat()
   214  			require.NoError(t, err, "getting file stats")
   215  			assert.Equal(t, int64(3), info.Size(), "file to truncate should be three bytes")
   216  
   217  			require.NoError(t, f.Truncate(0), "truncating")
   218  
   219  			// Verify that the file handle is actually truncated.
   220  			info, err = f.Stat()
   221  			require.NoError(t, err, "getting file stats")
   222  			assert.Equal(t, int64(0), info.Size(), "truncated file should be zero bytes")
   223  		}()
   224  
   225  		// Verify that an independent stat shows zero size.
   226  		info, err := os.Stat(path)
   227  		require.NoError(t, err, "getting file stats")
   228  		assert.Equal(t, int64(0), info.Size(), "truncated file should be zero bytes")
   229  
   230  		// Verify that reading the truncated file yields zero bytes.
   231  		bsRead, err := ioutil.ReadFile(path)
   232  		require.NoError(t, err, "reading truncated file")
   233  		assert.Empty(t, bsRead, "reading truncated file should yield no data")
   234  	})
   235  }
   236  
   237  // TestRemove verifies that we can remove a file.
   238  func TestRemove(t *testing.T) {
   239  	withTestMounts(t, func(m testMount) {
   240  		var (
   241  			r        = rand.New(rand.NewSource(0))
   242  			path     = filepath.Join(m.mountPoint, "test")
   243  			rootPath = file.Join(m.root, "test")
   244  		)
   245  		bs := make([]byte, 1*(1<<20))
   246  		_, err := r.Read(bs)
   247  		require.NoError(t, err, "making random data")
   248  		err = ioutil.WriteFile(path, bs, 0644)
   249  		require.NoError(t, err, "writing file")
   250  		err = os.Remove(path)
   251  		require.NoError(t, err, "removing file")
   252  		_, err = os.Stat(path)
   253  		require.True(t, os.IsNotExist(err), "file was not removed")
   254  		_, err = os.Stat(rootPath)
   255  		require.True(t, os.IsNotExist(err), "file was not removed in root")
   256  	})
   257  }
   258  
   259  // TestDirListing verifies that the directory listing of a file is updated when
   260  // the file is modified.
   261  func TestDirListing(t *testing.T) {
   262  	withTestMounts(t, func(m testMount) {
   263  		path := file.Join(m.mountPoint, "test")
   264  		// assertSize asserts that the listed FileInfo of the file at path reports
   265  		// the given size.
   266  		assertSize := func(size int64) {
   267  			infos, err := ioutil.ReadDir(m.mountPoint)
   268  			require.NoError(t, err, "listing directory")
   269  			require.Equal(t, 1, len(infos), "should only be one file in directory")
   270  			assert.Equal(t, size, infos[0].Size(), "file should be 3 bytes")
   271  		}
   272  
   273  		// Write a 3-byte file, and verify that its listing has the correct size.
   274  		require.NoError(t, ioutil.WriteFile(path, make([]byte, 3), 0644), "writing file")
   275  		assertSize(3)
   276  
   277  		// Overwrite it to be 1 byte, and verify that the listing is updated.
   278  		require.NoError(t, ioutil.WriteFile(path, make([]byte, 1), 0644), "overwriting file")
   279  		assertSize(1)
   280  
   281  		// Append 3 bytes, and verify that the listing is updated.
   282  		f, err := os.OpenFile(path, os.O_WRONLY|os.O_APPEND, 0644)
   283  		require.NoError(t, err, "opening file for append")
   284  		_, err = f.Write(make([]byte, 3))
   285  		require.NoError(t, err, "appending to file")
   286  		require.NoError(t, f.Close(), "closing file")
   287  		assertSize(4)
   288  	})
   289  }
   290  
   291  // TestMkdir verifies that we can make a directory.
   292  func TestMkdir(t *testing.T) {
   293  	withTestMounts(t, func(m testMount) {
   294  		var (
   295  			r    = rand.New(rand.NewSource(0))
   296  			path = filepath.Join(m.mountPoint, "test-dir")
   297  		)
   298  		err := os.Mkdir(path, 0775)
   299  		require.NoError(t, err, "making directory")
   300  
   301  		filePath := filepath.Join(path, "test")
   302  		rootFilePath := file.Join(m.root, "test-dir", "test")
   303  		assertRoundTrip(t, filePath, rootFilePath, r, 0)
   304  
   305  		info, err := os.Stat(path)
   306  		require.NoError(t, err, "getting file info of created directory")
   307  		require.True(t, info.IsDir(), "created directory is not a directory")
   308  	})
   309  }
   310  
   311  func withTestMounts(t *testing.T, f func(m testMount)) {
   312  	type makeRootFunc func(*testing.T) (string, func())
   313  	makeRoots := map[string]makeRootFunc{
   314  		"local": func(t *testing.T) (string, func()) {
   315  			return testutil.TempDir(t, "", "gfilefs-mnt")
   316  		},
   317  	}
   318  	if *s3RootFlag != "" {
   319  		makeRoots["s3"] = func(t *testing.T) (string, func()) {
   320  			ctx := context.Background()
   321  			lister := file.List(ctx, *s3RootFlag, true)
   322  			exists := lister.Scan()
   323  			if exists {
   324  				t.Logf("path exists: %s", lister.Path())
   325  			}
   326  			require.NoErrorf(t, lister.Err(), "listing %s", *s3RootFlag)
   327  			require.False(t, exists)
   328  			return *s3RootFlag, func() {
   329  				err := forEachFile(ctx, *s3RootFlag, func(path string) error {
   330  					return file.Remove(ctx, path)
   331  				})
   332  				require.NoError(t, err, "cleaning up test root")
   333  			}
   334  		}
   335  	}
   336  	for name, makeRoot := range makeRoots {
   337  		t.Run(name, func(t *testing.T) {
   338  			root, rootCleanUp := makeRoot(t)
   339  			defer rootCleanUp()
   340  			mountPoint, mountPointCleanUp := testutil.TempDir(t, "", "gfilefs-mnt")
   341  			defer mountPointCleanUp()
   342  			server, err := fs.Mount(
   343  				mountPoint,
   344  				fsnodefuse.NewRoot(gfilefs.New(root, "root")),
   345  				// TODO: Set fsnodefuse.ConfigureRequiredMountOptions.
   346  				&fs.Options{
   347  					MountOptions: fuse.MountOptions{
   348  						FsName:        "test",
   349  						DisableXAttrs: true,
   350  						Debug:         true,
   351  						MaxBackground: 1024,
   352  					},
   353  				},
   354  			)
   355  			require.NoError(t, err, "mounting %q", mountPoint)
   356  			defer func() {
   357  				log.Printf("unmounting %q", mountPoint)
   358  				assert.NoError(t, server.Unmount(),
   359  					"unmount of FUSE mounted at %q failed; may need manual cleanup",
   360  					mountPoint,
   361  				)
   362  				log.Printf("unmounted %q", mountPoint)
   363  			}()
   364  			f(testMount{root: root, mountPoint: mountPoint})
   365  		})
   366  	}
   367  }
   368  
   369  type testMount struct {
   370  	// root is the root path that is mounted at dir.
   371  	root string
   372  	// mountPoint is the FUSE mount point.
   373  	mountPoint string
   374  }
   375  
   376  // forEachFile runs the callback for every file under the directory in
   377  // parallel.  It returns any of the errors returned by the callback.  It is
   378  // cribbed from github.com/grailbio/base/cmd/grail-file/cmd.
   379  func forEachFile(ctx context.Context, dir string, callback func(path string) error) error {
   380  	const parallelism = 32
   381  	err := errors.Once{}
   382  	wg := sync.WaitGroup{}
   383  	ch := make(chan string, parallelism*100)
   384  	for i := 0; i < parallelism; i++ {
   385  		wg.Add(1)
   386  		go func() {
   387  			for path := range ch {
   388  				err.Set(callback(path))
   389  			}
   390  			wg.Done()
   391  		}()
   392  	}
   393  
   394  	lister := file.List(ctx, dir, true /*recursive*/)
   395  	for lister.Scan() {
   396  		if !lister.IsDir() {
   397  			ch <- lister.Path()
   398  		}
   399  	}
   400  	close(ch)
   401  	err.Set(lister.Err())
   402  	wg.Wait()
   403  	return err.Err()
   404  }
   405  
   406  func assertRoundTrip(t *testing.T, path, rootPath string, r *rand.Rand, size int) {
   407  	bs := make([]byte, size)
   408  	_, err := r.Read(bs)
   409  	require.NoError(t, err, "making random data")
   410  	err = ioutil.WriteFile(path, bs, 0644)
   411  	require.NoError(t, err, "writing file")
   412  
   413  	got, err := ioutil.ReadFile(path)
   414  	require.NoError(t, err, "reading file back")
   415  	assert.Equal(t, bs, got, "data read != data written")
   416  
   417  	info, err := os.Stat(path)
   418  	require.NoError(t, err, "stat of file")
   419  	assert.Equal(t, int64(len(bs)), info.Size(), "len(data read) != len(data written)")
   420  
   421  	// Verify that the file is written correctly to mounted root.
   422  	got, err = file.ReadFile(context.Background(), rootPath)
   423  	require.NoErrorf(t, err, "reading file in root %s back", rootPath)
   424  	assert.Equal(t, bs, got, "data read != data written")
   425  }