github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/syncer/validator_cond_test.go (about) 1 // Copyright 2022 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package syncer 15 16 import ( 17 "database/sql" 18 "fmt" 19 "strconv" 20 "testing" 21 22 "github.com/DATA-DOG/go-sqlmock" 23 "github.com/pingcap/tidb/pkg/parser" 24 "github.com/pingcap/tidb/pkg/parser/model" 25 "github.com/pingcap/tidb/pkg/util/dbutil/dbutiltest" 26 "github.com/pingcap/tidb/pkg/util/filter" 27 "github.com/stretchr/testify/require" 28 ) 29 30 func genValidateTableInfo(t *testing.T, creatSQL string) *model.TableInfo { 31 t.Helper() 32 var ( 33 err error 34 parser2 *parser.Parser 35 tableInfo *model.TableInfo 36 ) 37 parser2 = parser.New() 38 require.NoError(t, err) 39 tableInfo, err = dbutiltest.GetTableInfoBySQL(creatSQL, parser2) 40 require.NoError(t, err) 41 return tableInfo 42 } 43 44 func genValidationCond(t *testing.T, schemaName, tblName, creatSQL string, pkvs [][]string) *Cond { 45 t.Helper() 46 tbl := filter.Table{Schema: schemaName, Name: tblName} 47 tblInfo := genValidateTableInfo(t, creatSQL) 48 return &Cond{ 49 TargetTbl: tbl.String(), 50 Columns: tblInfo.Columns, 51 PK: tblInfo.Indices[0], 52 PkValues: pkvs, 53 } 54 } 55 56 func TestValidatorCondSelectMultiKey(t *testing.T) { 57 var res *sql.Rows 58 db, mock, err := sqlmock.New() 59 require.NoError(t, err) 60 defer db.Close() 61 creatTbl := "create table if not exists `test_cond`.`test1`(" + 62 "a int," + 63 "b int," + 64 "c int," + 65 "primary key(a, b)" + 66 ");" 67 // get table diff 68 pkValues := make([][]string, 0) 69 for i := 0; i < 3; i++ { 70 // 3 primary key 71 key1, key2 := strconv.Itoa(i+1), strconv.Itoa(i+2) 72 pkValues = append(pkValues, []string{key1, key2}) 73 } 74 cond := genValidationCond(t, "test_cond", "test1", creatTbl, pkValues) 75 // format query string 76 rowsQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s;", "`test_cond`.`test1`", cond.GetWhere()) 77 mock.ExpectQuery( 78 "SELECT COUNT\\(\\*\\) FROM `test_cond`.`test1` WHERE \\(a,b\\) in \\(\\(\\?,\\?\\),\\(\\?,\\?\\),\\(\\?,\\?\\)\\);", 79 ).WithArgs( 80 "1", "2", "2", "3", "3", "4", 81 ).WillReturnRows(mock.NewRows([]string{"COUNT(*)"}).AddRow("3")) 82 require.NoError(t, err) 83 res, err = db.Query(rowsQuery, cond.GetArgs()...) 84 require.NoError(t, err) 85 defer res.Close() 86 var cnt int 87 if res.Next() { 88 err = res.Scan(&cnt) 89 } 90 require.NoError(t, err) 91 require.Equal(t, 3, cnt) 92 require.NoError(t, res.Err()) 93 } 94 95 func TestValidatorCondGetWhereArgs(t *testing.T) { 96 db, _, err := sqlmock.New() 97 require.NoError(t, err) 98 defer db.Close() 99 type testCase struct { 100 creatTbl string 101 pks [][]string 102 tblName string 103 schemaName string 104 args []string 105 where string 106 } 107 cases := []testCase{ 108 { 109 creatTbl: `create table if not exists test_cond.test2( 110 a char(10), 111 b int, 112 c int, 113 primary key(a) 114 );`, // single primary key, 115 pks: [][]string{ 116 {"10a0"}, {"200"}, {"abc"}, 117 }, 118 tblName: "test2", 119 schemaName: "test_cond", 120 where: "a in (?,?,?)", 121 args: []string{ 122 "10a0", "200", "abc", 123 }, 124 }, 125 { 126 creatTbl: `create table if not exists test_cond.test3( 127 a int, 128 b char(10), 129 c varchar(10), 130 primary key(a, b, c) 131 );`, // multi primary key 132 pks: [][]string{ 133 {"10", "abc", "ef"}, 134 {"9897", "afdkiefkjg", "acdee"}, 135 }, 136 tblName: "test3", 137 schemaName: "test_cond", 138 where: "(a,b,c) in ((?,?,?),(?,?,?))", 139 args: []string{ 140 "10", "abc", "ef", "9897", "afdkiefkjg", "acdee", 141 }, 142 }, 143 } 144 for i := 0; i < len(cases); i++ { 145 cond := genValidationCond(t, cases[i].schemaName, cases[i].tblName, cases[i].creatTbl, cases[i].pks) 146 require.Equal(t, cases[i].where, cond.GetWhere()) 147 rawArgs := cond.GetArgs() 148 for j := 0; j < 3; j++ { 149 curData := fmt.Sprintf("%v", rawArgs[j]) 150 require.Equal(t, cases[i].args[j], curData) 151 } 152 } 153 }