github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/cache_test.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  package bigslice_test
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"io/ioutil"
    10  	"os"
    11  	"path/filepath"
    12  	"reflect"
    13  	"sort"
    14  	"sync/atomic"
    15  	"testing"
    16  
    17  	"github.com/grailbio/base/errors"
    18  	"github.com/grailbio/base/file"
    19  	"github.com/grailbio/base/log"
    20  	"github.com/grailbio/bigslice"
    21  	"github.com/grailbio/bigslice/exec"
    22  	"github.com/grailbio/bigslice/sliceio"
    23  	"github.com/grailbio/bigslice/slicetest"
    24  	"github.com/grailbio/bigslice/slicetype"
    25  	"github.com/grailbio/testutil"
    26  )
    27  
    28  func TestCache(t *testing.T) {
    29  	makeSlice := func(n, nShard int, dir string, computeAllowed bool) bigslice.Slice {
    30  		input := make([]int, n)
    31  		for i := range input {
    32  			input[i] = i
    33  		}
    34  		slice := bigslice.Const(nShard, input)
    35  		slice = bigslice.Map(slice, func(i int) int {
    36  			if !computeAllowed {
    37  				panic("compute not allowed")
    38  			}
    39  			return i * 2
    40  		})
    41  		ctx := context.Background()
    42  		slice = bigslice.Cache(ctx, slice, filepath.Join(dir, "cached"))
    43  		return slice
    44  	}
    45  	runTestCache(t, makeSlice)
    46  }
    47  
    48  // TestCacheDeps verifies that caching works when pipelined tasks have non-empty
    49  // dependencies. When the cache is valid, we do not need to read from these
    50  // dependencies. Verify that this does not break compilation or execution (e.g.
    51  // empty dependencies given to tasks that expect non-empty dependencies).
    52  func TestCacheDeps(t *testing.T) {
    53  	exec.DoShuffleReaders = false
    54  	makeSlice := func(n, nShard int, dir string, computeAllowed bool) bigslice.Slice {
    55  		input := make([]int, n)
    56  		for i := range input {
    57  			input[i] = i
    58  		}
    59  		slice := bigslice.Const(nShard, input)
    60  		// This shuffle causes a break in the pipeline, so the pipelined task
    61  		// will have a dependency on the Const slice tasks. Caching should cause
    62  		// compilation/execution to eliminate these dependencies safely.
    63  		slice = bigslice.Reshuffle(slice)
    64  		slice = bigslice.Map(slice, func(i int) int {
    65  			if !computeAllowed {
    66  				panic("compute not allowed")
    67  			}
    68  			return i * 2
    69  		})
    70  		ctx := context.Background()
    71  		slice = bigslice.Cache(ctx, slice, filepath.Join(dir, "cached"))
    72  		return slice
    73  	}
    74  	runTestCache(t, makeSlice)
    75  }
    76  
    77  // runTestCache verifies that the caching in the slice returned by makeSlice
    78  // behaves as expected. See usage in TestCache.
    79  func runTestCache(t *testing.T, makeSlice func(n, nShard int, dir string, computeAllowed bool) bigslice.Slice) {
    80  	dir, cleanUp := testutil.TempDir(t, "", "")
    81  	defer cleanUp()
    82  	ctx := context.Background()
    83  
    84  	const (
    85  		N      = 10000
    86  		Nshard = 10
    87  	)
    88  	slice1 := makeSlice(N, Nshard, dir, true)
    89  	if got, want := len(ls1(t, dir)), 0; got != want {
    90  		t.Errorf("got %v, want %v", got, want)
    91  	}
    92  	scan1 := runLocal(ctx, t, slice1)
    93  	defer scan1.Close()
    94  	if got, want := len(ls1(t, dir)), Nshard; got != want {
    95  		t.Errorf("got %v [%v], want %v", got, ls1(t, dir), want)
    96  	}
    97  
    98  	// Recompute the slice to pick up the cached results.
    99  	slice2 := makeSlice(N, Nshard, dir, false)
   100  	scan2 := runLocal(ctx, t, slice2)
   101  	defer scan2.Close()
   102  	if got, want := len(ls1(t, dir)), Nshard; got != want {
   103  		t.Errorf("got %v [%v], want %v", got, ls1(t, dir), want)
   104  	}
   105  
   106  	v1 := scanInts(ctx, t, scan1)
   107  	v2 := scanInts(ctx, t, scan2)
   108  	if got, want := len(v1), N; got != want {
   109  		t.Errorf("got %v, want %v", got, want)
   110  	}
   111  	if !reflect.DeepEqual(v1, v2) {
   112  		t.Errorf("corrupt cache")
   113  	}
   114  }
   115  
   116  func TestCacheIncremental(t *testing.T) {
   117  	dir, cleanUp := testutil.TempDir(t, "", "")
   118  	defer cleanUp()
   119  	ctx := context.Background()
   120  
   121  	const (
   122  		N      = 10000
   123  		Nshard = 10
   124  	)
   125  
   126  	rowsRan := make([]bool, N)
   127  
   128  	input := make([]int, N)
   129  	for i := range input {
   130  		input[i] = i
   131  	}
   132  	makeSlice := func() bigslice.Slice {
   133  		slice := bigslice.Const(Nshard, input)
   134  		slice = bigslice.Map(slice, func(i int) int {
   135  			rowsRan[i] = true
   136  			return i * 2
   137  		})
   138  		slice = bigslice.Cache(ctx, slice, filepath.Join(dir, "cached"))
   139  		return slice
   140  	}
   141  
   142  	// Run and populate the cache.
   143  	_ = runLocal(ctx, t, makeSlice())
   144  	if got, want := len(ls1(t, dir)), Nshard; got != want {
   145  		t.Errorf("got %v [%v], want %v", got, ls1(t, dir), want)
   146  	}
   147  
   148  	// Run and ensure there's no new computation.
   149  	for i := range rowsRan {
   150  		rowsRan[i] = false
   151  	}
   152  	_ = runLocal(ctx, t, makeSlice())
   153  	if got, want := len(ls1(t, dir)), Nshard; got != want {
   154  		t.Errorf("got %v [%v], want %v", got, ls1(t, dir), want)
   155  	}
   156  	for _, ran := range rowsRan {
   157  		if ran {
   158  			t.Error("want cache use")
   159  		}
   160  	}
   161  
   162  	// Delete some cache entries and ensure there's recomputation.
   163  	for i, f := range ls1(t, dir) {
   164  		if i%2 == 0 {
   165  			continue
   166  		}
   167  		if err := os.Remove(filepath.Join(dir, f)); err != nil {
   168  			t.Error(err)
   169  		}
   170  	}
   171  	for i := range rowsRan {
   172  		rowsRan[i] = false
   173  	}
   174  	_ = runLocal(ctx, t, makeSlice())
   175  	if got, want := len(ls1(t, dir)), Nshard; got != want {
   176  		t.Errorf("got %v [%v], want %v", got, ls1(t, dir), want)
   177  	}
   178  	var nRans int
   179  	for _, ran := range rowsRan {
   180  		if ran {
   181  			nRans++
   182  		}
   183  	}
   184  	if nRans < Nshard {
   185  		t.Error("want all recompution")
   186  	}
   187  }
   188  
   189  func TestCachePartialIncremental(t *testing.T) {
   190  	dir, cleanUp := testutil.TempDir(t, "", "")
   191  	defer cleanUp()
   192  	ctx := context.Background()
   193  
   194  	const (
   195  		N      = 10000
   196  		Nshard = 10
   197  	)
   198  
   199  	rowsRan := make([]bool, N)
   200  
   201  	input := make([]int, N)
   202  	for i := range input {
   203  		input[i] = i
   204  	}
   205  	makeSlice := func() bigslice.Slice {
   206  		slice := bigslice.Const(Nshard, input)
   207  		slice = bigslice.Map(slice, func(i int) int {
   208  			rowsRan[i] = true
   209  			return i * 2
   210  		})
   211  		slice = bigslice.CachePartial(ctx, slice, filepath.Join(dir, "cached"))
   212  		return slice
   213  	}
   214  
   215  	// Run and populate the cache.
   216  	_ = runLocal(ctx, t, makeSlice())
   217  	if got, want := len(ls1(t, dir)), Nshard; got != want {
   218  		t.Errorf("got %v [%v], want %v", got, ls1(t, dir), want)
   219  	}
   220  
   221  	// Run and ensure there's no new computation.
   222  	for i := range rowsRan {
   223  		rowsRan[i] = false
   224  	}
   225  	_ = runLocal(ctx, t, makeSlice())
   226  	if got, want := len(ls1(t, dir)), Nshard; got != want {
   227  		t.Errorf("got %v [%v], want %v", got, ls1(t, dir), want)
   228  	}
   229  	for _, ran := range rowsRan {
   230  		if ran {
   231  			t.Error("want cache use")
   232  		}
   233  	}
   234  
   235  	// Delete some cache entries and ensure there's partial recomputation.
   236  	for i, f := range ls1(t, dir) {
   237  		if i%2 == 0 {
   238  			continue
   239  		}
   240  		if err := os.Remove(filepath.Join(dir, f)); err != nil {
   241  			t.Error(err)
   242  		}
   243  	}
   244  	for i := range rowsRan {
   245  		rowsRan[i] = false
   246  	}
   247  	_ = runLocal(ctx, t, makeSlice())
   248  	if got, want := len(ls1(t, dir)), Nshard; got != want {
   249  		t.Errorf("got %v [%v], want %v", got, ls1(t, dir), want)
   250  	}
   251  	var nRowsRan int
   252  	for _, ran := range rowsRan {
   253  		if ran {
   254  			nRowsRan++
   255  		}
   256  	}
   257  	if nRowsRan == 0 || nRowsRan >= N {
   258  		t.Errorf("want partial recomputation, got %d of %d rows", nRowsRan, N)
   259  	}
   260  }
   261  
   262  func TestCacheErr(t *testing.T) {
   263  	dir, cleanUp := testutil.TempDir(t, "", "")
   264  	defer cleanUp()
   265  	ctx := context.Background()
   266  
   267  	computeRan := false
   268  
   269  	makeSlice := func() bigslice.Slice {
   270  		slice := bigslice.ReaderFunc(1, func(shard int, state *bool, ints []int) (n int, err error) {
   271  			if *state {
   272  				return 0, errors.New("random error")
   273  			}
   274  			for i := range ints {
   275  				ints[i] = i
   276  			}
   277  			*state = true
   278  			computeRan = true
   279  			return len(ints), nil
   280  		})
   281  		slice = bigslice.Cache(ctx, slice, file.Join(dir, "cached"))
   282  		return slice
   283  	}
   284  	if err := slicetest.RunErr(makeSlice()); err == nil {
   285  		t.Error("expected error")
   286  	}
   287  	if !computeRan {
   288  		t.Error()
   289  	}
   290  	// Ensure computation is rerun after error.
   291  	if err := slicetest.RunErr(makeSlice()); err == nil {
   292  		t.Error("expected error")
   293  	}
   294  	if !computeRan {
   295  		t.Error()
   296  	}
   297  }
   298  
   299  // TestReadCache verifies that ReadCache successfully reads from an existing cache.
   300  func TestReadCache(t *testing.T) {
   301  	dir, cleanUp := testutil.TempDir(t, "", "")
   302  	defer cleanUp()
   303  	prefix := filepath.Join(dir, "cached")
   304  	ctx := context.Background()
   305  
   306  	const (
   307  		N      = 10000
   308  		Nshard = 10
   309  	)
   310  	input := make([]int, N)
   311  	for i := range input {
   312  		input[i] = i
   313  	}
   314  	slice1 := bigslice.Const(Nshard, input)
   315  	slice1 = bigslice.Cache(ctx, slice1, prefix)
   316  	scan1 := runLocal(ctx, t, slice1)
   317  	defer scan1.Close()
   318  
   319  	// We now have a populated cache. Read from it, and make sure we get the
   320  	// same results.
   321  	slice2 := bigslice.ReadCache(ctx, slice1, slice1.NumShard(), prefix)
   322  	scan2 := runLocal(ctx, t, slice2)
   323  
   324  	v1 := scanInts(ctx, t, scan1)
   325  	v2 := scanInts(ctx, t, scan2)
   326  	if got, want := len(v1), N; got != want {
   327  		t.Errorf("got %v, want %v", got, want)
   328  	}
   329  	if !reflect.DeepEqual(v1, v2) {
   330  		t.Errorf("corrupt cache")
   331  	}
   332  }
   333  
   334  // TestReadCacheError verifies that a ReadCache reader returns an error if the
   335  // cache does not exist.
   336  func TestReadCacheError(t *testing.T) {
   337  	dir, cleanUp := testutil.TempDir(t, "", "")
   338  	defer cleanUp()
   339  	var (
   340  		prefix = filepath.Join(dir, "cached")
   341  		ctx    = context.Background()
   342  		slice  = bigslice.ReadCache(ctx, slicetype.New(reflect.TypeOf(0)), 1, prefix)
   343  		fn     = bigslice.Func(func() bigslice.Slice { return slice })
   344  		sess   = exec.Start(exec.Local)
   345  	)
   346  	defer sess.Shutdown()
   347  	_, err := sess.Run(ctx, fn)
   348  	if err == nil {
   349  		t.Errorf("expected error when reading from non-existent cache")
   350  	}
   351  }
   352  
   353  func ls1(t *testing.T, dir string) []string {
   354  	t.Helper()
   355  	d, err := os.Open(dir)
   356  	if err != nil {
   357  		t.Fatal(err)
   358  	}
   359  	infos, err := d.Readdir(-1)
   360  	if err != nil {
   361  		t.Fatal(err)
   362  	}
   363  	paths := make([]string, len(infos))
   364  	for i := range paths {
   365  		paths[i] = infos[i].Name()
   366  	}
   367  	sort.Strings(paths)
   368  	return paths
   369  }
   370  
   371  func runLocal(ctx context.Context, t *testing.T, slice bigslice.Slice) *sliceio.Scanner {
   372  	t.Helper()
   373  	fn := bigslice.Func(func() bigslice.Slice { return slice })
   374  	sess := exec.Start(exec.Local)
   375  	defer sess.Shutdown()
   376  	res, err := sess.Run(ctx, fn)
   377  	if err != nil {
   378  		t.Fatalf("error running func: %v", err)
   379  	}
   380  	return res.Scanner()
   381  }
   382  
   383  func scanInts(ctx context.Context, t *testing.T, scan *sliceio.Scanner) []int {
   384  	t.Helper()
   385  	var (
   386  		v  int
   387  		vs []int
   388  	)
   389  	for scan.Scan(ctx, &v) {
   390  		vs = append(vs, v)
   391  	}
   392  	if err := scan.Err(); err != nil {
   393  		t.Fatalf("scan error: %v", err)
   394  	}
   395  	sort.Ints(vs)
   396  	return vs
   397  }
   398  
   399  func ExampleCache() {
   400  	// Compute a slice that performs a mapping computation and uses Cache to
   401  	// cache the result, showing that we had to execute the mapping computation.
   402  	// Compute another slice that uses the cache, showing that we produced the
   403  	// same result without executing the mapping computation again.
   404  	dir, err := ioutil.TempDir("", "example-cache")
   405  	if err != nil {
   406  		log.Fatalf("could not create temp directory: %v", err)
   407  	}
   408  	defer os.RemoveAll(dir)
   409  	slice := bigslice.Const(2, []int{0, 1, 2, 3})
   410  	// slicetest.Print uses local evaluation, so we can use shared memory across
   411  	// all shard computations.
   412  	var computed atomic.Value
   413  	computed.Store(false)
   414  	slice = bigslice.Map(slice, func(x int) int {
   415  		computed.Store(true)
   416  		return x
   417  	})
   418  	// The first evaluation causes the map to be evaluated.
   419  	slice0 := bigslice.Cache(context.Background(), slice, dir+"/")
   420  	fmt.Println("# first evaluation")
   421  	slicetest.Print(slice0)
   422  	fmt.Printf("computed: %t\n", computed.Load().(bool))
   423  
   424  	// Reset the computed state for our second evaluation. The second evaluation
   425  	// will read from the cache that was written by the first evaluation, so the
   426  	// map will not be evaluated.
   427  	computed.Store(false)
   428  	slice1 := bigslice.Cache(context.Background(), slice, dir+"/")
   429  	fmt.Println("# second evaluation")
   430  	slicetest.Print(slice1)
   431  	fmt.Printf("computed: %t\n", computed.Load().(bool))
   432  	// Output:
   433  	// # first evaluation
   434  	// 0
   435  	// 1
   436  	// 2
   437  	// 3
   438  	// computed: true
   439  	// # second evaluation
   440  	// 0
   441  	// 1
   442  	// 2
   443  	// 3
   444  	// computed: false
   445  }
   446  
   447  func ExampleCachePartial() {
   448  	// Compute a slice that performs a mapping computation and uses Cache to
   449  	// cache the result, showing that we had to execute the mapping computation
   450  	// for each row. Manually remove only part of the cached data. Compute
   451  	// another slice that uses the cache, showing that we produced the same
   452  	// result, only executing the mapping computation on the rows whose data we
   453  	// removed from the cache.
   454  	dir, err := ioutil.TempDir("", "example-cache-partial")
   455  	if err != nil {
   456  		log.Fatalf("could not create temp directory: %v", err)
   457  	}
   458  	defer os.RemoveAll(dir)
   459  	slice := bigslice.Const(2, []int{0, 1, 2, 3})
   460  	// slicetest.Print uses local evaluation, so we can use shared memory across
   461  	// all shard computations.
   462  	var computed int32
   463  	slice = bigslice.Map(slice, func(x int) int {
   464  		atomic.AddInt32(&computed, 1)
   465  		return x
   466  	})
   467  	// The first evaluation causes the map to be evaluated.
   468  	slice0 := bigslice.CachePartial(context.Background(), slice, dir+"/")
   469  	fmt.Println("# first evaluation")
   470  	slicetest.Print(slice0)
   471  	fmt.Printf("computed: %d\n", computed)
   472  
   473  	// Remove one of the cache files. This will leave us with a partial cache,
   474  	// i.e. a cache with only some shards cached.
   475  	infos, err := ioutil.ReadDir(dir)
   476  	if err != nil {
   477  		log.Fatalf("error reading temp dir %s: %v", dir, err)
   478  	}
   479  	path := filepath.Join(dir, infos[0].Name())
   480  	if err = os.Remove(path); err != nil {
   481  		log.Fatalf("error removing cache file %s: %v", path, err)
   482  	}
   483  
   484  	// Reset the computed state for our second evaluation. The second evaluation
   485  	// will read from the partial cache that was written by the first
   486  	// evaluation, so only some rows will need recomputation.
   487  	computed = 0
   488  	slice1 := bigslice.CachePartial(context.Background(), slice, dir+"/")
   489  	fmt.Println("# second evaluation")
   490  	slicetest.Print(slice1)
   491  	fmt.Printf("computed: %d\n", computed)
   492  
   493  	// Note that this example is fragile for a couple of reasons. First, it
   494  	// relies on how the cache is stored in files. If that changes, we may need
   495  	// to change how we construct a partial cache. Second, it relies on the
   496  	// stability of the shard allocation. If that changes, we may end up with
   497  	// different sharding and a different number of rows needing computation.
   498  
   499  	// Output:
   500  	// # first evaluation
   501  	// 0
   502  	// 1
   503  	// 2
   504  	// 3
   505  	// computed: 4
   506  	// # second evaluation
   507  	// 0
   508  	// 1
   509  	// 2
   510  	// 3
   511  	// computed: 2
   512  }
   513  
   514  func ExampleReadCache() {
   515  	// Compute a slice that uses Cache to cache the result. Use ReadCache to
   516  	// read from that same cache. Observe that we get the same data.
   517  	const numShards = 2
   518  	dir, err := ioutil.TempDir("", "example-cache")
   519  	if err != nil {
   520  		log.Fatalf("could not create temp directory: %v", err)
   521  	}
   522  	defer os.RemoveAll(dir)
   523  	slice0 := bigslice.Const(numShards, []int{0, 1, 2, 3})
   524  	slice0 = bigslice.Cache(context.Background(), slice0, dir+"/")
   525  	fmt.Println("# build cache")
   526  	slicetest.Print(slice0)
   527  
   528  	slice1 := bigslice.ReadCache(context.Background(), slice0, numShards, dir+"/")
   529  	fmt.Println("# use ReadCache to read cache")
   530  	slicetest.Print(slice1)
   531  	// Output:
   532  	// # build cache
   533  	// 0
   534  	// 1
   535  	// 2
   536  	// 3
   537  	// # use ReadCache to read cache
   538  	// 0
   539  	// 1
   540  	// 2
   541  	// 3
   542  }