
     1  // Copyright 2021 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    14  package restore
    16  import (
    17  	"container/heap"
    18  	"context"
    19  	"database/sql"
    20  	"fmt"
    21  	"sync"
    22  	"sync/atomic"
    23  	"time"
    25  	""
    26  	""
    27  	""
    28  	tidbcfg ""
    29  	""
    30  	""
    31  	""
    32  	""
    33  	pd ""
    34  	""
    36  	""
    37  	""
    38  	""
    39  	""
    40  	""
    41  	""
    42  	""
    43  	""
    44  )
    46  const (
    47  	preUpdateServiceSafePointFactor = 3
    49  	maxErrorRetryCount = 3
    50  )
    52  var (
    53  	serviceSafePointTTL int64 = 10 * 60 // 10 min in seconds
    55  	minDistSQLScanConcurrency = 4
    56  )
    58  // RemoteChecksum represents a checksum result got from tidb.
    59  type RemoteChecksum struct {
    60  	Schema     string
    61  	Table      string
    62  	Checksum   uint64
    63  	TotalKVs   uint64
    64  	TotalBytes uint64
    65  }
    67  type ChecksumManager interface {
    68  	Checksum(ctx context.Context, tableInfo *checkpoints.TidbTableInfo) (*RemoteChecksum, error)
    69  }
    71  func newChecksumManager(ctx context.Context, rc *Controller) (ChecksumManager, error) {
    72  	// if we don't need checksum, just return nil
    73  	if rc.cfg.TikvImporter.Backend == config.BackendTiDB || rc.cfg.PostRestore.Checksum == config.OpLevelOff {
    74  		return nil, nil
    75  	}
    77  	pdAddr := rc.cfg.TiDB.PdAddr
    78  	pdVersion, err := pdutil.FetchPDVersion(ctx, rc.tls, pdAddr)
    79  	if err != nil {
    80  		return nil, errors.Trace(err)
    81  	}
    83  	// for v4.0.0 or upper, we can use the gc ttl api
    84  	var manager ChecksumManager
    85  	if pdVersion.Major >= 4 {
    86  		tlsOpt := rc.tls.ToPDSecurityOption()
    87  		pdCli, err := pd.NewClientWithContext(ctx, []string{pdAddr}, tlsOpt)
    88  		if err != nil {
    89  			return nil, errors.Trace(err)
    90  		}
    92  		// TODO: make tikv.Driver{}.Open use arguments instead of global variables
    93  		if tlsOpt.CAPath != "" {
    94  			conf := tidbcfg.GetGlobalConfig()
    95  			conf.Security.ClusterSSLCA = tlsOpt.CAPath
    96  			conf.Security.ClusterSSLCert = tlsOpt.CertPath
    97  			conf.Security.ClusterSSLKey = tlsOpt.KeyPath
    98  			tidbcfg.StoreGlobalConfig(conf)
    99  		}
   100  		store, err := driver.TiKVDriver{}.Open(fmt.Sprintf("tikv://%s?disableGC=true", pdAddr))
   101  		if err != nil {
   102  			return nil, errors.Trace(err)
   103  		}
   105  		manager = newTiKVChecksumManager(store.GetClient(), pdCli, uint(rc.cfg.TiDB.DistSQLScanConcurrency))
   106  	} else {
   107  		db, err := rc.tidbGlue.GetDB()
   108  		if err != nil {
   109  			return nil, errors.Trace(err)
   110  		}
   111  		manager = newTiDBChecksumExecutor(db)
   112  	}
   114  	return manager, nil
   115  }
   117  // fetch checksum for tidb sql client
   118  type tidbChecksumExecutor struct {
   119  	db      *sql.DB
   120  	manager *gcLifeTimeManager
   121  }
   123  func newTiDBChecksumExecutor(db *sql.DB) *tidbChecksumExecutor {
   124  	return &tidbChecksumExecutor{
   125  		db:      db,
   126  		manager: newGCLifeTimeManager(),
   127  	}
   128  }
   130  func (e *tidbChecksumExecutor) Checksum(ctx context.Context, tableInfo *checkpoints.TidbTableInfo) (*RemoteChecksum, error) {
   131  	var err error
   132  	if err = e.manager.addOneJob(ctx, e.db); err != nil {
   133  		return nil, err
   134  	}
   136  	// set it back finally
   137  	defer e.manager.removeOneJob(ctx, e.db)
   139  	tableName := common.UniqueTable(tableInfo.DB, tableInfo.Name)
   141  	task := log.With(zap.String("table", tableName)).Begin(zap.InfoLevel, "remote checksum")
   143  	// ADMIN CHECKSUM TABLE <table>,<table>  example.
   144  	// 	mysql> admin checksum table test.t;
   145  	// +---------+------------+---------------------+-----------+-------------+
   146  	// | Db_name | Table_name | Checksum_crc64_xor  | Total_kvs | Total_bytes |
   147  	// +---------+------------+---------------------+-----------+-------------+
   148  	// | test    | t          | 8520875019404689597 |   7296873 |   357601387 |
   149  	// +---------+------------+---------------------+-----------+-------------+
   151  	cs := RemoteChecksum{}
   152  	err = common.SQLWithRetry{DB: e.db, Logger: task.Logger}.QueryRow(ctx, "compute remote checksum",
   153  		"ADMIN CHECKSUM TABLE "+tableName, &cs.Schema, &cs.Table, &cs.Checksum, &cs.TotalKVs, &cs.TotalBytes,
   154  	)
   155  	dur := task.End(zap.ErrorLevel, err)
   156  	metric.ChecksumSecondsHistogram.Observe(dur.Seconds())
   157  	if err != nil {
   158  		return nil, errors.Trace(err)
   159  	}
   160  	return &cs, nil
   161  }
   163  // DoChecksum do checksum for tables.
   164  // table should be in <db>.<table>, format.  e.g.
   165  func DoChecksum(ctx context.Context, table *checkpoints.TidbTableInfo) (*RemoteChecksum, error) {
   166  	var err error
   167  	manager, ok := ctx.Value(&checksumManagerKey).(ChecksumManager)
   168  	if !ok {
   169  		return nil, errors.New("No gcLifeTimeManager found in context, check context initialization")
   170  	}
   172  	task := log.With(zap.String("table", table.Name)).Begin(zap.InfoLevel, "remote checksum")
   174  	cs, err := manager.Checksum(ctx, table)
   175  	dur := task.End(zap.ErrorLevel, err)
   176  	metric.ChecksumSecondsHistogram.Observe(dur.Seconds())
   178  	return cs, err
   179  }
   181  type gcLifeTimeManager struct {
   182  	runningJobsLock sync.Mutex
   183  	runningJobs     int
   184  	oriGCLifeTime   string
   185  }
   187  func newGCLifeTimeManager() *gcLifeTimeManager {
   188  	// Default values of three member are enough to initialize this struct
   189  	return &gcLifeTimeManager{}
   190  }
   192  // Pre- and post-condition:
   193  // if m.runningJobs == 0, GC life time has not been increased.
   194  // if m.runningJobs > 0, GC life time has been increased.
   195  // m.runningJobs won't be negative(overflow) since index concurrency is relatively small
   196  func (m *gcLifeTimeManager) addOneJob(ctx context.Context, db *sql.DB) error {
   197  	m.runningJobsLock.Lock()
   198  	defer m.runningJobsLock.Unlock()
   200  	if m.runningJobs == 0 {
   201  		oriGCLifeTime, err := ObtainGCLifeTime(ctx, db)
   202  		if err != nil {
   203  			return err
   204  		}
   205  		m.oriGCLifeTime = oriGCLifeTime
   206  		err = increaseGCLifeTime(ctx, m, db)
   207  		if err != nil {
   208  			return err
   209  		}
   210  	}
   211  	m.runningJobs++
   212  	return nil
   213  }
   215  // Pre- and post-condition:
   216  // if m.runningJobs == 0, GC life time has been tried to recovered. If this try fails, a warning will be printed.
   217  // if m.runningJobs > 0, GC life time has not been recovered.
   218  // m.runningJobs won't minus to negative since removeOneJob follows a successful addOneJob.
   219  func (m *gcLifeTimeManager) removeOneJob(ctx context.Context, db *sql.DB) {
   220  	m.runningJobsLock.Lock()
   221  	defer m.runningJobsLock.Unlock()
   223  	m.runningJobs--
   224  	if m.runningJobs == 0 {
   225  		err := UpdateGCLifeTime(ctx, db, m.oriGCLifeTime)
   226  		if err != nil {
   227  			query := fmt.Sprintf(
   228  				"UPDATE mysql.tidb SET VARIABLE_VALUE = '%s' WHERE VARIABLE_NAME = 'tikv_gc_life_time'",
   229  				m.oriGCLifeTime,
   230  			)
   231  			log.L().Warn("revert GC lifetime failed, please reset the GC lifetime manually after Lightning completed",
   232  				zap.String("query", query),
   233  				log.ShortError(err),
   234  			)
   235  		}
   236  	}
   237  }
   239  func increaseGCLifeTime(ctx context.Context, manager *gcLifeTimeManager, db *sql.DB) (err error) {
   240  	// checksum command usually takes a long time to execute,
   241  	// so here need to increase the gcLifeTime for single transaction.
   242  	var increaseGCLifeTime bool
   243  	if manager.oriGCLifeTime != "" {
   244  		ori, err := time.ParseDuration(manager.oriGCLifeTime)
   245  		if err != nil {
   246  			return errors.Trace(err)
   247  		}
   248  		if ori < defaultGCLifeTime {
   249  			increaseGCLifeTime = true
   250  		}
   251  	} else {
   252  		increaseGCLifeTime = true
   253  	}
   255  	if increaseGCLifeTime {
   256  		err = UpdateGCLifeTime(ctx, db, defaultGCLifeTime.String())
   257  		if err != nil {
   258  			return err
   259  		}
   260  	}
   262  	failpoint.Inject("IncreaseGCUpdateDuration", nil)
   264  	return nil
   265  }
   267  type tikvChecksumManager struct {
   268  	client                 kv.Client
   269  	manager                gcTTLManager
   270  	distSQLScanConcurrency uint
   271  }
   273  // newTiKVChecksumManager return a new tikv checksum manager
   274  func newTiKVChecksumManager(client kv.Client, pdClient pd.Client, distSQLScanConcurrency uint) *tikvChecksumManager {
   275  	return &tikvChecksumManager{
   276  		client:                 client,
   277  		manager:                newGCTTLManager(pdClient),
   278  		distSQLScanConcurrency: distSQLScanConcurrency,
   279  	}
   280  }
   282  func (e *tikvChecksumManager) checksumDB(ctx context.Context, tableInfo *checkpoints.TidbTableInfo) (*RemoteChecksum, error) {
   283  	physicalTS, logicalTS, err := e.manager.pdClient.GetTS(ctx)
   284  	if err != nil {
   285  		return nil, errors.Annotate(err, "fetch tso from pd failed")
   286  	}
   287  	executor, err := checksum.NewExecutorBuilder(tableInfo.Core, oracle.ComposeTS(physicalTS, logicalTS)).
   288  		SetConcurrency(e.distSQLScanConcurrency).
   289  		Build()
   290  	if err != nil {
   291  		return nil, errors.Trace(err)
   292  	}
   294  	distSQLScanConcurrency := int(e.distSQLScanConcurrency)
   295  	for i := 0; i < maxErrorRetryCount; i++ {
   296  		_ = executor.Each(func(request *kv.Request) error {
   297  			request.Concurrency = distSQLScanConcurrency
   298  			return nil
   299  		})
   300  		var execRes *tipb.ChecksumResponse
   301  		execRes, err = executor.Execute(ctx, e.client, func() {})
   302  		if err == nil {
   303  			return &RemoteChecksum{
   304  				Schema:     tableInfo.DB,
   305  				Table:      tableInfo.Name,
   306  				Checksum:   execRes.Checksum,
   307  				TotalBytes: execRes.TotalBytes,
   308  				TotalKVs:   execRes.TotalKvs,
   309  			}, nil
   310  		}
   312  		log.L().Warn("remote checksum failed", zap.String("db", tableInfo.DB),
   313  			zap.String("table", tableInfo.Name), zap.Error(err),
   314  			zap.Int("concurrency", distSQLScanConcurrency), zap.Int("retry", i))
   316  		// do not retry context.Canceled error
   317  		if !common.IsRetryableError(err) {
   318  			break
   319  		}
   320  		if distSQLScanConcurrency > minDistSQLScanConcurrency {
   321  			distSQLScanConcurrency = utils.MaxInt(distSQLScanConcurrency/2, minDistSQLScanConcurrency)
   322  		}
   323  	}
   325  	return nil, err
   326  }
   328  func (e *tikvChecksumManager) Checksum(ctx context.Context, tableInfo *checkpoints.TidbTableInfo) (*RemoteChecksum, error) {
   329  	tbl := common.UniqueTable(tableInfo.DB, tableInfo.Name)
   330  	err := e.manager.addOneJob(ctx, tbl, oracle.ComposeTS(time.Now().Unix()*1000, 0))
   331  	if err != nil {
   332  		return nil, errors.Trace(err)
   333  	}
   335  	return e.checksumDB(ctx, tableInfo)
   336  }
   338  type tableChecksumTS struct {
   339  	table    string
   340  	gcSafeTS uint64
   341  }
   343  // following function are for implement `heap.Interface`
   345  func (m *gcTTLManager) Len() int {
   346  	return len(m.tableGCSafeTS)
   347  }
   349  func (m *gcTTLManager) Less(i, j int) bool {
   350  	return m.tableGCSafeTS[i].gcSafeTS < m.tableGCSafeTS[j].gcSafeTS
   351  }
   353  func (m *gcTTLManager) Swap(i, j int) {
   354  	m.tableGCSafeTS[i], m.tableGCSafeTS[j] = m.tableGCSafeTS[j], m.tableGCSafeTS[i]
   355  }
   357  func (m *gcTTLManager) Push(x interface{}) {
   358  	m.tableGCSafeTS = append(m.tableGCSafeTS, x.(*tableChecksumTS))
   359  }
   361  func (m *gcTTLManager) Pop() interface{} {
   362  	i := m.tableGCSafeTS[len(m.tableGCSafeTS)-1]
   363  	m.tableGCSafeTS = m.tableGCSafeTS[:len(m.tableGCSafeTS)-1]
   364  	return i
   365  }
   367  type gcTTLManager struct {
   368  	lock     sync.Mutex
   369  	pdClient pd.Client
   370  	// tableGCSafeTS is a binary heap that stored active checksum jobs GC safe point ts
   371  	tableGCSafeTS []*tableChecksumTS
   372  	currentTS     uint64
   373  	serviceID     string
   374  	// 0 for not start, otherwise started
   375  	started uint32
   376  }
   378  func newGCTTLManager(pdClient pd.Client) gcTTLManager {
   379  	return gcTTLManager{
   380  		pdClient:  pdClient,
   381  		serviceID: fmt.Sprintf("lightning-%s", uuid.New()),
   382  	}
   383  }
   385  func (m *gcTTLManager) addOneJob(ctx context.Context, table string, ts uint64) error {
   386  	// start gc ttl loop if not started yet.
   387  	if atomic.CompareAndSwapUint32(&m.started, 0, 1) {
   388  		m.start(ctx)
   389  	}
   390  	m.lock.Lock()
   391  	defer m.lock.Unlock()
   392  	var curTS uint64
   393  	if len(m.tableGCSafeTS) > 0 {
   394  		curTS = m.tableGCSafeTS[0].gcSafeTS
   395  	}
   396  	m.Push(&tableChecksumTS{table: table, gcSafeTS: ts})
   397  	heap.Fix(m, len(m.tableGCSafeTS)-1)
   398  	m.currentTS = m.tableGCSafeTS[0].gcSafeTS
   399  	if curTS == 0 || m.currentTS < curTS {
   400  		return m.doUpdateGCTTL(ctx, m.currentTS)
   401  	}
   402  	return nil
   403  }
   405  func (m *gcTTLManager) removeOneJob(table string) {
   406  	m.lock.Lock()
   407  	defer m.lock.Unlock()
   408  	idx := -1
   409  	for i := 0; i < len(m.tableGCSafeTS); i++ {
   410  		if m.tableGCSafeTS[i].table == table {
   411  			idx = i
   412  			break
   413  		}
   414  	}
   416  	if idx >= 0 {
   417  		l := len(m.tableGCSafeTS)
   418  		m.tableGCSafeTS[idx] = m.tableGCSafeTS[l-1]
   419  		m.tableGCSafeTS = m.tableGCSafeTS[:l-1]
   420  		if l > 1 && idx < l-1 {
   421  			heap.Fix(m, idx)
   422  		}
   423  	}
   425  	var newTS uint64
   426  	if len(m.tableGCSafeTS) > 0 {
   427  		newTS = m.tableGCSafeTS[0].gcSafeTS
   428  	}
   429  	m.currentTS = newTS
   430  }
   432  func (m *gcTTLManager) updateGCTTL(ctx context.Context) error {
   433  	m.lock.Lock()
   434  	currentTS := m.currentTS
   435  	m.lock.Unlock()
   436  	return m.doUpdateGCTTL(ctx, currentTS)
   437  }
   439  func (m *gcTTLManager) doUpdateGCTTL(ctx context.Context, ts uint64) error {
   440  	log.L().Debug("update PD safePoint limit with TTL",
   441  		zap.Uint64("currnet_ts", ts))
   442  	var err error
   443  	if ts > 0 {
   444  		_, err = m.pdClient.UpdateServiceGCSafePoint(ctx,
   445  			m.serviceID, serviceSafePointTTL, ts)
   446  	}
   447  	return err
   448  }
   450  func (m *gcTTLManager) start(ctx context.Context) {
   451  	// It would be OK since TTL won't be zero, so gapTime should > `0.
   452  	updateGapTime := time.Duration(serviceSafePointTTL) * time.Second / preUpdateServiceSafePointFactor
   454  	updateTick := time.NewTicker(updateGapTime)
   456  	updateGCTTL := func() {
   457  		if err := m.updateGCTTL(ctx); err != nil {
   458  			log.L().Warn("failed to update service safe point, checksum may fail if gc triggered", zap.Error(err))
   459  		}
   460  	}
   462  	// trigger a service gc ttl at start
   463  	updateGCTTL()
   464  	go func() {
   465  		defer updateTick.Stop()
   466  		for {
   467  			select {
   468  			case <-ctx.Done():
   469  				log.L().Info("service safe point keeper exited")
   470  				return
   471  			case <-updateTick.C:
   472  				updateGCTTL()
   473  			}
   474  		}
   475  	}()
   476  }