github.com/iasthc/atlas/cmd/atlas@v0.0.0-20230523071841-73246df3f88d/internal/sqlparse/myparse/myparse.go (about) 1 // Copyright 2021-present The Atlas Authors. All rights reserved. 2 // This source code is licensed under the Apache 2.0 license found 3 // in the LICENSE file in the root directory of this source tree. 4 5 package myparse 6 7 import ( 8 "fmt" 9 "strconv" 10 11 "github.com/iasthc/atlas/cmd/atlas/internal/sqlparse/parseutil" 12 "github.com/iasthc/atlas/sql/migrate" 13 "github.com/iasthc/atlas/sql/schema" 14 15 "github.com/pingcap/tidb/parser" 16 "github.com/pingcap/tidb/parser/ast" 17 "github.com/pingcap/tidb/parser/mysql" 18 "github.com/pingcap/tidb/parser/opcode" 19 "github.com/pingcap/tidb/parser/test_driver" 20 "golang.org/x/exp/slices" 21 ) 22 23 // Parser implements the sqlparse.Parser 24 type Parser struct{} 25 26 // ColumnFilledBefore checks if the column was filled before the given position. 27 func (p *Parser) ColumnFilledBefore(f migrate.File, t *schema.Table, c *schema.Column, pos int) (bool, error) { 28 pr := parser.New() 29 return parseutil.MatchStmtBefore(f, pos, func(s *migrate.Stmt) (bool, error) { 30 stmt, err := pr.ParseOneStmt(s.Text, "", "") 31 if err != nil { 32 return false, err 33 } 34 u, ok := stmt.(*ast.UpdateStmt) 35 // Ensure the table was updated. 36 if !ok || !tableUpdated(u, t) { 37 return false, nil 38 } 39 // Accept UPDATE that fills all rows or those with NULL values as we cannot 40 // determine if NULL values were filled in case there is a custom filtering. 41 affectC := func() bool { 42 if u.Where == nil { 43 return true 44 } 45 is, ok := u.Where.(*ast.IsNullExpr) 46 if !ok || is.Not { 47 return false 48 } 49 n, ok := is.Expr.(*ast.ColumnNameExpr) 50 return ok && n.Name.Name.O == c.Name 51 }() 52 idx := slices.IndexFunc(u.List, func(a *ast.Assignment) bool { 53 return a.Column.Name.String() == c.Name && a.Expr != nil && a.Expr.GetType().GetType() != mysql.TypeNull 54 }) 55 // Ensure the column was filled. 56 return affectC && idx != -1, nil 57 }) 58 } 59 60 // ColumnFilledAfter checks if the column that matches the given value was filled after the position. 61 func (p *Parser) ColumnFilledAfter(f migrate.File, t *schema.Table, c *schema.Column, pos int, match any) (bool, error) { 62 stmts, err := f.StmtDecls() 63 if err != nil { 64 return false, err 65 } 66 switch i := slices.IndexFunc(stmts, func(s *migrate.Stmt) bool { 67 return s.Pos >= pos 68 }); i { 69 case -1: 70 return false, nil 71 default: 72 stmts = stmts[i:] 73 } 74 pr := parser.New() 75 for _, s := range stmts { 76 stmt, err := pr.ParseOneStmt(s.Text, "", "") 77 if err != nil { 78 return false, err 79 } 80 u, ok := stmt.(*ast.UpdateStmt) 81 // Ensure the table was updated. 82 if !ok || !tableUpdated(u, t) { 83 continue 84 } 85 // Accept UPDATE that fills all rows or those with NULL values as we cannot 86 // determine if NULL values were filled in case there is a custom filtering. 87 affectC := func() bool { 88 if u.Where == nil { 89 return true 90 } 91 switch match.(type) { 92 case nil: 93 is, ok := u.Where.(*ast.IsNullExpr) 94 if !ok || is.Not { 95 return false 96 } 97 n, ok := is.Expr.(*ast.ColumnNameExpr) 98 return ok && n.Name.Name.O == c.Name 99 default: 100 bin, ok := u.Where.(*ast.BinaryOperationExpr) 101 if !ok || bin.Op != opcode.EQ { 102 return false 103 } 104 l, r := bin.L, bin.R 105 if _, ok := bin.L.(*ast.ColumnNameExpr); !ok { 106 l, r = r, l 107 } 108 n, ok1 := l.(*ast.ColumnNameExpr) 109 v, ok2 := r.(*test_driver.ValueExpr) 110 if !ok1 || !ok2 || n.Name.Name.O != c.Name { 111 return false 112 } 113 x := fmt.Sprint(match) 114 if u, err := strconv.Unquote(x); err == nil { 115 x = u 116 } 117 // String representations should be ~equal. 118 return fmt.Sprint(v.Datum.GetValue()) == x 119 } 120 }() 121 idx := slices.IndexFunc(u.List, func(a *ast.Assignment) bool { 122 return a.Column.Name.String() == c.Name && a.Expr != nil && a.Expr.GetType().GetType() != mysql.TypeNull 123 }) 124 // Ensure the column was filled. 125 if affectC && idx != -1 { 126 return true, nil 127 } 128 } 129 return false, nil 130 } 131 132 // ColumnHasReferences checks if the column has an inline REFERENCES clause in the given CREATE or ALTER statement. 133 func (p *Parser) ColumnHasReferences(stmt *migrate.Stmt, c1 *schema.Column) (bool, error) { 134 if stmt == nil { 135 return false, nil 136 } 137 s, err := parser.New().ParseOneStmt(stmt.Text, "", "") 138 if err != nil { 139 return false, err 140 } 141 check := func(c2 *ast.ColumnDef) bool { 142 idxR := slices.IndexFunc(c2.Options, func(o *ast.ColumnOption) bool { 143 return o.Tp == ast.ColumnOptionReference 144 }) 145 return c1.Name == c2.Name.Name.String() && idxR != -1 146 } 147 switch s := s.(type) { 148 case *ast.CreateTableStmt: 149 return slices.IndexFunc(s.Cols, check) != -1, nil 150 case *ast.AlterTableStmt: 151 return slices.IndexFunc(s.Specs, func(s *ast.AlterTableSpec) bool { 152 return s.Tp == ast.AlterTableAddColumns && slices.IndexFunc(s.NewColumns, check) != -1 153 }) != -1, nil 154 } 155 return false, nil 156 } 157 158 // CreateViewAfter checks if a view was created after the position with the given name to a table. 159 func (p *Parser) CreateViewAfter(f migrate.File, old, new string, pos int) (bool, error) { 160 pr := parser.New() 161 return parseutil.MatchStmtAfter(f, pos, func(s *migrate.Stmt) (bool, error) { 162 stmt, err := pr.ParseOneStmt(s.Text, "", "") 163 if err != nil { 164 return false, err 165 } 166 v, ok := stmt.(*ast.CreateViewStmt) 167 if !ok || v.ViewName.Name.O != old { 168 return false, nil 169 } 170 sc, ok := v.Select.(*ast.SelectStmt) 171 if !ok || sc.From == nil || sc.From.TableRefs == nil || sc.From.TableRefs.Left == nil || sc.From.TableRefs.Right != nil { 172 return false, nil 173 } 174 t, ok := sc.From.TableRefs.Left.(*ast.TableSource) 175 if !ok || t.Source == nil { 176 return false, nil 177 } 178 name, ok := t.Source.(*ast.TableName) 179 return ok && name.Name.O == new, nil 180 }) 181 } 182 183 // FixChange fixes the changes according to the given statement. 184 func (p *Parser) FixChange(d migrate.Driver, s string, changes schema.Changes) (schema.Changes, error) { 185 stmt, err := parser.New().ParseOneStmt(s, "", "") 186 if err != nil { 187 return nil, err 188 } 189 if len(changes) == 0 { 190 return changes, nil 191 } 192 switch stmt := stmt.(type) { 193 case *ast.AlterTableStmt: 194 if changes, err = renameTable(d, stmt, changes); err != nil { 195 return nil, err 196 } 197 modify, ok := changes[0].(*schema.ModifyTable) 198 if !ok { 199 return nil, fmt.Errorf("expected modify-table change for alter-table statement, but got: %T", changes[0]) 200 } 201 for _, r := range renameColumns(stmt) { 202 parseutil.RenameColumn(modify, r) 203 } 204 for _, r := range renameIndexes(stmt) { 205 parseutil.RenameIndex(modify, r) 206 } 207 case *ast.RenameTableStmt: 208 for _, t := range stmt.TableToTables { 209 changes = parseutil.RenameTable( 210 changes, 211 &parseutil.Rename{ 212 From: t.OldTable.Name.O, 213 To: t.NewTable.Name.O, 214 }) 215 } 216 } 217 return changes, nil 218 } 219 220 // renameColumns returns all renamed columns that exist in the statement. 221 func renameColumns(stmt *ast.AlterTableStmt) (rename []*parseutil.Rename) { 222 for _, s := range stmt.Specs { 223 if s.Tp == ast.AlterTableRenameColumn { 224 rename = append(rename, &parseutil.Rename{ 225 From: s.OldColumnName.Name.O, 226 To: s.NewColumnName.Name.O, 227 }) 228 } 229 } 230 return 231 } 232 233 // renameIndexes returns all renamed indexes that exist in the statement. 234 func renameIndexes(stmt *ast.AlterTableStmt) (rename []*parseutil.Rename) { 235 for _, s := range stmt.Specs { 236 if s.Tp == ast.AlterTableRenameIndex { 237 rename = append(rename, &parseutil.Rename{ 238 From: s.FromKey.O, 239 To: s.ToKey.O, 240 }) 241 } 242 } 243 return 244 } 245 246 // renameTable fixes the changes from ALTER command with RENAME into ModifyTable and RenameTable. 247 func renameTable(drv migrate.Driver, stmt *ast.AlterTableStmt, changes schema.Changes) (schema.Changes, error) { 248 var r *ast.AlterTableSpec 249 for _, s := range stmt.Specs { 250 if s.Tp == ast.AlterTableRenameTable { 251 r = s 252 break 253 } 254 } 255 if r == nil { 256 return changes, nil 257 } 258 if len(changes) != 2 { 259 return nil, fmt.Errorf("unexected number fo changes for ALTER command with RENAME clause: %d", len(changes)) 260 } 261 i, j := changes.IndexDropTable(stmt.Table.Name.O), changes.IndexAddTable(r.NewTable.Name.O) 262 if i == -1 { 263 return nil, fmt.Errorf("DropTable %q change was not found in changes", stmt.Table.Name) 264 } 265 if j == -1 { 266 return nil, fmt.Errorf("AddTable %q change was not found in changes", r.NewTable.Name) 267 } 268 fromT, toT := changes[0].(*schema.DropTable).T, changes[1].(*schema.AddTable).T 269 fromT.Name = toT.Name 270 diff, err := drv.TableDiff(fromT, toT) 271 if err != nil { 272 return nil, err 273 } 274 changeT := *toT 275 changeT.Name = stmt.Table.Name.O 276 return schema.Changes{ 277 // Modify the table first. 278 &schema.ModifyTable{T: &changeT, Changes: diff}, 279 // Then, apply the RENAME. 280 &schema.RenameTable{From: &changeT, To: toT}, 281 }, nil 282 } 283 284 // tableUpdated checks if the table was updated in the statement. 285 func tableUpdated(u *ast.UpdateStmt, t *schema.Table) bool { 286 if u.TableRefs == nil || u.TableRefs.TableRefs == nil || u.TableRefs.TableRefs.Left == nil { 287 return false 288 } 289 ts, ok := u.TableRefs.TableRefs.Left.(*ast.TableSource) 290 if !ok { 291 return false 292 } 293 n, ok := ts.Source.(*ast.TableName) 294 return ok && n.Name.O == t.Name && (n.Schema.O == "" || n.Schema.O == t.Schema.Name) 295 }