github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/module/component/component_manager_test.go (about) 1 package component_test 2 3 import ( 4 "context" 5 "fmt" 6 "testing" 7 "time" 8 9 "github.com/hashicorp/go-multierror" 10 "github.com/stretchr/testify/assert" 11 "github.com/stretchr/testify/require" 12 "pgregory.net/rapid" 13 14 "github.com/onflow/flow-go/module/component" 15 "github.com/onflow/flow-go/module/irrecoverable" 16 "github.com/onflow/flow-go/module/util" 17 "github.com/onflow/flow-go/utils/unittest" 18 ) 19 20 const CHANNEL_CLOSE_LATENCY_ALLOWANCE = 25 * time.Millisecond 21 22 type WorkerState int 23 24 const ( 25 UnknownWorkerState = iota 26 WorkerStartingUp // worker is starting up 27 WorkerStartupShuttingDown // worker was canceled during startup and is shutting down 28 WorkerStartupCanceled // worker has exited after being canceled during startup 29 WorkerStartupEncounteredFatal // worker encountered a fatal error during startup 30 WorkerRunning // worker has started up and is running normally 31 WorkerShuttingDown // worker was canceled and is shutting down 32 WorkerCanceled // worker has exited after being canceled 33 WorkerEncounteredFatal // worker encountered a fatal error 34 WorkerDone // worker has shut down after running normally 35 36 ) 37 38 func (s WorkerState) String() string { 39 switch s { 40 case WorkerStartingUp: 41 return "WORKER_STARTING_UP" 42 case WorkerStartupShuttingDown: 43 return "WORKER_STARTUP_SHUTTING_DOWN" 44 case WorkerStartupCanceled: 45 return "WORKER_STARTUP_CANCELED" 46 case WorkerStartupEncounteredFatal: 47 return "WORKER_STARTUP_ENCOUNTERED_FATAL" 48 case WorkerRunning: 49 return "WORKER_RUNNING" 50 case WorkerShuttingDown: 51 return "WORKER_SHUTTING_DOWN" 52 case WorkerCanceled: 53 return "WORKER_CANCELED" 54 case WorkerEncounteredFatal: 55 return "WORKER_ENCOUNTERED_FATAL" 56 case WorkerDone: 57 return "WORKER_DONE" 58 default: 59 return "UNKNOWN" 60 } 61 } 62 63 type WorkerStateList []WorkerState 64 65 func (wsl WorkerStateList) Contains(ws WorkerState) bool { 66 for _, s := range wsl { 67 if s == ws { 68 return true 69 } 70 } 71 return false 72 } 73 74 type WorkerStateTransition int 75 76 const ( 77 UnknownWorkerStateTransition WorkerStateTransition = iota 78 WorkerCheckCtxAndShutdown // check context and shutdown if canceled 79 WorkerCheckCtxAndExit // check context and exit immediately if canceled 80 WorkerFinishStartup // finish starting up 81 WorkerDoWork // do work 82 WorkerExit // exit 83 WorkerThrowError // throw error 84 ) 85 86 func (wst WorkerStateTransition) String() string { 87 switch wst { 88 case WorkerCheckCtxAndShutdown: 89 return "WORKER_CHECK_CTX_AND_SHUTDOWN" 90 case WorkerCheckCtxAndExit: 91 return "WORKER_CHECK_CTX_AND_EXIT" 92 case WorkerFinishStartup: 93 return "WORKER_FINISH_STARTUP" 94 case WorkerDoWork: 95 return "WORKER_DO_WORK" 96 case WorkerExit: 97 return "WORKER_EXIT" 98 case WorkerThrowError: 99 return "WORKER_THROW_ERROR" 100 default: 101 return "UNKNOWN" 102 } 103 } 104 105 // WorkerStateTransitions is a map from worker state to valid state transitions 106 var WorkerStateTransitions = map[WorkerState][]WorkerStateTransition{ 107 WorkerStartingUp: {WorkerCheckCtxAndExit, WorkerCheckCtxAndShutdown, WorkerDoWork, WorkerFinishStartup}, 108 WorkerStartupShuttingDown: {WorkerDoWork, WorkerExit, WorkerThrowError}, 109 WorkerRunning: {WorkerCheckCtxAndExit, WorkerCheckCtxAndShutdown, WorkerDoWork, WorkerExit, WorkerThrowError}, 110 WorkerShuttingDown: {WorkerDoWork, WorkerExit, WorkerThrowError}, 111 } 112 113 // CheckWorkerStateTransition checks the validity of a worker state transition 114 func CheckWorkerStateTransition(t *rapid.T, id int, start, end WorkerState, transition WorkerStateTransition, canceled bool) { 115 if !(func() bool { 116 switch start { 117 case WorkerStartingUp: 118 switch transition { 119 case WorkerCheckCtxAndExit: 120 return (canceled && end == WorkerStartupCanceled) || (!canceled && end == WorkerStartingUp) 121 case WorkerCheckCtxAndShutdown: 122 return (canceled && end == WorkerStartupShuttingDown) || (!canceled && end == WorkerStartingUp) 123 case WorkerDoWork: 124 return end == WorkerStartingUp 125 case WorkerFinishStartup: 126 return end == WorkerRunning 127 } 128 case WorkerStartupShuttingDown: 129 switch transition { 130 case WorkerDoWork: 131 return end == WorkerStartupShuttingDown 132 case WorkerExit: 133 return end == WorkerStartupCanceled 134 case WorkerThrowError: 135 return end == WorkerStartupEncounteredFatal 136 } 137 case WorkerRunning: 138 switch transition { 139 case WorkerCheckCtxAndExit: 140 return (canceled && end == WorkerCanceled) || (!canceled && end == WorkerRunning) 141 case WorkerCheckCtxAndShutdown: 142 return (canceled && end == WorkerShuttingDown) || (!canceled && end == WorkerRunning) 143 case WorkerDoWork: 144 return end == WorkerRunning 145 case WorkerExit: 146 return end == WorkerDone 147 case WorkerThrowError: 148 return end == WorkerEncounteredFatal 149 } 150 case WorkerShuttingDown: 151 switch transition { 152 case WorkerDoWork: 153 return end == WorkerShuttingDown 154 case WorkerExit: 155 return end == WorkerCanceled 156 case WorkerThrowError: 157 return end == WorkerEncounteredFatal 158 } 159 } 160 161 return false 162 }()) { 163 require.Fail(t, "invalid worker state transition", "[worker %v] start=%v, canceled=%v, transition=%v, end=%v", id, start, canceled, transition, end) 164 } 165 } 166 167 type WSTConsumer func(WorkerStateTransition) WorkerState 168 type WSTProvider func(WorkerState) WorkerStateTransition 169 170 // MakeWorkerTransitionFuncs creates a WorkerStateTransition Consumer / Provider pair. 171 // The Consumer is called by the worker to notify the test code of the completion of a state transition 172 // and receive the next state transition instruction. 173 // The Provider is called by the test code to send the next state transition instruction and get the 174 // resulting end state. 175 func MakeWorkerTransitionFuncs() (WSTConsumer, WSTProvider) { 176 var started bool 177 stateChan := make(chan WorkerState, 1) 178 transitionChan := make(chan WorkerStateTransition) 179 180 consumer := func(wst WorkerStateTransition) WorkerState { 181 transitionChan <- wst 182 return <-stateChan 183 } 184 185 provider := func(ws WorkerState) WorkerStateTransition { 186 if started { 187 stateChan <- ws 188 } else { 189 started = true 190 } 191 192 if _, ok := WorkerStateTransitions[ws]; !ok { 193 return UnknownWorkerStateTransition 194 } 195 196 return <-transitionChan 197 } 198 199 return consumer, provider 200 } 201 202 func ComponentWorker(t *rapid.T, id int, next WSTProvider) component.ComponentWorker { 203 unexpectedStateTransition := func(s WorkerState, wst WorkerStateTransition) { 204 panic(fmt.Sprintf("[worker %v] unexpected state transition: received %v for state %v", id, wst, s)) 205 } 206 207 log := func(msg string) { 208 t.Logf("[worker %v] %v\n", id, msg) 209 } 210 211 return func(ctx irrecoverable.SignalerContext, ready component.ReadyFunc) { 212 var state WorkerState 213 goto startingUp 214 215 startingUp: 216 log("starting up") 217 state = WorkerStartingUp 218 switch transition := next(state); transition { 219 case WorkerCheckCtxAndExit: 220 if util.CheckClosed(ctx.Done()) { 221 goto startupCanceled 222 } 223 goto startingUp 224 case WorkerCheckCtxAndShutdown: 225 if util.CheckClosed(ctx.Done()) { 226 goto startupShuttingDown 227 } 228 goto startingUp 229 case WorkerDoWork: 230 goto startingUp 231 case WorkerFinishStartup: 232 ready() 233 goto running 234 default: 235 unexpectedStateTransition(state, transition) 236 } 237 238 startupShuttingDown: 239 log("startup shutting down") 240 state = WorkerStartupShuttingDown 241 switch transition := next(state); transition { 242 case WorkerDoWork: 243 goto startupShuttingDown 244 case WorkerExit: 245 goto startupCanceled 246 case WorkerThrowError: 247 goto startupEncounteredFatal 248 default: 249 unexpectedStateTransition(state, transition) 250 } 251 252 startupCanceled: 253 log("startup canceled") 254 state = WorkerStartupCanceled 255 next(state) 256 return 257 258 startupEncounteredFatal: 259 log("startup encountered fatal") 260 state = WorkerStartupEncounteredFatal 261 defer next(state) 262 ctx.Throw(&WorkerError{id}) 263 264 running: 265 log("running") 266 state = WorkerRunning 267 switch transition := next(state); transition { 268 case WorkerCheckCtxAndExit: 269 if util.CheckClosed(ctx.Done()) { 270 goto canceled 271 } 272 goto running 273 case WorkerCheckCtxAndShutdown: 274 if util.CheckClosed(ctx.Done()) { 275 goto shuttingDown 276 } 277 goto running 278 case WorkerDoWork: 279 goto running 280 case WorkerExit: 281 goto done 282 case WorkerThrowError: 283 goto encounteredFatal 284 default: 285 unexpectedStateTransition(state, transition) 286 } 287 288 shuttingDown: 289 log("shutting down") 290 state = WorkerShuttingDown 291 switch transition := next(state); transition { 292 case WorkerDoWork: 293 goto shuttingDown 294 case WorkerExit: 295 goto canceled 296 case WorkerThrowError: 297 goto encounteredFatal 298 default: 299 unexpectedStateTransition(state, transition) 300 } 301 302 canceled: 303 log("canceled") 304 state = WorkerCanceled 305 next(state) 306 return 307 308 encounteredFatal: 309 log("encountered fatal") 310 state = WorkerEncounteredFatal 311 defer next(state) 312 ctx.Throw(&WorkerError{id}) 313 314 done: 315 log("done") 316 state = WorkerDone 317 next(state) 318 } 319 } 320 321 type WorkerError struct { 322 id int 323 } 324 325 func (e *WorkerError) Is(target error) bool { 326 if t, ok := target.(*WorkerError); ok { 327 return t.id == e.id 328 } 329 return false 330 } 331 332 func (e *WorkerError) Error() string { 333 return fmt.Sprintf("[worker %v] irrecoverable error", e.id) 334 } 335 336 // StartStateTransition returns a pair of functions AddTransition and ExecuteTransitions. 337 // AddTransition is called to add a state transition step, and then ExecuteTransitions shuffles 338 // all of the added steps and executes them in a random order. 339 func StartStateTransition() (func(t func()), func(*rapid.T)) { 340 var transitions []func() 341 342 addTransition := func(t func()) { 343 transitions = append(transitions, t) 344 } 345 346 executeTransitions := func(t *rapid.T) { 347 for i := 0; i < len(transitions); i++ { 348 j := rapid.IntRange(0, len(transitions)-i-1).Draw(t, "") 349 transitions[i], transitions[j+i] = transitions[j+i], transitions[i] 350 transitions[i]() 351 } 352 // TODO: is this simpler? 353 // executionOrder := rapid.SliceOfNDistinct( 354 // rapid.IntRange(0, len(transitions)-1), len(transitions), len(transitions), nil, 355 // ).Draw(t, "transition_execution_order").([]int) 356 // for _, i := range executionOrder { 357 // transitions[i]() 358 // } 359 } 360 361 return addTransition, executeTransitions 362 } 363 364 type StateTransition struct { 365 cancel bool 366 workerIDs []int 367 workerTransitions []WorkerStateTransition 368 } 369 370 func (st *StateTransition) String() string { 371 return fmt.Sprintf( 372 "stateTransition{ cancel=%v, workerIDs=%v, workerTransitions=%v }", 373 st.cancel, st.workerIDs, st.workerTransitions, 374 ) 375 } 376 377 type ComponentManagerMachine struct { 378 cm *component.ComponentManager 379 380 cancel context.CancelFunc 381 workerTransitionConsumers []WSTConsumer 382 383 canceled bool 384 workerErrors error 385 workerStates []WorkerState 386 387 resetChannelReadTimeout func() 388 assertClosed func(t *rapid.T, ch <-chan struct{}, msgAndArgs ...interface{}) 389 assertNotClosed func(t *rapid.T, ch <-chan struct{}, msgAndArgs ...interface{}) 390 assertErrorThrownMatches func(t *rapid.T, err error, msgAndArgs ...interface{}) 391 assertErrorNotThrown func(t *rapid.T) 392 393 cancelGenerator *rapid.Generator[bool] 394 drawStateTransition func(t *rapid.T) *StateTransition 395 } 396 397 func (c *ComponentManagerMachine) init(t *rapid.T) { 398 numWorkers := rapid.IntRange(0, 5).Draw(t, "num_workers") 399 pCancel := rapid.Float64Range(0, 100).Draw(t, "p_cancel") 400 401 c.cancelGenerator = rapid.Map(rapid.Float64Range(0, 100), func(n float64) bool { 402 return pCancel == 100 || n < pCancel 403 }) 404 405 c.drawStateTransition = func(t *rapid.T) *StateTransition { 406 st := &StateTransition{} 407 408 if !c.canceled { 409 st.cancel = c.cancelGenerator.Draw(t, "cancel") 410 } 411 412 for workerId, state := range c.workerStates { 413 if allowedTransitions, ok := WorkerStateTransitions[state]; ok { 414 label := fmt.Sprintf("worker_transition_%v", workerId) 415 st.workerIDs = append(st.workerIDs, workerId) 416 st.workerTransitions = append(st.workerTransitions, rapid.SampledFrom(allowedTransitions).Draw(t, label)) 417 } 418 } 419 420 return rapid.Just(st).Draw(t, "state_transition") 421 } 422 423 ctx, cancel := context.WithCancel(context.Background()) 424 c.cancel = cancel 425 426 signalerCtx, errChan := irrecoverable.WithSignaler(ctx) 427 428 var channelReadTimeout <-chan struct{} 429 var signalerErr error 430 431 c.resetChannelReadTimeout = func() { 432 ctx, cancel := context.WithTimeout(context.Background(), CHANNEL_CLOSE_LATENCY_ALLOWANCE) 433 _ = cancel 434 channelReadTimeout = ctx.Done() 435 } 436 437 c.assertClosed = func(t *rapid.T, ch <-chan struct{}, msgAndArgs ...interface{}) { 438 select { 439 case <-ch: 440 default: 441 select { 442 case <-channelReadTimeout: 443 assert.Fail(t, "channel is not closed", msgAndArgs...) 444 case <-ch: 445 } 446 } 447 } 448 449 c.assertNotClosed = func(t *rapid.T, ch <-chan struct{}, msgAndArgs ...interface{}) { 450 select { 451 case <-ch: 452 assert.Fail(t, "channel is closed", msgAndArgs...) 453 default: 454 select { 455 case <-ch: 456 assert.Fail(t, "channel is closed", msgAndArgs...) 457 case <-channelReadTimeout: 458 } 459 } 460 } 461 462 c.assertErrorThrownMatches = func(t *rapid.T, err error, msgAndArgs ...interface{}) { 463 if signalerErr == nil { 464 select { 465 case signalerErr = <-errChan: 466 default: 467 select { 468 case <-channelReadTimeout: 469 assert.Fail(t, "error was not thrown") 470 return 471 case signalerErr = <-errChan: 472 } 473 } 474 } 475 476 assert.ErrorIs(t, err, signalerErr, msgAndArgs...) 477 } 478 479 c.assertErrorNotThrown = func(t *rapid.T) { 480 if signalerErr == nil { 481 select { 482 case signalerErr = <-errChan: 483 default: 484 select { 485 case signalerErr = <-errChan: 486 case <-channelReadTimeout: 487 return 488 } 489 } 490 } 491 492 assert.Fail(t, "error was thrown: %v", signalerErr) 493 } 494 495 c.workerTransitionConsumers = make([]WSTConsumer, numWorkers) 496 c.workerStates = make([]WorkerState, numWorkers) 497 498 cmb := component.NewComponentManagerBuilder() 499 500 for i := 0; i < numWorkers; i++ { 501 wtc, wtp := MakeWorkerTransitionFuncs() 502 c.workerTransitionConsumers[i] = wtc 503 cmb.AddWorker(ComponentWorker(t, i, wtp)) 504 } 505 506 c.cm = cmb.Build() 507 c.cm.Start(signalerCtx) 508 509 for i := 0; i < numWorkers; i++ { 510 c.workerStates[i] = WorkerStartingUp 511 } 512 } 513 514 func (c *ComponentManagerMachine) ExecuteStateTransition(t *rapid.T) { 515 st := c.drawStateTransition(t) 516 517 t.Logf("drew state transition: %v\n", st) 518 519 var errors *multierror.Error 520 521 addTransition, executeTransitionsInRandomOrder := StartStateTransition() 522 523 if st.cancel { 524 addTransition(func() { 525 t.Log("executing cancel transition\n") 526 c.cancel() 527 c.canceled = true 528 c.resetChannelReadTimeout() 529 c.assertClosed(t, c.cm.ShutdownSignal()) 530 }) 531 } 532 533 for i, workerId := range st.workerIDs { 534 i := i 535 workerId := workerId 536 addTransition(func() { 537 wst := st.workerTransitions[i] 538 t.Logf("executing worker %v transition: %v\n", workerId, wst) 539 endState := c.workerTransitionConsumers[workerId](wst) 540 CheckWorkerStateTransition(t, workerId, c.workerStates[workerId], endState, wst, c.canceled) 541 c.workerStates[workerId] = endState 542 543 if (WorkerStateList{WorkerStartupEncounteredFatal, WorkerEncounteredFatal}).Contains(endState) { 544 err := &WorkerError{workerId} 545 require.NotErrorIs(t, c.workerErrors, err) 546 require.NotErrorIs(t, errors, err) 547 errors = multierror.Append(errors, err) 548 c.canceled = true 549 c.resetChannelReadTimeout() 550 c.assertClosed(t, c.cm.ShutdownSignal()) 551 } 552 }) 553 } 554 555 executeTransitionsInRandomOrder(t) 556 557 if c.workerErrors == nil { 558 c.workerErrors = errors.ErrorOrNil() 559 } 560 561 t.Logf("end state: { canceled=%v, workerErrors=%v, workerStates=%v }\n", c.canceled, c.workerErrors, c.workerStates) 562 } 563 564 func (c *ComponentManagerMachine) Check(t *rapid.T) { 565 c.resetChannelReadTimeout() 566 567 if c.canceled { 568 c.assertClosed(t, c.cm.ShutdownSignal(), "context is canceled but component manager shutdown signal is not closed") 569 } 570 571 allWorkersReady := true 572 allWorkersDone := true 573 574 for workerID, state := range c.workerStates { 575 if (WorkerStateList{ 576 WorkerStartingUp, 577 WorkerStartupShuttingDown, 578 WorkerStartupCanceled, 579 WorkerStartupEncounteredFatal, 580 }).Contains(state) { 581 allWorkersReady = false 582 c.assertNotClosed(t, c.cm.Ready(), "worker %v has not finished startup but component manager ready channel is closed", workerID) 583 } 584 585 if !(WorkerStateList{ 586 WorkerStartupCanceled, 587 WorkerStartupEncounteredFatal, 588 WorkerCanceled, 589 WorkerEncounteredFatal, 590 WorkerDone, 591 }).Contains(state) { 592 allWorkersDone = false 593 c.assertNotClosed(t, c.cm.Done(), "worker %v has not exited but component manager done channel is closed", workerID) 594 } 595 596 if (WorkerStateList{ 597 WorkerStartupShuttingDown, 598 WorkerStartupCanceled, 599 WorkerStartupEncounteredFatal, 600 WorkerShuttingDown, 601 WorkerCanceled, 602 WorkerEncounteredFatal, 603 }).Contains(state) { 604 c.assertClosed(t, c.cm.ShutdownSignal(), "worker %v has been canceled or encountered a fatal error but component manager shutdown signal is not closed", workerID) 605 } 606 } 607 608 if allWorkersReady { 609 c.assertClosed(t, c.cm.Ready(), "all workers are ready but component manager ready channel is not closed") 610 } 611 612 if allWorkersDone { 613 c.assertClosed(t, c.cm.Done(), "all workers are done but component manager done channel is not closed") 614 } 615 616 if c.workerErrors != nil { 617 c.assertErrorThrownMatches(t, c.workerErrors, "error received by signaler does not match any of the ones thrown") 618 c.assertClosed(t, c.cm.ShutdownSignal(), "fatal error thrown but context is not canceled") 619 } else { 620 c.assertErrorNotThrown(t) 621 } 622 } 623 624 func TestComponentManager(t *testing.T) { 625 unittest.SkipUnless(t, unittest.TEST_LONG_RUNNING, "skip because this test takes too long") 626 627 rapid.Check(t, func(t *rapid.T) { 628 sm := new(ComponentManagerMachine) 629 sm.init(t) 630 t.Repeat(rapid.StateMachineActions(sm)) 631 }) 632 } 633 634 func TestComponentManagerShutdown(t *testing.T) { 635 mgr := component.NewComponentManagerBuilder(). 636 AddWorker(func(ctx irrecoverable.SignalerContext, ready component.ReadyFunc) { 637 ready() 638 <-ctx.Done() 639 }).Build() 640 641 parent, cancel := context.WithCancel(context.Background()) 642 ctx, _ := irrecoverable.WithSignaler(parent) 643 644 mgr.Start(ctx) 645 unittest.AssertClosesBefore(t, mgr.Ready(), 10*time.Millisecond) 646 cancel() 647 648 // ShutdownSignal indicates we have started shutdown, Done indicates we have completed 649 // shutdown. If we have completed shutdown, we must have started shutdown. 650 unittest.AssertClosesBefore(t, mgr.Done(), 10*time.Millisecond) 651 closed := util.CheckClosed(mgr.ShutdownSignal()) 652 assert.True(t, closed) 653 } 654 655 // run the test many times to reproduce consistently 656 func TestComponentManagerShutdown_100(t *testing.T) { 657 for i := 0; i < 100; i++ { 658 TestComponentManagerShutdown(t) 659 } 660 }