github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/syncer/validator_cond_test.go (about)

     1  // Copyright 2022 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package syncer
    15  
    16  import (
    17  	"database/sql"
    18  	"fmt"
    19  	"strconv"
    20  	"testing"
    21  
    22  	"github.com/DATA-DOG/go-sqlmock"
    23  	"github.com/pingcap/tidb/pkg/parser"
    24  	"github.com/pingcap/tidb/pkg/parser/model"
    25  	"github.com/pingcap/tidb/pkg/util/dbutil/dbutiltest"
    26  	"github.com/pingcap/tidb/pkg/util/filter"
    27  	"github.com/stretchr/testify/require"
    28  )
    29  
    30  func genValidateTableInfo(t *testing.T, creatSQL string) *model.TableInfo {
    31  	t.Helper()
    32  	var (
    33  		err       error
    34  		parser2   *parser.Parser
    35  		tableInfo *model.TableInfo
    36  	)
    37  	parser2 = parser.New()
    38  	require.NoError(t, err)
    39  	tableInfo, err = dbutiltest.GetTableInfoBySQL(creatSQL, parser2)
    40  	require.NoError(t, err)
    41  	return tableInfo
    42  }
    43  
    44  func genValidationCond(t *testing.T, schemaName, tblName, creatSQL string, pkvs [][]string) *Cond {
    45  	t.Helper()
    46  	tbl := filter.Table{Schema: schemaName, Name: tblName}
    47  	tblInfo := genValidateTableInfo(t, creatSQL)
    48  	return &Cond{
    49  		TargetTbl: tbl.String(),
    50  		Columns:   tblInfo.Columns,
    51  		PK:        tblInfo.Indices[0],
    52  		PkValues:  pkvs,
    53  	}
    54  }
    55  
    56  func TestValidatorCondSelectMultiKey(t *testing.T) {
    57  	var res *sql.Rows
    58  	db, mock, err := sqlmock.New()
    59  	require.NoError(t, err)
    60  	defer db.Close()
    61  	creatTbl := "create table if not exists `test_cond`.`test1`(" +
    62  		"a int," +
    63  		"b int," +
    64  		"c int," +
    65  		"primary key(a, b)" +
    66  		");"
    67  	// get table diff
    68  	pkValues := make([][]string, 0)
    69  	for i := 0; i < 3; i++ {
    70  		// 3 primary key
    71  		key1, key2 := strconv.Itoa(i+1), strconv.Itoa(i+2)
    72  		pkValues = append(pkValues, []string{key1, key2})
    73  	}
    74  	cond := genValidationCond(t, "test_cond", "test1", creatTbl, pkValues)
    75  	// format query string
    76  	rowsQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s;", "`test_cond`.`test1`", cond.GetWhere())
    77  	mock.ExpectQuery(
    78  		"SELECT COUNT\\(\\*\\) FROM `test_cond`.`test1` WHERE \\(a,b\\) in \\(\\(\\?,\\?\\),\\(\\?,\\?\\),\\(\\?,\\?\\)\\);",
    79  	).WithArgs(
    80  		"1", "2", "2", "3", "3", "4",
    81  	).WillReturnRows(mock.NewRows([]string{"COUNT(*)"}).AddRow("3"))
    82  	require.NoError(t, err)
    83  	res, err = db.Query(rowsQuery, cond.GetArgs()...)
    84  	require.NoError(t, err)
    85  	defer res.Close()
    86  	var cnt int
    87  	if res.Next() {
    88  		err = res.Scan(&cnt)
    89  	}
    90  	require.NoError(t, err)
    91  	require.Equal(t, 3, cnt)
    92  	require.NoError(t, res.Err())
    93  }
    94  
    95  func TestValidatorCondGetWhereArgs(t *testing.T) {
    96  	db, _, err := sqlmock.New()
    97  	require.NoError(t, err)
    98  	defer db.Close()
    99  	type testCase struct {
   100  		creatTbl   string
   101  		pks        [][]string
   102  		tblName    string
   103  		schemaName string
   104  		args       []string
   105  		where      string
   106  	}
   107  	cases := []testCase{
   108  		{
   109  			creatTbl: `create table if not exists test_cond.test2(
   110  				a char(10),
   111  				b int,
   112  				c int,
   113  				primary key(a)
   114  				);`, // single primary key,
   115  			pks: [][]string{
   116  				{"10a0"}, {"200"}, {"abc"},
   117  			},
   118  			tblName:    "test2",
   119  			schemaName: "test_cond",
   120  			where:      "a in (?,?,?)",
   121  			args: []string{
   122  				"10a0", "200", "abc",
   123  			},
   124  		},
   125  		{
   126  			creatTbl: `create table if not exists test_cond.test3(
   127  				a int,
   128  				b char(10),
   129  				c varchar(10),
   130  				primary key(a, b, c)
   131  				);`, // multi primary key
   132  			pks: [][]string{
   133  				{"10", "abc", "ef"},
   134  				{"9897", "afdkiefkjg", "acdee"},
   135  			},
   136  			tblName:    "test3",
   137  			schemaName: "test_cond",
   138  			where:      "(a,b,c) in ((?,?,?),(?,?,?))",
   139  			args: []string{
   140  				"10", "abc", "ef", "9897", "afdkiefkjg", "acdee",
   141  			},
   142  		},
   143  	}
   144  	for i := 0; i < len(cases); i++ {
   145  		cond := genValidationCond(t, cases[i].schemaName, cases[i].tblName, cases[i].creatTbl, cases[i].pks)
   146  		require.Equal(t, cases[i].where, cond.GetWhere())
   147  		rawArgs := cond.GetArgs()
   148  		for j := 0; j < 3; j++ {
   149  			curData := fmt.Sprintf("%v", rawArgs[j])
   150  			require.Equal(t, cases[i].args[j], curData)
   151  		}
   152  	}
   153  }