github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/plan/flatten_subquery.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    19  	"github.com/matrixorigin/matrixone/pkg/container/types"
    20  	"github.com/matrixorigin/matrixone/pkg/pb/plan"
    21  	"github.com/matrixorigin/matrixone/pkg/sql/plan/function"
    22  )
    23  
    24  var (
    25  	constTrue = &plan.Expr{
    26  		Expr: &plan.Expr_Lit{
    27  			Lit: &plan.Literal{
    28  				Value: &plan.Literal_Bval{
    29  					Bval: true,
    30  				},
    31  			},
    32  		},
    33  		Typ: plan.Type{
    34  			Id:          int32(types.T_bool),
    35  			NotNullable: true,
    36  		},
    37  	}
    38  
    39  	constFalse = &plan.Expr{
    40  		Expr: &plan.Expr_Lit{
    41  			Lit: &plan.Literal{
    42  				Value: &plan.Literal_Bval{
    43  					Bval: false,
    44  				},
    45  			},
    46  		},
    47  		Typ: plan.Type{
    48  			Id:          int32(types.T_bool),
    49  			NotNullable: true,
    50  		},
    51  	}
    52  )
    53  
    54  func (builder *QueryBuilder) flattenSubqueries(nodeID int32, expr *plan.Expr, ctx *BindContext) (int32, *plan.Expr, error) {
    55  	var err error
    56  
    57  	switch exprImpl := expr.Expr.(type) {
    58  	case *plan.Expr_F:
    59  		for i, arg := range exprImpl.F.Args {
    60  			nodeID, exprImpl.F.Args[i], err = builder.flattenSubqueries(nodeID, arg, ctx)
    61  			if err != nil {
    62  				return 0, nil, err
    63  			}
    64  		}
    65  
    66  	case *plan.Expr_Sub:
    67  		if builder.isForUpdate {
    68  			return 0, nil, moerr.NewInternalError(builder.GetContext(), "not support subquery for update")
    69  		}
    70  		nodeID, expr, err = builder.flattenSubquery(nodeID, exprImpl.Sub, ctx)
    71  	}
    72  
    73  	return nodeID, expr, err
    74  }
    75  
    76  func (builder *QueryBuilder) flattenSubquery(nodeID int32, subquery *plan.SubqueryRef, ctx *BindContext) (int32, *plan.Expr, error) {
    77  	if subquery.Child != nil && hasSubquery(subquery.Child) {
    78  		return 0, nil, moerr.NewNotSupported(builder.GetContext(), "a quantified subquery's left operand can't contain subquery")
    79  	}
    80  
    81  	subID := subquery.NodeId
    82  	subCtx := builder.ctxByNode[subID]
    83  
    84  	// Strip unnecessary subqueries which have no FROM clause
    85  	subNode := builder.qry.Nodes[subID]
    86  	if subNode.NodeType == plan.Node_PROJECT && builder.qry.Nodes[subNode.Children[0]].NodeType == plan.Node_VALUE_SCAN {
    87  		switch subquery.Typ {
    88  		case plan.SubqueryRef_SCALAR:
    89  			newProj, _ := decreaseDepth(subNode.ProjectList[0])
    90  			return nodeID, newProj, nil
    91  
    92  		case plan.SubqueryRef_EXISTS:
    93  			return nodeID, constTrue, nil
    94  
    95  		case plan.SubqueryRef_NOT_EXISTS:
    96  			return nodeID, constFalse, nil
    97  
    98  		case plan.SubqueryRef_IN:
    99  			newExpr, err := builder.generateRowComparison("=", subquery.Child, subCtx, true)
   100  			if err != nil {
   101  				return 0, nil, err
   102  			}
   103  
   104  			return nodeID, newExpr, nil
   105  
   106  		case plan.SubqueryRef_NOT_IN:
   107  			newExpr, err := builder.generateRowComparison("<>", subquery.Child, subCtx, true)
   108  			if err != nil {
   109  				return 0, nil, err
   110  			}
   111  
   112  			return nodeID, newExpr, nil
   113  
   114  		case plan.SubqueryRef_ANY, plan.SubqueryRef_ALL:
   115  			newExpr, err := builder.generateRowComparison(subquery.Op, subquery.Child, subCtx, true)
   116  			if err != nil {
   117  				return 0, nil, err
   118  			}
   119  
   120  			return nodeID, newExpr, nil
   121  
   122  		default:
   123  			return 0, nil, moerr.NewNotSupported(builder.GetContext(), "%s subquery not supported", subquery.Typ.String())
   124  		}
   125  	}
   126  
   127  	subID, preds, err := builder.pullupCorrelatedPredicates(subID, subCtx)
   128  	if err != nil {
   129  		return 0, nil, err
   130  	}
   131  
   132  	if subquery.Typ == plan.SubqueryRef_SCALAR && len(subCtx.aggregates) > 0 && builder.findNonEqPred(preds) {
   133  		return 0, nil, moerr.NewNYI(builder.GetContext(), "aggregation with non equal predicate in %s subquery  will be supported in future version", subquery.Typ.String())
   134  	}
   135  
   136  	filterPreds, joinPreds := decreaseDepthAndDispatch(preds)
   137  
   138  	if len(filterPreds) > 0 && subquery.Typ >= plan.SubqueryRef_SCALAR {
   139  		return 0, nil, moerr.NewNYI(builder.GetContext(), "correlated columns in %s subquery deeper than 1 level will be supported in future version", subquery.Typ.String())
   140  	}
   141  
   142  	switch subquery.Typ {
   143  	case plan.SubqueryRef_SCALAR:
   144  		var rewrite bool
   145  		// Uncorrelated subquery
   146  		if len(joinPreds) > 0 && builder.findAggrCount(subCtx.aggregates) {
   147  			rewrite = true
   148  		}
   149  
   150  		joinType := plan.Node_SINGLE
   151  		if subCtx.hasSingleRow {
   152  			joinType = plan.Node_LEFT
   153  		}
   154  
   155  		nodeID = builder.appendNode(&plan.Node{
   156  			NodeType: plan.Node_JOIN,
   157  			Children: []int32{nodeID, subID},
   158  			JoinType: joinType,
   159  			OnList:   joinPreds,
   160  		}, ctx)
   161  
   162  		if len(filterPreds) > 0 {
   163  			nodeID = builder.appendNode(&plan.Node{
   164  				NodeType:   plan.Node_FILTER,
   165  				Children:   []int32{nodeID},
   166  				FilterList: filterPreds,
   167  			}, ctx)
   168  		}
   169  
   170  		retExpr := &plan.Expr{
   171  			Typ: subCtx.results[0].Typ,
   172  			Expr: &plan.Expr_Col{
   173  				Col: &plan.ColRef{
   174  					RelPos: subCtx.topTag(),
   175  					ColPos: 0,
   176  				},
   177  			},
   178  		}
   179  		if rewrite {
   180  			argsType := make([]types.Type, 1)
   181  			argsType[0] = makeTypeByPlan2Expr(retExpr)
   182  			fGet, err := function.GetFunctionByName(builder.GetContext(), "isnull", argsType)
   183  			if err != nil {
   184  				return nodeID, retExpr, err
   185  			}
   186  			funcID, returnType := fGet.GetEncodedOverloadID(), fGet.GetReturnType()
   187  			isNullExpr := &Expr{
   188  				Expr: &plan.Expr_F{
   189  					F: &plan.Function{
   190  						Func: getFunctionObjRef(funcID, "isnull"),
   191  						Args: []*Expr{retExpr},
   192  					},
   193  				},
   194  				Typ: makePlan2Type(&returnType),
   195  			}
   196  			zeroExpr := makePlan2Int64ConstExprWithType(0)
   197  			argsType = make([]types.Type, 3)
   198  			argsType[0] = makeTypeByPlan2Expr(isNullExpr)
   199  			argsType[1] = makeTypeByPlan2Expr(zeroExpr)
   200  			argsType[2] = makeTypeByPlan2Expr(retExpr)
   201  			fGet, err = function.GetFunctionByName(builder.GetContext(), "case", argsType)
   202  			if err != nil {
   203  				return nodeID, retExpr, nil
   204  			}
   205  			funcID, returnType = fGet.GetEncodedOverloadID(), fGet.GetReturnType()
   206  			retExpr = &Expr{
   207  				Expr: &plan.Expr_F{
   208  					F: &plan.Function{
   209  						Func: getFunctionObjRef(funcID, "case"),
   210  						Args: []*Expr{isNullExpr, zeroExpr, DeepCopyExpr(retExpr)},
   211  					},
   212  				},
   213  				Typ: makePlan2Type(&returnType),
   214  			}
   215  		}
   216  		return nodeID, retExpr, nil
   217  
   218  	case plan.SubqueryRef_EXISTS:
   219  		// Uncorrelated subquery
   220  		if len(joinPreds) == 0 {
   221  			joinPreds = append(joinPreds, constTrue)
   222  		}
   223  
   224  		return builder.insertMarkJoin(nodeID, subID, joinPreds, nil, false, ctx)
   225  
   226  	case plan.SubqueryRef_NOT_EXISTS:
   227  		// Uncorrelated subquery
   228  		if len(joinPreds) == 0 {
   229  			joinPreds = append(joinPreds, constTrue)
   230  		}
   231  
   232  		return builder.insertMarkJoin(nodeID, subID, joinPreds, nil, true, ctx)
   233  
   234  	case plan.SubqueryRef_IN:
   235  		outerPred, err := builder.generateRowComparison("=", subquery.Child, subCtx, false)
   236  		if err != nil {
   237  			return 0, nil, err
   238  		}
   239  
   240  		return builder.insertMarkJoin(nodeID, subID, joinPreds, outerPred, false, ctx)
   241  
   242  	case plan.SubqueryRef_NOT_IN:
   243  		outerPred, err := builder.generateRowComparison("=", subquery.Child, subCtx, false)
   244  		if err != nil {
   245  			return 0, nil, err
   246  		}
   247  
   248  		return builder.insertMarkJoin(nodeID, subID, joinPreds, outerPred, true, ctx)
   249  
   250  	case plan.SubqueryRef_ANY:
   251  		outerPred, err := builder.generateRowComparison(subquery.Op, subquery.Child, subCtx, false)
   252  		if err != nil {
   253  			return 0, nil, err
   254  		}
   255  
   256  		return builder.insertMarkJoin(nodeID, subID, joinPreds, outerPred, false, ctx)
   257  
   258  	case plan.SubqueryRef_ALL:
   259  		outerPred, err := builder.generateRowComparison(subquery.Op, subquery.Child, subCtx, false)
   260  		if err != nil {
   261  			return 0, nil, err
   262  		}
   263  
   264  		outerPred, err = BindFuncExprImplByPlanExpr(builder.GetContext(), "not", []*plan.Expr{outerPred})
   265  		if err != nil {
   266  			return 0, nil, err
   267  		}
   268  
   269  		return builder.insertMarkJoin(nodeID, subID, joinPreds, outerPred, true, ctx)
   270  
   271  	default:
   272  		return 0, nil, moerr.NewNotSupported(builder.GetContext(), "%s subquery not supported", subquery.Typ.String())
   273  	}
   274  }
   275  
   276  func (builder *QueryBuilder) insertMarkJoin(left, right int32, joinPreds []*plan.Expr, outerPred *plan.Expr, negate bool, ctx *BindContext) (nodeID int32, markExpr *plan.Expr, err error) {
   277  	markTag := builder.genNewTag()
   278  
   279  	for i, pred := range joinPreds {
   280  		if !pred.Typ.NotNullable {
   281  			joinPreds[i], err = BindFuncExprImplByPlanExpr(builder.GetContext(), "istrue", []*plan.Expr{pred})
   282  			if err != nil {
   283  				return
   284  			}
   285  		}
   286  	}
   287  
   288  	notNull := true
   289  
   290  	if outerPred != nil {
   291  		joinPreds = append(joinPreds, outerPred)
   292  		notNull = outerPred.Typ.NotNullable
   293  	}
   294  
   295  	nodeID = builder.appendNode(&plan.Node{
   296  		NodeType:    plan.Node_JOIN,
   297  		Children:    []int32{left, right},
   298  		BindingTags: []int32{markTag},
   299  		JoinType:    plan.Node_MARK,
   300  		OnList:      joinPreds,
   301  	}, ctx)
   302  
   303  	markExpr = &plan.Expr{
   304  		Typ: plan.Type{
   305  			Id:          int32(types.T_bool),
   306  			NotNullable: notNull,
   307  		},
   308  		Expr: &plan.Expr_Col{
   309  			Col: &plan.ColRef{
   310  				RelPos: markTag,
   311  				ColPos: 0,
   312  			},
   313  		},
   314  	}
   315  
   316  	if negate {
   317  		markExpr, err = BindFuncExprImplByPlanExpr(builder.GetContext(), "not", []*plan.Expr{markExpr})
   318  	}
   319  
   320  	return
   321  }
   322  
   323  func getProjectExpr(idx int, ctx *BindContext, strip bool) *plan.Expr {
   324  	if strip {
   325  		newProj, _ := decreaseDepth(ctx.results[idx])
   326  		return newProj
   327  	} else {
   328  		return &plan.Expr{
   329  			Typ: ctx.results[idx].Typ,
   330  			Expr: &plan.Expr_Col{
   331  				Col: &plan.ColRef{
   332  					RelPos: ctx.rootTag(),
   333  					ColPos: int32(idx),
   334  				},
   335  			},
   336  		}
   337  	}
   338  }
   339  
   340  func (builder *QueryBuilder) generateRowComparison(op string, child *plan.Expr, ctx *BindContext, strip bool) (*plan.Expr, error) {
   341  	switch childImpl := child.Expr.(type) {
   342  	case *plan.Expr_List:
   343  		childList := childImpl.List.List
   344  		switch op {
   345  		case "=":
   346  			leftExpr, err := BindFuncExprImplByPlanExpr(builder.GetContext(), op, []*plan.Expr{
   347  				childList[0],
   348  				getProjectExpr(0, ctx, strip),
   349  			})
   350  			if err != nil {
   351  				return nil, err
   352  			}
   353  
   354  			for i := 1; i < len(childList); i++ {
   355  				rightExpr, err := BindFuncExprImplByPlanExpr(builder.GetContext(), op, []*plan.Expr{
   356  					childList[i],
   357  					getProjectExpr(i, ctx, strip),
   358  				})
   359  				if err != nil {
   360  					return nil, err
   361  				}
   362  
   363  				leftExpr, err = BindFuncExprImplByPlanExpr(builder.GetContext(), "and", []*plan.Expr{leftExpr, rightExpr})
   364  				if err != nil {
   365  					return nil, err
   366  				}
   367  			}
   368  
   369  			return leftExpr, nil
   370  
   371  		case "<>":
   372  			leftExpr, err := BindFuncExprImplByPlanExpr(builder.GetContext(), op, []*plan.Expr{
   373  				childList[0],
   374  				getProjectExpr(0, ctx, strip),
   375  			})
   376  			if err != nil {
   377  				return nil, err
   378  			}
   379  
   380  			for i := 1; i < len(childList); i++ {
   381  				rightExpr, err := BindFuncExprImplByPlanExpr(builder.GetContext(), op, []*plan.Expr{
   382  					childList[i],
   383  					getProjectExpr(i, ctx, strip),
   384  				})
   385  				if err != nil {
   386  					return nil, err
   387  				}
   388  
   389  				leftExpr, err = BindFuncExprImplByPlanExpr(builder.GetContext(), "or", []*plan.Expr{leftExpr, rightExpr})
   390  				if err != nil {
   391  					return nil, err
   392  				}
   393  			}
   394  
   395  			return leftExpr, nil
   396  
   397  		case "<", "<=", ">", ">=":
   398  			projList := make([]*plan.Expr, len(childList))
   399  			for i := range projList {
   400  				projList[i] = getProjectExpr(i, ctx, strip)
   401  
   402  			}
   403  
   404  			nonEqOp := op[:1] // <= -> <, >= -> >
   405  			return unwindTupleComparison(builder.GetContext(), nonEqOp, op, childList, projList, 0)
   406  
   407  		default:
   408  			return nil, moerr.NewNotSupported(builder.GetContext(), "row constructor only support comparison operators")
   409  		}
   410  
   411  	default:
   412  		return BindFuncExprImplByPlanExpr(builder.GetContext(), op, []*plan.Expr{
   413  			child,
   414  			getProjectExpr(0, ctx, strip),
   415  		})
   416  	}
   417  }
   418  
   419  func (builder *QueryBuilder) findAggrCount(aggrs []*plan.Expr) bool {
   420  	for _, aggr := range aggrs {
   421  		switch exprImpl := aggr.Expr.(type) {
   422  		case *plan.Expr_F:
   423  			if exprImpl.F.Func.ObjName == "count" || exprImpl.F.Func.ObjName == "starcount" {
   424  				return true
   425  			}
   426  		}
   427  	}
   428  	return false
   429  }
   430  
   431  func (builder *QueryBuilder) findNonEqPred(preds []*plan.Expr) bool {
   432  	for _, pred := range preds {
   433  		switch exprImpl := pred.Expr.(type) {
   434  		case *plan.Expr_F:
   435  			if exprImpl.F.Func.ObjName != "=" {
   436  				return true
   437  			}
   438  		}
   439  	}
   440  	return false
   441  }
   442  
   443  func (builder *QueryBuilder) pullupCorrelatedPredicates(nodeID int32, ctx *BindContext) (int32, []*plan.Expr, error) {
   444  	node := builder.qry.Nodes[nodeID]
   445  
   446  	var preds []*plan.Expr
   447  	var err error
   448  
   449  	var subPreds []*plan.Expr
   450  	for i, childID := range node.Children {
   451  		node.Children[i], subPreds, err = builder.pullupCorrelatedPredicates(childID, ctx)
   452  		if err != nil {
   453  			return 0, nil, err
   454  		}
   455  
   456  		preds = append(preds, subPreds...)
   457  	}
   458  
   459  	switch node.NodeType {
   460  	case plan.Node_AGG:
   461  		groupTag := node.BindingTags[0]
   462  		for _, pred := range preds {
   463  			builder.pullupThroughAgg(ctx, node, groupTag, pred)
   464  		}
   465  
   466  	case plan.Node_PROJECT:
   467  		projectTag := node.BindingTags[0]
   468  		for _, pred := range preds {
   469  			builder.pullupThroughProj(ctx, node, projectTag, pred)
   470  		}
   471  
   472  	case plan.Node_FILTER:
   473  		var newFilterList []*plan.Expr
   474  		for _, cond := range node.FilterList {
   475  			if hasCorrCol(cond) {
   476  				//cond, err = bindFuncExprImplByPlanExpr("is", []*plan.Expr{cond, DeepCopyExpr(constTrue)})
   477  				if err != nil {
   478  					return 0, nil, err
   479  				}
   480  				preds = append(preds, cond)
   481  			} else {
   482  				newFilterList = append(newFilterList, cond)
   483  			}
   484  		}
   485  
   486  		if len(newFilterList) == 0 {
   487  			nodeID = node.Children[0]
   488  		} else {
   489  			node.FilterList = newFilterList
   490  		}
   491  	}
   492  
   493  	return nodeID, preds, err
   494  }
   495  
   496  func (builder *QueryBuilder) pullupThroughAgg(ctx *BindContext, node *plan.Node, tag int32, expr *plan.Expr) *plan.Expr {
   497  	if !hasCorrCol(expr) {
   498  		switch expr.Expr.(type) {
   499  		case *plan.Expr_Col, *plan.Expr_F:
   500  			break
   501  
   502  		default:
   503  			return expr
   504  		}
   505  
   506  		colPos := int32(len(node.GroupBy))
   507  		node.GroupBy = append(node.GroupBy, expr)
   508  
   509  		if colRef, ok := expr.Expr.(*plan.Expr_Col); ok {
   510  			oldMapId := [2]int32{colRef.Col.RelPos, colRef.Col.ColPos}
   511  			newMapId := [2]int32{tag, colPos}
   512  
   513  			builder.nameByColRef[newMapId] = builder.nameByColRef[oldMapId]
   514  		}
   515  
   516  		return &plan.Expr{
   517  			Typ: expr.Typ,
   518  			Expr: &plan.Expr_Col{
   519  				Col: &plan.ColRef{
   520  					RelPos: tag,
   521  					ColPos: colPos,
   522  				},
   523  			},
   524  		}
   525  	}
   526  
   527  	switch exprImpl := expr.Expr.(type) {
   528  	case *plan.Expr_F:
   529  		for i, arg := range exprImpl.F.Args {
   530  			exprImpl.F.Args[i] = builder.pullupThroughAgg(ctx, node, tag, arg)
   531  		}
   532  	}
   533  
   534  	return expr
   535  }
   536  
   537  func (builder *QueryBuilder) pullupThroughProj(ctx *BindContext, node *plan.Node, tag int32, expr *plan.Expr) *plan.Expr {
   538  	if !hasCorrCol(expr) {
   539  		switch expr.Expr.(type) {
   540  		case *plan.Expr_Col, *plan.Expr_F:
   541  			break
   542  
   543  		default:
   544  			return expr
   545  		}
   546  
   547  		colPos := int32(len(node.ProjectList))
   548  		node.ProjectList = append(node.ProjectList, expr)
   549  
   550  		if colRef, ok := expr.Expr.(*plan.Expr_Col); ok {
   551  			oldMapId := [2]int32{colRef.Col.RelPos, colRef.Col.ColPos}
   552  			newMapId := [2]int32{tag, colPos}
   553  
   554  			builder.nameByColRef[newMapId] = builder.nameByColRef[oldMapId]
   555  		}
   556  
   557  		return &plan.Expr{
   558  			Typ: expr.Typ,
   559  			Expr: &plan.Expr_Col{
   560  				Col: &plan.ColRef{
   561  					RelPos: tag,
   562  					ColPos: colPos,
   563  				},
   564  			},
   565  		}
   566  	}
   567  
   568  	switch exprImpl := expr.Expr.(type) {
   569  	case *plan.Expr_F:
   570  		for i, arg := range exprImpl.F.Args {
   571  			exprImpl.F.Args[i] = builder.pullupThroughProj(ctx, node, tag, arg)
   572  		}
   573  	}
   574  
   575  	return expr
   576  }