github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/plan/join_order.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"sort"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/pb/plan"
    21  	"github.com/matrixorigin/matrixone/pkg/sql/plan/function"
    22  )
    23  
    24  type joinEdge struct {
    25  	leftCols  []int32
    26  	rightCols []int32
    27  }
    28  
    29  type joinVertex struct {
    30  	node     *plan.Node
    31  	children map[int32]bool
    32  	parent   int32
    33  	joined   bool
    34  }
    35  
    36  func (builder *QueryBuilder) pushdownSemiAntiJoins(nodeID int32) int32 {
    37  	if builder.optimizerHints != nil && builder.optimizerHints.pushDownSemiAntiJoins != 0 {
    38  		return nodeID
    39  	}
    40  	// TODO: handle SEMI/ANTI joins in join order
    41  	node := builder.qry.Nodes[nodeID]
    42  
    43  	for i, childID := range node.Children {
    44  		node.Children[i] = builder.pushdownSemiAntiJoins(childID)
    45  	}
    46  
    47  	if node.NodeType != plan.Node_JOIN || (node.JoinType != plan.Node_SEMI && node.JoinType != plan.Node_ANTI) {
    48  		return nodeID
    49  	}
    50  
    51  	var targetNode *plan.Node
    52  	var targetSide int32
    53  
    54  	joinNode := builder.qry.Nodes[node.Children[0]]
    55  
    56  	semiAntiStat := builder.qry.Nodes[node.Children[1]].Stats
    57  
    58  	for {
    59  		if joinNode.NodeType != plan.Node_JOIN {
    60  			break
    61  		}
    62  
    63  		leftTags := make(map[int32]bool)
    64  		for _, tag := range builder.enumerateTags(joinNode.Children[0]) {
    65  			leftTags[tag] = true
    66  		}
    67  
    68  		rightTags := make(map[int32]bool)
    69  		for _, tag := range builder.enumerateTags(joinNode.Children[1]) {
    70  			rightTags[tag] = true
    71  		}
    72  
    73  		var joinSide int8
    74  		for _, cond := range node.OnList {
    75  			joinSide |= getJoinSide(cond, leftTags, rightTags, 0)
    76  		}
    77  
    78  		// TODO: This logic is problematic. Use this threshold right now just for TPC-H
    79  		ratio := 2.0
    80  		if joinNode.JoinType == plan.Node_SEMI || joinNode.JoinType == plan.Node_ANTI {
    81  			ratio = 1.0
    82  		}
    83  
    84  		if joinSide == JoinSideLeft {
    85  			if semiAntiStat.Selectivity*ratio > builder.qry.Nodes[joinNode.Children[1]].Stats.Selectivity {
    86  				break
    87  			}
    88  			targetNode = joinNode
    89  			targetSide = 0
    90  			joinNode = builder.qry.Nodes[joinNode.Children[0]]
    91  		} else if joinNode.JoinType == plan.Node_INNER && joinSide == JoinSideRight {
    92  			if semiAntiStat.Selectivity*ratio > builder.qry.Nodes[joinNode.Children[0]].Stats.Selectivity {
    93  				break
    94  			}
    95  			targetNode = joinNode
    96  			targetSide = 1
    97  			joinNode = builder.qry.Nodes[joinNode.Children[1]]
    98  		} else {
    99  			break
   100  		}
   101  	}
   102  
   103  	if targetNode != nil {
   104  		nodeID = node.Children[0]
   105  		node.Children[0] = targetNode.Children[targetSide]
   106  		targetNode.Children[targetSide] = node.NodeId
   107  	}
   108  
   109  	return nodeID
   110  }
   111  
   112  func (builder *QueryBuilder) IsEquiJoin(node *plan.Node) bool {
   113  	if node.NodeType != plan.Node_JOIN {
   114  		return false
   115  	}
   116  
   117  	leftTags := make(map[int32]bool)
   118  	for _, tag := range builder.enumerateTags(node.Children[0]) {
   119  		leftTags[tag] = true
   120  	}
   121  
   122  	rightTags := make(map[int32]bool)
   123  	for _, tag := range builder.enumerateTags(node.Children[1]) {
   124  		rightTags[tag] = true
   125  	}
   126  
   127  	for _, expr := range node.OnList {
   128  		if equi := isEquiCond(expr, leftTags, rightTags); equi {
   129  			return true
   130  		}
   131  	}
   132  	return false
   133  }
   134  
   135  func isEquiCond(expr *plan.Expr, leftTags, rightTags map[int32]bool) bool {
   136  	if e, ok := expr.Expr.(*plan.Expr_F); ok {
   137  		if !IsEqualFunc(e.F.Func.GetObj()) {
   138  			return false
   139  		}
   140  
   141  		lside, rside := getJoinSide(e.F.Args[0], leftTags, rightTags, 0), getJoinSide(e.F.Args[1], leftTags, rightTags, 0)
   142  		if lside == JoinSideLeft && rside == JoinSideRight {
   143  			return true
   144  		} else if lside == JoinSideRight && rside == JoinSideLeft {
   145  			// swap to make sure left and right is in order
   146  			e.F.Args[0], e.F.Args[1] = e.F.Args[1], e.F.Args[0]
   147  			return true
   148  		}
   149  	}
   150  
   151  	return false
   152  }
   153  
   154  // IsEquiJoin2 Judge whether a join node is equi-join (after column remapping)
   155  // Can only be used after optimizer!!!
   156  func IsEquiJoin2(exprs []*plan.Expr) bool {
   157  	for _, expr := range exprs {
   158  		if e, ok := expr.Expr.(*plan.Expr_F); ok {
   159  			if !IsEqualFunc(e.F.Func.GetObj()) {
   160  				continue
   161  			}
   162  			lpos, rpos := HasColExpr(e.F.Args[0], -1), HasColExpr(e.F.Args[1], -1)
   163  			if lpos == -1 || rpos == -1 || (lpos == rpos) {
   164  				continue
   165  			}
   166  			return true
   167  		}
   168  	}
   169  	return false
   170  }
   171  
   172  func IsEqualFunc(id int64) bool {
   173  	fid, _ := function.DecodeOverloadID(id)
   174  	return fid == function.EQUAL
   175  }
   176  
   177  func HasColExpr(expr *plan.Expr, pos int32) int32 {
   178  	switch e := expr.Expr.(type) {
   179  	case *plan.Expr_Col:
   180  		if pos == -1 {
   181  			return e.Col.RelPos
   182  		}
   183  		if pos != e.Col.RelPos {
   184  			return -1
   185  		}
   186  		return pos
   187  	case *plan.Expr_F:
   188  		for i := range e.F.Args {
   189  			pos0 := HasColExpr(e.F.Args[i], pos)
   190  			switch {
   191  			case pos0 == -1:
   192  			case pos == -1:
   193  				pos = pos0
   194  			case pos != pos0:
   195  				return -1
   196  			}
   197  		}
   198  		return pos
   199  	default:
   200  		return pos
   201  	}
   202  }
   203  
   204  func (builder *QueryBuilder) determineJoinOrder(nodeID int32) int32 {
   205  	if builder.optimizerHints != nil && builder.optimizerHints.joinOrdering != 0 {
   206  		return nodeID
   207  	}
   208  	node := builder.qry.Nodes[nodeID]
   209  
   210  	if node.NodeType != plan.Node_JOIN || node.JoinType != plan.Node_INNER {
   211  		if len(node.Children) > 0 {
   212  			for i, child := range node.Children {
   213  				node.Children[i] = builder.determineJoinOrder(child)
   214  			}
   215  		}
   216  		return nodeID
   217  	}
   218  
   219  	if builder.qry.Nodes[node.Children[1]].NodeType == plan.Node_FUNCTION_SCAN {
   220  		return nodeID
   221  	}
   222  
   223  	leaves, conds := builder.gatherJoinLeavesAndConds(node, nil, nil)
   224  	newConds := deduceNewOnList(conds)
   225  	conds = append(conds, newConds...)
   226  	vertices := builder.getJoinGraph(leaves, conds)
   227  
   228  	subTrees := make([]*plan.Node, 0, len(leaves))
   229  	for i, vertex := range vertices {
   230  		// TODO handle cycles in the "dimension -> fact" DAG
   231  		if vertex.parent == -1 {
   232  			builder.buildSubJoinTree(vertices, int32(i))
   233  			subTrees = append(subTrees, vertex.node)
   234  		}
   235  	}
   236  	for _, vertex := range vertices {
   237  		if !vertex.joined {
   238  			subTrees = append(subTrees, vertex.node)
   239  		}
   240  	}
   241  
   242  	sort.Slice(subTrees, func(i, j int) bool { return compareStats(subTrees[i].Stats, subTrees[j].Stats) })
   243  
   244  	leafByTag := make(map[int32]int32)
   245  
   246  	for i, leaf := range subTrees {
   247  		tags := builder.enumerateTags(leaf.NodeId)
   248  
   249  		for _, tag := range tags {
   250  			leafByTag[tag] = int32(i)
   251  		}
   252  	}
   253  
   254  	nLeaf := int32(len(subTrees))
   255  
   256  	adjMat := make([]bool, nLeaf*nLeaf)
   257  	firstConnected := nLeaf
   258  	visited := make([]bool, nLeaf)
   259  
   260  	for _, cond := range conds {
   261  		hyperEdge := make(map[int32]bool)
   262  		getHyperEdgeFromExpr(cond, leafByTag, hyperEdge)
   263  
   264  		for i := range hyperEdge {
   265  			if i < firstConnected {
   266  				firstConnected = i
   267  			}
   268  			for j := range hyperEdge {
   269  				adjMat[int32(nLeaf)*i+j] = true
   270  			}
   271  		}
   272  	}
   273  
   274  	if firstConnected < nLeaf {
   275  		nodeID = subTrees[firstConnected].NodeId
   276  		visited[firstConnected] = true
   277  
   278  		eligible := adjMat[firstConnected*nLeaf : (firstConnected+1)*nLeaf]
   279  
   280  		for {
   281  			nextSibling := nLeaf
   282  			for i := range eligible {
   283  				if !visited[i] && eligible[i] {
   284  					nextSibling = int32(i)
   285  					break
   286  				}
   287  			}
   288  
   289  			if nextSibling == nLeaf {
   290  				break
   291  			}
   292  
   293  			visited[nextSibling] = true
   294  
   295  			children := []int32{nodeID, subTrees[nextSibling].NodeId}
   296  			nodeID = builder.appendNode(&plan.Node{
   297  				NodeType: plan.Node_JOIN,
   298  				Children: children,
   299  				JoinType: plan.Node_INNER,
   300  			}, nil)
   301  
   302  			for i, adj := range adjMat[nextSibling*nLeaf : (nextSibling+1)*nLeaf] {
   303  				eligible[i] = eligible[i] || adj
   304  			}
   305  		}
   306  
   307  		for i := range visited {
   308  			if !visited[i] {
   309  				nodeID = builder.appendNode(&plan.Node{
   310  					NodeType: plan.Node_JOIN,
   311  					Children: []int32{nodeID, subTrees[i].NodeId},
   312  					JoinType: plan.Node_INNER,
   313  				}, nil)
   314  			}
   315  		}
   316  	} else {
   317  		newNode := subTrees[0]
   318  		nodeID = newNode.NodeId
   319  
   320  		for i := 1; i < len(subTrees); i++ {
   321  			children := []int32{nodeID, subTrees[i].NodeId}
   322  			nodeID = builder.appendNode(&plan.Node{
   323  				NodeType: plan.Node_JOIN,
   324  				Children: children,
   325  				JoinType: plan.Node_INNER,
   326  			}, nil)
   327  		}
   328  	}
   329  
   330  	nodeID, conds = builder.pushdownFilters(nodeID, conds, true)
   331  	if len(conds) > 0 {
   332  		nodeID = builder.appendNode(&plan.Node{
   333  			NodeType:   plan.Node_FILTER,
   334  			Children:   []int32{nodeID},
   335  			FilterList: conds,
   336  		}, nil)
   337  	}
   338  	return nodeID
   339  }
   340  
   341  func (builder *QueryBuilder) gatherJoinLeavesAndConds(joinNode *plan.Node, leaves []*plan.Node, conds []*plan.Expr) ([]*plan.Node, []*plan.Expr) {
   342  	if joinNode.NodeType != plan.Node_JOIN || joinNode.JoinType != plan.Node_INNER || joinNode.Limit != nil {
   343  		nodeID := builder.determineJoinOrder(joinNode.NodeId)
   344  		leaves = append(leaves, builder.qry.Nodes[nodeID])
   345  		return leaves, conds
   346  	}
   347  
   348  	for _, childID := range joinNode.Children {
   349  		leaves, conds = builder.gatherJoinLeavesAndConds(builder.qry.Nodes[childID], leaves, conds)
   350  	}
   351  
   352  	conds = append(conds, joinNode.OnList...)
   353  
   354  	return leaves, conds
   355  }
   356  
   357  func (builder *QueryBuilder) getJoinGraph(leaves []*plan.Node, conds []*plan.Expr) []*joinVertex {
   358  	vertices := make([]*joinVertex, len(leaves))
   359  	tag2Vert := make(map[int32]int32)
   360  
   361  	for i, node := range leaves {
   362  		vertices[i] = &joinVertex{
   363  			node:     node,
   364  			children: make(map[int32]bool),
   365  			parent:   -1,
   366  		}
   367  
   368  		for _, tag := range builder.enumerateTags(node.NodeId) {
   369  			tag2Vert[tag] = int32(i)
   370  		}
   371  	}
   372  
   373  	edgeMap := make(map[[2]int32]*joinEdge)
   374  
   375  	for i := 0; i < 2; i++ {
   376  		for _, cond := range conds {
   377  			ok, leftCol, rightCol := checkStrictJoinPred(cond)
   378  			if !ok {
   379  				continue
   380  			}
   381  			var leftId, rightId int32
   382  			if leftId, ok = tag2Vert[leftCol.RelPos]; !ok {
   383  				continue
   384  			}
   385  			if rightId, ok = tag2Vert[rightCol.RelPos]; !ok {
   386  				continue
   387  			}
   388  
   389  			if leftId > rightId {
   390  				leftId, rightId = rightId, leftId
   391  				leftCol, rightCol = rightCol, leftCol
   392  			}
   393  
   394  			edge := edgeMap[[2]int32{leftId, rightId}]
   395  			if i == 0 {
   396  				if edge == nil {
   397  					edge = &joinEdge{}
   398  				}
   399  				edge.leftCols = append(edge.leftCols, leftCol.ColPos)
   400  				edge.rightCols = append(edge.rightCols, rightCol.ColPos)
   401  				edgeMap[[2]int32{leftId, rightId}] = edge
   402  			}
   403  
   404  			leftParent := vertices[leftId].parent
   405  			if isHighNdvCols(edge.leftCols, builder.tag2Table[leftCol.RelPos], builder) {
   406  				if leftParent == -1 || shouldChangeParent(leftId, leftParent, rightId, vertices) {
   407  					if vertices[rightId].parent != leftId {
   408  						setParent(leftId, rightId, vertices)
   409  					} else if vertices[leftId].node.Stats.Outcnt < vertices[rightId].node.Stats.Outcnt {
   410  						unsetParent(rightId, leftId, vertices)
   411  						setParent(leftId, rightId, vertices)
   412  					}
   413  				}
   414  			}
   415  			rightParent := vertices[rightId].parent
   416  			if isHighNdvCols(edge.rightCols, builder.tag2Table[rightCol.RelPos], builder) {
   417  				if rightParent == -1 || shouldChangeParent(rightId, rightParent, leftId, vertices) {
   418  					if vertices[leftId].parent != rightId {
   419  						setParent(rightId, leftId, vertices)
   420  					} else if vertices[rightId].node.Stats.Outcnt < vertices[leftId].node.Stats.Outcnt {
   421  						unsetParent(leftId, rightId, vertices)
   422  						setParent(rightId, leftId, vertices)
   423  					}
   424  				}
   425  			}
   426  		}
   427  	}
   428  	return vertices
   429  }
   430  
   431  func setParent(child, parent int32, vertices []*joinVertex) {
   432  	if child == -1 || parent == -1 {
   433  		return
   434  	}
   435  	unsetParent(child, vertices[child].parent, vertices)
   436  	vertices[child].parent = parent
   437  	vertices[parent].children[child] = true
   438  }
   439  
   440  func unsetParent(child, parent int32, vertices []*joinVertex) {
   441  	if child == -1 || parent == -1 {
   442  		return
   443  	}
   444  	if vertices[child].parent == parent {
   445  		vertices[child].parent = -1
   446  		delete(vertices[parent].children, child)
   447  	}
   448  }
   449  
   450  func findSelectivityInChildren(self int32, vertices []*joinVertex) bool {
   451  	if vertices[self].node.Stats.Selectivity < 0.9 {
   452  		return true
   453  	}
   454  	for child := range vertices[self].children {
   455  		if findSelectivityInChildren(child, vertices) {
   456  			return true
   457  		}
   458  	}
   459  	return false
   460  }
   461  
   462  func findParent(self, target int32, vertices []*joinVertex) bool {
   463  	parent := vertices[self].parent
   464  	if parent == target {
   465  		return true
   466  	} else if parent != -1 {
   467  		return findParent(parent, target, vertices)
   468  	}
   469  	return false
   470  }
   471  
   472  func shouldChangeParent(self, currentParent, nextParent int32, vertices []*joinVertex) bool {
   473  	selfStats := vertices[self].node.Stats
   474  	currentParentStats := vertices[currentParent].node.Stats
   475  	nextParentStats := vertices[nextParent].node.Stats
   476  	if currentParentStats.Cost > selfStats.Cost && currentParentStats.Cost > nextParentStats.Cost {
   477  		// current Parent is the biggest node
   478  		if findParent(nextParent, currentParent, vertices) {
   479  			return true
   480  		}
   481  		if findSelectivityInChildren(self, vertices) {
   482  			return false
   483  		}
   484  	}
   485  	if nextParentStats.Cost > selfStats.Cost && nextParentStats.Cost > currentParentStats.Cost {
   486  		// next Parent is the biggest node
   487  		if findParent(currentParent, nextParent, vertices) {
   488  			return false
   489  		}
   490  		if findSelectivityInChildren(self, vertices) {
   491  			return true
   492  		}
   493  	}
   494  	// self is the biggest node
   495  	return compareStats(nextParentStats, currentParentStats)
   496  }
   497  
   498  // buildSubJoinTree build sub- join tree for a fact table and all its dimension tables
   499  func (builder *QueryBuilder) buildSubJoinTree(vertices []*joinVertex, vid int32) {
   500  	vertex := vertices[vid]
   501  	vertex.joined = true
   502  
   503  	if len(vertex.children) == 0 {
   504  		return
   505  	}
   506  
   507  	dimensions := make([]*joinVertex, 0, len(vertex.children))
   508  	for child := range vertex.children {
   509  		if vertices[child].joined {
   510  			continue
   511  		}
   512  		builder.buildSubJoinTree(vertices, child)
   513  		dimensions = append(dimensions, vertices[child])
   514  	}
   515  	sort.Slice(dimensions, func(i, j int) bool { return compareStats(dimensions[i].node.Stats, dimensions[j].node.Stats) })
   516  
   517  	for _, child := range dimensions {
   518  
   519  		children := []int32{vertex.node.NodeId, child.node.NodeId}
   520  		nodeID := builder.appendNode(&plan.Node{
   521  			NodeType: plan.Node_JOIN,
   522  			Children: children,
   523  			JoinType: plan.Node_INNER,
   524  		}, nil)
   525  
   526  		vertex.node = builder.qry.Nodes[nodeID]
   527  	}
   528  }
   529  
   530  func containsAllPKs(cols []int32, tableDef *plan.TableDef) bool {
   531  	pkNames := tableDef.Pkey.Names
   532  	pks := make([]int32, len(pkNames))
   533  	for i := range pkNames {
   534  		pks[i] = tableDef.Name2ColIndex[pkNames[i]]
   535  	}
   536  	if len(pks) == 0 {
   537  		return false
   538  	}
   539  	for _, pk := range pks {
   540  		found := false
   541  		for _, col := range cols {
   542  			if col == pk {
   543  				found = true
   544  				break
   545  			}
   546  		}
   547  		if !found {
   548  			return false
   549  		}
   550  	}
   551  	return true
   552  }
   553  
   554  func (builder *QueryBuilder) enumerateTags(nodeID int32) []int32 {
   555  	var tags []int32
   556  
   557  	node := builder.qry.Nodes[nodeID]
   558  	if len(node.BindingTags) > 0 {
   559  		tags = append(tags, node.BindingTags...)
   560  		if node.NodeType != plan.Node_JOIN {
   561  			return tags
   562  		}
   563  	}
   564  
   565  	for _, childID := range builder.qry.Nodes[nodeID].Children {
   566  		tags = append(tags, builder.enumerateTags(childID)...)
   567  	}
   568  
   569  	return tags
   570  }