github.com/matrixorigin/matrixone@v0.7.0/pkg/taskservice/task_runner_test.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package taskservice
    16  
    17  import (
    18  	"context"
    19  	"sync"
    20  	"sync/atomic"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    25  	"github.com/matrixorigin/matrixone/pkg/common/runtime"
    26  	"github.com/matrixorigin/matrixone/pkg/logutil"
    27  	"github.com/matrixorigin/matrixone/pkg/pb/task"
    28  	"github.com/stretchr/testify/assert"
    29  	"github.com/stretchr/testify/require"
    30  	"go.uber.org/zap"
    31  )
    32  
    33  func TestRunTask(t *testing.T) {
    34  	runTaskRunnerTest(t, func(r *taskRunner, s TaskService, store TaskStorage) {
    35  		c := make(chan struct{})
    36  		r.RegisterExecutor(0, func(ctx context.Context, task task.Task) error {
    37  			defer close(c)
    38  			return nil
    39  		})
    40  		mustAddTestTask(t, store, 1, newTestTask("t1"))
    41  		mustAllocTestTask(t, s, store, map[string]string{"t1": r.runnerID})
    42  		<-c
    43  	}, WithRunnerParallelism(1),
    44  		WithRunnerFetchInterval(time.Millisecond))
    45  }
    46  
    47  func TestRunTasksInParallel(t *testing.T) {
    48  	runTaskRunnerTest(t, func(r *taskRunner, s TaskService, store TaskStorage) {
    49  		wg := &sync.WaitGroup{}
    50  		wg.Add(2)
    51  		r.RegisterExecutor(0, func(ctx context.Context, task task.Task) error {
    52  			defer wg.Done()
    53  			time.Sleep(time.Millisecond * 200)
    54  			return nil
    55  		})
    56  		mustAddTestTask(t, store, 1, newTestTask("t1"))
    57  		mustAddTestTask(t, store, 1, newTestTask("t2"))
    58  		mustAllocTestTask(t, s, store, map[string]string{"t1": r.runnerID, "t2": r.runnerID})
    59  		wg.Wait()
    60  	}, WithRunnerParallelism(2),
    61  		WithRunnerFetchInterval(time.Millisecond))
    62  }
    63  
    64  func TestTooMuchTasksWillBlockAndEventuallyCanBeExecuted(t *testing.T) {
    65  	runTaskRunnerTest(t, func(r *taskRunner, s TaskService, store TaskStorage) {
    66  		c := make(chan struct{})
    67  		continueC := make(chan struct{})
    68  		v := uint32(0)
    69  		wait := time.Millisecond * 200
    70  		r.RegisterExecutor(0, func(ctx context.Context, task task.Task) error {
    71  			n := atomic.AddUint32(&v, 1)
    72  			if n == 2 {
    73  				defer close(c) // second task close the chan
    74  			}
    75  			if n == 1 {
    76  				time.Sleep(wait) // block first task
    77  				<-continueC
    78  			}
    79  
    80  			return nil
    81  		})
    82  		mustAddTestTask(t, store, 1, newTestTask("t1"))
    83  		mustAddTestTask(t, store, 1, newTestTask("t2"))
    84  		mustAllocTestTask(t, s, store, map[string]string{"t1": r.runnerID, "t2": r.runnerID})
    85  		select {
    86  		case <-c:
    87  			assert.Fail(t, "must block")
    88  		case <-time.After(wait):
    89  			assert.Equal(t, uint32(1), atomic.LoadUint32(&v))
    90  			close(continueC) // second task can be run
    91  		}
    92  		<-c
    93  		assert.Equal(t, uint32(2), atomic.LoadUint32(&v))
    94  	}, WithRunnerParallelism(1),
    95  		WithRunnerFetchInterval(time.Millisecond))
    96  }
    97  
    98  func TestHeartbeatWithRunningTask(t *testing.T) {
    99  	runTaskRunnerTest(t, func(r *taskRunner, s TaskService, store TaskStorage) {
   100  		c := make(chan struct{})
   101  		completeC := make(chan struct{})
   102  		n := uint32(0)
   103  		r.RegisterExecutor(0, func(ctx context.Context, task task.Task) error {
   104  			if atomic.AddUint32(&n, 1) == 2 {
   105  				close(c)
   106  			}
   107  			<-completeC
   108  			return nil
   109  		})
   110  		mustAddTestTask(t, store, 1, newTestTask("t1"))
   111  		mustAddTestTask(t, store, 1, newTestTask("t2"))
   112  		mustAllocTestTask(t, s, store, map[string]string{"t1": r.runnerID, "t2": r.runnerID})
   113  		<-c
   114  		mustWaitTestTaskHasHeartbeat(t, store, 2)
   115  		close(completeC)
   116  	}, WithRunnerParallelism(2),
   117  		WithRunnerHeartbeatInterval(time.Millisecond),
   118  		WithRunnerFetchInterval(time.Millisecond))
   119  }
   120  
   121  func TestRunTaskWithRetry(t *testing.T) {
   122  	runTaskRunnerTest(t, func(r *taskRunner, s TaskService, store TaskStorage) {
   123  		c := make(chan struct{})
   124  		n := uint32(0)
   125  		r.RegisterExecutor(0, func(ctx context.Context, task task.Task) error {
   126  			if atomic.AddUint32(&n, 1) == 1 {
   127  				return moerr.NewInternalError(context.TODO(), "error")
   128  			}
   129  			close(c)
   130  			return nil
   131  		})
   132  		v := newTestTask("t1")
   133  		v.Metadata.Options.MaxRetryTimes = 1
   134  		mustAddTestTask(t, store, 1, v)
   135  		mustAllocTestTask(t, s, store, map[string]string{"t1": r.runnerID})
   136  		<-c
   137  		assert.Equal(t, uint32(2), n)
   138  	}, WithRunnerParallelism(2),
   139  		WithRunnerHeartbeatInterval(time.Millisecond),
   140  		WithRunnerFetchInterval(time.Millisecond))
   141  }
   142  
   143  func TestRunTaskWithDisableRetry(t *testing.T) {
   144  	runTaskRunnerTest(t, func(r *taskRunner, s TaskService, store TaskStorage) {
   145  		c := make(chan struct{})
   146  		n := uint32(0)
   147  		r.RegisterExecutor(0, func(ctx context.Context, task task.Task) error {
   148  			close(c)
   149  			if atomic.AddUint32(&n, 1) == 1 {
   150  				return moerr.NewInternalError(context.TODO(), "error")
   151  			}
   152  			return nil
   153  		})
   154  		v := newTestTask("t1")
   155  		v.Metadata.Options.MaxRetryTimes = 0
   156  		mustAddTestTask(t, store, 1, v)
   157  		mustAllocTestTask(t, s, store, map[string]string{"t1": r.runnerID})
   158  		<-c
   159  		mustWaitTestTaskHasExecuteResult(t, store, 1)
   160  		v = mustGetTestTask(t, store, 1)[0]
   161  		assert.Equal(t, task.ResultCode_Failed, v.ExecuteResult.Code)
   162  	}, WithRunnerParallelism(2),
   163  		WithRunnerHeartbeatInterval(time.Millisecond),
   164  		WithRunnerFetchInterval(time.Millisecond))
   165  }
   166  
   167  func TestCancelRunningTask(t *testing.T) {
   168  	runTaskRunnerTest(t, func(r *taskRunner, s TaskService, store TaskStorage) {
   169  		c := make(chan struct{})
   170  		cancelC := make(chan struct{})
   171  		r.RegisterExecutor(0, func(ctx context.Context, task task.Task) error {
   172  			close(c)
   173  			<-ctx.Done()
   174  			close(cancelC)
   175  			return nil
   176  		})
   177  		v := newTestTask("t1")
   178  		v.Metadata.Options.MaxRetryTimes = 0
   179  		mustAddTestTask(t, store, 1, v)
   180  		mustAllocTestTask(t, s, store, map[string]string{"t1": r.runnerID})
   181  		<-c
   182  		v = mustGetTestTask(t, store, 1)[0]
   183  		v.Epoch++
   184  		mustUpdateTestTask(t, store, 1, []task.Task{v})
   185  		<-cancelC
   186  		r.mu.RLock()
   187  		defer r.mu.RUnlock()
   188  		assert.Equal(t, 0, len(r.mu.runningTasks))
   189  	}, WithRunnerParallelism(2),
   190  		WithRunnerHeartbeatInterval(time.Millisecond),
   191  		WithRunnerFetchInterval(time.Millisecond))
   192  }
   193  
   194  func runTaskRunnerTest(t *testing.T,
   195  	testFunc func(r *taskRunner, s TaskService, store TaskStorage),
   196  	opts ...RunnerOption) {
   197  	store := NewMemTaskStorage()
   198  	s := NewTaskService(runtime.DefaultRuntime(), store)
   199  	defer func() {
   200  		assert.NoError(t, s.Close())
   201  	}()
   202  
   203  	opts = append(opts, WithRunnerLogger(logutil.GetPanicLoggerWithLevel(zap.DebugLevel)))
   204  	r := NewTaskRunner("r1", s, opts...)
   205  
   206  	require.NoError(t, r.Start())
   207  	defer func() {
   208  		require.NoError(t, r.Stop())
   209  	}()
   210  	testFunc(r.(*taskRunner), s, store)
   211  }
   212  
   213  func mustAllocTestTask(t *testing.T, s TaskService, store TaskStorage, alloc map[string]string) {
   214  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
   215  	defer cancel()
   216  
   217  	tasks := mustGetTestTask(t, store, len(alloc), WithTaskStatusCond(EQ, task.TaskStatus_Created))
   218  	n := 0
   219  	for _, v := range tasks {
   220  		if runner, ok := alloc[v.Metadata.ID]; ok {
   221  			require.NoError(t, s.Allocate(ctx, v, runner))
   222  			n++
   223  		}
   224  	}
   225  	if n != len(alloc) {
   226  		require.Fail(t, "task not found")
   227  	}
   228  }
   229  
   230  func mustWaitTestTaskHasHeartbeat(t *testing.T, store TaskStorage, expectHasHeartbeatCount int) {
   231  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
   232  	defer cancel()
   233  
   234  	for {
   235  		select {
   236  		case <-ctx.Done():
   237  			require.Fail(t, "wait heatbeat timeout")
   238  		default:
   239  			tasks := mustGetTestTask(t, store, expectHasHeartbeatCount,
   240  				WithTaskStatusCond(EQ, task.TaskStatus_Running))
   241  			n := 0
   242  			for _, v := range tasks {
   243  				if v.LastHeartbeat > 0 {
   244  					n++
   245  				}
   246  			}
   247  			if n == len(tasks) {
   248  				return
   249  			}
   250  		}
   251  	}
   252  }
   253  
   254  func mustWaitTestTaskHasExecuteResult(t *testing.T, store TaskStorage, expectCount int) {
   255  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
   256  	defer cancel()
   257  
   258  	for {
   259  		select {
   260  		case <-ctx.Done():
   261  			require.Fail(t, "wait execute result timeout")
   262  		default:
   263  			tasks, err := store.Query(ctx, WithTaskStatusCond(EQ, task.TaskStatus_Completed))
   264  			require.NoError(t, err)
   265  			if len(tasks) != expectCount {
   266  				break
   267  			}
   268  			n := 0
   269  			for _, v := range tasks {
   270  				if v.ExecuteResult != nil {
   271  					n++
   272  				}
   273  			}
   274  			if n == len(tasks) {
   275  				return
   276  			}
   277  		}
   278  	}
   279  }