github.com/iasthc/atlas/cmd/atlas@v0.0.0-20230523071841-73246df3f88d/internal/sqlparse/pgparse/pgparse.go (about) 1 // Copyright 2021-present The Atlas Authors. All rights reserved. 2 // This source code is licensed under the Apache 2.0 license found 3 // in the LICENSE file in the root directory of this source tree. 4 5 package pgparse 6 7 import ( 8 "fmt" 9 10 "github.com/iasthc/atlas/cmd/atlas/internal/sqlparse/parseutil" 11 "github.com/iasthc/atlas/sql/migrate" 12 "github.com/iasthc/atlas/sql/postgres" 13 "github.com/iasthc/atlas/sql/schema" 14 15 "github.com/auxten/postgresql-parser/pkg/sql/parser" 16 "github.com/auxten/postgresql-parser/pkg/sql/sem/tree" 17 "golang.org/x/exp/slices" 18 ) 19 20 // Parser implements the sqlparse.Parser 21 type Parser struct{} 22 23 // ColumnFilledBefore checks if the column was filled before the given position. 24 func (p *Parser) ColumnFilledBefore(f migrate.File, t *schema.Table, c *schema.Column, pos int) (bool, error) { 25 return parseutil.MatchStmtBefore(f, pos, func(s *migrate.Stmt) (bool, error) { 26 stmt, err := parser.ParseOne(s.Text) 27 if err != nil { 28 return false, err 29 } 30 u, ok := stmt.AST.(*tree.Update) 31 if !ok || !tableUpdated(u, t) { 32 return false, nil 33 } 34 // Accept UPDATE that fills all rows or those with NULL values as we cannot 35 // determine if NULL values were filled in case there is a custom filtering. 36 affectC := func() bool { 37 if u.Where == nil { 38 return true 39 } 40 x, ok := u.Where.Expr.(*tree.ComparisonExpr) 41 if !ok || x.Operator != tree.IsNotDistinctFrom || x.SubOperator != tree.EQ { 42 return false 43 } 44 return x.Left.String() == c.Name && x.Right == tree.DNull 45 }() 46 idx := slices.IndexFunc(u.Exprs, func(x *tree.UpdateExpr) bool { 47 return slices.Contains(x.Names, tree.Name(c.Name)) && x.Expr != tree.DNull 48 }) 49 // Ensure the column was filled. 50 return affectC && idx != -1, nil 51 }) 52 } 53 54 // CreateViewAfter checks if a view was created after the position with the given name to a table. 55 func (p *Parser) CreateViewAfter(f migrate.File, old, new string, pos int) (bool, error) { 56 return parseutil.MatchStmtAfter(f, pos, func(s *migrate.Stmt) (bool, error) { 57 stmt, err := parser.ParseOne(s.Text) 58 if err != nil { 59 return false, err 60 } 61 v, ok := stmt.AST.(*tree.CreateView) 62 if !ok || v.AsSource == nil || v.Name.String() != old { 63 return false, nil 64 } 65 sc, ok := v.AsSource.Select.(*tree.SelectClause) 66 if !ok || len(sc.From.Tables) != 1 { 67 return false, nil 68 } 69 return tree.AsString(sc.From.Tables[0]) == new, nil 70 }) 71 } 72 73 // FixChange fixes the changes according to the given statement. 74 func (p *Parser) FixChange(_ migrate.Driver, s string, changes schema.Changes) (schema.Changes, error) { 75 stmt, err := parser.ParseOne(s) 76 if err != nil { 77 return nil, err 78 } 79 switch stmt := stmt.AST.(type) { 80 case *tree.AlterTable: 81 if r, ok := renameColumn(stmt); ok { 82 modify, err := expectModify(changes) 83 if err != nil { 84 return nil, err 85 } 86 parseutil.RenameColumn(modify, r) 87 } 88 case *tree.RenameIndex: 89 modify, err := expectModify(changes) 90 if err != nil { 91 return nil, err 92 } 93 parseutil.RenameIndex(modify, &parseutil.Rename{ 94 From: stmt.Index.String(), 95 To: stmt.NewName.String(), 96 }) 97 case *tree.RenameTable: 98 changes = parseutil.RenameTable(changes, &parseutil.Rename{ 99 From: stmt.Name.String(), 100 To: stmt.NewName.String(), 101 }) 102 case *tree.CreateIndex: 103 modify, err := expectModify(changes) 104 if err != nil { 105 return nil, err 106 } 107 i := schema.Changes(modify.Changes).IndexAddIndex(stmt.Name.String()) 108 if i == -1 { 109 return nil, fmt.Errorf("AddIndex %q command not found", stmt.Name) 110 } 111 add := modify.Changes[i].(*schema.AddIndex) 112 if slices.IndexFunc(add.Extra, func(c schema.Clause) bool { 113 _, ok := c.(*postgres.Concurrently) 114 return ok 115 }) == -1 && stmt.Concurrently { 116 add.Extra = append(add.Extra, &postgres.Concurrently{}) 117 } 118 } 119 return changes, nil 120 } 121 122 // renameColumn returns the renamed column exists in the statement, is any. 123 func renameColumn(stmt *tree.AlterTable) (*parseutil.Rename, bool) { 124 for _, c := range stmt.Cmds { 125 if r, ok := c.(*tree.AlterTableRenameColumn); ok { 126 return &parseutil.Rename{ 127 From: r.Column.String(), 128 To: r.NewName.String(), 129 }, true 130 } 131 } 132 return nil, false 133 } 134 135 func expectModify(changes schema.Changes) (*schema.ModifyTable, error) { 136 if len(changes) != 1 { 137 return nil, fmt.Errorf("unexected number fo changes: %d", len(changes)) 138 } 139 modify, ok := changes[0].(*schema.ModifyTable) 140 if !ok { 141 return nil, fmt.Errorf("expected modify-table change for alter-table statement, but got: %T", changes[0]) 142 } 143 return modify, nil 144 } 145 146 // tableUpdated checks if the table was updated in the statement. 147 func tableUpdated(u *tree.Update, t *schema.Table) bool { 148 at, ok := u.Table.(*tree.AliasedTableExpr) 149 if !ok { 150 return false 151 } 152 n, ok := at.Expr.(*tree.TableName) 153 return ok && n.Table() == t.Name && (n.Schema() == "" || n.Schema() == t.Schema.Name) 154 }