github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/scope.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package planbuilder 16 17 import ( 18 "fmt" 19 "strings" 20 21 ast "github.com/dolthub/vitess/go/vt/sqlparser" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 "github.com/dolthub/go-mysql-server/sql/expression" 25 "github.com/dolthub/go-mysql-server/sql/plan" 26 "github.com/dolthub/go-mysql-server/sql/transform" 27 ) 28 29 // scope tracks relational dependencies necessary to type check expressions, 30 // resolve name definitions, and build relational nodes. 31 type scope struct { 32 b *Builder 33 parent *scope 34 ast ast.SQLNode 35 node sql.Node 36 37 activeSubquery *subquery 38 refsSubquery bool 39 40 // cols are definitions provided by this scope 41 cols []scopeColumn 42 colset sql.ColSet 43 // extraCols are auxillary output columns required 44 // for sorting or grouping 45 extraCols []scopeColumn 46 // redirectCol is used for using and natural joins right-table 47 // attributes that redirect to the left table intersection 48 redirectCol map[string]scopeColumn 49 // tables are the list of table definitions in this scope 50 tables map[string]sql.TableId 51 // ctes are common table expressions defined in this scope 52 // TODO these should be case-sensitive 53 ctes map[string]*scope 54 // groupBy collects aggregation functions and inputs 55 groupBy *groupBy 56 // windowFuncs is a list of window functions in the current scope 57 windowFuncs []scopeColumn 58 windowDefs map[string]*sql.WindowDefinition 59 // exprs collects unique expression ids for reference 60 exprs map[string]columnId 61 proc *procCtx 62 } 63 64 // resolveColumn matches a variable use to a column definition with a unique 65 // expression id. |chooseFirst| is indicated for accepting ambiguous having and 66 // group by columns. 67 func (s *scope) resolveColumn(db, table, col string, checkParent, chooseFirst bool) (scopeColumn, bool) { 68 // procedure params take precedence 69 if table == "" && checkParent && s.procActive() { 70 col, ok := s.proc.GetVar(col) 71 if ok { 72 return col, true 73 } 74 } 75 76 // Unqualified columns that have been redirected should return early to avoid ambiguous column errors. 77 if table == "" && s.redirectCol != nil { 78 if rCol, ok := s.redirectCol[col]; ok { 79 return rCol, true 80 } 81 } 82 83 var found scopeColumn 84 var foundCand bool 85 for _, c := range s.cols { 86 if strings.EqualFold(c.col, col) && (strings.EqualFold(c.table, table) || table == "") && (strings.EqualFold(c.db, db) || db == "") { 87 if foundCand { 88 if found.equals(c) { 89 continue 90 } 91 92 if !s.b.TriggerCtx().Call && len(s.b.TriggerCtx().UnresolvedTables) > 0 { 93 c, ok := s.triggerCol(table, col) 94 if ok { 95 return c, true 96 } 97 } 98 if c.table == OnDupValuesPrefix { 99 return found, true 100 } else if found.table == OnDupValuesPrefix { 101 return c, true 102 } 103 err := sql.ErrAmbiguousColumnName.New(col, []string{c.table, found.table}) 104 if c.table == "" { 105 err = sql.ErrAmbiguousColumnOrAliasName.New(c.col) 106 } 107 s.handleErr(err) 108 } 109 if chooseFirst || s.groupBy != nil { 110 return c, true 111 } 112 found = c 113 foundCand = true 114 } 115 } 116 if foundCand { 117 return found, true 118 } 119 120 if s.groupBy != nil { 121 if c, ok := s.groupBy.outScope.resolveColumn(db, table, col, false, false); ok { 122 return c, true 123 } 124 } 125 126 if !s.b.TriggerCtx().Call && len(s.b.TriggerCtx().UnresolvedTables) > 0 { 127 c, ok := s.triggerCol(table, col) 128 if ok { 129 return c, true 130 } 131 } 132 133 if !checkParent || s.parent == nil { 134 return scopeColumn{}, false 135 } 136 137 c, foundCand := s.parent.resolveColumn(db, table, col, true, false) 138 if !foundCand { 139 return scopeColumn{}, false 140 } 141 142 if s.parent.activeSubquery != nil { 143 s.parent.activeSubquery.addOutOfScope(c.id) 144 } 145 return c, true 146 } 147 148 func (s *scope) hasTable(table string) bool { 149 _, ok := s.tables[strings.ToLower(table)] 150 if ok { 151 return true 152 } 153 if s.parent != nil { 154 return s.parent.hasTable(table) 155 } 156 return false 157 } 158 159 // triggerCol is used to hallucinate a new column during trigger DDL 160 // when we fail a resolveColumn. 161 func (s *scope) triggerCol(table, col string) (scopeColumn, bool) { 162 // hallucinate tablecol 163 dbName := "" 164 if s.b.currentDatabase != nil { 165 dbName = s.b.currentDatabase.Name() 166 } 167 for _, t := range s.b.TriggerCtx().UnresolvedTables { 168 if strings.EqualFold(t, table) { 169 col := scopeColumn{db: dbName, table: table, col: col} 170 id := s.newColumn(col) 171 col.id = id 172 return col, true 173 } 174 } 175 if table == "" { 176 col := scopeColumn{db: dbName, table: table, col: col} 177 id := s.newColumn(col) 178 col.id = id 179 return col, true 180 } 181 return scopeColumn{}, false 182 } 183 184 // getExpr returns a columnId if the given expression has 185 // been built. 186 func (s *scope) getExpr(name string, checkCte bool) (columnId, bool) { 187 n := strings.ToLower(name) 188 id, ok := s.exprs[n] 189 if !ok && s.groupBy != nil { 190 id, ok = s.groupBy.outScope.getExpr(n, checkCte) 191 } 192 if !ok && checkCte && s.ctes != nil { 193 for _, cte := range s.ctes { 194 id, ok = cte.getExpr(n, false) 195 if ok { 196 break 197 } 198 } 199 } 200 // TODO: possibly want to look in parent scopes 201 if !ok && s.parent != nil { 202 return s.parent.getExpr(name, checkCte) 203 } 204 return id, ok 205 } 206 207 func (s *scope) procActive() bool { 208 return s.proc != nil 209 } 210 211 func (s *scope) initProc() { 212 s.proc = &procCtx{ 213 s: s, 214 conditions: make(map[string]*plan.DeclareCondition), 215 cursors: make(map[string]struct{}), 216 vars: make(map[string]scopeColumn), 217 labels: make(map[string]bool), 218 lastState: dsVariable, 219 } 220 } 221 222 // initGroupBy creates a container scope for aggregation 223 // functions and function inputs. 224 func (s *scope) initGroupBy() { 225 s.groupBy = &groupBy{outScope: s.replace()} 226 } 227 228 // pushSubquery creates a new scope with the subquery already initialized. 229 func (s *scope) pushSubquery() *scope { 230 newScope := s.push() 231 newScope.activeSubquery = &subquery{parent: s.nearestSubquery()} 232 return newScope 233 } 234 235 // replaceSubquery creates a new scope with the subquery already initialized. 236 func (s *scope) replaceSubquery() *scope { 237 newScope := s.replace() 238 newScope.activeSubquery = &subquery{parent: s.nearestSubquery()} 239 return newScope 240 } 241 242 // initSubquery creates a container for tracking out of scope 243 // column references and volatile functions. 244 func (s *scope) initSubquery() { 245 s.activeSubquery = &subquery{} 246 } 247 248 func (s *scope) correlated() sql.ColSet { 249 if s.activeSubquery == nil { 250 return sql.ColSet{} 251 } 252 return s.activeSubquery.correlated 253 } 254 255 func (s *scope) volatile() bool { 256 if s.activeSubquery == nil { 257 return false 258 } 259 return s.activeSubquery.volatile 260 } 261 262 func (s *scope) nearestSubquery() *subquery { 263 n := s 264 for n != nil { 265 if n.activeSubquery != nil { 266 return n.activeSubquery 267 } 268 n = n.parent 269 } 270 return nil 271 } 272 273 // setTableAlias updates column definitions in this scope to 274 // appear sourced from a new table name. 275 func (s *scope) setTableAlias(t string) { 276 t = strings.ToLower(t) 277 var oldTable string 278 for i := range s.cols { 279 beforeColStr := s.cols[i].String() 280 if oldTable == "" { 281 oldTable = s.cols[i].table 282 } 283 s.cols[i].table = t 284 id, ok := s.getExpr(beforeColStr, true) 285 if ok { 286 // todo better way to do projections 287 delete(s.exprs, beforeColStr) 288 } 289 s.exprs[strings.ToLower(s.cols[i].String())] = id 290 } 291 id, ok := s.tables[oldTable] 292 if !ok { 293 return 294 } 295 delete(s.tables, oldTable) 296 if s.tables == nil { 297 s.tables = make(map[string]sql.TableId) 298 } 299 s.tables[t] = id 300 } 301 302 // setColAlias updates the column name definitions for this scope 303 // to the names in the input list. 304 func (s *scope) setColAlias(cols []string) { 305 if len(cols) != len(s.cols) { 306 err := sql.ErrColumnCountMismatch.New() 307 s.b.handleErr(err) 308 } 309 ids := make([]columnId, len(cols)) 310 for i := range s.cols { 311 beforeColStr := s.cols[i].String() 312 id, ok := s.getExpr(beforeColStr, true) 313 if ok { 314 // todo better way to do projections 315 delete(s.exprs, beforeColStr) 316 } 317 ids[i] = id 318 delete(s.exprs, beforeColStr) 319 } 320 for i := range s.cols { 321 name := strings.ToLower(cols[i]) 322 s.cols[i].col = name 323 s.exprs[s.cols[i].String()] = ids[i] 324 } 325 } 326 327 // push creates a new scope referencing the current scope as a 328 // parent. Variables in the new scope will have name visibility 329 // into this scope. 330 func (s *scope) push() *scope { 331 new := &scope{ 332 b: s.b, 333 parent: s, 334 } 335 if s.procActive() { 336 new.initProc() 337 } 338 return new 339 } 340 341 // replace creates a new scope with the same parent definition 342 // visibility as the current scope. Useful for groupby and subqueries 343 // that have more complex naming hierarchy. 344 func (s *scope) replace() *scope { 345 if s == nil { 346 return &scope{} 347 } 348 return &scope{ 349 b: s.b, 350 parent: s.parent, 351 } 352 } 353 354 // aliasCte copies a scope, but increments the column and table ids 355 // for the new relation. 356 func (s *scope) aliasCte(alias string) *scope { 357 if s == nil { 358 return nil 359 } 360 outScope := s.copy() 361 if _, ok := s.tables[alias]; ok || alias == "" { 362 return outScope 363 } 364 365 sq, _ := outScope.node.(*plan.SubqueryAlias) 366 367 tabId := outScope.addTable(alias) 368 outScope.cols = nil 369 var colSet sql.ColSet 370 scopeMapping := make(map[sql.ColumnId]sql.Expression) 371 for _, c := range s.cols { 372 id := outScope.newColumn(scopeColumn{ 373 tableId: tabId, 374 db: c.db, 375 table: alias, 376 col: c.col, 377 originalCol: c.originalCol, 378 id: 0, 379 typ: c.typ, 380 nullable: c.nullable, 381 }) 382 colSet.Add(sql.ColumnId(id)) 383 // todo double scope mapping 384 if sq != nil { 385 scopeMapping[sql.ColumnId(id)] = sq.ScopeMapping[sql.ColumnId(c.id)] 386 } 387 } 388 389 if sq != nil { 390 outScope.node = sq.WithScopeMapping(scopeMapping).WithColumns(colSet).WithId(tabId) 391 } 392 return outScope 393 } 394 395 // copy produces an identical scope with copied references. 396 func (s *scope) copy() *scope { 397 if s == nil { 398 return nil 399 } 400 401 ret := *s 402 if ret.node != nil { 403 ret.node, _ = DeepCopyNode(s.node) 404 } 405 if s.tables != nil { 406 ret.tables = make(map[string]sql.TableId, len(s.tables)) 407 for k, v := range s.tables { 408 ret.tables[k] = v 409 } 410 } 411 if s.ctes != nil { 412 ret.ctes = make(map[string]*scope, len(s.ctes)) 413 for k, v := range s.ctes { 414 ret.ctes[k] = v 415 } 416 } 417 if s.exprs != nil { 418 ret.exprs = make(map[string]columnId, len(s.exprs)) 419 for k, v := range s.exprs { 420 ret.exprs[k] = v 421 } 422 } 423 if s.groupBy != nil { 424 gbCopy := *s.groupBy 425 ret.groupBy = &gbCopy 426 } 427 if s.cols != nil { 428 ret.cols = make([]scopeColumn, len(s.cols)) 429 copy(ret.cols, s.cols) 430 } 431 if !s.colset.Empty() { 432 ret.colset = s.colset.Copy() 433 } 434 435 return &ret 436 } 437 438 // DeepCopyNode copies a sql.Node. 439 func DeepCopyNode(node sql.Node) (sql.Node, error) { 440 n, _, err := transform.NodeExprs(node, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { 441 e, err := transform.Clone(e) 442 return e, transform.NewTree, err 443 }) 444 return n, err 445 } 446 447 // addCte adds a cte definition to this scope for table resolution. 448 func (s *scope) addCte(name string, cteScope *scope) { 449 if s.ctes == nil { 450 s.ctes = make(map[string]*scope) 451 } 452 s.ctes[name] = cteScope 453 s.addTable(name) 454 } 455 456 // getCte attempts to resolve a table name as a cte definition. 457 func (s *scope) getCte(name string) *scope { 458 checkScope := s 459 for checkScope != nil { 460 if checkScope.ctes != nil { 461 cte, ok := checkScope.ctes[strings.ToLower(name)] 462 if ok { 463 return cte 464 } 465 } 466 checkScope = checkScope.parent 467 } 468 return nil 469 } 470 471 // redirect overwrites a definition with an alias rewrite, 472 // without preventing us from resolving the original column. 473 // This is used for resolving natural join projections. 474 func (s *scope) redirect(from, to scopeColumn) { 475 if s.redirectCol == nil { 476 s.redirectCol = make(map[string]scopeColumn) 477 } 478 s.redirectCol[from.String()] = to 479 } 480 481 // addColumn interns and saves the given column to this scope. 482 // todo: new IR should absorb interning and use bitmaps for 483 // column identity 484 func (s *scope) addColumn(col scopeColumn) { 485 s.cols = append(s.cols, col) 486 s.colset.Add(sql.ColumnId(col.id)) 487 if s.exprs == nil { 488 s.exprs = make(map[string]columnId) 489 } 490 s.exprs[strings.ToLower(col.String())] = col.id 491 return 492 } 493 494 // newColumn adds the column to the current scope and assigns a 495 // new columnId for referencing. newColumn builds a new expression 496 // reference, whereas addColumn only adds a preexisting expression 497 // definition to a given scope. 498 func (s *scope) newColumn(col scopeColumn) columnId { 499 s.b.colId++ 500 col.id = s.b.colId 501 if col.table != "" { 502 tabId := s.addTable(col.table) 503 col.tableId = tabId 504 } 505 s.addColumn(col) 506 return col.id 507 } 508 509 // addTable records adds a table name defined in this scope 510 func (s *scope) addTable(name string) sql.TableId { 511 if name == "" { 512 return 0 513 } 514 name = strings.ToLower(name) 515 if s.tables == nil { 516 s.tables = make(map[string]sql.TableId) 517 } 518 if _, ok := s.tables[name]; !ok { 519 s.b.tabId++ 520 s.tables[name] = s.b.tabId 521 } 522 return s.tables[name] 523 } 524 525 // addExtraColumn marks an auxiliary column used in an 526 // aggregation, sorting, or having clause. 527 func (s *scope) addExtraColumn(col scopeColumn) { 528 s.extraCols = append(s.extraCols, col) 529 } 530 531 func (s *scope) addColumns(cols []scopeColumn) { 532 s.cols = append(s.cols, cols...) 533 } 534 535 // appendColumnsFromScope merges column definitions for 536 // multi-relational expressions. 537 func (s *scope) appendColumnsFromScope(src *scope) { 538 s.cols = append(s.cols, src.cols...) 539 if len(src.exprs) > 0 && s.exprs == nil { 540 s.exprs = make(map[string]columnId) 541 } 542 for k, v := range src.exprs { 543 s.exprs[k] = v 544 } 545 if len(src.redirectCol) > 0 && s.redirectCol == nil { 546 s.redirectCol = make(map[string]scopeColumn) 547 } 548 for k, v := range src.redirectCol { 549 s.redirectCol[k] = v 550 } 551 if len(src.tables) > 0 && s.tables == nil { 552 s.tables = make(map[string]sql.TableId) 553 } 554 for k, v := range src.tables { 555 s.tables[k] = v 556 } 557 // these become pass-through columns in the new scope. 558 for i := len(src.cols); i < len(s.cols); i++ { 559 s.cols[i].scalar = nil 560 } 561 } 562 563 func (s *scope) handleErr(err error) { 564 panic(parseErr{err}) 565 } 566 567 // tableId and columnId are temporary ways to track expression 568 // and name uniqueness. 569 // todo: the plan format should track these 570 type tableId uint16 571 type columnId uint16 572 573 type scopeColumn struct { 574 nullable bool 575 descending bool 576 outOfScope bool 577 id columnId 578 typ sql.Type 579 scalar sql.Expression 580 tableId sql.TableId 581 db string 582 table string 583 col string 584 originalCol string 585 } 586 587 // empty returns true if a scopeColumn is the null value 588 func (c scopeColumn) empty() bool { 589 return c.id == 0 590 } 591 592 func (c scopeColumn) equals(other scopeColumn) bool { 593 if c.id == other.id { 594 return true 595 } 596 if c.unwrapGetFieldAliasId() == other.unwrapGetFieldAliasId() { 597 return true 598 } 599 return false 600 } 601 602 func (c scopeColumn) unwrapGetFieldAliasId() columnId { 603 if c.scalar != nil { 604 if a, ok := c.scalar.(*expression.Alias); ok { 605 if gf, ok := a.Child.(*expression.GetField); ok { 606 return columnId(gf.Id()) 607 } 608 } 609 } 610 return c.id 611 } 612 613 func (c scopeColumn) withOriginal(col string) scopeColumn { 614 if !strings.EqualFold(c.db, sql.InformationSchemaDatabaseName) { 615 // info schema columns always presented as uppercase 616 c.originalCol = col 617 } 618 return c 619 } 620 621 // scalarGf returns a getField reference to this column's expression. 622 func (c scopeColumn) scalarGf() sql.Expression { 623 if c.scalar != nil { 624 if p, ok := c.scalar.(*expression.ProcedureParam); ok { 625 return p 626 } 627 } 628 if c.originalCol != "" { 629 return expression.NewGetFieldWithTable(int(c.id), int(c.tableId), c.typ, c.db, c.table, c.originalCol, c.nullable) 630 } 631 return expression.NewGetFieldWithTable(int(c.id), int(c.tableId), c.typ, c.db, c.table, c.col, c.nullable) 632 } 633 634 func (c scopeColumn) String() string { 635 if c.table == "" { 636 return c.col 637 } else { 638 return fmt.Sprintf("%s.%s", c.table, c.col) 639 } 640 }