github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/worker/subtask.go (about)

     1  // Copyright 2019 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package worker
    15  
    16  import (
    17  	"context"
    18  	"sync"
    19  	"time"
    20  
    21  	"github.com/go-mysql-org/go-mysql/mysql"
    22  	"github.com/pingcap/failpoint"
    23  	"github.com/pingcap/tiflow/dm/config"
    24  	"github.com/pingcap/tiflow/dm/dumpling"
    25  	"github.com/pingcap/tiflow/dm/loader"
    26  	"github.com/pingcap/tiflow/dm/pb"
    27  	"github.com/pingcap/tiflow/dm/pkg/binlog"
    28  	"github.com/pingcap/tiflow/dm/pkg/gtid"
    29  	"github.com/pingcap/tiflow/dm/pkg/log"
    30  	"github.com/pingcap/tiflow/dm/pkg/shardddl/pessimism"
    31  	"github.com/pingcap/tiflow/dm/pkg/terror"
    32  	"github.com/pingcap/tiflow/dm/pkg/utils"
    33  	"github.com/pingcap/tiflow/dm/relay"
    34  	"github.com/pingcap/tiflow/dm/syncer"
    35  	"github.com/pingcap/tiflow/dm/unit"
    36  	"github.com/prometheus/client_golang/prometheus"
    37  	clientv3 "go.etcd.io/etcd/client/v3"
    38  	"go.uber.org/atomic"
    39  	"go.uber.org/zap"
    40  )
    41  
    42  const (
    43  	// the timeout to wait for relay catchup when switching from load unit to sync unit.
    44  	waitRelayCatchupTimeout = 30 * time.Second
    45  )
    46  
    47  // createRealUnits is subtask units initializer
    48  // it can be used for testing.
    49  var createUnits = createRealUnits
    50  
    51  // createRealUnits creates process units base on task mode.
    52  func createRealUnits(cfg *config.SubTaskConfig, etcdClient *clientv3.Client, workerName string, relay relay.Process) []unit.Unit {
    53  	failpoint.Inject("mockCreateUnitsDumpOnly", func(_ failpoint.Value) {
    54  		log.L().Info("create mock worker units with dump unit only", zap.String("failpoint", "mockCreateUnitsDumpOnly"))
    55  		failpoint.Return([]unit.Unit{dumpling.NewDumpling(cfg)})
    56  	})
    57  
    58  	us := make([]unit.Unit, 0, 3)
    59  	switch cfg.Mode {
    60  	case config.ModeAll:
    61  		us = append(us, dumpling.NewDumpling(cfg))
    62  		us = append(us, loader.NewLightning(cfg, etcdClient, workerName))
    63  		us = append(us, syncer.NewSyncer(cfg, etcdClient, relay))
    64  	case config.ModeFull:
    65  		// NOTE: maybe need another checker in the future?
    66  		us = append(us, dumpling.NewDumpling(cfg))
    67  		us = append(us, loader.NewLightning(cfg, etcdClient, workerName))
    68  	case config.ModeIncrement:
    69  		us = append(us, syncer.NewSyncer(cfg, etcdClient, relay))
    70  	case config.ModeDump:
    71  		us = append(us, dumpling.NewDumpling(cfg))
    72  	case config.ModeLoadSync:
    73  		us = append(us, loader.NewLightning(cfg, etcdClient, workerName))
    74  		us = append(us, syncer.NewSyncer(cfg, etcdClient, relay))
    75  	default:
    76  		log.L().Error("unsupported task mode", zap.String("subtask", cfg.Name), zap.String("task mode", cfg.Mode))
    77  	}
    78  	return us
    79  }
    80  
    81  // SubTask represents a sub task of data migration.
    82  type SubTask struct {
    83  	cfg *config.SubTaskConfig
    84  
    85  	initialized atomic.Bool
    86  
    87  	l log.Logger
    88  
    89  	sync.RWMutex
    90  	// ctx is used for the whole subtask. It will be created only when we new a subtask.
    91  	ctx    context.Context
    92  	cancel context.CancelFunc
    93  	// currCtx is used for one loop. It will be created each time we use st.run/st.Resume
    94  	currCtx    context.Context
    95  	currCancel context.CancelFunc
    96  
    97  	units    []unit.Unit // units do job one by one
    98  	currUnit unit.Unit
    99  	prevUnit unit.Unit
   100  	resultWg sync.WaitGroup
   101  
   102  	stage  pb.Stage          // stage of current sub task
   103  	result *pb.ProcessResult // the process result, nil when is processing
   104  
   105  	etcdClient *clientv3.Client
   106  
   107  	workerName string
   108  
   109  	validator *syncer.DataValidator
   110  }
   111  
   112  // NewSubTask is subtask initializer
   113  // it can be used for testing.
   114  var NewSubTask = NewRealSubTask
   115  
   116  // NewRealSubTask creates a new SubTask.
   117  func NewRealSubTask(cfg *config.SubTaskConfig, etcdClient *clientv3.Client, workerName string) *SubTask {
   118  	return NewSubTaskWithStage(cfg, pb.Stage_New, etcdClient, workerName)
   119  }
   120  
   121  // NewSubTaskWithStage creates a new SubTask with stage.
   122  func NewSubTaskWithStage(cfg *config.SubTaskConfig, stage pb.Stage, etcdClient *clientv3.Client, workerName string) *SubTask {
   123  	ctx, cancel := context.WithCancel(context.Background())
   124  	st := SubTask{
   125  		cfg:        cfg,
   126  		stage:      stage,
   127  		l:          log.With(zap.String("subtask", cfg.Name)),
   128  		ctx:        ctx,
   129  		cancel:     cancel,
   130  		etcdClient: etcdClient,
   131  		workerName: workerName,
   132  	}
   133  	updateTaskMetric(st.cfg.Name, st.cfg.SourceID, st.stage, st.workerName)
   134  	return &st
   135  }
   136  
   137  // initUnits initializes the sub task processing units.
   138  func (st *SubTask) initUnits(relay relay.Process) error {
   139  	st.units = createUnits(st.cfg, st.etcdClient, st.workerName, relay)
   140  	if len(st.units) < 1 {
   141  		return terror.ErrWorkerNoAvailUnits.Generate(st.cfg.Name, st.cfg.Mode)
   142  	}
   143  
   144  	initializeUnitSuccess := true
   145  	// when error occurred, initialized units should be closed
   146  	// when continue sub task from loader / syncer, ahead units should be closed
   147  	var needCloseUnits []unit.Unit
   148  	defer func() {
   149  		for _, u := range needCloseUnits {
   150  			u.Close()
   151  		}
   152  
   153  		st.initialized.Store(initializeUnitSuccess)
   154  	}()
   155  
   156  	// every unit does base initialization in `Init`, and this must pass before start running the sub task
   157  	// other setups can be done in `Process`, like Loader's prepare which depends on Mydumper's output
   158  	// but setups in `Process` should be treated carefully, let it's compatible with Pause / Resume
   159  	for i, u := range st.units {
   160  		ctx, cancel := context.WithTimeout(context.Background(), unit.DefaultInitTimeout)
   161  		err := u.Init(ctx)
   162  		cancel()
   163  		if err != nil {
   164  			initializeUnitSuccess = false
   165  			// when init fail, other units initialized before should be closed
   166  			for j := 0; j < i; j++ {
   167  				needCloseUnits = append(needCloseUnits, st.units[j])
   168  			}
   169  			return terror.Annotatef(err, "fail to initialize unit %s of subtask %s ", u.Type(), st.cfg.Name)
   170  		}
   171  	}
   172  
   173  	// if the sub task ran before, some units may be skipped
   174  	skipIdx := 0
   175  	for i := len(st.units) - 1; i > 0; i-- {
   176  		u := st.units[i]
   177  		ctx, cancel := context.WithTimeout(context.Background(), unit.DefaultInitTimeout)
   178  		isFresh, err := u.IsFreshTask(ctx)
   179  		cancel()
   180  		if err != nil {
   181  			initializeUnitSuccess = false
   182  			return terror.Annotatef(err, "fail to get fresh status of subtask %s %s", st.cfg.Name, u.Type())
   183  		} else if !isFresh {
   184  			skipIdx = i
   185  			st.l.Info("continue unit", zap.Stringer("unit", u.Type()))
   186  			break
   187  		}
   188  	}
   189  
   190  	needCloseUnits = st.units[:skipIdx]
   191  	st.units = st.units[skipIdx:]
   192  
   193  	st.setCurrUnit(st.units[0])
   194  	return nil
   195  }
   196  
   197  // Run runs the sub task.
   198  // TODO: check concurrent problems.
   199  func (st *SubTask) Run(expectStage pb.Stage, expectValidatorStage pb.Stage, relay relay.Process) {
   200  	if st.Stage() == pb.Stage_Finished || st.Stage() == pb.Stage_Running {
   201  		st.l.Warn("prepare to run a subtask with invalid stage",
   202  			zap.Stringer("current stage", st.Stage()),
   203  			zap.Stringer("expected stage", expectStage))
   204  		return
   205  	}
   206  
   207  	if err := st.initUnits(relay); err != nil {
   208  		st.l.Error("fail to initialize subtask", log.ShortError(err))
   209  		st.fail(err)
   210  		return
   211  	}
   212  
   213  	st.StartValidator(expectValidatorStage, true)
   214  
   215  	if expectStage == pb.Stage_Running {
   216  		st.run()
   217  	} else {
   218  		// if not want to run, still need to set the stage.
   219  		st.setStage(expectStage)
   220  	}
   221  }
   222  
   223  func (st *SubTask) run() {
   224  	st.setStageAndResult(pb.Stage_Running, nil) // clear previous result
   225  	ctx, cancel := context.WithCancel(st.ctx)
   226  	st.setCurrCtx(ctx, cancel)
   227  	err := st.unitTransWaitCondition(ctx)
   228  	if err != nil {
   229  		st.l.Error("wait condition", log.ShortError(err))
   230  		st.fail(err)
   231  		return
   232  	} else if ctx.Err() != nil {
   233  		st.l.Error("exit SubTask.run", log.ShortError(ctx.Err()))
   234  		return
   235  	}
   236  
   237  	cu := st.CurrUnit()
   238  	st.l.Info("start to run", zap.Stringer("unit", cu.Type()))
   239  	pr := make(chan pb.ProcessResult, 1)
   240  	st.resultWg.Add(1)
   241  	go st.fetchResultAndUpdateStage(pr)
   242  	go cu.Process(ctx, pr)
   243  }
   244  
   245  func (st *SubTask) StartValidator(expect pb.Stage, startWithSubtask bool) {
   246  	// when validator mode=none
   247  	if expect == pb.Stage_InvalidStage {
   248  		return
   249  	}
   250  	st.Lock()
   251  	defer st.Unlock()
   252  
   253  	if st.cfg.ValidatorCfg.Mode != config.ValidationFast && st.cfg.ValidatorCfg.Mode != config.ValidationFull {
   254  		return
   255  	}
   256  	var syncerObj *syncer.Syncer
   257  	var ok bool
   258  	for _, u := range st.units {
   259  		if syncerObj, ok = u.(*syncer.Syncer); ok {
   260  			break
   261  		}
   262  	}
   263  	if syncerObj == nil {
   264  		st.l.Warn("cannot start validator without syncer")
   265  		return
   266  	}
   267  	if st.validator == nil {
   268  		st.validator = syncer.NewContinuousDataValidator(st.cfg, syncerObj, startWithSubtask)
   269  	}
   270  	st.validator.Start(expect)
   271  }
   272  
   273  func (st *SubTask) StopValidator() {
   274  	st.Lock()
   275  	if st.validator != nil {
   276  		st.validator.Stop()
   277  	}
   278  	st.Unlock()
   279  }
   280  
   281  func (st *SubTask) setCurrCtx(ctx context.Context, cancel context.CancelFunc) {
   282  	st.Lock()
   283  	// call previous cancel func for safety
   284  	if st.currCancel != nil {
   285  		st.currCancel()
   286  	}
   287  	st.currCtx = ctx
   288  	st.currCancel = cancel
   289  	st.Unlock()
   290  }
   291  
   292  func (st *SubTask) callCurrCancel() {
   293  	st.RLock()
   294  	st.currCancel()
   295  	st.RUnlock()
   296  }
   297  
   298  // fetchResultAndUpdateStage fetches process result, call Pause of current unit if needed and updates the stage of subtask.
   299  func (st *SubTask) fetchResultAndUpdateStage(pr chan pb.ProcessResult) {
   300  	defer st.resultWg.Done()
   301  
   302  	result := <-pr
   303  	// filter the context canceled error
   304  	errs := make([]*pb.ProcessError, 0, 2)
   305  	for _, err := range result.Errors {
   306  		if !unit.IsCtxCanceledProcessErr(err) {
   307  			errs = append(errs, err)
   308  		}
   309  	}
   310  	result.Errors = errs
   311  
   312  	st.callCurrCancel() // dm-unit finished, canceled or error occurred, always cancel processing
   313  
   314  	var (
   315  		cu    = st.CurrUnit()
   316  		stage pb.Stage
   317  	)
   318  
   319  	// update the stage according to result
   320  	if len(result.Errors) == 0 {
   321  		switch st.Stage() {
   322  		case pb.Stage_Pausing:
   323  			// paused by st.Pause
   324  			stage = pb.Stage_Paused
   325  		case pb.Stage_Stopping:
   326  			// stopped by st.Close
   327  			stage = pb.Stage_Stopped
   328  		default:
   329  			// process finished with no error
   330  			stage = pb.Stage_Finished
   331  		}
   332  	} else {
   333  		// error occurred, paused
   334  		stage = pb.Stage_Paused
   335  	}
   336  	st.setStageAndResult(stage, &result)
   337  
   338  	st.l.Info("unit process returned", zap.Stringer("unit", cu.Type()), zap.Stringer("stage", stage), zap.String("status", st.StatusJSON()))
   339  
   340  	switch stage {
   341  	case pb.Stage_Finished:
   342  		cu.Close()
   343  		nu := st.getNextUnit()
   344  		if nu == nil {
   345  			// Now, when finished, it only stops the process
   346  			// if needed, we can refine to Close it
   347  			st.l.Info("all process units finished")
   348  		} else {
   349  			st.l.Info("switching to next unit", zap.Stringer("unit", cu.Type()))
   350  			st.setCurrUnit(nu)
   351  			// NOTE: maybe need a Lock mechanism for sharding scenario
   352  			st.run() // re-run for next process unit
   353  		}
   354  	case pb.Stage_Stopped:
   355  		// the caller will close current unit and more units after it, so we don't call cu.Close here.
   356  	case pb.Stage_Paused:
   357  		cu.Pause()
   358  		for _, err := range result.Errors {
   359  			st.l.Error("unit process error", zap.Stringer("unit", cu.Type()), zap.Any("error information", err))
   360  		}
   361  		st.l.Info("paused", zap.Stringer("unit", cu.Type()))
   362  	}
   363  }
   364  
   365  // setCurrUnit set current dm unit to ut.
   366  func (st *SubTask) setCurrUnit(cu unit.Unit) {
   367  	st.Lock()
   368  	defer st.Unlock()
   369  	pu := st.currUnit
   370  	st.currUnit = cu
   371  	st.prevUnit = pu
   372  }
   373  
   374  // CurrUnit returns current dm unit.
   375  func (st *SubTask) CurrUnit() unit.Unit {
   376  	st.RLock()
   377  	defer st.RUnlock()
   378  	return st.currUnit
   379  }
   380  
   381  // PrevUnit returns dm previous unit.
   382  func (st *SubTask) PrevUnit() unit.Unit {
   383  	st.RLock()
   384  	defer st.RUnlock()
   385  	return st.prevUnit
   386  }
   387  
   388  // closeUnits closes all un-closed units (current unit and all the subsequent units).
   389  func (st *SubTask) closeUnits() {
   390  	st.cancel()
   391  	st.resultWg.Wait()
   392  
   393  	var (
   394  		cu  = st.currUnit
   395  		cui = -1
   396  	)
   397  
   398  	for i, u := range st.units {
   399  		if u == cu {
   400  			cui = i
   401  			break
   402  		}
   403  	}
   404  	if cui < 0 {
   405  		return
   406  	}
   407  
   408  	for i := cui; i < len(st.units); i++ {
   409  		u := st.units[i]
   410  		st.l.Info("closing unit process", zap.Stringer("unit", cu.Type()))
   411  		u.Close()
   412  		st.l.Info("closing unit done", zap.Stringer("unit", cu.Type()))
   413  	}
   414  }
   415  
   416  func (st *SubTask) killCurrentUnit() {
   417  	if st.CurrUnit() != nil {
   418  		ut := st.CurrUnit().Type()
   419  		st.l.Info("kill unit", zap.String("task", st.cfg.Name), zap.Stringer("unit", ut))
   420  		st.CurrUnit().Kill()
   421  		st.l.Info("kill unit done", zap.String("task", st.cfg.Name), zap.Stringer("unit", ut))
   422  	}
   423  }
   424  
   425  // getNextUnit gets the next process unit from st.units
   426  // if no next unit, return nil.
   427  func (st *SubTask) getNextUnit() unit.Unit {
   428  	var (
   429  		nu  unit.Unit
   430  		cui = len(st.units)
   431  		cu  = st.CurrUnit()
   432  	)
   433  	for i, u := range st.units {
   434  		if u == cu {
   435  			cui = i
   436  		}
   437  		if i == cui+1 {
   438  			nu = u
   439  			break
   440  		}
   441  	}
   442  	return nu
   443  }
   444  
   445  func (st *SubTask) setStage(stage pb.Stage) {
   446  	st.Lock()
   447  	defer st.Unlock()
   448  	st.stage = stage
   449  	updateTaskMetric(st.cfg.Name, st.cfg.SourceID, st.stage, st.workerName)
   450  }
   451  
   452  func (st *SubTask) setStageAndResult(stage pb.Stage, result *pb.ProcessResult) {
   453  	st.Lock()
   454  	defer st.Unlock()
   455  	st.stage = stage
   456  	updateTaskMetric(st.cfg.Name, st.cfg.SourceID, st.stage, st.workerName)
   457  	st.result = result
   458  }
   459  
   460  // stageCAS sets stage to newStage if its current value is oldStage.
   461  func (st *SubTask) stageCAS(oldStage, newStage pb.Stage) bool {
   462  	st.Lock()
   463  	defer st.Unlock()
   464  
   465  	if st.stage == oldStage {
   466  		st.stage = newStage
   467  		updateTaskMetric(st.cfg.Name, st.cfg.SourceID, st.stage, st.workerName)
   468  		return true
   469  	}
   470  	return false
   471  }
   472  
   473  // setStageIfNotIn sets stage to newStage if its current value is not in oldStages.
   474  func (st *SubTask) setStageIfNotIn(oldStages []pb.Stage, newStage pb.Stage) bool {
   475  	st.Lock()
   476  	defer st.Unlock()
   477  	for _, s := range oldStages {
   478  		if st.stage == s {
   479  			return false
   480  		}
   481  	}
   482  	st.stage = newStage
   483  	updateTaskMetric(st.cfg.Name, st.cfg.SourceID, st.stage, st.workerName)
   484  	return true
   485  }
   486  
   487  // setStageIfNotIn sets stage to newStage if its current value is in oldStages.
   488  func (st *SubTask) setStageIfIn(oldStages []pb.Stage, newStage pb.Stage) bool {
   489  	st.Lock()
   490  	defer st.Unlock()
   491  	for _, s := range oldStages {
   492  		if st.stage == s {
   493  			st.stage = newStage
   494  			updateTaskMetric(st.cfg.Name, st.cfg.SourceID, st.stage, st.workerName)
   495  			return true
   496  		}
   497  	}
   498  	return false
   499  }
   500  
   501  // Stage returns the stage of the sub task.
   502  func (st *SubTask) Stage() pb.Stage {
   503  	st.RLock()
   504  	defer st.RUnlock()
   505  	return st.stage
   506  }
   507  
   508  func (st *SubTask) validatorStage() pb.Stage {
   509  	st.RLock()
   510  	defer st.RUnlock()
   511  	if st.validator != nil {
   512  		return st.validator.Stage()
   513  	}
   514  	return pb.Stage_InvalidStage
   515  }
   516  
   517  // markResultCanceled mark result as canceled if stage is Paused.
   518  // This func is used to pause a task which has been paused by error,
   519  // so the task will not auto resume by task checker.
   520  func (st *SubTask) markResultCanceled() bool {
   521  	st.Lock()
   522  	defer st.Unlock()
   523  	if st.stage == pb.Stage_Paused {
   524  		if st.result != nil && !st.result.IsCanceled {
   525  			st.l.Info("manually pause task which has been paused by errors")
   526  			st.result.IsCanceled = true
   527  			return true
   528  		}
   529  	}
   530  	return false
   531  }
   532  
   533  // Result returns the result of the sub task.
   534  func (st *SubTask) Result() *pb.ProcessResult {
   535  	st.RLock()
   536  	defer st.RUnlock()
   537  	if st.result == nil {
   538  		return nil
   539  	}
   540  	tempProcessResult, _ := st.result.Marshal()
   541  	newProcessResult := &pb.ProcessResult{}
   542  	_ = newProcessResult.Unmarshal(tempProcessResult)
   543  	return newProcessResult
   544  }
   545  
   546  // Close stops the sub task.
   547  func (st *SubTask) Close() {
   548  	st.l.Info("closing")
   549  	if !st.setStageIfNotIn([]pb.Stage{pb.Stage_Stopped, pb.Stage_Stopping, pb.Stage_Finished}, pb.Stage_Stopping) {
   550  		st.l.Info("subTask is already closed, no need to close")
   551  		return
   552  	}
   553  	st.closeUnits() // close all un-closed units
   554  	updateTaskMetric(st.cfg.Name, st.cfg.SourceID, pb.Stage_Stopped, st.workerName)
   555  
   556  	// we can start/stop validator independent of task, so we don't set st.validator = nil inside
   557  	st.StopValidator()
   558  	st.validator = nil
   559  }
   560  
   561  // Kill kill running unit and stop the sub task.
   562  func (st *SubTask) Kill() {
   563  	st.l.Info("killing")
   564  	if !st.setStageIfNotIn([]pb.Stage{pb.Stage_Stopped, pb.Stage_Stopping, pb.Stage_Finished}, pb.Stage_Stopping) {
   565  		st.l.Info("subTask is already closed, no need to close")
   566  		return
   567  	}
   568  	st.killCurrentUnit()
   569  	st.closeUnits() // close all un-closed units
   570  
   571  	cfg := st.getCfg()
   572  	updateTaskMetric(cfg.Name, cfg.SourceID, pb.Stage_Stopped, st.workerName)
   573  
   574  	st.StopValidator()
   575  	st.validator = nil
   576  }
   577  
   578  // Pause pauses a running subtask or a subtask paused by error.
   579  func (st *SubTask) Pause() error {
   580  	if st.markResultCanceled() {
   581  		return nil
   582  	}
   583  
   584  	if !st.stageCAS(pb.Stage_Running, pb.Stage_Pausing) {
   585  		return terror.ErrWorkerNotRunningStage.Generate(st.Stage().String())
   586  	}
   587  
   588  	st.callCurrCancel()
   589  	st.resultWg.Wait() // wait fetchResultAndUpdateStage set Pause stage
   590  
   591  	return nil
   592  }
   593  
   594  // Resume resumes the paused sub task
   595  // TODO: similar to Run, refactor later.
   596  func (st *SubTask) Resume(relay relay.Process) error {
   597  	if !st.initialized.Load() {
   598  		expectValidatorStage, err := getExpectValidatorStage(st.cfg.ValidatorCfg, st.etcdClient, st.cfg.SourceID, st.cfg.Name, 0)
   599  		if err != nil {
   600  			return terror.Annotate(err, "fail to get validator stage from etcd")
   601  		}
   602  		st.Run(pb.Stage_Running, expectValidatorStage, relay)
   603  		return nil
   604  	}
   605  
   606  	if !st.setStageIfIn([]pb.Stage{pb.Stage_Paused, pb.Stage_Stopped}, pb.Stage_Resuming) {
   607  		return terror.ErrWorkerNotPausedStage.Generate(st.Stage().String())
   608  	}
   609  
   610  	ctx, cancel := context.WithCancel(st.ctx)
   611  	st.setCurrCtx(ctx, cancel)
   612  	// NOTE: this may block if user resume a task
   613  	err := st.unitTransWaitCondition(ctx)
   614  	if err != nil {
   615  		st.l.Error("wait condition", log.ShortError(err))
   616  		st.fail(err)
   617  		return err
   618  	} else if ctx.Err() != nil {
   619  		// ctx.Err() != nil means this context is canceled in other go routine,
   620  		// that go routine will change the stage, so don't need to set stage to paused here.
   621  		// nolint:nilerr
   622  		return nil
   623  	}
   624  
   625  	cu := st.CurrUnit()
   626  	st.l.Info("resume with unit", zap.Stringer("unit", cu.Type()))
   627  
   628  	pr := make(chan pb.ProcessResult, 1)
   629  	st.resultWg.Add(1)
   630  	go st.fetchResultAndUpdateStage(pr)
   631  	go cu.Resume(ctx, pr)
   632  
   633  	st.setStageAndResult(pb.Stage_Running, nil) // clear previous result
   634  	return nil
   635  }
   636  
   637  // Update update the sub task's config.
   638  func (st *SubTask) Update(ctx context.Context, cfg *config.SubTaskConfig) error {
   639  	if !st.stageCAS(pb.Stage_Paused, pb.Stage_Paused) { // only test for Paused
   640  		return terror.ErrWorkerUpdateTaskStage.Generate(st.Stage().String())
   641  	}
   642  
   643  	for _, u := range st.units {
   644  		err := u.Update(ctx, cfg)
   645  		if err != nil {
   646  			return err
   647  		}
   648  	}
   649  	st.SetCfg(*cfg)
   650  	return nil
   651  }
   652  
   653  // OperateSchema operates schema for an upstream table.
   654  func (st *SubTask) OperateSchema(ctx context.Context, req *pb.OperateWorkerSchemaRequest) (schema string, err error) {
   655  	switch req.Op {
   656  	case pb.SchemaOp_ListMigrateTargets:
   657  		if st.Stage() != pb.Stage_Running && st.Stage() != pb.Stage_Paused {
   658  			return "", terror.ErrWorkerNotPausedStage.Generate(st.Stage().String())
   659  		}
   660  	default:
   661  		if st.Stage() != pb.Stage_Paused {
   662  			return "", terror.ErrWorkerNotPausedStage.Generate(st.Stage().String())
   663  		}
   664  	}
   665  
   666  	syncUnit, ok := st.currUnit.(*syncer.Syncer)
   667  	if !ok {
   668  		return "", terror.ErrWorkerOperSyncUnitOnly.Generate(st.currUnit.Type())
   669  	}
   670  
   671  	if st.validatorStage() == pb.Stage_Running && req.Op != pb.SchemaOp_ListMigrateTargets {
   672  		return "", terror.ErrWorkerValidatorNotPaused.Generate(pb.Stage_Running.String())
   673  	}
   674  
   675  	return syncUnit.OperateSchema(ctx, req)
   676  }
   677  
   678  // CheckUnit checks whether current unit is sync unit.
   679  func (st *SubTask) CheckUnit() bool {
   680  	st.RLock()
   681  	defer st.RUnlock()
   682  	flag := true
   683  	if _, ok := st.currUnit.(*syncer.Syncer); !ok {
   684  		flag = false
   685  	}
   686  	return flag
   687  }
   688  
   689  // CheckUnitCfgCanUpdate checks this unit cfg can update.
   690  func (st *SubTask) CheckUnitCfgCanUpdate(cfg *config.SubTaskConfig) error {
   691  	st.RLock()
   692  	defer st.RUnlock()
   693  
   694  	if st.currUnit == nil {
   695  		return terror.ErrWorkerUpdateSubTaskConfig.Generate(cfg.Name, pb.UnitType_InvalidUnit)
   696  	}
   697  
   698  	switch st.currUnit.Type() {
   699  	case pb.UnitType_Sync:
   700  		if s, ok := st.currUnit.(*syncer.Syncer); ok {
   701  			return s.CheckCanUpdateCfg(cfg)
   702  		}
   703  		// skip check for mock sync unit
   704  	default:
   705  		return terror.ErrWorkerUpdateSubTaskConfig.Generate(cfg.Name, st.currUnit.Type())
   706  	}
   707  	return nil
   708  }
   709  
   710  // ShardDDLOperation returns the current shard DDL lock operation.
   711  func (st *SubTask) ShardDDLOperation() *pessimism.Operation {
   712  	st.RLock()
   713  	defer st.RUnlock()
   714  
   715  	cu := st.currUnit
   716  	syncer2, ok := cu.(*syncer.Syncer)
   717  	if !ok {
   718  		return nil
   719  	}
   720  
   721  	return syncer2.ShardDDLOperation()
   722  }
   723  
   724  // unitTransWaitCondition waits when transferring from current unit to next unit.
   725  // Currently there is only one wait condition
   726  // from Load unit to Sync unit, wait for relay-log catched up with mydumper binlog position.
   727  func (st *SubTask) unitTransWaitCondition(subTaskCtx context.Context) error {
   728  	var (
   729  		gset1 mysql.GTIDSet
   730  		gset2 mysql.GTIDSet
   731  		pos1  *mysql.Position
   732  		pos2  *mysql.Position
   733  		err   error
   734  	)
   735  	pu := st.PrevUnit()
   736  	cu := st.CurrUnit()
   737  	if pu != nil && pu.Type() == pb.UnitType_Load && cu.Type() == pb.UnitType_Sync {
   738  		st.l.Info("wait condition between two units", zap.Stringer("previous unit", pu.Type()), zap.Stringer("unit", cu.Type()))
   739  		hub := GetConditionHub()
   740  
   741  		if !hub.w.relayEnabled.Load() {
   742  			return nil
   743  		}
   744  
   745  		ctxWait, cancelWait := context.WithTimeout(hub.w.ctx, waitRelayCatchupTimeout)
   746  		defer cancelWait()
   747  
   748  		loadStatus := pu.Status(nil).(*pb.LoadStatus)
   749  
   750  		cfg := st.getCfg()
   751  		if cfg.EnableGTID {
   752  			gset1, err = gtid.ParserGTID(cfg.Flavor, loadStatus.MetaBinlogGTID)
   753  			if err != nil {
   754  				return terror.WithClass(err, terror.ClassDMWorker)
   755  			}
   756  		} else {
   757  			pos1, err = utils.DecodeBinlogPosition(loadStatus.MetaBinlog)
   758  			if err != nil {
   759  				return terror.WithClass(err, terror.ClassDMWorker)
   760  			}
   761  		}
   762  
   763  		for {
   764  			relayStatus := hub.w.relayHolder.Status(nil)
   765  
   766  			if cfg.EnableGTID {
   767  				gset2, err = gtid.ParserGTID(cfg.Flavor, relayStatus.RelayBinlogGtid)
   768  				if err != nil {
   769  					return terror.WithClass(err, terror.ClassDMWorker)
   770  				}
   771  				rc, ok := binlog.CompareGTID(gset1, gset2)
   772  				if !ok {
   773  					return terror.ErrWorkerWaitRelayCatchupGTID.Generate(loadStatus.MetaBinlogGTID, relayStatus.RelayBinlogGtid)
   774  				}
   775  				if rc <= 0 {
   776  					break
   777  				}
   778  			} else {
   779  				pos2, err = utils.DecodeBinlogPosition(relayStatus.RelayBinlog)
   780  				if err != nil {
   781  					return terror.WithClass(err, terror.ClassDMWorker)
   782  				}
   783  				if pos1.Compare(*pos2) <= 0 {
   784  					break
   785  				}
   786  			}
   787  
   788  			st.l.Debug("wait relay to catchup", zap.Bool("enableGTID", cfg.EnableGTID), zap.Stringer("load end position", pos1), zap.String("load end gtid", loadStatus.MetaBinlogGTID), zap.Stringer("relay position", pos2), zap.String("relay gtid", relayStatus.RelayBinlogGtid))
   789  
   790  			select {
   791  			case <-ctxWait.Done():
   792  				if cfg.EnableGTID {
   793  					return terror.ErrWorkerWaitRelayCatchupTimeout.Generate(waitRelayCatchupTimeout, loadStatus.MetaBinlogGTID, relayStatus.RelayBinlogGtid)
   794  				}
   795  				return terror.ErrWorkerWaitRelayCatchupTimeout.Generate(waitRelayCatchupTimeout, pos1, pos2)
   796  			case <-subTaskCtx.Done():
   797  				return nil
   798  			case <-time.After(time.Millisecond * 50):
   799  			}
   800  		}
   801  		st.l.Info("relay binlog pos catchup loader end binlog pos")
   802  	}
   803  	return nil
   804  }
   805  
   806  func (st *SubTask) fail(err error) {
   807  	st.setStageAndResult(pb.Stage_Paused, &pb.ProcessResult{
   808  		Errors: []*pb.ProcessError{
   809  			unit.NewProcessError(err),
   810  		},
   811  	})
   812  }
   813  
   814  // HandleError handle error for syncer unit.
   815  func (st *SubTask) HandleError(ctx context.Context, req *pb.HandleWorkerErrorRequest, relay relay.Process) (string, error) {
   816  	// TODO: do we need lock here?
   817  	syncUnit, ok := st.currUnit.(*syncer.Syncer)
   818  	if !ok {
   819  		return "", terror.ErrWorkerOperSyncUnitOnly.Generate(st.currUnit.Type())
   820  	}
   821  
   822  	msg, err := syncUnit.HandleError(ctx, req)
   823  	if err != nil {
   824  		return "", err
   825  	}
   826  
   827  	if st.Stage() == pb.Stage_Paused && req.Op != pb.ErrorOp_List {
   828  		err = st.Resume(relay)
   829  	}
   830  	return msg, err
   831  }
   832  
   833  func (st *SubTask) getCfg() *config.SubTaskConfig {
   834  	st.RLock()
   835  	defer st.RUnlock()
   836  	return st.cfg
   837  }
   838  
   839  func (st *SubTask) SetCfg(subTaskConfig config.SubTaskConfig) {
   840  	st.Lock()
   841  	st.cfg = &subTaskConfig
   842  	st.Unlock()
   843  }
   844  
   845  func (st *SubTask) UpdateValidatorCfg(validatorCfg config.ValidatorConfig) {
   846  	st.Lock()
   847  	// if user start validator on the fly, we update validator mode and start-time
   848  	st.cfg.ValidatorCfg.Mode = validatorCfg.Mode
   849  	st.cfg.ValidatorCfg.StartTime = validatorCfg.StartTime
   850  	st.Unlock()
   851  }
   852  
   853  func (st *SubTask) getValidatorStage() pb.Stage {
   854  	st.RLock()
   855  	defer st.RUnlock()
   856  
   857  	if st.validator != nil {
   858  		return st.validator.Stage()
   859  	}
   860  	return pb.Stage_InvalidStage
   861  }
   862  
   863  func updateTaskMetric(task, sourceID string, stage pb.Stage, workerName string) {
   864  	if stage == pb.Stage_Stopped || stage == pb.Stage_Finished {
   865  		taskState.DeletePartialMatch(prometheus.Labels{"task": task, "source_id": sourceID})
   866  	} else {
   867  		taskState.WithLabelValues(task, sourceID, workerName).Set(float64(stage))
   868  	}
   869  }
   870  
   871  func (st *SubTask) GetValidatorError(errState pb.ValidateErrorState) ([]*pb.ValidationError, error) {
   872  	if validator := st.getValidator(); validator != nil {
   873  		return validator.GetValidatorError(errState)
   874  	}
   875  	cfg := st.getCfg()
   876  	return nil, terror.ErrValidatorNotFound.Generate(cfg.Name, cfg.SourceID)
   877  }
   878  
   879  func (st *SubTask) OperateValidatorError(op pb.ValidationErrOp, errID uint64, isAll bool) error {
   880  	if validator := st.getValidator(); validator != nil {
   881  		return validator.OperateValidatorError(op, errID, isAll)
   882  	}
   883  	cfg := st.getCfg()
   884  	return terror.ErrValidatorNotFound.Generate(cfg.Name, cfg.SourceID)
   885  }
   886  
   887  func (st *SubTask) UpdateValidator(req *pb.UpdateValidationWorkerRequest) error {
   888  	if validator := st.getValidator(); validator != nil {
   889  		return validator.UpdateValidator(req)
   890  	}
   891  	cfg := st.getCfg()
   892  	return terror.ErrValidatorNotFound.Generate(cfg.Name, cfg.SourceID)
   893  }
   894  
   895  func (st *SubTask) getValidator() *syncer.DataValidator {
   896  	st.RLock()
   897  	defer st.RUnlock()
   898  	return st.validator
   899  }
   900  
   901  func (st *SubTask) GetValidatorStatus() (*pb.ValidationStatus, error) {
   902  	validator := st.getValidator()
   903  	if validator == nil {
   904  		cfg := st.getCfg()
   905  		return nil, terror.ErrValidatorNotFound.Generate(cfg.Name, cfg.SourceID)
   906  	}
   907  	return validator.GetValidatorStatus(), nil
   908  }
   909  
   910  func (st *SubTask) GetValidatorTableStatus(filterStatus pb.Stage) ([]*pb.ValidationTableStatus, error) {
   911  	validator := st.getValidator()
   912  	if validator == nil {
   913  		cfg := st.getCfg()
   914  		return nil, terror.ErrValidatorNotFound.Generate(cfg.Name, cfg.SourceID)
   915  	}
   916  	return validator.GetValidatorTableStatus(filterStatus), nil
   917  }