github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/exec/local.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package exec
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"net/http"
    11  	"runtime/debug"
    12  	"sync"
    13  
    14  	"github.com/grailbio/base/backgroundcontext"
    15  	"github.com/grailbio/base/errors"
    16  	"github.com/grailbio/base/eventlog"
    17  	"github.com/grailbio/base/limiter"
    18  	"github.com/grailbio/base/log"
    19  	"github.com/grailbio/bigslice/frame"
    20  	"github.com/grailbio/bigslice/metrics"
    21  	"github.com/grailbio/bigslice/sliceio"
    22  )
    23  
    24  // LocalExecutor is an executor that runs tasks in-process in
    25  // separate goroutines. All output is buffered in memory.
    26  type localExecutor struct {
    27  	mu      sync.Mutex
    28  	state   map[*Task]TaskState
    29  	buffers map[*Task]taskBuffer
    30  	limiter *limiter.Limiter
    31  	sess    *Session
    32  }
    33  
    34  func newLocalExecutor() *localExecutor {
    35  	return &localExecutor{
    36  		state:   make(map[*Task]TaskState),
    37  		buffers: make(map[*Task]taskBuffer),
    38  		limiter: limiter.New(),
    39  	}
    40  }
    41  
    42  func (*localExecutor) Name() string {
    43  	return "local"
    44  }
    45  
    46  func (l *localExecutor) Start(sess *Session) (shutdown func()) {
    47  	l.sess = sess
    48  	l.limiter.Release(sess.p)
    49  	return
    50  }
    51  
    52  func (l *localExecutor) Run(task *Task) {
    53  	ctx := backgroundcontext.Get()
    54  	n := 1
    55  	if task.Pragma.Exclusive() {
    56  		n = l.sess.p
    57  	}
    58  	if err := l.limiter.Acquire(ctx, n); err != nil {
    59  		// The only errors we should encounter here are context errors,
    60  		// in which case there is no more work to do.
    61  		if err != context.Canceled && err != context.DeadlineExceeded {
    62  			log.Panicf("exec.Local: unexpected error: %v", err)
    63  		}
    64  		return
    65  	}
    66  	defer l.limiter.Release(n)
    67  	in, err := l.depReaders(ctx, task)
    68  	if err != nil {
    69  		if errors.Match(fatalErr, err) {
    70  			task.Error(err)
    71  		} else {
    72  			task.Set(TaskLost)
    73  		}
    74  		return
    75  	}
    76  	task.Set(TaskRunning)
    77  
    78  	// Start execution, then place output in a task buffer. We also plumb a
    79  	// metrics scope in here so we can store and aggregate metrics.
    80  	task.Scope.Reset(nil)
    81  	out := task.Do(in)
    82  	buf, err := bufferOutput(metrics.ScopedContext(ctx, &task.Scope), task, out)
    83  	task.Lock()
    84  	if err == nil {
    85  		l.mu.Lock()
    86  		l.buffers[task] = buf
    87  		l.mu.Unlock()
    88  		task.state = TaskOk
    89  	} else {
    90  		if errors.Match(fatalErr, err) {
    91  			task.state = TaskErr
    92  		} else {
    93  			task.state = TaskLost
    94  		}
    95  		task.err = err
    96  	}
    97  	task.Broadcast()
    98  	task.Unlock()
    99  }
   100  
   101  func (l *localExecutor) depReaders(ctx context.Context, task *Task) ([]sliceio.Reader, error) {
   102  	in := make([]sliceio.Reader, 0, len(task.Deps))
   103  	for _, dep := range task.Deps {
   104  		reader := new(multiReader)
   105  		reader.q = make([]sliceio.Reader, dep.NumTask())
   106  		for j := 0; j < dep.NumTask(); j++ {
   107  			reader.q[j] = l.Reader(dep.Task(j), dep.Partition)
   108  		}
   109  		if dep.NumTask() > 0 && !dep.Task(0).Combiner.IsNil() {
   110  			// Perform input combination in-line, one for each partition.
   111  			combineKey := task.Name
   112  			if task.CombineKey != "" {
   113  				combineKey = TaskName{Op: task.CombineKey}
   114  			}
   115  			combiner, err := newCombiner(dep.Task(0), combineKey.String(), dep.Task(0).Combiner, *defaultChunksize*100)
   116  			if err != nil {
   117  				return nil, errors.E(errors.Fatal, "could not make combiner for %v", dep.Task(0).String(), err)
   118  			}
   119  			buf := frame.Make(dep.Task(0), *defaultChunksize, *defaultChunksize)
   120  			for {
   121  				var n int
   122  				n, err = reader.Read(ctx, buf)
   123  				if err != nil && err != sliceio.EOF {
   124  					return nil, errors.E("error reading %v", dep.Task(0).String(), err)
   125  				}
   126  				if combineErr := combiner.Combine(ctx, buf.Slice(0, n)); combineErr != nil {
   127  					return nil, errors.E(errors.Fatal, "failed to combine %v", dep.Task(0).String(), combineErr)
   128  				}
   129  				if err == sliceio.EOF {
   130  					break
   131  				}
   132  			}
   133  			reader, err := combiner.Reader()
   134  			if err != nil {
   135  				return nil, errors.E(errors.Fatal, "failed to start reading combiner for %v", dep.Task(0).String(), err)
   136  			}
   137  			in = append(in, reader)
   138  		} else if dep.Expand {
   139  			in = append(in, reader.q...)
   140  		} else {
   141  			in = append(in, reader)
   142  		}
   143  	}
   144  	return in, nil
   145  }
   146  
   147  func (l *localExecutor) Reader(task *Task, partition int) sliceio.ReadCloser {
   148  	l.mu.Lock()
   149  	buf, ok := l.buffers[task]
   150  	l.mu.Unlock()
   151  	if !ok {
   152  		return sliceio.ReaderWithCloseFunc{
   153  			Reader:    sliceio.ErrReader(fmt.Errorf("no data for %v", task)),
   154  			CloseFunc: func() error { return nil },
   155  		}
   156  	}
   157  	return buf.Reader(partition)
   158  }
   159  
   160  func (l *localExecutor) Discard(_ context.Context, task *Task) {
   161  	if !task.Combiner.IsNil() && task.CombineKey != "" {
   162  		// We do not yet handle tasks with shared combiners.
   163  		return
   164  	}
   165  	task.Lock()
   166  	if task.state == TaskOk {
   167  		l.mu.Lock()
   168  		delete(l.buffers, task)
   169  		l.mu.Unlock()
   170  		task.state = TaskLost
   171  		task.Broadcast()
   172  		task.Unlock()
   173  		return
   174  	}
   175  	task.Unlock()
   176  }
   177  
   178  func (l *localExecutor) Eventer() eventlog.Eventer {
   179  	return l.sess.eventer
   180  }
   181  
   182  func (*localExecutor) HandleDebug(*http.ServeMux) {}
   183  
   184  // BufferOutput reads the output from reader and places it in a
   185  // task buffer. If the output is partitioned, bufferOutput invokes
   186  // the task's partitioner in order to determine the correct partition.
   187  func bufferOutput(ctx context.Context, task *Task, out sliceio.Reader) (buf taskBuffer, err error) {
   188  	if task.NumOut() == 0 {
   189  		_, err = out.Read(ctx, frame.Empty)
   190  		if err == sliceio.EOF {
   191  			err = nil
   192  		}
   193  		return nil, err
   194  	}
   195  	buf = make(taskBuffer, task.NumPartition)
   196  	var in frame.Frame
   197  	defer func() {
   198  		if e := recover(); e != nil {
   199  			stack := debug.Stack()
   200  			err = fmt.Errorf("panic while evaluating slice: %v\n%s", e, string(stack))
   201  			err = errors.E(err, errors.Fatal)
   202  		}
   203  	}()
   204  	shards := make([]int, *defaultChunksize)
   205  	for {
   206  		if in.IsZero() {
   207  			in = frame.Make(task, *defaultChunksize, *defaultChunksize)
   208  		}
   209  		n, err := out.Read(ctx, in)
   210  		if err != nil && err != sliceio.EOF {
   211  			return nil, err
   212  		}
   213  		// If the output needs to be partitioned, we ask the partitioner to
   214  		// assign partitions to each input element, and then append the
   215  		// elements in their respective partitions. In this case, we just
   216  		// maintain buffer slices of defaultChunksize each.
   217  		if task.NumPartition > 1 {
   218  			task.Partitioner(ctx, in, task.NumPartition, shards[:n])
   219  			for i := 0; i < n; i++ {
   220  				p := shards[i]
   221  				// If we don't yet have a buffer or the current one is at capacity,
   222  				// create a new one.
   223  				m := len(buf[p])
   224  				if m == 0 || buf[p][m-1].Cap() == buf[p][m-1].Len() {
   225  					frame := frame.Make(task, 0, *defaultChunksize)
   226  					buf[p] = append(buf[p], frame)
   227  					m++
   228  				}
   229  				buf[p][m-1] = frame.AppendFrame(buf[p][m-1], in.Slice(i, i+1))
   230  			}
   231  		} else if n > 0 {
   232  			in = in.Slice(0, n)
   233  			buf[0] = append(buf[0], in)
   234  			in = frame.Frame{}
   235  		}
   236  		if err == sliceio.EOF {
   237  			break
   238  		}
   239  	}
   240  	return buf, nil
   241  }
   242  
   243  type multiReader struct {
   244  	q   []sliceio.Reader
   245  	err error
   246  }
   247  
   248  func (m *multiReader) Read(ctx context.Context, out frame.Frame) (n int, err error) {
   249  	if m.err != nil {
   250  		return 0, m.err
   251  	}
   252  	for len(m.q) > 0 {
   253  		n, err := m.q[0].Read(ctx, out)
   254  		switch {
   255  		case err == sliceio.EOF:
   256  			m.q = m.q[1:]
   257  		case err != nil:
   258  			m.err = err
   259  			return n, err
   260  		case n > 0:
   261  			return n, err
   262  		}
   263  	}
   264  	return 0, sliceio.EOF
   265  }