vitess.io/vitess@v0.16.2/go/vt/vtgate/semantics/scoper.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package semantics 18 19 import ( 20 "reflect" 21 22 vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 23 "vitess.io/vitess/go/vt/vterrors" 24 "vitess.io/vitess/go/vt/vtgate/engine" 25 26 "vitess.io/vitess/go/vt/sqlparser" 27 ) 28 29 type ( 30 // scoper is responsible for figuring out the scoping for the query, 31 // and keeps the current scope when walking the tree 32 scoper struct { 33 rScope map[*sqlparser.Select]*scope 34 wScope map[*sqlparser.Select]*scope 35 scopes []*scope 36 org originable 37 binder *binder 38 39 // These scopes are only used for rewriting ORDER BY 1 and GROUP BY 1 40 specialExprScopes map[*sqlparser.Literal]*scope 41 } 42 43 scope struct { 44 parent *scope 45 stmt sqlparser.Statement 46 tables []TableInfo 47 isUnion bool 48 joinUsing map[string]TableSet 49 stmtScope bool 50 } 51 ) 52 53 func newScoper() *scoper { 54 return &scoper{ 55 rScope: map[*sqlparser.Select]*scope{}, 56 wScope: map[*sqlparser.Select]*scope{}, 57 specialExprScopes: map[*sqlparser.Literal]*scope{}, 58 } 59 } 60 61 func (s *scoper) down(cursor *sqlparser.Cursor) error { 62 node := cursor.Node() 63 switch node := node.(type) { 64 case *sqlparser.Update, *sqlparser.Delete: 65 currScope := newScope(s.currentScope()) 66 currScope.stmtScope = true 67 s.push(currScope) 68 69 currScope.stmt = node.(sqlparser.Statement) 70 case *sqlparser.Select: 71 currScope := newScope(s.currentScope()) 72 currScope.stmtScope = true 73 s.push(currScope) 74 75 // Needed for order by with Literal to find the Expression. 76 currScope.stmt = node 77 78 s.rScope[node] = currScope 79 s.wScope[node] = newScope(nil) 80 case sqlparser.TableExpr: 81 if isParentSelect(cursor) { 82 // when checking the expressions used in JOIN conditions, special rules apply where the ON expression 83 // can only see the two tables involved in the JOIN, and no other tables of that select statement. 84 // They are allowed to see the tables of the outer select query. 85 // To create this special context, we will find the parent scope of the select statement involved. 86 nScope := newScope(s.currentScope().findParentScopeOfStatement()) 87 nScope.stmt = cursor.Parent().(*sqlparser.Select) 88 s.push(nScope) 89 } 90 case sqlparser.SelectExprs: 91 sel, parentIsSelect := cursor.Parent().(*sqlparser.Select) 92 if !parentIsSelect { 93 break 94 } 95 96 // adding a vTableInfo for each SELECT, so it can be used by GROUP BY, HAVING, ORDER BY 97 // the vTableInfo we are creating here should not be confused with derived tables' vTableInfo 98 wScope, exists := s.wScope[sel] 99 if !exists { 100 break 101 } 102 wScope.tables = []TableInfo{createVTableInfoForExpressions(node, s.currentScope().tables, s.org)} 103 case sqlparser.OrderBy: 104 if isParentSelectStatement(cursor) { 105 err := s.createSpecialScopePostProjection(cursor.Parent()) 106 if err != nil { 107 return err 108 } 109 for _, order := range node { 110 lit := keepIntLiteral(order.Expr) 111 if lit != nil { 112 s.specialExprScopes[lit] = s.currentScope() 113 } 114 } 115 } 116 case sqlparser.GroupBy: 117 err := s.createSpecialScopePostProjection(cursor.Parent()) 118 if err != nil { 119 return err 120 } 121 for _, expr := range node { 122 lit := keepIntLiteral(expr) 123 if lit != nil { 124 s.specialExprScopes[lit] = s.currentScope() 125 } 126 } 127 case *sqlparser.Where: 128 if node.Type != sqlparser.HavingClause { 129 break 130 } 131 return s.createSpecialScopePostProjection(cursor.Parent()) 132 case *sqlparser.DerivedTable: 133 if node.Lateral { 134 return vterrors.VT12001("lateral derived tables") 135 } 136 } 137 return nil 138 } 139 140 func keepIntLiteral(e sqlparser.Expr) *sqlparser.Literal { 141 coll, ok := e.(*sqlparser.CollateExpr) 142 if ok { 143 e = coll.Expr 144 } 145 l, ok := e.(*sqlparser.Literal) 146 if !ok { 147 return nil 148 } 149 if l.Type != sqlparser.IntVal { 150 return nil 151 } 152 return l 153 } 154 155 func (s *scoper) up(cursor *sqlparser.Cursor) error { 156 node := cursor.Node() 157 switch node := node.(type) { 158 case sqlparser.OrderBy: 159 if isParentSelectStatement(cursor) { 160 s.popScope() 161 } 162 case *sqlparser.Select, sqlparser.GroupBy, *sqlparser.Update: 163 s.popScope() 164 case *sqlparser.Where: 165 if node.Type != sqlparser.HavingClause { 166 break 167 } 168 s.popScope() 169 case sqlparser.TableExpr: 170 if isParentSelect(cursor) { 171 curScope := s.currentScope() 172 s.popScope() 173 earlierScope := s.currentScope() 174 // copy curScope into the earlierScope 175 for _, table := range curScope.tables { 176 err := earlierScope.addTable(table) 177 if err != nil { 178 return err 179 } 180 } 181 } 182 } 183 return nil 184 } 185 186 func ValidAsMapKey(s sqlparser.SQLNode) bool { 187 return reflect.TypeOf(s).Comparable() 188 } 189 190 // createSpecialScopePostProjection is used for the special projection in ORDER BY, GROUP BY and HAVING 191 func (s *scoper) createSpecialScopePostProjection(parent sqlparser.SQLNode) error { 192 switch parent := parent.(type) { 193 case *sqlparser.Select: 194 // In ORDER BY, GROUP BY and HAVING, we can see both the scope in the FROM part of the query, and the SELECT columns created 195 // so before walking the rest of the tree, we change the scope to match this behaviour 196 incomingScope := s.currentScope() 197 nScope := newScope(incomingScope) 198 nScope.tables = s.wScope[parent].tables 199 nScope.stmt = incomingScope.stmt 200 s.push(nScope) 201 202 if s.rScope[parent] != incomingScope { 203 return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: scope counts did not match") 204 } 205 case *sqlparser.Union: 206 nScope := newScope(nil) 207 nScope.isUnion = true 208 var tableInfo *vTableInfo 209 210 for i, sel := range sqlparser.GetAllSelects(parent) { 211 if i == 0 { 212 nScope.stmt = sel 213 tableInfo = createVTableInfoForExpressions(sel.SelectExprs, nil /*needed for star expressions*/, s.org) 214 nScope.tables = append(nScope.tables, tableInfo) 215 } 216 thisTableInfo := createVTableInfoForExpressions(sel.SelectExprs, nil /*needed for star expressions*/, s.org) 217 if len(tableInfo.cols) != len(thisTableInfo.cols) { 218 return engine.ErrWrongNumberOfColumnsInSelect 219 } 220 for i, col := range tableInfo.cols { 221 // at this stage, we don't store the actual dependencies, we only store the expressions. 222 // only later will we walk the expression tree and figure out the deps. so, we need to create a 223 // composite expression that contains all the expressions in the SELECTs that this UNION consists of 224 tableInfo.cols[i] = sqlparser.AndExpressions(col, thisTableInfo.cols[i]) 225 } 226 } 227 228 s.push(nScope) 229 } 230 return nil 231 } 232 233 func (s *scoper) currentScope() *scope { 234 size := len(s.scopes) 235 if size == 0 { 236 return nil 237 } 238 return s.scopes[size-1] 239 } 240 241 func (s *scoper) push(sc *scope) { 242 s.scopes = append(s.scopes, sc) 243 } 244 245 func (s *scoper) popScope() { 246 usingMap := s.currentScope().prepareUsingMap() 247 for ts, m := range usingMap { 248 s.binder.usingJoinInfo[ts] = m 249 } 250 l := len(s.scopes) - 1 251 s.scopes = s.scopes[:l] 252 } 253 254 func newScope(parent *scope) *scope { 255 return &scope{ 256 parent: parent, 257 joinUsing: map[string]TableSet{}, 258 } 259 } 260 261 func (s *scope) addTable(info TableInfo) error { 262 name, err := info.Name() 263 if err != nil { 264 return err 265 } 266 tblName := name.Name.String() 267 for _, table := range s.tables { 268 name, err := table.Name() 269 if err != nil { 270 return err 271 } 272 273 if tblName == name.Name.String() { 274 return vterrors.VT03013(name.Name.String()) 275 } 276 } 277 s.tables = append(s.tables, info) 278 return nil 279 } 280 281 func (s *scope) prepareUsingMap() (result map[TableSet]map[string]TableSet) { 282 result = map[TableSet]map[string]TableSet{} 283 for colName, tss := range s.joinUsing { 284 for _, ts := range tss.Constituents() { 285 m := result[ts] 286 if m == nil { 287 m = map[string]TableSet{} 288 } 289 m[colName] = tss 290 result[ts] = m 291 } 292 } 293 return 294 } 295 296 // findParentScopeOfStatement finds the scope that belongs to a statement. 297 func (s *scope) findParentScopeOfStatement() *scope { 298 if s.stmtScope { 299 return s.parent 300 } 301 if s.parent == nil { 302 return nil 303 } 304 return s.parent.findParentScopeOfStatement() 305 }