github.com/ngicks/gokugen@v0.0.5/scheduler/task_test.go (about)

     1  package scheduler_test
     2  
     3  import (
     4  	"context"
     5  	"math/rand"
     6  	"sync"
     7  	"sync/atomic"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/ngicks/gokugen/scheduler"
    12  )
    13  
    14  type mockTaskSet struct {
    15  	task                              *scheduler.Task
    16  	workCallCount, isContextCancelled *int32
    17  	selectCh                          chan struct{}
    18  	cancel                            func()
    19  }
    20  
    21  func (s *mockTaskSet) WorkCallCount() int32 {
    22  	return atomic.LoadInt32(s.workCallCount)
    23  }
    24  func (s *mockTaskSet) IsContextCancelled() bool {
    25  	return atomic.LoadInt32(s.isContextCancelled) == 1
    26  }
    27  func (s *mockTaskSet) GetSelectCh() <-chan struct{} {
    28  	return s.selectCh
    29  }
    30  func (s *mockTaskSet) Close() {
    31  	s.cancel()
    32  }
    33  func (s *mockTaskSet) Task() *scheduler.Task {
    34  	return s.task
    35  }
    36  
    37  func mockTaskFactory() *mockTaskSet {
    38  	var workCallCount, isContextCancelled int32
    39  
    40  	selectCh := make(chan struct{})
    41  	ctx, cancel := context.WithCancel(context.Background())
    42  
    43  	testSet := &mockTaskSet{
    44  		workCallCount:      &workCallCount,
    45  		isContextCancelled: &isContextCancelled,
    46  		selectCh:           selectCh,
    47  		cancel:             cancel,
    48  	}
    49  	t := scheduler.NewTask(time.Now(), func(taskCtx context.Context, scheduled time.Time) {
    50  		atomic.AddInt32(&workCallCount, 1)
    51  		testSet.selectCh <- struct{}{}
    52  		go func() {
    53  			select {
    54  			case <-ctx.Done():
    55  				return
    56  			case <-taskCtx.Done():
    57  				atomic.StoreInt32(&isContextCancelled, 1)
    58  				testSet.selectCh <- struct{}{}
    59  			}
    60  			return
    61  		}()
    62  		<-ctx.Done()
    63  	})
    64  	testSet.task = t
    65  	return testSet
    66  }
    67  
    68  func exhaustSelectChan(ch <-chan struct{}) {
    69  	timer := time.NewTimer(time.Millisecond)
    70  loop:
    71  	for {
    72  		select {
    73  		case <-timer.C:
    74  			break loop
    75  		default:
    76  			{
    77  				select {
    78  				case <-timer.C:
    79  					break loop
    80  				case <-ch:
    81  				default:
    82  				}
    83  			}
    84  		}
    85  	}
    86  }
    87  
    88  func TestTask(t *testing.T) {
    89  	t.Run("cancel", func(t *testing.T) {
    90  		taskSet := mockTaskFactory()
    91  		defer exhaustSelectChan(taskSet.GetSelectCh())
    92  
    93  		task := taskSet.Task()
    94  
    95  		if task.IsCancelled() {
    96  			t.Fatalf("IsCancelled must be false")
    97  		}
    98  		if !task.Cancel() {
    99  			t.Fatalf("closed must be true")
   100  		}
   101  		for i := 0; i < 10; i++ {
   102  			// This does not block. Bacause task is already cancelled, internal work will no be called.
   103  			task.Do(context.TODO())
   104  			if !task.IsCancelled() {
   105  				t.Fatalf("IsCancelled must be true")
   106  			}
   107  			if task.Cancel() {
   108  				t.Fatalf("closed must be false")
   109  			}
   110  		}
   111  
   112  		if taskSet.WorkCallCount() != 0 {
   113  			t.Fatalf("work must be called")
   114  		}
   115  	})
   116  
   117  	t.Run("do and cancel", func(t *testing.T) {
   118  		taskSet := mockTaskFactory()
   119  		defer exhaustSelectChan(taskSet.GetSelectCh())
   120  		task := taskSet.Task()
   121  
   122  		if taskSet.WorkCallCount() != 0 {
   123  			t.Fatalf("work must not be called at this point")
   124  		}
   125  		if task.IsDone() {
   126  			t.Fatalf("IsDone must be false")
   127  		}
   128  
   129  		wg := sync.WaitGroup{}
   130  		wg.Add(1)
   131  		go func() {
   132  			task.Do(context.TODO())
   133  			wg.Done()
   134  		}()
   135  
   136  		<-taskSet.GetSelectCh()
   137  
   138  		taskSet.Close()
   139  		wg.Wait()
   140  
   141  		if taskSet.WorkCallCount() != 1 {
   142  			t.Fatalf("work call count is not correct")
   143  		}
   144  
   145  		// This does not block. Because if it's done, it does not call internal work
   146  		task.Do(context.TODO())
   147  
   148  		if !task.IsDone() {
   149  			t.Fatalf("IsDone must be true")
   150  		}
   151  
   152  		if !task.Cancel() {
   153  			t.Fatalf("closed must be true")
   154  		}
   155  		if !task.IsCancelled() {
   156  			t.Fatalf("IsCancelled must be true")
   157  		}
   158  
   159  		task.Do(context.TODO())
   160  
   161  		if taskSet.WorkCallCount() != 1 {
   162  			t.Fatalf("work call count is not correct")
   163  		}
   164  	})
   165  
   166  	t.Run("passing already closed chan to Do", func(t *testing.T) {
   167  		taskSet := mockTaskFactory()
   168  		defer exhaustSelectChan(taskSet.GetSelectCh())
   169  		task := taskSet.Task()
   170  
   171  		if task.IsDone() {
   172  			t.Fatalf("IsDone must be false")
   173  		}
   174  
   175  		ctx, cancel := context.WithCancel(context.Background())
   176  		cancel()
   177  		// This does not block.
   178  		task.Do(ctx)
   179  		if !task.IsDone() {
   180  			t.Fatalf("IsDone must be true")
   181  		}
   182  	})
   183  
   184  	t.Run("cancelling task and closing chan passed to Do", func(t *testing.T) {
   185  		taskSet := mockTaskFactory()
   186  		defer exhaustSelectChan(taskSet.GetSelectCh())
   187  		task := taskSet.Task()
   188  
   189  		ctx := context.Background()
   190  		wg := sync.WaitGroup{}
   191  		wg.Add(1)
   192  		go func() {
   193  			task.Do(ctx)
   194  			wg.Done()
   195  		}()
   196  
   197  		if taskSet.IsContextCancelled() {
   198  			t.Fatalf("ctx must NOT be cancelled at this point")
   199  		}
   200  		selectCh := taskSet.GetSelectCh()
   201  		// waiting for Do to start
   202  		<-selectCh
   203  
   204  		go func() {
   205  			task.Cancel()
   206  		}()
   207  		<-selectCh
   208  		if !taskSet.IsContextCancelled() {
   209  			t.Fatalf("ctx must be cancelled")
   210  		}
   211  
   212  		taskSet.Close()
   213  		wg.Wait()
   214  	})
   215  
   216  	t.Run("GetScheduledTime", func(t *testing.T) {
   217  		for i := 0; i < 10; i++ {
   218  			n := time.Now().Add(time.Duration(rand.Int()))
   219  			task := scheduler.NewTask(n, func(taskCtx context.Context, scheduled time.Time) {})
   220  			if n != task.GetScheduledTime() {
   221  				t.Errorf("time mismatched! passed=%s, received=%s", n, task.GetScheduledTime())
   222  			}
   223  		}
   224  	})
   225  }