github.com/iasthc/atlas/cmd/atlas@v0.0.0-20230523071841-73246df3f88d/internal/sqlparse/sqliteparse/sqliteparse.go (about)

     1  // Copyright 2021-present The Atlas Authors. All rights reserved.
     2  // This source code is licensed under the Apache 2.0 license found
     3  // in the LICENSE file in the root directory of this source tree.
     4  
     5  package sqliteparse
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/iasthc/atlas/cmd/atlas/internal/sqlparse/parseutil"
    14  	"github.com/iasthc/atlas/sql/migrate"
    15  	"github.com/iasthc/atlas/sql/schema"
    16  
    17  	"github.com/antlr/antlr4/runtime/Go/antlr"
    18  	"golang.org/x/exp/slices"
    19  )
    20  
    21  type (
    22  	// Stmt provides extended functionality
    23  	// to ANTLR parsed statements.
    24  	Stmt struct {
    25  		stmt  antlr.ParseTree
    26  		input string
    27  		err   error
    28  	}
    29  
    30  	// listenError catches parse errors.
    31  	listenError struct {
    32  		antlr.DefaultErrorListener
    33  		err  error
    34  		text string
    35  	}
    36  )
    37  
    38  // SyntaxError implements ErrorListener.SyntaxError.
    39  func (l *listenError) SyntaxError(_ antlr.Recognizer, _ any, line, column int, msg string, _ antlr.RecognitionException) {
    40  	if idx := strings.Index(msg, " expecting "); idx != -1 {
    41  		msg = msg[:idx]
    42  	}
    43  	l.err = fmt.Errorf("line %d:%d: %s", line, column+1, msg)
    44  }
    45  
    46  // ParseStmt parses a statement.
    47  func ParseStmt(text string) (stmt *Stmt, err error) {
    48  	l := &listenError{text: text}
    49  	defer func() {
    50  		if l.err != nil {
    51  			err = l.err
    52  			stmt = nil
    53  		} else if perr := recover(); perr != nil {
    54  			m := fmt.Sprint(perr)
    55  			if v, ok := err.(antlr.RecognitionException); ok {
    56  				m = v.GetMessage()
    57  			}
    58  			err = errors.New(m)
    59  			stmt = nil
    60  		}
    61  	}()
    62  	lex := NewLexer(antlr.NewInputStream(text))
    63  	lex.RemoveErrorListeners()
    64  	lex.AddErrorListener(l)
    65  	p := NewParser(
    66  		antlr.NewCommonTokenStream(lex, 0),
    67  	)
    68  	p.RemoveErrorListeners()
    69  	p.AddErrorListener(l)
    70  	p.BuildParseTrees = true
    71  	stmt = &Stmt{
    72  		stmt: p.Sql_stmt(),
    73  	}
    74  	return
    75  }
    76  
    77  // IsAlterTable reports if the statement is type ALTER TABLE.
    78  func (s *Stmt) IsAlterTable() bool {
    79  	if s.stmt.GetChildCount() != 1 {
    80  		return false
    81  	}
    82  	_, ok := s.stmt.GetChild(0).(*Alter_table_stmtContext)
    83  	return ok
    84  }
    85  
    86  // RenameColumn returns the renamed column information from the statement, if exists.
    87  func (s *Stmt) RenameColumn() (*parseutil.Rename, bool) {
    88  	if !s.IsAlterTable() {
    89  		return nil, false
    90  	}
    91  	alter := s.stmt.GetChild(0).(*Alter_table_stmtContext)
    92  	if alter.old_column_name == nil || alter.new_column_name == nil {
    93  		return nil, false
    94  	}
    95  	return &parseutil.Rename{
    96  		From: unquote(alter.old_column_name.GetText()),
    97  		To:   unquote(alter.new_column_name.GetText()),
    98  	}, true
    99  }
   100  
   101  // RenameTable returns the renamed table information from the statement, if exists.
   102  func (s *Stmt) RenameTable() (*parseutil.Rename, bool) {
   103  	if !s.IsAlterTable() {
   104  		return nil, false
   105  	}
   106  	alter := s.stmt.GetChild(0).(*Alter_table_stmtContext)
   107  	if alter.new_table_name == nil {
   108  		return nil, false
   109  	}
   110  	return &parseutil.Rename{
   111  		From: unquote(alter.Table_name(0).GetText()),
   112  		To:   unquote(alter.new_table_name.GetText()),
   113  	}, true
   114  }
   115  
   116  // TableUpdate reports if the statement is an UPDATE command for the given table.
   117  func (s *Stmt) TableUpdate(t *schema.Table) (*Update_stmtContext, bool) {
   118  	if s.stmt.GetChildCount() != 1 {
   119  		return nil, false
   120  	}
   121  	u, ok := s.stmt.GetChild(0).(*Update_stmtContext)
   122  	if !ok {
   123  		return nil, false
   124  	}
   125  	name, ok := u.Qualified_table_name().(*Qualified_table_nameContext)
   126  	if !ok || unquote(name.Table_name().GetText()) != t.Name {
   127  		return nil, false
   128  	}
   129  	return u, true
   130  }
   131  
   132  // CreateView reports if the statement is a CREATE VIEW command with the given name.
   133  func (s *Stmt) CreateView(name string) (*Create_view_stmtContext, bool) {
   134  	if s.stmt.GetChildCount() != 1 {
   135  		return nil, false
   136  	}
   137  	v, ok := s.stmt.GetChild(0).(*Create_view_stmtContext)
   138  	if !ok || unquote(v.View_name().GetText()) != name {
   139  		return nil, false
   140  	}
   141  	return v, true
   142  }
   143  
   144  // FileParser implements the sqlparse.Parser
   145  type FileParser struct{}
   146  
   147  // ColumnFilledBefore checks if the column was filled before the given position.
   148  func (p *FileParser) ColumnFilledBefore(f migrate.File, t *schema.Table, c *schema.Column, pos int) (bool, error) {
   149  	return parseutil.MatchStmtBefore(f, pos, func(s *migrate.Stmt) (bool, error) {
   150  		stmt, err := ParseStmt(s.Text)
   151  		if err != nil {
   152  			return false, err
   153  		}
   154  		u, ok := stmt.TableUpdate(t)
   155  		if !ok {
   156  			return false, nil
   157  		}
   158  		// Accept UPDATE that fills all rows or those with NULL values as we cannot
   159  		// determine if NULL values were filled in case there is a custom filtering.
   160  		affectC := func() bool {
   161  			x := u.GetWhere()
   162  			if x == nil {
   163  				return true
   164  			}
   165  			if x.GetChildCount() != 3 {
   166  				return false
   167  			}
   168  			x1, ok := x.GetChild(0).(*ExprContext)
   169  			if !ok || unquote(x1.GetText()) != c.Name {
   170  				return false
   171  			}
   172  			x2, ok := x.GetChild(1).(*antlr.TerminalNodeImpl)
   173  			if !ok || x2.GetSymbol().GetTokenType() != ParserIS_ {
   174  				return false
   175  			}
   176  			return isnull(x.GetChild(2))
   177  		}()
   178  		list, ok := u.Assignment_list().(*Assignment_listContext)
   179  		if !ok {
   180  			return false, nil
   181  		}
   182  		idx := slices.IndexFunc(list.AllAssignment(), func(a IAssignmentContext) bool {
   183  			as, ok := a.(*AssignmentContext)
   184  			return ok && unquote(as.Column_name().GetText()) == c.Name && !isnull(as.Expr())
   185  		})
   186  		// Ensure the column was filled.
   187  		return affectC && idx != -1, nil
   188  	})
   189  }
   190  
   191  // CreateViewAfter checks if a view was created after the position with the given name to a table.
   192  func (p *FileParser) CreateViewAfter(f migrate.File, old, new string, pos int) (bool, error) {
   193  	return parseutil.MatchStmtAfter(f, pos, func(s *migrate.Stmt) (bool, error) {
   194  		stmt, err := ParseStmt(s.Text)
   195  		if err != nil {
   196  			return false, err
   197  		}
   198  		v, ok := stmt.CreateView(old)
   199  		if !ok {
   200  			return false, nil
   201  		}
   202  		sc, ok := v.Select_stmt().(*Select_stmtContext)
   203  		if !ok {
   204  			return false, nil
   205  		}
   206  		idx := slices.IndexFunc(sc.Select_core(0).GetChildren(), func(t antlr.Tree) bool {
   207  			ts, ok := t.(*Table_or_subqueryContext)
   208  			return ok && unquote(ts.GetText()) == new
   209  		})
   210  		return idx != -1, nil
   211  	})
   212  }
   213  
   214  // FixChange fixes the changes according to the given statement.
   215  func (p *FileParser) FixChange(_ migrate.Driver, s string, changes schema.Changes) (schema.Changes, error) {
   216  	stmt, err := ParseStmt(s)
   217  	if err != nil {
   218  		return nil, err
   219  	}
   220  	if !stmt.IsAlterTable() {
   221  		return changes, nil
   222  	}
   223  	if r, ok := stmt.RenameColumn(); ok {
   224  		if len(changes) != 1 {
   225  			return nil, fmt.Errorf("unexected number fo changes: %d", len(changes))
   226  		}
   227  		modify, ok := changes[0].(*schema.ModifyTable)
   228  		if !ok {
   229  			return nil, fmt.Errorf("expected modify-table change for alter-table statement, but got: %T", changes[0])
   230  		}
   231  		// ALTER COLUMN cannot be combined with additional commands.
   232  		if len(changes) > 2 {
   233  			return nil, fmt.Errorf("unexpected number of changes found: %d", len(changes))
   234  		}
   235  		parseutil.RenameColumn(modify, r)
   236  	}
   237  	if r, ok := stmt.RenameTable(); ok {
   238  		changes = parseutil.RenameTable(changes, r)
   239  	}
   240  	return changes, nil
   241  }
   242  
   243  func isnull(t antlr.Tree) bool {
   244  	x, ok := t.(*ExprContext)
   245  	if !ok || x.GetChildCount() != 1 {
   246  		return false
   247  	}
   248  	l, ok := x.GetChild(0).(*Literal_valueContext)
   249  	return ok && l.GetChildCount() == 1 && len(l.GetTokens(ParserNULL_)) > 0
   250  }
   251  
   252  func unquote(s string) string {
   253  	switch {
   254  	case len(s) < 2:
   255  	case s[0] == '`' && s[len(s)-1] == '`', s[0] == '"' && s[len(s)-1] == '"':
   256  		if u, err := strconv.Unquote(s); err == nil {
   257  			return u
   258  		}
   259  	}
   260  	return s
   261  }