github.com/tetratelabs/wazero@v1.7.3-0.20240513003603-48f702e154b5/internal/engine/wazevo/ssa/pass_cfg.go (about)

     1  package ssa
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"strings"
     7  
     8  	"github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi"
     9  )
    10  
    11  // passCalculateImmediateDominators calculates immediate dominators for each basic block.
    12  // The result is stored in b.dominators. This make it possible for the following passes to
    13  // use builder.isDominatedBy to check if a block is dominated by another block.
    14  //
    15  // At the last of pass, this function also does the loop detection and sets the basicBlock.loop flag.
    16  func passCalculateImmediateDominators(b *builder) {
    17  	reversePostOrder := b.reversePostOrderedBasicBlocks[:0]
    18  	exploreStack := b.blkStack[:0]
    19  	b.clearBlkVisited()
    20  
    21  	entryBlk := b.entryBlk()
    22  
    23  	// Store the reverse postorder from the entrypoint into reversePostOrder slice.
    24  	// This calculation of reverse postorder is not described in the paper,
    25  	// so we use heuristic to calculate it so that we could potentially handle arbitrary
    26  	// complex CFGs under the assumption that success is sorted in program's natural order.
    27  	// That means blk.success[i] always appears before blk.success[i+1] in the source program,
    28  	// which is a reasonable assumption as long as SSA Builder is properly used.
    29  	//
    30  	// First we push blocks in postorder iteratively visit successors of the entry block.
    31  	exploreStack = append(exploreStack, entryBlk)
    32  	const visitStateUnseen, visitStateSeen, visitStateDone = 0, 1, 2
    33  	b.blkVisited[entryBlk] = visitStateSeen
    34  	for len(exploreStack) > 0 {
    35  		tail := len(exploreStack) - 1
    36  		blk := exploreStack[tail]
    37  		exploreStack = exploreStack[:tail]
    38  		switch b.blkVisited[blk] {
    39  		case visitStateUnseen:
    40  			// This is likely a bug in the frontend.
    41  			panic("BUG: unsupported CFG")
    42  		case visitStateSeen:
    43  			// This is the first time to pop this block, and we have to see the successors first.
    44  			// So push this block again to the stack.
    45  			exploreStack = append(exploreStack, blk)
    46  			// And push the successors to the stack if necessary.
    47  			for _, succ := range blk.success {
    48  				if succ.ReturnBlock() || succ.invalid {
    49  					continue
    50  				}
    51  				if b.blkVisited[succ] == visitStateUnseen {
    52  					b.blkVisited[succ] = visitStateSeen
    53  					exploreStack = append(exploreStack, succ)
    54  				}
    55  			}
    56  			// Finally, we could pop this block once we pop all of its successors.
    57  			b.blkVisited[blk] = visitStateDone
    58  		case visitStateDone:
    59  			// Note: at this point we push blk in postorder despite its name.
    60  			reversePostOrder = append(reversePostOrder, blk)
    61  		}
    62  	}
    63  	// At this point, reversePostOrder has postorder actually, so we reverse it.
    64  	for i := len(reversePostOrder)/2 - 1; i >= 0; i-- {
    65  		j := len(reversePostOrder) - 1 - i
    66  		reversePostOrder[i], reversePostOrder[j] = reversePostOrder[j], reversePostOrder[i]
    67  	}
    68  
    69  	for i, blk := range reversePostOrder {
    70  		blk.reversePostOrder = i
    71  	}
    72  
    73  	// Reuse the dominators slice if possible from the previous computation of function.
    74  	b.dominators = b.dominators[:cap(b.dominators)]
    75  	if len(b.dominators) < b.basicBlocksPool.Allocated() {
    76  		// Generously reserve space in the slice because the slice will be reused future allocation.
    77  		b.dominators = append(b.dominators, make([]*basicBlock, b.basicBlocksPool.Allocated())...)
    78  	}
    79  	calculateDominators(reversePostOrder, b.dominators)
    80  
    81  	// Reuse the slices for the future use.
    82  	b.blkStack = exploreStack
    83  
    84  	// For the following passes.
    85  	b.reversePostOrderedBasicBlocks = reversePostOrder
    86  
    87  	// Ready to detect loops!
    88  	subPassLoopDetection(b)
    89  }
    90  
    91  // calculateDominators calculates the immediate dominator of each node in the CFG, and store the result in `doms`.
    92  // The algorithm is based on the one described in the paper "A Simple, Fast Dominance Algorithm"
    93  // https://www.cs.rice.edu/~keith/EMBED/dom.pdf which is a faster/simple alternative to the well known Lengauer-Tarjan algorithm.
    94  //
    95  // The following code almost matches the pseudocode in the paper with one exception (see the code comment below).
    96  //
    97  // The result slice `doms` must be pre-allocated with the size larger than the size of dfsBlocks.
    98  func calculateDominators(reversePostOrderedBlks []*basicBlock, doms []*basicBlock) {
    99  	entry, reversePostOrderedBlks := reversePostOrderedBlks[0], reversePostOrderedBlks[1: /* skips entry point */]
   100  	for _, blk := range reversePostOrderedBlks {
   101  		doms[blk.id] = nil
   102  	}
   103  	doms[entry.id] = entry
   104  
   105  	changed := true
   106  	for changed {
   107  		changed = false
   108  		for _, blk := range reversePostOrderedBlks {
   109  			var u *basicBlock
   110  			for i := range blk.preds {
   111  				pred := blk.preds[i].blk
   112  				// Skip if this pred is not reachable yet. Note that this is not described in the paper,
   113  				// but it is necessary to handle nested loops etc.
   114  				if doms[pred.id] == nil {
   115  					continue
   116  				}
   117  
   118  				if u == nil {
   119  					u = pred
   120  					continue
   121  				} else {
   122  					u = intersect(doms, u, pred)
   123  				}
   124  			}
   125  			if doms[blk.id] != u {
   126  				doms[blk.id] = u
   127  				changed = true
   128  			}
   129  		}
   130  	}
   131  }
   132  
   133  // intersect returns the common dominator of blk1 and blk2.
   134  //
   135  // This is the `intersect` function in the paper.
   136  func intersect(doms []*basicBlock, blk1 *basicBlock, blk2 *basicBlock) *basicBlock {
   137  	finger1, finger2 := blk1, blk2
   138  	for finger1 != finger2 {
   139  		// Move the 'finger1' upwards to its immediate dominator.
   140  		for finger1.reversePostOrder > finger2.reversePostOrder {
   141  			finger1 = doms[finger1.id]
   142  		}
   143  		// Move the 'finger2' upwards to its immediate dominator.
   144  		for finger2.reversePostOrder > finger1.reversePostOrder {
   145  			finger2 = doms[finger2.id]
   146  		}
   147  	}
   148  	return finger1
   149  }
   150  
   151  // subPassLoopDetection detects loops in the function using the immediate dominators.
   152  //
   153  // This is run at the last of passCalculateImmediateDominators.
   154  func subPassLoopDetection(b *builder) {
   155  	for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
   156  		for i := range blk.preds {
   157  			pred := blk.preds[i].blk
   158  			if pred.invalid {
   159  				continue
   160  			}
   161  			if b.isDominatedBy(pred, blk) {
   162  				blk.loopHeader = true
   163  			}
   164  		}
   165  	}
   166  }
   167  
   168  // buildLoopNestingForest builds the loop nesting forest for the function.
   169  // This must be called after branch splitting since it relies on the CFG.
   170  func passBuildLoopNestingForest(b *builder) {
   171  	ent := b.entryBlk()
   172  	doms := b.dominators
   173  	for _, blk := range b.reversePostOrderedBasicBlocks {
   174  		n := doms[blk.id]
   175  		for !n.loopHeader && n != ent {
   176  			n = doms[n.id]
   177  		}
   178  
   179  		if n == ent && blk.loopHeader {
   180  			b.loopNestingForestRoots = append(b.loopNestingForestRoots, blk)
   181  		} else if n == ent {
   182  		} else if n.loopHeader {
   183  			n.loopNestingForestChildren = append(n.loopNestingForestChildren, blk)
   184  		}
   185  	}
   186  
   187  	if wazevoapi.SSALoggingEnabled {
   188  		for _, root := range b.loopNestingForestRoots {
   189  			printLoopNestingForest(root.(*basicBlock), 0)
   190  		}
   191  	}
   192  }
   193  
   194  func printLoopNestingForest(root *basicBlock, depth int) {
   195  	fmt.Println(strings.Repeat("\t", depth), "loop nesting forest root:", root.ID())
   196  	for _, child := range root.loopNestingForestChildren {
   197  		fmt.Println(strings.Repeat("\t", depth+1), "child:", child.ID())
   198  		if child.LoopHeader() {
   199  			printLoopNestingForest(child.(*basicBlock), depth+2)
   200  		}
   201  	}
   202  }
   203  
   204  type dominatorSparseTree struct {
   205  	time         int
   206  	euler        []*basicBlock
   207  	first, depth []int
   208  	table        [][]int
   209  }
   210  
   211  // passBuildDominatorTree builds the dominator tree for the function, and constructs builder.sparseTree.
   212  func passBuildDominatorTree(b *builder) {
   213  	// First we materialize the children of each node in the dominator tree.
   214  	idoms := b.dominators
   215  	for _, blk := range b.reversePostOrderedBasicBlocks {
   216  		parent := idoms[blk.id]
   217  		if parent == nil {
   218  			panic("BUG")
   219  		} else if parent == blk {
   220  			// This is the entry block.
   221  			continue
   222  		}
   223  		if prev := parent.child; prev == nil {
   224  			parent.child = blk
   225  		} else {
   226  			parent.child = blk
   227  			blk.sibling = prev
   228  		}
   229  	}
   230  
   231  	// Reset the state from the previous computation.
   232  	n := b.basicBlocksPool.Allocated()
   233  	st := &b.sparseTree
   234  	st.euler = append(st.euler[:0], make([]*basicBlock, 2*n-1)...)
   235  	st.first = append(st.first[:0], make([]int, n)...)
   236  	for i := range st.first {
   237  		st.first[i] = -1
   238  	}
   239  	st.depth = append(st.depth[:0], make([]int, 2*n-1)...)
   240  	st.time = 0
   241  
   242  	// Start building the sparse tree.
   243  	st.eulerTour(b.entryBlk(), 0)
   244  	st.buildSparseTable()
   245  }
   246  
   247  func (dt *dominatorSparseTree) eulerTour(node *basicBlock, height int) {
   248  	if wazevoapi.SSALoggingEnabled {
   249  		fmt.Println(strings.Repeat("\t", height), "euler tour:", node.ID())
   250  	}
   251  	dt.euler[dt.time] = node
   252  	dt.depth[dt.time] = height
   253  	if dt.first[node.id] == -1 {
   254  		dt.first[node.id] = dt.time
   255  	}
   256  	dt.time++
   257  
   258  	for child := node.child; child != nil; child = child.sibling {
   259  		dt.eulerTour(child, height+1)
   260  		dt.euler[dt.time] = node // add the current node again after visiting a child
   261  		dt.depth[dt.time] = height
   262  		dt.time++
   263  	}
   264  }
   265  
   266  // buildSparseTable builds a sparse table for RMQ queries.
   267  func (dt *dominatorSparseTree) buildSparseTable() {
   268  	n := len(dt.depth)
   269  	k := int(math.Log2(float64(n))) + 1
   270  	table := dt.table
   271  
   272  	if n >= len(table) {
   273  		table = append(table, make([][]int, n+1)...)
   274  	}
   275  	for i := range table {
   276  		if len(table[i]) < k {
   277  			table[i] = append(table[i], make([]int, k)...)
   278  		}
   279  		table[i][0] = i
   280  	}
   281  
   282  	for j := 1; 1<<j <= n; j++ {
   283  		for i := 0; i+(1<<j)-1 < n; i++ {
   284  			if dt.depth[table[i][j-1]] < dt.depth[table[i+(1<<(j-1))][j-1]] {
   285  				table[i][j] = table[i][j-1]
   286  			} else {
   287  				table[i][j] = table[i+(1<<(j-1))][j-1]
   288  			}
   289  		}
   290  	}
   291  	dt.table = table
   292  }
   293  
   294  // rmq performs a range minimum query on the sparse table.
   295  func (dt *dominatorSparseTree) rmq(l, r int) int {
   296  	table := dt.table
   297  	depth := dt.depth
   298  	j := int(math.Log2(float64(r - l + 1)))
   299  	if depth[table[l][j]] <= depth[table[r-(1<<j)+1][j]] {
   300  		return table[l][j]
   301  	}
   302  	return table[r-(1<<j)+1][j]
   303  }
   304  
   305  // findLCA finds the LCA using the Euler tour and RMQ.
   306  func (dt *dominatorSparseTree) findLCA(u, v BasicBlockID) *basicBlock {
   307  	first := dt.first
   308  	if first[u] > first[v] {
   309  		u, v = v, u
   310  	}
   311  	return dt.euler[dt.rmq(first[u], first[v])]
   312  }