github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/exec/eval_test.go (about) 1 // Copyright 2018 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache 2.0 3 // license that can be found in the LICENSE file. 4 5 package exec 6 7 import ( 8 "context" 9 goerrors "errors" 10 "flag" 11 "fmt" 12 "net/http" 13 "strings" 14 "sync" 15 "testing" 16 "time" 17 18 "github.com/grailbio/base/errors" 19 "github.com/grailbio/base/eventlog" 20 "github.com/grailbio/bigslice" 21 "github.com/grailbio/bigslice/sliceio" 22 "golang.org/x/sync/errgroup" 23 ) 24 25 type testExecutor struct{} 26 27 func (testExecutor) Name() string { 28 return "test" 29 } 30 31 func (testExecutor) Start(*Session) (shutdown func()) { 32 return func() {} 33 } 34 35 func (t testExecutor) Run(task *Task) { 36 task.Lock() 37 task.state = TaskRunning 38 task.Broadcast() 39 task.Unlock() 40 } 41 42 func (testExecutor) Reader(*Task, int) sliceio.ReadCloser { 43 panic("not implemented") 44 } 45 46 func (testExecutor) Discard(context.Context, *Task) {} 47 48 func (testExecutor) Eventer() eventlog.Eventer { 49 return eventlog.Nop{} 50 } 51 52 func (testExecutor) HandleDebug(handler *http.ServeMux) { 53 panic("not implemented") 54 } 55 56 // constEvalTest sets up a 2-root-node task graph. 57 type constEvalTest struct { 58 Tasks []*Task 59 60 wg sync.WaitGroup 61 evalErr error 62 } 63 64 func (c *constEvalTest) Go(t *testing.T) { 65 t.Helper() 66 c.Tasks, _, _ = compileFunc(func() bigslice.Slice { 67 return bigslice.Const(2, []int{1, 2, 3}) 68 }) 69 ctx := context.Background() 70 c.wg.Add(1) 71 go func() { 72 c.evalErr = Eval(ctx, testExecutor{}, c.Tasks, nil) 73 c.wg.Done() 74 }() 75 } 76 77 func (c *constEvalTest) EvalErr() error { 78 c.wg.Wait() 79 return c.evalErr 80 } 81 82 // SimpleEvalTest sets up a simple, 2-node task graph. 83 type simpleEvalTest struct { 84 Tasks []*Task 85 86 ConstTask, CogroupTask *Task 87 88 wg sync.WaitGroup 89 evalErr error 90 } 91 92 func (s *simpleEvalTest) Go(t *testing.T) { 93 t.Helper() 94 s.Tasks, _, _ = compileFunc(func() bigslice.Slice { 95 slice := bigslice.Const(1, []int{1, 2, 3}) 96 slice = bigslice.Cogroup(slice) 97 return slice 98 }) 99 s.ConstTask = s.Tasks[0].Deps[0].Task(0) 100 s.CogroupTask = s.Tasks[0] 101 ctx := context.Background() 102 s.wg.Add(1) 103 go func() { 104 s.evalErr = Eval(ctx, testExecutor{}, s.Tasks, nil) 105 s.wg.Done() 106 }() 107 } 108 109 func (s *simpleEvalTest) EvalErr() error { 110 s.wg.Wait() 111 return s.evalErr 112 } 113 114 func waitState(t *testing.T, task *Task, state TaskState) { 115 t.Helper() 116 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 117 defer cancel() 118 task.Lock() 119 defer task.Unlock() 120 for task.state != state { 121 if err := task.Wait(ctx); err != nil { 122 t.Fatalf("task %v (state %v) did not reach desired state %v", task.Name, task.state, state) 123 } 124 } 125 } 126 127 // TestTaskErr verifies that a task evaluation error (TaskErr) causes Eval to 128 // return a corresponding error. 129 func TestTaskErr(t *testing.T) { 130 var ( 131 test simpleEvalTest 132 ctx = context.Background() 133 ) 134 test.Go(t) 135 state, err := test.ConstTask.WaitState(ctx, TaskRunning) 136 if err != nil { 137 t.Fatal(err) 138 } 139 if got, want := state, TaskRunning; got != want { 140 t.Fatalf("got %v, want %v", got, want) 141 } 142 if got, want := test.CogroupTask.State(), TaskInit; got != want { 143 t.Fatalf("got %v, want %v: %v", got, want, test.CogroupTask) 144 } 145 test.ConstTask.Error(goerrors.New("const task error")) 146 147 err = test.EvalErr() 148 if err == nil { 149 t.Fatal("expected error") 150 } 151 if got, want := strings.Contains(err.Error(), "const task error"), true; got != want { 152 t.Errorf("got %v, want %v", got, want) 153 } 154 if got, want := test.CogroupTask.State(), TaskInit; got != want { 155 t.Fatalf("got %v, want %v", got, want) 156 } 157 } 158 159 // TestAllRootsEvaluated verifies that all roots are evaluated at the moment 160 // Eval returns. 161 func TestAllRootsEvaluated(t *testing.T) { 162 var ( 163 test constEvalTest 164 ctx = context.Background() 165 ) 166 test.Go(t) 167 // We have two root tasks, task0 and task1. task0 is evaluated 168 // successfully. While task1 runs, task0 is lost. Verify that Eval only 169 // returns once task0 is re-evaluated successfully. 170 var ( 171 task0 = test.Tasks[0] 172 task1 = test.Tasks[1] 173 ) 174 // task0 is evaluated successfully. 175 task0.Lock() 176 for task0.state != TaskRunning { 177 if err := task0.Wait(ctx); err != nil { 178 t.Fatal(err) 179 } 180 } 181 task0.state = TaskOk 182 task0.Broadcast() 183 task0.Unlock() 184 // While task1 runs, task0 is lost. 185 task1.Lock() 186 for task1.state != TaskRunning { 187 if err := task1.Wait(ctx); err != nil { 188 t.Fatal(err) 189 } 190 } 191 task1.Unlock() 192 // Allow time for evaluation to notice task0's TaskOk state before marking 193 // it lost. 194 // TODO: Though this seems to work reliably in my environment, consider a 195 // non-racy way of doing this. Note that this shouldn't ever cause the test 196 // to falsely fail. It just means that this will test the Running -> Lost 197 // path instead of the Running -> Ok -> Lost path, as the evaluator might 198 // not see the transient Ok state. 199 time.Sleep(1 * time.Millisecond) 200 task0.Lock() 201 task0.state = TaskLost 202 task0.Broadcast() 203 task0.Unlock() 204 // task1 is successfully evaluated. 205 task1.Lock() 206 task1.state = TaskOk 207 task1.Broadcast() 208 task1.Unlock() 209 task0.Lock() 210 // Expect task0 to be resubmitted. Eval should not return until all roots 211 // are successfully evaluated. 212 for task0.state != TaskRunning { 213 if err := task0.Wait(ctx); err != nil { 214 t.Fatal(err) 215 } 216 } 217 task0.state = TaskOk 218 task0.Broadcast() 219 task0.Unlock() 220 if err := test.EvalErr(); err != nil { 221 t.Fatal(err) 222 } 223 } 224 225 func TestResubmitLostTask(t *testing.T) { 226 var ( 227 test simpleEvalTest 228 ctx = context.Background() 229 ) 230 test.Go(t) 231 var ( 232 fst = test.ConstTask 233 snd = test.CogroupTask 234 ) 235 fst.Lock() 236 for fst.state != TaskRunning { 237 if err := fst.Wait(ctx); err != nil { 238 t.Fatal(err) 239 } 240 } 241 fst.state = TaskLost 242 fst.Broadcast() 243 for fst.state == TaskLost { 244 if err := fst.Wait(ctx); err != nil { 245 t.Fatal(err) 246 } 247 } 248 // The evaluator should have resubmitted it. 249 if got, want := fst.state, TaskRunning; got != want { 250 t.Errorf("got %v, want %v", got, want) 251 } 252 253 // Now we lose both of them while the second is running. 254 // The evaluator should resubmit both. 255 fst.state = TaskOk 256 fst.Broadcast() 257 fst.Unlock() 258 259 snd.Lock() 260 for snd.state != TaskRunning { 261 if err := snd.Wait(ctx); err != nil { 262 t.Fatal(err) 263 } 264 } 265 fst.Lock() 266 snd.state = TaskLost 267 snd.Broadcast() 268 snd.Unlock() 269 fst.state = TaskLost 270 fst.Broadcast() 271 272 for fst.state < TaskRunning { 273 if err := fst.Wait(ctx); err != nil { 274 t.Fatal(err) 275 } 276 } 277 if got, want := snd.State(), TaskLost; got != want { 278 t.Errorf("got %v, want %v", got, want) 279 } 280 fst.state = TaskOk 281 fst.Broadcast() 282 fst.Unlock() 283 284 snd.Lock() 285 for snd.state < TaskRunning { 286 if err := snd.Wait(ctx); err != nil { 287 t.Fatal(err) 288 } 289 } 290 snd.state = TaskOk 291 snd.Broadcast() 292 snd.Unlock() 293 294 if err := test.EvalErr(); err != nil { 295 t.Fatal(err) 296 } 297 } 298 299 func TestResubmitLostInteriorTask(t *testing.T) { 300 for _, parallel := range []int{1, 10} { 301 parallel := parallel 302 t.Run(fmt.Sprintf("parallel=%v", parallel), func(t *testing.T) { 303 ctx, cancel := context.WithCancel(context.Background()) 304 defer cancel() 305 tasks, _, _ := compileFunc(func() (slice bigslice.Slice) { 306 slice = bigslice.Const(2, []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) 307 slice = bigslice.Cogroup(slice) 308 return 309 }) 310 311 var g errgroup.Group 312 for i := 0; i < parallel; i++ { 313 g.Go(func() error { return Eval(ctx, testExecutor{}, tasks, nil) }) 314 } 315 316 var ( 317 const0 = tasks[0].Deps[0].Task(0) 318 const1 = tasks[0].Deps[0].Task(1) 319 cogroup0 = tasks[0] 320 cogroup1 = tasks[1] 321 ) 322 waitState(t, const0, TaskRunning) 323 const0.Set(TaskOk) 324 waitState(t, const1, TaskRunning) 325 const1.Set(TaskOk) 326 327 waitState(t, cogroup0, TaskRunning) 328 waitState(t, cogroup1, TaskRunning) 329 const0.Set(TaskLost) 330 cogroup0.Set(TaskLost) 331 cogroup1.Set(TaskLost) 332 333 // Now, the evaluator must first recompute const0. 334 waitState(t, const0, TaskRunning) 335 // ... and then each of the cogroup tasks 336 const0.Set(TaskOk) 337 waitState(t, cogroup0, TaskRunning) 338 waitState(t, cogroup1, TaskRunning) 339 cogroup0.Set(TaskOk) 340 cogroup1.Set(TaskOk) 341 342 if err := g.Wait(); err != nil { 343 t.Fatal(err) 344 } 345 }) 346 } 347 } 348 349 // TestPersistentTaskLoss verifies that the evaluator will abandon evaluation 350 // with a task that is repeatedly lost on attempts to run it, as it is unable to 351 // make meaningful progress. 352 func TestPersistentTaskLoss(t *testing.T) { 353 var ( 354 test simpleEvalTest 355 ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) 356 ) 357 defer cancel() 358 test.Go(t) 359 fst := test.ConstTask 360 for { 361 if err := ctx.Err(); err != nil { 362 t.Fatal(err) 363 } 364 fst.Lock() 365 for fst.state != TaskRunning { 366 if err := fst.Wait(ctx); err != nil { 367 t.Fatal(err) 368 } 369 } 370 fst.state = TaskLost 371 fst.Broadcast() 372 for fst.state == TaskLost { 373 if err := fst.Wait(ctx); err != nil { 374 t.Fatal(err) 375 } 376 } 377 isErr := fst.state == TaskErr 378 fst.Unlock() 379 if isErr { 380 // The evaluator has given up on the task. 381 break 382 } 383 } 384 err := test.EvalErr() 385 if !errors.Is(errors.TooManyTries, err) { 386 t.Errorf("expected TooManyTries error, got: %v", err) 387 } 388 } 389 390 func multiPhaseCompile(nshard, nstage int) []*Task { 391 tasks, _, _ := compileFunc(func() bigslice.Slice { 392 keys := make([]string, nshard*2) 393 for i := range keys { 394 keys[i] = fmt.Sprint(i) 395 } 396 values := make([]int, nshard*2) 397 for i := range values { 398 values[i] = i 399 } 400 401 slice := bigslice.Const(nshard, keys, values) 402 for stage := 0; stage < nstage; stage++ { 403 slice = bigslice.Reduce(slice, func(i, j int) int { return i + j }) 404 } 405 return slice 406 }) 407 return tasks 408 } 409 410 func TestMultiPhaseEval(t *testing.T) { 411 const ( 412 S = 1000 413 P = 10 414 ) 415 tasks := multiPhaseCompile(S, P) 416 if got, want := len(tasks), S; got != want { 417 t.Fatalf("got %v, want %v", got, want) 418 } 419 var phases [][]*Task 420 for task := tasks[0].Deps[0].Task(0); ; { 421 phases = append(phases, task.Group) 422 if len(task.Deps) == 0 { 423 break 424 } 425 task = task.Deps[0].Task(0) 426 } 427 if got, want := len(phases), P; got != want { 428 t.Fatalf("got %v, want %v", got, want) 429 } 430 for _, group := range phases { 431 if got, want := len(group), S; got != want { 432 t.Errorf("got %v, want %v", got, want) 433 } 434 } 435 436 eval := func() (wait func()) { 437 var g errgroup.Group 438 g.Go(func() error { 439 t.Helper() 440 return Eval(context.Background(), testExecutor{}, tasks, nil) 441 }) 442 return func() { 443 t.Helper() 444 if err := g.Wait(); err != nil { 445 t.Fatal(err) 446 } 447 } 448 } 449 450 wait := eval() 451 452 for i := len(phases) - 1; i >= 0; i-- { 453 group := phases[i] 454 for _, task := range group { 455 waitState(t, task, TaskRunning) 456 } 457 // Make sure no other tasks are waiting or running. 458 for j := i - 1; j >= 0; j-- { 459 otherGroup := phases[j] 460 for _, task := range otherGroup { 461 if task.State() != TaskInit { 462 t.Fatal(task, ": wrong state") 463 } 464 } 465 } 466 for _, task := range group { 467 task.Set(TaskOk) 468 } 469 } 470 471 for _, task := range tasks { 472 waitState(t, task, TaskRunning) 473 task.Set(TaskOk) 474 } 475 wait() 476 477 mustState := func(task *Task, state TaskState) { 478 t.Helper() 479 if got, want := task.State(), state; got != want { 480 t.Fatalf("%v: got %v, want %v", task, got, want) 481 } 482 } 483 484 mustStates := func(def TaskState, states map[*Task]TaskState) { 485 t.Helper() 486 for _, group := range phases { 487 for _, task := range group { 488 state, ok := states[task] 489 if !ok { 490 state = def 491 } 492 mustState(task, state) 493 } 494 } 495 for _, task := range tasks { 496 state, ok := states[task] 497 if !ok { 498 state = def 499 } 500 mustState(task, state) 501 } 502 } 503 504 // An exterior task failure means a single resubmit. 505 tasks[S/2].Set(TaskLost) 506 wait = eval() 507 508 waitState(t, tasks[S/2], TaskRunning) 509 mustStates(TaskOk, map[*Task]TaskState{ 510 tasks[S/2]: TaskRunning, 511 }) 512 tasks[S/2].Set(TaskOk) 513 wait() 514 515 // A reachable path of interior task failures get resubmitted. 516 lost := []*Task{ 517 tasks[S/2], 518 phases[0][S/2], 519 phases[1][S/2], 520 } 521 unreachable := phases[3][S/2] 522 for _, task := range lost { 523 task.Set(TaskLost) 524 } 525 unreachable.Set(TaskLost) 526 wait = eval() 527 waitState(t, lost[len(lost)-1], TaskRunning) 528 mustStates(TaskOk, map[*Task]TaskState{ 529 unreachable: TaskLost, 530 lost[0]: TaskLost, 531 lost[1]: TaskLost, 532 lost[2]: TaskRunning, 533 }) 534 lost[2].Set(TaskOk) 535 waitState(t, lost[1], TaskRunning) 536 mustStates(TaskOk, map[*Task]TaskState{ 537 unreachable: TaskLost, 538 lost[0]: TaskLost, 539 lost[1]: TaskRunning, 540 }) 541 lost[1].Set(TaskOk) 542 waitState(t, lost[0], TaskRunning) 543 mustStates(TaskOk, map[*Task]TaskState{ 544 unreachable: TaskLost, 545 lost[0]: TaskRunning, 546 }) 547 lost[0].Set(TaskOk) 548 mustStates(TaskOk, map[*Task]TaskState{ 549 unreachable: TaskLost, 550 }) 551 wait() 552 } 553 554 type benchExecutor struct{ *testing.B } 555 556 func (benchExecutor) Start(*Session) (shutdown func()) { 557 return func() {} 558 } 559 560 func (b benchExecutor) Run(task *Task) { 561 task.Lock() 562 task.state = TaskOk 563 task.Broadcast() 564 task.Unlock() 565 } 566 567 func (benchExecutor) Reader(*Task, int) sliceio.ReadCloser { 568 panic("not implemented") 569 } 570 571 func (benchExecutor) Discard(context.Context, *Task) {} 572 573 func (benchExecutor) Eventer() eventlog.Eventer { 574 return eventlog.Nop{} 575 } 576 577 func (benchExecutor) HandleDebug(handler *http.ServeMux) { 578 panic("not implemented") 579 } 580 581 var evalStages = flag.Int("eval.bench.stages", 5, "number of stages for eval benchmark") 582 583 func BenchmarkEval(b *testing.B) { 584 for _, nshard := range []int{10, 100, 1000, 5000 /*, 100000*/} { 585 b.Run(fmt.Sprintf("eval.%d", nshard), func(b *testing.B) { 586 ctx := context.Background() 587 for i := 0; i < b.N; i++ { 588 b.StopTimer() 589 tasks := multiPhaseCompile(nshard, *evalStages) 590 if i == 0 { 591 b.Log("ntask=", len(tasks)) 592 } 593 b.StartTimer() 594 if err := Eval(ctx, benchExecutor{b}, tasks, nil); err != nil { 595 b.Fatal(err) 596 } 597 } 598 }) 599 } 600 } 601 602 func BenchmarkEnqueue(b *testing.B) { 603 for _, nshard := range []int{10, 100, 1000, 5000 /*, 100000*/} { 604 b.Run(fmt.Sprintf("enqueue.%d", nshard), func(b *testing.B) { 605 for i := 0; i < b.N; i++ { 606 b.StopTimer() 607 tasks := multiPhaseCompile(nshard, *evalStages) 608 if i == 0 { 609 b.Log("ntask=", len(tasks)) 610 } 611 state := newState() 612 b.StartTimer() 613 614 for _, task := range tasks { 615 state.Enqueue(task) 616 } 617 } 618 }) 619 } 620 }