github.com/grailbio/base@v0.0.11/file/file_test.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package file_test
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"flag"
    11  	"fmt"
    12  	"io"
    13  	"math/rand"
    14  	"sync"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/grailbio/base/file"
    19  	"github.com/grailbio/base/file/s3file"
    20  	"github.com/grailbio/testutil"
    21  	"github.com/grailbio/testutil/assert"
    22  )
    23  
    24  type errFile struct {
    25  	err error
    26  }
    27  
    28  func (f *errFile) String() string { return f.err.Error() }
    29  
    30  func (f *errFile) Open(ctx context.Context, path string, opts ...file.Opts) (file.File, error) {
    31  	return nil, f.err
    32  }
    33  
    34  func (f *errFile) Create(ctx context.Context, path string, opts ...file.Opts) (file.File, error) {
    35  	return nil, f.err
    36  }
    37  
    38  func (f *errFile) List(ctx context.Context, dir string, recursive bool) file.Lister {
    39  	return nil
    40  }
    41  
    42  func (f *errFile) Stat(ctx context.Context, path string, opts ...file.Opts) (file.Info, error) {
    43  	return nil, f.err
    44  }
    45  
    46  func (f *errFile) Remove(ctx context.Context, path string) error {
    47  	return f.err
    48  }
    49  
    50  func (f *errFile) Presign(ctx context.Context, path, method string, expiry time.Duration) (string, error) {
    51  	return "", f.err
    52  }
    53  
    54  func (f *errFile) Close(ctx context.Context) error {
    55  	return f.err
    56  }
    57  
    58  func TestRegistration(t *testing.T) {
    59  	testImpl := &errFile{errors.New("test")}
    60  	file.RegisterImplementation("foo", func() file.Implementation { return testImpl })
    61  	assert.True(t, file.FindImplementation("") != nil)
    62  	assert.True(t, file.FindImplementation("foo") == testImpl)
    63  	assert.True(t, file.FindImplementation("foo2") == nil)
    64  }
    65  
    66  func doReadFile(ctx context.Context, path string) string {
    67  	got, err := file.ReadFile(ctx, path)
    68  	if err != nil {
    69  		return err.Error()
    70  	}
    71  	return string(got)
    72  }
    73  
    74  func TestReadWriteFile(t *testing.T) {
    75  	tempDir, cleanup := testutil.TempDir(t, "", "")
    76  	defer cleanup()
    77  
    78  	ctx := context.Background()
    79  	path := file.Join(tempDir, "test.txt")
    80  	data := "Hello, olleh"
    81  	assert.NoError(t, file.WriteFile(ctx, path, []byte(data)))
    82  	assert.EQ(t, data, doReadFile(ctx, path))
    83  }
    84  
    85  func TestRemoveAllNonexistent(t *testing.T) {
    86  	tempDir, cleanup := testutil.TempDir(t, "", "")
    87  	defer cleanup()
    88  	ctx := context.Background()
    89  	assert.NoError(t, file.RemoveAll(ctx, file.Join(tempDir, "baddir")))
    90  }
    91  
    92  func TestRemoveAllRegularFile(t *testing.T) {
    93  	tempDir, cleanup := testutil.TempDir(t, "", "")
    94  	defer cleanup()
    95  	ctx := context.Background()
    96  
    97  	path := file.Join(tempDir, "test.txt")
    98  	data := "Hello, olleh"
    99  	assert.NoError(t, file.WriteFile(ctx, path, []byte(data)))
   100  	assert.EQ(t, data, doReadFile(ctx, path))
   101  	assert.NoError(t, file.RemoveAll(ctx, path))
   102  	assert.Regexp(t, doReadFile(ctx, path), "no such file")
   103  }
   104  
   105  func TestRemoveAllRecursive(t *testing.T) {
   106  	tempDir, cleanup := testutil.TempDir(t, "", "")
   107  	defer cleanup()
   108  	ctx := context.Background()
   109  
   110  	dir := file.Join(tempDir, "d")
   111  	data := "Hello, olleh"
   112  	assert.NoError(t, file.WriteFile(ctx, file.Join(dir, "file.txt"), []byte(data)))
   113  	assert.NoError(t, file.WriteFile(ctx, file.Join(dir, "e/file.txt"), []byte(data)))
   114  	assert.NoError(t, file.RemoveAll(ctx, dir))
   115  	assert.Regexp(t, doReadFile(ctx, file.Join(dir, "file.txt")), "no such file")
   116  	assert.Regexp(t, doReadFile(ctx, file.Join(dir, "e/file.txt")), "no such file")
   117  }
   118  
   119  func ExampleParsePath() {
   120  	parse := func(path string) {
   121  		scheme, suffix, err := file.ParsePath(path)
   122  		if err != nil {
   123  			fmt.Printf("%s 🢥 error %v\n", path, err)
   124  			return
   125  		}
   126  		fmt.Printf("%s 🢥 scheme \"%s\", suffix \"%s\"\n", path, scheme, suffix)
   127  	}
   128  	parse("/tmp/test")
   129  	parse("foo://bar")
   130  	parse("foo:///bar")
   131  	parse("foo:bar")
   132  	parse("/foo:bar")
   133  	// Output:
   134  	// /tmp/test 🢥 scheme "", suffix "/tmp/test"
   135  	// foo://bar 🢥 scheme "foo", suffix "bar"
   136  	// foo:///bar 🢥 scheme "foo", suffix "/bar"
   137  	// foo:bar 🢥 error parsepath foo:bar: a URL must start with 'scheme://'
   138  	// /foo:bar 🢥 scheme "", suffix "/foo:bar"
   139  }
   140  
   141  func ExampleBase() {
   142  	fmt.Println(file.Base(""))
   143  	fmt.Println(file.Base("foo1"))
   144  	fmt.Println(file.Base("foo2/"))
   145  	fmt.Println(file.Base("/"))
   146  	fmt.Println(file.Base("s3://"))
   147  	fmt.Println(file.Base("s3://blah1"))
   148  	fmt.Println(file.Base("s3://blah2/"))
   149  	fmt.Println(file.Base("s3://foo/blah3//"))
   150  	// Output:
   151  	// .
   152  	// foo1
   153  	// foo2
   154  	// /
   155  	// s3://
   156  	// blah1
   157  	// blah2
   158  	// blah3
   159  }
   160  
   161  func ExampleDir() {
   162  	fmt.Println(file.Dir("foo"))
   163  	fmt.Println(file.Dir("."))
   164  	fmt.Println(file.Dir("/a/b"))
   165  	fmt.Println(file.Dir("a/b"))
   166  	fmt.Println(file.Dir("s3://ab/cd"))
   167  	fmt.Println(file.Dir("s3://ab//cd"))
   168  	fmt.Println(file.Dir("s3://a/b/"))
   169  	fmt.Println(file.Dir("s3://a/b//"))
   170  	fmt.Println(file.Dir("s3://a//b//"))
   171  	fmt.Println(file.Dir("s3://a"))
   172  	// Output:
   173  	// .
   174  	// .
   175  	// /a
   176  	// a
   177  	// s3://ab
   178  	// s3://ab
   179  	// s3://a/b
   180  	// s3://a/b
   181  	// s3://a//b
   182  	// s3://
   183  }
   184  
   185  func ExampleJoin() {
   186  	fmt.Println(file.Join())
   187  	fmt.Println(file.Join(""))
   188  	fmt.Println(file.Join("foo", "bar"))
   189  	fmt.Println(file.Join("foo", ""))
   190  	fmt.Println(file.Join("foo", "/bar/"))
   191  	fmt.Println(file.Join(".", "foo:bar"))
   192  	fmt.Println(file.Join("s3://foo"))
   193  	fmt.Println(file.Join("s3://foo", "/bar/"))
   194  	fmt.Println(file.Join("s3://foo", "", "bar"))
   195  	fmt.Println(file.Join("s3://foo", "0"))
   196  	fmt.Println(file.Join("s3://foo", "abc"))
   197  	fmt.Println(file.Join("s3://foo//bar", "/", "/baz"))
   198  	// Output:
   199  	// foo/bar
   200  	// foo
   201  	// foo/bar
   202  	// ./foo:bar
   203  	// s3://foo
   204  	// s3://foo/bar
   205  	// s3://foo/bar
   206  	// s3://foo/0
   207  	// s3://foo/abc
   208  	// s3://foo//bar/baz
   209  }
   210  
   211  func ExampleIsAbs() {
   212  	fmt.Println(file.IsAbs("foo"))
   213  	fmt.Println(file.IsAbs("/foo"))
   214  	fmt.Println(file.IsAbs("s3://foo"))
   215  	// Output:
   216  	// false
   217  	// true
   218  	// true
   219  }
   220  
   221  var once = sync.Once{}
   222  
   223  func initBenchmark() {
   224  	once.Do(func() {
   225  		file.RegisterImplementation("s3",
   226  			func() file.Implementation {
   227  				return s3file.NewImplementation(s3file.NewDefaultProvider(), s3file.Options{})
   228  			})
   229  	})
   230  }
   231  
   232  var (
   233  	writeFlag  = flag.String("write", "", "Path of the file used by write benchmark.")
   234  	sizeFlag   = flag.Int64("size", 16<<20, "# of bytes to write during benchmark.")
   235  	verifyFlag = flag.Bool("verify", true, "Verify contents of the file created by the write benchmark")
   236  )
   237  
   238  func BenchmarkWrite(b *testing.B) {
   239  	initBenchmark()
   240  	if *writeFlag == "" {
   241  		b.Skip("--write flag not set")
   242  	}
   243  
   244  	// buf := make([]byte, 64<<10)
   245  	ctx := context.Background()
   246  	b.Logf("Writing %d bytes to %s", *sizeFlag, *writeFlag)
   247  	buf := make([]byte, 64<<10)
   248  	for i := 0; i < b.N; i++ {
   249  		rnd := rand.New(rand.NewSource(0))
   250  		f, err := file.Create(ctx, *writeFlag)
   251  		assert.NoError(b, err)
   252  		w := f.Writer(ctx)
   253  
   254  		total := int64(0)
   255  		for total < *sizeFlag {
   256  			b.StopTimer()
   257  			_, err := io.ReadFull(rnd, buf)
   258  			assert.NoError(b, err)
   259  			b.StartTimer()
   260  
   261  			n, err := w.Write(buf)
   262  			assert.NoError(b, err)
   263  			assert.EQ(b, n, len(buf))
   264  			total += int64(n)
   265  		}
   266  		assert.NoError(b, f.Close(ctx))
   267  	}
   268  
   269  	if *verifyFlag {
   270  		rnd := rand.New(rand.NewSource(0))
   271  		b.StopTimer()
   272  		f, err := file.Open(ctx, *writeFlag)
   273  		assert.NoError(b, err)
   274  		r := f.Reader(ctx)
   275  		total := int64(0)
   276  		expected := make([]byte, 64<<10)
   277  		got := make([]byte, 64<<10)
   278  
   279  		for total < *sizeFlag {
   280  			_, err := io.ReadFull(rnd, expected)
   281  			assert.NoError(b, err)
   282  			_, err = io.ReadFull(r, got)
   283  			assert.NoError(b, err)
   284  			assert.EQ(b, expected, got)
   285  			total += int64(len(got))
   286  		}
   287  		assert.NoError(b, f.Close(ctx))
   288  		b.StartTimer()
   289  	}
   290  }