github.com/pingcap/br@v5.3.0-alpha.0.20220125034240-ec59c7b6ce30+incompatible/pkg/restore/client.go (about)

     1  // Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0.
     2  
     3  package restore
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"crypto/tls"
     9  	"encoding/hex"
    10  	"encoding/json"
    11  	"fmt"
    12  	"sort"
    13  	"strconv"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/pingcap/br/pkg/metautil"
    18  
    19  	"github.com/opentracing/opentracing-go"
    20  	"github.com/pingcap/errors"
    21  	backuppb "github.com/pingcap/kvproto/pkg/backup"
    22  	"github.com/pingcap/kvproto/pkg/import_sstpb"
    23  	"github.com/pingcap/kvproto/pkg/metapb"
    24  	"github.com/pingcap/log"
    25  	"github.com/pingcap/parser/model"
    26  	"github.com/pingcap/tidb/domain"
    27  	"github.com/pingcap/tidb/kv"
    28  	"github.com/pingcap/tidb/statistics/handle"
    29  	"github.com/pingcap/tidb/tablecodec"
    30  	"github.com/pingcap/tidb/util/codec"
    31  	"github.com/tikv/client-go/v2/oracle"
    32  	pd "github.com/tikv/pd/client"
    33  	"github.com/tikv/pd/server/schedule/placement"
    34  	"go.uber.org/zap"
    35  	"golang.org/x/sync/errgroup"
    36  	"google.golang.org/grpc"
    37  	"google.golang.org/grpc/backoff"
    38  	"google.golang.org/grpc/credentials"
    39  	"google.golang.org/grpc/keepalive"
    40  
    41  	"github.com/pingcap/br/pkg/checksum"
    42  	"github.com/pingcap/br/pkg/conn"
    43  	berrors "github.com/pingcap/br/pkg/errors"
    44  	"github.com/pingcap/br/pkg/glue"
    45  	"github.com/pingcap/br/pkg/logutil"
    46  	"github.com/pingcap/br/pkg/pdutil"
    47  	"github.com/pingcap/br/pkg/redact"
    48  	"github.com/pingcap/br/pkg/storage"
    49  	"github.com/pingcap/br/pkg/summary"
    50  	"github.com/pingcap/br/pkg/utils"
    51  )
    52  
    53  // defaultChecksumConcurrency is the default number of the concurrent
    54  // checksum tasks.
    55  const defaultChecksumConcurrency = 64
    56  
    57  // Client sends requests to restore files.
    58  type Client struct {
    59  	pdClient      pd.Client
    60  	toolClient    SplitClient
    61  	fileImporter  FileImporter
    62  	workerPool    *utils.WorkerPool
    63  	tlsConf       *tls.Config
    64  	keepaliveConf keepalive.ClientParameters
    65  
    66  	databases  map[string]*utils.Database
    67  	ddlJobs    []*model.Job
    68  	backupMeta *backuppb.BackupMeta
    69  	// TODO Remove this field or replace it with a []*DB,
    70  	// since https://github.com/pingcap/br/pull/377 needs more DBs to speed up DDL execution.
    71  	// And for now, we must inject a pool of DBs to `Client.GoCreateTables`, otherwise there would be a race condition.
    72  	// This is dirty: why we need DBs from different sources?
    73  	// By replace it with a []*DB, we can remove the dirty parameter of `Client.GoCreateTable`,
    74  	// along with them in some private functions.
    75  	// Before you do it, you can firstly read discussions at
    76  	// https://github.com/pingcap/br/pull/377#discussion_r446594501,
    77  	// this probably isn't as easy as it seems like (however, not hard, too :D)
    78  	db              *DB
    79  	rateLimit       uint64
    80  	isOnline        bool
    81  	noSchema        bool
    82  	hasSpeedLimited bool
    83  
    84  	restoreStores []uint64
    85  
    86  	storage            storage.ExternalStorage
    87  	backend            *backuppb.StorageBackend
    88  	switchModeInterval time.Duration
    89  	switchCh           chan struct{}
    90  
    91  	// statHandler and dom are used for analyze table after restore.
    92  	// it will backup stats with #dump.DumpStatsToJSON
    93  	// and restore stats with #dump.LoadStatsFromJSON
    94  	statsHandler *handle.Handle
    95  	dom          *domain.Domain
    96  }
    97  
    98  // NewRestoreClient returns a new RestoreClient.
    99  func NewRestoreClient(
   100  	g glue.Glue,
   101  	pdClient pd.Client,
   102  	store kv.Storage,
   103  	tlsConf *tls.Config,
   104  	keepaliveConf keepalive.ClientParameters,
   105  ) (*Client, error) {
   106  	db, err := NewDB(g, store)
   107  	if err != nil {
   108  		return nil, errors.Trace(err)
   109  	}
   110  	dom, err := g.GetDomain(store)
   111  	if err != nil {
   112  		return nil, errors.Trace(err)
   113  	}
   114  
   115  	var statsHandle *handle.Handle
   116  	// tikv.Glue will return nil, tidb.Glue will return available domain
   117  	if dom != nil {
   118  		statsHandle = dom.StatsHandle()
   119  	}
   120  
   121  	return &Client{
   122  		pdClient:      pdClient,
   123  		toolClient:    NewSplitClient(pdClient, tlsConf),
   124  		db:            db,
   125  		tlsConf:       tlsConf,
   126  		keepaliveConf: keepaliveConf,
   127  		switchCh:      make(chan struct{}),
   128  		dom:           dom,
   129  		statsHandler:  statsHandle,
   130  	}, nil
   131  }
   132  
   133  // SetRateLimit to set rateLimit.
   134  func (rc *Client) SetRateLimit(rateLimit uint64) {
   135  	rc.rateLimit = rateLimit
   136  }
   137  
   138  // SetStorage set ExternalStorage for client.
   139  func (rc *Client) SetStorage(ctx context.Context, backend *backuppb.StorageBackend, opts *storage.ExternalStorageOptions) error {
   140  	var err error
   141  	rc.storage, err = storage.New(ctx, backend, opts)
   142  	if err != nil {
   143  		return errors.Trace(err)
   144  	}
   145  	rc.backend = backend
   146  	return nil
   147  }
   148  
   149  // GetPDClient returns a pd client.
   150  func (rc *Client) GetPDClient() pd.Client {
   151  	return rc.pdClient
   152  }
   153  
   154  // IsOnline tells if it's a online restore.
   155  func (rc *Client) IsOnline() bool {
   156  	return rc.isOnline
   157  }
   158  
   159  // SetSwitchModeInterval set switch mode interval for client.
   160  func (rc *Client) SetSwitchModeInterval(interval time.Duration) {
   161  	rc.switchModeInterval = interval
   162  }
   163  
   164  // Close a client.
   165  func (rc *Client) Close() {
   166  	// rc.db can be nil in raw kv mode.
   167  	if rc.db != nil {
   168  		rc.db.Close()
   169  	}
   170  	log.Info("Restore client closed")
   171  }
   172  
   173  // InitBackupMeta loads schemas from BackupMeta to initialize RestoreClient.
   174  func (rc *Client) InitBackupMeta(c context.Context, backupMeta *backuppb.BackupMeta, backend *backuppb.StorageBackend, externalStorage storage.ExternalStorage, reader *metautil.MetaReader) error {
   175  	if !backupMeta.IsRawKv {
   176  		databases, err := utils.LoadBackupTables(c, reader)
   177  		if err != nil {
   178  			return errors.Trace(err)
   179  		}
   180  		rc.databases = databases
   181  
   182  		var ddlJobs []*model.Job
   183  		// ddls is the bytes of json.Marshal
   184  		ddls, err := reader.ReadDDLs(c)
   185  		if err != nil {
   186  			return errors.Trace(err)
   187  		}
   188  		if len(ddls) != 0 {
   189  			err = json.Unmarshal(ddls, &ddlJobs)
   190  			if err != nil {
   191  				return errors.Trace(err)
   192  			}
   193  		}
   194  		rc.ddlJobs = ddlJobs
   195  	}
   196  	rc.backupMeta = backupMeta
   197  	log.Info("load backupmeta", zap.Int("databases", len(rc.databases)), zap.Int("jobs", len(rc.ddlJobs)))
   198  
   199  	metaClient := NewSplitClient(rc.pdClient, rc.tlsConf)
   200  	importCli := NewImportClient(metaClient, rc.tlsConf, rc.keepaliveConf)
   201  	rc.fileImporter = NewFileImporter(metaClient, importCli, backend, rc.backupMeta.IsRawKv, rc.rateLimit)
   202  	return rc.fileImporter.CheckMultiIngestSupport(c, rc.pdClient)
   203  }
   204  
   205  // IsRawKvMode checks whether the backup data is in raw kv format, in which case transactional recover is forbidden.
   206  func (rc *Client) IsRawKvMode() bool {
   207  	return rc.backupMeta.IsRawKv
   208  }
   209  
   210  // GetFilesInRawRange gets all files that are in the given range or intersects with the given range.
   211  func (rc *Client) GetFilesInRawRange(startKey []byte, endKey []byte, cf string) ([]*backuppb.File, error) {
   212  	if !rc.IsRawKvMode() {
   213  		return nil, errors.Annotate(berrors.ErrRestoreModeMismatch, "the backup data is not in raw kv mode")
   214  	}
   215  
   216  	for _, rawRange := range rc.backupMeta.RawRanges {
   217  		// First check whether the given range is backup-ed. If not, we cannot perform the restore.
   218  		if rawRange.Cf != cf {
   219  			continue
   220  		}
   221  
   222  		if (len(rawRange.EndKey) > 0 && bytes.Compare(startKey, rawRange.EndKey) >= 0) ||
   223  			(len(endKey) > 0 && bytes.Compare(rawRange.StartKey, endKey) >= 0) {
   224  			// The restoring range is totally out of the current range. Skip it.
   225  			continue
   226  		}
   227  
   228  		if bytes.Compare(startKey, rawRange.StartKey) < 0 ||
   229  			utils.CompareEndKey(endKey, rawRange.EndKey) > 0 {
   230  			// Only partial of the restoring range is in the current backup-ed range. So the given range can't be fully
   231  			// restored.
   232  			return nil, errors.Annotatef(berrors.ErrRestoreRangeMismatch,
   233  				"the given range to restore [%s, %s) is not fully covered by the range that was backed up [%s, %s)",
   234  				redact.Key(startKey), redact.Key(endKey), redact.Key(rawRange.StartKey), redact.Key(rawRange.EndKey),
   235  			)
   236  		}
   237  
   238  		// We have found the range that contains the given range. Find all necessary files.
   239  		files := make([]*backuppb.File, 0)
   240  
   241  		for _, file := range rc.backupMeta.Files {
   242  			if file.Cf != cf {
   243  				continue
   244  			}
   245  
   246  			if len(file.EndKey) > 0 && bytes.Compare(file.EndKey, startKey) < 0 {
   247  				// The file is before the range to be restored.
   248  				continue
   249  			}
   250  			if len(endKey) > 0 && bytes.Compare(endKey, file.StartKey) <= 0 {
   251  				// The file is after the range to be restored.
   252  				// The specified endKey is exclusive, so when it equals to a file's startKey, the file is still skipped.
   253  				continue
   254  			}
   255  
   256  			files = append(files, file)
   257  		}
   258  
   259  		// There should be at most one backed up range that covers the restoring range.
   260  		return files, nil
   261  	}
   262  
   263  	return nil, errors.Annotate(berrors.ErrRestoreRangeMismatch, "no backup data in the range")
   264  }
   265  
   266  // SetConcurrency sets the concurrency of dbs tables files.
   267  func (rc *Client) SetConcurrency(c uint) {
   268  	rc.workerPool = utils.NewWorkerPool(c, "file")
   269  }
   270  
   271  // EnableOnline sets the mode of restore to online.
   272  func (rc *Client) EnableOnline() {
   273  	rc.isOnline = true
   274  }
   275  
   276  // GetTLSConfig returns the tls config.
   277  func (rc *Client) GetTLSConfig() *tls.Config {
   278  	return rc.tlsConf
   279  }
   280  
   281  // GetTS gets a new timestamp from PD.
   282  func (rc *Client) GetTS(ctx context.Context) (uint64, error) {
   283  	p, l, err := rc.pdClient.GetTS(ctx)
   284  	if err != nil {
   285  		return 0, errors.Trace(err)
   286  	}
   287  	restoreTS := oracle.ComposeTS(p, l)
   288  	return restoreTS, nil
   289  }
   290  
   291  // ResetTS resets the timestamp of PD to a bigger value.
   292  func (rc *Client) ResetTS(ctx context.Context, pdAddrs []string) error {
   293  	restoreTS := rc.backupMeta.GetEndVersion()
   294  	log.Info("reset pd timestamp", zap.Uint64("ts", restoreTS))
   295  	i := 0
   296  	return utils.WithRetry(ctx, func() error {
   297  		idx := i % len(pdAddrs)
   298  		i++
   299  		return pdutil.ResetTS(ctx, pdAddrs[idx], restoreTS, rc.tlsConf)
   300  	}, newPDReqBackoffer())
   301  }
   302  
   303  // GetPlacementRules return the current placement rules.
   304  func (rc *Client) GetPlacementRules(ctx context.Context, pdAddrs []string) ([]placement.Rule, error) {
   305  	var placementRules []placement.Rule
   306  	i := 0
   307  	errRetry := utils.WithRetry(ctx, func() error {
   308  		var err error
   309  		idx := i % len(pdAddrs)
   310  		i++
   311  		placementRules, err = pdutil.GetPlacementRules(ctx, pdAddrs[idx], rc.tlsConf)
   312  		return errors.Trace(err)
   313  	}, newPDReqBackoffer())
   314  	return placementRules, errors.Trace(errRetry)
   315  }
   316  
   317  // GetDatabases returns all databases.
   318  func (rc *Client) GetDatabases() []*utils.Database {
   319  	dbs := make([]*utils.Database, 0, len(rc.databases))
   320  	for _, db := range rc.databases {
   321  		dbs = append(dbs, db)
   322  	}
   323  	return dbs
   324  }
   325  
   326  // GetDatabase returns a database by name.
   327  func (rc *Client) GetDatabase(name string) *utils.Database {
   328  	return rc.databases[name]
   329  }
   330  
   331  // GetDDLJobs returns ddl jobs.
   332  func (rc *Client) GetDDLJobs() []*model.Job {
   333  	return rc.ddlJobs
   334  }
   335  
   336  // GetTableSchema returns the schema of a table from TiDB.
   337  func (rc *Client) GetTableSchema(
   338  	dom *domain.Domain,
   339  	dbName model.CIStr,
   340  	tableName model.CIStr,
   341  ) (*model.TableInfo, error) {
   342  	info := dom.InfoSchema()
   343  	table, err := info.TableByName(dbName, tableName)
   344  	if err != nil {
   345  		return nil, errors.Trace(err)
   346  	}
   347  	return table.Meta(), nil
   348  }
   349  
   350  // CreateDatabase creates a database.
   351  func (rc *Client) CreateDatabase(ctx context.Context, db *model.DBInfo) error {
   352  	if rc.IsSkipCreateSQL() {
   353  		log.Info("skip create database", zap.Stringer("database", db.Name))
   354  		return nil
   355  	}
   356  	return rc.db.CreateDatabase(ctx, db)
   357  }
   358  
   359  // CreateTables creates multiple tables, and returns their rewrite rules.
   360  func (rc *Client) CreateTables(
   361  	dom *domain.Domain,
   362  	tables []*metautil.Table,
   363  	newTS uint64,
   364  ) (*RewriteRules, []*model.TableInfo, error) {
   365  	rewriteRules := &RewriteRules{
   366  		Data: make([]*import_sstpb.RewriteRule, 0),
   367  	}
   368  	newTables := make([]*model.TableInfo, 0, len(tables))
   369  	errCh := make(chan error, 1)
   370  	tbMapping := map[string]int{}
   371  	for i, t := range tables {
   372  		tbMapping[t.Info.Name.String()] = i
   373  	}
   374  	dataCh := rc.GoCreateTables(context.TODO(), dom, tables, newTS, nil, errCh)
   375  	for et := range dataCh {
   376  		rules := et.RewriteRule
   377  		rewriteRules.Data = append(rewriteRules.Data, rules.Data...)
   378  		newTables = append(newTables, et.Table)
   379  	}
   380  	// Let's ensure that it won't break the original order.
   381  	sort.Slice(newTables, func(i, j int) bool {
   382  		return tbMapping[newTables[i].Name.String()] < tbMapping[newTables[j].Name.String()]
   383  	})
   384  
   385  	select {
   386  	case err, ok := <-errCh:
   387  		if ok {
   388  			return nil, nil, errors.Trace(err)
   389  		}
   390  	default:
   391  	}
   392  	return rewriteRules, newTables, nil
   393  }
   394  
   395  func (rc *Client) createTable(
   396  	ctx context.Context,
   397  	db *DB,
   398  	dom *domain.Domain,
   399  	table *metautil.Table,
   400  	newTS uint64,
   401  ) (CreatedTable, error) {
   402  	if rc.IsSkipCreateSQL() {
   403  		log.Info("skip create table and alter autoIncID", zap.Stringer("table", table.Info.Name))
   404  	} else {
   405  		err := db.CreateTable(ctx, table)
   406  		if err != nil {
   407  			return CreatedTable{}, errors.Trace(err)
   408  		}
   409  	}
   410  	newTableInfo, err := rc.GetTableSchema(dom, table.DB.Name, table.Info.Name)
   411  	if err != nil {
   412  		return CreatedTable{}, errors.Trace(err)
   413  	}
   414  	if newTableInfo.IsCommonHandle != table.Info.IsCommonHandle {
   415  		return CreatedTable{}, errors.Annotatef(berrors.ErrRestoreModeMismatch,
   416  			"Clustered index option mismatch. Restored cluster's @@tidb_enable_clustered_index should be %v (backup table = %v, created table = %v).",
   417  			transferBoolToValue(table.Info.IsCommonHandle),
   418  			table.Info.IsCommonHandle,
   419  			newTableInfo.IsCommonHandle)
   420  	}
   421  	rules := GetRewriteRules(newTableInfo, table.Info, newTS)
   422  	et := CreatedTable{
   423  		RewriteRule: rules,
   424  		Table:       newTableInfo,
   425  		OldTable:    table,
   426  	}
   427  	return et, nil
   428  }
   429  
   430  // GoCreateTables create tables, and generate their information.
   431  // this function will use workers as the same number of sessionPool,
   432  // leave sessionPool nil to send DDLs sequential.
   433  func (rc *Client) GoCreateTables(
   434  	ctx context.Context,
   435  	dom *domain.Domain,
   436  	tables []*metautil.Table,
   437  	newTS uint64,
   438  	dbPool []*DB,
   439  	errCh chan<- error,
   440  ) <-chan CreatedTable {
   441  	// Could we have a smaller size of tables?
   442  	log.Info("start create tables")
   443  
   444  	if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
   445  		span1 := span.Tracer().StartSpan("Client.GoCreateTables", opentracing.ChildOf(span.Context()))
   446  		defer span1.Finish()
   447  		ctx = opentracing.ContextWithSpan(ctx, span1)
   448  	}
   449  	outCh := make(chan CreatedTable, len(tables))
   450  	rater := logutil.TraceRateOver(logutil.MetricTableCreatedCounter)
   451  	createOneTable := func(c context.Context, db *DB, t *metautil.Table) error {
   452  		select {
   453  		case <-c.Done():
   454  			return c.Err()
   455  		default:
   456  		}
   457  		rt, err := rc.createTable(c, db, dom, t, newTS)
   458  		if err != nil {
   459  			log.Error("create table failed",
   460  				zap.Error(err),
   461  				zap.Stringer("db", t.DB.Name),
   462  				zap.Stringer("table", t.Info.Name))
   463  			return errors.Trace(err)
   464  		}
   465  		log.Debug("table created and send to next",
   466  			zap.Int("output chan size", len(outCh)),
   467  			zap.Stringer("table", t.Info.Name),
   468  			zap.Stringer("database", t.DB.Name))
   469  		outCh <- rt
   470  		rater.Inc()
   471  		rater.L().Info("table created",
   472  			zap.Stringer("table", t.Info.Name),
   473  			zap.Stringer("database", t.DB.Name))
   474  		return nil
   475  	}
   476  	go func() {
   477  		defer close(outCh)
   478  		defer log.Debug("all tables are created")
   479  		var err error
   480  		if len(dbPool) > 0 {
   481  			err = rc.createTablesWithDBPool(ctx, createOneTable, tables, dbPool)
   482  		} else {
   483  			err = rc.createTablesWithSoleDB(ctx, createOneTable, tables)
   484  		}
   485  		if err != nil {
   486  			errCh <- err
   487  		}
   488  	}()
   489  	return outCh
   490  }
   491  
   492  func (rc *Client) createTablesWithSoleDB(ctx context.Context,
   493  	createOneTable func(ctx context.Context, db *DB, t *metautil.Table) error,
   494  	tables []*metautil.Table) error {
   495  	for _, t := range tables {
   496  		if err := createOneTable(ctx, rc.db, t); err != nil {
   497  			return errors.Trace(err)
   498  		}
   499  	}
   500  	return nil
   501  }
   502  
   503  func (rc *Client) createTablesWithDBPool(ctx context.Context,
   504  	createOneTable func(ctx context.Context, db *DB, t *metautil.Table) error,
   505  	tables []*metautil.Table, dbPool []*DB) error {
   506  	eg, ectx := errgroup.WithContext(ctx)
   507  	workers := utils.NewWorkerPool(uint(len(dbPool)), "DDL workers")
   508  	for _, t := range tables {
   509  		table := t
   510  		workers.ApplyWithIDInErrorGroup(eg, func(id uint64) error {
   511  			db := dbPool[id%uint64(len(dbPool))]
   512  			return createOneTable(ectx, db, table)
   513  		})
   514  	}
   515  	return eg.Wait()
   516  }
   517  
   518  // ExecDDLs executes the queries of the ddl jobs.
   519  func (rc *Client) ExecDDLs(ctx context.Context, ddlJobs []*model.Job) error {
   520  	// Sort the ddl jobs by schema version in ascending order.
   521  	sort.Slice(ddlJobs, func(i, j int) bool {
   522  		return ddlJobs[i].BinlogInfo.SchemaVersion < ddlJobs[j].BinlogInfo.SchemaVersion
   523  	})
   524  
   525  	for _, job := range ddlJobs {
   526  		err := rc.db.ExecDDL(ctx, job)
   527  		if err != nil {
   528  			return errors.Trace(err)
   529  		}
   530  		log.Info("execute ddl query",
   531  			zap.String("db", job.SchemaName),
   532  			zap.String("query", job.Query),
   533  			zap.Int64("historySchemaVersion", job.BinlogInfo.SchemaVersion))
   534  	}
   535  	return nil
   536  }
   537  
   538  func (rc *Client) setSpeedLimit(ctx context.Context) error {
   539  	if !rc.hasSpeedLimited && rc.rateLimit != 0 {
   540  		stores, err := conn.GetAllTiKVStores(ctx, rc.pdClient, conn.SkipTiFlash)
   541  		if err != nil {
   542  			return errors.Trace(err)
   543  		}
   544  		for _, store := range stores {
   545  			err = rc.fileImporter.setDownloadSpeedLimit(ctx, store.GetId())
   546  			if err != nil {
   547  				return errors.Trace(err)
   548  			}
   549  		}
   550  		rc.hasSpeedLimited = true
   551  	}
   552  	return nil
   553  }
   554  
   555  // isFilesBelongToSameRange check whether two files are belong to the same range with different cf.
   556  func isFilesBelongToSameRange(f1, f2 string) bool {
   557  	// the backup date file pattern is `{store_id}_{region_id}_{epoch_version}_{key}_{ts}_{cf}.sst`
   558  	// so we need to compare with out the `_{cf}.sst` suffix
   559  	idx1 := strings.LastIndex(f1, "_")
   560  	idx2 := strings.LastIndex(f2, "_")
   561  
   562  	if idx1 < 0 || idx2 < 0 {
   563  		panic(fmt.Sprintf("invalid backup data file name: '%s', '%s'", f1, f2))
   564  	}
   565  
   566  	return f1[:idx1] == f2[:idx2]
   567  }
   568  
   569  func drainFilesByRange(files []*backuppb.File, supportMulti bool) ([]*backuppb.File, []*backuppb.File) {
   570  	if len(files) == 0 {
   571  		return nil, nil
   572  	}
   573  	if !supportMulti {
   574  		return files[:1], files[1:]
   575  	}
   576  	idx := 1
   577  	for idx < len(files) {
   578  		if !isFilesBelongToSameRange(files[idx-1].Name, files[idx].Name) {
   579  			break
   580  		}
   581  		idx++
   582  	}
   583  
   584  	return files[:idx], files[idx:]
   585  }
   586  
   587  // RestoreFiles tries to restore the files.
   588  func (rc *Client) RestoreFiles(
   589  	ctx context.Context,
   590  	files []*backuppb.File,
   591  	rewriteRules *RewriteRules,
   592  	updateCh glue.Progress,
   593  ) (err error) {
   594  	start := time.Now()
   595  	defer func() {
   596  		elapsed := time.Since(start)
   597  		if err == nil {
   598  			log.Info("Restore files", zap.Duration("take", elapsed), logutil.Files(files))
   599  			summary.CollectSuccessUnit("files", len(files), elapsed)
   600  		}
   601  	}()
   602  
   603  	log.Debug("start to restore files", zap.Int("files", len(files)))
   604  
   605  	if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
   606  		span1 := span.Tracer().StartSpan("Client.RestoreFiles", opentracing.ChildOf(span.Context()))
   607  		defer span1.Finish()
   608  		ctx = opentracing.ContextWithSpan(ctx, span1)
   609  	}
   610  
   611  	eg, ectx := errgroup.WithContext(ctx)
   612  	err = rc.setSpeedLimit(ctx)
   613  	if err != nil {
   614  		return errors.Trace(err)
   615  	}
   616  
   617  	var rangeFiles []*backuppb.File
   618  	var leftFiles []*backuppb.File
   619  	for rangeFiles, leftFiles = drainFilesByRange(files, rc.fileImporter.supportMultiIngest); len(rangeFiles) != 0; rangeFiles, leftFiles = drainFilesByRange(leftFiles, rc.fileImporter.supportMultiIngest) {
   620  		filesReplica := rangeFiles
   621  		rc.workerPool.ApplyOnErrorGroup(eg,
   622  			func() error {
   623  				fileStart := time.Now()
   624  				defer func() {
   625  					log.Info("import files done", logutil.Files(filesReplica),
   626  						zap.Duration("take", time.Since(fileStart)))
   627  					updateCh.Inc()
   628  				}()
   629  				return rc.fileImporter.Import(ectx, filesReplica, rewriteRules)
   630  			})
   631  	}
   632  
   633  	if err := eg.Wait(); err != nil {
   634  		summary.CollectFailureUnit("file", err)
   635  		log.Error(
   636  			"restore files failed",
   637  			zap.Error(err),
   638  		)
   639  		return errors.Trace(err)
   640  	}
   641  	return nil
   642  }
   643  
   644  // RestoreRaw tries to restore raw keys in the specified range.
   645  func (rc *Client) RestoreRaw(
   646  	ctx context.Context, startKey []byte, endKey []byte, files []*backuppb.File, updateCh glue.Progress,
   647  ) error {
   648  	start := time.Now()
   649  	defer func() {
   650  		elapsed := time.Since(start)
   651  		log.Info("Restore Raw",
   652  			logutil.Key("startKey", startKey),
   653  			logutil.Key("endKey", endKey),
   654  			zap.Duration("take", elapsed))
   655  	}()
   656  	errCh := make(chan error, len(files))
   657  	eg, ectx := errgroup.WithContext(ctx)
   658  	defer close(errCh)
   659  
   660  	err := rc.fileImporter.SetRawRange(startKey, endKey)
   661  	if err != nil {
   662  		return errors.Trace(err)
   663  	}
   664  
   665  	for _, file := range files {
   666  		fileReplica := file
   667  		rc.workerPool.ApplyOnErrorGroup(eg,
   668  			func() error {
   669  				defer updateCh.Inc()
   670  				return rc.fileImporter.Import(ectx, []*backuppb.File{fileReplica}, EmptyRewriteRule())
   671  			})
   672  	}
   673  	if err := eg.Wait(); err != nil {
   674  		log.Error(
   675  			"restore raw range failed",
   676  			logutil.Key("startKey", startKey),
   677  			logutil.Key("endKey", endKey),
   678  			zap.Error(err),
   679  		)
   680  		return errors.Trace(err)
   681  	}
   682  	log.Info(
   683  		"finish to restore raw range",
   684  		logutil.Key("startKey", startKey),
   685  		logutil.Key("endKey", endKey),
   686  	)
   687  	return nil
   688  }
   689  
   690  // SwitchToImportMode switch tikv cluster to import mode.
   691  func (rc *Client) SwitchToImportMode(ctx context.Context) {
   692  	// tikv automatically switch to normal mode in every 10 minutes
   693  	// so we need ping tikv in less than 10 minute
   694  	go func() {
   695  		tick := time.NewTicker(rc.switchModeInterval)
   696  		defer tick.Stop()
   697  
   698  		// [important!] switch tikv mode into import at the beginning
   699  		log.Info("switch to import mode at beginning")
   700  		err := rc.switchTiKVMode(ctx, import_sstpb.SwitchMode_Import)
   701  		if err != nil {
   702  			log.Warn("switch to import mode failed", zap.Error(err))
   703  		}
   704  
   705  		for {
   706  			select {
   707  			case <-ctx.Done():
   708  				return
   709  			case <-tick.C:
   710  				log.Info("switch to import mode")
   711  				err := rc.switchTiKVMode(ctx, import_sstpb.SwitchMode_Import)
   712  				if err != nil {
   713  					log.Warn("switch to import mode failed", zap.Error(err))
   714  				}
   715  			case <-rc.switchCh:
   716  				log.Info("stop automatic switch to import mode")
   717  				return
   718  			}
   719  		}
   720  	}()
   721  }
   722  
   723  // SwitchToNormalMode switch tikv cluster to normal mode.
   724  func (rc *Client) SwitchToNormalMode(ctx context.Context) error {
   725  	close(rc.switchCh)
   726  	return rc.switchTiKVMode(ctx, import_sstpb.SwitchMode_Normal)
   727  }
   728  
   729  func (rc *Client) switchTiKVMode(ctx context.Context, mode import_sstpb.SwitchMode) error {
   730  	stores, err := conn.GetAllTiKVStores(ctx, rc.pdClient, conn.SkipTiFlash)
   731  	if err != nil {
   732  		return errors.Trace(err)
   733  	}
   734  	bfConf := backoff.DefaultConfig
   735  	bfConf.MaxDelay = time.Second * 3
   736  	for _, store := range stores {
   737  		opt := grpc.WithInsecure()
   738  		if rc.tlsConf != nil {
   739  			opt = grpc.WithTransportCredentials(credentials.NewTLS(rc.tlsConf))
   740  		}
   741  		gctx, cancel := context.WithTimeout(ctx, time.Second*5)
   742  		connection, err := grpc.DialContext(
   743  			gctx,
   744  			store.GetAddress(),
   745  			opt,
   746  			grpc.WithConnectParams(grpc.ConnectParams{Backoff: bfConf}),
   747  			// we don't need to set keepalive timeout here, because the connection lives
   748  			// at most 5s. (shorter than minimal value for keepalive time!)
   749  		)
   750  		cancel()
   751  		if err != nil {
   752  			return errors.Trace(err)
   753  		}
   754  		client := import_sstpb.NewImportSSTClient(connection)
   755  		_, err = client.SwitchMode(ctx, &import_sstpb.SwitchModeRequest{
   756  			Mode: mode,
   757  		})
   758  		if err != nil {
   759  			return errors.Trace(err)
   760  		}
   761  		err = connection.Close()
   762  		if err != nil {
   763  			log.Error("close grpc connection failed in switch mode", zap.Error(err))
   764  			continue
   765  		}
   766  	}
   767  	return nil
   768  }
   769  
   770  // GoValidateChecksum forks a goroutine to validate checksum after restore.
   771  // it returns a channel fires a struct{} when all things get done.
   772  func (rc *Client) GoValidateChecksum(
   773  	ctx context.Context,
   774  	tableStream <-chan CreatedTable,
   775  	kvClient kv.Client,
   776  	errCh chan<- error,
   777  	updateCh glue.Progress,
   778  	concurrency uint,
   779  ) <-chan struct{} {
   780  	log.Info("Start to validate checksum")
   781  	outCh := make(chan struct{}, 1)
   782  	workers := utils.NewWorkerPool(defaultChecksumConcurrency, "RestoreChecksum")
   783  	go func() {
   784  		wg, ectx := errgroup.WithContext(ctx)
   785  		defer func() {
   786  			log.Info("all checksum ended")
   787  			if err := wg.Wait(); err != nil {
   788  				errCh <- err
   789  			}
   790  			outCh <- struct{}{}
   791  			close(outCh)
   792  		}()
   793  		for {
   794  			select {
   795  			// if we use ectx here, maybe canceled will mask real error.
   796  			case <-ctx.Done():
   797  				errCh <- ctx.Err()
   798  			case tbl, ok := <-tableStream:
   799  				if !ok {
   800  					return
   801  				}
   802  				workers.ApplyOnErrorGroup(wg, func() error {
   803  					start := time.Now()
   804  					defer func() {
   805  						elapsed := time.Since(start)
   806  						summary.CollectDuration("restore checksum", elapsed)
   807  						summary.CollectSuccessUnit("table checksum", 1, elapsed)
   808  					}()
   809  					err := rc.execChecksum(ectx, tbl, kvClient, concurrency)
   810  					if err != nil {
   811  						return errors.Trace(err)
   812  					}
   813  					updateCh.Inc()
   814  					return nil
   815  				})
   816  			}
   817  		}
   818  	}()
   819  	return outCh
   820  }
   821  
   822  func (rc *Client) execChecksum(ctx context.Context, tbl CreatedTable, kvClient kv.Client, concurrency uint) error {
   823  	logger := log.With(
   824  		zap.String("db", tbl.OldTable.DB.Name.O),
   825  		zap.String("table", tbl.OldTable.Info.Name.O),
   826  	)
   827  
   828  	if tbl.OldTable.NoChecksum() {
   829  		logger.Warn("table has no checksum, skipping checksum")
   830  		return nil
   831  	}
   832  
   833  	if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
   834  		span1 := span.Tracer().StartSpan("Client.execChecksum", opentracing.ChildOf(span.Context()))
   835  		defer span1.Finish()
   836  		ctx = opentracing.ContextWithSpan(ctx, span1)
   837  	}
   838  
   839  	startTS, err := rc.GetTS(ctx)
   840  	if err != nil {
   841  		return errors.Trace(err)
   842  	}
   843  	exe, err := checksum.NewExecutorBuilder(tbl.Table, startTS).
   844  		SetOldTable(tbl.OldTable).
   845  		SetConcurrency(concurrency).
   846  		Build()
   847  	if err != nil {
   848  		return errors.Trace(err)
   849  	}
   850  	checksumResp, err := exe.Execute(ctx, kvClient, func() {
   851  		// TODO: update progress here.
   852  	})
   853  	if err != nil {
   854  		return errors.Trace(err)
   855  	}
   856  
   857  	table := tbl.OldTable
   858  	if checksumResp.Checksum != table.Crc64Xor ||
   859  		checksumResp.TotalKvs != table.TotalKvs ||
   860  		checksumResp.TotalBytes != table.TotalBytes {
   861  		logger.Error("failed in validate checksum",
   862  			zap.Uint64("origin tidb crc64", table.Crc64Xor),
   863  			zap.Uint64("calculated crc64", checksumResp.Checksum),
   864  			zap.Uint64("origin tidb total kvs", table.TotalKvs),
   865  			zap.Uint64("calculated total kvs", checksumResp.TotalKvs),
   866  			zap.Uint64("origin tidb total bytes", table.TotalBytes),
   867  			zap.Uint64("calculated total bytes", checksumResp.TotalBytes),
   868  		)
   869  		return errors.Annotate(berrors.ErrRestoreChecksumMismatch, "failed to validate checksum")
   870  	}
   871  	if table.Stats != nil {
   872  		logger.Info("start loads analyze after validate checksum",
   873  			zap.Int64("old id", tbl.OldTable.Info.ID),
   874  			zap.Int64("new id", tbl.Table.ID),
   875  		)
   876  		if err := rc.statsHandler.LoadStatsFromJSON(rc.dom.InfoSchema(), table.Stats); err != nil {
   877  			logger.Error("analyze table failed", zap.Any("table", table.Stats), zap.Error(err))
   878  		}
   879  	}
   880  	return nil
   881  }
   882  
   883  const (
   884  	restoreLabelKey   = "exclusive"
   885  	restoreLabelValue = "restore"
   886  )
   887  
   888  // LoadRestoreStores loads the stores used to restore data.
   889  func (rc *Client) LoadRestoreStores(ctx context.Context) error {
   890  	if !rc.isOnline {
   891  		return nil
   892  	}
   893  	if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
   894  		span1 := span.Tracer().StartSpan("Client.LoadRestoreStores", opentracing.ChildOf(span.Context()))
   895  		defer span1.Finish()
   896  		ctx = opentracing.ContextWithSpan(ctx, span1)
   897  	}
   898  
   899  	stores, err := rc.pdClient.GetAllStores(ctx)
   900  	if err != nil {
   901  		return errors.Trace(err)
   902  	}
   903  	for _, s := range stores {
   904  		if s.GetState() != metapb.StoreState_Up {
   905  			continue
   906  		}
   907  		for _, l := range s.GetLabels() {
   908  			if l.GetKey() == restoreLabelKey && l.GetValue() == restoreLabelValue {
   909  				rc.restoreStores = append(rc.restoreStores, s.GetId())
   910  				break
   911  			}
   912  		}
   913  	}
   914  	log.Info("load restore stores", zap.Uint64s("store-ids", rc.restoreStores))
   915  	return nil
   916  }
   917  
   918  // ResetRestoreLabels removes the exclusive labels of the restore stores.
   919  func (rc *Client) ResetRestoreLabels(ctx context.Context) error {
   920  	if !rc.isOnline {
   921  		return nil
   922  	}
   923  	log.Info("start reseting store labels")
   924  	return rc.toolClient.SetStoresLabel(ctx, rc.restoreStores, restoreLabelKey, "")
   925  }
   926  
   927  // SetupPlacementRules sets rules for the tables' regions.
   928  func (rc *Client) SetupPlacementRules(ctx context.Context, tables []*model.TableInfo) error {
   929  	if !rc.isOnline || len(rc.restoreStores) == 0 {
   930  		return nil
   931  	}
   932  	log.Info("start setting placement rules")
   933  	rule, err := rc.toolClient.GetPlacementRule(ctx, "pd", "default")
   934  	if err != nil {
   935  		return errors.Trace(err)
   936  	}
   937  	rule.Index = 100
   938  	rule.Override = true
   939  	rule.LabelConstraints = append(rule.LabelConstraints, placement.LabelConstraint{
   940  		Key:    restoreLabelKey,
   941  		Op:     "in",
   942  		Values: []string{restoreLabelValue},
   943  	})
   944  	for _, t := range tables {
   945  		rule.ID = rc.getRuleID(t.ID)
   946  		rule.StartKeyHex = hex.EncodeToString(codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(t.ID)))
   947  		rule.EndKeyHex = hex.EncodeToString(codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(t.ID+1)))
   948  		err = rc.toolClient.SetPlacementRule(ctx, rule)
   949  		if err != nil {
   950  			return errors.Trace(err)
   951  		}
   952  	}
   953  	log.Info("finish setting placement rules")
   954  	return nil
   955  }
   956  
   957  // WaitPlacementSchedule waits PD to move tables to restore stores.
   958  func (rc *Client) WaitPlacementSchedule(ctx context.Context, tables []*model.TableInfo) error {
   959  	if !rc.isOnline || len(rc.restoreStores) == 0 {
   960  		return nil
   961  	}
   962  	log.Info("start waiting placement schedule")
   963  	ticker := time.NewTicker(time.Second * 10)
   964  	defer ticker.Stop()
   965  	for {
   966  		select {
   967  		case <-ticker.C:
   968  			ok, progress, err := rc.checkRegions(ctx, tables)
   969  			if err != nil {
   970  				return errors.Trace(err)
   971  			}
   972  			if ok {
   973  				log.Info("finish waiting placement schedule")
   974  				return nil
   975  			}
   976  			log.Info("placement schedule progress: " + progress)
   977  		case <-ctx.Done():
   978  			return ctx.Err()
   979  		}
   980  	}
   981  }
   982  
   983  func (rc *Client) checkRegions(ctx context.Context, tables []*model.TableInfo) (bool, string, error) {
   984  	for i, t := range tables {
   985  		start := codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(t.ID))
   986  		end := codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(t.ID+1))
   987  		ok, regionProgress, err := rc.checkRange(ctx, start, end)
   988  		if err != nil {
   989  			return false, "", errors.Trace(err)
   990  		}
   991  		if !ok {
   992  			return false, fmt.Sprintf("table %v/%v, %s", i, len(tables), regionProgress), nil
   993  		}
   994  	}
   995  	return true, "", nil
   996  }
   997  
   998  func (rc *Client) checkRange(ctx context.Context, start, end []byte) (bool, string, error) {
   999  	regions, err := rc.toolClient.ScanRegions(ctx, start, end, -1)
  1000  	if err != nil {
  1001  		return false, "", errors.Trace(err)
  1002  	}
  1003  	for i, r := range regions {
  1004  	NEXT_PEER:
  1005  		for _, p := range r.Region.GetPeers() {
  1006  			for _, storeID := range rc.restoreStores {
  1007  				if p.GetStoreId() == storeID {
  1008  					continue NEXT_PEER
  1009  				}
  1010  			}
  1011  			return false, fmt.Sprintf("region %v/%v", i, len(regions)), nil
  1012  		}
  1013  	}
  1014  	return true, "", nil
  1015  }
  1016  
  1017  // ResetPlacementRules removes placement rules for tables.
  1018  func (rc *Client) ResetPlacementRules(ctx context.Context, tables []*model.TableInfo) error {
  1019  	if !rc.isOnline || len(rc.restoreStores) == 0 {
  1020  		return nil
  1021  	}
  1022  	log.Info("start reseting placement rules")
  1023  	var failedTables []int64
  1024  	for _, t := range tables {
  1025  		err := rc.toolClient.DeletePlacementRule(ctx, "pd", rc.getRuleID(t.ID))
  1026  		if err != nil {
  1027  			log.Info("failed to delete placement rule for table", zap.Int64("table-id", t.ID))
  1028  			failedTables = append(failedTables, t.ID)
  1029  		}
  1030  	}
  1031  	if len(failedTables) > 0 {
  1032  		return errors.Annotatef(berrors.ErrPDInvalidResponse, "failed to delete placement rules for tables %v", failedTables)
  1033  	}
  1034  	return nil
  1035  }
  1036  
  1037  func (rc *Client) getRuleID(tableID int64) string {
  1038  	return "restore-t" + strconv.FormatInt(tableID, 10)
  1039  }
  1040  
  1041  // IsIncremental returns whether this backup is incremental.
  1042  func (rc *Client) IsIncremental() bool {
  1043  	return !(rc.backupMeta.StartVersion == rc.backupMeta.EndVersion ||
  1044  		rc.backupMeta.StartVersion == 0)
  1045  }
  1046  
  1047  // EnableSkipCreateSQL sets switch of skip create schema and tables.
  1048  func (rc *Client) EnableSkipCreateSQL() {
  1049  	rc.noSchema = true
  1050  }
  1051  
  1052  // IsSkipCreateSQL returns whether we need skip create schema and tables in restore.
  1053  func (rc *Client) IsSkipCreateSQL() bool {
  1054  	return rc.noSchema
  1055  }
  1056  
  1057  // PreCheckTableTiFlashReplica checks whether TiFlash replica is less than TiFlash node.
  1058  func (rc *Client) PreCheckTableTiFlashReplica(
  1059  	ctx context.Context,
  1060  	tables []*metautil.Table,
  1061  ) error {
  1062  	tiFlashStores, err := conn.GetAllTiKVStores(ctx, rc.pdClient, conn.TiFlashOnly)
  1063  	if err != nil {
  1064  		return errors.Trace(err)
  1065  	}
  1066  	tiFlashStoreCount := len(tiFlashStores)
  1067  	for _, table := range tables {
  1068  		if table.Info.TiFlashReplica != nil && table.Info.TiFlashReplica.Count > uint64(tiFlashStoreCount) {
  1069  			// we cannot satisfy TiFlash replica in restore cluster. so we should
  1070  			// set TiFlashReplica to unavailable in tableInfo, to avoid TiDB cannot sense TiFlash and make plan to TiFlash
  1071  			// see details at https://github.com/pingcap/br/issues/931
  1072  			table.Info.TiFlashReplica = nil
  1073  		}
  1074  	}
  1075  	return nil
  1076  }
  1077  
  1078  // PreCheckTableClusterIndex checks whether backup tables and existed tables have different cluster index options。
  1079  func (rc *Client) PreCheckTableClusterIndex(
  1080  	tables []*metautil.Table,
  1081  	ddlJobs []*model.Job,
  1082  	dom *domain.Domain,
  1083  ) error {
  1084  	for _, table := range tables {
  1085  		oldTableInfo, err := rc.GetTableSchema(dom, table.DB.Name, table.Info.Name)
  1086  		// table exists in database
  1087  		if err == nil {
  1088  			if table.Info.IsCommonHandle != oldTableInfo.IsCommonHandle {
  1089  				return errors.Annotatef(berrors.ErrRestoreModeMismatch,
  1090  					"Clustered index option mismatch. Restored cluster's @@tidb_enable_clustered_index should be %v (backup table = %v, created table = %v).",
  1091  					transferBoolToValue(table.Info.IsCommonHandle),
  1092  					table.Info.IsCommonHandle,
  1093  					oldTableInfo.IsCommonHandle)
  1094  			}
  1095  		}
  1096  	}
  1097  	for _, job := range ddlJobs {
  1098  		if job.Type == model.ActionCreateTable {
  1099  			tableInfo := job.BinlogInfo.TableInfo
  1100  			if tableInfo != nil {
  1101  				oldTableInfo, err := rc.GetTableSchema(dom, model.NewCIStr(job.SchemaName), tableInfo.Name)
  1102  				// table exists in database
  1103  				if err == nil {
  1104  					if tableInfo.IsCommonHandle != oldTableInfo.IsCommonHandle {
  1105  						return errors.Annotatef(berrors.ErrRestoreModeMismatch,
  1106  							"Clustered index option mismatch. Restored cluster's @@tidb_enable_clustered_index should be %v (backup table = %v, created table = %v).",
  1107  							transferBoolToValue(tableInfo.IsCommonHandle),
  1108  							tableInfo.IsCommonHandle,
  1109  							oldTableInfo.IsCommonHandle)
  1110  					}
  1111  				}
  1112  			}
  1113  		}
  1114  	}
  1115  	return nil
  1116  }
  1117  
  1118  func transferBoolToValue(enable bool) string {
  1119  	if enable {
  1120  		return "ON"
  1121  	}
  1122  	return "OFF"
  1123  }