github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/pkg/schema/tracker_test.go (about)

     1  // Copyright 2019 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package schema
    15  
    16  import (
    17  	"context"
    18  	"encoding/json"
    19  	"fmt"
    20  	"sort"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/DATA-DOG/go-sqlmock"
    25  	"github.com/pingcap/tidb/pkg/ddl"
    26  	"github.com/pingcap/tidb/pkg/parser"
    27  	"github.com/pingcap/tidb/pkg/parser/ast"
    28  	"github.com/pingcap/tidb/pkg/parser/model"
    29  	"github.com/pingcap/tidb/pkg/parser/mysql"
    30  	"github.com/pingcap/tidb/pkg/util/filter"
    31  	timock "github.com/pingcap/tidb/pkg/util/mock"
    32  	"github.com/pingcap/tiflow/dm/config"
    33  	"github.com/pingcap/tiflow/dm/pkg/conn"
    34  	tcontext "github.com/pingcap/tiflow/dm/pkg/context"
    35  	dlog "github.com/pingcap/tiflow/dm/pkg/log"
    36  	"github.com/pingcap/tiflow/dm/pkg/terror"
    37  	"github.com/pingcap/tiflow/dm/syncer/dbconn"
    38  	"github.com/stretchr/testify/require"
    39  )
    40  
    41  func parseSQL(t *testing.T, p *parser.Parser, sql string) ast.StmtNode {
    42  	t.Helper()
    43  
    44  	ret, err := p.ParseOneStmt(sql, "", "")
    45  	require.NoError(t, err)
    46  	return ret
    47  }
    48  
    49  func TestNeedSessionCfgInOldImpl(t *testing.T) {
    50  	ctx := context.Background()
    51  	table := &filter.Table{
    52  		Schema: "testdb",
    53  		Name:   "foo",
    54  	}
    55  
    56  	tracker, err := NewTestTracker(context.Background(), "test-tracker", nil, dlog.L())
    57  	require.NoError(t, err)
    58  
    59  	p := parser.New()
    60  
    61  	err = tracker.Exec(context.Background(), "", parseSQL(t, p, "create database testdb;"))
    62  	require.NoError(t, err)
    63  	tracker.Close()
    64  
    65  	createAST := parseSQL(t, p, "create table foo (a varchar(255) primary key, b DATETIME NOT NULL DEFAULT '0000-00-00 00:00:00')")
    66  
    67  	// test create table with zero datetime
    68  	err = tracker.Exec(ctx, "testdb", createAST)
    69  	require.NoError(t, err)
    70  	err = tracker.DropTable(table)
    71  	require.NoError(t, err)
    72  
    73  	// caller should set SQL Mode through status vars.
    74  	p.SetSQLMode(mysql.ModeANSIQuotes)
    75  	createAST = parseSQL(t, p, "create table \"foo\" (a varchar(255) primary key, b DATETIME NOT NULL DEFAULT '0000-00-00 00:00:00')")
    76  
    77  	// Now create the table with ANSI_QUOTES and ZERO_DATE
    78  	err = tracker.Exec(ctx, "testdb", createAST)
    79  	require.NoError(t, err)
    80  
    81  	sql, err := tracker.GetCreateTable(context.Background(), table)
    82  	require.NoError(t, err)
    83  	// the result is not ANSI_QUOTES
    84  	require.Equal(t, "CREATE TABLE `foo` ( `a` varchar(255) NOT NULL, `b` datetime NOT NULL DEFAULT '0000-00-00 00:00:00', PRIMARY KEY (`a`) /*T![clustered_index] NONCLUSTERED */) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin", sql)
    85  
    86  	// test alter primary key
    87  	alterAST := parseSQL(t, p, "alter table \"foo\" drop primary key")
    88  	err = tracker.Exec(ctx, "testdb", alterAST)
    89  	require.NoError(t, err)
    90  
    91  	tracker.Close()
    92  }
    93  
    94  func TestDDL(t *testing.T) {
    95  	table := &filter.Table{
    96  		Schema: "testdb",
    97  		Name:   "foo",
    98  	}
    99  
   100  	ctx := context.Background()
   101  	p := parser.New()
   102  
   103  	tracker, err := NewTestTracker(ctx, "test-tracker", nil, dlog.L())
   104  	require.NoError(t, err)
   105  	defer tracker.Close()
   106  
   107  	// Table shouldn't exist before initialization.
   108  	_, err = tracker.GetTableInfo(table)
   109  	require.ErrorContains(t, err, "Unknown database 'testdb'")
   110  	require.True(t, IsTableNotExists(err))
   111  
   112  	_, err = tracker.GetCreateTable(ctx, table)
   113  	require.ErrorContains(t, err, "Unknown database 'testdb'")
   114  	require.True(t, IsTableNotExists(err))
   115  
   116  	err = tracker.Exec(ctx, "", parseSQL(t, p, "create database testdb;"))
   117  	require.NoError(t, err)
   118  
   119  	_, err = tracker.GetTableInfo(table)
   120  	require.ErrorContains(t, err, "Table 'testdb.foo' doesn't exist")
   121  	require.True(t, IsTableNotExists(err))
   122  
   123  	// Now create the table with 3 columns.
   124  	createAST := parseSQL(t, p, "create table foo (a varchar(255) primary key, b varchar(255) as (concat(a, a)), c int)")
   125  	err = tracker.Exec(ctx, "testdb", createAST)
   126  	require.NoError(t, err)
   127  
   128  	// Verify the table has 3 columns.
   129  	ti, err := tracker.GetTableInfo(table)
   130  	require.NoError(t, err)
   131  	require.Len(t, ti.Columns, 3)
   132  	require.Equal(t, "a", ti.Columns[0].Name.O)
   133  	require.False(t, ti.Columns[0].IsGenerated())
   134  	require.Equal(t, "b", ti.Columns[1].Name.O)
   135  	require.True(t, ti.Columns[1].IsGenerated())
   136  	require.Equal(t, "c", ti.Columns[2].Name.O)
   137  	require.False(t, ti.Columns[2].IsGenerated())
   138  
   139  	sql, err := tracker.GetCreateTable(ctx, table)
   140  	require.NoError(t, err)
   141  	require.Equal(t, "CREATE TABLE `foo` ( `a` varchar(255) NOT NULL, `b` varchar(255) GENERATED ALWAYS AS (concat(`a`, `a`)) VIRTUAL, `c` int(11) DEFAULT NULL, PRIMARY KEY (`a`) /*T![clustered_index] NONCLUSTERED */) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin", sql)
   142  
   143  	// Drop one column from the table.
   144  	err = tracker.Exec(ctx, "testdb", parseSQL(t, p, "alter table foo drop column b"))
   145  	require.NoError(t, err)
   146  
   147  	// Verify that 2 columns remain.
   148  	ti, err = tracker.GetTableInfo(table)
   149  	require.NoError(t, err)
   150  	require.Len(t, ti.Columns, 2)
   151  	require.Equal(t, "a", ti.Columns[0].Name.O)
   152  	require.False(t, ti.Columns[0].IsGenerated())
   153  	require.Equal(t, "c", ti.Columns[1].Name.O)
   154  	require.False(t, ti.Columns[1].IsGenerated())
   155  
   156  	sql, err = tracker.GetCreateTable(ctx, table)
   157  	require.NoError(t, err)
   158  	require.Equal(t, "CREATE TABLE `foo` ( `a` varchar(255) NOT NULL, `c` int(11) DEFAULT NULL, PRIMARY KEY (`a`) /*T![clustered_index] NONCLUSTERED */) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin", sql)
   159  
   160  	// test expression index on tidb_shard.
   161  	createAST = parseSQL(t, p, "CREATE TABLE bar (f_id INT PRIMARY KEY, UNIQUE KEY uniq_order_id ((tidb_shard(f_id)),f_id))")
   162  	err = tracker.Exec(ctx, "testdb", createAST)
   163  	require.NoError(t, err)
   164  }
   165  
   166  func TestGetSingleColumnIndices(t *testing.T) {
   167  	ctx := context.Background()
   168  	p := parser.New()
   169  
   170  	tracker, err := NewTestTracker(ctx, "test-tracker", nil, dlog.L())
   171  	require.NoError(t, err)
   172  	defer tracker.Close()
   173  
   174  	err = tracker.Exec(ctx, "", parseSQL(t, p, "create database testdb;"))
   175  	require.NoError(t, err)
   176  	err = tracker.Exec(ctx, "testdb", parseSQL(t, p, "create table foo (a int, b int, c int)"))
   177  	require.NoError(t, err)
   178  
   179  	// check GetSingleColumnIndices could return all legal indices
   180  	err = tracker.Exec(ctx, "testdb", parseSQL(t, p, "alter table foo add index idx_a1(a)"))
   181  	require.NoError(t, err)
   182  	err = tracker.Exec(ctx, "testdb", parseSQL(t, p, "alter table foo add index idx_a2(a)"))
   183  	require.NoError(t, err)
   184  	infos, err := tracker.GetSingleColumnIndices("testdb", "foo", "a")
   185  	require.NoError(t, err)
   186  	require.Len(t, infos, 2)
   187  	names := []string{infos[0].Name.L, infos[1].Name.L}
   188  	sort.Strings(names)
   189  	require.Equal(t, []string{"idx_a1", "idx_a2"}, names)
   190  
   191  	// check return nothing for both multi-column and single-column indices
   192  	err = tracker.Exec(ctx, "testdb", parseSQL(t, p, "alter table foo add index idx_ab(a, b)"))
   193  	require.NoError(t, err)
   194  	err = tracker.Exec(ctx, "testdb", parseSQL(t, p, "alter table foo add index idx_b(b)"))
   195  	require.NoError(t, err)
   196  	infos, err = tracker.GetSingleColumnIndices("testdb", "foo", "b")
   197  	require.Error(t, err)
   198  	require.Len(t, infos, 0)
   199  
   200  	// check no indices
   201  	infos, err = tracker.GetSingleColumnIndices("testdb", "foo", "c")
   202  	require.NoError(t, err)
   203  	require.Len(t, infos, 0)
   204  }
   205  
   206  func TestCreateSchemaIfNotExists(t *testing.T) {
   207  	ctx := context.Background()
   208  	p := parser.New()
   209  
   210  	tracker, err := NewTestTracker(ctx, "test-tracker", nil, dlog.L())
   211  	require.NoError(t, err)
   212  	defer tracker.Close()
   213  
   214  	// We cannot create a table without a database.
   215  	err = tracker.Exec(ctx, "testdb", parseSQL(t, p, "create table foo(a int)"))
   216  	require.ErrorContains(t, err, "Unknown database 'testdb'")
   217  
   218  	// We can create the database directly.
   219  	err = tracker.CreateSchemaIfNotExists("testdb")
   220  	require.NoError(t, err)
   221  
   222  	// Creating the same database twice is no-op.
   223  	err = tracker.CreateSchemaIfNotExists("testdb")
   224  	require.NoError(t, err)
   225  
   226  	// Now creating a table should be successful
   227  	err = tracker.Exec(ctx, "testdb", parseSQL(t, p, "create table foo(a int)"))
   228  	require.NoError(t, err)
   229  
   230  	ti, err := tracker.GetTableInfo(&filter.Table{Schema: "testdb", Name: "foo"})
   231  	require.NoError(t, err)
   232  	require.Equal(t, "foo", ti.Name.O)
   233  }
   234  
   235  func TestMultiDrop(t *testing.T) {
   236  	ctx := context.Background()
   237  	p := parser.New()
   238  
   239  	tracker, err := NewTestTracker(ctx, "test-tracker", nil, dlog.L())
   240  	require.NoError(t, err)
   241  	defer tracker.Close()
   242  
   243  	err = tracker.CreateSchemaIfNotExists("testdb")
   244  	require.NoError(t, err)
   245  	createAST := parseSQL(t, p, `create table foo(a int, b int, c int)
   246         partition by range( a ) (
   247  			partition p1 values less than (1991),
   248  			partition p2 values less than (1996),
   249  			partition p3 values less than (2001)
   250  	    );`)
   251  	err = tracker.Exec(ctx, "testdb", createAST)
   252  	require.NoError(t, err)
   253  
   254  	err = tracker.Exec(ctx, "testdb", parseSQL(t, p, "alter table foo drop partition p1, p2"))
   255  	require.NoError(t, err)
   256  
   257  	err = tracker.Exec(ctx, "testdb", parseSQL(t, p, "alter table foo drop b, drop c"))
   258  	require.NoError(t, err)
   259  }
   260  
   261  // clearVolatileInfo removes generated information like TS and ID so DeepEquals
   262  // of two compatible schemas can pass.
   263  func clearVolatileInfo(ti *model.TableInfo) {
   264  	ti.ID = 0
   265  	ti.UpdateTS = 0
   266  	if ti.Partition != nil {
   267  		for i := range ti.Partition.Definitions {
   268  			ti.Partition.Definitions[i].ID = 0
   269  		}
   270  	}
   271  }
   272  
   273  // asJSON is a convenient wrapper to print a TableInfo in its JSON representation.
   274  type asJSON struct{ *model.TableInfo }
   275  
   276  func (aj asJSON) String() string {
   277  	b, _ := json.Marshal(aj.TableInfo)
   278  	return string(b)
   279  }
   280  
   281  func TestCreateTableIfNotExists(t *testing.T) {
   282  	table := &filter.Table{
   283  		Schema: "testdb",
   284  		Name:   "foo",
   285  	}
   286  
   287  	ctx := context.Background()
   288  	p := parser.New()
   289  
   290  	tracker, err := NewTestTracker(ctx, "test-tracker", nil, dlog.L())
   291  	require.NoError(t, err)
   292  	defer tracker.Close()
   293  
   294  	// Create some sort of complicated table.
   295  	err = tracker.CreateSchemaIfNotExists("testdb")
   296  	require.NoError(t, err)
   297  
   298  	err = tracker.Exec(ctx, "testdb", parseSQL(t, p, `
   299  		create table foo(
   300  			a int primary key auto_increment,
   301  			b int as (c+1) not null,
   302  			c int comment 'some cmt',
   303  			d text,
   304  			key dk(d(255))
   305  		) comment 'more cmt' partition by range columns (a) (
   306  			partition x41 values less than (41),
   307  			partition x82 values less than (82),
   308  			partition rest values less than maxvalue comment 'part cmt'
   309  		);
   310  	`))
   311  	require.NoError(t, err)
   312  
   313  	// Save the table info
   314  	ti1, err := tracker.GetTableInfo(table)
   315  	require.NoError(t, err)
   316  	require.Equal(t, "foo", ti1.Name.O)
   317  	ti1 = ti1.Clone()
   318  	clearVolatileInfo(ti1)
   319  
   320  	// Remove the table. Should not be found anymore.
   321  	err = tracker.DropTable(table)
   322  	require.NoError(t, err)
   323  
   324  	_, err = tracker.GetTableInfo(table)
   325  	require.ErrorContains(t, err, "Table 'testdb.foo' doesn't exist")
   326  
   327  	// Recover the table using the table info.
   328  	err = tracker.CreateTableIfNotExists(&filter.Table{Schema: "testdb", Name: "foo"}, ti1)
   329  	require.NoError(t, err)
   330  
   331  	// The new table info should be equivalent to the old one except the TS and generated IDs.
   332  	ti2, err := tracker.GetTableInfo(table)
   333  	require.NoError(t, err)
   334  	clearVolatileInfo(ti2)
   335  	require.Equal(t, ti1, ti2, "ti2 = %s\nti1 = %s", asJSON{ti2}, asJSON{ti1})
   336  
   337  	// no error if table already exist
   338  	err = tracker.CreateTableIfNotExists(&filter.Table{Schema: "testdb", Name: "foo"}, ti1)
   339  	require.NoError(t, err)
   340  
   341  	// error if db not exist
   342  	err = tracker.CreateTableIfNotExists(&filter.Table{Schema: "test-another-db", Name: "foo"}, ti1)
   343  	require.ErrorContains(t, err, "Unknown database")
   344  
   345  	// Can use the table info to recover a table using a different name.
   346  	err = tracker.CreateTableIfNotExists(&filter.Table{Schema: "testdb", Name: "bar"}, ti1)
   347  	require.NoError(t, err)
   348  
   349  	ti3, err := tracker.GetTableInfo(&filter.Table{Schema: "testdb", Name: "bar"})
   350  	require.NoError(t, err)
   351  	require.Equal(t, "bar", ti3.Name.O)
   352  	clearVolatileInfo(ti3)
   353  	ti3.Name = ti1.Name
   354  	require.Equal(t, ti1, ti3, "ti3 = %s\nti1 = %s", asJSON{ti3}, asJSON{ti1})
   355  
   356  	start := time.Now()
   357  	for n := 0; n < 100; n++ {
   358  		err = tracker.CreateTableIfNotExists(&filter.Table{Schema: "testdb", Name: fmt.Sprintf("foo-%d", n)}, ti1)
   359  		require.NoError(t, err)
   360  	}
   361  	duration := time.Since(start)
   362  	require.Less(t, duration.Seconds(), float64(30))
   363  }
   364  
   365  func TestBatchCreateTableIfNotExist(t *testing.T) {
   366  	ctx := context.Background()
   367  	p := parser.New()
   368  
   369  	tracker, err := NewTestTracker(ctx, "test-tracker", nil, dlog.L())
   370  	require.NoError(t, err)
   371  	defer tracker.Close()
   372  
   373  	err = tracker.CreateSchemaIfNotExists("testdb")
   374  	require.NoError(t, err)
   375  	err = tracker.CreateSchemaIfNotExists("testdb2")
   376  	require.NoError(t, err)
   377  
   378  	tables := []*filter.Table{
   379  		{
   380  			Schema: "testdb",
   381  			Name:   "foo",
   382  		},
   383  		{
   384  			Schema: "testdb",
   385  			Name:   "foo1",
   386  		},
   387  		{
   388  			Schema: "testdb2",
   389  			Name:   "foo3",
   390  		},
   391  	}
   392  	execStmt := []string{
   393  		`create table foo(
   394  			a int primary key auto_increment,
   395  			b int as (c+1) not null,
   396  			c int comment 'some cmt',
   397  			d text,
   398  			key dk(d(255))
   399  		) comment 'more cmt' partition by range columns (a) (
   400  			partition x41 values less than (41),
   401  			partition x82 values less than (82),
   402  			partition rest values less than maxvalue comment 'part cmt'
   403  		);`,
   404  		`create table foo1(
   405  			a int primary key,
   406  			b text not null,
   407  			d datetime,
   408  			e varchar(5)
   409  		);`,
   410  		`create table foo3(
   411  			a int,
   412  			b int,
   413  			primary key(a));`,
   414  	}
   415  	tiInfos := make([]*model.TableInfo, len(tables))
   416  	for i := range tables {
   417  		err = tracker.Exec(ctx, tables[i].Schema, parseSQL(t, p, execStmt[i]))
   418  		require.NoError(t, err)
   419  		tiInfos[i], err = tracker.GetTableInfo(tables[i])
   420  		require.NoError(t, err)
   421  		require.Equal(t, tables[i].Name, tiInfos[i].Name.O)
   422  		tiInfos[i] = tiInfos[i].Clone()
   423  		clearVolatileInfo(tiInfos[i])
   424  	}
   425  	// drop all tables and recover
   426  	// 1. drop
   427  	for i := range tables {
   428  		err = tracker.DropTable(tables[i])
   429  		require.NoError(t, err)
   430  		_, err = tracker.GetTableInfo(tables[i])
   431  		require.ErrorContains(t, err, "doesn't exist")
   432  	}
   433  	// 2. test empty load
   434  	tablesToCreate := map[string]map[string]*model.TableInfo{}
   435  	tablesToCreate["testdb"] = map[string]*model.TableInfo{}
   436  	tablesToCreate["testdb2"] = map[string]*model.TableInfo{}
   437  	err = tracker.BatchCreateTableIfNotExist(tablesToCreate)
   438  	require.NoError(t, err)
   439  	// 3. recover
   440  	for i := range tables {
   441  		tablesToCreate[tables[i].Schema][tables[i].Name] = tiInfos[i]
   442  	}
   443  	err = tracker.BatchCreateTableIfNotExist(tablesToCreate)
   444  	require.NoError(t, err)
   445  	// 4. check all create success
   446  	for i := range tables {
   447  		var ti *model.TableInfo
   448  		ti, err = tracker.GetTableInfo(tables[i])
   449  		require.NoError(t, err)
   450  		cloneTi := ti.Clone()
   451  		clearVolatileInfo(cloneTi)
   452  		require.Equal(t, tiInfos[i], cloneTi)
   453  	}
   454  
   455  	// drop two tables and create all three
   456  	// expect: silently succeed
   457  	// 1. drop table
   458  	err = tracker.DropTable(tables[2])
   459  	require.NoError(t, err)
   460  	err = tracker.DropTable(tables[0])
   461  	require.NoError(t, err)
   462  	// 2. batch create
   463  	err = tracker.BatchCreateTableIfNotExist(tablesToCreate)
   464  	require.NoError(t, err)
   465  	// 3. check
   466  	for i := range tables {
   467  		var ti *model.TableInfo
   468  		ti, err = tracker.GetTableInfo(tables[i])
   469  		require.NoError(t, err)
   470  		clearVolatileInfo(ti)
   471  		require.Equal(t, tiInfos[i], ti)
   472  	}
   473  
   474  	// BatchCreateTableIfNotExist will also create database
   475  	err = tracker.Exec(ctx, "", parseSQL(t, p, `drop database testdb`))
   476  	require.NoError(t, err)
   477  	err = tracker.BatchCreateTableIfNotExist(tablesToCreate)
   478  	require.NoError(t, err)
   479  }
   480  
   481  func TestAllSchemas(t *testing.T) {
   482  	ctx := context.Background()
   483  	p := parser.New()
   484  
   485  	tracker, err := NewTestTracker(ctx, "test-tracker", nil, dlog.L())
   486  	require.NoError(t, err)
   487  	defer tracker.Close()
   488  
   489  	// nothing should exist...
   490  	require.Len(t, tracker.AllSchemas(), 0)
   491  
   492  	// Create several schemas and tables.
   493  	err = tracker.CreateSchemaIfNotExists("testdb1")
   494  	require.NoError(t, err)
   495  	err = tracker.CreateSchemaIfNotExists("testdb2")
   496  	require.NoError(t, err)
   497  	err = tracker.CreateSchemaIfNotExists("testdb3")
   498  	require.NoError(t, err)
   499  	err = tracker.Exec(ctx, "testdb2", parseSQL(t, p, "create table a(a int)"))
   500  	require.NoError(t, err)
   501  	err = tracker.Exec(ctx, "testdb1", parseSQL(t, p, "create table b(a int)"))
   502  	require.NoError(t, err)
   503  	err = tracker.Exec(ctx, "testdb1", parseSQL(t, p, "create table c(a int)"))
   504  	require.NoError(t, err)
   505  
   506  	// check schema tables
   507  	tables, err := tracker.ListSchemaTables("testdb1")
   508  	require.NoError(t, err)
   509  	sort.Strings(tables)
   510  	require.Equal(t, []string{"b", "c"}, tables)
   511  	// check schema not exists
   512  	notExistSchemaName := "testdb_not_found"
   513  	_, err = tracker.ListSchemaTables(notExistSchemaName)
   514  	require.True(t, terror.ErrSchemaTrackerUnSchemaNotExist.Equal(err))
   515  
   516  	// check that all schemas and tables are present.
   517  	allSchemas := tracker.AllSchemas()
   518  	require.Len(t, allSchemas, 3)
   519  	existingNames := 0
   520  	for _, schema := range allSchemas {
   521  		switch schema {
   522  		case "testdb1":
   523  			existingNames |= 1
   524  			tables, err = tracker.ListSchemaTables(schema)
   525  			require.NoError(t, err)
   526  			require.Len(t, tables, 2)
   527  			for _, table := range tables {
   528  				switch table {
   529  				case "b":
   530  					existingNames |= 8
   531  				case "c":
   532  					existingNames |= 16
   533  				default:
   534  					t.Fatalf("unexpected table testdb1.%s", table)
   535  				}
   536  			}
   537  		case "testdb2":
   538  			existingNames |= 2
   539  			tables, err = tracker.ListSchemaTables(schema)
   540  			require.NoError(t, err)
   541  			require.Len(t, tables, 1)
   542  			table := tables[0]
   543  			require.Equal(t, "a", table)
   544  		case "testdb3":
   545  			existingNames |= 4
   546  		default:
   547  			t.Fatalf("unexpected schema %s", schema)
   548  		}
   549  	}
   550  	require.Equal(t, 31, existingNames)
   551  
   552  	// reset the tracker. all schemas should be gone.
   553  	tracker.Reset()
   554  	require.Len(t, tracker.AllSchemas(), 0)
   555  	_, err = tracker.GetTableInfo(&filter.Table{Schema: "testdb2", Name: "a"})
   556  	require.ErrorContains(t, err, "Unknown database 'testdb2'")
   557  }
   558  
   559  func mockBaseConn(t *testing.T) (*dbconn.DBConn, sqlmock.Sqlmock) {
   560  	t.Helper()
   561  
   562  	db, mock, err := sqlmock.New()
   563  	require.NoError(t, err)
   564  	t.Cleanup(func() {
   565  		db.Close()
   566  	})
   567  	c, err := db.Conn(context.Background())
   568  	require.NoError(t, err)
   569  	baseConn := conn.NewBaseConnForTest(c, nil)
   570  	dbConn := dbconn.NewDBConn(&config.SubTaskConfig{}, baseConn)
   571  	return dbConn, mock
   572  }
   573  
   574  func TestInitDownStreamSQLModeAndParser(t *testing.T) {
   575  	dbConn, mock := mockBaseConn(t)
   576  
   577  	tracker, err := NewTestTracker(context.Background(), "test-tracker", dbConn, dlog.L())
   578  	require.NoError(t, err)
   579  	defer tracker.Close()
   580  
   581  	mock.ExpectBegin()
   582  	mock.ExpectExec(fmt.Sprintf("SET SESSION SQL_MODE = '%s'", mysql.DefaultSQLMode)).WillReturnResult(sqlmock.NewResult(0, 0))
   583  	mock.ExpectCommit()
   584  
   585  	tctx := tcontext.NewContext(context.Background(), dlog.L())
   586  
   587  	err = tracker.downstreamTracker.initDownStreamSQLModeAndParser(tctx)
   588  	require.NoError(t, err)
   589  	require.NotNil(t, tracker.downstreamTracker.stmtParser)
   590  }
   591  
   592  func TestGetDownStreamIndexInfo(t *testing.T) {
   593  	// origin table info
   594  	p := parser.New()
   595  	se := timock.NewContext()
   596  	ctx := context.Background()
   597  	node, err := p.ParseOneStmt("create table t(a int, b int, c varchar(10))", "utf8mb4", "utf8mb4_bin")
   598  	require.NoError(t, err)
   599  	oriTi, err := ddl.MockTableInfo(se, node.(*ast.CreateTableStmt), 1)
   600  	require.NoError(t, err)
   601  
   602  	// tracker and sqlmock
   603  	dbConn, mock := mockBaseConn(t)
   604  	tracker, err := NewTestTracker(ctx, "test-tracker", dbConn, dlog.L())
   605  	require.NoError(t, err)
   606  	defer tracker.Close()
   607  
   608  	mock.ExpectBegin()
   609  	mock.ExpectExec(fmt.Sprintf("SET SESSION SQL_MODE = '%s'", mysql.DefaultSQLMode)).WillReturnResult(sqlmock.NewResult(0, 0))
   610  	mock.ExpectCommit()
   611  
   612  	tableID := "`test`.`test`"
   613  
   614  	mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows(
   615  		sqlmock.NewRows([]string{"Table", "Create Table"}).
   616  			AddRow("test", "create table t(a int, b int, c varchar(20000), primary key(a, b), key(c(20000)))/*!90000 SHARD_ROW_ID_BITS=6 */"))
   617  	dti, err := tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi)
   618  	require.NoError(t, err)
   619  	require.NotNil(t, dti.WhereHandle.UniqueNotNullIdx)
   620  }
   621  
   622  func TestReTrackDownStreamIndex(t *testing.T) {
   623  	// origin table info
   624  	p := parser.New()
   625  	se := timock.NewContext()
   626  	node, err := p.ParseOneStmt("create table t(a int, b int, c varchar(10))", "utf8mb4", "utf8mb4_bin")
   627  	require.NoError(t, err)
   628  	oriTi, err := ddl.MockTableInfo(se, node.(*ast.CreateTableStmt), 1)
   629  	require.NoError(t, err)
   630  
   631  	dbConn, mock := mockBaseConn(t)
   632  	tracker, err := NewTestTracker(context.Background(), "test-tracker", dbConn, dlog.L())
   633  	require.NoError(t, err)
   634  	defer tracker.Close()
   635  
   636  	mock.ExpectBegin()
   637  	mock.ExpectExec(fmt.Sprintf("SET SESSION SQL_MODE = '%s'", mysql.DefaultSQLMode)).WillReturnResult(sqlmock.NewResult(0, 0))
   638  	mock.ExpectCommit()
   639  
   640  	tableID := "`test`.`test`"
   641  
   642  	mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows(
   643  		sqlmock.NewRows([]string{"Table", "Create Table"}).
   644  			AddRow("test", "create table t(a int, b int, c varchar(10), PRIMARY KEY (a,b))"))
   645  	_, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi)
   646  	require.NoError(t, err)
   647  	_, ok := tracker.downstreamTracker.tableInfos[tableID]
   648  	require.True(t, ok)
   649  
   650  	// just table
   651  	targetTables := []*filter.Table{{Schema: "test", Name: "a"}, {Schema: "test", Name: "test"}}
   652  	tracker.RemoveDownstreamSchema(tcontext.Background(), targetTables)
   653  	_, ok = tracker.downstreamTracker.tableInfos[tableID]
   654  	require.False(t, ok)
   655  
   656  	mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows(
   657  		sqlmock.NewRows([]string{"Table", "Create Table"}).
   658  			AddRow("test", "create table t(a int, b int, c varchar(10), PRIMARY KEY (a,b))"))
   659  	_, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi)
   660  	require.NoError(t, err)
   661  	_, ok = tracker.downstreamTracker.tableInfos[tableID]
   662  	require.True(t, ok)
   663  
   664  	tracker.RemoveDownstreamSchema(tcontext.Background(), targetTables)
   665  	_, ok = tracker.downstreamTracker.tableInfos[tableID]
   666  	require.False(t, ok)
   667  
   668  	mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows(
   669  		sqlmock.NewRows([]string{"Table", "Create Table"}).
   670  			AddRow("test", "create table t(a int primary key, b int, c varchar(10))"))
   671  	_, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi)
   672  	require.NoError(t, err)
   673  	_, ok = tracker.downstreamTracker.tableInfos[tableID]
   674  	require.True(t, ok)
   675  
   676  	// just schema
   677  	targetTables = []*filter.Table{{Schema: "test", Name: "a"}, {Schema: "test", Name: ""}}
   678  	tracker.RemoveDownstreamSchema(tcontext.Background(), targetTables)
   679  	_, ok = tracker.downstreamTracker.tableInfos[tableID]
   680  	require.False(t, ok)
   681  
   682  	mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows(
   683  		sqlmock.NewRows([]string{"Table", "Create Table"}).
   684  			AddRow("test", "create table t(a int, b int, c varchar(10), PRIMARY KEY (a,b))"))
   685  	_, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi)
   686  	require.NoError(t, err)
   687  	_, ok = tracker.downstreamTracker.tableInfos[tableID]
   688  	require.True(t, ok)
   689  
   690  	tracker.RemoveDownstreamSchema(tcontext.Background(), targetTables)
   691  	_, ok = tracker.downstreamTracker.tableInfos[tableID]
   692  	require.False(t, ok)
   693  
   694  	mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows(
   695  		sqlmock.NewRows([]string{"Table", "Create Table"}).
   696  			AddRow("test", "create table t(a int primary key, b int, c varchar(10))"))
   697  	_, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi)
   698  	require.NoError(t, err)
   699  	_, ok = tracker.downstreamTracker.tableInfos[tableID]
   700  	require.True(t, ok)
   701  }
   702  
   703  func TestVarchar20000(t *testing.T) {
   704  	// origin table info
   705  	p := parser.New()
   706  	node, err := p.ParseOneStmt("create table t(c varchar(20000)) charset=utf8", "", "")
   707  	require.NoError(t, err)
   708  	oriTi, err := ddl.BuildTableInfoFromAST(node.(*ast.CreateTableStmt))
   709  	require.NoError(t, err)
   710  
   711  	// tracker and sqlmock
   712  	dbConn, mock := mockBaseConn(t)
   713  	tracker, err := NewTestTracker(context.Background(), "test-tracker", dbConn, dlog.L())
   714  	require.NoError(t, err)
   715  	defer tracker.Close()
   716  
   717  	mock.ExpectBegin()
   718  	mock.ExpectExec(fmt.Sprintf("SET SESSION SQL_MODE = '%s'", mysql.DefaultSQLMode)).WillReturnResult(sqlmock.NewResult(0, 0))
   719  	mock.ExpectCommit()
   720  
   721  	tableID := "`test`.`test`"
   722  
   723  	mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows(
   724  		sqlmock.NewRows([]string{"Table", "Create Table"}).
   725  			AddRow("test", "create table t(c varchar(20000)) charset=utf8"))
   726  	_, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi)
   727  	require.NoError(t, err)
   728  	_, ok := tracker.downstreamTracker.tableInfos[tableID]
   729  	require.True(t, ok)
   730  }
   731  
   732  func TestPlacementRule(t *testing.T) {
   733  	// origin table info
   734  	p := parser.New()
   735  	node, err := p.ParseOneStmt("create table t(c int) charset=utf8mb4", "", "")
   736  	require.NoError(t, err)
   737  	oriTi, err := ddl.BuildTableInfoFromAST(node.(*ast.CreateTableStmt))
   738  	require.NoError(t, err)
   739  
   740  	dbConn, mock := mockBaseConn(t)
   741  	tracker, err := NewTestTracker(context.Background(), "test-tracker", dbConn, dlog.L())
   742  	require.NoError(t, err)
   743  	defer tracker.Close()
   744  
   745  	mock.ExpectBegin()
   746  	mock.ExpectExec(fmt.Sprintf("SET SESSION SQL_MODE = '%s'", mysql.DefaultSQLMode)).WillReturnResult(sqlmock.NewResult(0, 0))
   747  	mock.ExpectCommit()
   748  
   749  	tableID := "`test`.`test`"
   750  
   751  	mock.ExpectQuery("SHOW CREATE TABLE " + tableID).WillReturnRows(
   752  		sqlmock.NewRows([]string{"Table", "Create Table"}).
   753  			AddRow("test", ""+
   754  				"CREATE TABLE `t` ("+
   755  				"   `c` int(11) DEFAULT NULL"+
   756  				") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin /*T![placement] PLACEMENT POLICY=`acdc` */;"))
   757  	_, err = tracker.GetDownStreamTableInfo(tcontext.Background(), tableID, oriTi)
   758  	require.NoError(t, err)
   759  	_, ok := tracker.downstreamTracker.tableInfos[tableID]
   760  	require.True(t, ok)
   761  }
   762  
   763  func TestTimeTypes(t *testing.T) {
   764  	ctx := context.Background()
   765  	p := parser.New()
   766  	p.SetSQLMode(0)
   767  
   768  	tracker, err := NewTestTracker(ctx, "test-tracker", nil, dlog.L())
   769  	require.NoError(t, err)
   770  	defer tracker.Close()
   771  
   772  	// Create some sort of complicated table.
   773  	err = tracker.CreateSchemaIfNotExists("testdb")
   774  	require.NoError(t, err)
   775  
   776  	err = tracker.Exec(ctx, "testdb", parseSQL(t, p, `
   777  		create table foo(
   778  			c0 datetime primary key,
   779  			c1 datetime default current_timestamp,
   780  			c2 datetime default '0000-00-00 00:00:00',
   781  			c3 datetime default '2020-02-02 00:00:00',
   782  			c4 timestamp default current_timestamp,
   783  			c5 timestamp default '0000-00-00 00:00:00',
   784  			c6 timestamp default '2020-02-02 00:00:00'
   785  		);
   786  	`))
   787  	require.NoError(t, err)
   788  }
   789  
   790  func TestNeedRestrictedSQLExecutor(t *testing.T) {
   791  	ctx := context.Background()
   792  	p := parser.New()
   793  
   794  	tracker, err := NewTestTracker(ctx, "test-tracker", nil, dlog.L())
   795  	require.NoError(t, err)
   796  	defer tracker.Close()
   797  
   798  	// Create some sort of complicated table.
   799  	err = tracker.CreateSchemaIfNotExists("testdb")
   800  	require.NoError(t, err)
   801  
   802  	err = tracker.Exec(ctx, "testdb", parseSQL(t, p, `create table testdb.t (a int, b int);`))
   803  	require.NoError(t, err)
   804  
   805  	err = tracker.Exec(ctx, "testdb", parseSQL(t, p, `alter table testdb.t modify column a int not null;`))
   806  	require.NoError(t, err)
   807  }
   808  
   809  func TestMustNotUseMockStore(t *testing.T) {
   810  	ctx := context.Background()
   811  	tracker, err := NewTestTracker(ctx, "test-tracker", nil, dlog.L())
   812  	require.NoError(t, err)
   813  	defer tracker.Close()
   814  
   815  	require.Nil(t, tracker.se.GetStore(), "see https://github.com/pingcap/tiflow/issues/5334")
   816  }