github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/pkg/conn/db_test.go (about) 1 // Copyright 2019 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package conn 15 16 import ( 17 "context" 18 "fmt" 19 "strconv" 20 "testing" 21 "time" 22 23 "github.com/DATA-DOG/go-sqlmock" 24 "github.com/coreos/go-semver/semver" 25 gmysql "github.com/go-mysql-org/go-mysql/mysql" 26 "github.com/go-sql-driver/mysql" 27 "github.com/pingcap/errors" 28 tmysql "github.com/pingcap/tidb/pkg/parser/mysql" 29 "github.com/pingcap/tidb/pkg/util/filter" 30 regexprrouter "github.com/pingcap/tidb/pkg/util/regexpr-router" 31 router "github.com/pingcap/tidb/pkg/util/table-router" 32 tcontext "github.com/pingcap/tiflow/dm/pkg/context" 33 "github.com/pingcap/tiflow/dm/pkg/gtid" 34 "github.com/pingcap/tiflow/dm/pkg/log" 35 "github.com/stretchr/testify/require" 36 ) 37 38 func TestGetFlavor(t *testing.T) { 39 t.Parallel() 40 41 db, mock, err := sqlmock.New() 42 require.NoError(t, err) 43 44 // MySQL 45 mock.ExpectQuery(`SHOW GLOBAL VARIABLES LIKE 'version';`).WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.31-log")) 46 flavor, err := GetFlavor(context.Background(), NewBaseDBForTest(db)) 47 require.NoError(t, err) 48 require.Equal(t, "mysql", flavor) 49 require.NoError(t, mock.ExpectationsWereMet()) 50 51 // MariaDB 52 mock.ExpectQuery(`SHOW GLOBAL VARIABLES LIKE 'version';`).WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "10.13.1-MariaDB-1~wheezy")) 53 flavor, err = GetFlavor(context.Background(), NewBaseDBForTest(db)) 54 require.NoError(t, err) 55 require.Equal(t, "mariadb", flavor) 56 require.NoError(t, mock.ExpectationsWereMet()) 57 58 // others 59 mock.ExpectQuery(`SHOW GLOBAL VARIABLES LIKE 'version';`).WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "unknown")) 60 flavor, err = GetFlavor(context.Background(), NewBaseDBForTest(db)) 61 require.NoError(t, err) 62 require.Equal(t, "mysql", flavor) // as MySQL 63 require.NoError(t, mock.ExpectationsWereMet()) 64 } 65 66 func TestGetRandomServerID(t *testing.T) { 67 t.Parallel() 68 69 db, mock, err := sqlmock.New() 70 require.NoError(t, err) 71 72 tctx := tcontext.NewContext(context.Background(), log.L()) 73 createMockResult(mock, 1, []uint32{100, 101}, "mysql") 74 serverID, err := GetRandomServerID(tctx, NewBaseDBForTest(db)) 75 require.NoError(t, err) 76 require.Greater(t, serverID, uint32(0)) 77 require.NoError(t, mock.ExpectationsWereMet()) 78 require.NotEqual(t, 1, serverID) 79 require.NotEqual(t, 100, serverID) 80 require.NotEqual(t, 101, serverID) 81 } 82 83 func TestGetMariaDBGtidDomainID(t *testing.T) { 84 t.Parallel() 85 86 ctx, cancel := context.WithTimeout(context.Background(), DefaultDBTimeout) 87 defer cancel() 88 tctx := tcontext.NewContext(ctx, log.L()) 89 90 db, mock, err := sqlmock.New() 91 require.NoError(t, err) 92 93 rows := mock.NewRows([]string{"Variable_name", "Value"}).AddRow("gtid_domain_id", 101) 94 mock.ExpectQuery(`SHOW GLOBAL VARIABLES LIKE 'gtid_domain_id'`).WillReturnRows(rows) 95 96 dID, err := GetMariaDBGtidDomainID(tctx, NewBaseDBForTest(db)) 97 require.NoError(t, err) 98 require.Equal(t, uint32(101), dID) 99 require.NoError(t, mock.ExpectationsWereMet()) 100 } 101 102 func TestGetServerUUID(t *testing.T) { 103 t.Parallel() 104 105 ctx, cancel := context.WithTimeout(context.Background(), DefaultDBTimeout) 106 defer cancel() 107 tctx := tcontext.NewContext(ctx, log.L()) 108 109 db, mock, err := sqlmock.New() 110 require.NoError(t, err) 111 112 // MySQL 113 rows := mock.NewRows([]string{"Variable_name", "Value"}).AddRow("server_uuid", "074be7f4-f0f1-11ea-95bd-0242ac120002") 114 mock.ExpectQuery(`SHOW GLOBAL VARIABLES LIKE 'server_uuid'`).WillReturnRows(rows) 115 uuid, err := GetServerUUID(tctx, NewBaseDBForTest(db), "mysql") 116 require.NoError(t, err) 117 require.Equal(t, "074be7f4-f0f1-11ea-95bd-0242ac120002", uuid) 118 require.NoError(t, mock.ExpectationsWereMet()) 119 120 // MariaDB 121 rows = mock.NewRows([]string{"Variable_name", "Value"}).AddRow("gtid_domain_id", 123) 122 mock.ExpectQuery(`SHOW GLOBAL VARIABLES LIKE 'gtid_domain_id'`).WillReturnRows(rows) 123 rows = mock.NewRows([]string{"Variable_name", "Value"}).AddRow("server_id", 456) 124 mock.ExpectQuery(`SHOW GLOBAL VARIABLES LIKE 'server_id'`).WillReturnRows(rows) 125 uuid, err = GetServerUUID(tctx, NewBaseDBForTest(db), "mariadb") 126 require.NoError(t, err) 127 require.Equal(t, "123-456", uuid) 128 require.NoError(t, mock.ExpectationsWereMet()) 129 } 130 131 func TestGetServerUnixTS(t *testing.T) { 132 t.Parallel() 133 134 ctx := context.Background() 135 136 db, mock, err := sqlmock.New() 137 require.NoError(t, err) 138 139 ts := time.Now().Unix() 140 rows := sqlmock.NewRows([]string{"UNIX_TIMESTAMP()"}).AddRow(strconv.FormatInt(ts, 10)) 141 mock.ExpectQuery("SELECT UNIX_TIMESTAMP()").WillReturnRows(rows) 142 143 ts2, err := GetServerUnixTS(ctx, NewBaseDBForTest(db)) 144 require.NoError(t, err) 145 require.Equal(t, ts2, ts) 146 require.NoError(t, mock.ExpectationsWereMet()) 147 } 148 149 func TestGetParser(t *testing.T) { 150 t.Parallel() 151 152 ctx, cancel := context.WithTimeout(context.Background(), DefaultDBTimeout) 153 defer cancel() 154 tctx := tcontext.NewContext(ctx, log.L()) 155 156 var ( 157 DDL1 = `ALTER TABLE tbl ADD COLUMN c1 INT` 158 DDL2 = `ALTER TABLE tbl ADD COLUMN 'c1' INT` 159 DDL3 = `ALTER TABLE tbl ADD COLUMN "c1" INT` 160 ) 161 162 db, mock, err := sqlmock.New() 163 require.NoError(t, err) 164 165 // no `ANSI_QUOTES` 166 rows := mock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "") 167 mock.ExpectQuery(`SHOW VARIABLES LIKE 'sql_mode'`).WillReturnRows(rows) 168 p, err := GetParser(tctx, NewBaseDBForTest(db)) 169 require.NoError(t, err) 170 _, err = p.ParseOneStmt(DDL1, "", "") 171 require.NoError(t, err) 172 _, err = p.ParseOneStmt(DDL2, "", "") 173 require.Error(t, err) 174 _, err = p.ParseOneStmt(DDL3, "", "") 175 require.Error(t, err) 176 require.NoError(t, mock.ExpectationsWereMet()) 177 178 // `ANSI_QUOTES` 179 rows = mock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "ANSI_QUOTES") 180 mock.ExpectQuery(`SHOW VARIABLES LIKE 'sql_mode'`).WillReturnRows(rows) 181 p, err = GetParser(tctx, NewBaseDBForTest(db)) 182 require.NoError(t, err) 183 _, err = p.ParseOneStmt(DDL1, "", "") 184 require.NoError(t, err) 185 _, err = p.ParseOneStmt(DDL2, "", "") 186 require.Error(t, err) 187 _, err = p.ParseOneStmt(DDL3, "", "") 188 require.NoError(t, err) 189 require.NoError(t, mock.ExpectationsWereMet()) 190 } 191 192 func TestGetGTID(t *testing.T) { 193 t.Parallel() 194 195 ctx, cancel := context.WithTimeout(context.Background(), DefaultDBTimeout) 196 defer cancel() 197 tctx := tcontext.NewContext(ctx, log.L()) 198 199 db, mock, err := sqlmock.New() 200 require.NoError(t, err) 201 202 rows := mock.NewRows([]string{"Variable_name", "Value"}).AddRow("GTID_MODE", "ON") 203 mock.ExpectQuery(`SHOW GLOBAL VARIABLES LIKE 'GTID_MODE'`).WillReturnRows(rows) 204 mode, err := GetGTIDMode(tctx, NewBaseDBForTest(db)) 205 require.NoError(t, err) 206 require.Equal(t, "ON", mode) 207 require.NoError(t, mock.ExpectationsWereMet()) 208 } 209 210 func TestMySQLError(t *testing.T) { 211 t.Parallel() 212 213 err := newMysqlErr(tmysql.ErrNoSuchThread, "Unknown thread id: 111") 214 require.Equal(t, true, IsNoSuchThreadError(err)) 215 216 err = newMysqlErr(tmysql.ErrMasterFatalErrorReadingBinlog, "binlog purged error") 217 require.Equal(t, true, IsErrBinlogPurged(err)) 218 219 err = newMysqlErr(tmysql.ErrDupEntry, "Duplicate entry '123456' for key 'index'") 220 require.Equal(t, true, IsErrDuplicateEntry(err)) 221 } 222 223 func TestGetAllServerID(t *testing.T) { 224 t.Parallel() 225 226 testCases := []struct { 227 masterID uint32 228 serverIDs []uint32 229 }{ 230 { 231 1, 232 []uint32{2, 3, 4}, 233 }, { 234 2, 235 []uint32{}, 236 }, { 237 4294967295, // max server-id. 238 []uint32{}, 239 }, 240 } 241 242 db, mock, err := sqlmock.New() 243 require.NoError(t, err) 244 245 flavors := []string{gmysql.MariaDBFlavor, gmysql.MySQLFlavor} 246 247 tctx := tcontext.NewContext(context.Background(), log.L()) 248 for _, testCase := range testCases { 249 for _, flavor := range flavors { 250 createMockResult(mock, testCase.masterID, testCase.serverIDs, flavor) 251 serverIDs, err2 := GetAllServerID(tctx, NewBaseDBForTest(db)) 252 require.NoError(t, err2) 253 254 for _, serverID := range testCase.serverIDs { 255 _, ok := serverIDs[serverID] 256 require.True(t, ok) 257 } 258 259 _, ok := serverIDs[testCase.masterID] 260 require.True(t, ok) 261 } 262 } 263 264 err = mock.ExpectationsWereMet() 265 require.NoError(t, err) 266 } 267 268 func createMockResult(mock sqlmock.Sqlmock, masterID uint32, serverIDs []uint32, flavor string) { 269 expectQuery := mock.ExpectQuery("SHOW SLAVE HOSTS") 270 271 host := "test" 272 port := 3306 273 slaveUUID := "test" 274 275 if flavor == gmysql.MariaDBFlavor { 276 rows := sqlmock.NewRows([]string{"Server_id", "Host", "Port", "Master_id"}) 277 for _, serverID := range serverIDs { 278 rows.AddRow(serverID, host, port, masterID) 279 } 280 expectQuery.WillReturnRows(rows) 281 } else { 282 rows := sqlmock.NewRows([]string{"Server_id", "Host", "Port", "Master_id", "Slave_UUID"}) 283 for _, serverID := range serverIDs { 284 rows.AddRow(serverID, host, port, masterID, slaveUUID) 285 } 286 expectQuery.WillReturnRows(rows) 287 } 288 289 mock.ExpectQuery("SHOW GLOBAL VARIABLES LIKE 'server_id'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("server_id", masterID)) 290 } 291 292 func newMysqlErr(number uint16, message string) *mysql.MySQLError { 293 return &mysql.MySQLError{ 294 Number: number, 295 Message: message, 296 } 297 } 298 299 func TestTiDBVersion(t *testing.T) { 300 t.Parallel() 301 302 testCases := []struct { 303 version string 304 result *semver.Version 305 err error 306 }{ 307 { 308 "wrong-version", 309 semver.New("0.0.0"), 310 errors.Errorf("not a valid TiDB version: %s", "wrong-version"), 311 }, { 312 "5.7.31-log", 313 semver.New("0.0.0"), 314 errors.Errorf("not a valid TiDB version: %s", "5.7.31-log"), 315 }, { 316 "5.7.25-TiDB-v3.1.2", 317 semver.New("3.1.2"), 318 nil, 319 }, { 320 "5.7.25-TiDB-v4.0.0-beta.2-1293-g0843f32c0-dirty", 321 semver.New("4.0.00-beta.2"), 322 nil, 323 }, 324 } 325 326 for _, tc := range testCases { 327 tidbVer, err := ExtractTiDBVersion(tc.version) 328 if tc.err != nil { 329 require.Error(t, err) 330 require.Equal(t, tc.err.Error(), err.Error()) 331 } else { 332 require.Equal(t, tc.result, tidbVer) 333 } 334 } 335 } 336 337 func getGSetFromString(t *testing.T, s string) gmysql.GTIDSet { 338 t.Helper() 339 gSet, err := gtid.ParserGTID("mysql", s) 340 require.NoError(t, err) 341 return gSet 342 } 343 344 func TestAddGSetWithPurged(t *testing.T) { 345 t.Parallel() 346 347 db, mock, err := sqlmock.New() 348 require.NoError(t, err) 349 mariaGTID, err := gtid.ParserGTID("mariadb", "1-2-100") 350 require.NoError(t, err) 351 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 352 defer cancel() 353 baseDB := NewBaseDBForTest(db) 354 conn, err := baseDB.GetBaseConn(ctx) 355 require.NoError(t, err) 356 defer baseDB.ForceCloseConnWithoutErr(conn) 357 358 testCases := []struct { 359 originGSet gmysql.GTIDSet 360 purgedSet gmysql.GTIDSet 361 expectedSet gmysql.GTIDSet 362 err error 363 }{ 364 { 365 getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:6-14"), 366 getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:1-5"), 367 getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:1-14"), 368 nil, 369 }, { 370 getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:2-6"), 371 getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:1"), 372 getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:1-6"), 373 nil, 374 }, { 375 getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:1-6"), 376 getGSetFromString(t, "53bfca22-690d-11e7-8a62-18ded7a37b78:1-495"), 377 getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:1-6,53bfca22-690d-11e7-8a62-18ded7a37b78:1-495"), 378 nil, 379 }, { 380 getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:6-14"), 381 mariaGTID, 382 nil, 383 errors.New("invalid GTID format, must UUID:interval[:interval]"), 384 }, 385 } 386 387 for _, tc := range testCases { 388 mock.ExpectQuery("select @@GLOBAL.gtid_purged").WillReturnRows( 389 sqlmock.NewRows([]string{"@@GLOBAL.gtid_purged"}).AddRow(tc.purgedSet.String())) 390 originSet := tc.originGSet.Clone() 391 newSet, err := AddGSetWithPurged(ctx, originSet, conn) 392 require.True(t, errors.ErrorEqual(err, tc.err)) 393 require.Equal(t, tc.expectedSet, newSet) 394 // make sure origin gSet hasn't changed 395 require.Equal(t, tc.originGSet, originSet) 396 } 397 } 398 399 func TestGetMaxConnections(t *testing.T) { 400 t.Parallel() 401 402 ctx, cancel := context.WithTimeout(context.Background(), DefaultDBTimeout) 403 defer cancel() 404 tctx := tcontext.NewContext(ctx, log.L()) 405 406 db, mock, err := sqlmock.New() 407 require.NoError(t, err) 408 409 rows := mock.NewRows([]string{"Variable_name", "Value"}).AddRow("max_connections", "151") 410 mock.ExpectQuery(`SHOW VARIABLES LIKE 'max_connections'`).WillReturnRows(rows) 411 maxConnections, err := GetMaxConnections(tctx, NewBaseDBForTest(db)) 412 require.NoError(t, err) 413 require.Equal(t, 151, maxConnections) 414 require.NoError(t, mock.ExpectationsWereMet()) 415 } 416 417 func TestIsMariaDB(t *testing.T) { 418 t.Parallel() 419 420 require.True(t, IsMariaDB("5.5.50-MariaDB-1~wheezy")) 421 require.False(t, IsMariaDB("5.7.19-17-log")) 422 } 423 424 func TestCreateTableSQLToOneRow(t *testing.T) { 425 t.Parallel() 426 427 input := "CREATE TABLE `t1` (\n `id` bigint(20) NOT NULL,\n `c1` varchar(20) DEFAULT NULL,\n `c2` varchar(20) DEFAULT NULL,\n PRIMARY KEY (`id`) /*T![clustered_index] NONCLUSTERED */\n) ENGINE=InnoDB DEFAULT CHARSET=latin1 COLLATE=latin1_bin" 428 expected := "CREATE TABLE `t1` ( `id` bigint(20) NOT NULL, `c1` varchar(20) DEFAULT NULL, `c2` varchar(20) DEFAULT NULL, PRIMARY KEY (`id`) /*T![clustered_index] NONCLUSTERED */) ENGINE=InnoDB DEFAULT CHARSET=latin1 COLLATE=latin1_bin" 429 require.Equal(t, expected, CreateTableSQLToOneRow(input)) 430 } 431 432 func TestGetSlaveServerID(t *testing.T) { 433 t.Parallel() 434 435 db, mock, err := sqlmock.New() 436 require.NoError(t, err) 437 438 cases := []struct { 439 rows *sqlmock.Rows 440 results map[uint32]struct{} 441 }{ 442 // For MySQL 443 { 444 sqlmock.NewRows([]string{"Server_id", "Host", "Port", "Master_id", "Slave_UUID"}). 445 AddRow(192168010, "iconnect2", 3306, 192168011, "14cb6624-7f93-11e0-b2c0-c80aa9429562"). 446 AddRow(1921680101, "athena", 3306, 192168011, "07af4990-f41f-11df-a566-7ac56fdaf645"), 447 map[uint32]struct{}{ 448 192168010: {}, 1921680101: {}, 449 }, 450 }, 451 // For MariaDB 452 { 453 sqlmock.NewRows([]string{"Server_id", "Host", "Port", "Master_id"}). 454 AddRow(192168010, "iconnect2", 3306, 192168011). 455 AddRow(1921680101, "athena", 3306, 192168011), 456 map[uint32]struct{}{ 457 192168010: {}, 1921680101: {}, 458 }, 459 }, 460 // For MariaDB, with Server_id greater than 2^31, to test uint conversion 461 { 462 sqlmock.NewRows([]string{"Server_id", "Host", "Port", "Master_id"}). 463 AddRow(2147483649, "iconnect2", 3306, 192168011). 464 AddRow(2147483650, "athena", 3306, 192168011), 465 map[uint32]struct{}{ 466 2147483649: {}, 2147483650: {}, 467 }, 468 }, 469 } 470 471 tctx := tcontext.NewContext(context.Background(), log.L()) 472 for _, ca := range cases { 473 mock.ExpectQuery("SHOW SLAVE HOSTS").WillReturnRows(ca.rows) 474 results, err2 := GetSlaveServerID(tctx, NewBaseDBForTest(db)) 475 require.NoError(t, err2) 476 require.Equal(t, ca.results, results) 477 } 478 } 479 480 func TestFetchAllDoTables(t *testing.T) { 481 t.Parallel() 482 483 db, mock, err := sqlmock.New() 484 require.NoError(t, err) 485 486 // empty filter, exclude system schemas 487 ba, err := filter.New(false, nil) 488 require.NoError(t, err) 489 490 // no schemas need to do. 491 mock.ExpectQuery(`SHOW DATABASES`).WillReturnRows(sqlmock.NewRows([]string{"Database"})) 492 got, err := FetchAllDoTables(context.Background(), NewBaseDBForTest(db), ba) 493 require.NoError(t, err) 494 require.Len(t, got, 0) 495 require.NoError(t, mock.ExpectationsWereMet()) 496 497 // only system schemas exist, still no need to do. 498 schemas := []string{"information_schema", "mysql", "performance_schema", "sys", filter.DMHeartbeatSchema} 499 rows := sqlmock.NewRows([]string{"Database"}) 500 addRowsForSchemas(rows, schemas) 501 mock.ExpectQuery(`SHOW DATABASES`).WillReturnRows(rows) 502 got, err = FetchAllDoTables(context.Background(), NewBaseDBForTest(db), ba) 503 require.NoError(t, err) 504 require.Len(t, got, 0) 505 require.NoError(t, mock.ExpectationsWereMet()) 506 507 // schemas without tables in them. 508 doSchema := "test_db" 509 schemas = []string{"information_schema", "mysql", "performance_schema", "sys", filter.DMHeartbeatSchema, doSchema} 510 rows = sqlmock.NewRows([]string{"Database"}) 511 addRowsForSchemas(rows, schemas) 512 mock.ExpectQuery(`SHOW DATABASES`).WillReturnRows(rows) 513 mock.ExpectQuery(fmt.Sprintf("SHOW FULL TABLES IN `%s` WHERE Table_Type != 'VIEW'", doSchema)).WillReturnRows( 514 sqlmock.NewRows([]string{fmt.Sprintf("Tables_in_%s", doSchema), "Table_type"})) 515 got, err = FetchAllDoTables(context.Background(), NewBaseDBForTest(db), ba) 516 require.NoError(t, err) 517 require.Len(t, got, 0) 518 require.NoError(t, mock.ExpectationsWereMet()) 519 520 // do all tables under the schema. 521 rows = sqlmock.NewRows([]string{"Database"}) 522 addRowsForSchemas(rows, schemas) 523 mock.ExpectQuery(`SHOW DATABASES`).WillReturnRows(rows) 524 tables := []string{"tbl1", "tbl2", "exclude_tbl"} 525 rows = sqlmock.NewRows([]string{fmt.Sprintf("Tables_in_%s", doSchema), "Table_type"}) 526 addRowsForTables(rows, tables) 527 mock.ExpectQuery(fmt.Sprintf("SHOW FULL TABLES IN `%s` WHERE Table_Type != 'VIEW'", doSchema)).WillReturnRows(rows) 528 got, err = FetchAllDoTables(context.Background(), NewBaseDBForTest(db), ba) 529 require.NoError(t, err) 530 require.Len(t, got, 1) 531 require.Equal(t, tables, got[doSchema]) 532 require.NoError(t, mock.ExpectationsWereMet()) 533 534 // use a block-allow-list to fiter some tables 535 ba, err = filter.New(false, &filter.Rules{ 536 DoDBs: []string{doSchema}, 537 DoTables: []*filter.Table{ 538 {Schema: doSchema, Name: "tbl1"}, 539 {Schema: doSchema, Name: "tbl2"}, 540 }, 541 }) 542 require.NoError(t, err) 543 544 rows = sqlmock.NewRows([]string{"Database"}) 545 addRowsForSchemas(rows, schemas) 546 mock.ExpectQuery(`SHOW DATABASES`).WillReturnRows(rows) 547 rows = sqlmock.NewRows([]string{fmt.Sprintf("Tables_in_%s", doSchema), "Table_type"}) 548 addRowsForTables(rows, tables) 549 mock.ExpectQuery(fmt.Sprintf("SHOW FULL TABLES IN `%s` WHERE Table_Type != 'VIEW'", doSchema)).WillReturnRows(rows) 550 got, err = FetchAllDoTables(context.Background(), NewBaseDBForTest(db), ba) 551 require.NoError(t, err) 552 require.Len(t, got, 1) 553 require.Equal(t, []string{"tbl1", "tbl2"}, got[doSchema]) 554 require.NoError(t, mock.ExpectationsWereMet()) 555 } 556 557 func TestFetchTargetDoTables(t *testing.T) { 558 t.Parallel() 559 560 db, mock, err := sqlmock.New() 561 require.NoError(t, err) 562 563 // empty filter and router, just as upstream. 564 ba, err := filter.New(false, nil) 565 require.NoError(t, err) 566 r, err := regexprrouter.NewRegExprRouter(false, nil) 567 require.NoError(t, err) 568 569 schemas := []string{"shard1"} 570 rows := sqlmock.NewRows([]string{"Database"}) 571 addRowsForSchemas(rows, schemas) 572 mock.ExpectQuery(`SHOW DATABASES`).WillReturnRows(rows) 573 574 tablesM := map[string][]string{ 575 "shard1": {"tbl1", "tbl2"}, 576 } 577 for schema, tables := range tablesM { 578 rows = sqlmock.NewRows([]string{fmt.Sprintf("Tables_in_%s", schema), "Table_type"}) 579 addRowsForTables(rows, tables) 580 mock.ExpectQuery(fmt.Sprintf("SHOW FULL TABLES IN `%s` WHERE Table_Type != 'VIEW'", schema)).WillReturnRows(rows) 581 } 582 583 tablesMap, extendedCols, err := FetchTargetDoTables(context.Background(), "", NewBaseDBForTest(db), ba, r) 584 require.NoError(t, err) 585 require.Equal(t, map[filter.Table][]filter.Table{ 586 {Schema: "shard1", Name: "tbl1"}: {{Schema: "shard1", Name: "tbl1"}}, 587 {Schema: "shard1", Name: "tbl2"}: {{Schema: "shard1", Name: "tbl2"}}, 588 }, tablesMap) 589 require.Len(t, extendedCols, 0) 590 require.NoError(t, mock.ExpectationsWereMet()) 591 592 // route to the same downstream. 593 r, err = regexprrouter.NewRegExprRouter(false, []*router.TableRule{ 594 {SchemaPattern: "shard*", TablePattern: "tbl*", TargetSchema: "shard", TargetTable: "tbl"}, 595 }) 596 require.NoError(t, err) 597 598 rows = sqlmock.NewRows([]string{"Database"}) 599 addRowsForSchemas(rows, schemas) 600 mock.ExpectQuery(`SHOW DATABASES`).WillReturnRows(rows) 601 for schema, tables := range tablesM { 602 rows = sqlmock.NewRows([]string{fmt.Sprintf("Tables_in_%s", schema), "Table_type"}) 603 addRowsForTables(rows, tables) 604 mock.ExpectQuery(fmt.Sprintf("SHOW FULL TABLES IN `%s` WHERE Table_Type != 'VIEW'", schema)).WillReturnRows(rows) 605 } 606 607 tablesMap, extendedCols, err = FetchTargetDoTables(context.Background(), "", NewBaseDBForTest(db), ba, r) 608 require.NoError(t, err) 609 require.Equal(t, map[filter.Table][]filter.Table{ 610 {Schema: "shard", Name: "tbl"}: { 611 {Schema: "shard1", Name: "tbl1"}, 612 {Schema: "shard1", Name: "tbl2"}, 613 }, 614 }, tablesMap) 615 require.Len(t, extendedCols, 0) 616 require.NoError(t, mock.ExpectationsWereMet()) 617 } 618 619 func addRowsForSchemas(rows *sqlmock.Rows, schemas []string) { 620 for _, d := range schemas { 621 rows.AddRow(d) 622 } 623 } 624 625 func addRowsForTables(rows *sqlmock.Rows, tables []string) { 626 for _, table := range tables { 627 rows.AddRow(table, "BASE TABLE") 628 } 629 }