github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/unnest_insubqueries.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package analyzer 16 17 import ( 18 "fmt" 19 "strings" 20 21 "github.com/dolthub/go-mysql-server/sql" 22 "github.com/dolthub/go-mysql-server/sql/expression" 23 "github.com/dolthub/go-mysql-server/sql/plan" 24 "github.com/dolthub/go-mysql-server/sql/transform" 25 ) 26 27 type applyJoin struct { 28 l sql.Expression 29 r *plan.Subquery 30 op plan.JoinType 31 filter sql.Expression 32 original sql.Expression 33 max1 bool 34 } 35 36 // unnestInSubqueries converts expression.Comparer with an *plan.InSubquery 37 // RHS into joins. The match conditions include: 1) subquery is cacheable, 38 // 2) the top-level subquery projection is a get field with a sql.ColumnId 39 // and sql.TableId (to support join reordering). 40 // TODO decorrelate lhs too 41 // TODO non-null-rejecting with dual table 42 func unnestInSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) { 43 switch n.(type) { 44 case *plan.DeleteFrom, *plan.InsertInto: 45 return n, transform.SameTree, nil 46 } 47 48 var unnested bool 49 var aliases map[string]int 50 51 ret := n 52 var err error 53 same := transform.NewTree 54 for !same { 55 // simplifySubqExpr can merge two scopes, requiring us to either 56 // recurse on the merged scope or perform a fixed-point iteration. 57 ret, same, err = transform.Node(ret, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { 58 var filters []sql.Expression 59 var child sql.Node 60 switch n := n.(type) { 61 case *plan.Filter: 62 child = n.Child 63 filters = expression.SplitConjunction(n.Expression) 64 default: 65 } 66 67 if sel == nil { 68 return n, transform.SameTree, nil 69 } 70 71 var matches []applyJoin 72 var newFilters []sql.Expression 73 74 // separate decorrelation candidates 75 for _, e := range filters { 76 if !plan.IsNullRejecting(e) { 77 // TODO: rewrite dual table to permit in-scope joins, 78 // which aren't possible when values are projected 79 // above join filter 80 rt := getResolvedTable(n) 81 if rt == nil || plan.IsDualTable(rt.Table) { 82 newFilters = append(newFilters, e) 83 continue 84 } 85 } 86 87 candE := e 88 op := plan.JoinTypeSemi 89 if n, ok := e.(*expression.Not); ok { 90 candE = n.Child 91 op = plan.JoinTypeAnti 92 } 93 94 var sq *plan.Subquery 95 var l sql.Expression 96 var joinF sql.Expression 97 var max1 bool 98 switch e := candE.(type) { 99 case *plan.InSubquery: 100 sq, _ = e.RightChild.(*plan.Subquery) 101 l = e.LeftChild 102 103 joinF = expression.NewEquals(nil, nil) 104 case expression.Comparer: 105 sq, _ = e.Right().(*plan.Subquery) 106 l = e.Left() 107 joinF = e 108 max1 = true 109 default: 110 } 111 if sq != nil && sq.CanCacheResults() { 112 matches = append(matches, applyJoin{l: l, r: sq, op: op, filter: joinF, max1: max1, original: candE}) 113 } else { 114 newFilters = append(newFilters, e) 115 } 116 } 117 if len(matches) == 0 { 118 return n, transform.SameTree, nil 119 } 120 121 ret := child 122 for _, m := range matches { 123 // A successful candidate is built with: 124 // (1) Semi or anti join between the outer scope and (2) conditioned on (3). 125 // (2) Simplified or unnested subquery (table alias). 126 // (3) Join condition synthesized from the original correlated expression 127 // normalized to match changes to (2). 128 subq := m.r 129 130 if aliases == nil { 131 aliases = make(map[string]int) 132 ta, err := getTableAliases(n, scope) 133 if err != nil { 134 return n, transform.SameTree, err 135 } 136 for k, _ := range ta.aliases { 137 aliases[k] = 0 138 } 139 } 140 141 var newSubq sql.Node 142 newSubq, aliases, err = disambiguateTables(aliases, subq.Query) 143 if err != nil { 144 return ret, transform.SameTree, nil 145 } 146 147 rightF, ok, err := getHighestProjection(newSubq) 148 if err != nil { 149 return n, transform.SameTree, err 150 } 151 if !ok { 152 newFilters = append(newFilters, m.original) 153 continue 154 } 155 156 filter, err := m.filter.WithChildren(m.l, rightF) 157 if err != nil { 158 return n, transform.SameTree, err 159 } 160 var comment string 161 if c, ok := ret.(sql.CommentedNode); ok { 162 comment = c.Comment() 163 } 164 unnested = true 165 newJoin := plan.NewJoin(ret, newSubq, m.op, filter) 166 ret = newJoin.WithComment(comment) 167 } 168 169 if len(newFilters) == 0 { 170 return ret, transform.NewTree, nil 171 } 172 if len(newFilters) == len(filters) { 173 return n, transform.SameTree, nil 174 } 175 return plan.NewFilter(expression.JoinAnd(newFilters...), ret), transform.NewTree, nil 176 }) 177 if err != nil { 178 return n, transform.SameTree, err 179 } 180 } 181 return ret, transform.TreeIdentity(!unnested), nil 182 } 183 184 // returns an updated sql.Node with aliases de-duplicated, and an 185 // updated alias mapping with new conflicts and tables added. 186 func disambiguateTables(used map[string]int, n sql.Node) (sql.Node, map[string]int, error) { 187 rename := make(map[sql.TableId]string) 188 n, _, err := transform.NodeWithCtx(n, nil, func(c transform.Context) (sql.Node, transform.TreeIdentity, error) { 189 switch n := c.Node.(type) { 190 case sql.RenameableNode: 191 name := strings.ToLower(n.Name()) 192 if _, ok := c.Parent.(sql.RenameableNode); ok { 193 // skip checking when: TableAlias(ResolvedTable) 194 return n, transform.SameTree, nil 195 } 196 if cnt, ok := used[name]; ok { 197 used[name] = cnt + 1 198 newName := name 199 for ok { 200 cnt++ 201 newName = fmt.Sprintf("%s_%d", name, cnt) 202 _, ok = used[newName] 203 204 } 205 used[newName] = 0 206 207 tin, ok := n.(plan.TableIdNode) 208 if !ok { 209 return n, transform.SameTree, fmt.Errorf("expected sql.Renameable to implement plan.TableIdNode") 210 } 211 rename[tin.Id()] = newName 212 return n.WithName(newName), transform.NewTree, nil 213 } else { 214 used[name] = 0 215 } 216 return n, transform.NewTree, nil 217 default: 218 return transform.NodeExprs(n, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { 219 switch e := e.(type) { 220 case *expression.GetField: 221 if cnt, ok := used[strings.ToLower(e.Table())]; ok && cnt > 0 { 222 return e.WithTable(fmt.Sprintf("%s_%d", e.Table(), cnt)), transform.NewTree, nil 223 } 224 default: 225 } 226 return e, transform.NewTree, nil 227 }) 228 } 229 }) 230 if err != nil { 231 return nil, nil, err 232 } 233 if len(rename) > 0 { 234 n, _, err = renameExpressionTables(n, rename) 235 } 236 return n, used, err 237 } 238 239 // renameExpressionTables renames table references recursively. We use 240 // table ids to avoid improperly renaming tables in lower scopes with the 241 // same name. 242 func renameExpressionTables(n sql.Node, rename map[sql.TableId]string) (sql.Node, transform.TreeIdentity, error) { 243 return transform.NodeExprs(n, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { 244 switch e := e.(type) { 245 case *expression.GetField: 246 if to, ok := rename[e.TableId()]; ok { 247 return e.WithTable(to), transform.NewTree, nil 248 } 249 case *plan.Subquery: 250 newQ, same, err := renameExpressionTables(e.Query, rename) 251 if !same || err != nil { 252 return e, same, err 253 } 254 return e.WithQuery(newQ), transform.NewTree, nil 255 default: 256 } 257 return e, transform.NewTree, nil 258 }) 259 } 260 261 // getHighestProjection returns a set of projection expressions responsible 262 // for the input node's schema, or false if an aggregate or set type is 263 // found (which we cannot generate named projections for yet). 264 func getHighestProjection(n sql.Node) (sql.Expression, bool, error) { 265 sch := n.Schema() 266 for n != nil { 267 if !sch.Equals(n.Schema()) { 268 break 269 } 270 var proj []sql.Expression 271 switch nn := n.(type) { 272 case *plan.Project: 273 proj = nn.Projections 274 case *plan.JoinNode: 275 left, ok, err := getHighestProjection(nn.Left()) 276 if err != nil { 277 return nil, false, err 278 } 279 if !ok { 280 return nil, false, nil 281 } 282 right, ok, err := getHighestProjection(nn.Right()) 283 if err != nil { 284 return nil, false, err 285 } 286 if !ok { 287 return nil, false, nil 288 } 289 switch e := left.(type) { 290 case expression.Tuple: 291 proj = append(proj, e.Children()...) 292 default: 293 proj = append(proj, e) 294 } 295 switch e := right.(type) { 296 case expression.Tuple: 297 proj = append(proj, e.Children()...) 298 default: 299 proj = append(proj, e) 300 } 301 case *plan.GroupBy: 302 // todo(max): could make better effort to get column ids from these, 303 // but real fix is also giving synthesized projection column ids 304 // in binder 305 proj = nn.SelectedExprs 306 case *plan.Window: 307 proj = nn.SelectExprs 308 case *plan.SetOp: 309 return nil, false, nil 310 case plan.TableIdNode: 311 colset := nn.Columns() 312 idx := 0 313 sch := n.Schema() 314 for id, hasNext := colset.Next(1); hasNext; id, hasNext = colset.Next(id + 1) { 315 col := sch[idx] 316 proj = append(proj, expression.NewGetFieldWithTable(int(id), int(nn.Id()), col.Type, col.DatabaseSource, col.Source, col.Name, col.Nullable)) 317 idx++ 318 } 319 default: 320 if len(nn.Children()) == 1 { 321 n = nn.Children()[0] 322 continue 323 } 324 } 325 if proj == nil { 326 break 327 } 328 projCopy := make([]sql.Expression, len(proj)) 329 copy(projCopy, proj) 330 for i, p := range projCopy { 331 if a, ok := p.(*expression.Alias); ok { 332 if a.Unreferencable() || a.Id() == 0 { 333 return nil, false, nil 334 } 335 projCopy[i] = expression.NewGetField(int(a.Id()), a.Type(), a.Name(), a.IsNullable()) 336 } 337 } 338 if len(projCopy) == 1 { 339 return projCopy[0], true, nil 340 } 341 return expression.NewTuple(projCopy...), true, nil 342 } 343 return nil, false, fmt.Errorf("failed to find decorrelation projection") 344 }