github.com/iasthc/atlas/cmd/atlas@v0.0.0-20230523071841-73246df3f88d/internal/sqlparse/myparse/myparse.go (about)

     1  // Copyright 2021-present The Atlas Authors. All rights reserved.
     2  // This source code is licensed under the Apache 2.0 license found
     3  // in the LICENSE file in the root directory of this source tree.
     4  
     5  package myparse
     6  
     7  import (
     8  	"fmt"
     9  	"strconv"
    10  
    11  	"github.com/iasthc/atlas/cmd/atlas/internal/sqlparse/parseutil"
    12  	"github.com/iasthc/atlas/sql/migrate"
    13  	"github.com/iasthc/atlas/sql/schema"
    14  
    15  	"github.com/pingcap/tidb/parser"
    16  	"github.com/pingcap/tidb/parser/ast"
    17  	"github.com/pingcap/tidb/parser/mysql"
    18  	"github.com/pingcap/tidb/parser/opcode"
    19  	"github.com/pingcap/tidb/parser/test_driver"
    20  	"golang.org/x/exp/slices"
    21  )
    22  
    23  // Parser implements the sqlparse.Parser
    24  type Parser struct{}
    25  
    26  // ColumnFilledBefore checks if the column was filled before the given position.
    27  func (p *Parser) ColumnFilledBefore(f migrate.File, t *schema.Table, c *schema.Column, pos int) (bool, error) {
    28  	pr := parser.New()
    29  	return parseutil.MatchStmtBefore(f, pos, func(s *migrate.Stmt) (bool, error) {
    30  		stmt, err := pr.ParseOneStmt(s.Text, "", "")
    31  		if err != nil {
    32  			return false, err
    33  		}
    34  		u, ok := stmt.(*ast.UpdateStmt)
    35  		// Ensure the table was updated.
    36  		if !ok || !tableUpdated(u, t) {
    37  			return false, nil
    38  		}
    39  		// Accept UPDATE that fills all rows or those with NULL values as we cannot
    40  		// determine if NULL values were filled in case there is a custom filtering.
    41  		affectC := func() bool {
    42  			if u.Where == nil {
    43  				return true
    44  			}
    45  			is, ok := u.Where.(*ast.IsNullExpr)
    46  			if !ok || is.Not {
    47  				return false
    48  			}
    49  			n, ok := is.Expr.(*ast.ColumnNameExpr)
    50  			return ok && n.Name.Name.O == c.Name
    51  		}()
    52  		idx := slices.IndexFunc(u.List, func(a *ast.Assignment) bool {
    53  			return a.Column.Name.String() == c.Name && a.Expr != nil && a.Expr.GetType().GetType() != mysql.TypeNull
    54  		})
    55  		// Ensure the column was filled.
    56  		return affectC && idx != -1, nil
    57  	})
    58  }
    59  
    60  // ColumnFilledAfter checks if the column that matches the given value was filled after the position.
    61  func (p *Parser) ColumnFilledAfter(f migrate.File, t *schema.Table, c *schema.Column, pos int, match any) (bool, error) {
    62  	stmts, err := f.StmtDecls()
    63  	if err != nil {
    64  		return false, err
    65  	}
    66  	switch i := slices.IndexFunc(stmts, func(s *migrate.Stmt) bool {
    67  		return s.Pos >= pos
    68  	}); i {
    69  	case -1:
    70  		return false, nil
    71  	default:
    72  		stmts = stmts[i:]
    73  	}
    74  	pr := parser.New()
    75  	for _, s := range stmts {
    76  		stmt, err := pr.ParseOneStmt(s.Text, "", "")
    77  		if err != nil {
    78  			return false, err
    79  		}
    80  		u, ok := stmt.(*ast.UpdateStmt)
    81  		// Ensure the table was updated.
    82  		if !ok || !tableUpdated(u, t) {
    83  			continue
    84  		}
    85  		// Accept UPDATE that fills all rows or those with NULL values as we cannot
    86  		// determine if NULL values were filled in case there is a custom filtering.
    87  		affectC := func() bool {
    88  			if u.Where == nil {
    89  				return true
    90  			}
    91  			switch match.(type) {
    92  			case nil:
    93  				is, ok := u.Where.(*ast.IsNullExpr)
    94  				if !ok || is.Not {
    95  					return false
    96  				}
    97  				n, ok := is.Expr.(*ast.ColumnNameExpr)
    98  				return ok && n.Name.Name.O == c.Name
    99  			default:
   100  				bin, ok := u.Where.(*ast.BinaryOperationExpr)
   101  				if !ok || bin.Op != opcode.EQ {
   102  					return false
   103  				}
   104  				l, r := bin.L, bin.R
   105  				if _, ok := bin.L.(*ast.ColumnNameExpr); !ok {
   106  					l, r = r, l
   107  				}
   108  				n, ok1 := l.(*ast.ColumnNameExpr)
   109  				v, ok2 := r.(*test_driver.ValueExpr)
   110  				if !ok1 || !ok2 || n.Name.Name.O != c.Name {
   111  					return false
   112  				}
   113  				x := fmt.Sprint(match)
   114  				if u, err := strconv.Unquote(x); err == nil {
   115  					x = u
   116  				}
   117  				// String representations should be ~equal.
   118  				return fmt.Sprint(v.Datum.GetValue()) == x
   119  			}
   120  		}()
   121  		idx := slices.IndexFunc(u.List, func(a *ast.Assignment) bool {
   122  			return a.Column.Name.String() == c.Name && a.Expr != nil && a.Expr.GetType().GetType() != mysql.TypeNull
   123  		})
   124  		// Ensure the column was filled.
   125  		if affectC && idx != -1 {
   126  			return true, nil
   127  		}
   128  	}
   129  	return false, nil
   130  }
   131  
   132  // ColumnHasReferences checks if the column has an inline REFERENCES clause in the given CREATE or ALTER statement.
   133  func (p *Parser) ColumnHasReferences(stmt *migrate.Stmt, c1 *schema.Column) (bool, error) {
   134  	if stmt == nil {
   135  		return false, nil
   136  	}
   137  	s, err := parser.New().ParseOneStmt(stmt.Text, "", "")
   138  	if err != nil {
   139  		return false, err
   140  	}
   141  	check := func(c2 *ast.ColumnDef) bool {
   142  		idxR := slices.IndexFunc(c2.Options, func(o *ast.ColumnOption) bool {
   143  			return o.Tp == ast.ColumnOptionReference
   144  		})
   145  		return c1.Name == c2.Name.Name.String() && idxR != -1
   146  	}
   147  	switch s := s.(type) {
   148  	case *ast.CreateTableStmt:
   149  		return slices.IndexFunc(s.Cols, check) != -1, nil
   150  	case *ast.AlterTableStmt:
   151  		return slices.IndexFunc(s.Specs, func(s *ast.AlterTableSpec) bool {
   152  			return s.Tp == ast.AlterTableAddColumns && slices.IndexFunc(s.NewColumns, check) != -1
   153  		}) != -1, nil
   154  	}
   155  	return false, nil
   156  }
   157  
   158  // CreateViewAfter checks if a view was created after the position with the given name to a table.
   159  func (p *Parser) CreateViewAfter(f migrate.File, old, new string, pos int) (bool, error) {
   160  	pr := parser.New()
   161  	return parseutil.MatchStmtAfter(f, pos, func(s *migrate.Stmt) (bool, error) {
   162  		stmt, err := pr.ParseOneStmt(s.Text, "", "")
   163  		if err != nil {
   164  			return false, err
   165  		}
   166  		v, ok := stmt.(*ast.CreateViewStmt)
   167  		if !ok || v.ViewName.Name.O != old {
   168  			return false, nil
   169  		}
   170  		sc, ok := v.Select.(*ast.SelectStmt)
   171  		if !ok || sc.From == nil || sc.From.TableRefs == nil || sc.From.TableRefs.Left == nil || sc.From.TableRefs.Right != nil {
   172  			return false, nil
   173  		}
   174  		t, ok := sc.From.TableRefs.Left.(*ast.TableSource)
   175  		if !ok || t.Source == nil {
   176  			return false, nil
   177  		}
   178  		name, ok := t.Source.(*ast.TableName)
   179  		return ok && name.Name.O == new, nil
   180  	})
   181  }
   182  
   183  // FixChange fixes the changes according to the given statement.
   184  func (p *Parser) FixChange(d migrate.Driver, s string, changes schema.Changes) (schema.Changes, error) {
   185  	stmt, err := parser.New().ParseOneStmt(s, "", "")
   186  	if err != nil {
   187  		return nil, err
   188  	}
   189  	if len(changes) == 0 {
   190  		return changes, nil
   191  	}
   192  	switch stmt := stmt.(type) {
   193  	case *ast.AlterTableStmt:
   194  		if changes, err = renameTable(d, stmt, changes); err != nil {
   195  			return nil, err
   196  		}
   197  		modify, ok := changes[0].(*schema.ModifyTable)
   198  		if !ok {
   199  			return nil, fmt.Errorf("expected modify-table change for alter-table statement, but got: %T", changes[0])
   200  		}
   201  		for _, r := range renameColumns(stmt) {
   202  			parseutil.RenameColumn(modify, r)
   203  		}
   204  		for _, r := range renameIndexes(stmt) {
   205  			parseutil.RenameIndex(modify, r)
   206  		}
   207  	case *ast.RenameTableStmt:
   208  		for _, t := range stmt.TableToTables {
   209  			changes = parseutil.RenameTable(
   210  				changes,
   211  				&parseutil.Rename{
   212  					From: t.OldTable.Name.O,
   213  					To:   t.NewTable.Name.O,
   214  				})
   215  		}
   216  	}
   217  	return changes, nil
   218  }
   219  
   220  // renameColumns returns all renamed columns that exist in the statement.
   221  func renameColumns(stmt *ast.AlterTableStmt) (rename []*parseutil.Rename) {
   222  	for _, s := range stmt.Specs {
   223  		if s.Tp == ast.AlterTableRenameColumn {
   224  			rename = append(rename, &parseutil.Rename{
   225  				From: s.OldColumnName.Name.O,
   226  				To:   s.NewColumnName.Name.O,
   227  			})
   228  		}
   229  	}
   230  	return
   231  }
   232  
   233  // renameIndexes returns all renamed indexes that exist in the statement.
   234  func renameIndexes(stmt *ast.AlterTableStmt) (rename []*parseutil.Rename) {
   235  	for _, s := range stmt.Specs {
   236  		if s.Tp == ast.AlterTableRenameIndex {
   237  			rename = append(rename, &parseutil.Rename{
   238  				From: s.FromKey.O,
   239  				To:   s.ToKey.O,
   240  			})
   241  		}
   242  	}
   243  	return
   244  }
   245  
   246  // renameTable fixes the changes from ALTER command with RENAME into ModifyTable and RenameTable.
   247  func renameTable(drv migrate.Driver, stmt *ast.AlterTableStmt, changes schema.Changes) (schema.Changes, error) {
   248  	var r *ast.AlterTableSpec
   249  	for _, s := range stmt.Specs {
   250  		if s.Tp == ast.AlterTableRenameTable {
   251  			r = s
   252  			break
   253  		}
   254  	}
   255  	if r == nil {
   256  		return changes, nil
   257  	}
   258  	if len(changes) != 2 {
   259  		return nil, fmt.Errorf("unexected number fo changes for ALTER command with RENAME clause: %d", len(changes))
   260  	}
   261  	i, j := changes.IndexDropTable(stmt.Table.Name.O), changes.IndexAddTable(r.NewTable.Name.O)
   262  	if i == -1 {
   263  		return nil, fmt.Errorf("DropTable %q change was not found in changes", stmt.Table.Name)
   264  	}
   265  	if j == -1 {
   266  		return nil, fmt.Errorf("AddTable %q change was not found in changes", r.NewTable.Name)
   267  	}
   268  	fromT, toT := changes[0].(*schema.DropTable).T, changes[1].(*schema.AddTable).T
   269  	fromT.Name = toT.Name
   270  	diff, err := drv.TableDiff(fromT, toT)
   271  	if err != nil {
   272  		return nil, err
   273  	}
   274  	changeT := *toT
   275  	changeT.Name = stmt.Table.Name.O
   276  	return schema.Changes{
   277  		// Modify the table first.
   278  		&schema.ModifyTable{T: &changeT, Changes: diff},
   279  		// Then, apply the RENAME.
   280  		&schema.RenameTable{From: &changeT, To: toT},
   281  	}, nil
   282  }
   283  
   284  // tableUpdated checks if the table was updated in the statement.
   285  func tableUpdated(u *ast.UpdateStmt, t *schema.Table) bool {
   286  	if u.TableRefs == nil || u.TableRefs.TableRefs == nil || u.TableRefs.TableRefs.Left == nil {
   287  		return false
   288  	}
   289  	ts, ok := u.TableRefs.TableRefs.Left.(*ast.TableSource)
   290  	if !ok {
   291  		return false
   292  	}
   293  	n, ok := ts.Source.(*ast.TableName)
   294  	return ok && n.Name.O == t.Name && (n.Schema.O == "" || n.Schema.O == t.Schema.Name)
   295  }