github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/exec/compile.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package exec
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"strings"
    11  
    12  	"github.com/grailbio/base/log"
    13  	"github.com/grailbio/bigslice"
    14  	"github.com/grailbio/bigslice/frame"
    15  	"github.com/grailbio/bigslice/internal/slicecache"
    16  	"github.com/grailbio/bigslice/slicefunc"
    17  	"github.com/grailbio/bigslice/sliceio"
    18  )
    19  
    20  func defaultPartitioner(_ context.Context, frame frame.Frame, nshard int, shards []int) {
    21  	for i := range shards {
    22  		shards[i] = int(frame.Hash(i) % uint32(nshard))
    23  	}
    24  }
    25  
    26  // Pipeline returns the sequence of slices that may be pipelined
    27  // starting from slice. Slices that do not have shuffle dependencies
    28  // may be pipelined together: slices[0] depends on slices[1], and so on.
    29  func pipeline(slice bigslice.Slice) (slices []bigslice.Slice) {
    30  	for {
    31  		// Stop at *Results, so we can re-use previous tasks.
    32  		if _, ok := bigslice.Unwrap(slice).(*Result); ok {
    33  			return
    34  		}
    35  		slices = append(slices, slice)
    36  		if slice.NumDep() != 1 {
    37  			return
    38  		}
    39  		dep := slice.Dep(0)
    40  		if dep.Shuffle {
    41  			return
    42  		}
    43  		if pragma, ok := dep.Slice.(bigslice.Pragma); ok && pragma.Materialize() {
    44  			return
    45  		}
    46  		slice = dep.Slice
    47  	}
    48  }
    49  
    50  // memoKey is the memo key for memoized slice compilations.
    51  type memoKey struct {
    52  	slice bigslice.Slice
    53  	// numPartition is the number of partitions in the output of the memoized
    54  	// compiled tasks.
    55  	numPartition int
    56  }
    57  
    58  // partitioner configures the output partitioning of compiled tasks. The zero
    59  // value indicates that the output of the tasks are not for a shuffle
    60  // dependency.
    61  type partitioner struct {
    62  	// numPartition is the number of partitions in the output for a shuffle
    63  	// dependency, if >1. If 0, the output is not used by a shuffle.
    64  	numPartition int
    65  	partitioner  bigslice.Partitioner
    66  	Combiner     slicefunc.Func
    67  	CombineKey   string
    68  }
    69  
    70  // IsShuffle returns whether the task output is used by a shuffle dependency.
    71  func (p partitioner) IsShuffle() bool {
    72  	return p.numPartition != 0
    73  }
    74  
    75  // Partitioner returns the partitioner to be used to partition the output of
    76  // this task.
    77  func (p partitioner) Partitioner() bigslice.Partitioner {
    78  	if p.partitioner == nil {
    79  		return defaultPartitioner
    80  	}
    81  	return p.partitioner
    82  }
    83  
    84  // NumPartition returns the number of partitions that the task output should
    85  // have. If this is not a shuffle dependency, returns 1.
    86  func (p partitioner) NumPartition() int {
    87  	if p.numPartition == 0 {
    88  		return 1
    89  	}
    90  	return p.numPartition
    91  }
    92  
    93  // Compile compiles the provided slice into a set of task graphs,
    94  // each representing the computation for one shard of the slice. The
    95  // slice is produced by the provided invocation. Compile coalesces
    96  // slice operations that can be pipelined into single tasks, creating
    97  // wide dependencies only at shuffle boundaries. The provided namer
    98  // must mint names that are unique to the session. The order in which
    99  // the namer is invoked is guaranteed to be deterministic.
   100  //
   101  // TODO(marius): we don't currently reuse tasks across compilations,
   102  // even though this could sometimes safely be done (when the number
   103  // of partitions and the kind of partitioner matches at shuffle
   104  // boundaries). We should at least support this use case to avoid
   105  // redundant computations.
   106  //
   107  // TODO(marius): an alternative model for propagating invocations is
   108  // to provide each actual invocation with a "root" slice from where
   109  // all other slices must be derived. This simplifies the
   110  // implementation but may make the API a little confusing.
   111  func compile(inv execInvocation, slice bigslice.Slice, machineCombiners bool) (tasks []*Task, err error) {
   112  	c := compiler{
   113  		namer:            make(taskNamer),
   114  		inv:              inv,
   115  		machineCombiners: machineCombiners,
   116  		memo:             make(map[memoKey][]*Task),
   117  	}
   118  	// Top-level compilation always produces tasks that write single partitions,
   119  	// as they are materialized and will not be used as direct shuffle
   120  	// dependencies.
   121  	tasks, err = c.compile(slice, partitioner{})
   122  	return
   123  }
   124  
   125  type (
   126  	// CompileEnv is the environment for compilation. This environment should
   127  	// capture all external state that can affect compilation of an invocation.
   128  	// It is shared across compilations of the same invocation (e.g. on worker
   129  	// nodes) to guarantee consistent compilation results. This is a
   130  	// requirement of bigslice's computation model, as we assume that all nodes
   131  	// share the same view of the task graph. It must be gob-encodable for
   132  	// transport to workers.
   133  	CompileEnv struct {
   134  		// Writable is true if this environment is writable. It is only
   135  		// exported so that it can be gob-{en,dec}oded.
   136  		Writable bool
   137  		// Cached indicates whether a task operation's results can be read from
   138  		// cache. An "operation" is one of the pipelined elements that a task
   139  		// may perform. It is only exported so that it can be gob-{en,dec}oded.
   140  		Cached map[taskOp]bool
   141  	}
   142  	// taskOp is a (task, operation) pair. It is used as the key of
   143  	// (CompileEnv).Cached.
   144  	taskOp struct {
   145  		// N is the task, specified by name.
   146  		N TaskName
   147  		// OpIdx is the operation, specified by index in the task processing
   148  		// pipeline.
   149  		OpIdx int
   150  	}
   151  )
   152  
   153  // makeCompileEnv returns an empty and writable CompileEnv that can be passed
   154  // to compile.
   155  func makeCompileEnv() CompileEnv {
   156  	return CompileEnv{
   157  		Writable: true,
   158  		Cached:   make(map[taskOp]bool),
   159  	}
   160  }
   161  
   162  // MarkCached marks the (task, operation) given by (n, opIdx) as cached.
   163  func (e CompileEnv) MarkCached(n TaskName, opIdx int) {
   164  	if !e.Writable {
   165  		panic("env not writable")
   166  	}
   167  	e.Cached[taskOp{n, opIdx}] = true
   168  }
   169  
   170  // IsCached returns whether the (task, operation) given by (n, opIdx) is
   171  // cached.
   172  func (e CompileEnv) IsCached(n TaskName, opIdx int) bool {
   173  	return e.Cached[taskOp{n, opIdx}]
   174  }
   175  
   176  // Freeze freezes the state, marking e no longer writable.
   177  func (e *CompileEnv) Freeze() {
   178  	e.Writable = false
   179  }
   180  
   181  // IsWritable returns whether this environment is writable.
   182  func (e CompileEnv) IsWritable() bool {
   183  	return e.Writable
   184  }
   185  
   186  type compiler struct {
   187  	namer            taskNamer
   188  	inv              execInvocation
   189  	machineCombiners bool
   190  	memo             map[memoKey][]*Task
   191  }
   192  
   193  // compile compiles the provided slice into a set of task graphs, memoizing the
   194  // compilation so that tasks can be reused within the invocation.
   195  func (c *compiler) compile(slice bigslice.Slice, part partitioner) (tasks []*Task, err error) {
   196  	// We never reuse combiner tasks, as we currently don't have a way of
   197  	// identifying equivalent combiner functions. Ditto with custom
   198  	// partitioners.
   199  	if part.Combiner.IsNil() && part.partitioner == nil {
   200  		// TODO(jcharumilind): Repartition already-computed data instead of
   201  		// forcing recomputation of the slice if we get a different
   202  		// numPartition.
   203  		key := memoKey{slice: slice, numPartition: part.numPartition}
   204  		if memoTasks, ok := c.memo[key]; ok {
   205  			// We're compiling the same slice with the same number of partitions
   206  			// (and no combiner), so we can safely reuse the tasks.
   207  			return memoTasks, nil
   208  		}
   209  		defer func() {
   210  			if err != nil {
   211  				return
   212  			}
   213  			c.memo[key] = tasks
   214  		}()
   215  	}
   216  	// Beyond this point, any tasks used for shuffles are new and need to have
   217  	// task groups set up for phasic evaluation.
   218  	defer func() {
   219  		if part.IsShuffle() {
   220  			for _, task := range tasks {
   221  				task.Group = tasks
   222  			}
   223  		}
   224  	}()
   225  	// Reuse tasks from a previous invocation.
   226  	if result, ok := bigslice.Unwrap(slice).(*Result); ok {
   227  		for _, task := range result.tasks {
   228  			if !task.Combiner.IsNil() {
   229  				// TODO(marius): we may consider supporting this, but it should
   230  				// be very rare, since it requires the user to explicitly reuse
   231  				// an intermediate slice, which is impossible via the current
   232  				// API.
   233  				return nil, fmt.Errorf("cannot reuse task %s with combine key %s", task, task.CombineKey)
   234  			}
   235  		}
   236  		if !part.IsShuffle() {
   237  			tasks = result.tasks
   238  			return
   239  		}
   240  		// We now insert a set of tasks whose only purpose is (re-)shuffling
   241  		// the output from the previously completed task.
   242  		shuffleOpName := c.namer.New(fmt.Sprintf("%s_shuffle", result.tasks[0].Name.Op))
   243  		tasks = make([]*Task, len(result.tasks))
   244  		for shard, task := range result.tasks {
   245  			tasks[shard] = &Task{
   246  				Type:       slice,
   247  				Invocation: c.inv,
   248  				Name: TaskName{
   249  					InvIndex: c.inv.Index,
   250  					Op:       shuffleOpName,
   251  					Shard:    shard,
   252  					NumShard: len(result.tasks),
   253  				},
   254  				Do:     func(readers []sliceio.Reader) sliceio.Reader { return readers[0] },
   255  				Deps:   []TaskDep{{task, 0, false, ""}},
   256  				Pragma: task.Pragma,
   257  				Slices: task.Slices,
   258  			}
   259  		}
   260  		return
   261  	}
   262  	// Pipeline slices and create a task for each underlying shard, pipelining
   263  	// the eligible computations.
   264  	slices := pipeline(slice)
   265  	defer func() {
   266  		for _, task := range tasks {
   267  			task.Slices = slices
   268  		}
   269  	}()
   270  	var pragmas bigslice.Pragmas
   271  	ops := make([]string, 0, len(slices)+1)
   272  	ops = append(ops, fmt.Sprintf("inv%d", c.inv.Index))
   273  	for i := len(slices) - 1; i >= 0; i-- {
   274  		ops = append(ops, slices[i].Name().Op)
   275  		if pragma, ok := slices[i].(bigslice.Pragma); ok {
   276  			pragmas = append(pragmas, pragma)
   277  		}
   278  	}
   279  	opName := c.namer.New(strings.Join(ops, "_"))
   280  	tasks = make([]*Task, slice.NumShard())
   281  	for i := range tasks {
   282  		tasks[i] = &Task{
   283  			Type: slices[0],
   284  			Name: TaskName{
   285  				InvIndex: c.inv.Index,
   286  				Op:       opName,
   287  				Shard:    i,
   288  				NumShard: len(tasks),
   289  			},
   290  			Invocation:   c.inv,
   291  			Pragma:       pragmas,
   292  			NumPartition: part.NumPartition(),
   293  			Partitioner:  part.Partitioner(),
   294  			Combiner:     part.Combiner,
   295  			CombineKey:   part.CombineKey,
   296  		}
   297  	}
   298  	// Capture the dependencies for this task set; they are encoded in the last
   299  	// slice.
   300  	lastSlice := slices[len(slices)-1]
   301  	for i := 0; i < lastSlice.NumDep(); i++ {
   302  		dep := lastSlice.Dep(i)
   303  		if !dep.Shuffle {
   304  			depTasks, err := c.compile(dep.Slice, partitioner{})
   305  			if err != nil {
   306  				return nil, err
   307  			}
   308  			if len(tasks) != len(depTasks) {
   309  				log.Panicf("tasks:%d deptasks:%d", len(tasks), len(depTasks))
   310  			}
   311  			for shard := range tasks {
   312  				tasks[shard].Deps = append(tasks[shard].Deps,
   313  					TaskDep{depTasks[shard], 0, dep.Expand, ""})
   314  			}
   315  			continue
   316  		}
   317  		var combineKey string
   318  		if !lastSlice.Combiner().IsNil() && c.machineCombiners {
   319  			combineKey = opName
   320  		}
   321  		depPart := partitioner{
   322  			slice.NumShard(), dep.Partitioner,
   323  			lastSlice.Combiner(), combineKey,
   324  		}
   325  		depTasks, err := c.compile(dep.Slice, depPart)
   326  		if err != nil {
   327  			return nil, err
   328  		}
   329  		// Each shard reads different partitions from all of the previous slice's shards.
   330  		for partition := range tasks {
   331  			tasks[partition].Deps = append(tasks[partition].Deps,
   332  				TaskDep{depTasks[0], partition, dep.Expand, combineKey})
   333  		}
   334  	}
   335  	// Pipeline execution, folding multiple frame operations
   336  	// into a single task by composing their readers.
   337  	// Use cache when configured.
   338  	for opIdx := len(slices) - 1; opIdx >= 0; opIdx-- {
   339  		var (
   340  			pprofLabel = fmt.Sprintf("%s(%s)", slices[opIdx].Name(), c.inv.Location)
   341  			reader     = slices[opIdx].Reader
   342  			shardCache = slicecache.Empty
   343  		)
   344  		if c, ok := bigslice.Unwrap(slices[opIdx]).(slicecache.Cacheable); ok {
   345  			shardCache = c.Cache()
   346  		}
   347  		if c.inv.Env.IsWritable() {
   348  			for shard, task := range tasks {
   349  				if shardCache.IsCached(shard) {
   350  					c.inv.Env.MarkCached(task.Name, opIdx)
   351  				}
   352  			}
   353  		}
   354  		for shard, task := range tasks {
   355  			var (
   356  				shard = shard
   357  				prev  = task.Do
   358  			)
   359  			if c.inv.Env.IsCached(task.Name, opIdx) {
   360  				task.Do = func([]sliceio.Reader) sliceio.Reader {
   361  					r := shardCache.CacheReader(shard)
   362  					return &sliceio.PprofReader{Reader: r, Label: pprofLabel}
   363  				}
   364  				// Forget task dependencies for cached shards because we'll read
   365  				// from the cache file.
   366  				task.Deps = nil
   367  				continue
   368  			}
   369  			if prev == nil {
   370  				// First, read the input directly.
   371  				task.Do = func(readers []sliceio.Reader) sliceio.Reader {
   372  					r := reader(shard, readers)
   373  					r = shardCache.WritethroughReader(shard, r)
   374  					return &sliceio.PprofReader{Reader: r, Label: pprofLabel}
   375  				}
   376  			} else {
   377  				// Subsequently, read the previous pipelined slice's output.
   378  				task.Do = func(readers []sliceio.Reader) sliceio.Reader {
   379  					r := reader(shard, []sliceio.Reader{prev(readers)})
   380  					r = shardCache.WritethroughReader(shard, r)
   381  					return &sliceio.PprofReader{Reader: r, Label: pprofLabel}
   382  				}
   383  			}
   384  		}
   385  	}
   386  	return
   387  }
   388  
   389  type taskNamer map[string]int
   390  
   391  func (n taskNamer) New(name string) string {
   392  	c := n[name]
   393  	n[name]++
   394  	if c == 0 {
   395  		return name
   396  	}
   397  	return fmt.Sprintf("%s%d", name, c)
   398  }