vitess.io/vitess@v0.16.2/go/test/endtoend/schemadiff/vrepl/schemadiff_vrepl_suite_test.go (about)

     1  /*
     2  Copyright 2022 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package vreplsuite
    18  
    19  import (
    20  	"context"
    21  	"flag"
    22  	"fmt"
    23  	"os"
    24  	"path"
    25  	"regexp"
    26  	"strings"
    27  	"testing"
    28  
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/require"
    31  
    32  	"vitess.io/vitess/go/mysql"
    33  	"vitess.io/vitess/go/sqltypes"
    34  	"vitess.io/vitess/go/test/endtoend/cluster"
    35  	"vitess.io/vitess/go/test/endtoend/onlineddl"
    36  	"vitess.io/vitess/go/vt/schemadiff"
    37  	"vitess.io/vitess/go/vt/sqlparser"
    38  )
    39  
    40  var (
    41  	clusterInstance      *cluster.LocalProcessCluster
    42  	vtParams             mysql.ConnParams
    43  	evaluatedMysqlParams *mysql.ConnParams
    44  
    45  	hostname              = "localhost"
    46  	keyspaceName          = "ks"
    47  	cell                  = "zone1"
    48  	schemaChangeDirectory = ""
    49  	tableName             = `onlineddl_test`
    50  	eventName             = `onlineddl_test`
    51  )
    52  
    53  const (
    54  	testDataPath   = "../../onlineddl/vrepl_suite/testdata"
    55  	defaultSQLMode = "ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION"
    56  )
    57  
    58  type testTableSchema struct {
    59  	testName    string
    60  	tableSchema string
    61  }
    62  
    63  var (
    64  	fromTestTableSchemas []*testTableSchema
    65  	toTestTableSchemas   []*testTableSchema
    66  	autoIncrementRegexp  = regexp.MustCompile(`(?i) auto_increment[\s]*[=]?[\s]*([0-9]+)`)
    67  )
    68  
    69  func TestMain(m *testing.M) {
    70  	defer cluster.PanicHandler(nil)
    71  	flag.Parse()
    72  
    73  	exitcode, err := func() (int, error) {
    74  		clusterInstance = cluster.NewCluster(cell, hostname)
    75  		schemaChangeDirectory = path.Join("/tmp", fmt.Sprintf("schema_change_dir_%d", clusterInstance.GetAndReserveTabletUID()))
    76  		defer os.RemoveAll(schemaChangeDirectory)
    77  		defer clusterInstance.Teardown()
    78  
    79  		if _, err := os.Stat(schemaChangeDirectory); os.IsNotExist(err) {
    80  			_ = os.Mkdir(schemaChangeDirectory, 0700)
    81  		}
    82  
    83  		clusterInstance.VtctldExtraArgs = []string{
    84  			"--schema_change_dir", schemaChangeDirectory,
    85  			"--schema_change_controller", "local",
    86  			"--schema_change_check_interval", "1",
    87  		}
    88  
    89  		clusterInstance.VtTabletExtraArgs = []string{
    90  			"--enable-lag-throttler",
    91  			"--throttle_threshold", "1s",
    92  			"--heartbeat_enable",
    93  			"--heartbeat_interval", "250ms",
    94  			"--heartbeat_on_demand_duration", "5s",
    95  			"--migration_check_interval", "5s",
    96  			"--watch_replication_stream",
    97  		}
    98  
    99  		if err := clusterInstance.StartTopo(); err != nil {
   100  			return 1, err
   101  		}
   102  
   103  		// Start keyspace
   104  		keyspace := &cluster.Keyspace{
   105  			Name: keyspaceName,
   106  		}
   107  
   108  		// No need for replicas in this stress test
   109  		if err := clusterInstance.StartKeyspace(*keyspace, []string{"1"}, 0, false); err != nil {
   110  			return 1, err
   111  		}
   112  
   113  		vtgateInstance := clusterInstance.NewVtgateInstance()
   114  		// Start vtgate
   115  		if err := vtgateInstance.Setup(); err != nil {
   116  			return 1, err
   117  		}
   118  		// ensure it is torn down during cluster TearDown
   119  		clusterInstance.VtgateProcess = *vtgateInstance
   120  		vtParams = mysql.ConnParams{
   121  			Host: clusterInstance.Hostname,
   122  			Port: clusterInstance.VtgateMySQLPort,
   123  		}
   124  
   125  		return m.Run(), nil
   126  	}()
   127  	if err != nil {
   128  		fmt.Printf("%v\n", err)
   129  		os.Exit(1)
   130  	} else {
   131  		os.Exit(exitcode)
   132  	}
   133  
   134  }
   135  
   136  func TestSchemaChange(t *testing.T) {
   137  	defer cluster.PanicHandler(t)
   138  
   139  	shards := clusterInstance.Keyspaces[0].Shards
   140  	require.Equal(t, 1, len(shards))
   141  
   142  	files, err := os.ReadDir(testDataPath)
   143  	require.NoError(t, err)
   144  	for _, f := range files {
   145  		if !f.IsDir() {
   146  			continue
   147  		}
   148  		// this is a test!
   149  		t.Run(f.Name(), func(t *testing.T) {
   150  			testSingle(t, f.Name())
   151  		})
   152  	}
   153  }
   154  
   155  func readTestFile(t *testing.T, testName string, fileName string) (content string, exists bool) {
   156  	filePath := path.Join(testDataPath, testName, fileName)
   157  	_, err := os.Stat(filePath)
   158  	if os.IsNotExist(err) {
   159  		return "", false
   160  	}
   161  	require.NoError(t, err)
   162  	b, err := os.ReadFile(filePath)
   163  	require.NoError(t, err)
   164  	return strings.TrimSpace(string(b)), true
   165  }
   166  
   167  // testSingle is the main testing function for a single test in the suite.
   168  // It prepares the grounds, creates the test data, runs a migration, expects results/error, cleans up.
   169  func testSingle(t *testing.T, testName string) {
   170  	if ignoreVersions, exists := readTestFile(t, testName, "ignore_versions"); exists {
   171  		// ignoreVersions is a regexp
   172  		re, err := regexp.Compile(ignoreVersions)
   173  		require.NoError(t, err)
   174  
   175  		rs := mysqlExec(t, "select @@version as ver", "")
   176  		row := rs.Named().Row()
   177  		require.NotNil(t, row)
   178  		mysqlVersion := row["ver"].ToString()
   179  
   180  		if re.MatchString(mysqlVersion) {
   181  			t.Skipf("Skipping test due to ignore_versions=%s", ignoreVersions)
   182  			return
   183  		}
   184  	}
   185  	if _, exists := readTestFile(t, testName, "expect_query_failure"); exists {
   186  		// VTGate failure is expected!
   187  		// irrelevant to this suite.
   188  		// We only want to test actual migrations
   189  		t.Skip("expect_query_failure found. Irrelevant to this suite")
   190  		return
   191  	}
   192  	if _, exists := readTestFile(t, testName, "expect_failure"); exists {
   193  		// irrelevant to this suite.
   194  		// We only want to test actual migrations
   195  		t.Skip("expect_failure found. Irrelevant to this suite")
   196  		return
   197  	}
   198  	if _, exists := readTestFile(t, testName, "skip_schemadiff"); exists {
   199  		// irrelevant to this suite.
   200  		t.Skip("skip_schemadiff found. Irrelevant to this suite")
   201  		return
   202  	}
   203  
   204  	sqlModeQuery := fmt.Sprintf("set @@global.sql_mode='%s'", defaultSQLMode)
   205  	_ = mysqlExec(t, sqlModeQuery, "")
   206  	_ = mysqlExec(t, "set @@global.event_scheduler=0", "")
   207  
   208  	_ = mysqlExec(t, fmt.Sprintf("drop table if exists %s", tableName), "")
   209  	_ = mysqlExec(t, fmt.Sprintf("drop event if exists %s", eventName), "")
   210  
   211  	var fromCreateTable string
   212  	var toCreateTable string
   213  	{
   214  		// create
   215  		f := "create.sql"
   216  		_, exists := readTestFile(t, testName, f)
   217  		require.True(t, exists)
   218  		onlineddl.MysqlClientExecFile(t, mysqlParams(), testDataPath, testName, f)
   219  		// ensure test table has been created:
   220  		// read the create statement
   221  		fromCreateTable = getCreateTableStatement(t, tableName)
   222  		require.NotEmpty(t, fromCreateTable)
   223  	}
   224  	defer func() {
   225  		// destroy
   226  		f := "destroy.sql"
   227  		if _, exists := readTestFile(t, testName, f); exists {
   228  			onlineddl.MysqlClientExecFile(t, mysqlParams(), testDataPath, testName, f)
   229  		}
   230  	}()
   231  
   232  	// Run test
   233  	alterClause := "engine=innodb"
   234  	if content, exists := readTestFile(t, testName, "alter"); exists {
   235  		alterClause = content
   236  	}
   237  	alterStatement := fmt.Sprintf("alter table %s %s", tableName, alterClause)
   238  	// Run the DDL!
   239  	onlineddl.VtgateExecQuery(t, &vtParams, alterStatement, "")
   240  	// migration is complete
   241  	// read the table structure of modified table:
   242  	toCreateTable = getCreateTableStatement(t, tableName)
   243  	require.NotEmpty(t, toCreateTable)
   244  
   245  	if content, exists := readTestFile(t, testName, "expect_table_structure"); exists {
   246  		switch {
   247  		case strings.HasPrefix(testName, "autoinc"):
   248  			// In schemadiff_vrepl test, we run a direct ALTER TABLE. This is as opposed to
   249  			// vrepl_suite runnign a vreplication Online DDL. This matters, because AUTO_INCREMENT
   250  			// values in the resulting table are different between the two approaches!
   251  			// So for schemadiff_vrepl tests we ignore any AUTO_INCREMENT requirements,
   252  			// they're just not interesting for this test.
   253  		default:
   254  			assert.Regexpf(t, content, toCreateTable, "expected SHOW CREATE TABLE to match text in 'expect_table_structure' file")
   255  		}
   256  	}
   257  
   258  	fromTestTableSchemas = append(fromTestTableSchemas, &testTableSchema{
   259  		testName:    testName,
   260  		tableSchema: fromCreateTable,
   261  	})
   262  	toTestTableSchemas = append(toTestTableSchemas, &testTableSchema{
   263  		testName:    testName,
   264  		tableSchema: toCreateTable,
   265  	})
   266  
   267  	hints := &schemadiff.DiffHints{}
   268  	if strings.Contains(alterClause, "AUTO_INCREMENT") {
   269  		hints.AutoIncrementStrategy = schemadiff.AutoIncrementApplyAlways
   270  	}
   271  	t.Run("validate diff", func(t *testing.T) {
   272  		_, allowSchemadiffNormalization := readTestFile(t, testName, "allow_schemadiff_normalization")
   273  		validateDiff(t, fromCreateTable, toCreateTable, allowSchemadiffNormalization, hints)
   274  	})
   275  }
   276  
   277  // func TestRandomSchemaChanges(t *testing.T) {
   278  // 	defer cluster.PanicHandler(t)
   279  
   280  // 	hints := &schemadiff.DiffHints{AutoIncrementStrategy: schemadiff.AutoIncrementIgnore}
   281  // 	// count := 20
   282  // 	// for i := 0; i < count; i++ {
   283  // 	// 	fromTestTableSchema := fromTestTableSchemas[rand.Intn(len(fromTestTableSchemas))]
   284  // 	// 	toTestTableSchema := toTestTableSchemas[rand.Intn(len(toTestTableSchemas))]
   285  // 	// 	testName := fmt.Sprintf("%s/%s", fromTestTableSchema.testName, toTestTableSchema.testName)
   286  // 	// 	t.Run(testName, func(t *testing.T) {
   287  // 	// 		validateDiff(t, fromTestTableSchema.tableSchema, toTestTableSchema.tableSchema, hints)
   288  // 	// 	})
   289  // 	// }
   290  // 	for i := range rand.Perm(len(fromTestTableSchemas)) {
   291  // 		fromTestTableSchema := fromTestTableSchemas[i]
   292  // 		for j := range rand.Perm(len(toTestTableSchemas)) {
   293  // 			toTestTableSchema := toTestTableSchemas[j]
   294  // 			testName := fmt.Sprintf("%s:%s", fromTestTableSchema.testName, toTestTableSchema.testName)
   295  // 			t.Run(testName, func(t *testing.T) {
   296  // 				validateDiff(t, fromTestTableSchema.tableSchema, toTestTableSchema.tableSchema, hints)
   297  // 			})
   298  // 		}
   299  // 	}
   300  // }
   301  
   302  func TestIgnoreAutoIncrementRegexp(t *testing.T) {
   303  	// validate the validation function we use in our tests...
   304  	tt := []struct {
   305  		statement string
   306  		expect    string
   307  	}{
   308  		{
   309  			statement: "CREATE TABLE t(id int auto_increment primary key)",
   310  			expect:    "CREATE TABLE t(id int auto_increment primary key)",
   311  		},
   312  		{
   313  			statement: "CREATE TABLE t(id int auto_increment primary key) auto_increment=3",
   314  			expect:    "CREATE TABLE t(id int auto_increment primary key)",
   315  		},
   316  		{
   317  			statement: "CREATE TABLE t(id int auto_increment primary key) AUTO_INCREMENT=3 default charset=utf8",
   318  			expect:    "CREATE TABLE t(id int auto_increment primary key) default charset=utf8",
   319  		},
   320  		{
   321  			statement: "CREATE TABLE t(id int auto_increment primary key) default charset=utf8 auto_increment=3",
   322  			expect:    "CREATE TABLE t(id int auto_increment primary key) default charset=utf8",
   323  		},
   324  		{
   325  			statement: "CREATE TABLE t(id int auto_increment primary key) default charset=utf8 auto_increment=3 engine=innodb",
   326  			expect:    "CREATE TABLE t(id int auto_increment primary key) default charset=utf8 engine=innodb",
   327  		},
   328  	}
   329  	for _, tc := range tt {
   330  		t.Run(tc.statement, func(t *testing.T) {
   331  			ignored := ignoreAutoIncrement(t, tc.statement)
   332  			assert.Equal(t, tc.expect, ignored)
   333  		})
   334  	}
   335  }
   336  
   337  func ignoreAutoIncrement(t *testing.T, createTable string) string {
   338  	result := autoIncrementRegexp.ReplaceAllString(createTable, "")
   339  	// sanity:
   340  	require.Contains(t, result, "CREATE TABLE")
   341  	require.Contains(t, result, ")")
   342  	return result
   343  }
   344  
   345  func validateDiff(t *testing.T, fromCreateTable string, toCreateTable string, allowSchemadiffNormalization bool, hints *schemadiff.DiffHints) {
   346  	// turn the "from" and "to" create statement strings (which we just read via SHOW CREATE TABLE into sqlparser.CreateTable statement)
   347  	fromStmt, err := sqlparser.ParseStrictDDL(fromCreateTable)
   348  	require.NoError(t, err)
   349  	fromCreateTableStatement, ok := fromStmt.(*sqlparser.CreateTable)
   350  	require.True(t, ok)
   351  
   352  	toStmt, err := sqlparser.ParseStrictDDL(toCreateTable)
   353  	require.NoError(t, err)
   354  	toCreateTableStatement, ok := toStmt.(*sqlparser.CreateTable)
   355  	require.True(t, ok)
   356  
   357  	// The actual diff logic here!
   358  	diff, err := schemadiff.DiffTables(fromCreateTableStatement, toCreateTableStatement, hints)
   359  	assert.NoError(t, err)
   360  
   361  	// The diff can be empty or there can be an actual ALTER TABLE statement
   362  	diffedAlterQuery := ""
   363  	if diff != nil && !diff.IsEmpty() {
   364  		diffedAlterQuery = diff.CanonicalStatementString()
   365  	}
   366  
   367  	// Validate the diff! The way we do it is:
   368  	// Recreate the original table
   369  	// Alter the table directly using our evaluated diff (if empty we do nothing)
   370  	// Review the resulted table structure (via SHOW CREATE TABLE)
   371  	// Expect it to be identical to the structure generated by the suite earlier on (toCreateTable)
   372  	_ = mysqlExec(t, fmt.Sprintf("drop table if exists %s", tableName), "")
   373  	onlineddl.VtgateExecQuery(t, &vtParams, fromCreateTable, "")
   374  	if diffedAlterQuery != "" {
   375  		onlineddl.VtgateExecQuery(t, &vtParams, diffedAlterQuery, "")
   376  	}
   377  	resultCreateTable := getCreateTableStatement(t, tableName)
   378  	if hints.AutoIncrementStrategy == schemadiff.AutoIncrementIgnore {
   379  		toCreateTable = ignoreAutoIncrement(t, toCreateTable)
   380  		resultCreateTable = ignoreAutoIncrement(t, resultCreateTable)
   381  	}
   382  
   383  	// Next, the big test: does the result table, applied by schemadiff's evaluated ALTER, look exactly like
   384  	// the table generated by the test's own ALTER statement?
   385  
   386  	// But wait, there's caveats.
   387  
   388  	if toCreateTable != resultCreateTable {
   389  		// schemadiff's ALTER statement can normalize away CHARACTER SET and COLLATION definitions:
   390  		// when altering a column's CHARTSET&COLLATION into the table's values, schemadiff just strips the
   391  		// CHARSET and COLLATION clauses out of the `MODIFY COLUMN ...` statement. This is valid.
   392  		// However, MySQL outputs two different SHOW CREATE TABLE statements, even though the table
   393  		// structure is identical. And so we accept that there can be a normalization issue.
   394  		if allowSchemadiffNormalization {
   395  			{
   396  				stmt, err := sqlparser.ParseStrictDDL(toCreateTable)
   397  				require.NoError(t, err)
   398  				createTableStatement, ok := stmt.(*sqlparser.CreateTable)
   399  				require.True(t, ok)
   400  				c, err := schemadiff.NewCreateTableEntity(createTableStatement)
   401  				require.NoError(t, err)
   402  				toCreateTable = c.Create().CanonicalStatementString()
   403  			}
   404  			{
   405  				stmt, err := sqlparser.ParseStrictDDL(resultCreateTable)
   406  				require.NoError(t, err)
   407  				createTableStatement, ok := stmt.(*sqlparser.CreateTable)
   408  				require.True(t, ok)
   409  				c, err := schemadiff.NewCreateTableEntity(createTableStatement)
   410  				require.NoError(t, err)
   411  				resultCreateTable = c.Create().CanonicalStatementString()
   412  			}
   413  		}
   414  	}
   415  
   416  	// The actual validation test here:
   417  	assert.Equal(t, toCreateTable, resultCreateTable, "mismatched table structure. ALTER query was: %s", diffedAlterQuery)
   418  
   419  	// Also, let's see that our diff agrees there's no change:
   420  	resultStmt, err := sqlparser.ParseStrictDDL(resultCreateTable)
   421  	require.NoError(t, err)
   422  	resultCreateTableStatement, ok := resultStmt.(*sqlparser.CreateTable)
   423  	require.True(t, ok)
   424  
   425  	resultDiff, err := schemadiff.DiffTables(toCreateTableStatement, resultCreateTableStatement, hints)
   426  	assert.NoError(t, err)
   427  	assert.Nil(t, resultDiff)
   428  }
   429  
   430  func getTablet() *cluster.Vttablet {
   431  	return clusterInstance.Keyspaces[0].Shards[0].Vttablets[0]
   432  }
   433  
   434  func mysqlParams() *mysql.ConnParams {
   435  	if evaluatedMysqlParams != nil {
   436  		return evaluatedMysqlParams
   437  	}
   438  	evaluatedMysqlParams = &mysql.ConnParams{
   439  		Uname:      "vt_dba",
   440  		UnixSocket: path.Join(os.Getenv("VTDATAROOT"), fmt.Sprintf("/vt_%010d", getTablet().TabletUID), "/mysql.sock"),
   441  		DbName:     fmt.Sprintf("vt_%s", keyspaceName),
   442  	}
   443  	return evaluatedMysqlParams
   444  }
   445  
   446  // VtgateExecDDL executes a DDL query with given strategy
   447  func mysqlExec(t *testing.T, sql string, expectError string) *sqltypes.Result {
   448  	t.Helper()
   449  
   450  	ctx := context.Background()
   451  	conn, err := mysql.Connect(ctx, mysqlParams())
   452  	require.Nil(t, err)
   453  	defer conn.Close()
   454  
   455  	qr, err := conn.ExecuteFetch(sql, 100000, true)
   456  	if expectError == "" {
   457  		require.NoError(t, err)
   458  	} else {
   459  		require.Error(t, err, "error should not be nil")
   460  		require.Contains(t, err.Error(), expectError, "Unexpected error")
   461  	}
   462  	return qr
   463  }
   464  
   465  // getCreateTableStatement returns the CREATE TABLE statement for a given table
   466  func getCreateTableStatement(t *testing.T, tableName string) (statement string) {
   467  	queryResult, err := getTablet().VttabletProcess.QueryTablet(fmt.Sprintf("show create table %s", tableName), keyspaceName, true)
   468  	require.Nil(t, err)
   469  
   470  	assert.Equal(t, len(queryResult.Rows), 1)
   471  	assert.Equal(t, len(queryResult.Rows[0]), 2) // table name, create statement
   472  	statement = queryResult.Rows[0][1].ToString()
   473  	return statement
   474  }