github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/simulator/sqlgen/impl_test.go (about)

     1  // Copyright 2022 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package sqlgen
    15  
    16  import (
    17  	"fmt"
    18  	"strings"
    19  	"testing"
    20  
    21  	"github.com/pingcap/tidb/pkg/parser"
    22  	"github.com/pingcap/tidb/pkg/parser/ast"
    23  	"github.com/pingcap/tiflow/dm/pkg/log"
    24  	"github.com/pingcap/tiflow/dm/simulator/config"
    25  	"github.com/pingcap/tiflow/dm/simulator/mcp"
    26  	"github.com/stretchr/testify/suite"
    27  )
    28  
    29  type testSQLGenImplSuite struct {
    30  	suite.Suite
    31  	tableConfig   *config.TableConfig
    32  	sqlParser     *parser.Parser
    33  	allColNameMap map[string]struct{}
    34  }
    35  
    36  func (s *testSQLGenImplSuite) SetupSuite() {
    37  	s.allColNameMap = make(map[string]struct{})
    38  	s.Require().Nil(log.InitLogger(&log.Config{}))
    39  	s.tableConfig = &config.TableConfig{
    40  		DatabaseName: "games",
    41  		TableName:    "members",
    42  		Columns: []*config.ColumnDefinition{
    43  			{
    44  				ColumnName: "id",
    45  				DataType:   "int",
    46  				DataLen:    11,
    47  			},
    48  			{
    49  				ColumnName: "name",
    50  				DataType:   "varchar",
    51  				DataLen:    255,
    52  			},
    53  			{
    54  				ColumnName: "age",
    55  				DataType:   "int",
    56  				DataLen:    11,
    57  			},
    58  			{
    59  				ColumnName: "team_id",
    60  				DataType:   "int",
    61  				DataLen:    11,
    62  			},
    63  		},
    64  		UniqueKeyColumnNames: []string{"id"},
    65  	}
    66  	for _, colInfo := range s.tableConfig.Columns {
    67  		s.allColNameMap[fmt.Sprintf("`%s`", colInfo.ColumnName)] = struct{}{}
    68  	}
    69  	s.sqlParser = parser.New()
    70  }
    71  
    72  func generateUKColNameMap(ukColNames []string) map[string]struct{} {
    73  	ukColNameMap := make(map[string]struct{})
    74  	for _, colName := range ukColNames {
    75  		ukColNameMap[fmt.Sprintf("`%s`", colName)] = struct{}{}
    76  	}
    77  	return ukColNameMap
    78  }
    79  
    80  func (s *testSQLGenImplSuite) checkLoadUKsSQL(sql string, ukColNames []string) {
    81  	var err error
    82  	theAST, err := s.sqlParser.ParseOneStmt(sql, "", "")
    83  	if !s.Nilf(err, "parse statement error: %s", sql) {
    84  		return
    85  	}
    86  	selectAST, ok := theAST.(*ast.SelectStmt)
    87  	if !ok {
    88  		s.Fail("cannot convert the AST to select AST")
    89  		return
    90  	}
    91  	s.checkTableName(selectAST.From)
    92  	if !s.Equal(len(s.tableConfig.UniqueKeyColumnNames), len(selectAST.Fields.Fields)) {
    93  		return
    94  	}
    95  	ukColNameMap := generateUKColNameMap(ukColNames)
    96  	for _, field := range selectAST.Fields.Fields {
    97  		fieldNameStr, err := outputString(field)
    98  		if !s.Nil(err) {
    99  			continue
   100  		}
   101  		if _, ok := ukColNameMap[fieldNameStr]; !ok {
   102  			s.Fail(
   103  				"the parsed column name cannot be found in the UK names",
   104  				"parsed column name: %s", fieldNameStr,
   105  			)
   106  		}
   107  	}
   108  }
   109  
   110  func (s *testSQLGenImplSuite) checkTruncateSQL(sql string) {
   111  	theAST, err := s.sqlParser.ParseOneStmt(sql, "", "")
   112  	if !s.Nilf(err, "parse statement error: %s", sql) {
   113  		return
   114  	}
   115  	truncateAST, ok := theAST.(*ast.TruncateTableStmt)
   116  	if !ok {
   117  		s.Fail("cannot convert the AST to truncate AST")
   118  		return
   119  	}
   120  	s.checkTableName(truncateAST.Table)
   121  }
   122  
   123  func (s *testSQLGenImplSuite) checkInsertSQL(sql string) {
   124  	theAST, err := s.sqlParser.ParseOneStmt(sql, "", "")
   125  	if !s.Nilf(err, "parse statement error: %s", sql) {
   126  		return
   127  	}
   128  	insertAST, ok := theAST.(*ast.InsertStmt)
   129  	if !ok {
   130  		s.Fail("cannot convert the AST to insert AST")
   131  		return
   132  	}
   133  	s.checkTableName(insertAST.Table)
   134  	if !s.Equal(len(s.tableConfig.Columns), len(insertAST.Columns)) {
   135  		return
   136  	}
   137  	for _, col := range insertAST.Columns {
   138  		colNameStr, err := outputString(col)
   139  		if !s.Nil(err) {
   140  			continue
   141  		}
   142  		if _, ok := s.allColNameMap[colNameStr]; !ok {
   143  			s.Fail(
   144  				"the parsed column name cannot be found in the column names",
   145  				"parsed column name: %s", colNameStr,
   146  			)
   147  		}
   148  	}
   149  }
   150  
   151  func (s *testSQLGenImplSuite) checkUpdateSQL(sql string, ukColNames []string) {
   152  	theAST, err := s.sqlParser.ParseOneStmt(sql, "", "")
   153  	if !s.Nilf(err, "parse statement error: %s", sql) {
   154  		return
   155  	}
   156  	updateAST, ok := theAST.(*ast.UpdateStmt)
   157  	if !ok {
   158  		s.Fail("cannot convert the AST to update AST")
   159  		return
   160  	}
   161  	s.checkTableName(updateAST.TableRefs)
   162  	s.Greater(len(updateAST.List), 0)
   163  	s.checkWhereClause(updateAST.Where, ukColNames)
   164  }
   165  
   166  func (s *testSQLGenImplSuite) checkDeleteSQL(sql string, ukColNames []string) {
   167  	theAST, err := s.sqlParser.ParseOneStmt(sql, "", "")
   168  	if !s.Nilf(err, "parse statement error: %s", sql) {
   169  		return
   170  	}
   171  	deleteAST, ok := theAST.(*ast.DeleteStmt)
   172  	if !ok {
   173  		s.Fail("cannot convert the AST to delete AST")
   174  		return
   175  	}
   176  	s.checkTableName(deleteAST.TableRefs)
   177  	s.checkWhereClause(deleteAST.Where, ukColNames)
   178  }
   179  
   180  func (s *testSQLGenImplSuite) checkTableName(astNode ast.Node) {
   181  	tableNameStr, err := outputString(astNode)
   182  	if !s.Nil(err) {
   183  		return
   184  	}
   185  	s.Equal(
   186  		fmt.Sprintf("`%s`.`%s`", s.tableConfig.DatabaseName, s.tableConfig.TableName),
   187  		tableNameStr,
   188  	)
   189  }
   190  
   191  func (s *testSQLGenImplSuite) checkWhereClause(astNode ast.Node, ukColNames []string) {
   192  	whereClauseStr, err := outputString(astNode)
   193  	if !s.Nil(err) {
   194  		return
   195  	}
   196  	ukColNameMap := generateUKColNameMap(ukColNames)
   197  	for colName := range ukColNameMap {
   198  		if !s.Truef(
   199  			strings.Contains(whereClauseStr, fmt.Sprintf("%s=", colName)) ||
   200  				strings.Contains(whereClauseStr, fmt.Sprintf("%s IS NULL", colName)),
   201  			"cannot find the column name in the where clause: where clause string: %s; column name: %s",
   202  			whereClauseStr, colName,
   203  		) {
   204  			continue
   205  		}
   206  	}
   207  }
   208  
   209  func (s *testSQLGenImplSuite) TestDMLBasic() {
   210  	var (
   211  		err error
   212  		sql string
   213  		uk  *mcp.UniqueKey
   214  	)
   215  	g := NewSQLGeneratorImpl(s.tableConfig)
   216  
   217  	sql, _, err = g.GenLoadUniqueKeySQL()
   218  	s.Nil(err, "generate load UK SQL error")
   219  	s.T().Logf("Generated SELECT SQL: %s\n", sql)
   220  	s.checkLoadUKsSQL(sql, s.tableConfig.UniqueKeyColumnNames)
   221  
   222  	sql, err = g.GenTruncateTable()
   223  	s.Nil(err, "generate truncate table SQL error")
   224  	s.T().Logf("Generated Truncate Table SQL: %s\n", sql)
   225  	s.checkTruncateSQL(sql)
   226  
   227  	theMCP := mcp.NewModificationCandidatePool(8192)
   228  	for i := 0; i < 4096; i++ {
   229  		s.Nil(
   230  			theMCP.AddUK(mcp.NewUniqueKey(i, map[string]interface{}{
   231  				"id": i,
   232  			})),
   233  		)
   234  	}
   235  	for i := 0; i < 10; i++ {
   236  		uk = theMCP.NextUK()
   237  		sql, err = g.GenUpdateRow(uk)
   238  		s.Nil(err, "generate update sql error")
   239  		s.T().Logf("Generated SQL: %s\n", sql)
   240  		s.checkUpdateSQL(sql, s.tableConfig.UniqueKeyColumnNames)
   241  
   242  		sql, uk, err = g.GenInsertRow()
   243  		s.Nil(err, "generate insert sql error")
   244  		s.T().Logf("Generated SQL: %s\n; Unique key: %v\n", sql, uk)
   245  		s.checkInsertSQL(sql)
   246  
   247  		uk = theMCP.NextUK()
   248  		sql, err = g.GenDeleteRow(uk)
   249  		s.Nil(err, "generate delete sql error")
   250  		s.T().Logf("Generated SQL: %s\n; Unique key: %v\n", sql, uk)
   251  		s.checkDeleteSQL(sql, s.tableConfig.UniqueKeyColumnNames)
   252  	}
   253  }
   254  
   255  func (s *testSQLGenImplSuite) TestWhereNULL() {
   256  	var (
   257  		err error
   258  		sql string
   259  	)
   260  	theTableConfig := &config.TableConfig{
   261  		DatabaseName:         s.tableConfig.DatabaseName,
   262  		TableName:            s.tableConfig.TableName,
   263  		Columns:              s.tableConfig.Columns,
   264  		UniqueKeyColumnNames: []string{"name", "team_id"},
   265  	}
   266  	g := NewSQLGeneratorImpl(theTableConfig)
   267  	theUK := mcp.NewUniqueKey(-1, map[string]interface{}{
   268  		"name":    "ABCDEFG",
   269  		"team_id": nil,
   270  	})
   271  	sql, err = g.GenUpdateRow(theUK)
   272  	s.Require().Nil(err)
   273  	s.T().Logf("Generated UPDATE SQL: %s\n", sql)
   274  	s.checkUpdateSQL(sql, theTableConfig.UniqueKeyColumnNames)
   275  
   276  	sql, err = g.GenDeleteRow(theUK)
   277  	s.Require().Nil(err)
   278  	s.T().Logf("Generated DELETE SQL: %s\n", sql)
   279  	s.checkDeleteSQL(sql, theTableConfig.UniqueKeyColumnNames)
   280  }
   281  
   282  func (s *testSQLGenImplSuite) TestDMLAbnormalUK() {
   283  	var (
   284  		sql string
   285  		err error
   286  		uk  *mcp.UniqueKey
   287  	)
   288  	g := NewSQLGeneratorImpl(s.tableConfig)
   289  	uk = mcp.NewUniqueKey(-1, map[string]interface{}{
   290  		"abcdefg": 123,
   291  	})
   292  	_, err = g.GenUpdateRow(uk)
   293  	s.NotNil(err)
   294  	_, err = g.GenDeleteRow(uk)
   295  	s.NotNil(err)
   296  
   297  	uk = mcp.NewUniqueKey(-1, map[string]interface{}{
   298  		"id":      123,
   299  		"abcdefg": 321,
   300  	})
   301  	sql, err = g.GenUpdateRow(uk)
   302  	s.Nil(err)
   303  	s.T().Logf("Generated SQL: %s\n", sql)
   304  	s.checkUpdateSQL(sql, s.tableConfig.UniqueKeyColumnNames)
   305  
   306  	sql, err = g.GenDeleteRow(uk)
   307  	s.Nil(err)
   308  	s.T().Logf("Generated SQL: %s\n", sql)
   309  	s.checkDeleteSQL(sql, s.tableConfig.UniqueKeyColumnNames)
   310  
   311  	uk = mcp.NewUniqueKey(-1, map[string]interface{}{})
   312  	_, err = g.GenUpdateRow(uk)
   313  	s.NotNil(err)
   314  }
   315  
   316  func (s *testSQLGenImplSuite) TestDMLWithNoUK() {
   317  	var (
   318  		err   error
   319  		sql   string
   320  		theUK *mcp.UniqueKey
   321  	)
   322  	theTableConfig := &config.TableConfig{
   323  		DatabaseName:         s.tableConfig.DatabaseName,
   324  		TableName:            s.tableConfig.TableName,
   325  		Columns:              s.tableConfig.Columns,
   326  		UniqueKeyColumnNames: []string{},
   327  	}
   328  	g := NewSQLGeneratorImpl(theTableConfig)
   329  
   330  	sql, theUK, err = g.GenInsertRow()
   331  	s.Nil(err, "generate insert sql error")
   332  	s.T().Logf("Generated SQL: %s\n; Unique key: %v\n", sql, theUK)
   333  	s.checkInsertSQL(sql)
   334  
   335  	theUK = mcp.NewUniqueKey(-1, map[string]interface{}{})
   336  	_, err = g.GenUpdateRow(theUK)
   337  	s.NotNil(err)
   338  	_, err = g.GenDeleteRow(theUK)
   339  	s.NotNil(err)
   340  
   341  	theUK = mcp.NewUniqueKey(-1, map[string]interface{}{
   342  		"id": 123, // the column is filtered out by the UK configs
   343  	})
   344  	_, err = g.GenUpdateRow(theUK)
   345  	s.NotNil(err)
   346  	_, err = g.GenDeleteRow(theUK)
   347  	s.NotNil(err)
   348  }
   349  
   350  func TestSQLGenImplSuite(t *testing.T) {
   351  	suite.Run(t, &testSQLGenImplSuite{})
   352  }