github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/tests/integration_tests/cdc/dailytest/db.go (about)

     1  // Copyright 2020 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package dailytest
    15  
    16  import (
    17  	"bytes"
    18  	"database/sql"
    19  	"fmt"
    20  	"math"
    21  	"strconv"
    22  	"time"
    23  
    24  	"github.com/pingcap/errors"
    25  	"github.com/pingcap/log"
    26  	"github.com/pingcap/tidb/pkg/parser/mysql"
    27  	"github.com/pingcap/tiflow/tests/integration_tests/util"
    28  	"go.uber.org/zap/zapcore"
    29  )
    30  
    31  func intRangeValue(column *column, min int64, max int64) (int64, int64) {
    32  	var err error
    33  	if len(column.min) > 0 {
    34  		min, err = strconv.ParseInt(column.min, 10, 64)
    35  		if err != nil {
    36  			log.S().Fatal(err)
    37  		}
    38  
    39  		if len(column.max) > 0 {
    40  			max, err = strconv.ParseInt(column.max, 10, 64)
    41  			if err != nil {
    42  				log.S().Fatal(err)
    43  			}
    44  		}
    45  	}
    46  
    47  	return min, max
    48  }
    49  
    50  func randInt64Value(column *column, min int64, max int64) int64 {
    51  	if len(column.set) > 0 {
    52  		idx := randInt(0, len(column.set)-1)
    53  		data, _ := strconv.ParseInt(column.set[idx], 10, 64)
    54  		return data
    55  	}
    56  
    57  	min, max = intRangeValue(column, min, max)
    58  	return randInt64(min, max)
    59  }
    60  
    61  func uniqInt64Value(column *column, max int64) int64 {
    62  	min, max := intRangeValue(column, 0, max)
    63  	column.data.setInitInt64Value(column.step, min, max)
    64  	return column.data.uniqInt64()
    65  }
    66  
    67  func queryCount(table *table, db *sql.DB) (int, error) {
    68  	rows, err := db.Query(fmt.Sprintf("SELECT COUNT(*) as count FROM %s", table.name))
    69  	if err != nil {
    70  		return 0, errors.Trace(err)
    71  	}
    72  
    73  	var nums int
    74  	for rows.Next() {
    75  		err = rows.Scan(&nums)
    76  		if err != nil {
    77  			return 0, errors.Trace(err)
    78  		}
    79  	}
    80  
    81  	return nums, nil
    82  }
    83  
    84  func genDeleteSqls(table *table, db *sql.DB, count int) ([]string, [][]interface{}, error) {
    85  	nums, err := queryCount(table, db)
    86  	if err != nil {
    87  		return nil, nil, errors.Trace(err)
    88  	}
    89  
    90  	var sqls []string
    91  	var args [][]interface{}
    92  
    93  	if nums == 0 || nums-count < 1 {
    94  		return sqls, args, nil
    95  	}
    96  
    97  	start := randInt(1, nums-count)
    98  	length := len(table.columns)
    99  	where := genWhere(table.columns)
   100  
   101  	rows, err := db.Query(fmt.Sprintf("SELECT * FROM %s limit %d, %d", table.name, start, count))
   102  	if err != nil {
   103  		return nil, nil, errors.Trace(err)
   104  	}
   105  
   106  	for rows.Next() {
   107  		data := make([]interface{}, length)
   108  		dbArgs := make([]interface{}, length)
   109  
   110  		for i := 0; i < length; i++ {
   111  			dbArgs[i] = &data[i]
   112  		}
   113  
   114  		err = rows.Scan(dbArgs...)
   115  		if err != nil {
   116  			return nil, nil, errors.Trace(err)
   117  		}
   118  
   119  		sqls = append(sqls, fmt.Sprintf("delete from %s where %s", table.name, where))
   120  		args = append(args, data)
   121  	}
   122  
   123  	return sqls, args, nil
   124  }
   125  
   126  func genUpdateSqls(table *table, db *sql.DB, count int) ([]string, [][]interface{}, error) {
   127  	nums, err := queryCount(table, db)
   128  	if err != nil {
   129  		return nil, nil, errors.Trace(err)
   130  	}
   131  
   132  	var sqls []string
   133  	var args [][]interface{}
   134  
   135  	if nums == 0 || nums-count < 1 {
   136  		return sqls, args, nil
   137  	}
   138  
   139  	start := randInt(1, nums-count)
   140  	length := len(table.columns)
   141  	where := genWhere(table.columns)
   142  
   143  	rows, err := db.Query(fmt.Sprintf("SELECT * FROM %s limit %d, %d", table.name, start, count))
   144  	if err != nil {
   145  		return nil, nil, errors.Trace(err)
   146  	}
   147  
   148  	for rows.Next() {
   149  		data := make([]interface{}, length)
   150  		dbArgs := make([]interface{}, length)
   151  
   152  		for i := 0; i < length; i++ {
   153  			dbArgs[i] = &data[i]
   154  		}
   155  
   156  		err = rows.Scan(dbArgs...)
   157  		if err != nil {
   158  			return nil, nil, errors.Trace(err)
   159  		}
   160  
   161  		index := randInt(2, length-1)
   162  		column := table.columns[index]
   163  		updateData, err := genColumnData(table, column)
   164  		if err != nil {
   165  			return nil, nil, errors.Trace(err)
   166  		}
   167  
   168  		sqls = append(sqls, fmt.Sprintf("update %s set `%s` = %s where %s", table.name, column.name, updateData, where))
   169  		args = append(args, data)
   170  	}
   171  
   172  	return sqls, args, nil
   173  }
   174  
   175  func genInsertSqls(table *table, count int) ([]string, [][]interface{}, error) {
   176  	datas := make([]string, 0, count)
   177  	args := make([][]interface{}, 0, count)
   178  	for i := 0; i < count; i++ {
   179  		data, err := genRowData(table)
   180  		if err != nil {
   181  			return nil, nil, errors.Trace(err)
   182  		}
   183  		datas = append(datas, data)
   184  		args = append(args, nil)
   185  	}
   186  
   187  	return datas, args, nil
   188  }
   189  
   190  func genWhere(columns []*column) string {
   191  	var kvs bytes.Buffer
   192  	for i := range columns {
   193  		if i == len(columns)-1 {
   194  			fmt.Fprintf(&kvs, "`%s` = ?", columns[i].name)
   195  		} else {
   196  			fmt.Fprintf(&kvs, "`%s` = ? and ", columns[i].name)
   197  		}
   198  	}
   199  
   200  	return kvs.String()
   201  }
   202  
   203  func genRowData(table *table) (string, error) {
   204  	var values []byte
   205  	for _, column := range table.columns {
   206  		data, err := genColumnData(table, column)
   207  		if err != nil {
   208  			return "", errors.Trace(err)
   209  		}
   210  		values = append(values, []byte(data)...)
   211  		values = append(values, ',')
   212  	}
   213  
   214  	values = values[:len(values)-1]
   215  	sql := fmt.Sprintf("insert into %s  values (%s);", table.name, string(values))
   216  	return sql, nil
   217  }
   218  
   219  func genColumnData(table *table, column *column) (string, error) {
   220  	tp := column.tp
   221  	_, isUnique := table.uniqIndices[column.name]
   222  	isUnsigned := mysql.HasUnsignedFlag(tp.GetFlag())
   223  
   224  	switch tp.GetType() {
   225  	case mysql.TypeTiny:
   226  		var data int64
   227  		if isUnique {
   228  			data = uniqInt64Value(column, math.MaxUint8)
   229  		} else {
   230  			if isUnsigned {
   231  				data = randInt64Value(column, 0, math.MaxUint8)
   232  			} else {
   233  				data = randInt64Value(column, math.MinInt8, math.MaxInt8)
   234  			}
   235  		}
   236  		return strconv.FormatInt(data, 10), nil
   237  	case mysql.TypeShort:
   238  		var data int64
   239  		if isUnique {
   240  			data = uniqInt64Value(column, math.MaxUint16)
   241  		} else {
   242  			if isUnsigned {
   243  				data = randInt64Value(column, 0, math.MaxUint16)
   244  			} else {
   245  				data = randInt64Value(column, math.MinInt16, math.MaxInt16)
   246  			}
   247  		}
   248  		return strconv.FormatInt(data, 10), nil
   249  	case mysql.TypeLong:
   250  		var data int64
   251  		if isUnique {
   252  			data = uniqInt64Value(column, math.MaxUint32)
   253  		} else {
   254  			if isUnsigned {
   255  				data = randInt64Value(column, 0, math.MaxUint32)
   256  			} else {
   257  				data = randInt64Value(column, math.MinInt32, math.MaxInt32)
   258  			}
   259  		}
   260  		return strconv.FormatInt(data, 10), nil
   261  	case mysql.TypeLonglong:
   262  		var data int64
   263  		if isUnique {
   264  			data = uniqInt64Value(column, math.MaxInt64)
   265  		} else {
   266  			if isUnsigned {
   267  				data = randInt64Value(column, 0, math.MaxInt64)
   268  			} else {
   269  				data = randInt64Value(column, math.MinInt32, math.MaxInt32)
   270  			}
   271  		}
   272  		return strconv.FormatInt(data, 10), nil
   273  	case mysql.TypeVarchar, mysql.TypeString, mysql.TypeTinyBlob, mysql.TypeBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
   274  		data := []byte{'\''}
   275  		if isUnique {
   276  			data = append(data, []byte(column.data.uniqString(tp.GetFlen()))...)
   277  		} else {
   278  			data = append(data, []byte(randString(randInt(1, tp.GetFlen())))...)
   279  		}
   280  
   281  		data = append(data, '\'')
   282  		return string(data), nil
   283  	case mysql.TypeFloat, mysql.TypeDouble:
   284  		var data float64
   285  		if isUnique {
   286  			data = float64(uniqInt64Value(column, math.MaxInt64))
   287  		} else {
   288  			if isUnsigned {
   289  				data = float64(randInt64Value(column, 0, math.MaxInt64))
   290  			} else {
   291  				data = float64(randInt64Value(column, math.MinInt32, math.MaxInt32))
   292  			}
   293  		}
   294  		return strconv.FormatFloat(data, 'f', -1, 64), nil
   295  	case mysql.TypeDate:
   296  		data := []byte{'\''}
   297  		if isUnique {
   298  			data = append(data, []byte(column.data.uniqDate())...)
   299  		} else {
   300  			data = append(data, []byte(randDate(column.min, column.max))...)
   301  		}
   302  
   303  		data = append(data, '\'')
   304  		return string(data), nil
   305  	case mysql.TypeDatetime, mysql.TypeTimestamp:
   306  		data := []byte{'\''}
   307  		if isUnique {
   308  			data = append(data, []byte(column.data.uniqTimestamp())...)
   309  		} else {
   310  			data = append(data, []byte(randTimestamp(column.min, column.max))...)
   311  		}
   312  
   313  		data = append(data, '\'')
   314  		return string(data), nil
   315  	case mysql.TypeDuration:
   316  		data := []byte{'\''}
   317  		if isUnique {
   318  			data = append(data, []byte(column.data.uniqTime())...)
   319  		} else {
   320  			data = append(data, []byte(randTime(column.min, column.max))...)
   321  		}
   322  
   323  		data = append(data, '\'')
   324  		return string(data), nil
   325  	case mysql.TypeYear:
   326  		data := []byte{'\''}
   327  		if isUnique {
   328  			data = append(data, []byte(column.data.uniqYear())...)
   329  		} else {
   330  			data = append(data, []byte(randYear(column.min, column.max))...)
   331  		}
   332  
   333  		data = append(data, '\'')
   334  		return string(data), nil
   335  	default:
   336  		return "", errors.Errorf("unsupported column type - %v", column)
   337  	}
   338  }
   339  
   340  func execSQLs(db *sql.DB, sqls []string) error {
   341  	for _, sql := range sqls {
   342  		err := execSQL(db, sql)
   343  		if err != nil {
   344  			return errors.Trace(err)
   345  		}
   346  	}
   347  	return nil
   348  }
   349  
   350  func execSQL(db *sql.DB, sql string) error {
   351  	if len(sql) == 0 {
   352  		return nil
   353  	}
   354  
   355  	_, err := db.Exec(sql)
   356  	if err != nil {
   357  		return errors.Trace(err)
   358  	}
   359  
   360  	return nil
   361  }
   362  
   363  // RunTest will call writeSrc and check if src is contisitent with dst
   364  func RunTest(src *sql.DB, dst *sql.DB, schema string, writeSrc func(src *sql.DB)) {
   365  	writeSrc(src)
   366  
   367  	tick := time.NewTicker(time.Second * 5)
   368  	defer tick.Stop()
   369  	timeout := time.After(time.Second * 240)
   370  
   371  	oldLevel := log.GetLevel()
   372  	defer log.SetLevel(oldLevel)
   373  
   374  	for {
   375  		select {
   376  		case <-tick.C:
   377  			log.SetLevel(zapcore.WarnLevel)
   378  			if util.CheckSyncState(src, dst, schema) {
   379  				return
   380  			}
   381  		case <-timeout:
   382  			// check last time
   383  			log.SetLevel(zapcore.InfoLevel)
   384  			if !util.CheckSyncState(src, dst, schema) {
   385  				log.S().Fatal("sourceDB don't equal targetDB")
   386  			}
   387  
   388  			return
   389  		}
   390  	}
   391  }