github.com/wasilibs/wazerox@v0.0.0-20240124024944-4923be63ab5f/internal/engine/wazevo/ssa/pass.go (about)

     1  package ssa
     2  
     3  import (
     4  	"fmt"
     5  	"sort"
     6  
     7  	"github.com/wasilibs/wazerox/internal/engine/wazevo/wazevoapi"
     8  )
     9  
    10  // RunPasses implements Builder.RunPasses.
    11  //
    12  // The order here matters; some pass depends on the previous ones.
    13  //
    14  // Note that passes suffixed with "Opt" are the optimization passes, meaning that they edit the instructions and blocks
    15  // while the other passes are not, like passEstimateBranchProbabilities does not edit them, but only calculates the additional information.
    16  func (b *builder) RunPasses() {
    17  	passSortSuccessors(b)
    18  	passDeadBlockEliminationOpt(b)
    19  	passRedundantPhiEliminationOpt(b)
    20  	// The result of passCalculateImmediateDominators will be used by various passes below.
    21  	passCalculateImmediateDominators(b)
    22  	passNopInstElimination(b)
    23  
    24  	// TODO: implement either conversion of irreducible CFG into reducible one, or irreducible CFG detection where we panic.
    25  	// 	WebAssembly program shouldn't result in irreducible CFG, but we should handle it properly in just in case.
    26  	// 	See FixIrreducible pass in LLVM: https://llvm.org/doxygen/FixIrreducible_8cpp_source.html
    27  
    28  	// TODO: implement more optimization passes like:
    29  	// 	block coalescing.
    30  	// 	Copy-propagation.
    31  	// 	Constant folding.
    32  	// 	Common subexpression elimination.
    33  	// 	Arithmetic simplifications.
    34  	// 	and more!
    35  
    36  	// passDeadCodeEliminationOpt could be more accurate if we do this after other optimizations.
    37  	passDeadCodeEliminationOpt(b)
    38  	b.donePasses = true
    39  }
    40  
    41  // passDeadBlockEliminationOpt searches the unreachable blocks, and sets the basicBlock.invalid flag true if so.
    42  func passDeadBlockEliminationOpt(b *builder) {
    43  	entryBlk := b.entryBlk()
    44  	b.clearBlkVisited()
    45  	b.blkStack = append(b.blkStack, entryBlk)
    46  	for len(b.blkStack) > 0 {
    47  		reachableBlk := b.blkStack[len(b.blkStack)-1]
    48  		b.blkStack = b.blkStack[:len(b.blkStack)-1]
    49  		b.blkVisited[reachableBlk] = 0 // the value won't be used in this pass.
    50  
    51  		if !reachableBlk.sealed && !reachableBlk.ReturnBlock() {
    52  			panic(fmt.Sprintf("%s is not sealed", reachableBlk))
    53  		}
    54  
    55  		if wazevoapi.SSAValidationEnabled {
    56  			reachableBlk.validate(b)
    57  		}
    58  
    59  		for _, succ := range reachableBlk.success {
    60  			if _, ok := b.blkVisited[succ]; ok {
    61  				continue
    62  			}
    63  			b.blkStack = append(b.blkStack, succ)
    64  		}
    65  	}
    66  
    67  	for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
    68  		if _, ok := b.blkVisited[blk]; !ok {
    69  			blk.invalid = true
    70  		}
    71  	}
    72  }
    73  
    74  // passRedundantPhiEliminationOpt eliminates the redundant PHIs (in our terminology, parameters of a block).
    75  func passRedundantPhiEliminationOpt(b *builder) {
    76  	redundantParameterIndexes := b.ints[:0] // reuse the slice from previous iterations.
    77  
    78  	_ = b.blockIteratorBegin() // skip entry block!
    79  	// Below, we intentionally use the named iteration variable name, as this comes with inevitable nested for loops!
    80  	for blk := b.blockIteratorNext(); blk != nil; blk = b.blockIteratorNext() {
    81  		paramNum := len(blk.params)
    82  
    83  		for paramIndex := 0; paramIndex < paramNum; paramIndex++ {
    84  			phiValue := blk.params[paramIndex].value
    85  			redundant := true
    86  
    87  			nonSelfReferencingValue := ValueInvalid
    88  			for predIndex := range blk.preds {
    89  				pred := blk.preds[predIndex].branch.vs[paramIndex]
    90  				if pred == phiValue {
    91  					// This is self-referencing: PHI from the same PHI.
    92  					continue
    93  				}
    94  
    95  				if !nonSelfReferencingValue.Valid() {
    96  					nonSelfReferencingValue = pred
    97  					continue
    98  				}
    99  
   100  				if nonSelfReferencingValue != pred {
   101  					redundant = false
   102  					break
   103  				}
   104  			}
   105  
   106  			if !nonSelfReferencingValue.Valid() {
   107  				// This shouldn't happen, and must be a bug in builder.go.
   108  				panic("BUG: params added but only self-referencing")
   109  			}
   110  
   111  			if redundant {
   112  				b.redundantParameterIndexToValue[paramIndex] = nonSelfReferencingValue
   113  				redundantParameterIndexes = append(redundantParameterIndexes, paramIndex)
   114  			}
   115  		}
   116  
   117  		if len(b.redundantParameterIndexToValue) == 0 {
   118  			continue
   119  		}
   120  
   121  		// Remove the redundant PHIs from the argument list of branching instructions.
   122  		for predIndex := range blk.preds {
   123  			var cur int
   124  			predBlk := blk.preds[predIndex]
   125  			branchInst := predBlk.branch
   126  			for argIndex, value := range branchInst.vs {
   127  				if _, ok := b.redundantParameterIndexToValue[argIndex]; !ok {
   128  					branchInst.vs[cur] = value
   129  					cur++
   130  				}
   131  			}
   132  			branchInst.vs = branchInst.vs[:cur]
   133  		}
   134  
   135  		// Still need to have the definition of the value of the PHI (previously as the parameter).
   136  		for _, redundantParamIndex := range redundantParameterIndexes {
   137  			phiValue := blk.params[redundantParamIndex].value
   138  			onlyValue := b.redundantParameterIndexToValue[redundantParamIndex]
   139  			// Create an alias in this block from the only phi argument to the phi value.
   140  			b.alias(phiValue, onlyValue)
   141  		}
   142  
   143  		// Finally, Remove the param from the blk.
   144  		var cur int
   145  		for paramIndex := 0; paramIndex < paramNum; paramIndex++ {
   146  			param := blk.params[paramIndex]
   147  			if _, ok := b.redundantParameterIndexToValue[paramIndex]; !ok {
   148  				blk.params[cur] = param
   149  				cur++
   150  			}
   151  		}
   152  		blk.params = blk.params[:cur]
   153  
   154  		// Clears the map for the next iteration.
   155  		for _, paramIndex := range redundantParameterIndexes {
   156  			delete(b.redundantParameterIndexToValue, paramIndex)
   157  		}
   158  		redundantParameterIndexes = redundantParameterIndexes[:0]
   159  	}
   160  
   161  	// Reuse the slice for the future passes.
   162  	b.ints = redundantParameterIndexes
   163  }
   164  
   165  // passDeadCodeEliminationOpt traverses all the instructions, and calculates the reference count of each Value, and
   166  // eliminates all the unnecessary instructions whose ref count is zero.
   167  // The results are stored at builder.valueRefCounts. This also assigns a InstructionGroupID to each Instruction
   168  // during the process. This is the last SSA-level optimization pass and after this,
   169  // the SSA function is ready to be used by backends.
   170  //
   171  // TODO: the algorithm here might not be efficient. Get back to this later.
   172  func passDeadCodeEliminationOpt(b *builder) {
   173  	nvid := int(b.nextValueID)
   174  	if nvid >= len(b.valueRefCounts) {
   175  		b.valueRefCounts = append(b.valueRefCounts, make([]int, b.nextValueID)...)
   176  	}
   177  	if nvid >= len(b.valueIDToInstruction) {
   178  		b.valueIDToInstruction = append(b.valueIDToInstruction, make([]*Instruction, b.nextValueID)...)
   179  	}
   180  
   181  	// First, we gather all the instructions with side effects.
   182  	liveInstructions := b.instStack[:0]
   183  	// During the process, we will assign InstructionGroupID to each instruction, which is not
   184  	// relevant to dead code elimination, but we need in the backend.
   185  	var gid InstructionGroupID
   186  	for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
   187  		for cur := blk.rootInstr; cur != nil; cur = cur.next {
   188  			cur.gid = gid
   189  			switch cur.sideEffect() {
   190  			case sideEffectTraps:
   191  				// The trappable should always be alive.
   192  				liveInstructions = append(liveInstructions, cur)
   193  			case sideEffectStrict:
   194  				liveInstructions = append(liveInstructions, cur)
   195  				// The strict side effect should create different instruction groups.
   196  				gid++
   197  			}
   198  
   199  			r1, rs := cur.Returns()
   200  			if r1.Valid() {
   201  				b.valueIDToInstruction[r1.ID()] = cur
   202  			}
   203  			for _, r := range rs {
   204  				b.valueIDToInstruction[r.ID()] = cur
   205  			}
   206  		}
   207  	}
   208  
   209  	// Find all the instructions referenced by live instructions transitively.
   210  	for len(liveInstructions) > 0 {
   211  		tail := len(liveInstructions) - 1
   212  		live := liveInstructions[tail]
   213  		liveInstructions = liveInstructions[:tail]
   214  		if live.live {
   215  			// If it's already marked alive, this is referenced multiple times,
   216  			// so we can skip it.
   217  			continue
   218  		}
   219  		live.live = true
   220  
   221  		// Before we walk, we need to resolve the alias first.
   222  		b.resolveArgumentAlias(live)
   223  
   224  		v1, v2, v3, vs := live.Args()
   225  		if v1.Valid() {
   226  			producingInst := b.valueIDToInstruction[v1.ID()]
   227  			if producingInst != nil {
   228  				liveInstructions = append(liveInstructions, producingInst)
   229  			}
   230  		}
   231  
   232  		if v2.Valid() {
   233  			producingInst := b.valueIDToInstruction[v2.ID()]
   234  			if producingInst != nil {
   235  				liveInstructions = append(liveInstructions, producingInst)
   236  			}
   237  		}
   238  
   239  		if v3.Valid() {
   240  			producingInst := b.valueIDToInstruction[v3.ID()]
   241  			if producingInst != nil {
   242  				liveInstructions = append(liveInstructions, producingInst)
   243  			}
   244  		}
   245  
   246  		for _, v := range vs {
   247  			producingInst := b.valueIDToInstruction[v.ID()]
   248  			if producingInst != nil {
   249  				liveInstructions = append(liveInstructions, producingInst)
   250  			}
   251  		}
   252  	}
   253  
   254  	// Now that all the live instructions are flagged as live=true, we eliminate all dead instructions.
   255  	for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
   256  		for cur := blk.rootInstr; cur != nil; cur = cur.next {
   257  			if !cur.live {
   258  				// Remove the instruction from the list.
   259  				if prev := cur.prev; prev != nil {
   260  					prev.next = cur.next
   261  				} else {
   262  					blk.rootInstr = cur.next
   263  				}
   264  				if next := cur.next; next != nil {
   265  					next.prev = cur.prev
   266  				}
   267  				continue
   268  			}
   269  
   270  			// If the value alive, we can be sure that arguments are used definitely.
   271  			// Hence, we can increment the value reference counts.
   272  			v1, v2, v3, vs := cur.Args()
   273  			if v1.Valid() {
   274  				b.incRefCount(v1.ID(), cur)
   275  			}
   276  			if v2.Valid() {
   277  				b.incRefCount(v2.ID(), cur)
   278  			}
   279  			if v3.Valid() {
   280  				b.incRefCount(v3.ID(), cur)
   281  			}
   282  			for _, v := range vs {
   283  				b.incRefCount(v.ID(), cur)
   284  			}
   285  		}
   286  	}
   287  
   288  	b.instStack = liveInstructions // we reuse the stack for the next iteration.
   289  }
   290  
   291  func (b *builder) incRefCount(id ValueID, from *Instruction) {
   292  	if wazevoapi.SSALoggingEnabled {
   293  		fmt.Printf("v%d referenced from %v\n", id, from.Format(b))
   294  	}
   295  	b.valueRefCounts[id]++
   296  }
   297  
   298  // clearBlkVisited clears the b.blkVisited map so that we can reuse it for multiple places.
   299  func (b *builder) clearBlkVisited() {
   300  	b.blkStack2 = b.blkStack2[:0]
   301  	for key := range b.blkVisited {
   302  		b.blkStack2 = append(b.blkStack2, key)
   303  	}
   304  	for _, blk := range b.blkStack2 {
   305  		delete(b.blkVisited, blk)
   306  	}
   307  	b.blkStack2 = b.blkStack2[:0]
   308  }
   309  
   310  // passNopInstElimination eliminates the instructions which is essentially a no-op.
   311  func passNopInstElimination(b *builder) {
   312  	if int(b.nextValueID) >= len(b.valueIDToInstruction) {
   313  		b.valueIDToInstruction = append(b.valueIDToInstruction, make([]*Instruction, b.nextValueID)...)
   314  	}
   315  
   316  	for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
   317  		for cur := blk.rootInstr; cur != nil; cur = cur.next {
   318  			r1, rs := cur.Returns()
   319  			if r1.Valid() {
   320  				b.valueIDToInstruction[r1.ID()] = cur
   321  			}
   322  			for _, r := range rs {
   323  				b.valueIDToInstruction[r.ID()] = cur
   324  			}
   325  		}
   326  	}
   327  
   328  	for blk := b.blockIteratorBegin(); blk != nil; blk = b.blockIteratorNext() {
   329  		for cur := blk.rootInstr; cur != nil; cur = cur.next {
   330  			switch cur.Opcode() {
   331  			// TODO: add more logics here.
   332  			case OpcodeIshl, OpcodeSshr, OpcodeUshr:
   333  				x, amount := cur.Arg2()
   334  				definingInst := b.valueIDToInstruction[amount.ID()]
   335  				if definingInst == nil {
   336  					// If there's no defining instruction, that means the amount is coming from the parameter.
   337  					continue
   338  				}
   339  				if definingInst.Constant() {
   340  					v := definingInst.ConstantVal()
   341  
   342  					if x.Type().Bits() == 64 {
   343  						v = v % 64
   344  					} else {
   345  						v = v % 32
   346  					}
   347  					if v == 0 {
   348  						b.alias(cur.Return(), x)
   349  					}
   350  				}
   351  			}
   352  		}
   353  	}
   354  }
   355  
   356  // passSortSuccessors sorts the successors of each block in the natural program order.
   357  func passSortSuccessors(b *builder) {
   358  	for i := 0; i < b.basicBlocksPool.Allocated(); i++ {
   359  		blk := b.basicBlocksPool.View(i)
   360  		sort.SliceStable(blk.success, func(i, j int) bool {
   361  			iBlk, jBlk := blk.success[i], blk.success[j]
   362  			if jBlk.ReturnBlock() {
   363  				return true
   364  			}
   365  			if iBlk.ReturnBlock() {
   366  				return false
   367  			}
   368  			iRoot, jRoot := iBlk.rootInstr, jBlk.rootInstr
   369  			if iRoot == nil || jRoot == nil { // For testing.
   370  				return true
   371  			}
   372  			return iBlk.rootInstr.id < jBlk.rootInstr.id
   373  		})
   374  	}
   375  }