github.com/wasilibs/wazerox@v0.0.0-20240124024944-4923be63ab5f/internal/engine/wazevo/ssa/pass.go (about) 1 package ssa 2 3 import ( 4 "fmt" 5 "sort" 6 7 "github.com/wasilibs/wazerox/internal/engine/wazevo/wazevoapi" 8 ) 9 10 // RunPasses implements Builder.RunPasses. 11 // 12 // The order here matters; some pass depends on the previous ones. 13 // 14 // Note that passes suffixed with "Opt" are the optimization passes, meaning that they edit the instructions and blocks 15 // while the other passes are not, like passEstimateBranchProbabilities does not edit them, but only calculates the additional information. 16 func (b *builder) RunPasses() { 17 passSortSuccessors(b) 18 passDeadBlockEliminationOpt(b) 19 passRedundantPhiEliminationOpt(b) 20 // The result of passCalculateImmediateDominators will be used by various passes below. 21 passCalculateImmediateDominators(b) 22 passNopInstElimination(b) 23 24 // TODO: implement either conversion of irreducible CFG into reducible one, or irreducible CFG detection where we panic. 25 // WebAssembly program shouldn't result in irreducible CFG, but we should handle it properly in just in case. 26 // See FixIrreducible pass in LLVM: https://llvm.org/doxygen/FixIrreducible_8cpp_source.html 27 28 // TODO: implement more optimization passes like: 29 // block coalescing. 30 // Copy-propagation. 31 // Constant folding. 32 // Common subexpression elimination. 33 // Arithmetic simplifications. 34 // and more! 35 36 // passDeadCodeEliminationOpt could be more accurate if we do this after other optimizations. 37 passDeadCodeEliminationOpt(b) 38 b.donePasses = true 39 } 40 41 // passDeadBlockEliminationOpt searches the unreachable blocks, and sets the basicBlock.invalid flag true if so. 42 func passDeadBlockEliminationOpt(b *builder) { 43 entryBlk := b.entryBlk() 44 b.clearBlkVisited() 45 b.blkStack = append(b.blkStack, entryBlk) 46 for len(b.blkStack) > 0 { 47 reachableBlk := b.blkStack[len(b.blkStack)-1] 48 b.blkStack = b.blkStack[:len(b.blkStack)-1] 49 b.blkVisited[reachableBlk] = 0 // the value won't be used in this pass. 50 51 if !reachableBlk.sealed && !reachableBlk.ReturnBlock() { 52 panic(fmt.Sprintf("%s is not sealed", reachableBlk)) 53 } 54 55 if wazevoapi.SSAValidationEnabled { 56 reachableBlk.validate(b) 57 } 58 59 for _, succ := range reachableBlk.success { 60 if _, ok := b.blkVisited[succ]; ok { 61 continue 62 } 63 b.blkStack = append(b.blkStack, succ) 64 } 65 } 66 67 for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() { 68 if _, ok := b.blkVisited[blk]; !ok { 69 blk.invalid = true 70 } 71 } 72 } 73 74 // passRedundantPhiEliminationOpt eliminates the redundant PHIs (in our terminology, parameters of a block). 75 func passRedundantPhiEliminationOpt(b *builder) { 76 redundantParameterIndexes := b.ints[:0] // reuse the slice from previous iterations. 77 78 _ = b.blockIteratorBegin() // skip entry block! 79 // Below, we intentionally use the named iteration variable name, as this comes with inevitable nested for loops! 80 for blk := b.blockIteratorNext(); blk != nil; blk = b.blockIteratorNext() { 81 paramNum := len(blk.params) 82 83 for paramIndex := 0; paramIndex < paramNum; paramIndex++ { 84 phiValue := blk.params[paramIndex].value 85 redundant := true 86 87 nonSelfReferencingValue := ValueInvalid 88 for predIndex := range blk.preds { 89 pred := blk.preds[predIndex].branch.vs[paramIndex] 90 if pred == phiValue { 91 // This is self-referencing: PHI from the same PHI. 92 continue 93 } 94 95 if !nonSelfReferencingValue.Valid() { 96 nonSelfReferencingValue = pred 97 continue 98 } 99 100 if nonSelfReferencingValue != pred { 101 redundant = false 102 break 103 } 104 } 105 106 if !nonSelfReferencingValue.Valid() { 107 // This shouldn't happen, and must be a bug in builder.go. 108 panic("BUG: params added but only self-referencing") 109 } 110 111 if redundant { 112 b.redundantParameterIndexToValue[paramIndex] = nonSelfReferencingValue 113 redundantParameterIndexes = append(redundantParameterIndexes, paramIndex) 114 } 115 } 116 117 if len(b.redundantParameterIndexToValue) == 0 { 118 continue 119 } 120 121 // Remove the redundant PHIs from the argument list of branching instructions. 122 for predIndex := range blk.preds { 123 var cur int 124 predBlk := blk.preds[predIndex] 125 branchInst := predBlk.branch 126 for argIndex, value := range branchInst.vs { 127 if _, ok := b.redundantParameterIndexToValue[argIndex]; !ok { 128 branchInst.vs[cur] = value 129 cur++ 130 } 131 } 132 branchInst.vs = branchInst.vs[:cur] 133 } 134 135 // Still need to have the definition of the value of the PHI (previously as the parameter). 136 for _, redundantParamIndex := range redundantParameterIndexes { 137 phiValue := blk.params[redundantParamIndex].value 138 onlyValue := b.redundantParameterIndexToValue[redundantParamIndex] 139 // Create an alias in this block from the only phi argument to the phi value. 140 b.alias(phiValue, onlyValue) 141 } 142 143 // Finally, Remove the param from the blk. 144 var cur int 145 for paramIndex := 0; paramIndex < paramNum; paramIndex++ { 146 param := blk.params[paramIndex] 147 if _, ok := b.redundantParameterIndexToValue[paramIndex]; !ok { 148 blk.params[cur] = param 149 cur++ 150 } 151 } 152 blk.params = blk.params[:cur] 153 154 // Clears the map for the next iteration. 155 for _, paramIndex := range redundantParameterIndexes { 156 delete(b.redundantParameterIndexToValue, paramIndex) 157 } 158 redundantParameterIndexes = redundantParameterIndexes[:0] 159 } 160 161 // Reuse the slice for the future passes. 162 b.ints = redundantParameterIndexes 163 } 164 165 // passDeadCodeEliminationOpt traverses all the instructions, and calculates the reference count of each Value, and 166 // eliminates all the unnecessary instructions whose ref count is zero. 167 // The results are stored at builder.valueRefCounts. This also assigns a InstructionGroupID to each Instruction 168 // during the process. This is the last SSA-level optimization pass and after this, 169 // the SSA function is ready to be used by backends. 170 // 171 // TODO: the algorithm here might not be efficient. Get back to this later. 172 func passDeadCodeEliminationOpt(b *builder) { 173 nvid := int(b.nextValueID) 174 if nvid >= len(b.valueRefCounts) { 175 b.valueRefCounts = append(b.valueRefCounts, make([]int, b.nextValueID)...) 176 } 177 if nvid >= len(b.valueIDToInstruction) { 178 b.valueIDToInstruction = append(b.valueIDToInstruction, make([]*Instruction, b.nextValueID)...) 179 } 180 181 // First, we gather all the instructions with side effects. 182 liveInstructions := b.instStack[:0] 183 // During the process, we will assign InstructionGroupID to each instruction, which is not 184 // relevant to dead code elimination, but we need in the backend. 185 var gid InstructionGroupID 186 for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() { 187 for cur := blk.rootInstr; cur != nil; cur = cur.next { 188 cur.gid = gid 189 switch cur.sideEffect() { 190 case sideEffectTraps: 191 // The trappable should always be alive. 192 liveInstructions = append(liveInstructions, cur) 193 case sideEffectStrict: 194 liveInstructions = append(liveInstructions, cur) 195 // The strict side effect should create different instruction groups. 196 gid++ 197 } 198 199 r1, rs := cur.Returns() 200 if r1.Valid() { 201 b.valueIDToInstruction[r1.ID()] = cur 202 } 203 for _, r := range rs { 204 b.valueIDToInstruction[r.ID()] = cur 205 } 206 } 207 } 208 209 // Find all the instructions referenced by live instructions transitively. 210 for len(liveInstructions) > 0 { 211 tail := len(liveInstructions) - 1 212 live := liveInstructions[tail] 213 liveInstructions = liveInstructions[:tail] 214 if live.live { 215 // If it's already marked alive, this is referenced multiple times, 216 // so we can skip it. 217 continue 218 } 219 live.live = true 220 221 // Before we walk, we need to resolve the alias first. 222 b.resolveArgumentAlias(live) 223 224 v1, v2, v3, vs := live.Args() 225 if v1.Valid() { 226 producingInst := b.valueIDToInstruction[v1.ID()] 227 if producingInst != nil { 228 liveInstructions = append(liveInstructions, producingInst) 229 } 230 } 231 232 if v2.Valid() { 233 producingInst := b.valueIDToInstruction[v2.ID()] 234 if producingInst != nil { 235 liveInstructions = append(liveInstructions, producingInst) 236 } 237 } 238 239 if v3.Valid() { 240 producingInst := b.valueIDToInstruction[v3.ID()] 241 if producingInst != nil { 242 liveInstructions = append(liveInstructions, producingInst) 243 } 244 } 245 246 for _, v := range vs { 247 producingInst := b.valueIDToInstruction[v.ID()] 248 if producingInst != nil { 249 liveInstructions = append(liveInstructions, producingInst) 250 } 251 } 252 } 253 254 // Now that all the live instructions are flagged as live=true, we eliminate all dead instructions. 255 for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() { 256 for cur := blk.rootInstr; cur != nil; cur = cur.next { 257 if !cur.live { 258 // Remove the instruction from the list. 259 if prev := cur.prev; prev != nil { 260 prev.next = cur.next 261 } else { 262 blk.rootInstr = cur.next 263 } 264 if next := cur.next; next != nil { 265 next.prev = cur.prev 266 } 267 continue 268 } 269 270 // If the value alive, we can be sure that arguments are used definitely. 271 // Hence, we can increment the value reference counts. 272 v1, v2, v3, vs := cur.Args() 273 if v1.Valid() { 274 b.incRefCount(v1.ID(), cur) 275 } 276 if v2.Valid() { 277 b.incRefCount(v2.ID(), cur) 278 } 279 if v3.Valid() { 280 b.incRefCount(v3.ID(), cur) 281 } 282 for _, v := range vs { 283 b.incRefCount(v.ID(), cur) 284 } 285 } 286 } 287 288 b.instStack = liveInstructions // we reuse the stack for the next iteration. 289 } 290 291 func (b *builder) incRefCount(id ValueID, from *Instruction) { 292 if wazevoapi.SSALoggingEnabled { 293 fmt.Printf("v%d referenced from %v\n", id, from.Format(b)) 294 } 295 b.valueRefCounts[id]++ 296 } 297 298 // clearBlkVisited clears the b.blkVisited map so that we can reuse it for multiple places. 299 func (b *builder) clearBlkVisited() { 300 b.blkStack2 = b.blkStack2[:0] 301 for key := range b.blkVisited { 302 b.blkStack2 = append(b.blkStack2, key) 303 } 304 for _, blk := range b.blkStack2 { 305 delete(b.blkVisited, blk) 306 } 307 b.blkStack2 = b.blkStack2[:0] 308 } 309 310 // passNopInstElimination eliminates the instructions which is essentially a no-op. 311 func passNopInstElimination(b *builder) { 312 if int(b.nextValueID) >= len(b.valueIDToInstruction) { 313 b.valueIDToInstruction = append(b.valueIDToInstruction, make([]*Instruction, b.nextValueID)...) 314 } 315 316 for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() { 317 for cur := blk.rootInstr; cur != nil; cur = cur.next { 318 r1, rs := cur.Returns() 319 if r1.Valid() { 320 b.valueIDToInstruction[r1.ID()] = cur 321 } 322 for _, r := range rs { 323 b.valueIDToInstruction[r.ID()] = cur 324 } 325 } 326 } 327 328 for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() { 329 for cur := blk.rootInstr; cur != nil; cur = cur.next { 330 switch cur.Opcode() { 331 // TODO: add more logics here. 332 case OpcodeIshl, OpcodeSshr, OpcodeUshr: 333 x, amount := cur.Arg2() 334 definingInst := b.valueIDToInstruction[amount.ID()] 335 if definingInst == nil { 336 // If there's no defining instruction, that means the amount is coming from the parameter. 337 continue 338 } 339 if definingInst.Constant() { 340 v := definingInst.ConstantVal() 341 342 if x.Type().Bits() == 64 { 343 v = v % 64 344 } else { 345 v = v % 32 346 } 347 if v == 0 { 348 b.alias(cur.Return(), x) 349 } 350 } 351 } 352 } 353 } 354 } 355 356 // passSortSuccessors sorts the successors of each block in the natural program order. 357 func passSortSuccessors(b *builder) { 358 for i := 0; i < b.basicBlocksPool.Allocated(); i++ { 359 blk := b.basicBlocksPool.View(i) 360 sort.SliceStable(blk.success, func(i, j int) bool { 361 iBlk, jBlk := blk.success[i], blk.success[j] 362 if jBlk.ReturnBlock() { 363 return true 364 } 365 if iBlk.ReturnBlock() { 366 return false 367 } 368 iRoot, jRoot := iBlk.rootInstr, jBlk.rootInstr 369 if iRoot == nil || jRoot == nil { // For testing. 370 return true 371 } 372 return iBlk.rootInstr.id < jBlk.rootInstr.id 373 }) 374 } 375 }