github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/plan/flatten_subquery.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    19  	"github.com/matrixorigin/matrixone/pkg/container/types"
    20  	"github.com/matrixorigin/matrixone/pkg/pb/plan"
    21  	"github.com/matrixorigin/matrixone/pkg/sql/plan/function"
    22  )
    23  
    24  var (
    25  	constTrue = &plan.Expr{
    26  		Expr: &plan.Expr_C{
    27  			C: &plan.Const{
    28  				Isnull: false,
    29  				Value: &plan.Const_Bval{
    30  					Bval: true,
    31  				},
    32  			},
    33  		},
    34  		Typ: &plan.Type{
    35  			Id:          int32(types.T_bool),
    36  			NotNullable: true,
    37  			Size:        1,
    38  		},
    39  	}
    40  )
    41  
    42  func (builder *QueryBuilder) flattenSubqueries(nodeID int32, expr *plan.Expr, ctx *BindContext) (int32, *plan.Expr, error) {
    43  	var err error
    44  
    45  	switch exprImpl := expr.Expr.(type) {
    46  	case *plan.Expr_F:
    47  		for i, arg := range exprImpl.F.Args {
    48  			nodeID, exprImpl.F.Args[i], err = builder.flattenSubqueries(nodeID, arg, ctx)
    49  			if err != nil {
    50  				return 0, nil, err
    51  			}
    52  		}
    53  
    54  	case *plan.Expr_Sub:
    55  		nodeID, expr, err = builder.flattenSubquery(nodeID, exprImpl.Sub, ctx)
    56  	}
    57  
    58  	return nodeID, expr, err
    59  }
    60  
    61  func (builder *QueryBuilder) flattenSubquery(nodeID int32, subquery *plan.SubqueryRef, ctx *BindContext) (int32, *plan.Expr, error) {
    62  	subID := subquery.NodeId
    63  	subCtx := builder.ctxByNode[subID]
    64  
    65  	subID, preds, err := builder.pullupCorrelatedPredicates(subID, subCtx)
    66  	if err != nil {
    67  		return 0, nil, err
    68  	}
    69  
    70  	if subquery.Typ == plan.SubqueryRef_SCALAR && len(subCtx.aggregates) > 0 && builder.findNonEqPred(preds) {
    71  		return 0, nil, moerr.NewNYI(builder.GetContext(), "aggregation with non equal predicate in %s subquery  will be supported in future version", subquery.Typ.String())
    72  	}
    73  
    74  	filterPreds, joinPreds := decreaseDepthAndDispatch(preds)
    75  
    76  	if len(filterPreds) > 0 && subquery.Typ >= plan.SubqueryRef_SCALAR {
    77  		return 0, nil, moerr.NewNYI(builder.GetContext(), "correlated columns in %s subquery deeper than 1 level will be supported in future version", subquery.Typ.String())
    78  	}
    79  
    80  	switch subquery.Typ {
    81  	case plan.SubqueryRef_SCALAR:
    82  		var rewrite bool
    83  		// Uncorrelated subquery
    84  		if len(joinPreds) == 0 {
    85  			joinPreds = append(joinPreds, constTrue)
    86  		} else if builder.findAggrCount(subCtx.aggregates) {
    87  			rewrite = true
    88  		}
    89  
    90  		joinType := plan.Node_SINGLE
    91  		if subCtx.hasSingleRow {
    92  			joinType = plan.Node_LEFT
    93  		}
    94  
    95  		nodeID = builder.appendNode(&plan.Node{
    96  			NodeType: plan.Node_JOIN,
    97  			Children: []int32{nodeID, subID},
    98  			JoinType: joinType,
    99  			OnList:   joinPreds,
   100  		}, ctx)
   101  
   102  		if len(filterPreds) > 0 {
   103  			nodeID = builder.appendNode(&plan.Node{
   104  				NodeType:   plan.Node_FILTER,
   105  				Children:   []int32{nodeID},
   106  				FilterList: filterPreds,
   107  			}, ctx)
   108  		}
   109  
   110  		retExpr := &plan.Expr{
   111  			Typ: subCtx.results[0].Typ,
   112  			Expr: &plan.Expr_Col{
   113  				Col: &plan.ColRef{
   114  					RelPos: subCtx.rootTag(),
   115  					ColPos: 0,
   116  				},
   117  			},
   118  		}
   119  		if rewrite {
   120  			argsType := make([]types.Type, 1)
   121  			argsType[0] = makeTypeByPlan2Expr(retExpr)
   122  			funcID, returnType, _, _ := function.GetFunctionByName(builder.GetContext(), "isnull", argsType)
   123  			isNullExpr := &Expr{
   124  				Expr: &plan.Expr_F{
   125  					F: &plan.Function{
   126  						Func: getFunctionObjRef(funcID, "isnull"),
   127  						Args: []*Expr{retExpr},
   128  					},
   129  				},
   130  				Typ: makePlan2Type(&returnType),
   131  			}
   132  			zeroExpr := makePlan2Int64ConstExprWithType(0)
   133  			argsType = make([]types.Type, 3)
   134  			argsType[0] = makeTypeByPlan2Expr(isNullExpr)
   135  			argsType[1] = makeTypeByPlan2Expr(zeroExpr)
   136  			argsType[2] = makeTypeByPlan2Expr(retExpr)
   137  			funcID, returnType, _, _ = function.GetFunctionByName(builder.GetContext(), "case", argsType)
   138  			retExpr = &Expr{
   139  				Expr: &plan.Expr_F{
   140  					F: &plan.Function{
   141  						Func: getFunctionObjRef(funcID, "case"),
   142  						Args: []*Expr{isNullExpr, zeroExpr, DeepCopyExpr(retExpr)},
   143  					},
   144  				},
   145  				Typ: makePlan2Type(&returnType),
   146  			}
   147  		}
   148  		return nodeID, retExpr, nil
   149  
   150  	case plan.SubqueryRef_EXISTS:
   151  		// Uncorrelated subquery
   152  		if len(joinPreds) == 0 {
   153  			joinPreds = append(joinPreds, constTrue)
   154  		}
   155  
   156  		return builder.insertMarkJoin(nodeID, subID, joinPreds, nil, false, ctx)
   157  
   158  	case plan.SubqueryRef_NOT_EXISTS:
   159  		// Uncorrelated subquery
   160  		if len(joinPreds) == 0 {
   161  			joinPreds = append(joinPreds, constTrue)
   162  		}
   163  
   164  		return builder.insertMarkJoin(nodeID, subID, joinPreds, nil, true, ctx)
   165  
   166  	case plan.SubqueryRef_IN:
   167  		outerPred, err := builder.generateComparison("=", subquery.Child, subCtx)
   168  		if err != nil {
   169  			return 0, nil, err
   170  		}
   171  
   172  		return builder.insertMarkJoin(nodeID, subID, joinPreds, outerPred, false, ctx)
   173  
   174  	case plan.SubqueryRef_NOT_IN:
   175  		outerPred, err := builder.generateComparison("=", subquery.Child, subCtx)
   176  		if err != nil {
   177  			return 0, nil, err
   178  		}
   179  
   180  		return builder.insertMarkJoin(nodeID, subID, joinPreds, outerPred, true, ctx)
   181  
   182  	case plan.SubqueryRef_ANY:
   183  		outerPred, err := builder.generateComparison(subquery.Op, subquery.Child, subCtx)
   184  		if err != nil {
   185  			return 0, nil, err
   186  		}
   187  
   188  		return builder.insertMarkJoin(nodeID, subID, joinPreds, outerPred, false, ctx)
   189  
   190  	case plan.SubqueryRef_ALL:
   191  		outerPred, err := builder.generateComparison(subquery.Op, subquery.Child, subCtx)
   192  		if err != nil {
   193  			return 0, nil, err
   194  		}
   195  
   196  		outerPred, err = bindFuncExprImplByPlanExpr(builder.GetContext(), "not", []*plan.Expr{outerPred})
   197  		if err != nil {
   198  			return 0, nil, err
   199  		}
   200  
   201  		return builder.insertMarkJoin(nodeID, subID, joinPreds, outerPred, true, ctx)
   202  
   203  	default:
   204  		return 0, nil, moerr.NewNotSupported(builder.GetContext(), "%s subquery not supported", subquery.Typ.String())
   205  	}
   206  }
   207  
   208  func (builder *QueryBuilder) insertMarkJoin(left, right int32, joinPreds []*plan.Expr, outerPred *plan.Expr, negate bool, ctx *BindContext) (nodeID int32, markExpr *plan.Expr, err error) {
   209  	markTag := builder.genNewTag()
   210  
   211  	for i, pred := range joinPreds {
   212  		if !pred.Typ.NotNullable {
   213  			joinPreds[i], err = bindFuncExprImplByPlanExpr(builder.GetContext(), "istrue", []*plan.Expr{pred})
   214  			if err != nil {
   215  				return
   216  			}
   217  		}
   218  	}
   219  
   220  	notNull := true
   221  
   222  	if outerPred != nil {
   223  		joinPreds = append(joinPreds, outerPred)
   224  		notNull = outerPred.Typ.NotNullable
   225  	}
   226  
   227  	nodeID = builder.appendNode(&plan.Node{
   228  		NodeType:    plan.Node_JOIN,
   229  		Children:    []int32{left, right},
   230  		BindingTags: []int32{markTag},
   231  		JoinType:    plan.Node_MARK,
   232  		OnList:      joinPreds,
   233  	}, ctx)
   234  
   235  	markExpr = &plan.Expr{
   236  		Typ: &plan.Type{
   237  			Id:          int32(types.T_bool),
   238  			NotNullable: notNull,
   239  			Size:        1,
   240  		},
   241  		Expr: &plan.Expr_Col{
   242  			Col: &plan.ColRef{
   243  				RelPos: markTag,
   244  				ColPos: 0,
   245  			},
   246  		},
   247  	}
   248  
   249  	if negate {
   250  		markExpr, err = bindFuncExprImplByPlanExpr(builder.GetContext(), "not", []*plan.Expr{markExpr})
   251  	}
   252  
   253  	return
   254  }
   255  
   256  func (builder *QueryBuilder) generateComparison(op string, child *plan.Expr, ctx *BindContext) (*plan.Expr, error) {
   257  	switch childImpl := child.Expr.(type) {
   258  	case *plan.Expr_List:
   259  		childList := childImpl.List.List
   260  		switch op {
   261  		case "=":
   262  			leftExpr, err := bindFuncExprImplByPlanExpr(builder.GetContext(), op, []*plan.Expr{
   263  				childList[0],
   264  				{
   265  					Typ: ctx.results[0].Typ,
   266  					Expr: &plan.Expr_Col{
   267  						Col: &plan.ColRef{
   268  							RelPos: ctx.rootTag(),
   269  							ColPos: 0,
   270  						},
   271  					},
   272  				},
   273  			})
   274  			if err != nil {
   275  				return nil, err
   276  			}
   277  
   278  			for i := 1; i < len(childList); i++ {
   279  				rightExpr, err := bindFuncExprImplByPlanExpr(builder.GetContext(), op, []*plan.Expr{
   280  					childList[i],
   281  					{
   282  						Typ: ctx.results[i].Typ,
   283  						Expr: &plan.Expr_Col{
   284  							Col: &plan.ColRef{
   285  								RelPos: ctx.rootTag(),
   286  								ColPos: int32(i),
   287  							},
   288  						},
   289  					},
   290  				})
   291  				if err != nil {
   292  					return nil, err
   293  				}
   294  
   295  				leftExpr, err = bindFuncExprImplByPlanExpr(builder.GetContext(), "and", []*plan.Expr{leftExpr, rightExpr})
   296  				if err != nil {
   297  					return nil, err
   298  				}
   299  			}
   300  
   301  			return leftExpr, nil
   302  
   303  		case "<>":
   304  			leftExpr, err := bindFuncExprImplByPlanExpr(builder.GetContext(), op, []*plan.Expr{
   305  				childList[0],
   306  				{
   307  					Typ: ctx.results[0].Typ,
   308  					Expr: &plan.Expr_Col{
   309  						Col: &plan.ColRef{
   310  							RelPos: ctx.rootTag(),
   311  							ColPos: 0,
   312  						},
   313  					},
   314  				},
   315  			})
   316  			if err != nil {
   317  				return nil, err
   318  			}
   319  
   320  			for i := 1; i < len(childList); i++ {
   321  				rightExpr, err := bindFuncExprImplByPlanExpr(builder.GetContext(), op, []*plan.Expr{
   322  					childList[i],
   323  					{
   324  						Typ: ctx.results[i].Typ,
   325  						Expr: &plan.Expr_Col{
   326  							Col: &plan.ColRef{
   327  								RelPos: ctx.rootTag(),
   328  								ColPos: int32(i),
   329  							},
   330  						},
   331  					},
   332  				})
   333  				if err != nil {
   334  					return nil, err
   335  				}
   336  
   337  				leftExpr, err = bindFuncExprImplByPlanExpr(builder.GetContext(), "or", []*plan.Expr{leftExpr, rightExpr})
   338  				if err != nil {
   339  					return nil, err
   340  				}
   341  			}
   342  
   343  			return leftExpr, nil
   344  
   345  		case "<", "<=", ">", ">=":
   346  			projList := make([]*plan.Expr, len(childList))
   347  			for i := range projList {
   348  				projList[i] = &plan.Expr{
   349  					Typ: ctx.results[i].Typ,
   350  					Expr: &plan.Expr_Col{
   351  						Col: &plan.ColRef{
   352  							RelPos: ctx.rootTag(),
   353  							ColPos: int32(i),
   354  						},
   355  					},
   356  				}
   357  			}
   358  
   359  			nonEqOp := op[:1] // <= -> <, >= -> >
   360  			return unwindTupleComparison(builder.GetContext(), nonEqOp, op, childList, projList, 0)
   361  
   362  		default:
   363  			return nil, moerr.NewNotSupported(builder.GetContext(), "row constructor only support comparison operators")
   364  		}
   365  
   366  	default:
   367  		return bindFuncExprImplByPlanExpr(builder.GetContext(), op, []*plan.Expr{
   368  			child,
   369  			{
   370  				Typ: ctx.results[0].Typ,
   371  				Expr: &plan.Expr_Col{
   372  					Col: &plan.ColRef{
   373  						RelPos: ctx.rootTag(),
   374  						ColPos: 0,
   375  					},
   376  				},
   377  			},
   378  		})
   379  	}
   380  }
   381  
   382  func (builder *QueryBuilder) findAggrCount(aggrs []*plan.Expr) bool {
   383  	for _, aggr := range aggrs {
   384  		switch exprImpl := aggr.Expr.(type) {
   385  		case *plan.Expr_F:
   386  			if exprImpl.F.Func.ObjName == "count" || exprImpl.F.Func.ObjName == "starcount" {
   387  				return true
   388  			}
   389  		}
   390  	}
   391  	return false
   392  }
   393  
   394  func (builder *QueryBuilder) findNonEqPred(preds []*plan.Expr) bool {
   395  	for _, pred := range preds {
   396  		switch exprImpl := pred.Expr.(type) {
   397  		case *plan.Expr_F:
   398  			if exprImpl.F.Func.ObjName != "=" {
   399  				return true
   400  			}
   401  		}
   402  	}
   403  	return false
   404  }
   405  
   406  func (builder *QueryBuilder) pullupCorrelatedPredicates(nodeID int32, ctx *BindContext) (int32, []*plan.Expr, error) {
   407  	node := builder.qry.Nodes[nodeID]
   408  
   409  	var preds []*plan.Expr
   410  	var err error
   411  
   412  	var subPreds []*plan.Expr
   413  	for i, childID := range node.Children {
   414  		node.Children[i], subPreds, err = builder.pullupCorrelatedPredicates(childID, ctx)
   415  		if err != nil {
   416  			return 0, nil, err
   417  		}
   418  
   419  		preds = append(preds, subPreds...)
   420  	}
   421  
   422  	switch node.NodeType {
   423  	case plan.Node_AGG:
   424  		groupTag := node.BindingTags[0]
   425  		for _, pred := range preds {
   426  			builder.pullupThroughAgg(ctx, node, groupTag, pred)
   427  		}
   428  
   429  	case plan.Node_PROJECT:
   430  		projectTag := node.BindingTags[0]
   431  		for _, pred := range preds {
   432  			builder.pullupThroughProj(ctx, node, projectTag, pred)
   433  		}
   434  
   435  	case plan.Node_FILTER:
   436  		var newFilterList []*plan.Expr
   437  		for _, cond := range node.FilterList {
   438  			if hasCorrCol(cond) {
   439  				//cond, err = bindFuncExprImplByPlanExpr("is", []*plan.Expr{cond, DeepCopyExpr(constTrue)})
   440  				if err != nil {
   441  					return 0, nil, err
   442  				}
   443  				preds = append(preds, cond)
   444  			} else {
   445  				newFilterList = append(newFilterList, cond)
   446  			}
   447  		}
   448  
   449  		if len(newFilterList) == 0 {
   450  			nodeID = node.Children[0]
   451  		} else {
   452  			node.FilterList = newFilterList
   453  		}
   454  	}
   455  
   456  	return nodeID, preds, err
   457  }
   458  
   459  func (builder *QueryBuilder) pullupThroughAgg(ctx *BindContext, node *plan.Node, tag int32, expr *plan.Expr) *plan.Expr {
   460  	if !hasCorrCol(expr) {
   461  		switch expr.Expr.(type) {
   462  		case *plan.Expr_Col, *plan.Expr_F:
   463  			break
   464  
   465  		default:
   466  			return expr
   467  		}
   468  
   469  		colPos := int32(len(node.GroupBy))
   470  		node.GroupBy = append(node.GroupBy, expr)
   471  
   472  		if colRef, ok := expr.Expr.(*plan.Expr_Col); ok {
   473  			oldMapId := [2]int32{colRef.Col.RelPos, colRef.Col.ColPos}
   474  			newMapId := [2]int32{tag, colPos}
   475  
   476  			builder.nameByColRef[newMapId] = builder.nameByColRef[oldMapId]
   477  		}
   478  
   479  		return &plan.Expr{
   480  			Typ: expr.Typ,
   481  			Expr: &plan.Expr_Col{
   482  				Col: &plan.ColRef{
   483  					RelPos: tag,
   484  					ColPos: colPos,
   485  				},
   486  			},
   487  		}
   488  	}
   489  
   490  	switch exprImpl := expr.Expr.(type) {
   491  	case *plan.Expr_F:
   492  		for i, arg := range exprImpl.F.Args {
   493  			exprImpl.F.Args[i] = builder.pullupThroughAgg(ctx, node, tag, arg)
   494  		}
   495  	}
   496  
   497  	return expr
   498  }
   499  
   500  func (builder *QueryBuilder) pullupThroughProj(ctx *BindContext, node *plan.Node, tag int32, expr *plan.Expr) *plan.Expr {
   501  	if !hasCorrCol(expr) {
   502  		switch expr.Expr.(type) {
   503  		case *plan.Expr_Col, *plan.Expr_F:
   504  			break
   505  
   506  		default:
   507  			return expr
   508  		}
   509  
   510  		colPos := int32(len(node.ProjectList))
   511  		node.ProjectList = append(node.ProjectList, expr)
   512  
   513  		if colRef, ok := expr.Expr.(*plan.Expr_Col); ok {
   514  			oldMapId := [2]int32{colRef.Col.RelPos, colRef.Col.ColPos}
   515  			newMapId := [2]int32{tag, colPos}
   516  
   517  			builder.nameByColRef[newMapId] = builder.nameByColRef[oldMapId]
   518  		}
   519  
   520  		return &plan.Expr{
   521  			Typ: expr.Typ,
   522  			Expr: &plan.Expr_Col{
   523  				Col: &plan.ColRef{
   524  					RelPos: tag,
   525  					ColPos: colPos,
   526  				},
   527  			},
   528  		}
   529  	}
   530  
   531  	switch exprImpl := expr.Expr.(type) {
   532  	case *plan.Expr_F:
   533  		for i, arg := range exprImpl.F.Args {
   534  			exprImpl.F.Args[i] = builder.pullupThroughProj(ctx, node, tag, arg)
   535  		}
   536  	}
   537  
   538  	return expr
   539  }