github.com/iasthc/atlas/cmd/atlas@v0.0.0-20230523071841-73246df3f88d/internal/sqlparse/pgparse/pgparse.go (about)

     1  // Copyright 2021-present The Atlas Authors. All rights reserved.
     2  // This source code is licensed under the Apache 2.0 license found
     3  // in the LICENSE file in the root directory of this source tree.
     4  
     5  package pgparse
     6  
     7  import (
     8  	"fmt"
     9  
    10  	"github.com/iasthc/atlas/cmd/atlas/internal/sqlparse/parseutil"
    11  	"github.com/iasthc/atlas/sql/migrate"
    12  	"github.com/iasthc/atlas/sql/postgres"
    13  	"github.com/iasthc/atlas/sql/schema"
    14  
    15  	"github.com/auxten/postgresql-parser/pkg/sql/parser"
    16  	"github.com/auxten/postgresql-parser/pkg/sql/sem/tree"
    17  	"golang.org/x/exp/slices"
    18  )
    19  
    20  // Parser implements the sqlparse.Parser
    21  type Parser struct{}
    22  
    23  // ColumnFilledBefore checks if the column was filled before the given position.
    24  func (p *Parser) ColumnFilledBefore(f migrate.File, t *schema.Table, c *schema.Column, pos int) (bool, error) {
    25  	return parseutil.MatchStmtBefore(f, pos, func(s *migrate.Stmt) (bool, error) {
    26  		stmt, err := parser.ParseOne(s.Text)
    27  		if err != nil {
    28  			return false, err
    29  		}
    30  		u, ok := stmt.AST.(*tree.Update)
    31  		if !ok || !tableUpdated(u, t) {
    32  			return false, nil
    33  		}
    34  		// Accept UPDATE that fills all rows or those with NULL values as we cannot
    35  		// determine if NULL values were filled in case there is a custom filtering.
    36  		affectC := func() bool {
    37  			if u.Where == nil {
    38  				return true
    39  			}
    40  			x, ok := u.Where.Expr.(*tree.ComparisonExpr)
    41  			if !ok || x.Operator != tree.IsNotDistinctFrom || x.SubOperator != tree.EQ {
    42  				return false
    43  			}
    44  			return x.Left.String() == c.Name && x.Right == tree.DNull
    45  		}()
    46  		idx := slices.IndexFunc(u.Exprs, func(x *tree.UpdateExpr) bool {
    47  			return slices.Contains(x.Names, tree.Name(c.Name)) && x.Expr != tree.DNull
    48  		})
    49  		// Ensure the column was filled.
    50  		return affectC && idx != -1, nil
    51  	})
    52  }
    53  
    54  // CreateViewAfter checks if a view was created after the position with the given name to a table.
    55  func (p *Parser) CreateViewAfter(f migrate.File, old, new string, pos int) (bool, error) {
    56  	return parseutil.MatchStmtAfter(f, pos, func(s *migrate.Stmt) (bool, error) {
    57  		stmt, err := parser.ParseOne(s.Text)
    58  		if err != nil {
    59  			return false, err
    60  		}
    61  		v, ok := stmt.AST.(*tree.CreateView)
    62  		if !ok || v.AsSource == nil || v.Name.String() != old {
    63  			return false, nil
    64  		}
    65  		sc, ok := v.AsSource.Select.(*tree.SelectClause)
    66  		if !ok || len(sc.From.Tables) != 1 {
    67  			return false, nil
    68  		}
    69  		return tree.AsString(sc.From.Tables[0]) == new, nil
    70  	})
    71  }
    72  
    73  // FixChange fixes the changes according to the given statement.
    74  func (p *Parser) FixChange(_ migrate.Driver, s string, changes schema.Changes) (schema.Changes, error) {
    75  	stmt, err := parser.ParseOne(s)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  	switch stmt := stmt.AST.(type) {
    80  	case *tree.AlterTable:
    81  		if r, ok := renameColumn(stmt); ok {
    82  			modify, err := expectModify(changes)
    83  			if err != nil {
    84  				return nil, err
    85  			}
    86  			parseutil.RenameColumn(modify, r)
    87  		}
    88  	case *tree.RenameIndex:
    89  		modify, err := expectModify(changes)
    90  		if err != nil {
    91  			return nil, err
    92  		}
    93  		parseutil.RenameIndex(modify, &parseutil.Rename{
    94  			From: stmt.Index.String(),
    95  			To:   stmt.NewName.String(),
    96  		})
    97  	case *tree.RenameTable:
    98  		changes = parseutil.RenameTable(changes, &parseutil.Rename{
    99  			From: stmt.Name.String(),
   100  			To:   stmt.NewName.String(),
   101  		})
   102  	case *tree.CreateIndex:
   103  		modify, err := expectModify(changes)
   104  		if err != nil {
   105  			return nil, err
   106  		}
   107  		i := schema.Changes(modify.Changes).IndexAddIndex(stmt.Name.String())
   108  		if i == -1 {
   109  			return nil, fmt.Errorf("AddIndex %q command not found", stmt.Name)
   110  		}
   111  		add := modify.Changes[i].(*schema.AddIndex)
   112  		if slices.IndexFunc(add.Extra, func(c schema.Clause) bool {
   113  			_, ok := c.(*postgres.Concurrently)
   114  			return ok
   115  		}) == -1 && stmt.Concurrently {
   116  			add.Extra = append(add.Extra, &postgres.Concurrently{})
   117  		}
   118  	}
   119  	return changes, nil
   120  }
   121  
   122  // renameColumn returns the renamed column exists in the statement, is any.
   123  func renameColumn(stmt *tree.AlterTable) (*parseutil.Rename, bool) {
   124  	for _, c := range stmt.Cmds {
   125  		if r, ok := c.(*tree.AlterTableRenameColumn); ok {
   126  			return &parseutil.Rename{
   127  				From: r.Column.String(),
   128  				To:   r.NewName.String(),
   129  			}, true
   130  		}
   131  	}
   132  	return nil, false
   133  }
   134  
   135  func expectModify(changes schema.Changes) (*schema.ModifyTable, error) {
   136  	if len(changes) != 1 {
   137  		return nil, fmt.Errorf("unexected number fo changes: %d", len(changes))
   138  	}
   139  	modify, ok := changes[0].(*schema.ModifyTable)
   140  	if !ok {
   141  		return nil, fmt.Errorf("expected modify-table change for alter-table statement, but got: %T", changes[0])
   142  	}
   143  	return modify, nil
   144  }
   145  
   146  // tableUpdated checks if the table was updated in the statement.
   147  func tableUpdated(u *tree.Update, t *schema.Table) bool {
   148  	at, ok := u.Table.(*tree.AliasedTableExpr)
   149  	if !ok {
   150  		return false
   151  	}
   152  	n, ok := at.Expr.(*tree.TableName)
   153  	return ok && n.Table() == t.Name && (n.Schema() == "" || n.Schema() == t.Schema.Name)
   154  }