github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/engine/wazevo/ssa/pass_cfg.go (about) 1 package ssa 2 3 import ( 4 "fmt" 5 "math" 6 "strings" 7 8 "github.com/bananabytelabs/wazero/internal/engine/wazevo/wazevoapi" 9 ) 10 11 // passCalculateImmediateDominators calculates immediate dominators for each basic block. 12 // The result is stored in b.dominators. This make it possible for the following passes to 13 // use builder.isDominatedBy to check if a block is dominated by another block. 14 // 15 // At the last of pass, this function also does the loop detection and sets the basicBlock.loop flag. 16 func passCalculateImmediateDominators(b *builder) { 17 reversePostOrder := b.reversePostOrderedBasicBlocks[:0] 18 exploreStack := b.blkStack[:0] 19 b.clearBlkVisited() 20 21 entryBlk := b.entryBlk() 22 23 // Store the reverse postorder from the entrypoint into reversePostOrder slice. 24 // This calculation of reverse postorder is not described in the paper, 25 // so we use heuristic to calculate it so that we could potentially handle arbitrary 26 // complex CFGs under the assumption that success is sorted in program's natural order. 27 // That means blk.success[i] always appears before blk.success[i+1] in the source program, 28 // which is a reasonable assumption as long as SSA Builder is properly used. 29 // 30 // First we push blocks in postorder iteratively visit successors of the entry block. 31 exploreStack = append(exploreStack, entryBlk) 32 const visitStateUnseen, visitStateSeen, visitStateDone = 0, 1, 2 33 b.blkVisited[entryBlk] = visitStateSeen 34 for len(exploreStack) > 0 { 35 tail := len(exploreStack) - 1 36 blk := exploreStack[tail] 37 exploreStack = exploreStack[:tail] 38 switch b.blkVisited[blk] { 39 case visitStateUnseen: 40 // This is likely a bug in the frontend. 41 panic("BUG: unsupported CFG") 42 case visitStateSeen: 43 // This is the first time to pop this block, and we have to see the successors first. 44 // So push this block again to the stack. 45 exploreStack = append(exploreStack, blk) 46 // And push the successors to the stack if necessary. 47 for _, succ := range blk.success { 48 if succ.ReturnBlock() || succ.invalid { 49 continue 50 } 51 if b.blkVisited[succ] == visitStateUnseen { 52 b.blkVisited[succ] = visitStateSeen 53 exploreStack = append(exploreStack, succ) 54 } 55 } 56 // Finally, we could pop this block once we pop all of its successors. 57 b.blkVisited[blk] = visitStateDone 58 case visitStateDone: 59 // Note: at this point we push blk in postorder despite its name. 60 reversePostOrder = append(reversePostOrder, blk) 61 } 62 } 63 // At this point, reversePostOrder has postorder actually, so we reverse it. 64 for i := len(reversePostOrder)/2 - 1; i >= 0; i-- { 65 j := len(reversePostOrder) - 1 - i 66 reversePostOrder[i], reversePostOrder[j] = reversePostOrder[j], reversePostOrder[i] 67 } 68 69 for i, blk := range reversePostOrder { 70 blk.reversePostOrder = i 71 } 72 73 // Reuse the dominators slice if possible from the previous computation of function. 74 b.dominators = b.dominators[:cap(b.dominators)] 75 if len(b.dominators) < b.basicBlocksPool.Allocated() { 76 // Generously reserve space in the slice because the slice will be reused future allocation. 77 b.dominators = append(b.dominators, make([]*basicBlock, b.basicBlocksPool.Allocated())...) 78 } 79 calculateDominators(reversePostOrder, b.dominators) 80 81 // Reuse the slices for the future use. 82 b.blkStack = exploreStack 83 84 // For the following passes. 85 b.reversePostOrderedBasicBlocks = reversePostOrder 86 87 // Ready to detect loops! 88 subPassLoopDetection(b) 89 } 90 91 // calculateDominators calculates the immediate dominator of each node in the CFG, and store the result in `doms`. 92 // The algorithm is based on the one described in the paper "A Simple, Fast Dominance Algorithm" 93 // https://www.cs.rice.edu/~keith/EMBED/dom.pdf which is a faster/simple alternative to the well known Lengauer-Tarjan algorithm. 94 // 95 // The following code almost matches the pseudocode in the paper with one exception (see the code comment below). 96 // 97 // The result slice `doms` must be pre-allocated with the size larger than the size of dfsBlocks. 98 func calculateDominators(reversePostOrderedBlks []*basicBlock, doms []*basicBlock) { 99 entry, reversePostOrderedBlks := reversePostOrderedBlks[0], reversePostOrderedBlks[1: /* skips entry point */] 100 for _, blk := range reversePostOrderedBlks { 101 doms[blk.id] = nil 102 } 103 doms[entry.id] = entry 104 105 changed := true 106 for changed { 107 changed = false 108 for _, blk := range reversePostOrderedBlks { 109 var u *basicBlock 110 for i := range blk.preds { 111 pred := blk.preds[i].blk 112 // Skip if this pred is not reachable yet. Note that this is not described in the paper, 113 // but it is necessary to handle nested loops etc. 114 if doms[pred.id] == nil { 115 continue 116 } 117 118 if u == nil { 119 u = pred 120 continue 121 } else { 122 u = intersect(doms, u, pred) 123 } 124 } 125 if doms[blk.id] != u { 126 doms[blk.id] = u 127 changed = true 128 } 129 } 130 } 131 } 132 133 // intersect returns the common dominator of blk1 and blk2. 134 // 135 // This is the `intersect` function in the paper. 136 func intersect(doms []*basicBlock, blk1 *basicBlock, blk2 *basicBlock) *basicBlock { 137 finger1, finger2 := blk1, blk2 138 for finger1 != finger2 { 139 // Move the 'finger1' upwards to its immediate dominator. 140 for finger1.reversePostOrder > finger2.reversePostOrder { 141 finger1 = doms[finger1.id] 142 } 143 // Move the 'finger2' upwards to its immediate dominator. 144 for finger2.reversePostOrder > finger1.reversePostOrder { 145 finger2 = doms[finger2.id] 146 } 147 } 148 return finger1 149 } 150 151 // subPassLoopDetection detects loops in the function using the immediate dominators. 152 // 153 // This is run at the last of passCalculateImmediateDominators. 154 func subPassLoopDetection(b *builder) { 155 for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() { 156 for i := range blk.preds { 157 pred := blk.preds[i].blk 158 if pred.invalid { 159 continue 160 } 161 if b.isDominatedBy(pred, blk) { 162 blk.loopHeader = true 163 } 164 } 165 } 166 } 167 168 // buildLoopNestingForest builds the loop nesting forest for the function. 169 // This must be called after branch splitting since it relies on the CFG. 170 func buildLoopNestingForest(b *builder) { 171 ent := b.entryBlk() 172 doms := b.dominators 173 for _, blk := range b.reversePostOrderedBasicBlocks { 174 n := doms[blk.id] 175 for !n.loopHeader && n != ent { 176 n = doms[n.id] 177 } 178 179 if n == ent && blk.loopHeader { 180 b.loopNestingForestRoots = append(b.loopNestingForestRoots, blk) 181 } else if n == ent { 182 } else if n.loopHeader { 183 n.loopNestingForestChildren = append(n.loopNestingForestChildren, blk) 184 } 185 } 186 187 if wazevoapi.SSALoggingEnabled { 188 for _, root := range b.loopNestingForestRoots { 189 printLoopNestingForest(root.(*basicBlock), 0) 190 } 191 } 192 } 193 194 func printLoopNestingForest(root *basicBlock, depth int) { 195 fmt.Println(strings.Repeat("\t", depth), "loop nesting forest root:", root.ID()) 196 for _, child := range root.loopNestingForestChildren { 197 fmt.Println(strings.Repeat("\t", depth+1), "child:", child.ID()) 198 if child.LoopHeader() { 199 printLoopNestingForest(child.(*basicBlock), depth+2) 200 } 201 } 202 } 203 204 type dominatorSparseTree struct { 205 time int 206 euler []*basicBlock 207 first, depth []int 208 table [][]int 209 } 210 211 // buildDominatorTree builds the dominator tree for the function, and constructs builder.sparseTree. 212 func buildDominatorTree(b *builder) { 213 // First we materialize the children of each node in the dominator tree. 214 idoms := b.dominators 215 for _, blk := range b.reversePostOrderedBasicBlocks { 216 parent := idoms[blk.id] 217 if parent == nil { 218 panic("BUG") 219 } else if parent == blk { 220 // This is the entry block. 221 continue 222 } 223 if prev := parent.child; prev == nil { 224 parent.child = blk 225 } else { 226 parent.child = blk 227 blk.sibling = prev 228 } 229 } 230 231 // Reset the state from the previous computation. 232 n := b.basicBlocksPool.Allocated() 233 st := &b.sparseTree 234 st.euler = append(st.euler[:0], make([]*basicBlock, 2*n-1)...) 235 st.first = append(st.first[:0], make([]int, n)...) 236 for i := range st.first { 237 st.first[i] = -1 238 } 239 st.depth = append(st.depth[:0], make([]int, 2*n-1)...) 240 st.time = 0 241 242 // Start building the sparse tree. 243 st.eulerTour(b.entryBlk(), 0) 244 st.buildSparseTable() 245 } 246 247 func (dt *dominatorSparseTree) eulerTour(node *basicBlock, height int) { 248 if wazevoapi.SSALoggingEnabled { 249 fmt.Println(strings.Repeat("\t", height), "euler tour:", node.ID()) 250 } 251 dt.euler[dt.time] = node 252 dt.depth[dt.time] = height 253 if dt.first[node.id] == -1 { 254 dt.first[node.id] = dt.time 255 } 256 dt.time++ 257 258 for child := node.child; child != nil; child = child.sibling { 259 dt.eulerTour(child, height+1) 260 dt.euler[dt.time] = node // add the current node again after visiting a child 261 dt.depth[dt.time] = height 262 dt.time++ 263 } 264 } 265 266 // buildSparseTable builds a sparse table for RMQ queries. 267 func (dt *dominatorSparseTree) buildSparseTable() { 268 n := len(dt.depth) 269 k := int(math.Log2(float64(n))) + 1 270 table := dt.table 271 272 if n >= len(table) { 273 table = append(table, make([][]int, n+1)...) 274 } 275 for i := range table { 276 if len(table[i]) < k { 277 table[i] = append(table[i], make([]int, k)...) 278 } 279 table[i][0] = i 280 } 281 282 for j := 1; 1<<j <= n; j++ { 283 for i := 0; i+(1<<j)-1 < n; i++ { 284 if dt.depth[table[i][j-1]] < dt.depth[table[i+(1<<(j-1))][j-1]] { 285 table[i][j] = table[i][j-1] 286 } else { 287 table[i][j] = table[i+(1<<(j-1))][j-1] 288 } 289 } 290 } 291 dt.table = table 292 } 293 294 // rmq performs a range minimum query on the sparse table. 295 func (dt *dominatorSparseTree) rmq(l, r int) int { 296 table := dt.table 297 depth := dt.depth 298 j := int(math.Log2(float64(r - l + 1))) 299 if depth[table[l][j]] <= depth[table[r-(1<<j)+1][j]] { 300 return table[l][j] 301 } 302 return table[r-(1<<j)+1][j] 303 } 304 305 // findLCA finds the LCA using the Euler tour and RMQ. 306 func (dt *dominatorSparseTree) findLCA(u, v BasicBlockID) *basicBlock { 307 first := dt.first 308 if first[u] > first[v] { 309 u, v = v, u 310 } 311 return dt.euler[dt.rmq(first[u], first[v])] 312 }