github.com/iasthc/atlas/cmd/atlas@v0.0.0-20230523071841-73246df3f88d/internal/sqlparse/parseutil/parseutil.go (about)

     1  // Copyright 2021-present The Atlas Authors. All rights reserved.
     2  // This source code is licensed under the Apache 2.0 license found
     3  // in the LICENSE file in the root directory of this source tree.
     4  
     5  // Package parseutil exposes shared functions used by the different parsers.
     6  package parseutil
     7  
     8  import (
     9  	"github.com/iasthc/atlas/sql/migrate"
    10  	"github.com/iasthc/atlas/sql/schema"
    11  
    12  	"golang.org/x/exp/slices"
    13  )
    14  
    15  // Rename describes rename of a resource.
    16  type Rename struct {
    17  	From, To string
    18  }
    19  
    20  // RenameColumn patches DROP/ADD column commands to RENAME.
    21  func RenameColumn(modify *schema.ModifyTable, r *Rename) {
    22  	changes := schema.Changes(modify.Changes)
    23  	switch i, j := changes.IndexDropColumn(r.From), changes.IndexAddColumn(r.To); {
    24  	case j == -1:
    25  	// Rename column.
    26  	case i != -1:
    27  		changes[max(i, j)] = &schema.RenameColumn{
    28  			From: changes[i].(*schema.DropColumn).C,
    29  			To:   changes[j].(*schema.AddColumn).C,
    30  		}
    31  		changes.RemoveIndex(min(i, j))
    32  	// Rename column and add the previous name back.
    33  	default:
    34  		if i = changes.IndexModifyColumn(r.From); i == -1 {
    35  			// ADD COLUMN must come after the RENAME.
    36  			return
    37  		}
    38  		modify := changes[i].(*schema.ModifyColumn)
    39  		changes[min(i, j)] = &schema.RenameColumn{
    40  			From: modify.From,
    41  			To:   changes[j].(*schema.AddColumn).C,
    42  		}
    43  		changes[max(i, j)] = &schema.AddColumn{
    44  			C: modify.To,
    45  		}
    46  	}
    47  	modify.Changes = changes
    48  }
    49  
    50  // RenameIndex patches DROP/ADD index commands to RENAME.
    51  func RenameIndex(modify *schema.ModifyTable, r *Rename) {
    52  	changes := schema.Changes(modify.Changes)
    53  	i := changes.IndexDropIndex(r.From)
    54  	j := changes.IndexAddIndex(r.To)
    55  	if i != -1 && j != -1 {
    56  		changes[max(i, j)] = &schema.RenameIndex{
    57  			From: changes[i].(*schema.DropIndex).I,
    58  			To:   changes[j].(*schema.AddIndex).I,
    59  		}
    60  		changes.RemoveIndex(min(i, j))
    61  		modify.Changes = changes
    62  	}
    63  }
    64  
    65  // RenameTable patches DROP/ADD table commands to RENAME.
    66  func RenameTable(changes schema.Changes, r *Rename) schema.Changes {
    67  	i := changes.LastIndexDropTable(r.From)
    68  	j := changes.LastIndexAddTable(r.To)
    69  	if i != -1 && j != -1 {
    70  		changes[max(i, j)] = &schema.RenameTable{
    71  			From: changes[i].(*schema.DropTable).T,
    72  			To:   changes[j].(*schema.AddTable).T,
    73  		}
    74  		changes.RemoveIndex(min(i, j))
    75  	}
    76  	return changes
    77  }
    78  
    79  // MatchStmtBefore reports if the file contains any statement that matches the predicate before the given position.
    80  func MatchStmtBefore(f migrate.File, pos int, p func(*migrate.Stmt) (bool, error)) (bool, error) {
    81  	stmts, err := f.StmtDecls()
    82  	if err != nil {
    83  		return false, err
    84  	}
    85  	i := slices.IndexFunc(stmts, func(s *migrate.Stmt) bool {
    86  		return s.Pos >= pos
    87  	})
    88  	if i != -1 {
    89  		stmts = stmts[:i]
    90  	}
    91  	for _, s := range stmts {
    92  		m, err := p(s)
    93  		if err != nil {
    94  			return false, err
    95  		}
    96  		if m {
    97  			return true, nil
    98  		}
    99  	}
   100  	return false, nil
   101  }
   102  
   103  // MatchStmtAfter reports if the file contains any statement that matches the predicate after the given position.
   104  func MatchStmtAfter(f migrate.File, pos int, p func(*migrate.Stmt) (bool, error)) (bool, error) {
   105  	stmts, err := f.StmtDecls()
   106  	if err != nil {
   107  		return false, err
   108  	}
   109  	i := slices.IndexFunc(stmts, func(s *migrate.Stmt) bool {
   110  		return s.Pos > pos
   111  	})
   112  	if i == -1 {
   113  		return false, nil
   114  	}
   115  	stmts = stmts[i:]
   116  	for _, s := range stmts {
   117  		m, err := p(s)
   118  		if err != nil {
   119  			return false, err
   120  		}
   121  		if m {
   122  			return true, nil
   123  		}
   124  	}
   125  	return false, nil
   126  }
   127  
   128  func max(i, j int) int {
   129  	if i > j {
   130  		return i
   131  	}
   132  	return j
   133  }
   134  
   135  func min(i, j int) int {
   136  	if i < j {
   137  		return i
   138  	}
   139  	return j
   140  }