github.com/koko1123/flow-go-1@v0.29.6/module/component/component_manager_test.go (about) 1 package component_test 2 3 import ( 4 "context" 5 "fmt" 6 "testing" 7 "time" 8 9 "github.com/hashicorp/go-multierror" 10 "github.com/stretchr/testify/assert" 11 "github.com/stretchr/testify/require" 12 "pgregory.net/rapid" 13 14 "github.com/koko1123/flow-go-1/module/component" 15 "github.com/koko1123/flow-go-1/module/irrecoverable" 16 "github.com/koko1123/flow-go-1/module/util" 17 "github.com/koko1123/flow-go-1/utils/unittest" 18 ) 19 20 const CHANNEL_CLOSE_LATENCY_ALLOWANCE = 25 * time.Millisecond 21 22 type WorkerState int 23 24 const ( 25 UnknownWorkerState = iota 26 WorkerStartingUp // worker is starting up 27 WorkerStartupShuttingDown // worker was canceled during startup and is shutting down 28 WorkerStartupCanceled // worker has exited after being canceled during startup 29 WorkerStartupEncounteredFatal // worker encountered a fatal error during startup 30 WorkerRunning // worker has started up and is running normally 31 WorkerShuttingDown // worker was canceled and is shutting down 32 WorkerCanceled // worker has exited after being canceled 33 WorkerEncounteredFatal // worker encountered a fatal error 34 WorkerDone // worker has shut down after running normally 35 36 ) 37 38 func (s WorkerState) String() string { 39 switch s { 40 case WorkerStartingUp: 41 return "WORKER_STARTING_UP" 42 case WorkerStartupShuttingDown: 43 return "WORKER_STARTUP_SHUTTING_DOWN" 44 case WorkerStartupCanceled: 45 return "WORKER_STARTUP_CANCELED" 46 case WorkerStartupEncounteredFatal: 47 return "WORKER_STARTUP_ENCOUNTERED_FATAL" 48 case WorkerRunning: 49 return "WORKER_RUNNING" 50 case WorkerShuttingDown: 51 return "WORKER_SHUTTING_DOWN" 52 case WorkerCanceled: 53 return "WORKER_CANCELED" 54 case WorkerEncounteredFatal: 55 return "WORKER_ENCOUNTERED_FATAL" 56 case WorkerDone: 57 return "WORKER_DONE" 58 default: 59 return "UNKNOWN" 60 } 61 } 62 63 type WorkerStateList []WorkerState 64 65 func (wsl WorkerStateList) Contains(ws WorkerState) bool { 66 for _, s := range wsl { 67 if s == ws { 68 return true 69 } 70 } 71 return false 72 } 73 74 type WorkerStateTransition int 75 76 const ( 77 UnknownWorkerStateTransition WorkerStateTransition = iota 78 WorkerCheckCtxAndShutdown // check context and shutdown if canceled 79 WorkerCheckCtxAndExit // check context and exit immediately if canceled 80 WorkerFinishStartup // finish starting up 81 WorkerDoWork // do work 82 WorkerExit // exit 83 WorkerThrowError // throw error 84 ) 85 86 func (wst WorkerStateTransition) String() string { 87 switch wst { 88 case WorkerCheckCtxAndShutdown: 89 return "WORKER_CHECK_CTX_AND_SHUTDOWN" 90 case WorkerCheckCtxAndExit: 91 return "WORKER_CHECK_CTX_AND_EXIT" 92 case WorkerFinishStartup: 93 return "WORKER_FINISH_STARTUP" 94 case WorkerDoWork: 95 return "WORKER_DO_WORK" 96 case WorkerExit: 97 return "WORKER_EXIT" 98 case WorkerThrowError: 99 return "WORKER_THROW_ERROR" 100 default: 101 return "UNKNOWN" 102 } 103 } 104 105 // WorkerStateTransitions is a map from worker state to valid state transitions 106 var WorkerStateTransitions = map[WorkerState][]WorkerStateTransition{ 107 WorkerStartingUp: {WorkerCheckCtxAndExit, WorkerCheckCtxAndShutdown, WorkerDoWork, WorkerFinishStartup}, 108 WorkerStartupShuttingDown: {WorkerDoWork, WorkerExit, WorkerThrowError}, 109 WorkerRunning: {WorkerCheckCtxAndExit, WorkerCheckCtxAndShutdown, WorkerDoWork, WorkerExit, WorkerThrowError}, 110 WorkerShuttingDown: {WorkerDoWork, WorkerExit, WorkerThrowError}, 111 } 112 113 // CheckWorkerStateTransition checks the validity of a worker state transition 114 func CheckWorkerStateTransition(t *rapid.T, id int, start, end WorkerState, transition WorkerStateTransition, canceled bool) { 115 if !(func() bool { 116 switch start { 117 case WorkerStartingUp: 118 switch transition { 119 case WorkerCheckCtxAndExit: 120 return (canceled && end == WorkerStartupCanceled) || (!canceled && end == WorkerStartingUp) 121 case WorkerCheckCtxAndShutdown: 122 return (canceled && end == WorkerStartupShuttingDown) || (!canceled && end == WorkerStartingUp) 123 case WorkerDoWork: 124 return end == WorkerStartingUp 125 case WorkerFinishStartup: 126 return end == WorkerRunning 127 } 128 case WorkerStartupShuttingDown: 129 switch transition { 130 case WorkerDoWork: 131 return end == WorkerStartupShuttingDown 132 case WorkerExit: 133 return end == WorkerStartupCanceled 134 case WorkerThrowError: 135 return end == WorkerStartupEncounteredFatal 136 } 137 case WorkerRunning: 138 switch transition { 139 case WorkerCheckCtxAndExit: 140 return (canceled && end == WorkerCanceled) || (!canceled && end == WorkerRunning) 141 case WorkerCheckCtxAndShutdown: 142 return (canceled && end == WorkerShuttingDown) || (!canceled && end == WorkerRunning) 143 case WorkerDoWork: 144 return end == WorkerRunning 145 case WorkerExit: 146 return end == WorkerDone 147 case WorkerThrowError: 148 return end == WorkerEncounteredFatal 149 } 150 case WorkerShuttingDown: 151 switch transition { 152 case WorkerDoWork: 153 return end == WorkerShuttingDown 154 case WorkerExit: 155 return end == WorkerCanceled 156 case WorkerThrowError: 157 return end == WorkerEncounteredFatal 158 } 159 } 160 161 return false 162 }()) { 163 require.Fail(t, "invalid worker state transition", "[worker %v] start=%v, canceled=%v, transition=%v, end=%v", id, start, canceled, transition, end) 164 } 165 } 166 167 type WSTConsumer func(WorkerStateTransition) WorkerState 168 type WSTProvider func(WorkerState) WorkerStateTransition 169 170 // MakeWorkerTransitionFuncs creates a WorkerStateTransition Consumer / Provider pair. 171 // The Consumer is called by the worker to notify the test code of the completion of a state transition 172 // and receive the next state transition instruction. 173 // The Provider is called by the test code to send the next state transition instruction and get the 174 // resulting end state. 175 func MakeWorkerTransitionFuncs() (WSTConsumer, WSTProvider) { 176 var started bool 177 stateChan := make(chan WorkerState, 1) 178 transitionChan := make(chan WorkerStateTransition) 179 180 consumer := func(wst WorkerStateTransition) WorkerState { 181 transitionChan <- wst 182 return <-stateChan 183 } 184 185 provider := func(ws WorkerState) WorkerStateTransition { 186 if started { 187 stateChan <- ws 188 } else { 189 started = true 190 } 191 192 if _, ok := WorkerStateTransitions[ws]; !ok { 193 return UnknownWorkerStateTransition 194 } 195 196 return <-transitionChan 197 } 198 199 return consumer, provider 200 } 201 202 func ComponentWorker(t *rapid.T, id int, next WSTProvider) component.ComponentWorker { 203 unexpectedStateTransition := func(s WorkerState, wst WorkerStateTransition) { 204 panic(fmt.Sprintf("[worker %v] unexpected state transition: received %v for state %v", id, wst, s)) 205 } 206 207 log := func(msg string) { 208 t.Logf("[worker %v] %v\n", id, msg) 209 } 210 211 return func(ctx irrecoverable.SignalerContext, ready component.ReadyFunc) { 212 var state WorkerState 213 goto startingUp 214 215 startingUp: 216 log("starting up") 217 state = WorkerStartingUp 218 switch transition := next(state); transition { 219 case WorkerCheckCtxAndExit: 220 if util.CheckClosed(ctx.Done()) { 221 goto startupCanceled 222 } 223 goto startingUp 224 case WorkerCheckCtxAndShutdown: 225 if util.CheckClosed(ctx.Done()) { 226 goto startupShuttingDown 227 } 228 goto startingUp 229 case WorkerDoWork: 230 goto startingUp 231 case WorkerFinishStartup: 232 ready() 233 goto running 234 default: 235 unexpectedStateTransition(state, transition) 236 } 237 238 startupShuttingDown: 239 log("startup shutting down") 240 state = WorkerStartupShuttingDown 241 switch transition := next(state); transition { 242 case WorkerDoWork: 243 goto startupShuttingDown 244 case WorkerExit: 245 goto startupCanceled 246 case WorkerThrowError: 247 goto startupEncounteredFatal 248 default: 249 unexpectedStateTransition(state, transition) 250 } 251 252 startupCanceled: 253 log("startup canceled") 254 state = WorkerStartupCanceled 255 next(state) 256 return 257 258 startupEncounteredFatal: 259 log("startup encountered fatal") 260 state = WorkerStartupEncounteredFatal 261 defer next(state) 262 ctx.Throw(&WorkerError{id}) 263 264 running: 265 log("running") 266 state = WorkerRunning 267 switch transition := next(state); transition { 268 case WorkerCheckCtxAndExit: 269 if util.CheckClosed(ctx.Done()) { 270 goto canceled 271 } 272 goto running 273 case WorkerCheckCtxAndShutdown: 274 if util.CheckClosed(ctx.Done()) { 275 goto shuttingDown 276 } 277 goto running 278 case WorkerDoWork: 279 goto running 280 case WorkerExit: 281 goto done 282 case WorkerThrowError: 283 goto encounteredFatal 284 default: 285 unexpectedStateTransition(state, transition) 286 } 287 288 shuttingDown: 289 log("shutting down") 290 state = WorkerShuttingDown 291 switch transition := next(state); transition { 292 case WorkerDoWork: 293 goto shuttingDown 294 case WorkerExit: 295 goto canceled 296 case WorkerThrowError: 297 goto encounteredFatal 298 default: 299 unexpectedStateTransition(state, transition) 300 } 301 302 canceled: 303 log("canceled") 304 state = WorkerCanceled 305 next(state) 306 return 307 308 encounteredFatal: 309 log("encountered fatal") 310 state = WorkerEncounteredFatal 311 defer next(state) 312 ctx.Throw(&WorkerError{id}) 313 314 done: 315 log("done") 316 state = WorkerDone 317 next(state) 318 } 319 } 320 321 type WorkerError struct { 322 id int 323 } 324 325 func (e *WorkerError) Is(target error) bool { 326 if t, ok := target.(*WorkerError); ok { 327 return t.id == e.id 328 } 329 return false 330 } 331 332 func (e *WorkerError) Error() string { 333 return fmt.Sprintf("[worker %v] irrecoverable error", e.id) 334 } 335 336 // StartStateTransition returns a pair of functions AddTransition and ExecuteTransitions. 337 // AddTransition is called to add a state transition step, and then ExecuteTransitions shuffles 338 // all of the added steps and executes them in a random order. 339 func StartStateTransition() (func(t func()), func(*rapid.T)) { 340 var transitions []func() 341 342 addTransition := func(t func()) { 343 transitions = append(transitions, t) 344 } 345 346 executeTransitions := func(t *rapid.T) { 347 for i := 0; i < len(transitions); i++ { 348 j := rapid.IntRange(0, len(transitions)-i-1).Draw(t, "").(int) 349 transitions[i], transitions[j+i] = transitions[j+i], transitions[i] 350 transitions[i]() 351 } 352 // TODO: is this simpler? 353 // executionOrder := rapid.SliceOfNDistinct( 354 // rapid.IntRange(0, len(transitions)-1), len(transitions), len(transitions), nil, 355 // ).Draw(t, "transition_execution_order").([]int) 356 // for _, i := range executionOrder { 357 // transitions[i]() 358 // } 359 } 360 361 return addTransition, executeTransitions 362 } 363 364 type StateTransition struct { 365 cancel bool 366 workerIDs []int 367 workerTransitions []WorkerStateTransition 368 } 369 370 func (st *StateTransition) String() string { 371 return fmt.Sprintf( 372 "stateTransition{ cancel=%v, workerIDs=%v, workerTransitions=%v }", 373 st.cancel, st.workerIDs, st.workerTransitions, 374 ) 375 } 376 377 type ComponentManagerMachine struct { 378 cm *component.ComponentManager 379 380 cancel context.CancelFunc 381 workerTransitionConsumers []WSTConsumer 382 383 canceled bool 384 workerErrors error 385 workerStates []WorkerState 386 387 resetChannelReadTimeout func() 388 assertClosed func(t *rapid.T, ch <-chan struct{}, msgAndArgs ...interface{}) 389 assertNotClosed func(t *rapid.T, ch <-chan struct{}, msgAndArgs ...interface{}) 390 assertErrorThrownMatches func(t *rapid.T, err error, msgAndArgs ...interface{}) 391 assertErrorNotThrown func(t *rapid.T) 392 393 cancelGenerator *rapid.Generator 394 drawStateTransition func(t *rapid.T) *StateTransition 395 } 396 397 func (c *ComponentManagerMachine) Init(t *rapid.T) { 398 numWorkers := rapid.IntRange(0, 5).Draw(t, "num_workers").(int) 399 pCancel := rapid.Float64Range(0, 100).Draw(t, "p_cancel").(float64) 400 401 c.cancelGenerator = rapid.Float64Range(0, 100). 402 Map(func(n float64) bool { 403 return pCancel == 100 || n < pCancel 404 }) 405 406 c.drawStateTransition = func(t *rapid.T) *StateTransition { 407 st := &StateTransition{} 408 409 if !c.canceled { 410 st.cancel = c.cancelGenerator.Draw(t, "cancel").(bool) 411 } 412 413 for workerId, state := range c.workerStates { 414 if allowedTransitions, ok := WorkerStateTransitions[state]; ok { 415 label := fmt.Sprintf("worker_transition_%v", workerId) 416 st.workerIDs = append(st.workerIDs, workerId) 417 st.workerTransitions = append(st.workerTransitions, rapid.SampledFrom(allowedTransitions).Draw(t, label).(WorkerStateTransition)) 418 } 419 } 420 421 return rapid.Just(st).Draw(t, "state_transition").(*StateTransition) 422 } 423 424 ctx, cancel := context.WithCancel(context.Background()) 425 c.cancel = cancel 426 427 signalerCtx, errChan := irrecoverable.WithSignaler(ctx) 428 429 var channelReadTimeout <-chan struct{} 430 var signalerErr error 431 432 c.resetChannelReadTimeout = func() { 433 ctx, cancel := context.WithTimeout(context.Background(), CHANNEL_CLOSE_LATENCY_ALLOWANCE) 434 _ = cancel 435 channelReadTimeout = ctx.Done() 436 } 437 438 c.assertClosed = func(t *rapid.T, ch <-chan struct{}, msgAndArgs ...interface{}) { 439 select { 440 case <-ch: 441 default: 442 select { 443 case <-channelReadTimeout: 444 assert.Fail(t, "channel is not closed", msgAndArgs...) 445 case <-ch: 446 } 447 } 448 } 449 450 c.assertNotClosed = func(t *rapid.T, ch <-chan struct{}, msgAndArgs ...interface{}) { 451 select { 452 case <-ch: 453 assert.Fail(t, "channel is closed", msgAndArgs...) 454 default: 455 select { 456 case <-ch: 457 assert.Fail(t, "channel is closed", msgAndArgs...) 458 case <-channelReadTimeout: 459 } 460 } 461 } 462 463 c.assertErrorThrownMatches = func(t *rapid.T, err error, msgAndArgs ...interface{}) { 464 if signalerErr == nil { 465 select { 466 case signalerErr = <-errChan: 467 default: 468 select { 469 case <-channelReadTimeout: 470 assert.Fail(t, "error was not thrown") 471 return 472 case signalerErr = <-errChan: 473 } 474 } 475 } 476 477 assert.ErrorIs(t, err, signalerErr, msgAndArgs...) 478 } 479 480 c.assertErrorNotThrown = func(t *rapid.T) { 481 if signalerErr == nil { 482 select { 483 case signalerErr = <-errChan: 484 default: 485 select { 486 case signalerErr = <-errChan: 487 case <-channelReadTimeout: 488 return 489 } 490 } 491 } 492 493 assert.Fail(t, "error was thrown: %v", signalerErr) 494 } 495 496 c.workerTransitionConsumers = make([]WSTConsumer, numWorkers) 497 c.workerStates = make([]WorkerState, numWorkers) 498 499 cmb := component.NewComponentManagerBuilder() 500 501 for i := 0; i < numWorkers; i++ { 502 wtc, wtp := MakeWorkerTransitionFuncs() 503 c.workerTransitionConsumers[i] = wtc 504 cmb.AddWorker(ComponentWorker(t, i, wtp)) 505 } 506 507 c.cm = cmb.Build() 508 c.cm.Start(signalerCtx) 509 510 for i := 0; i < numWorkers; i++ { 511 c.workerStates[i] = WorkerStartingUp 512 } 513 } 514 515 func (c *ComponentManagerMachine) ExecuteStateTransition(t *rapid.T) { 516 st := c.drawStateTransition(t) 517 518 t.Logf("drew state transition: %v\n", st) 519 520 var errors *multierror.Error 521 522 addTransition, executeTransitionsInRandomOrder := StartStateTransition() 523 524 if st.cancel { 525 addTransition(func() { 526 t.Log("executing cancel transition\n") 527 c.cancel() 528 c.canceled = true 529 c.resetChannelReadTimeout() 530 c.assertClosed(t, c.cm.ShutdownSignal()) 531 }) 532 } 533 534 for i, workerId := range st.workerIDs { 535 i := i 536 workerId := workerId 537 addTransition(func() { 538 wst := st.workerTransitions[i] 539 t.Logf("executing worker %v transition: %v\n", workerId, wst) 540 endState := c.workerTransitionConsumers[workerId](wst) 541 CheckWorkerStateTransition(t, workerId, c.workerStates[workerId], endState, wst, c.canceled) 542 c.workerStates[workerId] = endState 543 544 if (WorkerStateList{WorkerStartupEncounteredFatal, WorkerEncounteredFatal}).Contains(endState) { 545 err := &WorkerError{workerId} 546 require.NotErrorIs(t, c.workerErrors, err) 547 require.NotErrorIs(t, errors, err) 548 errors = multierror.Append(errors, err) 549 c.canceled = true 550 c.resetChannelReadTimeout() 551 c.assertClosed(t, c.cm.ShutdownSignal()) 552 } 553 }) 554 } 555 556 executeTransitionsInRandomOrder(t) 557 558 if c.workerErrors == nil { 559 c.workerErrors = errors.ErrorOrNil() 560 } 561 562 t.Logf("end state: { canceled=%v, workerErrors=%v, workerStates=%v }\n", c.canceled, c.workerErrors, c.workerStates) 563 } 564 565 func (c *ComponentManagerMachine) Check(t *rapid.T) { 566 c.resetChannelReadTimeout() 567 568 if c.canceled { 569 c.assertClosed(t, c.cm.ShutdownSignal(), "context is canceled but component manager shutdown signal is not closed") 570 } 571 572 allWorkersReady := true 573 allWorkersDone := true 574 575 for workerID, state := range c.workerStates { 576 if (WorkerStateList{ 577 WorkerStartingUp, 578 WorkerStartupShuttingDown, 579 WorkerStartupCanceled, 580 WorkerStartupEncounteredFatal, 581 }).Contains(state) { 582 allWorkersReady = false 583 c.assertNotClosed(t, c.cm.Ready(), "worker %v has not finished startup but component manager ready channel is closed", workerID) 584 } 585 586 if !(WorkerStateList{ 587 WorkerStartupCanceled, 588 WorkerStartupEncounteredFatal, 589 WorkerCanceled, 590 WorkerEncounteredFatal, 591 WorkerDone, 592 }).Contains(state) { 593 allWorkersDone = false 594 c.assertNotClosed(t, c.cm.Done(), "worker %v has not exited but component manager done channel is closed", workerID) 595 } 596 597 if (WorkerStateList{ 598 WorkerStartupShuttingDown, 599 WorkerStartupCanceled, 600 WorkerStartupEncounteredFatal, 601 WorkerShuttingDown, 602 WorkerCanceled, 603 WorkerEncounteredFatal, 604 }).Contains(state) { 605 c.assertClosed(t, c.cm.ShutdownSignal(), "worker %v has been canceled or encountered a fatal error but component manager shutdown signal is not closed", workerID) 606 } 607 } 608 609 if allWorkersReady { 610 c.assertClosed(t, c.cm.Ready(), "all workers are ready but component manager ready channel is not closed") 611 } 612 613 if allWorkersDone { 614 c.assertClosed(t, c.cm.Done(), "all workers are done but component manager done channel is not closed") 615 } 616 617 if c.workerErrors != nil { 618 c.assertErrorThrownMatches(t, c.workerErrors, "error received by signaler does not match any of the ones thrown") 619 c.assertClosed(t, c.cm.ShutdownSignal(), "fatal error thrown but context is not canceled") 620 } else { 621 c.assertErrorNotThrown(t) 622 } 623 } 624 625 func TestComponentManager(t *testing.T) { 626 unittest.SkipUnless(t, unittest.TEST_LONG_RUNNING, "skip because this test takes too long") 627 628 rapid.Check(t, rapid.Run(&ComponentManagerMachine{})) 629 } 630 631 func TestComponentManagerShutdown(t *testing.T) { 632 mgr := component.NewComponentManagerBuilder(). 633 AddWorker(func(ctx irrecoverable.SignalerContext, ready component.ReadyFunc) { 634 ready() 635 <-ctx.Done() 636 }).Build() 637 638 parent, cancel := context.WithCancel(context.Background()) 639 ctx, _ := irrecoverable.WithSignaler(parent) 640 641 mgr.Start(ctx) 642 unittest.AssertClosesBefore(t, mgr.Ready(), 10*time.Millisecond) 643 cancel() 644 645 // ShutdownSignal indicates we have started shutdown, Done indicates we have completed 646 // shutdown. If we have completed shutdown, we must have started shutdown. 647 unittest.AssertClosesBefore(t, mgr.Done(), 10*time.Millisecond) 648 closed := util.CheckClosed(mgr.ShutdownSignal()) 649 assert.True(t, closed) 650 } 651 652 // run the test many times to reproduce consistently 653 func TestComponentManagerShutdown_100(t *testing.T) { 654 for i := 0; i < 100; i++ { 655 TestComponentManagerShutdown(t) 656 } 657 }