github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/pkg/parser/common.go (about)

     1  // Copyright 2019 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package parser
    15  
    16  import (
    17  	"bytes"
    18  
    19  	"github.com/pingcap/tidb/pkg/parser"
    20  	"github.com/pingcap/tidb/pkg/parser/ast"
    21  	"github.com/pingcap/tidb/pkg/parser/format"
    22  	"github.com/pingcap/tidb/pkg/parser/model"
    23  	_ "github.com/pingcap/tidb/pkg/types/parser_driver" // for import parser driver
    24  	"github.com/pingcap/tidb/pkg/util/filter"
    25  	"github.com/pingcap/tiflow/dm/pkg/conn"
    26  	"github.com/pingcap/tiflow/dm/pkg/log"
    27  	"github.com/pingcap/tiflow/dm/pkg/terror"
    28  	"github.com/pingcap/tiflow/dm/pkg/utils"
    29  	"go.uber.org/zap"
    30  )
    31  
    32  const (
    33  	// SingleRenameTableNameNum stands for number of TableNames in a single table renaming. it's 2 after
    34  	// https://github.com/pingcap/parser/pull/1021
    35  	SingleRenameTableNameNum = 2
    36  )
    37  
    38  // Parse wraps parser.Parse(), makes `parser` suitable for dm.
    39  func Parse(p *parser.Parser, sql, charset, collation string) (stmt []ast.StmtNode, err error) {
    40  	stmts, warnings, err := p.Parse(sql, charset, collation)
    41  	if len(warnings) > 0 {
    42  		log.L().Warn("parse statement", zap.String("sql", sql), zap.Errors("warning messages", warnings))
    43  	}
    44  
    45  	return stmts, terror.ErrParseSQL.Delegate(err)
    46  }
    47  
    48  // ref: https://github.com/pingcap/tidb/blob/09feccb529be2830944e11f5fed474020f50370f/server/sql_info_fetcher.go#L46
    49  type tableNameExtractor struct {
    50  	curDB  string
    51  	flavor conn.LowerCaseTableNamesFlavor
    52  	names  []*filter.Table
    53  }
    54  
    55  func (tne *tableNameExtractor) Enter(in ast.Node) (ast.Node, bool) {
    56  	if _, ok := in.(*ast.ReferenceDef); ok {
    57  		return in, true
    58  	}
    59  	if t, ok := in.(*ast.TableName); ok {
    60  		var tb *filter.Table
    61  		if tne.flavor == conn.LCTableNamesSensitive {
    62  			tb = &filter.Table{Schema: t.Schema.O, Name: t.Name.O}
    63  		} else {
    64  			tb = &filter.Table{Schema: t.Schema.L, Name: t.Name.L}
    65  		}
    66  
    67  		if tb.Schema == "" {
    68  			tb.Schema = tne.curDB
    69  		}
    70  		tne.names = append(tne.names, tb)
    71  		return in, true
    72  	}
    73  	return in, false
    74  }
    75  
    76  func (tne *tableNameExtractor) Leave(in ast.Node) (ast.Node, bool) {
    77  	return in, true
    78  }
    79  
    80  // FetchDDLTables returns tables in ddl the result contains many tables.
    81  // Because we use visitor pattern, first tableName is always upper-most table in ast
    82  // specifically, for `create table like` DDL, result contains [sourceTable, sourceRefTable]
    83  // for rename table ddl, result contains [old1, new1, old2, new2, old3, new3, ...] because of TiDB parser
    84  // for other DDL, order of tableName is the node visit order.
    85  func FetchDDLTables(schema string, stmt ast.StmtNode, flavor conn.LowerCaseTableNamesFlavor) ([]*filter.Table, error) {
    86  	switch stmt.(type) {
    87  	case ast.DDLNode:
    88  	default:
    89  		return nil, terror.ErrUnknownTypeDDL.Generate(stmt)
    90  	}
    91  
    92  	// special cases: schema related SQLs doesn't have tableName
    93  	// todo: pass .O or .L of table name depends on flavor
    94  	switch v := stmt.(type) {
    95  	case *ast.AlterDatabaseStmt:
    96  		return []*filter.Table{genTableName(v.Name.O, "")}, nil
    97  	case *ast.CreateDatabaseStmt:
    98  		return []*filter.Table{genTableName(v.Name.O, "")}, nil
    99  	case *ast.DropDatabaseStmt:
   100  		return []*filter.Table{genTableName(v.Name.O, "")}, nil
   101  	}
   102  
   103  	e := &tableNameExtractor{
   104  		curDB:  schema,
   105  		flavor: flavor,
   106  		names:  make([]*filter.Table, 0),
   107  	}
   108  	stmt.Accept(e)
   109  
   110  	return e.names, nil
   111  }
   112  
   113  type tableRenameVisitor struct {
   114  	targetNames []*filter.Table
   115  	i           int
   116  	hasErr      bool
   117  }
   118  
   119  func (v *tableRenameVisitor) Enter(in ast.Node) (ast.Node, bool) {
   120  	if v.hasErr {
   121  		return in, true
   122  	}
   123  	if _, ok := in.(*ast.ReferenceDef); ok {
   124  		return in, true
   125  	}
   126  	if t, ok := in.(*ast.TableName); ok {
   127  		if v.i >= len(v.targetNames) {
   128  			v.hasErr = true
   129  			return in, true
   130  		}
   131  		t.Schema = model.NewCIStr(v.targetNames[v.i].Schema)
   132  		t.Name = model.NewCIStr(v.targetNames[v.i].Name)
   133  		v.i++
   134  		return in, true
   135  	}
   136  	return in, false
   137  }
   138  
   139  func (v *tableRenameVisitor) Leave(in ast.Node) (ast.Node, bool) {
   140  	if v.hasErr {
   141  		return in, false
   142  	}
   143  	return in, true
   144  }
   145  
   146  // RenameDDLTable renames tables in ddl by given `targetTables`
   147  // argument `targetTables` is same with return value of FetchDDLTables
   148  // returned DDL is formatted like StringSingleQuotes, KeyWordUppercase and NameBackQuotes.
   149  func RenameDDLTable(stmt ast.StmtNode, targetTables []*filter.Table) (string, error) {
   150  	switch stmt.(type) {
   151  	case ast.DDLNode:
   152  	default:
   153  		return "", terror.ErrUnknownTypeDDL.Generate(stmt)
   154  	}
   155  
   156  	switch v := stmt.(type) {
   157  	case *ast.AlterDatabaseStmt:
   158  		v.Name = model.NewCIStr(targetTables[0].Schema)
   159  	case *ast.CreateDatabaseStmt:
   160  		v.Name = model.NewCIStr(targetTables[0].Schema)
   161  	case *ast.DropDatabaseStmt:
   162  		v.Name = model.NewCIStr(targetTables[0].Schema)
   163  	default:
   164  		visitor := &tableRenameVisitor{
   165  			targetNames: targetTables,
   166  		}
   167  		stmt.Accept(visitor)
   168  		if visitor.hasErr {
   169  			return "", terror.ErrRewriteSQL.Generate(stmt, targetTables)
   170  		}
   171  	}
   172  
   173  	var b []byte
   174  	bf := bytes.NewBuffer(b)
   175  	err := stmt.Restore(&format.RestoreCtx{
   176  		Flags: format.DefaultRestoreFlags | format.RestoreTiDBSpecialComment | format.RestoreStringWithoutDefaultCharset,
   177  		In:    bf,
   178  	})
   179  	if err != nil {
   180  		return "", terror.ErrRestoreASTNode.Delegate(err)
   181  	}
   182  
   183  	return bf.String(), nil
   184  }
   185  
   186  // SplitDDL splits multiple operations in one DDL statement into multiple DDL statements
   187  // returned DDL is formatted like StringSingleQuotes, KeyWordUppercase and NameBackQuotes
   188  // if fail to restore, it would not restore the value of `stmt` (it changes it's values if `stmt` is one of  DropTableStmt, RenameTableStmt, AlterTableStmt).
   189  func SplitDDL(stmt ast.StmtNode, schema string) (sqls []string, err error) {
   190  	var (
   191  		schemaName = model.NewCIStr(schema) // fill schema name
   192  		bf         = new(bytes.Buffer)
   193  		ctx        = &format.RestoreCtx{
   194  			Flags: format.DefaultRestoreFlags | format.RestoreTiDBSpecialComment | format.RestoreStringWithoutDefaultCharset,
   195  			In:    bf,
   196  		}
   197  	)
   198  
   199  	switch v := stmt.(type) {
   200  	case *ast.CreateSequenceStmt:
   201  	case *ast.AlterSequenceStmt:
   202  	case *ast.DropSequenceStmt:
   203  	case *ast.AlterDatabaseStmt:
   204  	case *ast.CreateDatabaseStmt:
   205  		v.IfNotExists = true
   206  	case *ast.DropDatabaseStmt:
   207  		v.IfExists = true
   208  	case *ast.DropTableStmt:
   209  		v.IfExists = true
   210  
   211  		tables := v.Tables
   212  		for _, t := range tables {
   213  			if t.Schema.O == "" {
   214  				t.Schema = schemaName
   215  			}
   216  
   217  			v.Tables = []*ast.TableName{t}
   218  			bf.Reset()
   219  			err = stmt.Restore(ctx)
   220  			if err != nil {
   221  				v.Tables = tables
   222  				return nil, terror.ErrRestoreASTNode.Delegate(err)
   223  			}
   224  
   225  			sqls = append(sqls, bf.String())
   226  		}
   227  		v.Tables = tables
   228  
   229  		return sqls, nil
   230  	case *ast.CreateTableStmt:
   231  		v.IfNotExists = true
   232  		if v.Table.Schema.O == "" {
   233  			v.Table.Schema = schemaName
   234  		}
   235  
   236  		if v.ReferTable != nil && v.ReferTable.Schema.O == "" {
   237  			v.ReferTable.Schema = schemaName
   238  		}
   239  	case *ast.TruncateTableStmt:
   240  		if v.Table.Schema.O == "" {
   241  			v.Table.Schema = schemaName
   242  		}
   243  	case *ast.DropIndexStmt:
   244  		v.IfExists = true
   245  		if v.Table.Schema.O == "" {
   246  			v.Table.Schema = schemaName
   247  		}
   248  	case *ast.CreateIndexStmt:
   249  		if v.Table.Schema.O == "" {
   250  			v.Table.Schema = schemaName
   251  		}
   252  	case *ast.RenameTableStmt:
   253  		t2ts := v.TableToTables
   254  		for _, t2t := range t2ts {
   255  			if t2t.OldTable.Schema.O == "" {
   256  				t2t.OldTable.Schema = schemaName
   257  			}
   258  			if t2t.NewTable.Schema.O == "" {
   259  				t2t.NewTable.Schema = schemaName
   260  			}
   261  
   262  			v.TableToTables = []*ast.TableToTable{t2t}
   263  
   264  			bf.Reset()
   265  			err = stmt.Restore(ctx)
   266  			if err != nil {
   267  				v.TableToTables = t2ts
   268  				return nil, terror.ErrRestoreASTNode.Delegate(err)
   269  			}
   270  
   271  			sqls = append(sqls, bf.String())
   272  		}
   273  		v.TableToTables = t2ts
   274  
   275  		return sqls, nil
   276  	case *ast.AlterTableStmt:
   277  		specs := v.Specs
   278  		table := v.Table
   279  
   280  		if v.Table.Schema.O == "" {
   281  			v.Table.Schema = schemaName
   282  		}
   283  
   284  		for _, spec := range specs {
   285  			if spec.Tp == ast.AlterTableRenameTable {
   286  				if spec.NewTable.Schema.O == "" {
   287  					spec.NewTable.Schema = schemaName
   288  				}
   289  			}
   290  
   291  			v.Specs = []*ast.AlterTableSpec{spec}
   292  
   293  			// handle `alter table t1 add column (c1 int, c2 int)`
   294  			if spec.Tp == ast.AlterTableAddColumns && len(spec.NewColumns) > 1 {
   295  				columns := spec.NewColumns
   296  				spec.Position = &ast.ColumnPosition{
   297  					Tp: ast.ColumnPositionNone, // otherwise restore will become "alter table t1 add column (c1 int)"
   298  				}
   299  				for _, c := range columns {
   300  					spec.NewColumns = []*ast.ColumnDef{c}
   301  					bf.Reset()
   302  					err = stmt.Restore(ctx)
   303  					if err != nil {
   304  						v.Specs = specs
   305  						v.Table = table
   306  						return nil, terror.ErrRestoreASTNode.Delegate(err)
   307  					}
   308  					sqls = append(sqls, bf.String())
   309  				}
   310  				// we have restore SQL for every columns, skip below general restoring and continue on next spec
   311  				continue
   312  			}
   313  
   314  			bf.Reset()
   315  			err = stmt.Restore(ctx)
   316  			if err != nil {
   317  				v.Specs = specs
   318  				v.Table = table
   319  				return nil, terror.ErrRestoreASTNode.Delegate(err)
   320  			}
   321  			sqls = append(sqls, bf.String())
   322  
   323  			if spec.Tp == ast.AlterTableRenameTable {
   324  				v.Table = spec.NewTable
   325  			}
   326  		}
   327  		v.Specs = specs
   328  		v.Table = table
   329  
   330  		return sqls, nil
   331  	default:
   332  		return nil, terror.ErrUnknownTypeDDL.Generate(stmt)
   333  	}
   334  
   335  	bf.Reset()
   336  	err = stmt.Restore(ctx)
   337  	if err != nil {
   338  		return nil, terror.ErrRestoreASTNode.Delegate(err)
   339  	}
   340  	sqls = append(sqls, bf.String())
   341  
   342  	return sqls, nil
   343  }
   344  
   345  func genTableName(schema string, table string) *filter.Table {
   346  	return &filter.Table{Schema: schema, Name: table}
   347  }
   348  
   349  // CheckIsDDL checks input SQL whether is a valid DDL statement.
   350  func CheckIsDDL(sql string, p *parser.Parser) bool {
   351  	// fast path for begin/comit
   352  	if sql == "BEGIN" || sql == "COMMIT" {
   353  		return false
   354  	}
   355  	sql = utils.TrimCtrlChars(sql)
   356  
   357  	if utils.IsBuildInSkipDDL(sql) {
   358  		return false
   359  	}
   360  
   361  	// if parse error, treat it as not a DDL
   362  	stmts, err := Parse(p, sql, "", "")
   363  	if err != nil || len(stmts) == 0 {
   364  		return false
   365  	}
   366  
   367  	stmt := stmts[0]
   368  	switch stmt.(type) {
   369  	case ast.DDLNode:
   370  		return true
   371  	default:
   372  		// other thing this like `BEGIN`
   373  		return false
   374  	}
   375  }