github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/pkg/conn/db_test.go (about)

     1  // Copyright 2019 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package conn
    15  
    16  import (
    17  	"context"
    18  	"fmt"
    19  	"strconv"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/DATA-DOG/go-sqlmock"
    24  	"github.com/coreos/go-semver/semver"
    25  	gmysql "github.com/go-mysql-org/go-mysql/mysql"
    26  	"github.com/go-sql-driver/mysql"
    27  	"github.com/pingcap/errors"
    28  	tmysql "github.com/pingcap/tidb/pkg/parser/mysql"
    29  	"github.com/pingcap/tidb/pkg/util/filter"
    30  	regexprrouter "github.com/pingcap/tidb/pkg/util/regexpr-router"
    31  	router "github.com/pingcap/tidb/pkg/util/table-router"
    32  	tcontext "github.com/pingcap/tiflow/dm/pkg/context"
    33  	"github.com/pingcap/tiflow/dm/pkg/gtid"
    34  	"github.com/pingcap/tiflow/dm/pkg/log"
    35  	"github.com/stretchr/testify/require"
    36  )
    37  
    38  func TestGetFlavor(t *testing.T) {
    39  	t.Parallel()
    40  
    41  	db, mock, err := sqlmock.New()
    42  	require.NoError(t, err)
    43  
    44  	// MySQL
    45  	mock.ExpectQuery(`SHOW GLOBAL VARIABLES LIKE 'version';`).WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "5.7.31-log"))
    46  	flavor, err := GetFlavor(context.Background(), NewBaseDBForTest(db))
    47  	require.NoError(t, err)
    48  	require.Equal(t, "mysql", flavor)
    49  	require.NoError(t, mock.ExpectationsWereMet())
    50  
    51  	// MariaDB
    52  	mock.ExpectQuery(`SHOW GLOBAL VARIABLES LIKE 'version';`).WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "10.13.1-MariaDB-1~wheezy"))
    53  	flavor, err = GetFlavor(context.Background(), NewBaseDBForTest(db))
    54  	require.NoError(t, err)
    55  	require.Equal(t, "mariadb", flavor)
    56  	require.NoError(t, mock.ExpectationsWereMet())
    57  
    58  	// others
    59  	mock.ExpectQuery(`SHOW GLOBAL VARIABLES LIKE 'version';`).WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("version", "unknown"))
    60  	flavor, err = GetFlavor(context.Background(), NewBaseDBForTest(db))
    61  	require.NoError(t, err)
    62  	require.Equal(t, "mysql", flavor) // as MySQL
    63  	require.NoError(t, mock.ExpectationsWereMet())
    64  }
    65  
    66  func TestGetRandomServerID(t *testing.T) {
    67  	t.Parallel()
    68  
    69  	db, mock, err := sqlmock.New()
    70  	require.NoError(t, err)
    71  
    72  	tctx := tcontext.NewContext(context.Background(), log.L())
    73  	createMockResult(mock, 1, []uint32{100, 101}, "mysql")
    74  	serverID, err := GetRandomServerID(tctx, NewBaseDBForTest(db))
    75  	require.NoError(t, err)
    76  	require.Greater(t, serverID, uint32(0))
    77  	require.NoError(t, mock.ExpectationsWereMet())
    78  	require.NotEqual(t, 1, serverID)
    79  	require.NotEqual(t, 100, serverID)
    80  	require.NotEqual(t, 101, serverID)
    81  }
    82  
    83  func TestGetMariaDBGtidDomainID(t *testing.T) {
    84  	t.Parallel()
    85  
    86  	ctx, cancel := context.WithTimeout(context.Background(), DefaultDBTimeout)
    87  	defer cancel()
    88  	tctx := tcontext.NewContext(ctx, log.L())
    89  
    90  	db, mock, err := sqlmock.New()
    91  	require.NoError(t, err)
    92  
    93  	rows := mock.NewRows([]string{"Variable_name", "Value"}).AddRow("gtid_domain_id", 101)
    94  	mock.ExpectQuery(`SHOW GLOBAL VARIABLES LIKE 'gtid_domain_id'`).WillReturnRows(rows)
    95  
    96  	dID, err := GetMariaDBGtidDomainID(tctx, NewBaseDBForTest(db))
    97  	require.NoError(t, err)
    98  	require.Equal(t, uint32(101), dID)
    99  	require.NoError(t, mock.ExpectationsWereMet())
   100  }
   101  
   102  func TestGetServerUUID(t *testing.T) {
   103  	t.Parallel()
   104  
   105  	ctx, cancel := context.WithTimeout(context.Background(), DefaultDBTimeout)
   106  	defer cancel()
   107  	tctx := tcontext.NewContext(ctx, log.L())
   108  
   109  	db, mock, err := sqlmock.New()
   110  	require.NoError(t, err)
   111  
   112  	// MySQL
   113  	rows := mock.NewRows([]string{"Variable_name", "Value"}).AddRow("server_uuid", "074be7f4-f0f1-11ea-95bd-0242ac120002")
   114  	mock.ExpectQuery(`SHOW GLOBAL VARIABLES LIKE 'server_uuid'`).WillReturnRows(rows)
   115  	uuid, err := GetServerUUID(tctx, NewBaseDBForTest(db), "mysql")
   116  	require.NoError(t, err)
   117  	require.Equal(t, "074be7f4-f0f1-11ea-95bd-0242ac120002", uuid)
   118  	require.NoError(t, mock.ExpectationsWereMet())
   119  
   120  	// MariaDB
   121  	rows = mock.NewRows([]string{"Variable_name", "Value"}).AddRow("gtid_domain_id", 123)
   122  	mock.ExpectQuery(`SHOW GLOBAL VARIABLES LIKE 'gtid_domain_id'`).WillReturnRows(rows)
   123  	rows = mock.NewRows([]string{"Variable_name", "Value"}).AddRow("server_id", 456)
   124  	mock.ExpectQuery(`SHOW GLOBAL VARIABLES LIKE 'server_id'`).WillReturnRows(rows)
   125  	uuid, err = GetServerUUID(tctx, NewBaseDBForTest(db), "mariadb")
   126  	require.NoError(t, err)
   127  	require.Equal(t, "123-456", uuid)
   128  	require.NoError(t, mock.ExpectationsWereMet())
   129  }
   130  
   131  func TestGetServerUnixTS(t *testing.T) {
   132  	t.Parallel()
   133  
   134  	ctx := context.Background()
   135  
   136  	db, mock, err := sqlmock.New()
   137  	require.NoError(t, err)
   138  
   139  	ts := time.Now().Unix()
   140  	rows := sqlmock.NewRows([]string{"UNIX_TIMESTAMP()"}).AddRow(strconv.FormatInt(ts, 10))
   141  	mock.ExpectQuery("SELECT UNIX_TIMESTAMP()").WillReturnRows(rows)
   142  
   143  	ts2, err := GetServerUnixTS(ctx, NewBaseDBForTest(db))
   144  	require.NoError(t, err)
   145  	require.Equal(t, ts2, ts)
   146  	require.NoError(t, mock.ExpectationsWereMet())
   147  }
   148  
   149  func TestGetParser(t *testing.T) {
   150  	t.Parallel()
   151  
   152  	ctx, cancel := context.WithTimeout(context.Background(), DefaultDBTimeout)
   153  	defer cancel()
   154  	tctx := tcontext.NewContext(ctx, log.L())
   155  
   156  	var (
   157  		DDL1 = `ALTER TABLE tbl ADD COLUMN c1 INT`
   158  		DDL2 = `ALTER TABLE tbl ADD COLUMN 'c1' INT`
   159  		DDL3 = `ALTER TABLE tbl ADD COLUMN "c1" INT`
   160  	)
   161  
   162  	db, mock, err := sqlmock.New()
   163  	require.NoError(t, err)
   164  
   165  	// no `ANSI_QUOTES`
   166  	rows := mock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "")
   167  	mock.ExpectQuery(`SHOW VARIABLES LIKE 'sql_mode'`).WillReturnRows(rows)
   168  	p, err := GetParser(tctx, NewBaseDBForTest(db))
   169  	require.NoError(t, err)
   170  	_, err = p.ParseOneStmt(DDL1, "", "")
   171  	require.NoError(t, err)
   172  	_, err = p.ParseOneStmt(DDL2, "", "")
   173  	require.Error(t, err)
   174  	_, err = p.ParseOneStmt(DDL3, "", "")
   175  	require.Error(t, err)
   176  	require.NoError(t, mock.ExpectationsWereMet())
   177  
   178  	// `ANSI_QUOTES`
   179  	rows = mock.NewRows([]string{"Variable_name", "Value"}).AddRow("sql_mode", "ANSI_QUOTES")
   180  	mock.ExpectQuery(`SHOW VARIABLES LIKE 'sql_mode'`).WillReturnRows(rows)
   181  	p, err = GetParser(tctx, NewBaseDBForTest(db))
   182  	require.NoError(t, err)
   183  	_, err = p.ParseOneStmt(DDL1, "", "")
   184  	require.NoError(t, err)
   185  	_, err = p.ParseOneStmt(DDL2, "", "")
   186  	require.Error(t, err)
   187  	_, err = p.ParseOneStmt(DDL3, "", "")
   188  	require.NoError(t, err)
   189  	require.NoError(t, mock.ExpectationsWereMet())
   190  }
   191  
   192  func TestGetGTID(t *testing.T) {
   193  	t.Parallel()
   194  
   195  	ctx, cancel := context.WithTimeout(context.Background(), DefaultDBTimeout)
   196  	defer cancel()
   197  	tctx := tcontext.NewContext(ctx, log.L())
   198  
   199  	db, mock, err := sqlmock.New()
   200  	require.NoError(t, err)
   201  
   202  	rows := mock.NewRows([]string{"Variable_name", "Value"}).AddRow("GTID_MODE", "ON")
   203  	mock.ExpectQuery(`SHOW GLOBAL VARIABLES LIKE 'GTID_MODE'`).WillReturnRows(rows)
   204  	mode, err := GetGTIDMode(tctx, NewBaseDBForTest(db))
   205  	require.NoError(t, err)
   206  	require.Equal(t, "ON", mode)
   207  	require.NoError(t, mock.ExpectationsWereMet())
   208  }
   209  
   210  func TestMySQLError(t *testing.T) {
   211  	t.Parallel()
   212  
   213  	err := newMysqlErr(tmysql.ErrNoSuchThread, "Unknown thread id: 111")
   214  	require.Equal(t, true, IsNoSuchThreadError(err))
   215  
   216  	err = newMysqlErr(tmysql.ErrMasterFatalErrorReadingBinlog, "binlog purged error")
   217  	require.Equal(t, true, IsErrBinlogPurged(err))
   218  
   219  	err = newMysqlErr(tmysql.ErrDupEntry, "Duplicate entry '123456' for key 'index'")
   220  	require.Equal(t, true, IsErrDuplicateEntry(err))
   221  }
   222  
   223  func TestGetAllServerID(t *testing.T) {
   224  	t.Parallel()
   225  
   226  	testCases := []struct {
   227  		masterID  uint32
   228  		serverIDs []uint32
   229  	}{
   230  		{
   231  			1,
   232  			[]uint32{2, 3, 4},
   233  		}, {
   234  			2,
   235  			[]uint32{},
   236  		}, {
   237  			4294967295, // max server-id.
   238  			[]uint32{},
   239  		},
   240  	}
   241  
   242  	db, mock, err := sqlmock.New()
   243  	require.NoError(t, err)
   244  
   245  	flavors := []string{gmysql.MariaDBFlavor, gmysql.MySQLFlavor}
   246  
   247  	tctx := tcontext.NewContext(context.Background(), log.L())
   248  	for _, testCase := range testCases {
   249  		for _, flavor := range flavors {
   250  			createMockResult(mock, testCase.masterID, testCase.serverIDs, flavor)
   251  			serverIDs, err2 := GetAllServerID(tctx, NewBaseDBForTest(db))
   252  			require.NoError(t, err2)
   253  
   254  			for _, serverID := range testCase.serverIDs {
   255  				_, ok := serverIDs[serverID]
   256  				require.True(t, ok)
   257  			}
   258  
   259  			_, ok := serverIDs[testCase.masterID]
   260  			require.True(t, ok)
   261  		}
   262  	}
   263  
   264  	err = mock.ExpectationsWereMet()
   265  	require.NoError(t, err)
   266  }
   267  
   268  func createMockResult(mock sqlmock.Sqlmock, masterID uint32, serverIDs []uint32, flavor string) {
   269  	expectQuery := mock.ExpectQuery("SHOW SLAVE HOSTS")
   270  
   271  	host := "test"
   272  	port := 3306
   273  	slaveUUID := "test"
   274  
   275  	if flavor == gmysql.MariaDBFlavor {
   276  		rows := sqlmock.NewRows([]string{"Server_id", "Host", "Port", "Master_id"})
   277  		for _, serverID := range serverIDs {
   278  			rows.AddRow(serverID, host, port, masterID)
   279  		}
   280  		expectQuery.WillReturnRows(rows)
   281  	} else {
   282  		rows := sqlmock.NewRows([]string{"Server_id", "Host", "Port", "Master_id", "Slave_UUID"})
   283  		for _, serverID := range serverIDs {
   284  			rows.AddRow(serverID, host, port, masterID, slaveUUID)
   285  		}
   286  		expectQuery.WillReturnRows(rows)
   287  	}
   288  
   289  	mock.ExpectQuery("SHOW GLOBAL VARIABLES LIKE 'server_id'").WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).AddRow("server_id", masterID))
   290  }
   291  
   292  func newMysqlErr(number uint16, message string) *mysql.MySQLError {
   293  	return &mysql.MySQLError{
   294  		Number:  number,
   295  		Message: message,
   296  	}
   297  }
   298  
   299  func TestTiDBVersion(t *testing.T) {
   300  	t.Parallel()
   301  
   302  	testCases := []struct {
   303  		version string
   304  		result  *semver.Version
   305  		err     error
   306  	}{
   307  		{
   308  			"wrong-version",
   309  			semver.New("0.0.0"),
   310  			errors.Errorf("not a valid TiDB version: %s", "wrong-version"),
   311  		}, {
   312  			"5.7.31-log",
   313  			semver.New("0.0.0"),
   314  			errors.Errorf("not a valid TiDB version: %s", "5.7.31-log"),
   315  		}, {
   316  			"5.7.25-TiDB-v3.1.2",
   317  			semver.New("3.1.2"),
   318  			nil,
   319  		}, {
   320  			"5.7.25-TiDB-v4.0.0-beta.2-1293-g0843f32c0-dirty",
   321  			semver.New("4.0.00-beta.2"),
   322  			nil,
   323  		},
   324  	}
   325  
   326  	for _, tc := range testCases {
   327  		tidbVer, err := ExtractTiDBVersion(tc.version)
   328  		if tc.err != nil {
   329  			require.Error(t, err)
   330  			require.Equal(t, tc.err.Error(), err.Error())
   331  		} else {
   332  			require.Equal(t, tc.result, tidbVer)
   333  		}
   334  	}
   335  }
   336  
   337  func getGSetFromString(t *testing.T, s string) gmysql.GTIDSet {
   338  	t.Helper()
   339  	gSet, err := gtid.ParserGTID("mysql", s)
   340  	require.NoError(t, err)
   341  	return gSet
   342  }
   343  
   344  func TestAddGSetWithPurged(t *testing.T) {
   345  	t.Parallel()
   346  
   347  	db, mock, err := sqlmock.New()
   348  	require.NoError(t, err)
   349  	mariaGTID, err := gtid.ParserGTID("mariadb", "1-2-100")
   350  	require.NoError(t, err)
   351  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   352  	defer cancel()
   353  	baseDB := NewBaseDBForTest(db)
   354  	conn, err := baseDB.GetBaseConn(ctx)
   355  	require.NoError(t, err)
   356  	defer baseDB.ForceCloseConnWithoutErr(conn)
   357  
   358  	testCases := []struct {
   359  		originGSet  gmysql.GTIDSet
   360  		purgedSet   gmysql.GTIDSet
   361  		expectedSet gmysql.GTIDSet
   362  		err         error
   363  	}{
   364  		{
   365  			getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:6-14"),
   366  			getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:1-5"),
   367  			getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:1-14"),
   368  			nil,
   369  		}, {
   370  			getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:2-6"),
   371  			getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:1"),
   372  			getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:1-6"),
   373  			nil,
   374  		}, {
   375  			getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:1-6"),
   376  			getGSetFromString(t, "53bfca22-690d-11e7-8a62-18ded7a37b78:1-495"),
   377  			getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:1-6,53bfca22-690d-11e7-8a62-18ded7a37b78:1-495"),
   378  			nil,
   379  		}, {
   380  			getGSetFromString(t, "3ccc475b-2343-11e7-be21-6c0b84d59f30:6-14"),
   381  			mariaGTID,
   382  			nil,
   383  			errors.New("invalid GTID format, must UUID:interval[:interval]"),
   384  		},
   385  	}
   386  
   387  	for _, tc := range testCases {
   388  		mock.ExpectQuery("select @@GLOBAL.gtid_purged").WillReturnRows(
   389  			sqlmock.NewRows([]string{"@@GLOBAL.gtid_purged"}).AddRow(tc.purgedSet.String()))
   390  		originSet := tc.originGSet.Clone()
   391  		newSet, err := AddGSetWithPurged(ctx, originSet, conn)
   392  		require.True(t, errors.ErrorEqual(err, tc.err))
   393  		require.Equal(t, tc.expectedSet, newSet)
   394  		// make sure origin gSet hasn't changed
   395  		require.Equal(t, tc.originGSet, originSet)
   396  	}
   397  }
   398  
   399  func TestGetMaxConnections(t *testing.T) {
   400  	t.Parallel()
   401  
   402  	ctx, cancel := context.WithTimeout(context.Background(), DefaultDBTimeout)
   403  	defer cancel()
   404  	tctx := tcontext.NewContext(ctx, log.L())
   405  
   406  	db, mock, err := sqlmock.New()
   407  	require.NoError(t, err)
   408  
   409  	rows := mock.NewRows([]string{"Variable_name", "Value"}).AddRow("max_connections", "151")
   410  	mock.ExpectQuery(`SHOW VARIABLES LIKE 'max_connections'`).WillReturnRows(rows)
   411  	maxConnections, err := GetMaxConnections(tctx, NewBaseDBForTest(db))
   412  	require.NoError(t, err)
   413  	require.Equal(t, 151, maxConnections)
   414  	require.NoError(t, mock.ExpectationsWereMet())
   415  }
   416  
   417  func TestIsMariaDB(t *testing.T) {
   418  	t.Parallel()
   419  
   420  	require.True(t, IsMariaDB("5.5.50-MariaDB-1~wheezy"))
   421  	require.False(t, IsMariaDB("5.7.19-17-log"))
   422  }
   423  
   424  func TestCreateTableSQLToOneRow(t *testing.T) {
   425  	t.Parallel()
   426  
   427  	input := "CREATE TABLE `t1` (\n  `id` bigint(20) NOT NULL,\n  `c1` varchar(20) DEFAULT NULL,\n  `c2` varchar(20) DEFAULT NULL,\n  PRIMARY KEY (`id`) /*T![clustered_index] NONCLUSTERED */\n) ENGINE=InnoDB DEFAULT CHARSET=latin1 COLLATE=latin1_bin"
   428  	expected := "CREATE TABLE `t1` ( `id` bigint(20) NOT NULL, `c1` varchar(20) DEFAULT NULL, `c2` varchar(20) DEFAULT NULL, PRIMARY KEY (`id`) /*T![clustered_index] NONCLUSTERED */) ENGINE=InnoDB DEFAULT CHARSET=latin1 COLLATE=latin1_bin"
   429  	require.Equal(t, expected, CreateTableSQLToOneRow(input))
   430  }
   431  
   432  func TestGetSlaveServerID(t *testing.T) {
   433  	t.Parallel()
   434  
   435  	db, mock, err := sqlmock.New()
   436  	require.NoError(t, err)
   437  
   438  	cases := []struct {
   439  		rows    *sqlmock.Rows
   440  		results map[uint32]struct{}
   441  	}{
   442  		// For MySQL
   443  		{
   444  			sqlmock.NewRows([]string{"Server_id", "Host", "Port", "Master_id", "Slave_UUID"}).
   445  				AddRow(192168010, "iconnect2", 3306, 192168011, "14cb6624-7f93-11e0-b2c0-c80aa9429562").
   446  				AddRow(1921680101, "athena", 3306, 192168011, "07af4990-f41f-11df-a566-7ac56fdaf645"),
   447  			map[uint32]struct{}{
   448  				192168010: {}, 1921680101: {},
   449  			},
   450  		},
   451  		// For MariaDB
   452  		{
   453  			sqlmock.NewRows([]string{"Server_id", "Host", "Port", "Master_id"}).
   454  				AddRow(192168010, "iconnect2", 3306, 192168011).
   455  				AddRow(1921680101, "athena", 3306, 192168011),
   456  			map[uint32]struct{}{
   457  				192168010: {}, 1921680101: {},
   458  			},
   459  		},
   460  		// For MariaDB, with Server_id greater than 2^31, to test uint conversion
   461  		{
   462  			sqlmock.NewRows([]string{"Server_id", "Host", "Port", "Master_id"}).
   463  				AddRow(2147483649, "iconnect2", 3306, 192168011).
   464  				AddRow(2147483650, "athena", 3306, 192168011),
   465  			map[uint32]struct{}{
   466  				2147483649: {}, 2147483650: {},
   467  			},
   468  		},
   469  	}
   470  
   471  	tctx := tcontext.NewContext(context.Background(), log.L())
   472  	for _, ca := range cases {
   473  		mock.ExpectQuery("SHOW SLAVE HOSTS").WillReturnRows(ca.rows)
   474  		results, err2 := GetSlaveServerID(tctx, NewBaseDBForTest(db))
   475  		require.NoError(t, err2)
   476  		require.Equal(t, ca.results, results)
   477  	}
   478  }
   479  
   480  func TestFetchAllDoTables(t *testing.T) {
   481  	t.Parallel()
   482  
   483  	db, mock, err := sqlmock.New()
   484  	require.NoError(t, err)
   485  
   486  	// empty filter, exclude system schemas
   487  	ba, err := filter.New(false, nil)
   488  	require.NoError(t, err)
   489  
   490  	// no schemas need to do.
   491  	mock.ExpectQuery(`SHOW DATABASES`).WillReturnRows(sqlmock.NewRows([]string{"Database"}))
   492  	got, err := FetchAllDoTables(context.Background(), NewBaseDBForTest(db), ba)
   493  	require.NoError(t, err)
   494  	require.Len(t, got, 0)
   495  	require.NoError(t, mock.ExpectationsWereMet())
   496  
   497  	// only system schemas exist, still no need to do.
   498  	schemas := []string{"information_schema", "mysql", "performance_schema", "sys", filter.DMHeartbeatSchema}
   499  	rows := sqlmock.NewRows([]string{"Database"})
   500  	addRowsForSchemas(rows, schemas)
   501  	mock.ExpectQuery(`SHOW DATABASES`).WillReturnRows(rows)
   502  	got, err = FetchAllDoTables(context.Background(), NewBaseDBForTest(db), ba)
   503  	require.NoError(t, err)
   504  	require.Len(t, got, 0)
   505  	require.NoError(t, mock.ExpectationsWereMet())
   506  
   507  	// schemas without tables in them.
   508  	doSchema := "test_db"
   509  	schemas = []string{"information_schema", "mysql", "performance_schema", "sys", filter.DMHeartbeatSchema, doSchema}
   510  	rows = sqlmock.NewRows([]string{"Database"})
   511  	addRowsForSchemas(rows, schemas)
   512  	mock.ExpectQuery(`SHOW DATABASES`).WillReturnRows(rows)
   513  	mock.ExpectQuery(fmt.Sprintf("SHOW FULL TABLES IN `%s` WHERE Table_Type != 'VIEW'", doSchema)).WillReturnRows(
   514  		sqlmock.NewRows([]string{fmt.Sprintf("Tables_in_%s", doSchema), "Table_type"}))
   515  	got, err = FetchAllDoTables(context.Background(), NewBaseDBForTest(db), ba)
   516  	require.NoError(t, err)
   517  	require.Len(t, got, 0)
   518  	require.NoError(t, mock.ExpectationsWereMet())
   519  
   520  	// do all tables under the schema.
   521  	rows = sqlmock.NewRows([]string{"Database"})
   522  	addRowsForSchemas(rows, schemas)
   523  	mock.ExpectQuery(`SHOW DATABASES`).WillReturnRows(rows)
   524  	tables := []string{"tbl1", "tbl2", "exclude_tbl"}
   525  	rows = sqlmock.NewRows([]string{fmt.Sprintf("Tables_in_%s", doSchema), "Table_type"})
   526  	addRowsForTables(rows, tables)
   527  	mock.ExpectQuery(fmt.Sprintf("SHOW FULL TABLES IN `%s` WHERE Table_Type != 'VIEW'", doSchema)).WillReturnRows(rows)
   528  	got, err = FetchAllDoTables(context.Background(), NewBaseDBForTest(db), ba)
   529  	require.NoError(t, err)
   530  	require.Len(t, got, 1)
   531  	require.Equal(t, tables, got[doSchema])
   532  	require.NoError(t, mock.ExpectationsWereMet())
   533  
   534  	// use a block-allow-list to fiter some tables
   535  	ba, err = filter.New(false, &filter.Rules{
   536  		DoDBs: []string{doSchema},
   537  		DoTables: []*filter.Table{
   538  			{Schema: doSchema, Name: "tbl1"},
   539  			{Schema: doSchema, Name: "tbl2"},
   540  		},
   541  	})
   542  	require.NoError(t, err)
   543  
   544  	rows = sqlmock.NewRows([]string{"Database"})
   545  	addRowsForSchemas(rows, schemas)
   546  	mock.ExpectQuery(`SHOW DATABASES`).WillReturnRows(rows)
   547  	rows = sqlmock.NewRows([]string{fmt.Sprintf("Tables_in_%s", doSchema), "Table_type"})
   548  	addRowsForTables(rows, tables)
   549  	mock.ExpectQuery(fmt.Sprintf("SHOW FULL TABLES IN `%s` WHERE Table_Type != 'VIEW'", doSchema)).WillReturnRows(rows)
   550  	got, err = FetchAllDoTables(context.Background(), NewBaseDBForTest(db), ba)
   551  	require.NoError(t, err)
   552  	require.Len(t, got, 1)
   553  	require.Equal(t, []string{"tbl1", "tbl2"}, got[doSchema])
   554  	require.NoError(t, mock.ExpectationsWereMet())
   555  }
   556  
   557  func TestFetchTargetDoTables(t *testing.T) {
   558  	t.Parallel()
   559  
   560  	db, mock, err := sqlmock.New()
   561  	require.NoError(t, err)
   562  
   563  	// empty filter and router, just as upstream.
   564  	ba, err := filter.New(false, nil)
   565  	require.NoError(t, err)
   566  	r, err := regexprrouter.NewRegExprRouter(false, nil)
   567  	require.NoError(t, err)
   568  
   569  	schemas := []string{"shard1"}
   570  	rows := sqlmock.NewRows([]string{"Database"})
   571  	addRowsForSchemas(rows, schemas)
   572  	mock.ExpectQuery(`SHOW DATABASES`).WillReturnRows(rows)
   573  
   574  	tablesM := map[string][]string{
   575  		"shard1": {"tbl1", "tbl2"},
   576  	}
   577  	for schema, tables := range tablesM {
   578  		rows = sqlmock.NewRows([]string{fmt.Sprintf("Tables_in_%s", schema), "Table_type"})
   579  		addRowsForTables(rows, tables)
   580  		mock.ExpectQuery(fmt.Sprintf("SHOW FULL TABLES IN `%s` WHERE Table_Type != 'VIEW'", schema)).WillReturnRows(rows)
   581  	}
   582  
   583  	tablesMap, extendedCols, err := FetchTargetDoTables(context.Background(), "", NewBaseDBForTest(db), ba, r)
   584  	require.NoError(t, err)
   585  	require.Equal(t, map[filter.Table][]filter.Table{
   586  		{Schema: "shard1", Name: "tbl1"}: {{Schema: "shard1", Name: "tbl1"}},
   587  		{Schema: "shard1", Name: "tbl2"}: {{Schema: "shard1", Name: "tbl2"}},
   588  	}, tablesMap)
   589  	require.Len(t, extendedCols, 0)
   590  	require.NoError(t, mock.ExpectationsWereMet())
   591  
   592  	// route to the same downstream.
   593  	r, err = regexprrouter.NewRegExprRouter(false, []*router.TableRule{
   594  		{SchemaPattern: "shard*", TablePattern: "tbl*", TargetSchema: "shard", TargetTable: "tbl"},
   595  	})
   596  	require.NoError(t, err)
   597  
   598  	rows = sqlmock.NewRows([]string{"Database"})
   599  	addRowsForSchemas(rows, schemas)
   600  	mock.ExpectQuery(`SHOW DATABASES`).WillReturnRows(rows)
   601  	for schema, tables := range tablesM {
   602  		rows = sqlmock.NewRows([]string{fmt.Sprintf("Tables_in_%s", schema), "Table_type"})
   603  		addRowsForTables(rows, tables)
   604  		mock.ExpectQuery(fmt.Sprintf("SHOW FULL TABLES IN `%s` WHERE Table_Type != 'VIEW'", schema)).WillReturnRows(rows)
   605  	}
   606  
   607  	tablesMap, extendedCols, err = FetchTargetDoTables(context.Background(), "", NewBaseDBForTest(db), ba, r)
   608  	require.NoError(t, err)
   609  	require.Equal(t, map[filter.Table][]filter.Table{
   610  		{Schema: "shard", Name: "tbl"}: {
   611  			{Schema: "shard1", Name: "tbl1"},
   612  			{Schema: "shard1", Name: "tbl2"},
   613  		},
   614  	}, tablesMap)
   615  	require.Len(t, extendedCols, 0)
   616  	require.NoError(t, mock.ExpectationsWereMet())
   617  }
   618  
   619  func addRowsForSchemas(rows *sqlmock.Rows, schemas []string) {
   620  	for _, d := range schemas {
   621  		rows.AddRow(d)
   622  	}
   623  }
   624  
   625  func addRowsForTables(rows *sqlmock.Rows, tables []string) {
   626  	for _, table := range tables {
   627  		rows.AddRow(table, "BASE TABLE")
   628  	}
   629  }