github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/tests/utils/checksum_checker/main.go (about)

     1  // Copyright 2021 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package main
    15  
    16  import (
    17  	"database/sql"
    18  	"flag"
    19  	"fmt"
    20  	"os"
    21  	"strings"
    22  	"time"
    23  
    24  	"github.com/pingcap/log"
    25  	"github.com/pingcap/tiflow/cdc/sink/dmlsink/mq/transformer/columnselector"
    26  	cmdUtil "github.com/pingcap/tiflow/pkg/cmd/util"
    27  	"github.com/pingcap/tiflow/pkg/config"
    28  	"github.com/pingcap/tiflow/pkg/errors"
    29  	"go.uber.org/zap"
    30  )
    31  
    32  type options struct {
    33  	upstreamURI   string
    34  	downstreamURI string
    35  	dbNames       string
    36  	configFile    string
    37  }
    38  
    39  func (o *options) validate() error {
    40  	if o.upstreamURI == "" {
    41  		return errors.New("upstreamURI is required")
    42  	}
    43  	if o.downstreamURI == "" {
    44  		return errors.New("downstreamURI is required")
    45  	}
    46  	if len(o.dbNames) == 0 {
    47  		return errors.New("dbNames is required")
    48  	}
    49  	return nil
    50  }
    51  
    52  func main() {
    53  	o := &options{}
    54  
    55  	flags := flag.NewFlagSet(os.Args[0], flag.ExitOnError)
    56  	flags.StringVar(&o.upstreamURI, "upstream-uri", "", "upstream database uri")
    57  	flags.StringVar(&o.downstreamURI, "downstream-uri", "", "downstream database uri")
    58  	flags.StringVar(&o.dbNames, "databases", "", "database names, separate by the `,`")
    59  	flags.StringVar(&o.configFile, "config", "", "config file")
    60  	if err := flags.Parse(os.Args[1:]); err != nil {
    61  		log.Panic("parse args failed", zap.Error(err))
    62  	}
    63  	if err := o.validate(); err != nil {
    64  		log.Panic("invalid options", zap.Error(err))
    65  	}
    66  
    67  	upstreamDB, err := openDB(o.upstreamURI)
    68  	if err != nil {
    69  		log.Panic("cannot open db for the upstream", zap.Error(err))
    70  	}
    71  
    72  	downstreamDB, err := openDB(o.downstreamURI)
    73  	if err != nil {
    74  		log.Panic("cannot open db for the downstream", zap.Error(err))
    75  	}
    76  
    77  	replicaConfig := config.GetDefaultReplicaConfig()
    78  	if o.configFile != "" {
    79  		err = cmdUtil.StrictDecodeFile(o.configFile, "checksum checker", replicaConfig)
    80  		if err != nil {
    81  			log.Panic("cannot decode config file", zap.Error(err))
    82  		}
    83  	}
    84  
    85  	columnFilter, err := columnselector.New(replicaConfig)
    86  	if err != nil {
    87  		log.Panic("cannot create column filter", zap.Error(err))
    88  	}
    89  
    90  	dbNames := strings.Split(o.dbNames, ",")
    91  	err = compareCRC32CheckSum(upstreamDB, downstreamDB, dbNames, columnFilter)
    92  	if err != nil {
    93  		log.Panic("compare checksum failed", zap.Error(err))
    94  	}
    95  	log.Info("compare checksum passed")
    96  }
    97  
    98  func compareCRC32CheckSum(
    99  	upstream, downstream *sql.DB, dbNames []string, selector *columnselector.ColumnSelector,
   100  ) error {
   101  	start := time.Now()
   102  	source, err := getChecksum(upstream, dbNames, selector)
   103  	if err != nil {
   104  		log.Warn("get checksum for the upstream failed", zap.Error(err))
   105  		return errors.Trace(err)
   106  	}
   107  	log.Info("get checksum for the upstream success",
   108  		zap.Duration("elapsed", time.Since(start)))
   109  
   110  	start = time.Now()
   111  	sink, err := getChecksum(downstream, dbNames, selector)
   112  	if err != nil {
   113  		log.Warn("get checksum for the downstream failed", zap.Error(err))
   114  		return errors.Trace(err)
   115  	}
   116  	log.Info("get checksum for the downstream success",
   117  		zap.Duration("elapsed", time.Since(start)))
   118  
   119  	if len(source) != len(sink) {
   120  		log.Error("source and sink have different crc32 size",
   121  			zap.Int("source", len(source)), zap.Int("sink", len(sink)))
   122  		return fmt.Errorf("source and sink have different crc32 size, source: %d, sink: %d",
   123  			len(source), len(sink))
   124  	}
   125  
   126  	for tableName, expected := range source {
   127  		actual, ok := sink[tableName]
   128  		if !ok {
   129  			return fmt.Errorf("table not found at sink, table: %s", tableName)
   130  		}
   131  		if expected != actual {
   132  			log.Error("crc32 mismatch",
   133  				zap.String("table", tableName), zap.Uint32("source", expected), zap.Uint32("sink", actual))
   134  			return fmt.Errorf("crc32 mismatch, table: %s, source: %d, sink: %d", tableName, expected, actual)
   135  		}
   136  	}
   137  	return nil
   138  }
   139  
   140  func getChecksum(
   141  	db *sql.DB, dbNames []string, selector *columnselector.ColumnSelector,
   142  ) (map[string]uint32, error) {
   143  	result := make(map[string]uint32)
   144  	for _, dbName := range dbNames {
   145  		tables, err := getAllTables(db, dbName)
   146  		if err != nil {
   147  			return nil, err
   148  		}
   149  		for _, table := range tables {
   150  			tx, err := db.Begin()
   151  			if err != nil {
   152  				_ = tx.Rollback()
   153  				return nil, errors.Trace(err)
   154  			}
   155  			columns, err := getColumns(tx, dbName, table, selector)
   156  			if err != nil {
   157  				_ = tx.Rollback()
   158  				return nil, errors.Trace(err)
   159  			}
   160  			checksum, err := doChecksum(tx, dbName, table, columns)
   161  			if err != nil {
   162  				_ = tx.Rollback()
   163  				return nil, errors.Trace(err)
   164  			}
   165  			_ = tx.Commit()
   166  			result[dbName+"."+table] = checksum
   167  		}
   168  	}
   169  	return result, nil
   170  }
   171  
   172  func doChecksum(tx *sql.Tx, schema, table string, columns []string) (uint32, error) {
   173  	a := strings.Join(columns, "`,`")
   174  
   175  	concat := fmt.Sprintf("CONCAT_WS(',', `%s`)", a)
   176  	tableName := schema + "." + table
   177  	query := fmt.Sprintf("SELECT BIT_XOR(CRC32(%s)) AS checksum FROM %s", concat, tableName)
   178  	var checkSum uint32
   179  	rows := tx.QueryRow(query)
   180  	err := rows.Scan(&checkSum)
   181  	if err != nil {
   182  		log.Error("get crc32 checksum failed",
   183  			zap.Error(err), zap.String("table", tableName), zap.String("query", query))
   184  		return 0, errors.Trace(err)
   185  	}
   186  	log.Info("do checkSum success", zap.String("table", tableName), zap.Uint32("checkSum", checkSum))
   187  	return checkSum, nil
   188  }
   189  
   190  func getColumns(tx *sql.Tx, schema, table string, selector *columnselector.ColumnSelector) (result []string, err error) {
   191  	rows, err := tx.Query(fmt.Sprintf("SHOW COLUMNS FROM %s", schema+"."+table))
   192  	if err != nil {
   193  		return nil, errors.Trace(err)
   194  	}
   195  	defer func() {
   196  		if err := rows.Close(); err != nil {
   197  			log.Warn("close rows failed", zap.Error(err))
   198  		}
   199  	}()
   200  
   201  	for rows.Next() {
   202  		var t columnInfo
   203  		if err := rows.Scan(&t.Field, &t.Type, &t.Null, &t.Key, &t.Default, &t.Extra); err != nil {
   204  			return result, errors.Trace(err)
   205  		}
   206  		if selector.VerifyColumn(schema, table, t.Field) {
   207  			result = append(result, t.Field)
   208  		}
   209  	}
   210  	return result, nil
   211  }
   212  
   213  type columnInfo struct {
   214  	Field   string
   215  	Type    string
   216  	Null    string
   217  	Key     string
   218  	Default *string
   219  	Extra   string
   220  }
   221  
   222  func getAllTables(db *sql.DB, dbName string) ([]string, error) {
   223  	var result []string
   224  	dbName = strings.TrimSpace(dbName)
   225  	tx, err := db.Begin()
   226  	if err != nil {
   227  		_ = tx.Rollback()
   228  		return nil, errors.Trace(err)
   229  	}
   230  	query := fmt.Sprintf(`show full tables from %s where table_type != "VIEW"`, dbName)
   231  	rows, err := tx.Query(query)
   232  	if err != nil {
   233  		_ = tx.Rollback()
   234  		return nil, errors.Trace(err)
   235  	}
   236  	for rows.Next() {
   237  		var t string
   238  		var tt string
   239  		if err := rows.Scan(&t, &tt); err != nil {
   240  			_ = tx.Rollback()
   241  			return nil, errors.Trace(err)
   242  		}
   243  		result = append(result, t)
   244  	}
   245  	_ = rows.Close()
   246  	_ = tx.Commit()
   247  	return result, nil
   248  }
   249  
   250  func openDB(uri string) (*sql.DB, error) {
   251  	db, err := sql.Open("mysql", uri)
   252  	if err != nil {
   253  		return nil, errors.Trace(err)
   254  	}
   255  
   256  	if err := db.Ping(); err != nil {
   257  		return nil, errors.Trace(err)
   258  	}
   259  	return db, nil
   260  }