github.com/tetratelabs/wazero@v1.7.3-0.20240513003603-48f702e154b5/internal/engine/wazevo/ssa/pass.go (about)

     1  package ssa
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/tetratelabs/wazero/internal/engine/wazevo/wazevoapi"
     7  )
     8  
     9  // RunPasses implements Builder.RunPasses.
    10  //
    11  // The order here matters; some pass depends on the previous ones.
    12  //
    13  // Note that passes suffixed with "Opt" are the optimization passes, meaning that they edit the instructions and blocks
    14  // while the other passes are not, like passEstimateBranchProbabilities does not edit them, but only calculates the additional information.
    15  func (b *builder) RunPasses() {
    16  	b.runPreBlockLayoutPasses()
    17  	b.runBlockLayoutPass()
    18  	b.runPostBlockLayoutPasses()
    19  	b.runFinalizingPasses()
    20  }
    21  
    22  func (b *builder) runPreBlockLayoutPasses() {
    23  	passSortSuccessors(b)
    24  	passDeadBlockEliminationOpt(b)
    25  	passRedundantPhiEliminationOpt(b)
    26  	// The result of passCalculateImmediateDominators will be used by various passes below.
    27  	passCalculateImmediateDominators(b)
    28  	passNopInstElimination(b)
    29  
    30  	// TODO: implement either conversion of irreducible CFG into reducible one, or irreducible CFG detection where we panic.
    31  	// 	WebAssembly program shouldn't result in irreducible CFG, but we should handle it properly in just in case.
    32  	// 	See FixIrreducible pass in LLVM: https://llvm.org/doxygen/FixIrreducible_8cpp_source.html
    33  
    34  	// TODO: implement more optimization passes like:
    35  	// 	block coalescing.
    36  	// 	Copy-propagation.
    37  	// 	Constant folding.
    38  	// 	Common subexpression elimination.
    39  	// 	Arithmetic simplifications.
    40  	// 	and more!
    41  
    42  	// passDeadCodeEliminationOpt could be more accurate if we do this after other optimizations.
    43  	passDeadCodeEliminationOpt(b)
    44  	b.donePreBlockLayoutPasses = true
    45  }
    46  
    47  func (b *builder) runBlockLayoutPass() {
    48  	if !b.donePreBlockLayoutPasses {
    49  		panic("runBlockLayoutPass must be called after all pre passes are done")
    50  	}
    51  	passLayoutBlocks(b)
    52  	b.doneBlockLayout = true
    53  }
    54  
    55  // runPostBlockLayoutPasses runs the post block layout passes. After this point, CFG is somewhat stable,
    56  // but still can be modified before finalizing passes. At this point, critical edges are split by passLayoutBlocks.
    57  func (b *builder) runPostBlockLayoutPasses() {
    58  	if !b.doneBlockLayout {
    59  		panic("runPostBlockLayoutPasses must be called after block layout pass is done")
    60  	}
    61  	// TODO: Do more. e.g. tail duplication, loop unrolling, etc.
    62  
    63  	b.donePostBlockLayoutPasses = true
    64  }
    65  
    66  // runFinalizingPasses runs the finalizing passes. After this point, CFG should not be modified.
    67  func (b *builder) runFinalizingPasses() {
    68  	if !b.donePostBlockLayoutPasses {
    69  		panic("runFinalizingPasses must be called after post block layout passes are done")
    70  	}
    71  	// Critical edges are split, so we fix the loop nesting forest.
    72  	passBuildLoopNestingForest(b)
    73  	passBuildDominatorTree(b)
    74  	// Now that we know the final placement of the blocks, we can explicitly mark the fallthrough jumps.
    75  	b.markFallthroughJumps()
    76  }
    77  
    78  // passDeadBlockEliminationOpt searches the unreachable blocks, and sets the basicBlock.invalid flag true if so.
    79  func passDeadBlockEliminationOpt(b *builder) {
    80  	entryBlk := b.entryBlk()
    81  	b.clearBlkVisited()
    82  	b.blkStack = append(b.blkStack, entryBlk)
    83  	for len(b.blkStack) > 0 {
    84  		reachableBlk := b.blkStack[len(b.blkStack)-1]
    85  		b.blkStack = b.blkStack[:len(b.blkStack)-1]
    86  		b.blkVisited[reachableBlk] = 0 // the value won't be used in this pass.
    87  
    88  		if !reachableBlk.sealed && !reachableBlk.ReturnBlock() {
    89  			panic(fmt.Sprintf("%s is not sealed", reachableBlk))
    90  		}
    91  
    92  		if wazevoapi.SSAValidationEnabled {
    93  			reachableBlk.validate(b)
    94  		}
    95  
    96  		for _, succ := range reachableBlk.success {
    97  			if _, ok := b.blkVisited[succ]; ok {
    98  				continue
    99  			}
   100  			b.blkStack = append(b.blkStack, succ)
   101  		}
   102  	}
   103  
   104  	for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
   105  		if _, ok := b.blkVisited[blk]; !ok {
   106  			blk.invalid = true
   107  		}
   108  	}
   109  }
   110  
   111  // passRedundantPhiEliminationOpt eliminates the redundant PHIs (in our terminology, parameters of a block).
   112  func passRedundantPhiEliminationOpt(b *builder) {
   113  	redundantParameterIndexes := b.ints[:0] // reuse the slice from previous iterations.
   114  
   115  	// TODO: this might be costly for large programs, but at least, as far as I did the experiment, it's almost the
   116  	//  same as the single iteration version in terms of the overall compilation time. That *might be* mostly thanks to the fact
   117  	//  that removing many PHIs results in the reduction of the total instructions, not because of this indefinite iteration is
   118  	//  relatively small. For example, sqlite speedtest binary results in the large number of redundant PHIs,
   119  	//  the maximum number of iteration was 22, which seems to be acceptable but not that small either since the
   120  	//  complexity here is O(BlockNum * Iterations) at the worst case where BlockNum might be the order of thousands.
   121  	for {
   122  		changed := false
   123  		_ = b.blockIteratorBegin() // skip entry block!
   124  		// Below, we intentionally use the named iteration variable name, as this comes with inevitable nested for loops!
   125  		for blk := b.blockIteratorNext(); blk != nil; blk = b.blockIteratorNext() {
   126  			paramNum := len(blk.params)
   127  
   128  			for paramIndex := 0; paramIndex < paramNum; paramIndex++ {
   129  				phiValue := blk.params[paramIndex].value
   130  				redundant := true
   131  
   132  				nonSelfReferencingValue := ValueInvalid
   133  				for predIndex := range blk.preds {
   134  					br := blk.preds[predIndex].branch
   135  					// Resolve the alias in the arguments so that we could use the previous iteration's result.
   136  					b.resolveArgumentAlias(br)
   137  					pred := br.vs.View()[paramIndex]
   138  					if pred == phiValue {
   139  						// This is self-referencing: PHI from the same PHI.
   140  						continue
   141  					}
   142  
   143  					if !nonSelfReferencingValue.Valid() {
   144  						nonSelfReferencingValue = pred
   145  						continue
   146  					}
   147  
   148  					if nonSelfReferencingValue != pred {
   149  						redundant = false
   150  						break
   151  					}
   152  				}
   153  
   154  				if !nonSelfReferencingValue.Valid() {
   155  					// This shouldn't happen, and must be a bug in builder.go.
   156  					panic("BUG: params added but only self-referencing")
   157  				}
   158  
   159  				if redundant {
   160  					b.redundantParameterIndexToValue[paramIndex] = nonSelfReferencingValue
   161  					redundantParameterIndexes = append(redundantParameterIndexes, paramIndex)
   162  				}
   163  			}
   164  
   165  			if len(b.redundantParameterIndexToValue) == 0 {
   166  				continue
   167  			}
   168  			changed = true
   169  
   170  			// Remove the redundant PHIs from the argument list of branching instructions.
   171  			for predIndex := range blk.preds {
   172  				var cur int
   173  				predBlk := blk.preds[predIndex]
   174  				branchInst := predBlk.branch
   175  				view := branchInst.vs.View()
   176  				for argIndex, value := range view {
   177  					if _, ok := b.redundantParameterIndexToValue[argIndex]; !ok {
   178  						view[cur] = value
   179  						cur++
   180  					}
   181  				}
   182  				branchInst.vs.Cut(cur)
   183  			}
   184  
   185  			// Still need to have the definition of the value of the PHI (previously as the parameter).
   186  			for _, redundantParamIndex := range redundantParameterIndexes {
   187  				phiValue := blk.params[redundantParamIndex].value
   188  				onlyValue := b.redundantParameterIndexToValue[redundantParamIndex]
   189  				// Create an alias in this block from the only phi argument to the phi value.
   190  				b.alias(phiValue, onlyValue)
   191  			}
   192  
   193  			// Finally, Remove the param from the blk.
   194  			var cur int
   195  			for paramIndex := 0; paramIndex < paramNum; paramIndex++ {
   196  				param := blk.params[paramIndex]
   197  				if _, ok := b.redundantParameterIndexToValue[paramIndex]; !ok {
   198  					blk.params[cur] = param
   199  					cur++
   200  				}
   201  			}
   202  			blk.params = blk.params[:cur]
   203  
   204  			// Clears the map for the next iteration.
   205  			for _, paramIndex := range redundantParameterIndexes {
   206  				delete(b.redundantParameterIndexToValue, paramIndex)
   207  			}
   208  			redundantParameterIndexes = redundantParameterIndexes[:0]
   209  		}
   210  
   211  		if !changed {
   212  			break
   213  		}
   214  	}
   215  
   216  	// Reuse the slice for the future passes.
   217  	b.ints = redundantParameterIndexes
   218  }
   219  
   220  // passDeadCodeEliminationOpt traverses all the instructions, and calculates the reference count of each Value, and
   221  // eliminates all the unnecessary instructions whose ref count is zero.
   222  // The results are stored at builder.valueRefCounts. This also assigns a InstructionGroupID to each Instruction
   223  // during the process. This is the last SSA-level optimization pass and after this,
   224  // the SSA function is ready to be used by backends.
   225  //
   226  // TODO: the algorithm here might not be efficient. Get back to this later.
   227  func passDeadCodeEliminationOpt(b *builder) {
   228  	nvid := int(b.nextValueID)
   229  	if nvid >= len(b.valueRefCounts) {
   230  		b.valueRefCounts = append(b.valueRefCounts, make([]int, b.nextValueID)...)
   231  	}
   232  	if nvid >= len(b.valueIDToInstruction) {
   233  		b.valueIDToInstruction = append(b.valueIDToInstruction, make([]*Instruction, b.nextValueID)...)
   234  	}
   235  
   236  	// First, we gather all the instructions with side effects.
   237  	liveInstructions := b.instStack[:0]
   238  	// During the process, we will assign InstructionGroupID to each instruction, which is not
   239  	// relevant to dead code elimination, but we need in the backend.
   240  	var gid InstructionGroupID
   241  	for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
   242  		for cur := blk.rootInstr; cur != nil; cur = cur.next {
   243  			cur.gid = gid
   244  			switch cur.sideEffect() {
   245  			case sideEffectTraps:
   246  				// The trappable should always be alive.
   247  				liveInstructions = append(liveInstructions, cur)
   248  			case sideEffectStrict:
   249  				liveInstructions = append(liveInstructions, cur)
   250  				// The strict side effect should create different instruction groups.
   251  				gid++
   252  			}
   253  
   254  			r1, rs := cur.Returns()
   255  			if r1.Valid() {
   256  				b.valueIDToInstruction[r1.ID()] = cur
   257  			}
   258  			for _, r := range rs {
   259  				b.valueIDToInstruction[r.ID()] = cur
   260  			}
   261  		}
   262  	}
   263  
   264  	// Find all the instructions referenced by live instructions transitively.
   265  	for len(liveInstructions) > 0 {
   266  		tail := len(liveInstructions) - 1
   267  		live := liveInstructions[tail]
   268  		liveInstructions = liveInstructions[:tail]
   269  		if live.live {
   270  			// If it's already marked alive, this is referenced multiple times,
   271  			// so we can skip it.
   272  			continue
   273  		}
   274  		live.live = true
   275  
   276  		// Before we walk, we need to resolve the alias first.
   277  		b.resolveArgumentAlias(live)
   278  
   279  		v1, v2, v3, vs := live.Args()
   280  		if v1.Valid() {
   281  			producingInst := b.valueIDToInstruction[v1.ID()]
   282  			if producingInst != nil {
   283  				liveInstructions = append(liveInstructions, producingInst)
   284  			}
   285  		}
   286  
   287  		if v2.Valid() {
   288  			producingInst := b.valueIDToInstruction[v2.ID()]
   289  			if producingInst != nil {
   290  				liveInstructions = append(liveInstructions, producingInst)
   291  			}
   292  		}
   293  
   294  		if v3.Valid() {
   295  			producingInst := b.valueIDToInstruction[v3.ID()]
   296  			if producingInst != nil {
   297  				liveInstructions = append(liveInstructions, producingInst)
   298  			}
   299  		}
   300  
   301  		for _, v := range vs {
   302  			producingInst := b.valueIDToInstruction[v.ID()]
   303  			if producingInst != nil {
   304  				liveInstructions = append(liveInstructions, producingInst)
   305  			}
   306  		}
   307  	}
   308  
   309  	// Now that all the live instructions are flagged as live=true, we eliminate all dead instructions.
   310  	for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
   311  		for cur := blk.rootInstr; cur != nil; cur = cur.next {
   312  			if !cur.live {
   313  				// Remove the instruction from the list.
   314  				if prev := cur.prev; prev != nil {
   315  					prev.next = cur.next
   316  				} else {
   317  					blk.rootInstr = cur.next
   318  				}
   319  				if next := cur.next; next != nil {
   320  					next.prev = cur.prev
   321  				}
   322  				continue
   323  			}
   324  
   325  			// If the value alive, we can be sure that arguments are used definitely.
   326  			// Hence, we can increment the value reference counts.
   327  			v1, v2, v3, vs := cur.Args()
   328  			if v1.Valid() {
   329  				b.incRefCount(v1.ID(), cur)
   330  			}
   331  			if v2.Valid() {
   332  				b.incRefCount(v2.ID(), cur)
   333  			}
   334  			if v3.Valid() {
   335  				b.incRefCount(v3.ID(), cur)
   336  			}
   337  			for _, v := range vs {
   338  				b.incRefCount(v.ID(), cur)
   339  			}
   340  		}
   341  	}
   342  
   343  	b.instStack = liveInstructions // we reuse the stack for the next iteration.
   344  }
   345  
   346  func (b *builder) incRefCount(id ValueID, from *Instruction) {
   347  	if wazevoapi.SSALoggingEnabled {
   348  		fmt.Printf("v%d referenced from %v\n", id, from.Format(b))
   349  	}
   350  	b.valueRefCounts[id]++
   351  }
   352  
   353  // clearBlkVisited clears the b.blkVisited map so that we can reuse it for multiple places.
   354  func (b *builder) clearBlkVisited() {
   355  	b.blkStack2 = b.blkStack2[:0]
   356  	for key := range b.blkVisited {
   357  		b.blkStack2 = append(b.blkStack2, key)
   358  	}
   359  	for _, blk := range b.blkStack2 {
   360  		delete(b.blkVisited, blk)
   361  	}
   362  	b.blkStack2 = b.blkStack2[:0]
   363  }
   364  
   365  // passNopInstElimination eliminates the instructions which is essentially a no-op.
   366  func passNopInstElimination(b *builder) {
   367  	if int(b.nextValueID) >= len(b.valueIDToInstruction) {
   368  		b.valueIDToInstruction = append(b.valueIDToInstruction, make([]*Instruction, b.nextValueID)...)
   369  	}
   370  
   371  	for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
   372  		for cur := blk.rootInstr; cur != nil; cur = cur.next {
   373  			r1, rs := cur.Returns()
   374  			if r1.Valid() {
   375  				b.valueIDToInstruction[r1.ID()] = cur
   376  			}
   377  			for _, r := range rs {
   378  				b.valueIDToInstruction[r.ID()] = cur
   379  			}
   380  		}
   381  	}
   382  
   383  	for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
   384  		for cur := blk.rootInstr; cur != nil; cur = cur.next {
   385  			switch cur.Opcode() {
   386  			// TODO: add more logics here.
   387  			case OpcodeIshl, OpcodeSshr, OpcodeUshr:
   388  				x, amount := cur.Arg2()
   389  				definingInst := b.valueIDToInstruction[amount.ID()]
   390  				if definingInst == nil {
   391  					// If there's no defining instruction, that means the amount is coming from the parameter.
   392  					continue
   393  				}
   394  				if definingInst.Constant() {
   395  					v := definingInst.ConstantVal()
   396  
   397  					if x.Type().Bits() == 64 {
   398  						v = v % 64
   399  					} else {
   400  						v = v % 32
   401  					}
   402  					if v == 0 {
   403  						b.alias(cur.Return(), x)
   404  					}
   405  				}
   406  			}
   407  		}
   408  	}
   409  }
   410  
   411  // passSortSuccessors sorts the successors of each block in the natural program order.
   412  func passSortSuccessors(b *builder) {
   413  	for i := 0; i < b.basicBlocksPool.Allocated(); i++ {
   414  		blk := b.basicBlocksPool.View(i)
   415  		sortBlocks(blk.success)
   416  	}
   417  }