github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/engine/jobmaster/dm/worker_manager_test.go (about)

     1  // Copyright 2022 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package dm
    15  
    16  import (
    17  	"context"
    18  	"sync"
    19  	"time"
    20  
    21  	"github.com/pingcap/log"
    22  	dmconfig "github.com/pingcap/tiflow/dm/config"
    23  	"github.com/pingcap/tiflow/engine/framework"
    24  	frameModel "github.com/pingcap/tiflow/engine/framework/model"
    25  	"github.com/pingcap/tiflow/engine/jobmaster/dm/config"
    26  	"github.com/pingcap/tiflow/engine/jobmaster/dm/metadata"
    27  	"github.com/pingcap/tiflow/engine/jobmaster/dm/runtime"
    28  	dmpkg "github.com/pingcap/tiflow/engine/pkg/dm"
    29  	resModel "github.com/pingcap/tiflow/engine/pkg/externalresource/model"
    30  	kvmock "github.com/pingcap/tiflow/engine/pkg/meta/mock"
    31  	"github.com/pingcap/tiflow/pkg/errors"
    32  	"github.com/stretchr/testify/mock"
    33  	"github.com/stretchr/testify/require"
    34  )
    35  
    36  func (t *testDMJobmasterSuite) TestUpdateWorkerStatus() {
    37  	jobCfg := &config.JobCfg{}
    38  	require.NoError(t.T(), jobCfg.DecodeFile(jobTemplatePath))
    39  	job := metadata.NewJob(jobCfg)
    40  	jobStore := metadata.NewJobStore(kvmock.NewMetaMock(), log.L())
    41  	unitStore := metadata.NewUnitStateStore(kvmock.NewMetaMock())
    42  	require.NoError(t.T(), jobStore.Put(context.Background(), job))
    43  	workerManager := NewWorkerManager("job_id", nil, jobStore, unitStore, nil, nil, nil, log.L(), resModel.ResourceTypeLocalFile)
    44  
    45  	require.Len(t.T(), workerManager.WorkerStatus(), 0)
    46  
    47  	source1 := jobCfg.Upstreams[0].SourceID
    48  	source2 := jobCfg.Upstreams[1].SourceID
    49  	workerStatus1 := runtime.InitWorkerStatus(source1, frameModel.WorkerDMDump, "worker-id-1")
    50  	workerStatus2 := runtime.InitWorkerStatus(source2, frameModel.WorkerDMDump, "worker-id-2")
    51  	require.True(t.T(), workerManager.allTombStone())
    52  
    53  	// Creating
    54  	workerManager.UpdateWorkerStatus(workerStatus1)
    55  	workerManager.UpdateWorkerStatus(workerStatus2)
    56  	workerStatusMap := workerManager.WorkerStatus()
    57  	require.Len(t.T(), workerStatusMap, 2)
    58  	require.Contains(t.T(), workerStatusMap, source1)
    59  	require.Contains(t.T(), workerStatusMap, source2)
    60  	require.Equal(t.T(), workerStatusMap[source1], workerStatus1)
    61  	require.Equal(t.T(), workerStatusMap[source2], workerStatus2)
    62  	require.False(t.T(), workerManager.allTombStone())
    63  
    64  	// Online
    65  	workerStatus1.Stage = runtime.WorkerOnline
    66  	workerStatus2.Stage = runtime.WorkerOnline
    67  	workerManager.UpdateWorkerStatus(workerStatus1)
    68  	workerManager.UpdateWorkerStatus(workerStatus2)
    69  	workerStatusMap = workerManager.WorkerStatus()
    70  	require.Len(t.T(), workerStatusMap, 2)
    71  	require.Contains(t.T(), workerStatusMap, source1)
    72  	require.Contains(t.T(), workerStatusMap, source2)
    73  	require.Equal(t.T(), workerStatusMap[source1], workerStatus1)
    74  	require.Equal(t.T(), workerStatusMap[source2], workerStatus2)
    75  	require.False(t.T(), workerManager.allTombStone())
    76  
    77  	// Offline
    78  	workerStatus1.Stage = runtime.WorkerOffline
    79  	workerManager.UpdateWorkerStatus(workerStatus1)
    80  	workerStatusMap = workerManager.WorkerStatus()
    81  	require.Len(t.T(), workerStatusMap, 2)
    82  	require.Contains(t.T(), workerStatusMap, source1)
    83  	require.Contains(t.T(), workerStatusMap, source2)
    84  	require.Equal(t.T(), workerStatusMap[source1], workerStatus1)
    85  	require.Equal(t.T(), workerStatusMap[source2], workerStatus2)
    86  	require.False(t.T(), workerManager.allTombStone())
    87  
    88  	// Finished
    89  	workerStatus1.Stage = runtime.WorkerFinished
    90  	workerManager.UpdateWorkerStatus(workerStatus1)
    91  	workerStatusMap = workerManager.WorkerStatus()
    92  	require.Len(t.T(), workerStatusMap, 2)
    93  	require.Contains(t.T(), workerStatusMap, source1)
    94  	require.Contains(t.T(), workerStatusMap, source2)
    95  	require.Equal(t.T(), workerStatusMap[source1], workerStatus1)
    96  	require.Equal(t.T(), workerStatusMap[source2], workerStatus2)
    97  	require.False(t.T(), workerManager.allTombStone())
    98  
    99  	// mock jobmaster recover
   100  	workerStatus1.Stage = runtime.WorkerOnline
   101  	workerStatus2.Stage = runtime.WorkerOnline
   102  	workerStatusList := []runtime.WorkerStatus{workerStatus1, workerStatus2}
   103  	workerManager = NewWorkerManager("job_id", workerStatusList, jobStore, nil, nil, nil, nil, log.L(), resModel.ResourceTypeLocalFile)
   104  	workerStatusMap = workerManager.WorkerStatus()
   105  	require.Len(t.T(), workerStatusMap, 2)
   106  	require.Contains(t.T(), workerStatusMap, source1)
   107  	require.Contains(t.T(), workerStatusMap, source2)
   108  	require.Equal(t.T(), workerStatusMap[source1], workerStatus1)
   109  	require.Equal(t.T(), workerStatusMap[source2], workerStatus2)
   110  	require.False(t.T(), workerManager.allTombStone())
   111  
   112  	// mock dispatch error
   113  	workerManager.removeWorkerStatusByWorkerID("worker-not-exist")
   114  	workerStatusMap = workerManager.WorkerStatus()
   115  	require.Len(t.T(), workerStatusMap, 2)
   116  	require.Contains(t.T(), workerStatusMap, source1)
   117  	require.Contains(t.T(), workerStatusMap, source2)
   118  	require.Equal(t.T(), workerStatusMap[source1], workerStatus1)
   119  	require.Equal(t.T(), workerStatusMap[source2], workerStatus2)
   120  	workerManager.removeWorkerStatusByWorkerID(workerStatus1.ID)
   121  	workerStatusMap = workerManager.WorkerStatus()
   122  	require.Len(t.T(), workerStatusMap, 1)
   123  	require.Contains(t.T(), workerStatusMap, source2)
   124  	require.Equal(t.T(), workerStatusMap[source2], workerStatus2)
   125  	require.False(t.T(), workerManager.allTombStone())
   126  
   127  	workerStatus2.Stage = runtime.WorkerFinished
   128  	workerManager.UpdateWorkerStatus(workerStatus2)
   129  	require.True(t.T(), workerManager.allTombStone())
   130  }
   131  
   132  func (t *testDMJobmasterSuite) TestClearWorkerStatus() {
   133  	messageAgent := &dmpkg.MockMessageAgent{}
   134  	ctx, cancel := context.WithCancel(context.Background())
   135  	defer cancel()
   136  	source1 := "source1"
   137  	source2 := "source2"
   138  	workerStatus1 := runtime.InitWorkerStatus(source1, frameModel.WorkerDMDump, "worker-id-1")
   139  	workerStatus2 := runtime.InitWorkerStatus(source2, frameModel.WorkerDMDump, "worker-id-2")
   140  
   141  	workerManager := NewWorkerManager("job_id", []runtime.WorkerStatus{workerStatus1, workerStatus2}, nil, nil, nil, messageAgent, nil,
   142  		log.L(), resModel.ResourceTypeLocalFile)
   143  	require.Len(t.T(), workerManager.WorkerStatus(), 2)
   144  
   145  	workerManager.removeOfflineWorkers()
   146  	require.Len(t.T(), workerManager.WorkerStatus(), 2)
   147  
   148  	workerStatus1.Stage = runtime.WorkerOffline
   149  	workerStatus2.Stage = runtime.WorkerOnline
   150  	workerManager.UpdateWorkerStatus(workerStatus1)
   151  	workerManager.UpdateWorkerStatus(workerStatus2)
   152  	workerManager.removeOfflineWorkers()
   153  	require.Len(t.T(), workerManager.WorkerStatus(), 1)
   154  
   155  	job := metadata.NewJob(&config.JobCfg{})
   156  	destroyError := errors.New("destroy error")
   157  
   158  	job.Tasks[source2] = metadata.NewTask(&config.TaskCfg{})
   159  	require.NoError(t.T(), workerManager.stopOutdatedWorkers(context.Background(), job))
   160  	messageAgent.On("SendMessage").Return(destroyError).Once()
   161  	jobCfg := &config.JobCfg{ModRevision: 1}
   162  	taskCfg := jobCfg.ToTaskCfg()
   163  	job.Tasks[source2] = metadata.NewTask(taskCfg)
   164  	require.EqualError(t.T(), workerManager.stopOutdatedWorkers(context.Background(), job), destroyError.Error())
   165  	messageAgent.On("SendMessage").Return(nil).Once()
   166  	job.Tasks[source2] = metadata.NewTask(taskCfg)
   167  	require.NoError(t.T(), workerManager.stopOutdatedWorkers(context.Background(), job))
   168  
   169  	job = metadata.NewJob(&config.JobCfg{})
   170  	job.Tasks[source2] = metadata.NewTask(&config.TaskCfg{})
   171  	err := workerManager.stopUnneededWorkers(ctx, job)
   172  	require.NoError(t.T(), err)
   173  	require.Len(t.T(), workerManager.WorkerStatus(), 1)
   174  
   175  	delete(job.Tasks, source2)
   176  	messageAgent.On("SendMessage").Return(destroyError).Once()
   177  	err = workerManager.stopUnneededWorkers(ctx, job)
   178  	require.EqualError(t.T(), err, destroyError.Error())
   179  	require.Len(t.T(), workerManager.WorkerStatus(), 1)
   180  
   181  	messageAgent.On("SendMessage").Return(nil).Once()
   182  	err = workerManager.stopUnneededWorkers(ctx, job)
   183  	require.NoError(t.T(), err)
   184  	require.Len(t.T(), workerManager.WorkerStatus(), 1)
   185  
   186  	workerStatus2.Stage = runtime.WorkerOffline
   187  	workerManager.UpdateWorkerStatus(workerStatus2)
   188  	require.NoError(t.T(), workerManager.stopUnneededWorkers(ctx, job))
   189  	workerManager.removeOfflineWorkers()
   190  	require.Len(t.T(), workerManager.WorkerStatus(), 0)
   191  
   192  	err = workerManager.onJobDel(context.Background())
   193  	require.NoError(t.T(), err)
   194  	require.Len(t.T(), workerManager.WorkerStatus(), 0)
   195  
   196  	workerStatus1.Stage = runtime.WorkerFinished
   197  	workerManager.UpdateWorkerStatus(workerStatus1)
   198  	workerStatus2.Stage = runtime.WorkerOnline
   199  	workerManager.UpdateWorkerStatus(workerStatus2)
   200  
   201  	messageAgent.On("SendMessage").Return(destroyError).Once()
   202  	messageAgent.On("SendMessage").Return(nil).Once()
   203  	workerManager.UpdateWorkerStatus(workerStatus1)
   204  	workerManager.UpdateWorkerStatus(workerStatus2)
   205  	require.Len(t.T(), workerManager.WorkerStatus(), 2)
   206  	err = workerManager.onJobDel(context.Background())
   207  	require.EqualError(t.T(), err, destroyError.Error())
   208  	require.Len(t.T(), workerManager.WorkerStatus(), 2)
   209  	workerStatus1.Stage = runtime.WorkerOffline
   210  	workerManager.UpdateWorkerStatus(workerStatus1)
   211  	workerManager.removeOfflineWorkers()
   212  	require.Len(t.T(), workerManager.WorkerStatus(), 1)
   213  
   214  	workerManager.UpdateWorkerStatus(runtime.InitWorkerStatus("task", frameModel.WorkerDMDump, "worker-id"))
   215  	require.Len(t.T(), workerManager.WorkerStatus(), 2)
   216  	workerManager.removeOfflineWorkers()
   217  	require.Len(t.T(), workerManager.WorkerStatus(), 2)
   218  	require.Eventually(t.T(), func() bool {
   219  		workerManager.removeOfflineWorkers()
   220  		return len(workerManager.WorkerStatus()) == 1
   221  	}, 10*time.Second, 200*time.Millisecond)
   222  }
   223  
   224  func (t *testDMJobmasterSuite) TestCreateWorker() {
   225  	mockAgent := &MockWorkerAgent{}
   226  	unitStore := metadata.NewUnitStateStore(kvmock.NewMetaMock())
   227  	workerManager := NewWorkerManager("job_id", nil, nil, unitStore, mockAgent, nil, nil, log.L(), resModel.ResourceTypeLocalFile)
   228  
   229  	jobCfg := &config.JobCfg{}
   230  	require.NoError(t.T(), jobCfg.DecodeFile(jobTemplatePath))
   231  	taskCfgs := jobCfg.ToTaskCfgs()
   232  	task1 := jobCfg.Upstreams[0].SourceID
   233  	worker1 := "worker1"
   234  	createError := errors.New("create error")
   235  	mockAgent.On("CreateWorker").Return("", createError).Once()
   236  	require.EqualError(t.T(), workerManager.createWorker(context.Background(), task1, frameModel.WorkerDMDump, taskCfgs[task1]), createError.Error())
   237  	require.Len(t.T(), workerManager.WorkerStatus(), 0)
   238  
   239  	workerStatus1 := runtime.InitWorkerStatus(task1, frameModel.WorkerDMDump, worker1)
   240  	mockAgent.On("CreateWorker").Return(worker1, createError).Once()
   241  	require.EqualError(t.T(), workerManager.createWorker(context.Background(), task1, frameModel.WorkerDMDump, taskCfgs[task1]), createError.Error())
   242  	workerStatusMap := workerManager.WorkerStatus()
   243  	require.Len(t.T(), workerStatusMap, 1)
   244  	require.Contains(t.T(), workerStatusMap, task1)
   245  	require.Equal(t.T(), workerStatusMap[task1].ID, workerStatus1.ID)
   246  
   247  	task2 := jobCfg.Upstreams[1].SourceID
   248  	worker2 := "worker2"
   249  	workerStatus2 := runtime.InitWorkerStatus(task2, frameModel.WorkerDMLoad, worker2)
   250  	mockAgent.On("CreateWorker").Return(worker2, nil).Once()
   251  	require.NoError(t.T(), workerManager.createWorker(context.Background(), task2, frameModel.WorkerDMLoad, taskCfgs[task2]))
   252  	workerStatusMap = workerManager.WorkerStatus()
   253  	require.Len(t.T(), workerStatusMap, 2)
   254  	require.Contains(t.T(), workerStatusMap, task1)
   255  	require.Contains(t.T(), workerStatusMap, task2)
   256  	require.Equal(t.T(), workerStatusMap[task1].ID, workerStatus1.ID)
   257  	require.Equal(t.T(), workerStatusMap[task2].ID, workerStatus2.ID)
   258  }
   259  
   260  func (t *testDMJobmasterSuite) TestGetUnit() {
   261  	ctx, cancel := context.WithCancel(context.Background())
   262  	defer cancel()
   263  	mockAgent := &MockCheckpointAgent{}
   264  	task := &metadata.Task{Cfg: &config.TaskCfg{}}
   265  	task.Cfg.TaskMode = dmconfig.ModeFull
   266  	workerManager := NewWorkerManager("job_id", nil, nil, nil, nil, nil, mockAgent, log.L(), resModel.ResourceTypeLocalFile)
   267  
   268  	workerStatus := runtime.NewWorkerStatus("source", frameModel.WorkerDMDump, "worker-id-1", runtime.WorkerOnline, 0)
   269  	require.Equal(t.T(), getNextUnit(task, workerStatus), frameModel.WorkerDMDump)
   270  	workerStatus.Stage = runtime.WorkerFinished
   271  	require.Equal(t.T(), getNextUnit(task, workerStatus), frameModel.WorkerDMLoad)
   272  	workerStatus.Stage = runtime.WorkerOnline
   273  	workerStatus.Unit = frameModel.WorkerDMLoad
   274  	require.Equal(t.T(), getNextUnit(task, workerStatus), frameModel.WorkerDMLoad)
   275  	workerStatus.Stage = runtime.WorkerFinished
   276  	require.Equal(t.T(), getNextUnit(task, workerStatus), frameModel.WorkerDMLoad)
   277  
   278  	task.Cfg.TaskMode = dmconfig.ModeAll
   279  	workerStatus.Unit = frameModel.WorkerDMDump
   280  	require.Equal(t.T(), getNextUnit(task, workerStatus), frameModel.WorkerDMLoad)
   281  	workerStatus.Unit = frameModel.WorkerDMLoad
   282  	require.Equal(t.T(), getNextUnit(task, workerStatus), frameModel.WorkerDMSync)
   283  	workerStatus.Unit = frameModel.WorkerDMSync
   284  	workerStatus.Stage = runtime.WorkerOnline
   285  	require.Equal(t.T(), getNextUnit(task, workerStatus), frameModel.WorkerDMSync)
   286  
   287  	task.Cfg.TaskMode = dmconfig.ModeIncrement
   288  	require.Equal(t.T(), getNextUnit(task, workerStatus), frameModel.WorkerDMSync)
   289  
   290  	task.Cfg.TaskMode = dmconfig.ModeFull
   291  	mockAgent.On("IsFresh", mock.Anything, mock.Anything, mock.Anything).Return(false, errors.New("checkpoint error")).Once()
   292  	unit, isFresh, err := workerManager.getCurrentUnit(ctx, task)
   293  	require.Error(t.T(), err)
   294  	require.Equal(t.T(), unit, frameModel.WorkerType(0))
   295  	require.False(t.T(), isFresh)
   296  	mockAgent.On("IsFresh", mock.Anything, mock.Anything, mock.Anything).Return(true, nil).Twice()
   297  	unit, isFresh, err = workerManager.getCurrentUnit(ctx, task)
   298  	require.NoError(t.T(), err)
   299  	require.Equal(t.T(), unit, frameModel.WorkerDMDump)
   300  	require.True(t.T(), isFresh)
   301  	mockAgent.On("IsFresh", mock.Anything, mock.Anything, mock.Anything).Return(false, nil).Once()
   302  	unit, isFresh, err = workerManager.getCurrentUnit(ctx, task)
   303  	require.NoError(t.T(), err)
   304  	require.Equal(t.T(), unit, frameModel.WorkerDMLoad)
   305  	require.False(t.T(), isFresh)
   306  
   307  	task.Cfg.TaskMode = dmconfig.ModeAll
   308  	mockAgent.On("IsFresh", mock.Anything, mock.Anything, mock.Anything).Return(true, nil).Times(3)
   309  	unit, isFresh, err = workerManager.getCurrentUnit(ctx, task)
   310  	require.NoError(t.T(), err)
   311  	require.Equal(t.T(), unit, frameModel.WorkerDMDump)
   312  	require.True(t.T(), isFresh)
   313  	mockAgent.On("IsFresh", mock.Anything, mock.Anything, mock.Anything).Return(true, nil).Once()
   314  	mockAgent.On("IsFresh", mock.Anything, mock.Anything, mock.Anything).Return(false, nil).Once()
   315  	unit, isFresh, err = workerManager.getCurrentUnit(ctx, task)
   316  	require.NoError(t.T(), err)
   317  	require.Equal(t.T(), unit, frameModel.WorkerDMLoad)
   318  	require.False(t.T(), isFresh)
   319  	mockAgent.On("IsFresh", mock.Anything, mock.Anything, mock.Anything).Return(false, nil).Once()
   320  	unit, isFresh, err = workerManager.getCurrentUnit(ctx, task)
   321  	require.NoError(t.T(), err)
   322  	require.Equal(t.T(), unit, frameModel.WorkerDMSync)
   323  	require.False(t.T(), isFresh)
   324  
   325  	task.Cfg.TaskMode = dmconfig.ModeIncrement
   326  	mockAgent.On("IsFresh", mock.Anything, mock.Anything, mock.Anything).Return(true, nil).Once()
   327  	unit, isFresh, err = workerManager.getCurrentUnit(ctx, task)
   328  	require.NoError(t.T(), err)
   329  	require.Equal(t.T(), unit, frameModel.WorkerDMSync)
   330  	require.True(t.T(), isFresh)
   331  	mockAgent.On("IsFresh", mock.Anything, mock.Anything, mock.Anything).Return(false, nil).Once()
   332  	unit, isFresh, err = workerManager.getCurrentUnit(ctx, task)
   333  	require.NoError(t.T(), err)
   334  	require.Equal(t.T(), unit, frameModel.WorkerDMSync)
   335  	require.False(t.T(), isFresh)
   336  }
   337  
   338  func (t *testDMJobmasterSuite) TestCheckAndScheduleWorkers() {
   339  	jobCfg := &config.JobCfg{}
   340  	require.NoError(t.T(), jobCfg.DecodeFile(jobTemplatePath))
   341  	jobCfg.TaskMode = dmconfig.ModeFull
   342  	job := metadata.NewJob(jobCfg)
   343  	checkpointAgent := &MockCheckpointAgent{}
   344  	workerAgent := &MockWorkerAgent{}
   345  	unitStore := metadata.NewUnitStateStore(kvmock.NewMetaMock())
   346  	workerManager := NewWorkerManager("job_id", nil, nil, unitStore, workerAgent, nil, checkpointAgent, log.L(), resModel.ResourceTypeLocalFile)
   347  
   348  	// new tasks
   349  	worker1 := "worker1"
   350  	worker2 := "worker2"
   351  	source1 := jobCfg.Upstreams[0].SourceID
   352  	source2 := jobCfg.Upstreams[1].SourceID
   353  	checkpointError := errors.New("checkpoint error")
   354  	createError := errors.New("create error")
   355  
   356  	getCurrentStatus := func() map[string]*metadata.UnitStatus {
   357  		state, err := workerManager.unitStore.Get(context.Background())
   358  		require.NoError(t.T(), err)
   359  		unitState, ok := state.(*metadata.UnitState)
   360  		require.True(t.T(), ok)
   361  		return unitState.CurrentUnitStatus
   362  	}
   363  	var currentStatus map[string]*metadata.UnitStatus
   364  
   365  	getTaskID := func() (string, string) {
   366  		if _, ok := currentStatus[source1]; ok {
   367  			return source1, source2
   368  		}
   369  		if _, ok := currentStatus[source2]; ok {
   370  			return source2, source1
   371  		}
   372  		return "", ""
   373  	}
   374  
   375  	checkpointAgent.On("IsFresh", mock.Anything, mock.Anything, mock.Anything).Return(true, nil).Times(3)
   376  	checkpointAgent.On("IsFresh", mock.Anything, mock.Anything, mock.Anything).Return(false, checkpointError).Once()
   377  	workerAgent.On("CreateWorker").Return(worker1, nil).Once()
   378  	require.EqualError(t.T(), workerManager.checkAndScheduleWorkers(context.Background(), job), checkpointError.Error())
   379  	wokerStatusMap := workerManager.WorkerStatus()
   380  	require.Len(t.T(), wokerStatusMap, 1)
   381  
   382  	currentStatus = getCurrentStatus()
   383  	taskID1, taskID2 := getTaskID()
   384  	require.Len(t.T(), currentStatus, 1)
   385  	require.Contains(t.T(), currentStatus, taskID1)
   386  	require.True(t.T(), time.Since(currentStatus[taskID1].CreatedTime).Seconds() < float64(time.Second))
   387  
   388  	// check again
   389  	checkpointAgent.On("IsFresh", mock.Anything, mock.Anything, mock.Anything).Return(true, nil).Times(3)
   390  	workerAgent.On("CreateWorker").Return(worker2, createError).Once()
   391  	require.EqualError(t.T(), workerManager.checkAndScheduleWorkers(context.Background(), job), createError.Error())
   392  	wokerStatusMap = workerManager.WorkerStatus()
   393  	require.Len(t.T(), wokerStatusMap, 2)
   394  	require.Contains(t.T(), wokerStatusMap, source1)
   395  	require.Contains(t.T(), wokerStatusMap, source2)
   396  	workerStatus1 := wokerStatusMap[source1]
   397  	workerStatus2 := wokerStatusMap[source2]
   398  	currentStatus = getCurrentStatus()
   399  	require.Len(t.T(), currentStatus, 2)
   400  	require.Contains(t.T(), currentStatus, taskID2)
   401  	require.True(t.T(), time.Since(currentStatus[taskID2].CreatedTime).Seconds() < float64(time.Second))
   402  
   403  	// expected
   404  	workerStatus1.Stage = runtime.WorkerOnline
   405  	workerStatus2.Stage = runtime.WorkerOnline
   406  	workerManager.UpdateWorkerStatus(workerStatus1)
   407  	workerManager.UpdateWorkerStatus(workerStatus2)
   408  	require.NoError(t.T(), workerManager.checkAndScheduleWorkers(context.Background(), job))
   409  	wokerStatusMap = workerManager.WorkerStatus()
   410  	require.Len(t.T(), wokerStatusMap, 2)
   411  	require.Contains(t.T(), wokerStatusMap, source1)
   412  	require.Contains(t.T(), wokerStatusMap, source2)
   413  	require.Equal(t.T(), wokerStatusMap[source1], workerStatus1)
   414  	require.Equal(t.T(), wokerStatusMap[source2], workerStatus2)
   415  
   416  	// switch unit
   417  	worker3 := "worker3"
   418  	workerStatus1.Stage = runtime.WorkerFinished
   419  	workerStatus3 := runtime.InitWorkerStatus(source1, frameModel.WorkerDMLoad, worker3)
   420  	workerManager.UpdateWorkerStatus(workerStatus1)
   421  	workerStatus1.Stage = runtime.WorkerFinished
   422  	workerAgent.On("CreateWorker").Return(worker3, nil).Once()
   423  	require.NoError(t.T(), workerManager.checkAndScheduleWorkers(context.Background(), job))
   424  	wokerStatusMap = workerManager.WorkerStatus()
   425  	require.Len(t.T(), wokerStatusMap, 2)
   426  	require.Contains(t.T(), wokerStatusMap, source1)
   427  	require.Contains(t.T(), wokerStatusMap, source2)
   428  	require.Equal(t.T(), wokerStatusMap[source1].ID, workerStatus3.ID)
   429  	require.Equal(t.T(), wokerStatusMap[source2].ID, workerStatus2.ID)
   430  	currentStatus = getCurrentStatus()
   431  	require.Contains(t.T(), currentStatus, source1)
   432  	require.True(t.T(), time.Since(currentStatus[source1].CreatedTime).Seconds() < float64(time.Second))
   433  	require.Equal(t.T(), frameModel.WorkerDMLoad, currentStatus[source1].Unit)
   434  
   435  	// unexpected
   436  	worker4 := "worker3"
   437  	workerStatus3.Stage = runtime.WorkerOffline
   438  	workerStatus4 := runtime.InitWorkerStatus(source1, frameModel.WorkerDMLoad, worker4)
   439  	workerManager.UpdateWorkerStatus(workerStatus3)
   440  	workerAgent.On("CreateWorker").Return(worker4, nil).Once()
   441  	require.NoError(t.T(), workerManager.checkAndScheduleWorkers(context.Background(), job))
   442  	wokerStatusMap = workerManager.WorkerStatus()
   443  	require.Len(t.T(), wokerStatusMap, 2)
   444  	require.Contains(t.T(), wokerStatusMap, source1)
   445  	require.Contains(t.T(), wokerStatusMap, source2)
   446  	require.Equal(t.T(), wokerStatusMap[source1].ID, workerStatus4.ID)
   447  	require.Equal(t.T(), wokerStatusMap[source2].ID, workerStatus2.ID)
   448  
   449  	// finished
   450  	workerStatus4.Stage = runtime.WorkerFinished
   451  	workerManager.UpdateWorkerStatus(workerStatus4)
   452  	require.NoError(t.T(), workerManager.checkAndScheduleWorkers(context.Background(), job))
   453  	wokerStatusMap = workerManager.WorkerStatus()
   454  	require.Len(t.T(), wokerStatusMap, 2)
   455  	require.Contains(t.T(), wokerStatusMap, source1)
   456  	require.Contains(t.T(), wokerStatusMap, source2)
   457  	require.Equal(t.T(), wokerStatusMap[source1].ID, workerStatus4.ID)
   458  	require.Equal(t.T(), wokerStatusMap[source2].ID, workerStatus2.ID)
   459  }
   460  
   461  func (t *testDMJobmasterSuite) TestWorkerManager() {
   462  	jobCfg := &config.JobCfg{}
   463  	require.NoError(t.T(), jobCfg.DecodeFile(jobTemplatePath))
   464  	job := metadata.NewJob(jobCfg)
   465  	jobStore := metadata.NewJobStore(kvmock.NewMetaMock(), log.L())
   466  	require.NoError(t.T(), jobStore.Put(context.Background(), job))
   467  
   468  	unitStore := metadata.NewUnitStateStore(kvmock.NewMetaMock())
   469  	checkpointAgent := &MockCheckpointAgent{}
   470  	workerAgent := &MockWorkerAgent{}
   471  	messageAgent := &dmpkg.MockMessageAgent{}
   472  	workerManager := NewWorkerManager("job_id", nil, jobStore, unitStore, workerAgent, messageAgent, checkpointAgent, log.L(), resModel.ResourceTypeLocalFile)
   473  	source1 := jobCfg.Upstreams[0].SourceID
   474  	source2 := jobCfg.Upstreams[1].SourceID
   475  
   476  	ctx, cancel := context.WithCancel(context.Background())
   477  	defer cancel()
   478  
   479  	worker1 := "worker1"
   480  	worker2 := "worker2"
   481  	checkpointError := errors.New("checkpoint error")
   482  	checkpointAgent.On("IsFresh", mock.Anything, mock.Anything, mock.Anything).Return(true, nil).Times(3)
   483  	checkpointAgent.On("IsFresh", mock.Anything, mock.Anything, mock.Anything).Return(false, checkpointError).Once()
   484  	checkpointAgent.On("IsFresh", mock.Anything, mock.Anything, mock.Anything).Return(true, nil).Times(6)
   485  	createError := errors.New("create error")
   486  	workerAgent.On("CreateWorker").Return(worker1, nil).Once()
   487  	workerAgent.On("CreateWorker").Return("", createError).Once()
   488  	workerAgent.On("CreateWorker").Return(worker2, nil).Once()
   489  
   490  	var wg sync.WaitGroup
   491  	wg.Add(1)
   492  	// run worker manager
   493  	go func() {
   494  		defer wg.Done()
   495  		t := time.NewTicker(50 * time.Millisecond)
   496  		for {
   497  			select {
   498  			case <-ctx.Done():
   499  				return
   500  			case <-t.C:
   501  				workerManager.DoTick(ctx)
   502  			}
   503  		}
   504  	}()
   505  
   506  	// first check
   507  	require.Eventually(t.T(), func() bool {
   508  		return len(workerManager.WorkerStatus()) == 2
   509  	}, 5*time.Second, 100*time.Millisecond)
   510  
   511  	workerStatus1 := workerManager.WorkerStatus()[source1]
   512  	workerStatus2 := workerManager.WorkerStatus()[source2]
   513  	require.Equal(t.T(), runtime.WorkerCreating, workerStatus1.Stage)
   514  	require.Equal(t.T(), runtime.WorkerCreating, workerStatus2.Stage)
   515  
   516  	// worker online
   517  	workerStatus1.Stage = runtime.WorkerOnline
   518  	workerStatus2.Stage = runtime.WorkerOnline
   519  	workerManager.UpdateWorkerStatus(workerStatus1)
   520  	workerManager.UpdateWorkerStatus(workerStatus2)
   521  
   522  	// mock check by interval
   523  	workerManager.SetNextCheckTime(time.Now().Add(10 * time.Millisecond))
   524  
   525  	// expected, no panic in mock agent
   526  	time.Sleep(1 * time.Second)
   527  
   528  	// worker2 offline
   529  	source := workerStatus2.TaskID
   530  	worker3 := "worker3"
   531  	workerStatus2.Stage = runtime.WorkerOffline
   532  	workerStatus3 := runtime.InitWorkerStatus(source, frameModel.WorkerDMDump, worker3)
   533  	// check by offline
   534  	workerManager.UpdateWorkerStatus(workerStatus2)
   535  	workerManager.SetNextCheckTime(time.Now())
   536  	checkpointAgent.On("IsFresh", mock.Anything, mock.Anything, mock.Anything).Return(true, nil).Times(3)
   537  	workerAgent.On("CreateWorker").Return(worker3, nil).Once()
   538  
   539  	// scheduled eventually
   540  	require.Eventually(t.T(), func() bool {
   541  		return workerManager.WorkerStatus()[source].ID == workerStatus3.ID
   542  	}, 5*time.Second, 100*time.Millisecond)
   543  	workerStatus3.Stage = runtime.WorkerOnline
   544  	workerManager.UpdateWorkerStatus(workerStatus3)
   545  
   546  	// mock remove task2 by update-job
   547  	delete(job.Tasks, source2)
   548  	job.Tasks[source1].Cfg.ModRevision++
   549  	jobStore.Put(context.Background(), job)
   550  	messageAgent.On("SendMessage").Return(nil).Twice()
   551  	// check by update-job, task2 stops, task1 restarts
   552  	workerManager.SetNextCheckTime(time.Now())
   553  	// both task removed eventually
   554  	require.Eventually(t.T(), func() bool {
   555  		messageAgent.Lock()
   556  		defer messageAgent.Unlock()
   557  		return len(messageAgent.Calls) == 2
   558  	}, 5*time.Second, 100*time.Millisecond)
   559  	workerStatus1.Stage = runtime.WorkerOffline
   560  	workerStatus3.Stage = runtime.WorkerOffline
   561  	workerManager.UpdateWorkerStatus(workerStatus1)
   562  	workerManager.UpdateWorkerStatus(workerStatus3)
   563  
   564  	// task1 eventually restarts
   565  	checkpointAgent.On("IsFresh", mock.Anything, mock.Anything, mock.Anything).Return(true, nil).Times(3)
   566  	workerAgent.On("CreateWorker").Return(worker1, nil).Once()
   567  	workerManager.SetNextCheckTime(time.Now())
   568  	require.Eventually(t.T(), func() bool {
   569  		return len(workerManager.WorkerStatus()) == 1
   570  	}, 5*time.Second, 100*time.Millisecond)
   571  
   572  	// mock task1 finished
   573  	worker4 := "worker4"
   574  	workerStatus := workerManager.WorkerStatus()[source1]
   575  	workerStatus.Stage = runtime.WorkerFinished
   576  	workerManager.UpdateWorkerStatus(workerStatus)
   577  	workerAgent.On("CreateWorker").Return(worker4, nil).Once()
   578  	// check by finished
   579  	workerManager.SetNextCheckTime(time.Now())
   580  	// scheduled eventually
   581  	require.Eventually(t.T(), func() bool {
   582  		return workerManager.WorkerStatus()[source1].ID == worker4
   583  	}, 5*time.Second, 100*time.Millisecond)
   584  
   585  	// mock deleting job
   586  	jobStore.MarkDeleting(ctx)
   587  	destroyError := errors.New("destroy error")
   588  	messageAgent.On("SendMessage").Return(destroyError).Once()
   589  	messageAgent.On("SendMessage").Return(nil).Once()
   590  
   591  	// check by delete
   592  	workerManager.SetNextCheckTime(time.Now())
   593  	require.Eventually(t.T(), func() bool {
   594  		messageAgent.Lock()
   595  		defer messageAgent.Unlock()
   596  		return len(messageAgent.Calls) == 4
   597  	}, 5*time.Second, 100*time.Millisecond)
   598  
   599  	workerStatus.Stage = runtime.WorkerOffline
   600  	workerManager.UpdateWorkerStatus(workerStatus)
   601  	// deleted eventually
   602  	workerManager.SetNextCheckTime(time.Now())
   603  	require.Eventually(t.T(), func() bool {
   604  		return len(workerManager.WorkerStatus()) == 0
   605  	}, 5*time.Second, 100*time.Millisecond)
   606  
   607  	cancel()
   608  	wg.Wait()
   609  
   610  	checkpointAgent.AssertExpectations(t.T())
   611  	workerAgent.AssertExpectations(t.T())
   612  	messageAgent.AssertExpectations(t.T())
   613  }
   614  
   615  type MockWorkerAgent struct {
   616  	sync.Mutex
   617  	mock.Mock
   618  }
   619  
   620  func (mockAgent *MockWorkerAgent) CreateWorker(
   621  	workerType framework.WorkerType, taskCfg interface{},
   622  	opts ...framework.CreateWorkerOpt,
   623  ) (frameModel.WorkerID, error) {
   624  	mockAgent.Lock()
   625  	defer mockAgent.Unlock()
   626  	args := mockAgent.Called()
   627  	return args.Get(0).(frameModel.WorkerID), args.Error(1)
   628  }