vitess.io/vitess@v0.16.2/go/vt/vtgate/semantics/semantic_state.go (about) 1 /* 2 Copyright 2020 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package semantics 18 19 import ( 20 "vitess.io/vitess/go/mysql/collations" 21 "vitess.io/vitess/go/sqltypes" 22 "vitess.io/vitess/go/vt/key" 23 querypb "vitess.io/vitess/go/vt/proto/query" 24 topodatapb "vitess.io/vitess/go/vt/proto/topodata" 25 vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 26 "vitess.io/vitess/go/vt/vterrors" 27 "vitess.io/vitess/go/vt/vtgate/evalengine" 28 "vitess.io/vitess/go/vt/vtgate/vindexes" 29 30 "vitess.io/vitess/go/vt/sqlparser" 31 ) 32 33 type ( 34 // TableInfo contains information about tables 35 TableInfo interface { 36 // Name returns the table name 37 Name() (sqlparser.TableName, error) 38 39 // GetVindexTable returns the vschema version of this TableInfo 40 GetVindexTable() *vindexes.Table 41 42 // IsInfSchema returns true if this table is information_schema 43 IsInfSchema() bool 44 45 // matches returns true if the provided table name matches this TableInfo 46 matches(name sqlparser.TableName) bool 47 48 // authoritative is true if we have exhaustive column information 49 authoritative() bool 50 51 // getExpr returns the AST struct behind this table 52 getExpr() *sqlparser.AliasedTableExpr 53 54 // getColumns returns the known column information for this table 55 getColumns() []ColumnInfo 56 57 dependencies(colName string, org originable) (dependencies, error) 58 getExprFor(s string) (sqlparser.Expr, error) 59 getTableSet(org originable) TableSet 60 } 61 62 // ColumnInfo contains information about columns 63 ColumnInfo struct { 64 Name string 65 Type Type 66 } 67 68 // ExprDependencies stores the tables that an expression depends on as a map 69 ExprDependencies map[sqlparser.Expr]TableSet 70 71 // SemTable contains semantic analysis information about the query. 72 SemTable struct { 73 Tables []TableInfo 74 75 // NotSingleRouteErr stores any errors that have to be generated if the query cannot be planned as a single route. 76 NotSingleRouteErr error 77 // NotUnshardedErr stores any errors that have to be generated if the query is not unsharded. 78 NotUnshardedErr error 79 80 // Recursive contains the dependencies from the expression to the actual tables 81 // in the query (i.e. not including derived tables). If an expression is a column on a derived table, 82 // this map will contain the accumulated dependencies for the column expression inside the derived table 83 Recursive ExprDependencies 84 85 // Direct keeps information about the closest dependency for an expression. 86 // It does not recurse inside derived tables and the like to find the original dependencies 87 Direct ExprDependencies 88 89 ExprTypes map[sqlparser.Expr]Type 90 selectScope map[*sqlparser.Select]*scope 91 Comments *sqlparser.ParsedComments 92 SubqueryMap map[sqlparser.Statement][]*sqlparser.ExtractedSubquery 93 SubqueryRef map[*sqlparser.Subquery]*sqlparser.ExtractedSubquery 94 95 // ColumnEqualities is used to enable transitive closures 96 // if a == b and b == c then a == c 97 ColumnEqualities map[columnName][]sqlparser.Expr 98 99 // DefaultCollation is the default collation for this query, which is usually 100 // inherited from the connection's default collation. 101 Collation collations.ID 102 103 Warning string 104 105 // ExpandedColumns is a map of all the added columns for a given table. 106 ExpandedColumns map[sqlparser.TableName][]*sqlparser.ColName 107 108 comparator *sqlparser.Comparator 109 } 110 111 columnName struct { 112 Table TableSet 113 ColumnName string 114 } 115 116 // SchemaInformation is used tp provide table information from Vschema. 117 SchemaInformation interface { 118 FindTableOrVindex(tablename sqlparser.TableName) (*vindexes.Table, vindexes.Vindex, string, topodatapb.TabletType, key.Destination, error) 119 ConnCollation() collations.ID 120 } 121 ) 122 123 var ( 124 // ErrNotSingleTable refers to an error happening when something should be used only for single tables 125 ErrNotSingleTable = vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] should only be used for single tables") 126 ) 127 128 // CopyDependencies copies the dependencies from one expression into the other 129 func (st *SemTable) CopyDependencies(from, to sqlparser.Expr) { 130 st.Recursive[to] = st.RecursiveDeps(from) 131 st.Direct[to] = st.DirectDeps(from) 132 } 133 134 // EmptySemTable creates a new empty SemTable 135 func EmptySemTable() *SemTable { 136 return &SemTable{ 137 Recursive: map[sqlparser.Expr]TableSet{}, 138 Direct: map[sqlparser.Expr]TableSet{}, 139 ColumnEqualities: map[columnName][]sqlparser.Expr{}, 140 } 141 } 142 143 // TableSetFor returns the bitmask for this particular table 144 func (st *SemTable) TableSetFor(t *sqlparser.AliasedTableExpr) TableSet { 145 for idx, t2 := range st.Tables { 146 if t == t2.getExpr() { 147 return SingleTableSet(idx) 148 } 149 } 150 return EmptyTableSet() 151 } 152 153 // ReplaceTableSetFor replaces the given single TabletSet with the new *sqlparser.AliasedTableExpr 154 func (st *SemTable) ReplaceTableSetFor(id TableSet, t *sqlparser.AliasedTableExpr) { 155 if id.NumberOfTables() != 1 { 156 // This is probably a derived table 157 return 158 } 159 tblOffset := id.TableOffset() 160 if tblOffset > len(st.Tables) { 161 // This should not happen and is probably a bug, but the output query will still work fine 162 return 163 } 164 switch tbl := st.Tables[id.TableOffset()].(type) { 165 case *RealTable: 166 tbl.ASTNode = t 167 case *DerivedTable: 168 tbl.ASTNode = t 169 } 170 } 171 172 // TableInfoFor returns the table info for the table set. It should contains only single table. 173 func (st *SemTable) TableInfoFor(id TableSet) (TableInfo, error) { 174 offset := id.TableOffset() 175 if offset < 0 { 176 return nil, ErrNotSingleTable 177 } 178 return st.Tables[offset], nil 179 } 180 181 // RecursiveDeps return the table dependencies of the expression. 182 func (st *SemTable) RecursiveDeps(expr sqlparser.Expr) TableSet { 183 return st.Recursive.dependencies(expr) 184 } 185 186 // DirectDeps return the table dependencies of the expression. 187 func (st *SemTable) DirectDeps(expr sqlparser.Expr) TableSet { 188 return st.Direct.dependencies(expr) 189 } 190 191 // AddColumnEquality adds a relation of the given colName to the ColumnEqualities map 192 func (st *SemTable) AddColumnEquality(colName *sqlparser.ColName, expr sqlparser.Expr) { 193 ts := st.Direct.dependencies(colName) 194 columnName := columnName{ 195 Table: ts, 196 ColumnName: colName.Name.String(), 197 } 198 elem := st.ColumnEqualities[columnName] 199 elem = append(elem, expr) 200 st.ColumnEqualities[columnName] = elem 201 } 202 203 // GetExprAndEqualities returns a slice containing the given expression, and it's known equalities if any 204 func (st *SemTable) GetExprAndEqualities(expr sqlparser.Expr) []sqlparser.Expr { 205 result := []sqlparser.Expr{expr} 206 switch expr := expr.(type) { 207 case *sqlparser.ColName: 208 table := st.DirectDeps(expr) 209 k := columnName{Table: table, ColumnName: expr.Name.String()} 210 result = append(result, st.ColumnEqualities[k]...) 211 } 212 return result 213 } 214 215 // TableInfoForExpr returns the table info of the table that this expression depends on. 216 // Careful: this only works for expressions that have a single table dependency 217 func (st *SemTable) TableInfoForExpr(expr sqlparser.Expr) (TableInfo, error) { 218 return st.TableInfoFor(st.Direct.dependencies(expr)) 219 } 220 221 // GetSelectTables returns the table in the select. 222 func (st *SemTable) GetSelectTables(node *sqlparser.Select) []TableInfo { 223 scope := st.selectScope[node] 224 return scope.tables 225 } 226 227 // AddExprs adds new select exprs to the SemTable. 228 func (st *SemTable) AddExprs(tbl *sqlparser.AliasedTableExpr, cols sqlparser.SelectExprs) { 229 tableSet := st.TableSetFor(tbl) 230 for _, col := range cols { 231 st.Recursive[col.(*sqlparser.AliasedExpr).Expr] = tableSet 232 } 233 } 234 235 // TypeFor returns the type of expressions in the query 236 func (st *SemTable) TypeFor(e sqlparser.Expr) *querypb.Type { 237 typ, found := st.ExprTypes[e] 238 if found { 239 return &typ.Type 240 } 241 return nil 242 } 243 244 // CollationForExpr returns the collation name of expressions in the query 245 func (st *SemTable) CollationForExpr(e sqlparser.Expr) collations.ID { 246 typ, found := st.ExprTypes[e] 247 if found { 248 return typ.Collation 249 } 250 return collations.Unknown 251 } 252 253 // NeedsWeightString returns true if the given expression needs weight_string to do safe comparisons 254 func (st *SemTable) NeedsWeightString(e sqlparser.Expr) bool { 255 typ, found := st.ExprTypes[e] 256 if !found { 257 return true 258 } 259 return typ.Collation == collations.Unknown && !sqltypes.IsNumber(typ.Type) 260 } 261 262 func (st *SemTable) DefaultCollation() collations.ID { 263 return st.Collation 264 } 265 266 // dependencies return the table dependencies of the expression. This method finds table dependencies recursively 267 func (d ExprDependencies) dependencies(expr sqlparser.Expr) (deps TableSet) { 268 if ValidAsMapKey(expr) { 269 // we have something that could live in the cache 270 var found bool 271 deps, found = d[expr] 272 if found { 273 return deps 274 } 275 defer func() { 276 d[expr] = deps 277 }() 278 } 279 280 // During the original semantic analysis, all ColNames were found and bound to the corresponding tables 281 // Here, we'll walk the expression tree and look to see if we can find any sub-expressions 282 // that have already set dependencies. 283 _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { 284 expr, ok := node.(sqlparser.Expr) 285 if !ok || !ValidAsMapKey(expr) { 286 // if this is not an expression, or it is an expression we can't use as a map-key, 287 // just carry on down the tree 288 return true, nil 289 } 290 291 if extracted, ok := expr.(*sqlparser.ExtractedSubquery); ok { 292 if extracted.OtherSide != nil { 293 set := d.dependencies(extracted.OtherSide) 294 deps = deps.Merge(set) 295 } 296 return false, nil 297 } 298 set, found := d[expr] 299 deps = deps.Merge(set) 300 301 // if we found a cached value, there is no need to continue down to visit children 302 return !found, nil 303 }, expr) 304 305 return deps 306 } 307 308 // RewriteDerivedTableExpression rewrites all the ColName instances in the supplied expression with 309 // the expressions behind the column definition of the derived table 310 // SELECT foo FROM (SELECT id+42 as foo FROM user) as t 311 // We need `foo` to be translated to `id+42` on the inside of the derived table 312 func RewriteDerivedTableExpression(expr sqlparser.Expr, vt TableInfo) sqlparser.Expr { 313 return sqlparser.CopyOnRewrite(expr, nil, func(cursor *sqlparser.CopyOnWriteCursor) { 314 node, ok := cursor.Node().(*sqlparser.ColName) 315 if !ok { 316 return 317 } 318 exp, err := vt.getExprFor(node.Name.String()) 319 if err == nil { 320 cursor.Replace(exp) 321 return 322 } 323 324 // cloning the expression and removing the qualifier 325 col := *node 326 col.Qualifier = sqlparser.TableName{} 327 cursor.Replace(&col) 328 329 }, nil).(sqlparser.Expr) 330 } 331 332 // FindSubqueryReference goes over the sub queries and searches for it by value equality instead of reference equality 333 func (st *SemTable) FindSubqueryReference(subquery *sqlparser.Subquery) *sqlparser.ExtractedSubquery { 334 for foundSubq, extractedSubquery := range st.SubqueryRef { 335 if sqlparser.Equals.RefOfSubquery(subquery, foundSubq) { 336 return extractedSubquery 337 } 338 } 339 return nil 340 } 341 342 // GetSubqueryNeedingRewrite returns a list of sub-queries that need to be rewritten 343 func (st *SemTable) GetSubqueryNeedingRewrite() []*sqlparser.ExtractedSubquery { 344 var res []*sqlparser.ExtractedSubquery 345 for _, extractedSubquery := range st.SubqueryRef { 346 if extractedSubquery.NeedsRewrite { 347 res = append(res, extractedSubquery) 348 } 349 } 350 return res 351 } 352 353 // CopyExprInfo lookups src in the ExprTypes map and, if a key is found, assign 354 // the corresponding Type value of src to dest. 355 func (st *SemTable) CopyExprInfo(src, dest sqlparser.Expr) { 356 srcType, found := st.ExprTypes[src] 357 if found { 358 st.ExprTypes[dest] = srcType 359 } 360 } 361 362 var _ evalengine.TranslationLookup = (*SemTable)(nil) 363 364 var columnNotSupportedErr = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "column access not supported here") 365 366 // ColumnLookup implements the TranslationLookup interface 367 func (st *SemTable) ColumnLookup(*sqlparser.ColName) (int, error) { 368 return 0, columnNotSupportedErr 369 } 370 371 // SingleUnshardedKeyspace returns the single keyspace if all tables in the query are in the same, unsharded keyspace 372 func (st *SemTable) SingleUnshardedKeyspace() (*vindexes.Keyspace, []*vindexes.Table) { 373 var ks *vindexes.Keyspace 374 var tables []*vindexes.Table 375 for _, table := range st.Tables { 376 vindexTable := table.GetVindexTable() 377 378 if vindexTable == nil { 379 _, isDT := table.getExpr().Expr.(*sqlparser.DerivedTable) 380 if isDT { 381 // derived tables are ok, as long as all real tables are from the same unsharded keyspace 382 // we check the real tables inside the derived table as well for same unsharded keyspace. 383 continue 384 } 385 return nil, nil 386 } 387 if vindexTable.Type != "" { 388 // A reference table is not an issue when seeing if a query is going to an unsharded keyspace 389 if vindexTable.Type == vindexes.TypeReference { 390 continue 391 } 392 return nil, nil 393 } 394 name, ok := table.getExpr().Expr.(sqlparser.TableName) 395 if !ok { 396 return nil, nil 397 } 398 if name.Name.String() != vindexTable.Name.String() { 399 // this points to a table alias. safer to not shortcut 400 return nil, nil 401 } 402 this := vindexTable.Keyspace 403 if this == nil || this.Sharded { 404 return nil, nil 405 } 406 if ks == nil { 407 ks = this 408 } else { 409 if ks != this { 410 return nil, nil 411 } 412 } 413 tables = append(tables, vindexTable) 414 } 415 return ks, tables 416 } 417 418 // EqualsExpr compares two expressions using the semantic analysis information. 419 // This means that we use the binding info to recognize that two ColName's can point to the same 420 // table column even though they are written differently. Example would be the `foobar` column in the following query: 421 // `SELECT foobar FROM tbl ORDER BY tbl.foobar` 422 // The expression in the select list is not equal to the one in the ORDER BY, 423 // but they point to the same column and would be considered equal by this method 424 func (st *SemTable) EqualsExpr(a, b sqlparser.Expr) bool { 425 return st.ASTEquals().Expr(a, b) 426 } 427 428 func (st *SemTable) ContainsExpr(e sqlparser.Expr, expres []sqlparser.Expr) bool { 429 for _, expre := range expres { 430 if st.EqualsExpr(e, expre) { 431 return true 432 } 433 } 434 return false 435 } 436 437 // AndExpressions ands together two or more expressions, minimising the expr when possible 438 func (st *SemTable) AndExpressions(exprs ...sqlparser.Expr) sqlparser.Expr { 439 switch len(exprs) { 440 case 0: 441 return nil 442 case 1: 443 return exprs[0] 444 default: 445 result := (sqlparser.Expr)(nil) 446 outer: 447 // we'll loop and remove any duplicates 448 for i, expr := range exprs { 449 if expr == nil { 450 continue 451 } 452 if result == nil { 453 result = expr 454 continue outer 455 } 456 457 for j := 0; j < i; j++ { 458 if st.EqualsExpr(expr, exprs[j]) { 459 continue outer 460 } 461 } 462 result = &sqlparser.AndExpr{Left: result, Right: expr} 463 } 464 return result 465 } 466 } 467 468 // ASTEquals returns a sqlparser.Comparator that uses the semantic information in this SemTable to 469 // explicitly compare column names for equality. 470 func (st *SemTable) ASTEquals() *sqlparser.Comparator { 471 if st.comparator == nil { 472 st.comparator = &sqlparser.Comparator{ 473 RefOfColName_: func(a, b *sqlparser.ColName) bool { 474 aDeps := st.RecursiveDeps(a) 475 bDeps := st.RecursiveDeps(b) 476 if aDeps != bDeps && (aDeps.IsEmpty() || bDeps.IsEmpty()) { 477 // if we don't know, we don't know 478 return sqlparser.Equals.RefOfColName(a, b) 479 } 480 return a.Name.Equal(b.Name) && aDeps == bDeps 481 }, 482 } 483 } 484 return st.comparator 485 }