github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/simulator/sqlgen/impl_test.go (about) 1 // Copyright 2022 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package sqlgen 15 16 import ( 17 "fmt" 18 "strings" 19 "testing" 20 21 "github.com/pingcap/tidb/pkg/parser" 22 "github.com/pingcap/tidb/pkg/parser/ast" 23 "github.com/pingcap/tiflow/dm/pkg/log" 24 "github.com/pingcap/tiflow/dm/simulator/config" 25 "github.com/pingcap/tiflow/dm/simulator/mcp" 26 "github.com/stretchr/testify/suite" 27 ) 28 29 type testSQLGenImplSuite struct { 30 suite.Suite 31 tableConfig *config.TableConfig 32 sqlParser *parser.Parser 33 allColNameMap map[string]struct{} 34 } 35 36 func (s *testSQLGenImplSuite) SetupSuite() { 37 s.allColNameMap = make(map[string]struct{}) 38 s.Require().Nil(log.InitLogger(&log.Config{})) 39 s.tableConfig = &config.TableConfig{ 40 DatabaseName: "games", 41 TableName: "members", 42 Columns: []*config.ColumnDefinition{ 43 { 44 ColumnName: "id", 45 DataType: "int", 46 DataLen: 11, 47 }, 48 { 49 ColumnName: "name", 50 DataType: "varchar", 51 DataLen: 255, 52 }, 53 { 54 ColumnName: "age", 55 DataType: "int", 56 DataLen: 11, 57 }, 58 { 59 ColumnName: "team_id", 60 DataType: "int", 61 DataLen: 11, 62 }, 63 }, 64 UniqueKeyColumnNames: []string{"id"}, 65 } 66 for _, colInfo := range s.tableConfig.Columns { 67 s.allColNameMap[fmt.Sprintf("`%s`", colInfo.ColumnName)] = struct{}{} 68 } 69 s.sqlParser = parser.New() 70 } 71 72 func generateUKColNameMap(ukColNames []string) map[string]struct{} { 73 ukColNameMap := make(map[string]struct{}) 74 for _, colName := range ukColNames { 75 ukColNameMap[fmt.Sprintf("`%s`", colName)] = struct{}{} 76 } 77 return ukColNameMap 78 } 79 80 func (s *testSQLGenImplSuite) checkLoadUKsSQL(sql string, ukColNames []string) { 81 var err error 82 theAST, err := s.sqlParser.ParseOneStmt(sql, "", "") 83 if !s.Nilf(err, "parse statement error: %s", sql) { 84 return 85 } 86 selectAST, ok := theAST.(*ast.SelectStmt) 87 if !ok { 88 s.Fail("cannot convert the AST to select AST") 89 return 90 } 91 s.checkTableName(selectAST.From) 92 if !s.Equal(len(s.tableConfig.UniqueKeyColumnNames), len(selectAST.Fields.Fields)) { 93 return 94 } 95 ukColNameMap := generateUKColNameMap(ukColNames) 96 for _, field := range selectAST.Fields.Fields { 97 fieldNameStr, err := outputString(field) 98 if !s.Nil(err) { 99 continue 100 } 101 if _, ok := ukColNameMap[fieldNameStr]; !ok { 102 s.Fail( 103 "the parsed column name cannot be found in the UK names", 104 "parsed column name: %s", fieldNameStr, 105 ) 106 } 107 } 108 } 109 110 func (s *testSQLGenImplSuite) checkTruncateSQL(sql string) { 111 theAST, err := s.sqlParser.ParseOneStmt(sql, "", "") 112 if !s.Nilf(err, "parse statement error: %s", sql) { 113 return 114 } 115 truncateAST, ok := theAST.(*ast.TruncateTableStmt) 116 if !ok { 117 s.Fail("cannot convert the AST to truncate AST") 118 return 119 } 120 s.checkTableName(truncateAST.Table) 121 } 122 123 func (s *testSQLGenImplSuite) checkInsertSQL(sql string) { 124 theAST, err := s.sqlParser.ParseOneStmt(sql, "", "") 125 if !s.Nilf(err, "parse statement error: %s", sql) { 126 return 127 } 128 insertAST, ok := theAST.(*ast.InsertStmt) 129 if !ok { 130 s.Fail("cannot convert the AST to insert AST") 131 return 132 } 133 s.checkTableName(insertAST.Table) 134 if !s.Equal(len(s.tableConfig.Columns), len(insertAST.Columns)) { 135 return 136 } 137 for _, col := range insertAST.Columns { 138 colNameStr, err := outputString(col) 139 if !s.Nil(err) { 140 continue 141 } 142 if _, ok := s.allColNameMap[colNameStr]; !ok { 143 s.Fail( 144 "the parsed column name cannot be found in the column names", 145 "parsed column name: %s", colNameStr, 146 ) 147 } 148 } 149 } 150 151 func (s *testSQLGenImplSuite) checkUpdateSQL(sql string, ukColNames []string) { 152 theAST, err := s.sqlParser.ParseOneStmt(sql, "", "") 153 if !s.Nilf(err, "parse statement error: %s", sql) { 154 return 155 } 156 updateAST, ok := theAST.(*ast.UpdateStmt) 157 if !ok { 158 s.Fail("cannot convert the AST to update AST") 159 return 160 } 161 s.checkTableName(updateAST.TableRefs) 162 s.Greater(len(updateAST.List), 0) 163 s.checkWhereClause(updateAST.Where, ukColNames) 164 } 165 166 func (s *testSQLGenImplSuite) checkDeleteSQL(sql string, ukColNames []string) { 167 theAST, err := s.sqlParser.ParseOneStmt(sql, "", "") 168 if !s.Nilf(err, "parse statement error: %s", sql) { 169 return 170 } 171 deleteAST, ok := theAST.(*ast.DeleteStmt) 172 if !ok { 173 s.Fail("cannot convert the AST to delete AST") 174 return 175 } 176 s.checkTableName(deleteAST.TableRefs) 177 s.checkWhereClause(deleteAST.Where, ukColNames) 178 } 179 180 func (s *testSQLGenImplSuite) checkTableName(astNode ast.Node) { 181 tableNameStr, err := outputString(astNode) 182 if !s.Nil(err) { 183 return 184 } 185 s.Equal( 186 fmt.Sprintf("`%s`.`%s`", s.tableConfig.DatabaseName, s.tableConfig.TableName), 187 tableNameStr, 188 ) 189 } 190 191 func (s *testSQLGenImplSuite) checkWhereClause(astNode ast.Node, ukColNames []string) { 192 whereClauseStr, err := outputString(astNode) 193 if !s.Nil(err) { 194 return 195 } 196 ukColNameMap := generateUKColNameMap(ukColNames) 197 for colName := range ukColNameMap { 198 if !s.Truef( 199 strings.Contains(whereClauseStr, fmt.Sprintf("%s=", colName)) || 200 strings.Contains(whereClauseStr, fmt.Sprintf("%s IS NULL", colName)), 201 "cannot find the column name in the where clause: where clause string: %s; column name: %s", 202 whereClauseStr, colName, 203 ) { 204 continue 205 } 206 } 207 } 208 209 func (s *testSQLGenImplSuite) TestDMLBasic() { 210 var ( 211 err error 212 sql string 213 uk *mcp.UniqueKey 214 ) 215 g := NewSQLGeneratorImpl(s.tableConfig) 216 217 sql, _, err = g.GenLoadUniqueKeySQL() 218 s.Nil(err, "generate load UK SQL error") 219 s.T().Logf("Generated SELECT SQL: %s\n", sql) 220 s.checkLoadUKsSQL(sql, s.tableConfig.UniqueKeyColumnNames) 221 222 sql, err = g.GenTruncateTable() 223 s.Nil(err, "generate truncate table SQL error") 224 s.T().Logf("Generated Truncate Table SQL: %s\n", sql) 225 s.checkTruncateSQL(sql) 226 227 theMCP := mcp.NewModificationCandidatePool(8192) 228 for i := 0; i < 4096; i++ { 229 s.Nil( 230 theMCP.AddUK(mcp.NewUniqueKey(i, map[string]interface{}{ 231 "id": i, 232 })), 233 ) 234 } 235 for i := 0; i < 10; i++ { 236 uk = theMCP.NextUK() 237 sql, err = g.GenUpdateRow(uk) 238 s.Nil(err, "generate update sql error") 239 s.T().Logf("Generated SQL: %s\n", sql) 240 s.checkUpdateSQL(sql, s.tableConfig.UniqueKeyColumnNames) 241 242 sql, uk, err = g.GenInsertRow() 243 s.Nil(err, "generate insert sql error") 244 s.T().Logf("Generated SQL: %s\n; Unique key: %v\n", sql, uk) 245 s.checkInsertSQL(sql) 246 247 uk = theMCP.NextUK() 248 sql, err = g.GenDeleteRow(uk) 249 s.Nil(err, "generate delete sql error") 250 s.T().Logf("Generated SQL: %s\n; Unique key: %v\n", sql, uk) 251 s.checkDeleteSQL(sql, s.tableConfig.UniqueKeyColumnNames) 252 } 253 } 254 255 func (s *testSQLGenImplSuite) TestWhereNULL() { 256 var ( 257 err error 258 sql string 259 ) 260 theTableConfig := &config.TableConfig{ 261 DatabaseName: s.tableConfig.DatabaseName, 262 TableName: s.tableConfig.TableName, 263 Columns: s.tableConfig.Columns, 264 UniqueKeyColumnNames: []string{"name", "team_id"}, 265 } 266 g := NewSQLGeneratorImpl(theTableConfig) 267 theUK := mcp.NewUniqueKey(-1, map[string]interface{}{ 268 "name": "ABCDEFG", 269 "team_id": nil, 270 }) 271 sql, err = g.GenUpdateRow(theUK) 272 s.Require().Nil(err) 273 s.T().Logf("Generated UPDATE SQL: %s\n", sql) 274 s.checkUpdateSQL(sql, theTableConfig.UniqueKeyColumnNames) 275 276 sql, err = g.GenDeleteRow(theUK) 277 s.Require().Nil(err) 278 s.T().Logf("Generated DELETE SQL: %s\n", sql) 279 s.checkDeleteSQL(sql, theTableConfig.UniqueKeyColumnNames) 280 } 281 282 func (s *testSQLGenImplSuite) TestDMLAbnormalUK() { 283 var ( 284 sql string 285 err error 286 uk *mcp.UniqueKey 287 ) 288 g := NewSQLGeneratorImpl(s.tableConfig) 289 uk = mcp.NewUniqueKey(-1, map[string]interface{}{ 290 "abcdefg": 123, 291 }) 292 _, err = g.GenUpdateRow(uk) 293 s.NotNil(err) 294 _, err = g.GenDeleteRow(uk) 295 s.NotNil(err) 296 297 uk = mcp.NewUniqueKey(-1, map[string]interface{}{ 298 "id": 123, 299 "abcdefg": 321, 300 }) 301 sql, err = g.GenUpdateRow(uk) 302 s.Nil(err) 303 s.T().Logf("Generated SQL: %s\n", sql) 304 s.checkUpdateSQL(sql, s.tableConfig.UniqueKeyColumnNames) 305 306 sql, err = g.GenDeleteRow(uk) 307 s.Nil(err) 308 s.T().Logf("Generated SQL: %s\n", sql) 309 s.checkDeleteSQL(sql, s.tableConfig.UniqueKeyColumnNames) 310 311 uk = mcp.NewUniqueKey(-1, map[string]interface{}{}) 312 _, err = g.GenUpdateRow(uk) 313 s.NotNil(err) 314 } 315 316 func (s *testSQLGenImplSuite) TestDMLWithNoUK() { 317 var ( 318 err error 319 sql string 320 theUK *mcp.UniqueKey 321 ) 322 theTableConfig := &config.TableConfig{ 323 DatabaseName: s.tableConfig.DatabaseName, 324 TableName: s.tableConfig.TableName, 325 Columns: s.tableConfig.Columns, 326 UniqueKeyColumnNames: []string{}, 327 } 328 g := NewSQLGeneratorImpl(theTableConfig) 329 330 sql, theUK, err = g.GenInsertRow() 331 s.Nil(err, "generate insert sql error") 332 s.T().Logf("Generated SQL: %s\n; Unique key: %v\n", sql, theUK) 333 s.checkInsertSQL(sql) 334 335 theUK = mcp.NewUniqueKey(-1, map[string]interface{}{}) 336 _, err = g.GenUpdateRow(theUK) 337 s.NotNil(err) 338 _, err = g.GenDeleteRow(theUK) 339 s.NotNil(err) 340 341 theUK = mcp.NewUniqueKey(-1, map[string]interface{}{ 342 "id": 123, // the column is filtered out by the UK configs 343 }) 344 _, err = g.GenUpdateRow(theUK) 345 s.NotNil(err) 346 _, err = g.GenDeleteRow(theUK) 347 s.NotNil(err) 348 } 349 350 func TestSQLGenImplSuite(t *testing.T) { 351 suite.Run(t, &testSQLGenImplSuite{}) 352 }