
     1  // Copyright 2022 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    14  package dm
    16  import (
    17  	"context"
    18  	"database/sql"
    19  	"fmt"
    20  	"math/rand"
    21  	"os"
    22  	"time"
    24  	""
    25  	""
    26  	""
    27  	""
    28  	""
    29  	""
    30  	sqlconfig ""
    31  	""
    32  	sqlgen ""
    33  	pb ""
    34  	""
    35  	""
    36  	""
    37  	""
    38  	""
    39  )
    41  const (
    42  	tableNum = 5
    43  	rowNum   = 1000
    44  	batch    = 100
    45  	// 5 minutes
    46  	diffTimes    = 150
    47  	diffInterval = 2 * time.Second
    48  )
    50  // Case is a data migration Case test case with one or more sources.
    51  type Case struct {
    52  	addr     string
    53  	cfgBytes []byte
    54  	sources  []*dbConn
    55  	target   *dbConn
    56  	tables   []string
    57  	jobID    string
    58  	name     string
    60  	// source -> table -> mcp
    61  	mcps []map[string]*mcp.ModificationCandidatePool
    62  	// source -> table -> generator
    63  	generators []map[string]sqlgen.SQLGenerator
    64  	// table -> key -> struct{}
    65  	keySet map[string]map[string]struct{}
    67  	result []int
    68  }
    70  // NewCase creates a new test case.
    71  func NewCase(ctx context.Context, addr string, name string, cfgPath string) (*Case, error) {
    72  	cfgBytes, err := os.ReadFile(cfgPath)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    77  	var jobCfg config.JobCfg
    78  	if err := jobCfg.Decode(cfgBytes); err != nil {
    79  		return nil, err
    80  	}
    82  	c := &Case{
    83  		sources:    make([]*dbConn, 0, len(jobCfg.Upstreams)),
    84  		cfgBytes:   cfgBytes,
    85  		addr:       addr,
    86  		name:       name,
    87  		mcps:       make([]map[string]*mcp.ModificationCandidatePool, 0, 3),
    88  		generators: make([]map[string]sqlgen.SQLGenerator, 0, 3),
    89  		keySet:     make(map[string]map[string]struct{}, tableNum),
    90  		result:     make([]int, 3),
    91  	}
    92  	for _, upstream := range jobCfg.Upstreams {
    93  		source, err := newDBConn(ctx, conn.UpstreamDBConfig(upstream.DBCfg), name)
    94  		if err != nil {
    95  			return nil, err
    96  		}
    97  		c.sources = append(c.sources, source)
    98  	}
    99  	target, err := newDBConn(ctx, conn.DownstreamDBConfig(jobCfg.TargetDB), name)
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103 = target
   105  	// init table config
   106  	for range c.sources {
   107  		generators := make(map[string]sqlgen.SQLGenerator)
   108  		mcps := make(map[string]*mcp.ModificationCandidatePool)
   109  		for i := 0; i < tableNum; i++ {
   110  			tableName := fmt.Sprintf("tb%d", i)
   111  			tableConfig := &sqlconfig.TableConfig{
   112  				DatabaseName:,
   113  				TableName:    tableName,
   114  				Columns: []*sqlconfig.ColumnDefinition{
   115  					{
   116  						ColumnName: "id",
   117  						DataType:   "int",
   118  						DataLen:    11,
   119  					},
   120  					{
   121  						ColumnName: "name",
   122  						DataType:   "varchar",
   123  						DataLen:    255,
   124  					},
   125  					{
   126  						ColumnName: "age",
   127  						DataType:   "int",
   128  						DataLen:    11,
   129  					},
   130  					{
   131  						ColumnName: "team_id",
   132  						DataType:   "int",
   133  						DataLen:    11,
   134  					},
   135  				},
   136  				UniqueKeyColumnNames: []string{"id"},
   137  			}
   138  			generators[tableName] = sqlgen.NewSQLGeneratorImpl(tableConfig)
   139  			mcps[tableName] = mcp.NewModificationCandidatePool(100000000)
   140  			c.keySet[tableName] = make(map[string]struct{})
   141  			c.tables = append(c.tables, tableName)
   142  		}
   143  		c.generators = append(c.generators, generators)
   144  		c.mcps = append(c.mcps, mcps)
   145  	}
   147  	return c, nil
   148  }
   150  // Run runs a test case.
   151  func (c *Case) Run(ctx context.Context) error {
   152  	defer func() {
   153  		log.L().Info("finish run case", zap.String("name",, zap.String("job_id", c.jobID), zap.Int("insert", c.result[0]), zap.Int("update", c.result[1]), zap.Int("delete", c.result[2]))
   154  	}()
   155  	if err := c.genFullData(); err != nil {
   156  		return err
   157  	}
   158  	if err := c.createJob(ctx); err != nil {
   159  		return err
   160  	}
   161  	if err := c.diffDataLoop(ctx); err != nil {
   162  		return err
   163  	}
   164  	log.L().Info("full mode of the task has completed", zap.String("name",, zap.String("job_id", c.jobID))
   165  	return c.incrLoop(ctx)
   166  }
   168  func (c *Case) createJob(ctx context.Context) error {
   169  	return retry.Do(ctx, func() error {
   170  		jobID, err := e2e.CreateJobViaHTTP(ctx, c.addr, "chaos-dm-test", "project-dm", pb.Job_DM, c.cfgBytes)
   171  		if err != nil {
   172  			log.L().Error("create job failed", zap.String("name",, zap.Error(err))
   173  			return err
   174  		}
   175  		c.jobID = jobID
   176  		return nil
   177  	},
   178  		retry.WithBackoffBaseDelay(1000 /* 1 second */),
   179  		retry.WithBackoffMaxDelay(8000 /* 8 seconds */),
   180  		retry.WithMaxTries(15 /* fail after 103 seconds */),
   181  	)
   182  }
   184  func (c *Case) genFullData() error {
   185  	log.L().Info("start generate full data", zap.String("name",, zap.String("job_id", c.jobID))
   186  	for source, generators := range c.generators {
   187  		for table, generator := range generators {
   188  			if _, err := c.sources[source].ExecuteSQLs(
   189  				"CREATE DATABASE IF NOT EXISTS "" CHARSET utf8mb4 COLLATE utf8mb4_general_ci",
   190  				"USE "; err != nil {
   191  				return err
   192  			}
   193  			if _, err := c.sources[source].ExecuteSQLs(generator.GenCreateTable()); err != nil {
   194  				return err
   195  			}
   196  			sqls := make([]string, 0, rowNum)
   197  			for j := 0; j < rowNum; j++ {
   198  				sql, uk, err := generator.GenInsertRow()
   199  				if err != nil {
   200  					return err
   201  				}
   202  				// key already exists
   203  				if _, ok := c.keySet[table][uk.GetValueHash()]; ok {
   204  					continue
   205  				}
   206  				if err := c.mcps[source][table].AddUK(uk); err != nil {
   207  					return err
   208  				}
   209  				c.keySet[table][uk.GetValueHash()] = struct{}{}
   210  				sqls = append(sqls, sql)
   211  			}
   212  			if _, err := c.sources[source].ExecuteSQLs(sqls...); err != nil {
   213  				return err
   214  			}
   215  		}
   216  	}
   217  	return nil
   218  }
   220  func (c *Case) diffData(ctx context.Context) (bool, error) {
   221  	log.L().Info("start diff data", zap.String("name",, zap.String("job_id", c.jobID))
   222  	for _, tableName := range c.tables {
   223  		row :=, fmt.Sprintf("SELECT count(1) FROM %s", dbutil.TableName(, tableName)))
   224  		if row.Err() != nil {
   225  			if row.Err() == context.DeadlineExceeded {
   226  				return false, nil
   227  			}
   228  			return false, row.Err()
   229  		}
   230  		var count int
   231  		if err := row.Scan(&count); err != nil {
   232  			return false, err
   233  		}
   234  		var totalCount int
   235  		for _, mcps := range c.mcps {
   236  			totalCount += mcps[tableName].Len()
   237  		}
   238  		if count != totalCount {
   239  			log.Error("data is not same", zap.String("name",, zap.String("job_id", c.jobID), zap.Int("downstream", count), zap.Int("upstream", totalCount))
   240  			return false, nil
   241  		}
   242  	}
   243  	return true, nil
   244  }
   246  func (c *Case) diffDataLoop(ctx context.Context) error {
   247  	for i := 0; i < diffTimes; i++ {
   248  		select {
   249  		case <-ctx.Done():
   250  			return nil
   251  		case <-time.After(diffInterval):
   252  			if same, err := c.diffData(ctx); err != nil {
   253  				if ignoreErrNoSuchTable(err) {
   254  					continue
   255  				}
   256  				return err
   257  			} else if same {
   258  				return nil
   259  			}
   260  		}
   261  	}
   262  	sourceDBs := make([]*sql.DB, 0, len(c.sources))
   263  	for _, s := range c.sources {
   264  		sourceDBs = append(sourceDBs, s.db.DB)
   265  	}
   266  	return syncDiffInspector(ctx,, c.tables,, sourceDBs...)
   267  }
   269  // randDML generates DML (INSERT, UPDATE or DELETE).
   270  func (c *Case) randDML(source int, table string, deleteKeys map[string][]string) (string, error) {
   271  	generator := c.generators[source][table]
   272  	mcp := c.mcps[source][table]
   273  	t := rand.Intn(3)
   274  	key := mcp.NextUK()
   275  	// no rows
   276  	if key == nil {
   277  		t = 0
   278  	}
   279  	c.result[t]++
   280  	switch t {
   281  	case 0:
   282  		sql, uk, err := generator.GenInsertRow()
   283  		if err != nil {
   284  			return "", err
   285  		}
   286  		_, ok := c.keySet[table][uk.GetValueHash()]
   287  		for ok {
   288  			sql, uk, err = generator.GenInsertRow()
   289  			if err != nil {
   290  				return "", err
   291  			}
   292  			_, ok = c.keySet[table][uk.GetValueHash()]
   293  		}
   294  		if err := c.mcps[source][table].AddUK(uk); err != nil {
   295  			return "", err
   296  		}
   297  		c.keySet[table][uk.GetValueHash()] = struct{}{}
   298  		return sql, nil
   299  	case 1:
   300  		return generator.GenUpdateRow(key)
   301  	default:
   302  		sql, err := generator.GenDeleteRow(key)
   303  		if err != nil {
   304  			return "", err
   305  		}
   306  		deleteKeys[table] = append(deleteKeys[table], key.GetValueHash())
   307  		err = mcp.DeleteUK(key)
   308  		return sql, err
   309  	}
   310  }
   312  func (c *Case) genIncrData(ctx context.Context) error {
   313  	log.L().Info("start generate incremental data", zap.String("name",, zap.String("job_id", c.jobID))
   314  	deleteKeys := make(map[string][]string)
   315  	defer func() {
   316  		for tb, keys := range deleteKeys {
   317  			for _, k := range keys {
   318  				delete(c.keySet[tb], k)
   319  			}
   320  		}
   321  	}()
   322  	for {
   323  		select {
   324  		case <-ctx.Done():
   325  			return nil
   326  		default:
   327  		}
   328  		source := rand.Intn(len(c.sources))
   329  		tableName := c.tables[rand.Intn(tableNum)]
   331  		sqls := make([]string, 0, batch)
   332  		for i := 0; i < batch; i++ {
   333  			sql, err := c.randDML(source, tableName, deleteKeys)
   334  			if err != nil {
   335  				return err
   336  			}
   337  			sqls = append(sqls, sql)
   338  		}
   339  		if _, err := c.sources[source].ExecuteSQLs(sqls...); err != nil {
   340  			return err
   341  		}
   342  	}
   343  }
   345  func (c *Case) incrLoop(ctx context.Context) error {
   346  	for {
   347  		select {
   348  		case <-ctx.Done():
   349  			return nil
   350  		default:
   351  		}
   352  		ctx2, cancel := context.WithTimeout(ctx, time.Second*10)
   353  		err := c.genIncrData(ctx2)
   354  		cancel()
   355  		if err != nil {
   356  			return err
   357  		}
   358  		if err := c.diffDataLoop(ctx); err != nil {
   359  			return err
   360  		}
   361  	}
   362  }
   364  func ignoreErrNoSuchTable(err error) bool {
   365  	err = errors.Cause(err)
   366  	mysqlErr, ok := err.(*mysql.MySQLError)
   367  	if !ok {
   368  		return false
   369  	}
   371  	switch mysqlErr.Number {
   372  	case errno.ErrNoSuchTable:
   373  		return true
   374  	default:
   375  		return false
   376  	}
   377  }
   379  func syncDiffInspector(ctx context.Context, schema string, tables []string, targetDB *sql.DB, sourceDBs ...*sql.DB) error {
   380  	for _, table := range tables {
   381  		sourceTables := make([]*diff.TableInstance, 0, len(sourceDBs))
   382  		for i, sourceDB := range sourceDBs {
   383  			sourceTables = append(sourceTables, &diff.TableInstance{
   384  				Conn:       sourceDB,
   385  				Schema:     schema,
   386  				Table:      table,
   387  				InstanceID: fmt.Sprintf("source-%d", i),
   388  			})
   389  		}
   391  		targetTable := &diff.TableInstance{
   392  			Conn:       targetDB,
   393  			Schema:     schema,
   394  			Table:      table,
   395  			InstanceID: "target",
   396  		}
   398  		td := &diff.TableDiff{
   399  			SourceTables:     sourceTables,
   400  			TargetTable:      targetTable,
   401  			ChunkSize:        1000,
   402  			Sample:           100,
   403  			CheckThreadCount: 1,
   404  			UseChecksum:      true,
   405  			TiDBStatsSource:  targetTable,
   406  			CpDB:             targetDB,
   407  		}
   409  		structEqual, dataEqual, err := td.Equal(ctx, func(dml string) error {
   410  			return nil
   411  		})
   413  		if errors.Cause(err) == context.Canceled || errors.Cause(err) == context.DeadlineExceeded {
   414  			return nil
   415  		}
   416  		if !structEqual {
   417  			return errors.Errorf("different struct for table %s", dbutil.TableName(schema, table))
   418  		} else if !dataEqual {
   419  			return errors.Errorf("different data for table %s", dbutil.TableName(schema, table))
   420  		}
   421  		log.L().Info("data equal for table", zap.String("schema", schema), zap.String("table", table))
   422  	}
   424  	return nil
   425  }