github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/plan/flatten_subquery.go (about) 1 // Copyright 2022 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package plan 16 17 import ( 18 "github.com/matrixorigin/matrixone/pkg/common/moerr" 19 "github.com/matrixorigin/matrixone/pkg/container/types" 20 "github.com/matrixorigin/matrixone/pkg/pb/plan" 21 "github.com/matrixorigin/matrixone/pkg/sql/plan/function" 22 ) 23 24 var ( 25 constTrue = &plan.Expr{ 26 Expr: &plan.Expr_C{ 27 C: &plan.Const{ 28 Isnull: false, 29 Value: &plan.Const_Bval{ 30 Bval: true, 31 }, 32 }, 33 }, 34 Typ: &plan.Type{ 35 Id: int32(types.T_bool), 36 NotNullable: true, 37 Size: 1, 38 }, 39 } 40 ) 41 42 func (builder *QueryBuilder) flattenSubqueries(nodeID int32, expr *plan.Expr, ctx *BindContext) (int32, *plan.Expr, error) { 43 var err error 44 45 switch exprImpl := expr.Expr.(type) { 46 case *plan.Expr_F: 47 for i, arg := range exprImpl.F.Args { 48 nodeID, exprImpl.F.Args[i], err = builder.flattenSubqueries(nodeID, arg, ctx) 49 if err != nil { 50 return 0, nil, err 51 } 52 } 53 54 case *plan.Expr_Sub: 55 nodeID, expr, err = builder.flattenSubquery(nodeID, exprImpl.Sub, ctx) 56 } 57 58 return nodeID, expr, err 59 } 60 61 func (builder *QueryBuilder) flattenSubquery(nodeID int32, subquery *plan.SubqueryRef, ctx *BindContext) (int32, *plan.Expr, error) { 62 subID := subquery.NodeId 63 subCtx := builder.ctxByNode[subID] 64 65 subID, preds, err := builder.pullupCorrelatedPredicates(subID, subCtx) 66 if err != nil { 67 return 0, nil, err 68 } 69 70 if subquery.Typ == plan.SubqueryRef_SCALAR && len(subCtx.aggregates) > 0 && builder.findNonEqPred(preds) { 71 return 0, nil, moerr.NewNYI(builder.GetContext(), "aggregation with non equal predicate in %s subquery will be supported in future version", subquery.Typ.String()) 72 } 73 74 filterPreds, joinPreds := decreaseDepthAndDispatch(preds) 75 76 if len(filterPreds) > 0 && subquery.Typ >= plan.SubqueryRef_SCALAR { 77 return 0, nil, moerr.NewNYI(builder.GetContext(), "correlated columns in %s subquery deeper than 1 level will be supported in future version", subquery.Typ.String()) 78 } 79 80 switch subquery.Typ { 81 case plan.SubqueryRef_SCALAR: 82 var rewrite bool 83 // Uncorrelated subquery 84 if len(joinPreds) == 0 { 85 joinPreds = append(joinPreds, constTrue) 86 } else if builder.findAggrCount(subCtx.aggregates) { 87 rewrite = true 88 } 89 90 joinType := plan.Node_SINGLE 91 if subCtx.hasSingleRow { 92 joinType = plan.Node_LEFT 93 } 94 95 nodeID = builder.appendNode(&plan.Node{ 96 NodeType: plan.Node_JOIN, 97 Children: []int32{nodeID, subID}, 98 JoinType: joinType, 99 OnList: joinPreds, 100 }, ctx) 101 102 if len(filterPreds) > 0 { 103 nodeID = builder.appendNode(&plan.Node{ 104 NodeType: plan.Node_FILTER, 105 Children: []int32{nodeID}, 106 FilterList: filterPreds, 107 }, ctx) 108 } 109 110 retExpr := &plan.Expr{ 111 Typ: subCtx.results[0].Typ, 112 Expr: &plan.Expr_Col{ 113 Col: &plan.ColRef{ 114 RelPos: subCtx.rootTag(), 115 ColPos: 0, 116 }, 117 }, 118 } 119 if rewrite { 120 argsType := make([]types.Type, 1) 121 argsType[0] = makeTypeByPlan2Expr(retExpr) 122 funcID, returnType, _, _ := function.GetFunctionByName(builder.GetContext(), "isnull", argsType) 123 isNullExpr := &Expr{ 124 Expr: &plan.Expr_F{ 125 F: &plan.Function{ 126 Func: getFunctionObjRef(funcID, "isnull"), 127 Args: []*Expr{retExpr}, 128 }, 129 }, 130 Typ: makePlan2Type(&returnType), 131 } 132 zeroExpr := makePlan2Int64ConstExprWithType(0) 133 argsType = make([]types.Type, 3) 134 argsType[0] = makeTypeByPlan2Expr(isNullExpr) 135 argsType[1] = makeTypeByPlan2Expr(zeroExpr) 136 argsType[2] = makeTypeByPlan2Expr(retExpr) 137 funcID, returnType, _, _ = function.GetFunctionByName(builder.GetContext(), "case", argsType) 138 retExpr = &Expr{ 139 Expr: &plan.Expr_F{ 140 F: &plan.Function{ 141 Func: getFunctionObjRef(funcID, "case"), 142 Args: []*Expr{isNullExpr, zeroExpr, DeepCopyExpr(retExpr)}, 143 }, 144 }, 145 Typ: makePlan2Type(&returnType), 146 } 147 } 148 return nodeID, retExpr, nil 149 150 case plan.SubqueryRef_EXISTS: 151 // Uncorrelated subquery 152 if len(joinPreds) == 0 { 153 joinPreds = append(joinPreds, constTrue) 154 } 155 156 return builder.insertMarkJoin(nodeID, subID, joinPreds, nil, false, ctx) 157 158 case plan.SubqueryRef_NOT_EXISTS: 159 // Uncorrelated subquery 160 if len(joinPreds) == 0 { 161 joinPreds = append(joinPreds, constTrue) 162 } 163 164 return builder.insertMarkJoin(nodeID, subID, joinPreds, nil, true, ctx) 165 166 case plan.SubqueryRef_IN: 167 outerPred, err := builder.generateComparison("=", subquery.Child, subCtx) 168 if err != nil { 169 return 0, nil, err 170 } 171 172 return builder.insertMarkJoin(nodeID, subID, joinPreds, outerPred, false, ctx) 173 174 case plan.SubqueryRef_NOT_IN: 175 outerPred, err := builder.generateComparison("=", subquery.Child, subCtx) 176 if err != nil { 177 return 0, nil, err 178 } 179 180 return builder.insertMarkJoin(nodeID, subID, joinPreds, outerPred, true, ctx) 181 182 case plan.SubqueryRef_ANY: 183 outerPred, err := builder.generateComparison(subquery.Op, subquery.Child, subCtx) 184 if err != nil { 185 return 0, nil, err 186 } 187 188 return builder.insertMarkJoin(nodeID, subID, joinPreds, outerPred, false, ctx) 189 190 case plan.SubqueryRef_ALL: 191 outerPred, err := builder.generateComparison(subquery.Op, subquery.Child, subCtx) 192 if err != nil { 193 return 0, nil, err 194 } 195 196 outerPred, err = bindFuncExprImplByPlanExpr(builder.GetContext(), "not", []*plan.Expr{outerPred}) 197 if err != nil { 198 return 0, nil, err 199 } 200 201 return builder.insertMarkJoin(nodeID, subID, joinPreds, outerPred, true, ctx) 202 203 default: 204 return 0, nil, moerr.NewNotSupported(builder.GetContext(), "%s subquery not supported", subquery.Typ.String()) 205 } 206 } 207 208 func (builder *QueryBuilder) insertMarkJoin(left, right int32, joinPreds []*plan.Expr, outerPred *plan.Expr, negate bool, ctx *BindContext) (nodeID int32, markExpr *plan.Expr, err error) { 209 markTag := builder.genNewTag() 210 211 for i, pred := range joinPreds { 212 if !pred.Typ.NotNullable { 213 joinPreds[i], err = bindFuncExprImplByPlanExpr(builder.GetContext(), "istrue", []*plan.Expr{pred}) 214 if err != nil { 215 return 216 } 217 } 218 } 219 220 notNull := true 221 222 if outerPred != nil { 223 joinPreds = append(joinPreds, outerPred) 224 notNull = outerPred.Typ.NotNullable 225 } 226 227 nodeID = builder.appendNode(&plan.Node{ 228 NodeType: plan.Node_JOIN, 229 Children: []int32{left, right}, 230 BindingTags: []int32{markTag}, 231 JoinType: plan.Node_MARK, 232 OnList: joinPreds, 233 }, ctx) 234 235 markExpr = &plan.Expr{ 236 Typ: &plan.Type{ 237 Id: int32(types.T_bool), 238 NotNullable: notNull, 239 Size: 1, 240 }, 241 Expr: &plan.Expr_Col{ 242 Col: &plan.ColRef{ 243 RelPos: markTag, 244 ColPos: 0, 245 }, 246 }, 247 } 248 249 if negate { 250 markExpr, err = bindFuncExprImplByPlanExpr(builder.GetContext(), "not", []*plan.Expr{markExpr}) 251 } 252 253 return 254 } 255 256 func (builder *QueryBuilder) generateComparison(op string, child *plan.Expr, ctx *BindContext) (*plan.Expr, error) { 257 switch childImpl := child.Expr.(type) { 258 case *plan.Expr_List: 259 childList := childImpl.List.List 260 switch op { 261 case "=": 262 leftExpr, err := bindFuncExprImplByPlanExpr(builder.GetContext(), op, []*plan.Expr{ 263 childList[0], 264 { 265 Typ: ctx.results[0].Typ, 266 Expr: &plan.Expr_Col{ 267 Col: &plan.ColRef{ 268 RelPos: ctx.rootTag(), 269 ColPos: 0, 270 }, 271 }, 272 }, 273 }) 274 if err != nil { 275 return nil, err 276 } 277 278 for i := 1; i < len(childList); i++ { 279 rightExpr, err := bindFuncExprImplByPlanExpr(builder.GetContext(), op, []*plan.Expr{ 280 childList[i], 281 { 282 Typ: ctx.results[i].Typ, 283 Expr: &plan.Expr_Col{ 284 Col: &plan.ColRef{ 285 RelPos: ctx.rootTag(), 286 ColPos: int32(i), 287 }, 288 }, 289 }, 290 }) 291 if err != nil { 292 return nil, err 293 } 294 295 leftExpr, err = bindFuncExprImplByPlanExpr(builder.GetContext(), "and", []*plan.Expr{leftExpr, rightExpr}) 296 if err != nil { 297 return nil, err 298 } 299 } 300 301 return leftExpr, nil 302 303 case "<>": 304 leftExpr, err := bindFuncExprImplByPlanExpr(builder.GetContext(), op, []*plan.Expr{ 305 childList[0], 306 { 307 Typ: ctx.results[0].Typ, 308 Expr: &plan.Expr_Col{ 309 Col: &plan.ColRef{ 310 RelPos: ctx.rootTag(), 311 ColPos: 0, 312 }, 313 }, 314 }, 315 }) 316 if err != nil { 317 return nil, err 318 } 319 320 for i := 1; i < len(childList); i++ { 321 rightExpr, err := bindFuncExprImplByPlanExpr(builder.GetContext(), op, []*plan.Expr{ 322 childList[i], 323 { 324 Typ: ctx.results[i].Typ, 325 Expr: &plan.Expr_Col{ 326 Col: &plan.ColRef{ 327 RelPos: ctx.rootTag(), 328 ColPos: int32(i), 329 }, 330 }, 331 }, 332 }) 333 if err != nil { 334 return nil, err 335 } 336 337 leftExpr, err = bindFuncExprImplByPlanExpr(builder.GetContext(), "or", []*plan.Expr{leftExpr, rightExpr}) 338 if err != nil { 339 return nil, err 340 } 341 } 342 343 return leftExpr, nil 344 345 case "<", "<=", ">", ">=": 346 projList := make([]*plan.Expr, len(childList)) 347 for i := range projList { 348 projList[i] = &plan.Expr{ 349 Typ: ctx.results[i].Typ, 350 Expr: &plan.Expr_Col{ 351 Col: &plan.ColRef{ 352 RelPos: ctx.rootTag(), 353 ColPos: int32(i), 354 }, 355 }, 356 } 357 } 358 359 nonEqOp := op[:1] // <= -> <, >= -> > 360 return unwindTupleComparison(builder.GetContext(), nonEqOp, op, childList, projList, 0) 361 362 default: 363 return nil, moerr.NewNotSupported(builder.GetContext(), "row constructor only support comparison operators") 364 } 365 366 default: 367 return bindFuncExprImplByPlanExpr(builder.GetContext(), op, []*plan.Expr{ 368 child, 369 { 370 Typ: ctx.results[0].Typ, 371 Expr: &plan.Expr_Col{ 372 Col: &plan.ColRef{ 373 RelPos: ctx.rootTag(), 374 ColPos: 0, 375 }, 376 }, 377 }, 378 }) 379 } 380 } 381 382 func (builder *QueryBuilder) findAggrCount(aggrs []*plan.Expr) bool { 383 for _, aggr := range aggrs { 384 switch exprImpl := aggr.Expr.(type) { 385 case *plan.Expr_F: 386 if exprImpl.F.Func.ObjName == "count" || exprImpl.F.Func.ObjName == "starcount" { 387 return true 388 } 389 } 390 } 391 return false 392 } 393 394 func (builder *QueryBuilder) findNonEqPred(preds []*plan.Expr) bool { 395 for _, pred := range preds { 396 switch exprImpl := pred.Expr.(type) { 397 case *plan.Expr_F: 398 if exprImpl.F.Func.ObjName != "=" { 399 return true 400 } 401 } 402 } 403 return false 404 } 405 406 func (builder *QueryBuilder) pullupCorrelatedPredicates(nodeID int32, ctx *BindContext) (int32, []*plan.Expr, error) { 407 node := builder.qry.Nodes[nodeID] 408 409 var preds []*plan.Expr 410 var err error 411 412 var subPreds []*plan.Expr 413 for i, childID := range node.Children { 414 node.Children[i], subPreds, err = builder.pullupCorrelatedPredicates(childID, ctx) 415 if err != nil { 416 return 0, nil, err 417 } 418 419 preds = append(preds, subPreds...) 420 } 421 422 switch node.NodeType { 423 case plan.Node_AGG: 424 groupTag := node.BindingTags[0] 425 for _, pred := range preds { 426 builder.pullupThroughAgg(ctx, node, groupTag, pred) 427 } 428 429 case plan.Node_PROJECT: 430 projectTag := node.BindingTags[0] 431 for _, pred := range preds { 432 builder.pullupThroughProj(ctx, node, projectTag, pred) 433 } 434 435 case plan.Node_FILTER: 436 var newFilterList []*plan.Expr 437 for _, cond := range node.FilterList { 438 if hasCorrCol(cond) { 439 //cond, err = bindFuncExprImplByPlanExpr("is", []*plan.Expr{cond, DeepCopyExpr(constTrue)}) 440 if err != nil { 441 return 0, nil, err 442 } 443 preds = append(preds, cond) 444 } else { 445 newFilterList = append(newFilterList, cond) 446 } 447 } 448 449 if len(newFilterList) == 0 { 450 nodeID = node.Children[0] 451 } else { 452 node.FilterList = newFilterList 453 } 454 } 455 456 return nodeID, preds, err 457 } 458 459 func (builder *QueryBuilder) pullupThroughAgg(ctx *BindContext, node *plan.Node, tag int32, expr *plan.Expr) *plan.Expr { 460 if !hasCorrCol(expr) { 461 switch expr.Expr.(type) { 462 case *plan.Expr_Col, *plan.Expr_F: 463 break 464 465 default: 466 return expr 467 } 468 469 colPos := int32(len(node.GroupBy)) 470 node.GroupBy = append(node.GroupBy, expr) 471 472 if colRef, ok := expr.Expr.(*plan.Expr_Col); ok { 473 oldMapId := [2]int32{colRef.Col.RelPos, colRef.Col.ColPos} 474 newMapId := [2]int32{tag, colPos} 475 476 builder.nameByColRef[newMapId] = builder.nameByColRef[oldMapId] 477 } 478 479 return &plan.Expr{ 480 Typ: expr.Typ, 481 Expr: &plan.Expr_Col{ 482 Col: &plan.ColRef{ 483 RelPos: tag, 484 ColPos: colPos, 485 }, 486 }, 487 } 488 } 489 490 switch exprImpl := expr.Expr.(type) { 491 case *plan.Expr_F: 492 for i, arg := range exprImpl.F.Args { 493 exprImpl.F.Args[i] = builder.pullupThroughAgg(ctx, node, tag, arg) 494 } 495 } 496 497 return expr 498 } 499 500 func (builder *QueryBuilder) pullupThroughProj(ctx *BindContext, node *plan.Node, tag int32, expr *plan.Expr) *plan.Expr { 501 if !hasCorrCol(expr) { 502 switch expr.Expr.(type) { 503 case *plan.Expr_Col, *plan.Expr_F: 504 break 505 506 default: 507 return expr 508 } 509 510 colPos := int32(len(node.ProjectList)) 511 node.ProjectList = append(node.ProjectList, expr) 512 513 if colRef, ok := expr.Expr.(*plan.Expr_Col); ok { 514 oldMapId := [2]int32{colRef.Col.RelPos, colRef.Col.ColPos} 515 newMapId := [2]int32{tag, colPos} 516 517 builder.nameByColRef[newMapId] = builder.nameByColRef[oldMapId] 518 } 519 520 return &plan.Expr{ 521 Typ: expr.Typ, 522 Expr: &plan.Expr_Col{ 523 Col: &plan.ColRef{ 524 RelPos: tag, 525 ColPos: colPos, 526 }, 527 }, 528 } 529 } 530 531 switch exprImpl := expr.Expr.(type) { 532 case *plan.Expr_F: 533 for i, arg := range exprImpl.F.Args { 534 exprImpl.F.Args[i] = builder.pullupThroughProj(ctx, node, tag, arg) 535 } 536 } 537 538 return expr 539 }