github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/plan/join_order.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"math"
    19  	"sort"
    20  
    21  	"github.com/matrixorigin/matrixone/pkg/pb/plan"
    22  )
    23  
    24  type joinEdge struct {
    25  	leftCols  []int32
    26  	rightCols []int32
    27  }
    28  
    29  type joinVertex struct {
    30  	node        *plan.Node
    31  	pks         []int32
    32  	selectivity float64
    33  	outcnt      float64
    34  	pkSelRate   float64
    35  
    36  	children map[int32]any
    37  	parent   int32
    38  
    39  	joined bool
    40  }
    41  
    42  func (builder *QueryBuilder) pushdownSemiAntiJoins(nodeID int32) int32 {
    43  	// TODO: handle SEMI/ANTI joins in join order
    44  	node := builder.qry.Nodes[nodeID]
    45  
    46  	for i, childID := range node.Children {
    47  		node.Children[i] = builder.pushdownSemiAntiJoins(childID)
    48  	}
    49  
    50  	if node.NodeType != plan.Node_JOIN {
    51  		return nodeID
    52  	}
    53  
    54  	if node.JoinType != plan.Node_SEMI && node.JoinType != plan.Node_ANTI {
    55  		return nodeID
    56  	}
    57  
    58  	for _, filter := range node.OnList {
    59  		if f, ok := filter.Expr.(*plan.Expr_F); ok {
    60  			if f.F.Func.ObjName != "=" {
    61  				return nodeID
    62  			}
    63  		}
    64  	}
    65  
    66  	var targetNode *plan.Node
    67  	var targetSide int32
    68  
    69  	joinNode := builder.qry.Nodes[node.Children[0]]
    70  
    71  	for {
    72  		if joinNode.NodeType != plan.Node_JOIN {
    73  			break
    74  		}
    75  
    76  		if joinNode.JoinType != plan.Node_INNER && joinNode.JoinType != plan.Node_LEFT {
    77  			break
    78  		}
    79  
    80  		leftTags := make(map[int32]*Binding)
    81  		for _, tag := range builder.enumerateTags(joinNode.Children[0]) {
    82  			leftTags[tag] = nil
    83  		}
    84  
    85  		rightTags := make(map[int32]*Binding)
    86  		for _, tag := range builder.enumerateTags(joinNode.Children[1]) {
    87  			rightTags[tag] = nil
    88  		}
    89  
    90  		var joinSide int8
    91  		for _, cond := range node.OnList {
    92  			joinSide |= getJoinSide(cond, leftTags, rightTags, 0)
    93  		}
    94  
    95  		if joinSide == JoinSideLeft {
    96  			targetNode = joinNode
    97  			targetSide = 0
    98  			joinNode = builder.qry.Nodes[joinNode.Children[0]]
    99  		} else if joinNode.JoinType == plan.Node_INNER && joinSide == JoinSideRight {
   100  			targetNode = joinNode
   101  			targetSide = 1
   102  			joinNode = builder.qry.Nodes[joinNode.Children[1]]
   103  		} else {
   104  			break
   105  		}
   106  	}
   107  
   108  	if targetNode != nil {
   109  		nodeID = node.Children[0]
   110  		node.Children[0] = targetNode.Children[targetSide]
   111  		targetNode.Children[targetSide] = node.NodeId
   112  	}
   113  
   114  	return nodeID
   115  }
   116  
   117  func (builder *QueryBuilder) swapJoinOrderByStats(children []int32) []int32 {
   118  	left := builder.qry.Nodes[children[0]].Stats.Outcnt
   119  	right := builder.qry.Nodes[children[1]].Stats.Outcnt
   120  	if left < right {
   121  		return []int32{children[1], children[0]}
   122  	} else {
   123  		return children
   124  	}
   125  }
   126  func (builder *QueryBuilder) determineJoinOrder(nodeID int32) int32 {
   127  	node := builder.qry.Nodes[nodeID]
   128  
   129  	if node.NodeType != plan.Node_JOIN || node.JoinType != plan.Node_INNER {
   130  		if len(node.Children) > 0 {
   131  			for i, child := range node.Children {
   132  				node.Children[i] = builder.determineJoinOrder(child)
   133  			}
   134  		}
   135  		return nodeID
   136  	}
   137  
   138  	leaves, conds := builder.gatherJoinLeavesAndConds(node, nil, nil)
   139  
   140  	vertices := builder.getJoinGraph(leaves, conds)
   141  	subTrees := make([]*plan.Node, 0, len(leaves))
   142  	for i, vertex := range vertices {
   143  		// TODO handle cycles in the "dimension -> fact" DAG
   144  		if vertex.parent == -1 {
   145  			builder.buildSubJoinTree(vertices, int32(i))
   146  			subTrees = append(subTrees, vertex.node)
   147  		}
   148  	}
   149  	for _, vertex := range vertices {
   150  		if !vertex.joined {
   151  			subTrees = append(subTrees, vertex.node)
   152  		}
   153  	}
   154  
   155  	sort.Slice(subTrees, func(i, j int) bool {
   156  		if subTrees[j].Stats == nil {
   157  			return false
   158  		}
   159  		if subTrees[i].Stats == nil {
   160  			return true
   161  		}
   162  		if math.Abs(subTrees[i].Stats.Selectivity-subTrees[j].Stats.Selectivity) > 0.01 {
   163  			return subTrees[i].Stats.Selectivity < subTrees[j].Stats.Selectivity
   164  		} else {
   165  			return subTrees[i].Stats.Outcnt < subTrees[j].Stats.Outcnt
   166  		}
   167  	})
   168  
   169  	leafByTag := make(map[int32]int32)
   170  
   171  	for i, leaf := range subTrees {
   172  		tags := builder.enumerateTags(leaf.NodeId)
   173  
   174  		for _, tag := range tags {
   175  			leafByTag[tag] = int32(i)
   176  		}
   177  	}
   178  
   179  	nLeaf := int32(len(subTrees))
   180  
   181  	adjMat := make([]bool, nLeaf*nLeaf)
   182  	firstConnected := nLeaf
   183  	visited := make([]bool, nLeaf)
   184  
   185  	for _, cond := range conds {
   186  		hyperEdge := make(map[int32]any)
   187  		getHyperEdgeFromExpr(cond, leafByTag, hyperEdge)
   188  
   189  		for i := range hyperEdge {
   190  			if i < firstConnected {
   191  				firstConnected = i
   192  			}
   193  			for j := range hyperEdge {
   194  				adjMat[int32(nLeaf)*i+j] = true
   195  			}
   196  		}
   197  	}
   198  
   199  	if firstConnected < nLeaf {
   200  		nodeID = subTrees[firstConnected].NodeId
   201  		visited[firstConnected] = true
   202  
   203  		eligible := adjMat[firstConnected*nLeaf : (firstConnected+1)*nLeaf]
   204  
   205  		for {
   206  			nextSibling := nLeaf
   207  			for i := range eligible {
   208  				if !visited[i] && eligible[i] {
   209  					nextSibling = int32(i)
   210  					break
   211  				}
   212  			}
   213  
   214  			if nextSibling == nLeaf {
   215  				break
   216  			}
   217  
   218  			visited[nextSibling] = true
   219  
   220  			children := []int32{nodeID, subTrees[nextSibling].NodeId}
   221  			children = builder.swapJoinOrderByStats(children)
   222  			nodeID = builder.appendNode(&plan.Node{
   223  				NodeType: plan.Node_JOIN,
   224  				Children: children,
   225  				JoinType: plan.Node_INNER,
   226  			}, nil)
   227  
   228  			for i, adj := range adjMat[nextSibling*nLeaf : (nextSibling+1)*nLeaf] {
   229  				eligible[i] = eligible[i] || adj
   230  			}
   231  		}
   232  
   233  		for i := range visited {
   234  			if !visited[i] {
   235  				nodeID = builder.appendNode(&plan.Node{
   236  					NodeType: plan.Node_JOIN,
   237  					Children: []int32{nodeID, subTrees[i].NodeId},
   238  					JoinType: plan.Node_INNER,
   239  				}, nil)
   240  			}
   241  		}
   242  	} else {
   243  		newNode := subTrees[0]
   244  		nodeID = newNode.NodeId
   245  
   246  		for i := 1; i < len(subTrees); i++ {
   247  			children := []int32{nodeID, subTrees[i].NodeId}
   248  			children = builder.swapJoinOrderByStats(children)
   249  			nodeID = builder.appendNode(&plan.Node{
   250  				NodeType: plan.Node_JOIN,
   251  				Children: children,
   252  				JoinType: plan.Node_INNER,
   253  			}, nil)
   254  		}
   255  	}
   256  
   257  	nodeID, _ = builder.pushdownFilters(nodeID, conds)
   258  	ReCalcNodeStats(nodeID, builder, true)
   259  
   260  	return nodeID
   261  }
   262  
   263  func (builder *QueryBuilder) gatherJoinLeavesAndConds(joinNode *plan.Node, leaves []*plan.Node, conds []*plan.Expr) ([]*plan.Node, []*plan.Expr) {
   264  	if joinNode.NodeType != plan.Node_JOIN || joinNode.JoinType != plan.Node_INNER {
   265  		nodeID := builder.determineJoinOrder(joinNode.NodeId)
   266  		leaves = append(leaves, builder.qry.Nodes[nodeID])
   267  		return leaves, conds
   268  	}
   269  
   270  	for _, childID := range joinNode.Children {
   271  		leaves, conds = builder.gatherJoinLeavesAndConds(builder.qry.Nodes[childID], leaves, conds)
   272  	}
   273  
   274  	conds = append(conds, joinNode.OnList...)
   275  
   276  	return leaves, conds
   277  }
   278  
   279  func (builder *QueryBuilder) getJoinGraph(leaves []*plan.Node, conds []*plan.Expr) []*joinVertex {
   280  	vertices := make([]*joinVertex, len(leaves))
   281  	tag2Vert := make(map[int32]int32)
   282  
   283  	for i, node := range leaves {
   284  		vertices[i] = &joinVertex{
   285  			node:        node,
   286  			selectivity: node.Stats.Selectivity,
   287  			outcnt:      node.Stats.Outcnt,
   288  			pkSelRate:   1.0,
   289  			children:    make(map[int32]any),
   290  			parent:      -1,
   291  		}
   292  
   293  		if node.NodeType == plan.Node_TABLE_SCAN {
   294  			binding := builder.ctxByNode[node.NodeId].bindingByTag[node.BindingTags[0]]
   295  			pkDef := builder.compCtx.GetPrimaryKeyDef(node.ObjRef.SchemaName, node.ObjRef.ObjName)
   296  			pks := make([]int32, len(pkDef))
   297  			for i, pk := range pkDef {
   298  				pks[i] = binding.FindColumn(pk.Name)
   299  			}
   300  			vertices[i].pks = pks
   301  			tag2Vert[node.BindingTags[0]] = int32(i)
   302  		}
   303  
   304  		for _, filter := range node.FilterList {
   305  			if builder.filterOnPK(filter, vertices[i].pks) {
   306  				vertices[i].pkSelRate *= 0.1
   307  			}
   308  		}
   309  	}
   310  
   311  	edgeMap := make(map[[2]int32]*joinEdge)
   312  
   313  	for _, cond := range conds {
   314  		if f, ok := cond.Expr.(*plan.Expr_F); ok {
   315  			if f.F.Func.ObjName != "=" {
   316  				continue
   317  			}
   318  			if _, ok = f.F.Args[0].Expr.(*plan.Expr_Col); !ok {
   319  				continue
   320  			}
   321  			if _, ok = f.F.Args[1].Expr.(*plan.Expr_Col); !ok {
   322  				continue
   323  			}
   324  
   325  			var leftId, rightId int32
   326  
   327  			leftCol := f.F.Args[0].Expr.(*plan.Expr_Col).Col
   328  			rightCol := f.F.Args[1].Expr.(*plan.Expr_Col).Col
   329  			if leftId, ok = tag2Vert[leftCol.RelPos]; !ok {
   330  				continue
   331  			}
   332  			if rightId, ok = tag2Vert[rightCol.RelPos]; !ok {
   333  				continue
   334  			}
   335  			if vertices[leftId].parent != -1 && vertices[rightId].parent != -1 {
   336  				continue
   337  			}
   338  
   339  			if leftId > rightId {
   340  				leftId, rightId = rightId, leftId
   341  				leftCol, rightCol = rightCol, leftCol
   342  			}
   343  
   344  			edge := edgeMap[[2]int32{leftId, rightId}]
   345  			if edge == nil {
   346  				edge = &joinEdge{}
   347  			}
   348  			edge.leftCols = append(edge.leftCols, leftCol.ColPos)
   349  			edge.rightCols = append(edge.rightCols, rightCol.ColPos)
   350  			edgeMap[[2]int32{leftId, rightId}] = edge
   351  
   352  			if vertices[leftId].parent == -1 && containsAllPKs(edge.leftCols, vertices[leftId].pks) {
   353  				if vertices[rightId].parent != leftId {
   354  					vertices[leftId].parent = rightId
   355  					vertices[rightId].children[leftId] = nil
   356  				}
   357  			}
   358  			if vertices[rightId].parent == -1 && containsAllPKs(edge.rightCols, vertices[rightId].pks) {
   359  				if vertices[leftId].parent != rightId {
   360  					vertices[rightId].parent = leftId
   361  					vertices[leftId].children[rightId] = nil
   362  				}
   363  			}
   364  		}
   365  	}
   366  
   367  	return vertices
   368  }
   369  
   370  // buildSubJoinTree build sub- join tree for a fact table and all its dimension tables
   371  func (builder *QueryBuilder) buildSubJoinTree(vertices []*joinVertex, vid int32) {
   372  	vertex := vertices[vid]
   373  	vertex.joined = true
   374  
   375  	if len(vertex.children) == 0 {
   376  		return
   377  	}
   378  
   379  	dimensions := make([]*joinVertex, 0, len(vertex.children))
   380  	for child := range vertex.children {
   381  		if vertices[child].joined {
   382  			continue
   383  		}
   384  		builder.buildSubJoinTree(vertices, child)
   385  		dimensions = append(dimensions, vertices[child])
   386  	}
   387  	sort.Slice(dimensions, func(i, j int) bool {
   388  		if dimensions[i].pkSelRate < dimensions[j].pkSelRate {
   389  			return true
   390  		} else if dimensions[i].pkSelRate > dimensions[j].pkSelRate {
   391  			return false
   392  		} else {
   393  			//if math.Abs(dimensions[i].selectivity-dimensions[j].selectivity) > 0.01 {
   394  			//	return dimensions[i].selectivity < dimensions[j].selectivity
   395  			//} else {
   396  			return dimensions[i].outcnt < dimensions[j].outcnt
   397  			//}
   398  		}
   399  	})
   400  
   401  	for _, child := range dimensions {
   402  
   403  		children := []int32{vertex.node.NodeId, child.node.NodeId}
   404  		children = builder.swapJoinOrderByStats(children)
   405  		nodeId := builder.appendNode(&plan.Node{
   406  			NodeType: plan.Node_JOIN,
   407  			Children: children,
   408  			JoinType: plan.Node_INNER,
   409  		}, nil)
   410  
   411  		vertex.outcnt *= child.pkSelRate
   412  		vertex.pkSelRate *= child.pkSelRate
   413  		vertex.node = builder.qry.Nodes[nodeId]
   414  		vertex.node.Stats.Outcnt = vertex.outcnt
   415  	}
   416  }
   417  
   418  func containsAllPKs(cols, pks []int32) bool {
   419  	if len(pks) == 0 {
   420  		return false
   421  	}
   422  
   423  	for _, pk := range pks {
   424  		found := false
   425  		for _, col := range cols {
   426  			if col == pk {
   427  				found = true
   428  				break
   429  			}
   430  		}
   431  
   432  		if !found {
   433  			return false
   434  		}
   435  	}
   436  
   437  	return true
   438  }
   439  
   440  func (builder *QueryBuilder) filterOnPK(filter *plan.Expr, pks []int32) bool {
   441  	// FIXME better handle expressions
   442  	return len(pks) > 0
   443  }
   444  
   445  func (builder *QueryBuilder) enumerateTags(nodeID int32) []int32 {
   446  	var tags []int32
   447  
   448  	node := builder.qry.Nodes[nodeID]
   449  	if len(node.BindingTags) > 0 {
   450  		tags = append(tags, node.BindingTags...)
   451  		if node.NodeType != plan.Node_JOIN {
   452  			return tags
   453  		}
   454  	}
   455  
   456  	for _, childID := range builder.qry.Nodes[nodeID].Children {
   457  		tags = append(tags, builder.enumerateTags(childID)...)
   458  	}
   459  
   460  	return tags
   461  }