github.com/iasthc/atlas/cmd/atlas@v0.0.0-20230523071841-73246df3f88d/internal/sqlparse/sqliteparse/sqliteparse.go (about) 1 // Copyright 2021-present The Atlas Authors. All rights reserved. 2 // This source code is licensed under the Apache 2.0 license found 3 // in the LICENSE file in the root directory of this source tree. 4 5 package sqliteparse 6 7 import ( 8 "errors" 9 "fmt" 10 "strconv" 11 "strings" 12 13 "github.com/iasthc/atlas/cmd/atlas/internal/sqlparse/parseutil" 14 "github.com/iasthc/atlas/sql/migrate" 15 "github.com/iasthc/atlas/sql/schema" 16 17 "github.com/antlr/antlr4/runtime/Go/antlr" 18 "golang.org/x/exp/slices" 19 ) 20 21 type ( 22 // Stmt provides extended functionality 23 // to ANTLR parsed statements. 24 Stmt struct { 25 stmt antlr.ParseTree 26 input string 27 err error 28 } 29 30 // listenError catches parse errors. 31 listenError struct { 32 antlr.DefaultErrorListener 33 err error 34 text string 35 } 36 ) 37 38 // SyntaxError implements ErrorListener.SyntaxError. 39 func (l *listenError) SyntaxError(_ antlr.Recognizer, _ any, line, column int, msg string, _ antlr.RecognitionException) { 40 if idx := strings.Index(msg, " expecting "); idx != -1 { 41 msg = msg[:idx] 42 } 43 l.err = fmt.Errorf("line %d:%d: %s", line, column+1, msg) 44 } 45 46 // ParseStmt parses a statement. 47 func ParseStmt(text string) (stmt *Stmt, err error) { 48 l := &listenError{text: text} 49 defer func() { 50 if l.err != nil { 51 err = l.err 52 stmt = nil 53 } else if perr := recover(); perr != nil { 54 m := fmt.Sprint(perr) 55 if v, ok := err.(antlr.RecognitionException); ok { 56 m = v.GetMessage() 57 } 58 err = errors.New(m) 59 stmt = nil 60 } 61 }() 62 lex := NewLexer(antlr.NewInputStream(text)) 63 lex.RemoveErrorListeners() 64 lex.AddErrorListener(l) 65 p := NewParser( 66 antlr.NewCommonTokenStream(lex, 0), 67 ) 68 p.RemoveErrorListeners() 69 p.AddErrorListener(l) 70 p.BuildParseTrees = true 71 stmt = &Stmt{ 72 stmt: p.Sql_stmt(), 73 } 74 return 75 } 76 77 // IsAlterTable reports if the statement is type ALTER TABLE. 78 func (s *Stmt) IsAlterTable() bool { 79 if s.stmt.GetChildCount() != 1 { 80 return false 81 } 82 _, ok := s.stmt.GetChild(0).(*Alter_table_stmtContext) 83 return ok 84 } 85 86 // RenameColumn returns the renamed column information from the statement, if exists. 87 func (s *Stmt) RenameColumn() (*parseutil.Rename, bool) { 88 if !s.IsAlterTable() { 89 return nil, false 90 } 91 alter := s.stmt.GetChild(0).(*Alter_table_stmtContext) 92 if alter.old_column_name == nil || alter.new_column_name == nil { 93 return nil, false 94 } 95 return &parseutil.Rename{ 96 From: unquote(alter.old_column_name.GetText()), 97 To: unquote(alter.new_column_name.GetText()), 98 }, true 99 } 100 101 // RenameTable returns the renamed table information from the statement, if exists. 102 func (s *Stmt) RenameTable() (*parseutil.Rename, bool) { 103 if !s.IsAlterTable() { 104 return nil, false 105 } 106 alter := s.stmt.GetChild(0).(*Alter_table_stmtContext) 107 if alter.new_table_name == nil { 108 return nil, false 109 } 110 return &parseutil.Rename{ 111 From: unquote(alter.Table_name(0).GetText()), 112 To: unquote(alter.new_table_name.GetText()), 113 }, true 114 } 115 116 // TableUpdate reports if the statement is an UPDATE command for the given table. 117 func (s *Stmt) TableUpdate(t *schema.Table) (*Update_stmtContext, bool) { 118 if s.stmt.GetChildCount() != 1 { 119 return nil, false 120 } 121 u, ok := s.stmt.GetChild(0).(*Update_stmtContext) 122 if !ok { 123 return nil, false 124 } 125 name, ok := u.Qualified_table_name().(*Qualified_table_nameContext) 126 if !ok || unquote(name.Table_name().GetText()) != t.Name { 127 return nil, false 128 } 129 return u, true 130 } 131 132 // CreateView reports if the statement is a CREATE VIEW command with the given name. 133 func (s *Stmt) CreateView(name string) (*Create_view_stmtContext, bool) { 134 if s.stmt.GetChildCount() != 1 { 135 return nil, false 136 } 137 v, ok := s.stmt.GetChild(0).(*Create_view_stmtContext) 138 if !ok || unquote(v.View_name().GetText()) != name { 139 return nil, false 140 } 141 return v, true 142 } 143 144 // FileParser implements the sqlparse.Parser 145 type FileParser struct{} 146 147 // ColumnFilledBefore checks if the column was filled before the given position. 148 func (p *FileParser) ColumnFilledBefore(f migrate.File, t *schema.Table, c *schema.Column, pos int) (bool, error) { 149 return parseutil.MatchStmtBefore(f, pos, func(s *migrate.Stmt) (bool, error) { 150 stmt, err := ParseStmt(s.Text) 151 if err != nil { 152 return false, err 153 } 154 u, ok := stmt.TableUpdate(t) 155 if !ok { 156 return false, nil 157 } 158 // Accept UPDATE that fills all rows or those with NULL values as we cannot 159 // determine if NULL values were filled in case there is a custom filtering. 160 affectC := func() bool { 161 x := u.GetWhere() 162 if x == nil { 163 return true 164 } 165 if x.GetChildCount() != 3 { 166 return false 167 } 168 x1, ok := x.GetChild(0).(*ExprContext) 169 if !ok || unquote(x1.GetText()) != c.Name { 170 return false 171 } 172 x2, ok := x.GetChild(1).(*antlr.TerminalNodeImpl) 173 if !ok || x2.GetSymbol().GetTokenType() != ParserIS_ { 174 return false 175 } 176 return isnull(x.GetChild(2)) 177 }() 178 list, ok := u.Assignment_list().(*Assignment_listContext) 179 if !ok { 180 return false, nil 181 } 182 idx := slices.IndexFunc(list.AllAssignment(), func(a IAssignmentContext) bool { 183 as, ok := a.(*AssignmentContext) 184 return ok && unquote(as.Column_name().GetText()) == c.Name && !isnull(as.Expr()) 185 }) 186 // Ensure the column was filled. 187 return affectC && idx != -1, nil 188 }) 189 } 190 191 // CreateViewAfter checks if a view was created after the position with the given name to a table. 192 func (p *FileParser) CreateViewAfter(f migrate.File, old, new string, pos int) (bool, error) { 193 return parseutil.MatchStmtAfter(f, pos, func(s *migrate.Stmt) (bool, error) { 194 stmt, err := ParseStmt(s.Text) 195 if err != nil { 196 return false, err 197 } 198 v, ok := stmt.CreateView(old) 199 if !ok { 200 return false, nil 201 } 202 sc, ok := v.Select_stmt().(*Select_stmtContext) 203 if !ok { 204 return false, nil 205 } 206 idx := slices.IndexFunc(sc.Select_core(0).GetChildren(), func(t antlr.Tree) bool { 207 ts, ok := t.(*Table_or_subqueryContext) 208 return ok && unquote(ts.GetText()) == new 209 }) 210 return idx != -1, nil 211 }) 212 } 213 214 // FixChange fixes the changes according to the given statement. 215 func (p *FileParser) FixChange(_ migrate.Driver, s string, changes schema.Changes) (schema.Changes, error) { 216 stmt, err := ParseStmt(s) 217 if err != nil { 218 return nil, err 219 } 220 if !stmt.IsAlterTable() { 221 return changes, nil 222 } 223 if r, ok := stmt.RenameColumn(); ok { 224 if len(changes) != 1 { 225 return nil, fmt.Errorf("unexected number fo changes: %d", len(changes)) 226 } 227 modify, ok := changes[0].(*schema.ModifyTable) 228 if !ok { 229 return nil, fmt.Errorf("expected modify-table change for alter-table statement, but got: %T", changes[0]) 230 } 231 // ALTER COLUMN cannot be combined with additional commands. 232 if len(changes) > 2 { 233 return nil, fmt.Errorf("unexpected number of changes found: %d", len(changes)) 234 } 235 parseutil.RenameColumn(modify, r) 236 } 237 if r, ok := stmt.RenameTable(); ok { 238 changes = parseutil.RenameTable(changes, r) 239 } 240 return changes, nil 241 } 242 243 func isnull(t antlr.Tree) bool { 244 x, ok := t.(*ExprContext) 245 if !ok || x.GetChildCount() != 1 { 246 return false 247 } 248 l, ok := x.GetChild(0).(*Literal_valueContext) 249 return ok && l.GetChildCount() == 1 && len(l.GetTokens(ParserNULL_)) > 0 250 } 251 252 func unquote(s string) string { 253 switch { 254 case len(s) < 2: 255 case s[0] == '`' && s[len(s)-1] == '`', s[0] == '"' && s[len(s)-1] == '"': 256 if u, err := strconv.Unquote(s); err == nil { 257 return u 258 } 259 } 260 return s 261 }