github.com/status-im/status-go@v1.1.0/services/wallet/async/scheduler_test.go (about)

     1  package async
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"sync"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/stretchr/testify/require"
    11  )
    12  
    13  const (
    14  	noActionPerformed = "no action performed"
    15  	taskCalled        = "task called"
    16  	taskResultCalled  = "task result called"
    17  )
    18  
    19  func TestScheduler_Enqueue_Simple(t *testing.T) {
    20  	s := NewScheduler()
    21  	callChan := make(chan string, 10)
    22  
    23  	testFunction := func(policy ReplacementPolicy, failTest bool) {
    24  		testTask := TaskType{1, policy}
    25  		ignored := s.Enqueue(testTask, func(ctx context.Context) (interface{}, error) {
    26  			callChan <- taskCalled
    27  			if failTest {
    28  				return nil, errors.New("test error")
    29  			}
    30  			return 123, nil
    31  		}, func(res interface{}, taskType TaskType, err error) {
    32  			if failTest {
    33  				require.Error(t, err)
    34  				require.Nil(t, res)
    35  			} else {
    36  				require.NoError(t, err)
    37  				require.Equal(t, 123, res)
    38  			}
    39  			require.Equal(t, testTask, taskType)
    40  			callChan <- taskResultCalled
    41  		})
    42  		require.False(t, ignored)
    43  
    44  		lastRes := noActionPerformed
    45  		done := false
    46  		for !done {
    47  			select {
    48  			case callRes := <-callChan:
    49  				if callRes == taskCalled {
    50  					require.Equal(t, noActionPerformed, lastRes)
    51  				} else if callRes == taskResultCalled {
    52  					require.Equal(t, taskCalled, lastRes)
    53  					done = true
    54  				} else {
    55  					require.Fail(t, "unexpected result", `"%s" for policy %d`, callRes, policy)
    56  				}
    57  				lastRes = callRes
    58  			case <-time.After(1 * time.Second):
    59  				require.Fail(t, "test not completed in time", `last result: "%s" for policy %d`, lastRes, policy)
    60  			}
    61  		}
    62  
    63  		require.Equal(t, taskResultCalled, lastRes)
    64  	}
    65  
    66  	testFailed := false
    67  	for i := 0; i < 2; i++ {
    68  		testFailed = (i == 0)
    69  		for policy := range []ReplacementPolicy{ReplacementPolicyCancelOld, ReplacementPolicyIgnoreNew} {
    70  			testFunction(policy, testFailed)
    71  		}
    72  	}
    73  }
    74  
    75  // Validate the task is cancelled when a new one is scheduled and that the third one will overwrite the second one
    76  func TestScheduler_Enqueue_VerifyReplacementPolicyCancelOld(t *testing.T) {
    77  	s := NewScheduler()
    78  
    79  	type testStage string
    80  	const (
    81  		stage1FirstTaskStarted                testStage = "First task started"
    82  		stage2ThirdEnqueueOverwroteSecondTask testStage = "Third Enqueue overwrote second task"
    83  		stage3ExitingFirstCancelledTask       testStage = "Exiting first cancelled task"
    84  		stage5ThirdTaskRunning                testStage = "Third task running"
    85  		stage6ThirdTaskResponse               testStage = "Third task response"
    86  	)
    87  
    88  	testStages := []testStage{
    89  		stage1FirstTaskStarted,
    90  		stage2ThirdEnqueueOverwroteSecondTask,
    91  		stage3ExitingFirstCancelledTask,
    92  		stage5ThirdTaskRunning,
    93  		stage6ThirdTaskResponse,
    94  	}
    95  
    96  	callChan := make(chan testStage, len(testStages))
    97  	var firstRunWG, secondRunWG, thirdRunWG sync.WaitGroup
    98  
    99  	firstRunWG.Add(1)
   100  	secondRunWG.Add(1)
   101  	thirdRunWG.Add(1)
   102  
   103  	stage4AsyncFirstTaskCanceledResponse := false
   104  
   105  	testTask := TaskType{1, ReplacementPolicyCancelOld}
   106  	for i := 0; i < 2; i++ {
   107  		currentIndex := i
   108  		ignored := s.Enqueue(testTask, func(workCtx context.Context) (interface{}, error) {
   109  			callChan <- stage1FirstTaskStarted
   110  
   111  			// Mark first task running so that the second Enqueue will cancel this one and overwrite it
   112  			firstRunWG.Done()
   113  
   114  			// Wait for the first task to be cancelled by the second one
   115  			select {
   116  			case <-workCtx.Done():
   117  				require.ErrorAs(t, workCtx.Err(), &context.Canceled)
   118  
   119  				// Unblock the third Enqueue call
   120  				secondRunWG.Done()
   121  
   122  				// Block the second task from running until the third one is overwriting the second one that didn't run
   123  				thirdRunWG.Wait()
   124  				callChan <- stage3ExitingFirstCancelledTask
   125  			case <-time.After(1 * time.Second):
   126  				require.Fail(t, "task not cancelled in time")
   127  			}
   128  			return nil, workCtx.Err()
   129  		}, func(res interface{}, taskType TaskType, err error) {
   130  			switch currentIndex {
   131  			case 0:
   132  				// First task was cancelled by the second one Enqueue call
   133  				stage4AsyncFirstTaskCanceledResponse = true
   134  
   135  				require.ErrorAs(t, err, &context.Canceled)
   136  			case 1:
   137  				callChan <- stage2ThirdEnqueueOverwroteSecondTask
   138  
   139  				// Unblock the first task from blocking execution of the third one
   140  				// also validate that the third Enqueue call overwrote running the second one
   141  				thirdRunWG.Done()
   142  
   143  				require.True(t, errors.Is(err, ErrTaskOverwritten))
   144  			}
   145  		})
   146  		require.False(t, ignored)
   147  
   148  		// Wait first task to run
   149  		firstRunWG.Wait()
   150  	}
   151  	// Wait for the second task to be cancelled before running the third one
   152  	secondRunWG.Wait()
   153  
   154  	ignored := s.Enqueue(testTask, func(ctx context.Context) (interface{}, error) {
   155  		callChan <- stage5ThirdTaskRunning
   156  		return 123, errors.New("test error")
   157  	}, func(res interface{}, taskType TaskType, err error) {
   158  		require.Error(t, err)
   159  		require.Equal(t, testTask, taskType)
   160  		require.Equal(t, 123, res)
   161  
   162  		callChan <- stage6ThirdTaskResponse
   163  	})
   164  	require.False(t, ignored)
   165  
   166  	lastRes := noActionPerformed
   167  	expectedTestStageIndex := 0
   168  	for i := 0; i < len(testStages); i++ {
   169  		select {
   170  		case callRes := <-callChan:
   171  			require.Equal(t, testStages[expectedTestStageIndex], callRes, "task stage out of order; expected %s, got %s", testStages[expectedTestStageIndex], callRes)
   172  			expectedTestStageIndex++
   173  		case <-time.After(1 * time.Second):
   174  			require.Fail(t, "test not completed in time", `last result: "%s" for cancel task policy`, lastRes)
   175  		}
   176  	}
   177  	require.True(t, stage4AsyncFirstTaskCanceledResponse)
   178  }
   179  
   180  func TestScheduler_Enqueue_VerifyReplacementPolicyIgnoreNew(t *testing.T) {
   181  	s := NewScheduler()
   182  	callChan := make(chan string, 10)
   183  	workloadWG := sync.WaitGroup{}
   184  	taskCallCount := 0
   185  	resultCallCount := 0
   186  
   187  	workloadWG.Add(1)
   188  	testTask := TaskType{1, ReplacementPolicyIgnoreNew}
   189  	ignored := s.Enqueue(testTask, func(workCtx context.Context) (interface{}, error) {
   190  		workloadWG.Wait()
   191  		require.NoError(t, workCtx.Err())
   192  		taskCallCount++
   193  		callChan <- taskCalled
   194  		return 123, nil
   195  	}, func(res interface{}, taskType TaskType, err error) {
   196  		require.NoError(t, err)
   197  		require.Equal(t, testTask, taskType)
   198  		require.Equal(t, 123, res)
   199  		resultCallCount++
   200  		callChan <- taskResultCalled
   201  	})
   202  	require.False(t, ignored)
   203  
   204  	ignored = s.Enqueue(testTask, func(ctx context.Context) (interface{}, error) {
   205  		require.Fail(t, "unexpected call")
   206  		return nil, errors.New("unexpected call")
   207  	}, func(res interface{}, taskType TaskType, err error) {
   208  		require.Fail(t, "unexpected result call")
   209  	})
   210  	require.True(t, ignored)
   211  	workloadWG.Done()
   212  
   213  	lastRes := noActionPerformed
   214  	done := false
   215  	for !done {
   216  		select {
   217  		case callRes := <-callChan:
   218  			if callRes == taskCalled {
   219  				require.Equal(t, noActionPerformed, lastRes)
   220  			} else if callRes == taskResultCalled {
   221  				require.Equal(t, taskCalled, lastRes)
   222  				done = true
   223  			} else {
   224  				require.Fail(t, "unexpected result", `"%s" for ignore task policy`, callRes)
   225  			}
   226  			lastRes = callRes
   227  		case <-time.After(1 * time.Second):
   228  			require.Fail(t, "test not completed in time", `last result: "%s" for ignore task policy`, lastRes)
   229  		}
   230  	}
   231  
   232  	require.Equal(t, 1, resultCallCount)
   233  	require.Equal(t, 1, taskCallCount)
   234  
   235  	require.Equal(t, taskResultCalled, lastRes)
   236  }
   237  
   238  func TestScheduler_Enqueue_ValidateOrder(t *testing.T) {
   239  	s := NewScheduler()
   240  	waitEnqueueAll := sync.WaitGroup{}
   241  
   242  	type failType bool
   243  	const (
   244  		fail failType = true
   245  		pass failType = false
   246  	)
   247  
   248  	type enqueueParams struct {
   249  		taskType   TaskType
   250  		taskAction failType
   251  		callIndex  int
   252  	}
   253  	testTask1 := TaskType{1, ReplacementPolicyCancelOld}
   254  	testTask2 := TaskType{2, ReplacementPolicyCancelOld}
   255  	testTask3 := TaskType{3, ReplacementPolicyIgnoreNew}
   256  	// Task type, ReplacementPolicy: CancelOld if true IgnoreNew if false, task fail or success, index
   257  	enqueueSequence := []enqueueParams{
   258  		{testTask1, pass, 0}, // 1 task event
   259  		{testTask2, pass, 0}, // 0 task event
   260  		{testTask3, fail, 0}, // 1 task event
   261  		{testTask3, pass, 0}, // 0 task event
   262  		{testTask2, pass, 0}, // 1 task event
   263  		{testTask1, pass, 0}, // 1 task event
   264  		{testTask3, fail, 0}, // 0 run event
   265  	}
   266  	const taskEventCount = 4
   267  
   268  	taskSuccessChan := make(chan enqueueParams, len(enqueueSequence))
   269  	taskCanceledChan := make(chan enqueueParams, len(enqueueSequence))
   270  	taskFailedChan := make(chan enqueueParams, len(enqueueSequence))
   271  	resChan := make(chan enqueueParams, len(enqueueSequence))
   272  
   273  	firstIgnoreNewProcessed := make(map[TaskType]bool)
   274  
   275  	ignoredCount := 0
   276  
   277  	waitEnqueueAll.Add(1)
   278  	for i := 0; i < len(enqueueSequence); i++ {
   279  		enqueueSequence[i].callIndex = i
   280  
   281  		p := enqueueSequence[i]
   282  
   283  		currentIndex := i
   284  
   285  		ignored := s.Enqueue(p.taskType, func(ctx context.Context) (interface{}, error) {
   286  			waitEnqueueAll.Wait()
   287  
   288  			if p.taskType.Policy == ReplacementPolicyCancelOld && ctx.Err() != nil && errors.Is(ctx.Err(), context.Canceled) {
   289  				taskCanceledChan <- p
   290  				t.Logf("task canceled, task seq: %d, task type: %+v", currentIndex, p.taskType)
   291  				return nil, ctx.Err()
   292  			}
   293  
   294  			if p.taskAction == fail {
   295  				taskFailedChan <- p
   296  				return nil, errors.New("test error")
   297  			}
   298  			taskSuccessChan <- p
   299  			t.Logf("task executed successfully, task seq: %d, task type: %+v", currentIndex, p.taskType)
   300  			return 10 * (currentIndex + 1), nil
   301  		}, func(res interface{}, taskType TaskType, err error) {
   302  			require.Equal(t, p.taskType, taskType)
   303  			resChan <- p
   304  			t.Logf("response invoked, task seq: %d, task type: %+v, result: %+v", currentIndex, taskType, res)
   305  		})
   306  
   307  		if ignored {
   308  			t.Logf("task ignored, task seq: %d, task type: %+v", currentIndex, p.taskType)
   309  			ignoredCount++
   310  		}
   311  
   312  		if _, ok := firstIgnoreNewProcessed[p.taskType]; !ok {
   313  			require.False(t, ignored)
   314  			firstIgnoreNewProcessed[p.taskType] = p.taskType.Policy == ReplacementPolicyCancelOld
   315  		} else {
   316  			if p.taskType.Policy == ReplacementPolicyIgnoreNew {
   317  				require.True(t, ignored)
   318  			} else {
   319  				require.False(t, ignored)
   320  			}
   321  		}
   322  	}
   323  
   324  	waitEnqueueAll.Done()
   325  
   326  	taskSuccessCount := make(map[TaskType]int)
   327  	taskCanceledCount := make(map[TaskType]int)
   328  	taskFailedCount := make(map[TaskType]int)
   329  	resChanCount := make(map[TaskType]int)
   330  
   331  	// Only ignored don't generate result events
   332  	expectedEventsCount := len(enqueueSequence) - ignoredCount + taskEventCount
   333  	for i := 0; i < expectedEventsCount; i++ {
   334  		// Loop for run and result calls
   335  		select {
   336  		case p := <-taskSuccessChan:
   337  			taskSuccessCount[p.taskType]++
   338  		case p := <-taskCanceledChan:
   339  			taskCanceledCount[p.taskType]++
   340  		case p := <-taskFailedChan:
   341  			taskFailedCount[p.taskType]++
   342  		case p := <-resChan:
   343  			resChanCount[p.taskType]++
   344  		case <-time.After(1 * time.Second):
   345  			require.Fail(t, "test not completed in time")
   346  		}
   347  	}
   348  
   349  	require.Equal(t, 1, taskSuccessCount[testTask1], "expected one task call for type: %d had %d", 1, taskSuccessCount[testTask1])
   350  	require.Equal(t, 1, taskSuccessCount[testTask2], "expected one task call for type: %d had %d", 2, taskSuccessCount[testTask2])
   351  	require.Equal(t, 0, taskSuccessCount[testTask3], "expected no task call for type: %d had %d", 3, taskSuccessCount[testTask3])
   352  
   353  	require.Equal(t, 1, taskCanceledCount[testTask1], "expected one task call for type: %d had %d", 1, taskSuccessCount[testTask1])
   354  	require.Equal(t, 0, taskCanceledCount[testTask2], "expected no task call for type: %d had %d", 2, taskSuccessCount[testTask2])
   355  	require.Equal(t, 0, taskCanceledCount[testTask3], "expected no task call for type: %d had %d", 3, taskSuccessCount[testTask3])
   356  
   357  	require.Equal(t, 0, taskFailedCount[testTask1], "expected no task call for type: %d had %d", 1, taskSuccessCount[testTask1])
   358  	require.Equal(t, 0, taskFailedCount[testTask2], "expected no task call for type: %d had %d", 2, taskSuccessCount[testTask2])
   359  	require.Equal(t, 1, taskFailedCount[testTask3], "expected one task call for type: %d had %d", 3, taskSuccessCount[testTask3])
   360  
   361  	require.Equal(t, 2, resChanCount[testTask1], "expected two task call for type: %d had %d", 1, taskSuccessCount[testTask1])
   362  	require.Equal(t, 2, resChanCount[testTask2], "expected two task call for type: %d had %d", 2, taskSuccessCount[testTask2])
   363  	require.Equal(t, 1, resChanCount[testTask3], "expected one task call for type: %d had %d", 3, taskSuccessCount[testTask3])
   364  }
   365  
   366  func TestScheduler_Enqueue_InResult(t *testing.T) {
   367  	s := NewScheduler()
   368  	callChan := make(chan int, 6)
   369  
   370  	s.Enqueue(TaskType{ID: 1, Policy: ReplacementPolicyCancelOld},
   371  		func(ctx context.Context) (interface{}, error) {
   372  			callChan <- 0
   373  			return nil, nil
   374  		}, func(res interface{}, taskType TaskType, err error) {
   375  			callChan <- 1
   376  			s.Enqueue(TaskType{1, ReplacementPolicyCancelOld}, func(ctx context.Context) (interface{}, error) {
   377  				callChan <- 2
   378  				return nil, nil
   379  			}, func(res interface{}, taskType TaskType, err error) {
   380  				callChan <- 3
   381  				s.Enqueue(TaskType{1, ReplacementPolicyCancelOld}, func(ctx context.Context) (interface{}, error) {
   382  					callChan <- 4
   383  					return nil, nil
   384  				}, func(res interface{}, taskType TaskType, err error) {
   385  					callChan <- 5
   386  				})
   387  			})
   388  		},
   389  	)
   390  	for i := 0; i < 6; i++ {
   391  		select {
   392  		case res := <-callChan:
   393  			require.Equal(t, i, res)
   394  		case <-time.After(1 * time.Second):
   395  			require.Fail(t, "test not completed in time")
   396  		}
   397  	}
   398  }
   399  
   400  func TestScheduler_Enqueue_Quick_Stop(t *testing.T) {
   401  	scheduler := NewScheduler()
   402  
   403  	var wg sync.WaitGroup
   404  	wg.Add(2)
   405  
   406  	longRunningTask := func(ctx context.Context) (interface{}, error) {
   407  		defer wg.Done()
   408  		select {
   409  		case <-ctx.Done():
   410  			// we should reach here rather than other condition branch as Stop() canceled the context quickly
   411  			return nil, ctx.Err()
   412  		case <-time.After(10 * time.Second):
   413  			return "task completed", nil
   414  		}
   415  	}
   416  
   417  	resFn := func(res interface{}, taskType TaskType, err error) {
   418  		require.Error(t, err)
   419  		require.ErrorIs(t, err, context.Canceled)
   420  		wg.Done()
   421  	}
   422  
   423  	scheduler.Enqueue(TaskType{ID: 1, Policy: ReplacementPolicyCancelOld}, longRunningTask, resFn)
   424  
   425  	require.NotPanics(t, func() {
   426  		scheduler.Stop()
   427  		wg.Wait()
   428  	})
   429  }