github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/exec/task.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package exec
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"sort"
    14  	"strings"
    15  	"sync"
    16  	"text/tabwriter"
    17  
    18  	"github.com/grailbio/base/status"
    19  	"github.com/grailbio/base/sync/ctxsync"
    20  	"github.com/grailbio/bigslice"
    21  	"github.com/grailbio/bigslice/metrics"
    22  	"github.com/grailbio/bigslice/slicefunc"
    23  	"github.com/grailbio/bigslice/sliceio"
    24  	"github.com/grailbio/bigslice/slicetype"
    25  )
    26  
    27  func init() {
    28  	close(closedc)
    29  }
    30  
    31  // closedc is closed in init which can be used any time we just want a closed
    32  // channel (i.e. a channel that is always ready and receives a zero value).
    33  var closedc = make(chan struct{})
    34  
    35  // ErrTaskLost indicates that a Task was in TaskLost state.
    36  var ErrTaskLost = errors.New("task was lost")
    37  
    38  // TaskState represents the runtime state of a Task. TaskState
    39  // values are defined so that their magnitudes correspond with
    40  // task progression.
    41  type TaskState int
    42  
    43  const (
    44  	// TaskInit is the initial state of a task. Tasks in state TaskInit
    45  	// have usually not yet been seen by an executor.
    46  	TaskInit TaskState = iota
    47  
    48  	// TaskWaiting indicates that a task has been scheduled for
    49  	// execution (it is runnable) but has not yet been allocated
    50  	// resources by the executor.
    51  	TaskWaiting
    52  	// TaskRunning is the state of a task that's currently being run or
    53  	// discarded. After a task is in state TaskRunning, it can only enter a
    54  	// larger-valued state.
    55  	TaskRunning
    56  
    57  	// TaskOk indicates that a task has successfully completed;
    58  	// the task's results are available to dependent tasks.
    59  	//
    60  	// All TaskState values greater than TaskOk indicate task
    61  	// errors.
    62  	TaskOk
    63  
    64  	// TaskErr indicates that the task experienced a failure while
    65  	// running.
    66  	TaskErr
    67  	// TaskLost indicates that the task was lost, usually because
    68  	// the machine to which the task was assigned failed.
    69  	TaskLost
    70  
    71  	maxState
    72  )
    73  
    74  var states = [...]string{
    75  	TaskInit:    "INIT",
    76  	TaskWaiting: "WAITING",
    77  	TaskRunning: "RUNNING",
    78  	TaskOk:      "OK",
    79  	TaskErr:     "ERROR",
    80  	TaskLost:    "LOST",
    81  }
    82  
    83  // String returns the task's state as an upper-case string.
    84  func (s TaskState) String() string {
    85  	return states[s]
    86  }
    87  
    88  // A TaskDep describes a single dependency for a task. A dependency
    89  // comprises one or more tasks and the partition number of the task
    90  // set that must be read at run time.
    91  type TaskDep struct {
    92  	// Head holds the underlying task that represents this dependency.
    93  	// For shuffle dependencies, that task is the head task of the
    94  	// phase, and the evaluator must expand the phase.
    95  	Head      *Task
    96  	Partition int
    97  
    98  	// Expand indicates that the task's dependencies for a given
    99  	// partition should not be merged, but rather passed individually to
   100  	// the task implementation.
   101  	Expand bool
   102  
   103  	// CombineKey is an optional label that names the combination key to
   104  	// be used by this dependency. It is used to name a single combiner
   105  	// buffer from which is read a number of combined tasks.
   106  	//
   107  	// CombineKeys must be provided to tasks that contain combiners.
   108  	CombineKey string
   109  }
   110  
   111  // NumTask returns the number of tasks that are comprised by this dependency.
   112  func (d TaskDep) NumTask() int {
   113  	if d.Head == nil {
   114  		return 0
   115  	}
   116  	if n := len(d.Head.Group); n > 0 {
   117  		return n
   118  	}
   119  	return 1
   120  }
   121  
   122  // Task returns the i'th task comprised by this dependency.
   123  func (d TaskDep) Task(i int) *Task {
   124  	if i == 0 {
   125  		return d.Head
   126  	}
   127  	return d.Head.Group[i]
   128  }
   129  
   130  // A TaskName uniquely names a task by its constituent components.
   131  // Tasks with 0 shards are taken to be combiner tasks: they are
   132  // machine-local buffers of combiner outputs for some (non-overlapping)
   133  // subset of shards for a task.
   134  type TaskName struct {
   135  	// InvIndex is the index of the invocation for which the task was compiled.
   136  	InvIndex uint64
   137  	// Op is a unique string describing the operation that is provided
   138  	// by the task.
   139  	Op string
   140  	// Shard and NumShard describe the shard processed by this task
   141  	// and the total number of shards to be processed.
   142  	Shard, NumShard int
   143  }
   144  
   145  // String returns a canonical representation of the task name,
   146  // formatted as:
   147  //
   148  //	{n.Op}@{n.NumShard}:{n.Shard}
   149  //	{n.Op}_combiner
   150  func (n TaskName) String() string {
   151  	if n.NumShard == 0 {
   152  		return n.Op + "_combiner"
   153  	}
   154  	return fmt.Sprintf("%s@%d:%d", n.Op, n.NumShard, n.Shard)
   155  }
   156  
   157  // IsCombiner returns whether the named task is a combiner task.
   158  func (n TaskName) IsCombiner() bool {
   159  	return n.NumShard == 0
   160  }
   161  
   162  // TaskSubscriber is subscribed to a Task using Subscribe. It is then notified
   163  // whenever the Task state changes. This is useful for efficiently observing the
   164  // state changes of many tasks.
   165  type TaskSubscriber struct {
   166  	sync.Mutex
   167  	cond *ctxsync.Cond
   168  
   169  	// tasks holds the set of tasks that has changed since the last call to
   170  	// Tasks.
   171  	tasks map[*Task]struct{}
   172  }
   173  
   174  // NewTaskSubscriber returns a new TaskSubscriber. It needs to be subscribed to
   175  // a Task with Subscribe for it to be notified of task state changes.
   176  func NewTaskSubscriber() *TaskSubscriber {
   177  	s := &TaskSubscriber{tasks: make(map[*Task]struct{})}
   178  	s.cond = ctxsync.NewCond(s)
   179  	return s
   180  }
   181  
   182  // Notify notifies s of a task whose state has changed.
   183  func (s *TaskSubscriber) Notify(task *Task) {
   184  	s.Lock()
   185  	defer s.Unlock()
   186  	s.tasks[task] = struct{}{}
   187  	s.cond.Broadcast()
   188  }
   189  
   190  // Ready returns a channel that is closed if a subsequent call to Tasks will
   191  // return a non-nil slice.
   192  func (s *TaskSubscriber) Ready() <-chan struct{} {
   193  	s.Lock()
   194  	if len(s.tasks) > 0 {
   195  		s.Unlock()
   196  		return closedc
   197  	}
   198  	return s.cond.Done()
   199  }
   200  
   201  // Tasks returns the tasks whose state has changed since the last call to Tasks.
   202  func (s *TaskSubscriber) Tasks() []*Task {
   203  	s.Lock()
   204  	defer s.Unlock()
   205  	tasks := make([]*Task, 0, len(s.tasks))
   206  	for task := range s.tasks {
   207  		tasks = append(tasks, task)
   208  	}
   209  	s.tasks = make(map[*Task]struct{})
   210  	return tasks
   211  }
   212  
   213  // A Task represents a concrete computational task. Tasks form graphs
   214  // through dependencies; task graphs are compiled from slices.
   215  //
   216  // Tasks also maintain executor state, and are used to coordinate
   217  // execution between concurrent evaluators and a single executor
   218  // (which may be evaluating many tasks concurrently). Tasks thus
   219  // embed a mutex for coordination and provide a context-aware
   220  // conditional variable to coordinate runtime state changes.
   221  type Task struct {
   222  	slicetype.Type
   223  	// Invocation is the task's invocation, i.e. the Func invocation
   224  	// from which this task was compiled.
   225  	Invocation execInvocation
   226  	// Name is the name of the task. Tasks are named uniquely inside each
   227  	// Bigslice session.
   228  	Name TaskName
   229  	// Do starts computation for this task, returning a reader that
   230  	// computes batches of values on demand. Do is invoked with readers
   231  	// for the task's dependencies.
   232  	Do func([]sliceio.Reader) sliceio.Reader
   233  	// Deps are the task's dependencies. See TaskDep for details.
   234  	Deps []TaskDep
   235  
   236  	// Partitioner is used to partition the task's output. It will only
   237  	// be called when NumPartition > 1.
   238  	Partitioner bigslice.Partitioner
   239  	// NumPartition is the number of partitions that are output by this task.
   240  	// If NumPartition > 1, then the task must also define a partitioner.
   241  	NumPartition int
   242  
   243  	// Combiner specifies an (optional) combiner to use for this task's output.
   244  	// If a Combiner is not Nil, CombineKey names the combine buffer used:
   245  	// each combine buffer contains combiner outputs from multiple tasks.
   246  	// If CombineKey is not set, then per-task buffers are used instead.
   247  	Combiner   slicefunc.Func
   248  	CombineKey string
   249  
   250  	// Pragma comprises the pragmas of all slice operations that
   251  	// are pipelined into this task.
   252  	bigslice.Pragma
   253  
   254  	// Slices is the set of slices to which this task directly contributes.
   255  	Slices []bigslice.Slice
   256  
   257  	// Group stores an ordered list of peer tasks. If Group is nonempty,
   258  	// it is guaranteed that these sets of tasks constitute a shuffle
   259  	// dependency, and share a set of shuffle dependencies. This allows
   260  	// the evaluator to perform optimizations while tracking such
   261  	// dependencies.
   262  	Group []*Task
   263  
   264  	// Scopes is the metrics scope for this task. It is populated with the
   265  	// metrics produced during execution of this task.
   266  	Scope metrics.Scope
   267  
   268  	// subs is the set of subscribers to which this task will be sent whenever
   269  	// its state changes.
   270  	subs []*TaskSubscriber
   271  
   272  	// The following are used to coordinate runtime execution.
   273  
   274  	sync.Mutex
   275  	waitc chan struct{}
   276  
   277  	// State is the task's state. It is protected by the task's lock
   278  	// and state changes are also broadcast on the task's condition
   279  	// variable.
   280  	state TaskState
   281  	// Err is defines when state == TaskErr.
   282  	err error
   283  
   284  	// consecutiveLost is the number of times this task has been run and lost
   285  	// consecutively. See maxConsecutiveLost.
   286  	consecutiveLost int
   287  
   288  	// Status is a status object to which task status is reported.
   289  	Status *status.Task
   290  }
   291  
   292  // Phase returns the phase to which this task belongs.
   293  func (t *Task) Phase() []*Task {
   294  	if len(t.Group) == 0 {
   295  		return []*Task{t}
   296  	}
   297  	return t.Group
   298  }
   299  
   300  // Head returns the head task of this task's phase. If the task does
   301  // not belong to a phase, Head returns the task t.
   302  func (t *Task) Head() *Task {
   303  	if len(t.Group) == 0 {
   304  		return t
   305  	}
   306  	return t.Group[0]
   307  }
   308  
   309  // String returns a short, human-readable string describing the
   310  // task's state.
   311  func (t *Task) String() string {
   312  	// We play fast-and-loose with concurrency here (we read state and
   313  	// err without holding the task's mutex) so that it is safe to call
   314  	// String even when the lock is held.
   315  	var b bytes.Buffer
   316  	fmt.Fprintf(&b, "task %s [%d] %s", t.Name, t.Invocation.Index, t.state)
   317  	if t.err != nil {
   318  		fmt.Fprintf(&b, ": %v", t.err)
   319  	}
   320  	return b.String()
   321  }
   322  
   323  // Set sets the task's state to the provided state and notifies
   324  // any waiters.
   325  func (t *Task) Set(state TaskState) {
   326  	t.Lock()
   327  	t.state = state
   328  	t.Broadcast()
   329  	t.Unlock()
   330  }
   331  
   332  // Error sets the task's state to TaskErr and its error to the
   333  // provided error. Waiters are notified.
   334  func (t *Task) Error(err error) {
   335  	t.Lock()
   336  	t.state = TaskErr
   337  	t.err = err
   338  	t.Status.Printf(err.Error())
   339  	t.Broadcast()
   340  	t.Unlock()
   341  }
   342  
   343  // Errorf formats an error message using fmt.Errorf, sets the task's
   344  // state to TaskErr and its err to the resulting error message.
   345  func (t *Task) Errorf(format string, v ...interface{}) {
   346  	t.Error(fmt.Errorf(format, v...))
   347  }
   348  
   349  // Err returns an error if the task's state is >= TaskErr. When the
   350  // state is > TaskErr, Err returns an error describing the task's
   351  // failed state, otherwise, t.err is returned.
   352  func (t *Task) Err() error {
   353  	t.Lock()
   354  	defer t.Unlock()
   355  	switch t.state {
   356  	case TaskErr:
   357  		if t.err == nil {
   358  			panic("TaskErr without an err")
   359  		}
   360  		return t.err
   361  	case TaskLost:
   362  		return ErrTaskLost
   363  	}
   364  	if t.state >= TaskErr {
   365  		panic("unhandled state")
   366  	}
   367  	return nil
   368  }
   369  
   370  // State returns the task's current state.
   371  func (t *Task) State() TaskState {
   372  	t.Lock()
   373  	state := t.state
   374  	t.Unlock()
   375  	return state
   376  }
   377  
   378  // Broadcast notifies waiters of a state change. Broadcast must only
   379  // be called while the task's lock is held.
   380  func (t *Task) Broadcast() {
   381  	if t.waitc != nil {
   382  		close(t.waitc)
   383  		t.waitc = nil
   384  	}
   385  	for _, sub := range t.subs {
   386  		sub.Notify(t)
   387  	}
   388  }
   389  
   390  // Wait returns after the next call to Broadcast, or if the context
   391  // is complete. The task's lock must be held when calling Wait.
   392  func (t *Task) Wait(ctx context.Context) error {
   393  	if t.waitc == nil {
   394  		t.waitc = make(chan struct{})
   395  	}
   396  	waitc := t.waitc
   397  	t.Unlock()
   398  	var err error
   399  	select {
   400  	case <-waitc:
   401  	case <-ctx.Done():
   402  		err = ctx.Err()
   403  	}
   404  	t.Lock()
   405  	return err
   406  }
   407  
   408  // WaitState returns when the task's state is at least the provided state,
   409  // or else when the context is done.
   410  func (t *Task) WaitState(ctx context.Context, state TaskState) (TaskState, error) {
   411  	t.Lock()
   412  	defer t.Unlock()
   413  	var err error
   414  	for t.state < state && err == nil {
   415  		err = t.Wait(ctx)
   416  	}
   417  	return t.state, err
   418  }
   419  
   420  // Subscribe subscribes s to be notified of any changes to t's state. If s has
   421  // already been subscribed, no-op.
   422  func (t *Task) Subscribe(s *TaskSubscriber) {
   423  	t.Lock()
   424  	defer t.Unlock()
   425  	for _, sub := range t.subs {
   426  		if s == sub {
   427  			// It is already registered.
   428  			return
   429  		}
   430  	}
   431  	t.subs = append(t.subs, s)
   432  }
   433  
   434  // Unsubscribe unsubscribes previously subscribe s. s will on longer receive
   435  // task state change notifications. No-op if s was never subscribed.
   436  func (t *Task) Unsubscribe(s *TaskSubscriber) {
   437  	t.Lock()
   438  	defer t.Unlock()
   439  	subs := t.subs[:0]
   440  	for _, sub := range t.subs {
   441  		if s == sub {
   442  			continue
   443  		}
   444  		subs = append(subs, sub)
   445  	}
   446  	t.subs = subs
   447  }
   448  
   449  // GraphString returns a schematic string of the task graph rooted at t.
   450  func (t *Task) GraphString() string {
   451  	var b bytes.Buffer
   452  	t.WriteGraph(&b)
   453  	return b.String()
   454  }
   455  
   456  // WriteGraph writes a schematic string of the task graph rooted at t into w.
   457  func (t *Task) WriteGraph(w io.Writer) {
   458  	var tw tabwriter.Writer
   459  	tw.Init(w, 4, 4, 1, ' ', 0)
   460  	fmt.Fprintln(&tw, "tasks:")
   461  	for _, task := range t.All() {
   462  		out := make([]string, task.NumOut())
   463  		for i := range out {
   464  			out[i] = fmt.Sprint(task.Out(i))
   465  		}
   466  		outstr := strings.Join(out, ",")
   467  		fmt.Fprintf(&tw, "\t%s\t%s\t%d [%s]\n", task.Name, outstr, task.NumPartition, task.State())
   468  	}
   469  	tw.Flush()
   470  	fmt.Fprintln(&tw, "dependencies:")
   471  	t.writeDeps(&tw)
   472  	tw.Flush()
   473  }
   474  
   475  func (t *Task) writeDeps(w io.Writer) {
   476  	for _, dep := range t.Deps {
   477  		for i := 0; i < dep.NumTask(); i++ {
   478  			task := dep.Task(i)
   479  			fmt.Fprintf(w, "\t%s:\t%s[%d]\n", t.Name, task.Name, dep.Partition)
   480  			task.writeDeps(w)
   481  		}
   482  	}
   483  }
   484  
   485  // All returns all tasks reachable from t. The returned
   486  // set of tasks is unique.
   487  func (t *Task) All() []*Task {
   488  	all := make(map[*Task]bool)
   489  	t.all(all)
   490  	var tasks []*Task
   491  	for task := range all {
   492  		tasks = append(tasks, task)
   493  	}
   494  	sort.Slice(tasks, func(i, j int) bool {
   495  		return tasks[i].Name.String() < tasks[j].Name.String()
   496  	})
   497  	return tasks
   498  }
   499  
   500  func (t *Task) all(tasks map[*Task]bool) {
   501  	if tasks[t] {
   502  		return
   503  	}
   504  	tasks[t] = true
   505  	for _, dep := range t.Deps {
   506  		for i := 0; i < dep.NumTask(); i++ {
   507  			dep.Task(i).all(tasks)
   508  		}
   509  	}
   510  }