github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/exec/eval_test.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package exec
     6  
     7  import (
     8  	"context"
     9  	goerrors "errors"
    10  	"flag"
    11  	"fmt"
    12  	"net/http"
    13  	"strings"
    14  	"sync"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/grailbio/base/errors"
    19  	"github.com/grailbio/base/eventlog"
    20  	"github.com/grailbio/bigslice"
    21  	"github.com/grailbio/bigslice/sliceio"
    22  	"golang.org/x/sync/errgroup"
    23  )
    24  
    25  type testExecutor struct{}
    26  
    27  func (testExecutor) Name() string {
    28  	return "test"
    29  }
    30  
    31  func (testExecutor) Start(*Session) (shutdown func()) {
    32  	return func() {}
    33  }
    34  
    35  func (t testExecutor) Run(task *Task) {
    36  	task.Lock()
    37  	task.state = TaskRunning
    38  	task.Broadcast()
    39  	task.Unlock()
    40  }
    41  
    42  func (testExecutor) Reader(*Task, int) sliceio.ReadCloser {
    43  	panic("not implemented")
    44  }
    45  
    46  func (testExecutor) Discard(context.Context, *Task) {}
    47  
    48  func (testExecutor) Eventer() eventlog.Eventer {
    49  	return eventlog.Nop{}
    50  }
    51  
    52  func (testExecutor) HandleDebug(handler *http.ServeMux) {
    53  	panic("not implemented")
    54  }
    55  
    56  // constEvalTest sets up a 2-root-node task graph.
    57  type constEvalTest struct {
    58  	Tasks []*Task
    59  
    60  	wg      sync.WaitGroup
    61  	evalErr error
    62  }
    63  
    64  func (c *constEvalTest) Go(t *testing.T) {
    65  	t.Helper()
    66  	c.Tasks, _, _ = compileFunc(func() bigslice.Slice {
    67  		return bigslice.Const(2, []int{1, 2, 3})
    68  	})
    69  	ctx := context.Background()
    70  	c.wg.Add(1)
    71  	go func() {
    72  		c.evalErr = Eval(ctx, testExecutor{}, c.Tasks, nil)
    73  		c.wg.Done()
    74  	}()
    75  }
    76  
    77  func (c *constEvalTest) EvalErr() error {
    78  	c.wg.Wait()
    79  	return c.evalErr
    80  }
    81  
    82  // SimpleEvalTest sets up a simple, 2-node task graph.
    83  type simpleEvalTest struct {
    84  	Tasks []*Task
    85  
    86  	ConstTask, CogroupTask *Task
    87  
    88  	wg      sync.WaitGroup
    89  	evalErr error
    90  }
    91  
    92  func (s *simpleEvalTest) Go(t *testing.T) {
    93  	t.Helper()
    94  	s.Tasks, _, _ = compileFunc(func() bigslice.Slice {
    95  		slice := bigslice.Const(1, []int{1, 2, 3})
    96  		slice = bigslice.Cogroup(slice)
    97  		return slice
    98  	})
    99  	s.ConstTask = s.Tasks[0].Deps[0].Task(0)
   100  	s.CogroupTask = s.Tasks[0]
   101  	ctx := context.Background()
   102  	s.wg.Add(1)
   103  	go func() {
   104  		s.evalErr = Eval(ctx, testExecutor{}, s.Tasks, nil)
   105  		s.wg.Done()
   106  	}()
   107  }
   108  
   109  func (s *simpleEvalTest) EvalErr() error {
   110  	s.wg.Wait()
   111  	return s.evalErr
   112  }
   113  
   114  func waitState(t *testing.T, task *Task, state TaskState) {
   115  	t.Helper()
   116  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   117  	defer cancel()
   118  	task.Lock()
   119  	defer task.Unlock()
   120  	for task.state != state {
   121  		if err := task.Wait(ctx); err != nil {
   122  			t.Fatalf("task %v (state %v) did not reach desired state %v", task.Name, task.state, state)
   123  		}
   124  	}
   125  }
   126  
   127  // TestTaskErr verifies that a task evaluation error (TaskErr) causes Eval to
   128  // return a corresponding error.
   129  func TestTaskErr(t *testing.T) {
   130  	var (
   131  		test simpleEvalTest
   132  		ctx  = context.Background()
   133  	)
   134  	test.Go(t)
   135  	state, err := test.ConstTask.WaitState(ctx, TaskRunning)
   136  	if err != nil {
   137  		t.Fatal(err)
   138  	}
   139  	if got, want := state, TaskRunning; got != want {
   140  		t.Fatalf("got %v, want %v", got, want)
   141  	}
   142  	if got, want := test.CogroupTask.State(), TaskInit; got != want {
   143  		t.Fatalf("got %v, want %v: %v", got, want, test.CogroupTask)
   144  	}
   145  	test.ConstTask.Error(goerrors.New("const task error"))
   146  
   147  	err = test.EvalErr()
   148  	if err == nil {
   149  		t.Fatal("expected error")
   150  	}
   151  	if got, want := strings.Contains(err.Error(), "const task error"), true; got != want {
   152  		t.Errorf("got %v, want %v", got, want)
   153  	}
   154  	if got, want := test.CogroupTask.State(), TaskInit; got != want {
   155  		t.Fatalf("got %v, want %v", got, want)
   156  	}
   157  }
   158  
   159  // TestAllRootsEvaluated verifies that all roots are evaluated at the moment
   160  // Eval returns.
   161  func TestAllRootsEvaluated(t *testing.T) {
   162  	var (
   163  		test constEvalTest
   164  		ctx  = context.Background()
   165  	)
   166  	test.Go(t)
   167  	// We have two root tasks, task0 and task1. task0 is evaluated
   168  	// successfully. While task1 runs, task0 is lost. Verify that Eval only
   169  	// returns once task0 is re-evaluated successfully.
   170  	var (
   171  		task0 = test.Tasks[0]
   172  		task1 = test.Tasks[1]
   173  	)
   174  	// task0 is evaluated successfully.
   175  	task0.Lock()
   176  	for task0.state != TaskRunning {
   177  		if err := task0.Wait(ctx); err != nil {
   178  			t.Fatal(err)
   179  		}
   180  	}
   181  	task0.state = TaskOk
   182  	task0.Broadcast()
   183  	task0.Unlock()
   184  	// While task1 runs, task0 is lost.
   185  	task1.Lock()
   186  	for task1.state != TaskRunning {
   187  		if err := task1.Wait(ctx); err != nil {
   188  			t.Fatal(err)
   189  		}
   190  	}
   191  	task1.Unlock()
   192  	// Allow time for evaluation to notice task0's TaskOk state before marking
   193  	// it lost.
   194  	// TODO: Though this seems to work reliably in my environment, consider a
   195  	// non-racy way of doing this. Note that this shouldn't ever cause the test
   196  	// to falsely fail. It just means that this will test the Running -> Lost
   197  	// path instead of the Running -> Ok -> Lost path, as the evaluator might
   198  	// not see the transient Ok state.
   199  	time.Sleep(1 * time.Millisecond)
   200  	task0.Lock()
   201  	task0.state = TaskLost
   202  	task0.Broadcast()
   203  	task0.Unlock()
   204  	// task1 is successfully evaluated.
   205  	task1.Lock()
   206  	task1.state = TaskOk
   207  	task1.Broadcast()
   208  	task1.Unlock()
   209  	task0.Lock()
   210  	// Expect task0 to be resubmitted. Eval should not return until all roots
   211  	// are successfully evaluated.
   212  	for task0.state != TaskRunning {
   213  		if err := task0.Wait(ctx); err != nil {
   214  			t.Fatal(err)
   215  		}
   216  	}
   217  	task0.state = TaskOk
   218  	task0.Broadcast()
   219  	task0.Unlock()
   220  	if err := test.EvalErr(); err != nil {
   221  		t.Fatal(err)
   222  	}
   223  }
   224  
   225  func TestResubmitLostTask(t *testing.T) {
   226  	var (
   227  		test simpleEvalTest
   228  		ctx  = context.Background()
   229  	)
   230  	test.Go(t)
   231  	var (
   232  		fst = test.ConstTask
   233  		snd = test.CogroupTask
   234  	)
   235  	fst.Lock()
   236  	for fst.state != TaskRunning {
   237  		if err := fst.Wait(ctx); err != nil {
   238  			t.Fatal(err)
   239  		}
   240  	}
   241  	fst.state = TaskLost
   242  	fst.Broadcast()
   243  	for fst.state == TaskLost {
   244  		if err := fst.Wait(ctx); err != nil {
   245  			t.Fatal(err)
   246  		}
   247  	}
   248  	// The evaluator should have resubmitted it.
   249  	if got, want := fst.state, TaskRunning; got != want {
   250  		t.Errorf("got %v, want %v", got, want)
   251  	}
   252  
   253  	// Now we lose both of them while the second is running.
   254  	// The evaluator should resubmit both.
   255  	fst.state = TaskOk
   256  	fst.Broadcast()
   257  	fst.Unlock()
   258  
   259  	snd.Lock()
   260  	for snd.state != TaskRunning {
   261  		if err := snd.Wait(ctx); err != nil {
   262  			t.Fatal(err)
   263  		}
   264  	}
   265  	fst.Lock()
   266  	snd.state = TaskLost
   267  	snd.Broadcast()
   268  	snd.Unlock()
   269  	fst.state = TaskLost
   270  	fst.Broadcast()
   271  
   272  	for fst.state < TaskRunning {
   273  		if err := fst.Wait(ctx); err != nil {
   274  			t.Fatal(err)
   275  		}
   276  	}
   277  	if got, want := snd.State(), TaskLost; got != want {
   278  		t.Errorf("got %v, want %v", got, want)
   279  	}
   280  	fst.state = TaskOk
   281  	fst.Broadcast()
   282  	fst.Unlock()
   283  
   284  	snd.Lock()
   285  	for snd.state < TaskRunning {
   286  		if err := snd.Wait(ctx); err != nil {
   287  			t.Fatal(err)
   288  		}
   289  	}
   290  	snd.state = TaskOk
   291  	snd.Broadcast()
   292  	snd.Unlock()
   293  
   294  	if err := test.EvalErr(); err != nil {
   295  		t.Fatal(err)
   296  	}
   297  }
   298  
   299  func TestResubmitLostInteriorTask(t *testing.T) {
   300  	for _, parallel := range []int{1, 10} {
   301  		parallel := parallel
   302  		t.Run(fmt.Sprintf("parallel=%v", parallel), func(t *testing.T) {
   303  			ctx, cancel := context.WithCancel(context.Background())
   304  			defer cancel()
   305  			tasks, _, _ := compileFunc(func() (slice bigslice.Slice) {
   306  				slice = bigslice.Const(2, []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
   307  				slice = bigslice.Cogroup(slice)
   308  				return
   309  			})
   310  
   311  			var g errgroup.Group
   312  			for i := 0; i < parallel; i++ {
   313  				g.Go(func() error { return Eval(ctx, testExecutor{}, tasks, nil) })
   314  			}
   315  
   316  			var (
   317  				const0   = tasks[0].Deps[0].Task(0)
   318  				const1   = tasks[0].Deps[0].Task(1)
   319  				cogroup0 = tasks[0]
   320  				cogroup1 = tasks[1]
   321  			)
   322  			waitState(t, const0, TaskRunning)
   323  			const0.Set(TaskOk)
   324  			waitState(t, const1, TaskRunning)
   325  			const1.Set(TaskOk)
   326  
   327  			waitState(t, cogroup0, TaskRunning)
   328  			waitState(t, cogroup1, TaskRunning)
   329  			const0.Set(TaskLost)
   330  			cogroup0.Set(TaskLost)
   331  			cogroup1.Set(TaskLost)
   332  
   333  			// Now, the evaluator must first recompute const0.
   334  			waitState(t, const0, TaskRunning)
   335  			// ... and then each of the cogroup tasks
   336  			const0.Set(TaskOk)
   337  			waitState(t, cogroup0, TaskRunning)
   338  			waitState(t, cogroup1, TaskRunning)
   339  			cogroup0.Set(TaskOk)
   340  			cogroup1.Set(TaskOk)
   341  
   342  			if err := g.Wait(); err != nil {
   343  				t.Fatal(err)
   344  			}
   345  		})
   346  	}
   347  }
   348  
   349  // TestPersistentTaskLoss verifies that the evaluator will abandon evaluation
   350  // with a task that is repeatedly lost on attempts to run it, as it is unable to
   351  // make meaningful progress.
   352  func TestPersistentTaskLoss(t *testing.T) {
   353  	var (
   354  		test        simpleEvalTest
   355  		ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second)
   356  	)
   357  	defer cancel()
   358  	test.Go(t)
   359  	fst := test.ConstTask
   360  	for {
   361  		if err := ctx.Err(); err != nil {
   362  			t.Fatal(err)
   363  		}
   364  		fst.Lock()
   365  		for fst.state != TaskRunning {
   366  			if err := fst.Wait(ctx); err != nil {
   367  				t.Fatal(err)
   368  			}
   369  		}
   370  		fst.state = TaskLost
   371  		fst.Broadcast()
   372  		for fst.state == TaskLost {
   373  			if err := fst.Wait(ctx); err != nil {
   374  				t.Fatal(err)
   375  			}
   376  		}
   377  		isErr := fst.state == TaskErr
   378  		fst.Unlock()
   379  		if isErr {
   380  			// The evaluator has given up on the task.
   381  			break
   382  		}
   383  	}
   384  	err := test.EvalErr()
   385  	if !errors.Is(errors.TooManyTries, err) {
   386  		t.Errorf("expected TooManyTries error, got: %v", err)
   387  	}
   388  }
   389  
   390  func multiPhaseCompile(nshard, nstage int) []*Task {
   391  	tasks, _, _ := compileFunc(func() bigslice.Slice {
   392  		keys := make([]string, nshard*2)
   393  		for i := range keys {
   394  			keys[i] = fmt.Sprint(i)
   395  		}
   396  		values := make([]int, nshard*2)
   397  		for i := range values {
   398  			values[i] = i
   399  		}
   400  
   401  		slice := bigslice.Const(nshard, keys, values)
   402  		for stage := 0; stage < nstage; stage++ {
   403  			slice = bigslice.Reduce(slice, func(i, j int) int { return i + j })
   404  		}
   405  		return slice
   406  	})
   407  	return tasks
   408  }
   409  
   410  func TestMultiPhaseEval(t *testing.T) {
   411  	const (
   412  		S = 1000
   413  		P = 10
   414  	)
   415  	tasks := multiPhaseCompile(S, P)
   416  	if got, want := len(tasks), S; got != want {
   417  		t.Fatalf("got %v, want %v", got, want)
   418  	}
   419  	var phases [][]*Task
   420  	for task := tasks[0].Deps[0].Task(0); ; {
   421  		phases = append(phases, task.Group)
   422  		if len(task.Deps) == 0 {
   423  			break
   424  		}
   425  		task = task.Deps[0].Task(0)
   426  	}
   427  	if got, want := len(phases), P; got != want {
   428  		t.Fatalf("got %v, want %v", got, want)
   429  	}
   430  	for _, group := range phases {
   431  		if got, want := len(group), S; got != want {
   432  			t.Errorf("got %v, want %v", got, want)
   433  		}
   434  	}
   435  
   436  	eval := func() (wait func()) {
   437  		var g errgroup.Group
   438  		g.Go(func() error {
   439  			t.Helper()
   440  			return Eval(context.Background(), testExecutor{}, tasks, nil)
   441  		})
   442  		return func() {
   443  			t.Helper()
   444  			if err := g.Wait(); err != nil {
   445  				t.Fatal(err)
   446  			}
   447  		}
   448  	}
   449  
   450  	wait := eval()
   451  
   452  	for i := len(phases) - 1; i >= 0; i-- {
   453  		group := phases[i]
   454  		for _, task := range group {
   455  			waitState(t, task, TaskRunning)
   456  		}
   457  		// Make sure no other tasks are waiting or running.
   458  		for j := i - 1; j >= 0; j-- {
   459  			otherGroup := phases[j]
   460  			for _, task := range otherGroup {
   461  				if task.State() != TaskInit {
   462  					t.Fatal(task, ": wrong state")
   463  				}
   464  			}
   465  		}
   466  		for _, task := range group {
   467  			task.Set(TaskOk)
   468  		}
   469  	}
   470  
   471  	for _, task := range tasks {
   472  		waitState(t, task, TaskRunning)
   473  		task.Set(TaskOk)
   474  	}
   475  	wait()
   476  
   477  	mustState := func(task *Task, state TaskState) {
   478  		t.Helper()
   479  		if got, want := task.State(), state; got != want {
   480  			t.Fatalf("%v: got %v, want %v", task, got, want)
   481  		}
   482  	}
   483  
   484  	mustStates := func(def TaskState, states map[*Task]TaskState) {
   485  		t.Helper()
   486  		for _, group := range phases {
   487  			for _, task := range group {
   488  				state, ok := states[task]
   489  				if !ok {
   490  					state = def
   491  				}
   492  				mustState(task, state)
   493  			}
   494  		}
   495  		for _, task := range tasks {
   496  			state, ok := states[task]
   497  			if !ok {
   498  				state = def
   499  			}
   500  			mustState(task, state)
   501  		}
   502  	}
   503  
   504  	// An exterior task failure means a single resubmit.
   505  	tasks[S/2].Set(TaskLost)
   506  	wait = eval()
   507  
   508  	waitState(t, tasks[S/2], TaskRunning)
   509  	mustStates(TaskOk, map[*Task]TaskState{
   510  		tasks[S/2]: TaskRunning,
   511  	})
   512  	tasks[S/2].Set(TaskOk)
   513  	wait()
   514  
   515  	// A reachable path of interior task failures get resubmitted.
   516  	lost := []*Task{
   517  		tasks[S/2],
   518  		phases[0][S/2],
   519  		phases[1][S/2],
   520  	}
   521  	unreachable := phases[3][S/2]
   522  	for _, task := range lost {
   523  		task.Set(TaskLost)
   524  	}
   525  	unreachable.Set(TaskLost)
   526  	wait = eval()
   527  	waitState(t, lost[len(lost)-1], TaskRunning)
   528  	mustStates(TaskOk, map[*Task]TaskState{
   529  		unreachable: TaskLost,
   530  		lost[0]:     TaskLost,
   531  		lost[1]:     TaskLost,
   532  		lost[2]:     TaskRunning,
   533  	})
   534  	lost[2].Set(TaskOk)
   535  	waitState(t, lost[1], TaskRunning)
   536  	mustStates(TaskOk, map[*Task]TaskState{
   537  		unreachable: TaskLost,
   538  		lost[0]:     TaskLost,
   539  		lost[1]:     TaskRunning,
   540  	})
   541  	lost[1].Set(TaskOk)
   542  	waitState(t, lost[0], TaskRunning)
   543  	mustStates(TaskOk, map[*Task]TaskState{
   544  		unreachable: TaskLost,
   545  		lost[0]:     TaskRunning,
   546  	})
   547  	lost[0].Set(TaskOk)
   548  	mustStates(TaskOk, map[*Task]TaskState{
   549  		unreachable: TaskLost,
   550  	})
   551  	wait()
   552  }
   553  
   554  type benchExecutor struct{ *testing.B }
   555  
   556  func (benchExecutor) Start(*Session) (shutdown func()) {
   557  	return func() {}
   558  }
   559  
   560  func (b benchExecutor) Run(task *Task) {
   561  	task.Lock()
   562  	task.state = TaskOk
   563  	task.Broadcast()
   564  	task.Unlock()
   565  }
   566  
   567  func (benchExecutor) Reader(*Task, int) sliceio.ReadCloser {
   568  	panic("not implemented")
   569  }
   570  
   571  func (benchExecutor) Discard(context.Context, *Task) {}
   572  
   573  func (benchExecutor) Eventer() eventlog.Eventer {
   574  	return eventlog.Nop{}
   575  }
   576  
   577  func (benchExecutor) HandleDebug(handler *http.ServeMux) {
   578  	panic("not implemented")
   579  }
   580  
   581  var evalStages = flag.Int("eval.bench.stages", 5, "number of stages for eval benchmark")
   582  
   583  func BenchmarkEval(b *testing.B) {
   584  	for _, nshard := range []int{10, 100, 1000, 5000 /*, 100000*/} {
   585  		b.Run(fmt.Sprintf("eval.%d", nshard), func(b *testing.B) {
   586  			ctx := context.Background()
   587  			for i := 0; i < b.N; i++ {
   588  				b.StopTimer()
   589  				tasks := multiPhaseCompile(nshard, *evalStages)
   590  				if i == 0 {
   591  					b.Log("ntask=", len(tasks))
   592  				}
   593  				b.StartTimer()
   594  				if err := Eval(ctx, benchExecutor{b}, tasks, nil); err != nil {
   595  					b.Fatal(err)
   596  				}
   597  			}
   598  		})
   599  	}
   600  }
   601  
   602  func BenchmarkEnqueue(b *testing.B) {
   603  	for _, nshard := range []int{10, 100, 1000, 5000 /*, 100000*/} {
   604  		b.Run(fmt.Sprintf("enqueue.%d", nshard), func(b *testing.B) {
   605  			for i := 0; i < b.N; i++ {
   606  				b.StopTimer()
   607  				tasks := multiPhaseCompile(nshard, *evalStages)
   608  				if i == 0 {
   609  					b.Log("ntask=", len(tasks))
   610  				}
   611  				state := newState()
   612  				b.StartTimer()
   613  
   614  				for _, task := range tasks {
   615  					state.Enqueue(task)
   616  				}
   617  			}
   618  		})
   619  	}
   620  }