github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/exec/compile_test.go (about)

     1  // Copyright 2019 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package exec
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"fmt"
    11  	"io/ioutil"
    12  	"sort"
    13  	"strings"
    14  	"testing"
    15  
    16  	"github.com/grailbio/bigslice"
    17  	"github.com/grailbio/bigslice/frame"
    18  	"github.com/grailbio/bigslice/internal/slicecache"
    19  	"github.com/grailbio/bigslice/slicefunc"
    20  	"github.com/grailbio/bigslice/sliceio"
    21  )
    22  
    23  func TestCompile(t *testing.T) {
    24  	for _, c := range []struct {
    25  		name string
    26  		f    func() bigslice.Slice
    27  	}{
    28  		{
    29  			"trivial",
    30  			func() (slice bigslice.Slice) {
    31  				slice = bigslice.Const(3, []int{})
    32  				return
    33  			},
    34  		},
    35  		{
    36  			"shuffle",
    37  			func() (slice bigslice.Slice) {
    38  				slice = bigslice.Const(3, []int{}, []float64{})
    39  				slice = bigslice.Reduce(slice, func(v0, v1 float64) float64 { return v0 + v1 })
    40  				return
    41  			},
    42  		},
    43  		{
    44  			// Branch where both branches pipeline with the subsequent maps.
    45  			"branch",
    46  			func() (slice bigslice.Slice) {
    47  				slice = bigslice.Const(3, []int{})
    48  				slice = bigslice.Map(slice, func(i int) int { return i })
    49  				slice0 := bigslice.Map(slice, func(i int) int { return i })
    50  				slice1 := bigslice.Map(slice, func(i int) int { return i })
    51  				slice = bigslice.Cogroup(slice0, slice1)
    52  				return
    53  			},
    54  		},
    55  		{
    56  			// Branch from a materialized slice, so the subsequent maps are not
    57  			// pipelined through the materialized tasks.
    58  			"branch-materialize",
    59  			func() (slice bigslice.Slice) {
    60  				slice = bigslice.Const(3, []int{})
    61  				slice = bigslice.Map(slice, func(i int) int { return i }, bigslice.ExperimentalMaterialize)
    62  				slice0 := bigslice.Map(slice, func(i int) int { return i })
    63  				slice1 := bigslice.Map(slice, func(i int) int { return i })
    64  				slice = bigslice.Cogroup(slice0, slice1)
    65  				return
    66  			},
    67  		},
    68  		{
    69  			// Branch the const slice with a reduce, which introduces its own
    70  			// shuffle/combiner, so the const slice tasks cannot be reused.
    71  			"branch-shuffle",
    72  			func() (slice bigslice.Slice) {
    73  				slice = bigslice.Const(3, []int{}, []float64{})
    74  				slice0 := bigslice.Reduce(slice, func(v0, v1 float64) float64 { return v0 + v1 })
    75  				slice = bigslice.Cogroup(slice, slice0)
    76  				return
    77  			},
    78  		},
    79  		{
    80  			// Branch where each branch demands the same partition number from
    81  			// the branch point slice. In this case, the branch point tasks can
    82  			// be reused.
    83  			"branch-same-partitions",
    84  			func() (slice bigslice.Slice) {
    85  				slice = bigslice.Const(3, []int{})
    86  				slice = bigslice.Map(slice, func(i int) int { return i })
    87  				slice0 := bigslice.Reshard(slice, 2)
    88  				slice1 := bigslice.Reshard(slice, 2)
    89  				slice = bigslice.Cogroup(slice0, slice1)
    90  				return
    91  			},
    92  		},
    93  		{
    94  			// Branch where each branch demands different partition numbers from
    95  			// the branch point slice. In this case, the branch point tasks
    96  			// cannot be reused.
    97  			"branch-different-partitions",
    98  			func() (slice bigslice.Slice) {
    99  				slice = bigslice.Const(3, []int{})
   100  				slice = bigslice.Map(slice, func(i int) int { return i })
   101  				slice0 := bigslice.Reshard(slice, 1)
   102  				slice1 := bigslice.Reshard(slice, 2)
   103  				slice = bigslice.Cogroup(slice0, slice1)
   104  				return
   105  			},
   106  		},
   107  	} {
   108  		t.Run(c.name, func(t *testing.T) {
   109  			f := bigslice.Func(c.f)
   110  			inv := makeExecInvocation(f.Invocation("<unknown>"))
   111  			inv.Index = 1
   112  			slice := inv.Invoke()
   113  			tasks, err := compile(inv, slice, false)
   114  			if err != nil {
   115  				t.Fatalf("compilation failed")
   116  			}
   117  			_ = iterTasks(tasks, func(task *Task) error {
   118  				if task.Pragma == nil {
   119  					t.Errorf("%v has nil task.Pragma", task)
   120  				}
   121  				return nil
   122  			})
   123  			g := makeGraph(tasks)
   124  			want, err := ioutil.ReadFile("testdata/" + c.name + ".graph")
   125  			if err != nil {
   126  				t.Fatalf("error reading graph: %v", err)
   127  			}
   128  			d := lineDiff(g.String(), string(want))
   129  			if d != "" {
   130  				t.Errorf("differs from %s.graph:\n%s", c.name, d)
   131  			}
   132  		})
   133  	}
   134  }
   135  
   136  // TestCompileEnv verifies that the compileEnv is used and behaves properly,
   137  // specifically verifying that compilation correctly writes to writable
   138  // environments and reads from non-writable environments.
   139  func TestCompileEnv(t *testing.T) {
   140  	const Nshard = 8
   141  
   142  	// cachedShards is set up just before we invoke the Func. It represents the
   143  	// fake cache state from the perspective of that invocation.
   144  	var cachedShards []int
   145  	f := bigslice.Func(func() bigslice.Slice {
   146  		slice := bigslice.Const(Nshard, []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})
   147  		// Break the pipeline, as we use this to detect for which compiled tasks
   148  		// compilation considered the cache valid. If the cache is valid, the
   149  		// compiled root task will have no dependencies.
   150  		slice = bigslice.Reshuffle(slice)
   151  		slice = fakeCache(slice, cachedShards)
   152  		return slice
   153  	})
   154  	inv := makeExecInvocation(f.Invocation("<unknown>"))
   155  	inv.Index = 0
   156  
   157  	cachedShardsFrozen := []int{1, 4, 5}
   158  	cachedShards = cachedShardsFrozen
   159  	slice0 := inv.Invoke()
   160  	tasks, err := compile(inv, slice0, false)
   161  	if err != nil {
   162  		t.Fatalf("compilation failed")
   163  	}
   164  	for _, task := range tasks {
   165  		var cached bool
   166  		for _, shard := range cachedShardsFrozen {
   167  			if shard == task.Name.Shard {
   168  				cached = true
   169  			}
   170  		}
   171  		// Verify that the resulting tasks reflect the cache state.
   172  		if got, want := len(task.Deps) == 0, cached; got != want {
   173  			t.Errorf("got %v, want %v", got, want)
   174  		}
   175  	}
   176  
   177  	// Freeze the environment, and verify that compilation uses the environment
   178  	// and not the current cache state.
   179  	inv.Env.Freeze()
   180  	cachedShards = []int{2, 4, 7} // different cache state from above.
   181  	slice1 := inv.Invoke()
   182  	tasks, err = compile(inv, slice1, false)
   183  	if err != nil {
   184  		t.Fatalf("compilation failed")
   185  	}
   186  	for _, task := range tasks {
   187  		var cached bool
   188  		for _, shard := range cachedShardsFrozen {
   189  			if shard == task.Name.Shard {
   190  				cached = true
   191  			}
   192  		}
   193  		// Verify that the tasks are compiled according to the environment that
   194  		// reflects cachedShardsFrozen, and not the current cache state in
   195  		// cachedShards.
   196  		if got, want := len(task.Deps) == 0, cached; got != want {
   197  			t.Errorf("got %v, want %v", got, want)
   198  		}
   199  	}
   200  }
   201  
   202  // TestPipelinedCache verifies that cacheable slices that are pipelined for
   203  // execution behave as we expect.
   204  func TestPipelinedCache(t *testing.T) {
   205  	const Nshard = 8
   206  	f := bigslice.Func(func() bigslice.Slice {
   207  		slice := bigslice.Const(Nshard, []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})
   208  		// Break the pipeline, as we use this to detect for which compiled tasks
   209  		// compilation considered the cache valid. If the cache is valid, the
   210  		// compiled root task will have no dependencies.
   211  		slice = bigslice.Reshuffle(slice)
   212  		id := func(i int) int { return i }
   213  		// These slices will be pipelined.  We set it up with different shards
   214  		// cached in different slices, with some shards not cached at all.
   215  		// When we examine the resulting dependencies of the (pipelined) task,
   216  		// we should only see dependencies for shards without any cache, as
   217  		// only those need the upstream results.
   218  		slice = fakeCache(bigslice.Map(slice, id), []int{0, 2})
   219  		slice = bigslice.Map(slice, id)
   220  		slice = fakeCache(bigslice.Map(slice, id), []int{5, 7})
   221  		slice = bigslice.Map(slice, id)
   222  		return slice
   223  	})
   224  	inv := makeExecInvocation(f.Invocation("<unknown>"))
   225  	inv.Index = 0
   226  	slice0 := inv.Invoke()
   227  	tasks, err := compile(inv, slice0, false)
   228  	if err != nil {
   229  		t.Fatalf("compilation failed")
   230  	}
   231  	// These are all the shards that we expect to be computable without
   232  	// dependencies, as some part of the (pipelined) computation is cached.
   233  	// This is the union of the shards cached in our fakeCache slices.
   234  	noDeps := []int{0, 2, 5, 7}
   235  	for _, task := range tasks {
   236  		var inNoDeps bool
   237  		for _, shard := range noDeps {
   238  			if shard == task.Name.Shard {
   239  				inNoDeps = true
   240  			}
   241  		}
   242  		// Verify that the resulting tasks reflect the cache state.
   243  		if got, want := len(task.Deps) == 0, inNoDeps; got != want {
   244  			t.Errorf("got %v, want %v", got, want)
   245  		}
   246  		// Invoke Do to verify that we can construct our pipelined computation.
   247  		// There have been bugs for which this call would panic.  Note that
   248  		// this is somewhat fragile, as we assume that Do does not access the
   249  		// input readers, instead only composing readers to represent the
   250  		// pipeline.
   251  		task.Do([]sliceio.Reader{sliceio.EmptyReader{}})
   252  	}
   253  }
   254  
   255  // makeGraph returns a graph representation of the task graph roots that is
   256  // convenient for printing and comparing. We use this to verify (and debug)
   257  // compilation results.
   258  func makeGraph(roots []*Task) graph {
   259  	var (
   260  		visited = make(map[*Task]bool)
   261  		g       graph
   262  		walk    func(tasks []*Task)
   263  	)
   264  	walk = func(tasks []*Task) {
   265  		if len(tasks) == 0 {
   266  			return
   267  		}
   268  		for _, t := range tasks {
   269  			if visited[t] {
   270  				continue
   271  			}
   272  			visited[t] = true
   273  			g.nodes = append(g.nodes, t.Name.String())
   274  			for _, d := range t.Deps {
   275  				for i := 0; i < d.NumTask(); i++ {
   276  					edge := edge{t.Name.String(), d.Task(i).Name.String()}
   277  					g.edges = append(g.edges, edge)
   278  					walk([]*Task{d.Task(i)})
   279  				}
   280  			}
   281  
   282  		}
   283  	}
   284  	walk(roots)
   285  	g.Sort()
   286  	return g
   287  }
   288  
   289  type edge struct {
   290  	src string
   291  	dst string
   292  }
   293  
   294  type graph struct {
   295  	nodes []string
   296  	edges []edge
   297  }
   298  
   299  func (g graph) Sort() {
   300  	sort.Strings(g.nodes)
   301  	sort.Slice(g.edges, func(i, j int) bool {
   302  		if g.edges[i].src != g.edges[j].src {
   303  			return g.edges[i].src < g.edges[j].src
   304  		}
   305  		return g.edges[i].dst < g.edges[j].dst
   306  	})
   307  }
   308  
   309  func (g graph) String() string {
   310  	var b bytes.Buffer
   311  	for _, n := range g.nodes {
   312  		fmt.Fprintf(&b, "%s\n", n)
   313  	}
   314  	for _, e := range g.edges {
   315  		fmt.Fprintf(&b, "%s -> %s\n", e.src, e.dst)
   316  	}
   317  	return b.String()
   318  }
   319  
   320  func lineDiff(lhs, rhs string) string {
   321  	lhsLines := strings.Split(lhs, "\n")
   322  	rhsLines := strings.Split(rhs, "\n")
   323  
   324  	// This is a vanilla Levenshtein distance implementation.
   325  	const (
   326  		editNone = iota
   327  		editAdd
   328  		editDel
   329  		editRep
   330  	)
   331  	type cell struct {
   332  		edit int
   333  		cost int
   334  	}
   335  	cells := make([][]cell, len(lhsLines)+1)
   336  	for i := range cells {
   337  		cells[i] = make([]cell, len(rhsLines)+1)
   338  	}
   339  	for i := 1; i < len(lhsLines)+1; i++ {
   340  		cells[i][0].edit = editDel
   341  		cells[i][0].cost = i
   342  	}
   343  	for j := 1; j < len(rhsLines)+1; j++ {
   344  		cells[0][j].edit = editAdd
   345  		cells[0][j].cost = j
   346  	}
   347  	for i := 1; i < len(lhsLines)+1; i++ {
   348  		for j := 1; j < len(rhsLines)+1; j++ {
   349  			if lhsLines[i-1] == rhsLines[j-1] {
   350  				cells[i][j].cost = cells[i-1][j-1].cost
   351  				continue
   352  			}
   353  			repCost := cells[i-1][j-1].cost + 1
   354  			minCost := repCost
   355  			delCost := cells[i-1][j].cost + 1
   356  			if delCost < minCost {
   357  				minCost = delCost
   358  			}
   359  			addCost := cells[i][j-1].cost + 1
   360  			if addCost < minCost {
   361  				minCost = addCost
   362  			}
   363  			cells[i][j].cost = minCost
   364  			switch minCost {
   365  			case repCost:
   366  				cells[i][j].edit = editRep
   367  			case addCost:
   368  				cells[i][j].edit = editAdd
   369  			case delCost:
   370  				cells[i][j].edit = editDel
   371  			}
   372  		}
   373  	}
   374  	var (
   375  		d      []string
   376  		differ bool
   377  	)
   378  	for i, j := len(lhsLines), len(rhsLines); i > 0 || j > 0; {
   379  		switch cells[i][j].edit {
   380  		case editNone:
   381  			d = append(d, lhsLines[i-1])
   382  			i--
   383  			j--
   384  		case editAdd:
   385  			d = append(d, "+ "+rhsLines[j-1])
   386  			j--
   387  			differ = true
   388  		case editDel:
   389  			d = append(d, "- "+lhsLines[i-1])
   390  			i--
   391  			differ = true
   392  		case editRep:
   393  			d = append(d, "+ "+rhsLines[j-1])
   394  			d = append(d, "- "+lhsLines[i-1])
   395  			i--
   396  			j--
   397  			differ = true
   398  		}
   399  	}
   400  	if !differ {
   401  		return ""
   402  	}
   403  	for i := len(d)/2 - 1; i >= 0; i-- {
   404  		opp := len(d) - 1 - i
   405  		d[i], d[opp] = d[opp], d[i]
   406  	}
   407  	var b bytes.Buffer
   408  	for _, dLine := range d {
   409  		b.WriteString(dLine + "\n")
   410  	}
   411  	return b.String()
   412  }
   413  
   414  type fakeShardCache struct {
   415  	cachedSet map[int]bool
   416  }
   417  
   418  func (c fakeShardCache) IsCached(shard int) bool { return c.cachedSet[shard] }
   419  func (fakeShardCache) WritethroughReader(shard int, reader sliceio.Reader) sliceio.Reader {
   420  	return reader
   421  }
   422  func (fakeShardCache) CacheReader(shard int) sliceio.Reader {
   423  	return emptyReader{}
   424  }
   425  
   426  type emptyReader struct{}
   427  
   428  func (emptyReader) Read(ctx context.Context, frame frame.Frame) (int, error) {
   429  	return 0, sliceio.EOF
   430  }
   431  
   432  type fakeCacheSlice struct {
   433  	name bigslice.Name
   434  	bigslice.Slice
   435  	cache slicecache.ShardCache
   436  }
   437  
   438  func (c *fakeCacheSlice) Name() bigslice.Name { return c.name }
   439  func (c *fakeCacheSlice) NumDep() int         { return 1 }
   440  func (c *fakeCacheSlice) Dep(i int) bigslice.Dep {
   441  	return bigslice.Dep{
   442  		Slice:       c.Slice,
   443  		Shuffle:     false,
   444  		Partitioner: nil,
   445  		Expand:      false,
   446  	}
   447  }
   448  func (*fakeCacheSlice) Combiner() slicefunc.Func                                 { return slicefunc.Nil }
   449  func (c *fakeCacheSlice) Reader(shard int, deps []sliceio.Reader) sliceio.Reader { return deps[0] }
   450  func (c *fakeCacheSlice) Cache() slicecache.ShardCache                           { return c.cache }
   451  
   452  func fakeCache(slice bigslice.Slice, cachedShards []int) bigslice.Slice {
   453  	cachedSet := make(map[int]bool)
   454  	for _, shard := range cachedShards {
   455  		cachedSet[shard] = true
   456  	}
   457  	return &fakeCacheSlice{bigslice.MakeName("testcache"), slice, fakeShardCache{cachedSet}}
   458  }