github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/pkg/parser/common.go (about) 1 // Copyright 2019 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package parser 15 16 import ( 17 "bytes" 18 19 "github.com/pingcap/tidb/pkg/parser" 20 "github.com/pingcap/tidb/pkg/parser/ast" 21 "github.com/pingcap/tidb/pkg/parser/format" 22 "github.com/pingcap/tidb/pkg/parser/model" 23 _ "github.com/pingcap/tidb/pkg/types/parser_driver" // for import parser driver 24 "github.com/pingcap/tidb/pkg/util/filter" 25 "github.com/pingcap/tiflow/dm/pkg/conn" 26 "github.com/pingcap/tiflow/dm/pkg/log" 27 "github.com/pingcap/tiflow/dm/pkg/terror" 28 "github.com/pingcap/tiflow/dm/pkg/utils" 29 "go.uber.org/zap" 30 ) 31 32 const ( 33 // SingleRenameTableNameNum stands for number of TableNames in a single table renaming. it's 2 after 34 // https://github.com/pingcap/parser/pull/1021 35 SingleRenameTableNameNum = 2 36 ) 37 38 // Parse wraps parser.Parse(), makes `parser` suitable for dm. 39 func Parse(p *parser.Parser, sql, charset, collation string) (stmt []ast.StmtNode, err error) { 40 stmts, warnings, err := p.Parse(sql, charset, collation) 41 if len(warnings) > 0 { 42 log.L().Warn("parse statement", zap.String("sql", sql), zap.Errors("warning messages", warnings)) 43 } 44 45 return stmts, terror.ErrParseSQL.Delegate(err) 46 } 47 48 // ref: https://github.com/pingcap/tidb/blob/09feccb529be2830944e11f5fed474020f50370f/server/sql_info_fetcher.go#L46 49 type tableNameExtractor struct { 50 curDB string 51 flavor conn.LowerCaseTableNamesFlavor 52 names []*filter.Table 53 } 54 55 func (tne *tableNameExtractor) Enter(in ast.Node) (ast.Node, bool) { 56 if _, ok := in.(*ast.ReferenceDef); ok { 57 return in, true 58 } 59 if t, ok := in.(*ast.TableName); ok { 60 var tb *filter.Table 61 if tne.flavor == conn.LCTableNamesSensitive { 62 tb = &filter.Table{Schema: t.Schema.O, Name: t.Name.O} 63 } else { 64 tb = &filter.Table{Schema: t.Schema.L, Name: t.Name.L} 65 } 66 67 if tb.Schema == "" { 68 tb.Schema = tne.curDB 69 } 70 tne.names = append(tne.names, tb) 71 return in, true 72 } 73 return in, false 74 } 75 76 func (tne *tableNameExtractor) Leave(in ast.Node) (ast.Node, bool) { 77 return in, true 78 } 79 80 // FetchDDLTables returns tables in ddl the result contains many tables. 81 // Because we use visitor pattern, first tableName is always upper-most table in ast 82 // specifically, for `create table like` DDL, result contains [sourceTable, sourceRefTable] 83 // for rename table ddl, result contains [old1, new1, old2, new2, old3, new3, ...] because of TiDB parser 84 // for other DDL, order of tableName is the node visit order. 85 func FetchDDLTables(schema string, stmt ast.StmtNode, flavor conn.LowerCaseTableNamesFlavor) ([]*filter.Table, error) { 86 switch stmt.(type) { 87 case ast.DDLNode: 88 default: 89 return nil, terror.ErrUnknownTypeDDL.Generate(stmt) 90 } 91 92 // special cases: schema related SQLs doesn't have tableName 93 // todo: pass .O or .L of table name depends on flavor 94 switch v := stmt.(type) { 95 case *ast.AlterDatabaseStmt: 96 return []*filter.Table{genTableName(v.Name.O, "")}, nil 97 case *ast.CreateDatabaseStmt: 98 return []*filter.Table{genTableName(v.Name.O, "")}, nil 99 case *ast.DropDatabaseStmt: 100 return []*filter.Table{genTableName(v.Name.O, "")}, nil 101 } 102 103 e := &tableNameExtractor{ 104 curDB: schema, 105 flavor: flavor, 106 names: make([]*filter.Table, 0), 107 } 108 stmt.Accept(e) 109 110 return e.names, nil 111 } 112 113 type tableRenameVisitor struct { 114 targetNames []*filter.Table 115 i int 116 hasErr bool 117 } 118 119 func (v *tableRenameVisitor) Enter(in ast.Node) (ast.Node, bool) { 120 if v.hasErr { 121 return in, true 122 } 123 if _, ok := in.(*ast.ReferenceDef); ok { 124 return in, true 125 } 126 if t, ok := in.(*ast.TableName); ok { 127 if v.i >= len(v.targetNames) { 128 v.hasErr = true 129 return in, true 130 } 131 t.Schema = model.NewCIStr(v.targetNames[v.i].Schema) 132 t.Name = model.NewCIStr(v.targetNames[v.i].Name) 133 v.i++ 134 return in, true 135 } 136 return in, false 137 } 138 139 func (v *tableRenameVisitor) Leave(in ast.Node) (ast.Node, bool) { 140 if v.hasErr { 141 return in, false 142 } 143 return in, true 144 } 145 146 // RenameDDLTable renames tables in ddl by given `targetTables` 147 // argument `targetTables` is same with return value of FetchDDLTables 148 // returned DDL is formatted like StringSingleQuotes, KeyWordUppercase and NameBackQuotes. 149 func RenameDDLTable(stmt ast.StmtNode, targetTables []*filter.Table) (string, error) { 150 switch stmt.(type) { 151 case ast.DDLNode: 152 default: 153 return "", terror.ErrUnknownTypeDDL.Generate(stmt) 154 } 155 156 switch v := stmt.(type) { 157 case *ast.AlterDatabaseStmt: 158 v.Name = model.NewCIStr(targetTables[0].Schema) 159 case *ast.CreateDatabaseStmt: 160 v.Name = model.NewCIStr(targetTables[0].Schema) 161 case *ast.DropDatabaseStmt: 162 v.Name = model.NewCIStr(targetTables[0].Schema) 163 default: 164 visitor := &tableRenameVisitor{ 165 targetNames: targetTables, 166 } 167 stmt.Accept(visitor) 168 if visitor.hasErr { 169 return "", terror.ErrRewriteSQL.Generate(stmt, targetTables) 170 } 171 } 172 173 var b []byte 174 bf := bytes.NewBuffer(b) 175 err := stmt.Restore(&format.RestoreCtx{ 176 Flags: format.DefaultRestoreFlags | format.RestoreTiDBSpecialComment | format.RestoreStringWithoutDefaultCharset, 177 In: bf, 178 }) 179 if err != nil { 180 return "", terror.ErrRestoreASTNode.Delegate(err) 181 } 182 183 return bf.String(), nil 184 } 185 186 // SplitDDL splits multiple operations in one DDL statement into multiple DDL statements 187 // returned DDL is formatted like StringSingleQuotes, KeyWordUppercase and NameBackQuotes 188 // if fail to restore, it would not restore the value of `stmt` (it changes it's values if `stmt` is one of DropTableStmt, RenameTableStmt, AlterTableStmt). 189 func SplitDDL(stmt ast.StmtNode, schema string) (sqls []string, err error) { 190 var ( 191 schemaName = model.NewCIStr(schema) // fill schema name 192 bf = new(bytes.Buffer) 193 ctx = &format.RestoreCtx{ 194 Flags: format.DefaultRestoreFlags | format.RestoreTiDBSpecialComment | format.RestoreStringWithoutDefaultCharset, 195 In: bf, 196 } 197 ) 198 199 switch v := stmt.(type) { 200 case *ast.CreateSequenceStmt: 201 case *ast.AlterSequenceStmt: 202 case *ast.DropSequenceStmt: 203 case *ast.AlterDatabaseStmt: 204 case *ast.CreateDatabaseStmt: 205 v.IfNotExists = true 206 case *ast.DropDatabaseStmt: 207 v.IfExists = true 208 case *ast.DropTableStmt: 209 v.IfExists = true 210 211 tables := v.Tables 212 for _, t := range tables { 213 if t.Schema.O == "" { 214 t.Schema = schemaName 215 } 216 217 v.Tables = []*ast.TableName{t} 218 bf.Reset() 219 err = stmt.Restore(ctx) 220 if err != nil { 221 v.Tables = tables 222 return nil, terror.ErrRestoreASTNode.Delegate(err) 223 } 224 225 sqls = append(sqls, bf.String()) 226 } 227 v.Tables = tables 228 229 return sqls, nil 230 case *ast.CreateTableStmt: 231 v.IfNotExists = true 232 if v.Table.Schema.O == "" { 233 v.Table.Schema = schemaName 234 } 235 236 if v.ReferTable != nil && v.ReferTable.Schema.O == "" { 237 v.ReferTable.Schema = schemaName 238 } 239 case *ast.TruncateTableStmt: 240 if v.Table.Schema.O == "" { 241 v.Table.Schema = schemaName 242 } 243 case *ast.DropIndexStmt: 244 v.IfExists = true 245 if v.Table.Schema.O == "" { 246 v.Table.Schema = schemaName 247 } 248 case *ast.CreateIndexStmt: 249 if v.Table.Schema.O == "" { 250 v.Table.Schema = schemaName 251 } 252 case *ast.RenameTableStmt: 253 t2ts := v.TableToTables 254 for _, t2t := range t2ts { 255 if t2t.OldTable.Schema.O == "" { 256 t2t.OldTable.Schema = schemaName 257 } 258 if t2t.NewTable.Schema.O == "" { 259 t2t.NewTable.Schema = schemaName 260 } 261 262 v.TableToTables = []*ast.TableToTable{t2t} 263 264 bf.Reset() 265 err = stmt.Restore(ctx) 266 if err != nil { 267 v.TableToTables = t2ts 268 return nil, terror.ErrRestoreASTNode.Delegate(err) 269 } 270 271 sqls = append(sqls, bf.String()) 272 } 273 v.TableToTables = t2ts 274 275 return sqls, nil 276 case *ast.AlterTableStmt: 277 specs := v.Specs 278 table := v.Table 279 280 if v.Table.Schema.O == "" { 281 v.Table.Schema = schemaName 282 } 283 284 for _, spec := range specs { 285 if spec.Tp == ast.AlterTableRenameTable { 286 if spec.NewTable.Schema.O == "" { 287 spec.NewTable.Schema = schemaName 288 } 289 } 290 291 v.Specs = []*ast.AlterTableSpec{spec} 292 293 // handle `alter table t1 add column (c1 int, c2 int)` 294 if spec.Tp == ast.AlterTableAddColumns && len(spec.NewColumns) > 1 { 295 columns := spec.NewColumns 296 spec.Position = &ast.ColumnPosition{ 297 Tp: ast.ColumnPositionNone, // otherwise restore will become "alter table t1 add column (c1 int)" 298 } 299 for _, c := range columns { 300 spec.NewColumns = []*ast.ColumnDef{c} 301 bf.Reset() 302 err = stmt.Restore(ctx) 303 if err != nil { 304 v.Specs = specs 305 v.Table = table 306 return nil, terror.ErrRestoreASTNode.Delegate(err) 307 } 308 sqls = append(sqls, bf.String()) 309 } 310 // we have restore SQL for every columns, skip below general restoring and continue on next spec 311 continue 312 } 313 314 bf.Reset() 315 err = stmt.Restore(ctx) 316 if err != nil { 317 v.Specs = specs 318 v.Table = table 319 return nil, terror.ErrRestoreASTNode.Delegate(err) 320 } 321 sqls = append(sqls, bf.String()) 322 323 if spec.Tp == ast.AlterTableRenameTable { 324 v.Table = spec.NewTable 325 } 326 } 327 v.Specs = specs 328 v.Table = table 329 330 return sqls, nil 331 default: 332 return nil, terror.ErrUnknownTypeDDL.Generate(stmt) 333 } 334 335 bf.Reset() 336 err = stmt.Restore(ctx) 337 if err != nil { 338 return nil, terror.ErrRestoreASTNode.Delegate(err) 339 } 340 sqls = append(sqls, bf.String()) 341 342 return sqls, nil 343 } 344 345 func genTableName(schema string, table string) *filter.Table { 346 return &filter.Table{Schema: schema, Name: table} 347 } 348 349 // CheckIsDDL checks input SQL whether is a valid DDL statement. 350 func CheckIsDDL(sql string, p *parser.Parser) bool { 351 // fast path for begin/comit 352 if sql == "BEGIN" || sql == "COMMIT" { 353 return false 354 } 355 sql = utils.TrimCtrlChars(sql) 356 357 if utils.IsBuildInSkipDDL(sql) { 358 return false 359 } 360 361 // if parse error, treat it as not a DDL 362 stmts, err := Parse(p, sql, "", "") 363 if err != nil || len(stmts) == 0 { 364 return false 365 } 366 367 stmt := stmts[0] 368 switch stmt.(type) { 369 case ast.DDLNode: 370 return true 371 default: 372 // other thing this like `BEGIN` 373 return false 374 } 375 }