github.com/wasilibs/wazerox@v0.0.0-20240124024944-4923be63ab5f/internal/engine/wazevo/backend/regalloc/regalloc.go (about)

     1  // Package regalloc performs register allocation. The algorithm can work on any ISA by implementing the interfaces in
     2  // api.go.
     3  package regalloc
     4  
     5  // References:
     6  // * https://web.stanford.edu/class/archive/cs/cs143/cs143.1128/lectures/17/Slides17.pdf
     7  // * https://en.wikipedia.org/wiki/Chaitin%27s_algorithm
     8  // * https://llvm.org/ProjectsWithLLVM/2004-Fall-CS426-LS.pdf
     9  // * https://pfalcon.github.io/ssabook/latest/book-full.pdf: Chapter 9. for liveness analysis.
    10  
    11  import (
    12  	"fmt"
    13  	"math"
    14  	"strings"
    15  
    16  	"github.com/wasilibs/wazerox/internal/engine/wazevo/wazevoapi"
    17  )
    18  
    19  // NewAllocator returns a new Allocator.
    20  func NewAllocator(allocatableRegs *RegisterInfo) Allocator {
    21  	a := Allocator{
    22  		regInfo:               allocatableRegs,
    23  		blockLivenessDataPool: wazevoapi.NewPool[blockLivenessData](resetBlockLivenessData),
    24  		phiDefInstListPool:    wazevoapi.NewPool[phiDefInstList](resetPhiDefInstList),
    25  	}
    26  	a.state.reset()
    27  	for _, regs := range allocatableRegs.AllocatableRegisters {
    28  		for _, r := range regs {
    29  			a.allocatableSet = a.allocatableSet.add(r)
    30  		}
    31  	}
    32  	return a
    33  }
    34  
    35  type (
    36  	// RegisterInfo holds the statically-known ISA-specific register information.
    37  	RegisterInfo struct {
    38  		// AllocatableRegisters is a 2D array of allocatable RealReg, indexed by regTypeNum and regNum.
    39  		// The order matters: the first element is the most preferred one when allocating.
    40  		AllocatableRegisters [NumRegType][]RealReg
    41  		CalleeSavedRegisters [RealRegsNumMax]bool
    42  		CallerSavedRegisters [RealRegsNumMax]bool
    43  		RealRegToVReg        []VReg
    44  		// RealRegName returns the name of the given RealReg for debugging.
    45  		RealRegName func(r RealReg) string
    46  		RealRegType func(r RealReg) RegType
    47  	}
    48  
    49  	// Allocator is a register allocator.
    50  	Allocator struct {
    51  		// regInfo is static per ABI/ISA, and is initialized by the machine during Machine.PrepareRegisterAllocator.
    52  		regInfo *RegisterInfo
    53  		// allocatableSet is a set of allocatable RealReg derived from regInfo. Static per ABI/ISA.
    54  		allocatableSet           regSet
    55  		allocatedCalleeSavedRegs []VReg
    56  		blockLivenessDataPool    wazevoapi.Pool[blockLivenessData]
    57  		blockLivenessData        [] /* blockID to */ *blockLivenessData
    58  		vs                       []VReg
    59  		maxBlockID               int
    60  		phiDefInstListPool       wazevoapi.Pool[phiDefInstList]
    61  
    62  		// Followings are re-used during various places e.g. coloring.
    63  		blks             []Block
    64  		reals            []RealReg
    65  		currentOccupants regInUseSet
    66  
    67  		// Following two fields are updated while iterating the blocks in the reverse postorder.
    68  		state       state
    69  		blockStates [] /* blockID to */ blockState
    70  	}
    71  
    72  	// blockLivenessData is a per-block information used during the register allocation.
    73  	blockLivenessData struct {
    74  		seen     bool
    75  		liveOuts map[VReg]struct{}
    76  		liveIns  map[VReg]struct{}
    77  	}
    78  
    79  	// programCounter represents an opaque index into the program which is used to represents a LiveInterval of a VReg.
    80  	programCounter int32
    81  
    82  	state struct {
    83  		argRealRegs          []VReg
    84  		regsInUse            regInUseSet
    85  		vrStates             []vrState
    86  		maxVRegIDEncountered int
    87  
    88  		// allocatedRegSet is a set of RealReg that are allocated during the allocation phase. This is reset per function.
    89  		allocatedRegSet regSet
    90  	}
    91  
    92  	blockState struct {
    93  		visited            bool
    94  		startFromPredIndex int
    95  		// startRegs is a list of RealReg that are used at the beginning of the block. This is used to fix the merge edges.
    96  		startRegs regInUseSet
    97  		// endRegs is a list of RealReg that are used at the end of the block. This is used to fix the merge edges.
    98  		endRegs regInUseSet
    99  		init    bool
   100  	}
   101  
   102  	vrState struct {
   103  		v VReg
   104  		r RealReg
   105  		// defInstr is the instruction that defines this value. If this is the phi value and not the entry block, this is nil.
   106  		defInstr Instr
   107  		// defBlk is the block that defines this value. If this is the phi value, this is the block whose arguments contain this value.
   108  		defBlk Block
   109  		// spilled is true if this value is spilled i.e. the value is reload from the stack somewhere in the program.
   110  		spilled bool
   111  		// lca = lowest common ancestor. This is the block that is the lowest common ancestor of all the blocks that
   112  		// reloads this value. This is used to determine the spill location. Only valid if spilled=true.
   113  		lca Block
   114  		// lastUse is the program counter of the last use of this value. This changes while iterating the block, and
   115  		// should not be used across the blocks as it becomes invalid.
   116  		lastUse programCounter
   117  		// isPhi is true if this is a phi value.
   118  		isPhi bool
   119  		// phiDefInstList is a list of instructions that defines this phi value.
   120  		// This is used to determine the spill location, and only valid if isPhi=true.
   121  		*phiDefInstList
   122  	}
   123  
   124  	// phiDefInstList is a linked list of instructions that defines a phi value.
   125  	phiDefInstList struct {
   126  		instr Instr
   127  		next  *phiDefInstList
   128  	}
   129  )
   130  
   131  func resetPhiDefInstList(l *phiDefInstList) {
   132  	l.instr = nil
   133  	l.next = nil
   134  }
   135  
   136  func (s *state) dump(info *RegisterInfo) { //nolint:unused
   137  	fmt.Println("\t\tstate:")
   138  	fmt.Println("\t\t\targRealRegs:", s.argRealRegs)
   139  	fmt.Println("\t\t\tregsInUse", s.regsInUse.format(info))
   140  	fmt.Println("\t\t\tallocatedRegSet:", s.allocatedRegSet.format(info))
   141  	fmt.Println("\t\t\tused:", s.regsInUse.format(info))
   142  	fmt.Println("\t\t\tmaxVRegIDEncountered:", s.maxVRegIDEncountered)
   143  	var strs []string
   144  	for i, v := range s.vrStates {
   145  		if v.r != RealRegInvalid {
   146  			strs = append(strs, fmt.Sprintf("(v%d: %s)", i, info.RealRegName(v.r)))
   147  		}
   148  	}
   149  	fmt.Println("\t\t\tvrStates:", strings.Join(strs, ", "))
   150  }
   151  
   152  func (s *state) reset() {
   153  	s.argRealRegs = s.argRealRegs[:0]
   154  	for i, l := 0, len(s.vrStates); i <= s.maxVRegIDEncountered && i < l; i++ {
   155  		s.vrStates[i].reset()
   156  	}
   157  	s.maxVRegIDEncountered = -1
   158  	s.allocatedRegSet = regSet(0)
   159  	s.regsInUse.reset()
   160  }
   161  
   162  func (a *Allocator) getBlockState(bID int) *blockState {
   163  	if bID >= len(a.blockStates) {
   164  		a.blockStates = append(a.blockStates, make([]blockState, (bID+1)-len(a.blockStates))...)
   165  		a.blockStates = a.blockStates[:cap(a.blockStates)]
   166  	}
   167  	ret := &a.blockStates[bID]
   168  	if !ret.init {
   169  		ret.reset()
   170  		ret.init = true
   171  	}
   172  	return ret
   173  }
   174  
   175  func (s *state) setVRegState(v VReg, r RealReg) {
   176  	id := int(v.ID())
   177  	if id >= len(s.vrStates) {
   178  		s.vrStates = append(s.vrStates, make([]vrState, id+1-len(s.vrStates))...)
   179  		s.vrStates = s.vrStates[:cap(s.vrStates)]
   180  	}
   181  
   182  	st := &s.vrStates[id]
   183  	st.r = r
   184  	st.v = v
   185  }
   186  
   187  func (vs *vrState) reset() {
   188  	vs.r = RealRegInvalid
   189  	vs.defInstr = nil
   190  	vs.defBlk = nil
   191  	vs.spilled = false
   192  	vs.lca = nil
   193  	vs.isPhi = false
   194  	vs.phiDefInstList = nil
   195  }
   196  
   197  func (s *state) getVRegState(v VReg) *vrState {
   198  	id := int(v.ID())
   199  	if id >= len(s.vrStates) {
   200  		s.setVRegState(v, RealRegInvalid)
   201  	}
   202  	if s.maxVRegIDEncountered < id {
   203  		s.maxVRegIDEncountered = id
   204  	}
   205  	return &s.vrStates[id]
   206  }
   207  
   208  func (s *state) useRealReg(r RealReg, v VReg) {
   209  	if s.regsInUse.has(r) {
   210  		panic("BUG: useRealReg: the given real register is already used")
   211  	}
   212  	s.regsInUse.add(r, v)
   213  	s.setVRegState(v, r)
   214  	s.allocatedRegSet = s.allocatedRegSet.add(r)
   215  }
   216  
   217  func (s *state) releaseRealReg(r RealReg) {
   218  	current := s.regsInUse.get(r)
   219  	if current.Valid() {
   220  		s.regsInUse.remove(r)
   221  		s.setVRegState(current, RealRegInvalid)
   222  	}
   223  }
   224  
   225  // recordReload records that the given VReg is reloaded in the given block.
   226  // This is used to determine the spill location by tracking the lowest common ancestor of all the blocks that reloads the value.
   227  func (vs *vrState) recordReload(f Function, blk Block) {
   228  	vs.spilled = true
   229  	if vs.lca == nil {
   230  		if wazevoapi.RegAllocLoggingEnabled {
   231  			fmt.Printf("\t\tv%d is reloaded in blk%d,\n", vs.v.ID(), blk.ID())
   232  		}
   233  		vs.lca = blk
   234  	} else {
   235  		if wazevoapi.RegAllocLoggingEnabled {
   236  			fmt.Printf("\t\tv%d is reloaded in blk%d, lca=%d\n", vs.v.ID(), blk.ID(), vs.lca.ID())
   237  		}
   238  		vs.lca = f.LowestCommonAncestor(vs.lca, blk)
   239  		if wazevoapi.RegAllocLoggingEnabled {
   240  			fmt.Printf("updated lca=%d\n", vs.lca.ID())
   241  		}
   242  	}
   243  }
   244  
   245  func (s *state) findOrSpillAllocatable(a *Allocator, allocatable []RealReg, forbiddenMask regSet) (r RealReg) {
   246  	r = RealRegInvalid
   247  	var lastUseAt programCounter = math.MinInt32
   248  	var spillVReg VReg
   249  	for _, candidateReal := range allocatable {
   250  		if forbiddenMask.has(candidateReal) {
   251  			continue
   252  		}
   253  
   254  		using := s.regsInUse.get(candidateReal)
   255  		if using == VRegInvalid {
   256  			// This is not used at this point.
   257  			return candidateReal
   258  		}
   259  
   260  		if last := s.getVRegState(using).lastUse; last > lastUseAt {
   261  			lastUseAt = last
   262  			r = candidateReal
   263  			spillVReg = using
   264  		}
   265  	}
   266  
   267  	if r == RealRegInvalid {
   268  		panic("not found any allocatable register")
   269  	}
   270  
   271  	if wazevoapi.RegAllocLoggingEnabled {
   272  		fmt.Printf("\tspilling v%d when: %s\n", spillVReg.ID(), forbiddenMask.format(a.regInfo))
   273  	}
   274  	s.releaseRealReg(r)
   275  	return r
   276  }
   277  
   278  func (s *state) findAllocatable(allocatable []RealReg, forbiddenMask regSet) RealReg {
   279  	for _, r := range allocatable {
   280  		if !s.regsInUse.has(r) && !forbiddenMask.has(r) {
   281  			return r
   282  		}
   283  	}
   284  	return RealRegInvalid
   285  }
   286  
   287  func (s *state) resetAt(bs *blockState, liveIns map[VReg]struct{}) {
   288  	s.regsInUse.range_(func(_ RealReg, vr VReg) {
   289  		s.setVRegState(vr, RealRegInvalid)
   290  	})
   291  	s.regsInUse.reset()
   292  	bs.endRegs.range_(func(r RealReg, v VReg) {
   293  		if _, ok := liveIns[v]; ok {
   294  			s.regsInUse.add(r, v)
   295  			s.setVRegState(v, r)
   296  		}
   297  	})
   298  }
   299  
   300  func (b *blockState) reset() {
   301  	b.visited = false
   302  	b.endRegs.reset()
   303  	b.startRegs.reset()
   304  	b.startFromPredIndex = -1
   305  	b.init = false
   306  }
   307  
   308  func (b *blockState) dump(a *RegisterInfo) {
   309  	fmt.Println("\t\tblockState:")
   310  	fmt.Println("\t\t\tstartRegs:", b.startRegs.format(a))
   311  	fmt.Println("\t\t\tendRegs:", b.endRegs.format(a))
   312  	fmt.Println("\t\t\tstartFromPredIndex:", b.startFromPredIndex)
   313  	fmt.Println("\t\t\tvisited:", b.visited)
   314  	fmt.Println("\t\t\tinit:", b.init)
   315  }
   316  
   317  // DoAllocation performs register allocation on the given Function.
   318  func (a *Allocator) DoAllocation(f Function) {
   319  	a.livenessAnalysis(f)
   320  	a.alloc(f)
   321  	a.determineCalleeSavedRealRegs(f)
   322  	f.Done()
   323  }
   324  
   325  func (a *Allocator) determineCalleeSavedRealRegs(f Function) {
   326  	a.allocatedCalleeSavedRegs = a.allocatedCalleeSavedRegs[:0]
   327  	a.state.allocatedRegSet.range_(func(allocatedRealReg RealReg) {
   328  		if a.regInfo.isCalleeSaved(allocatedRealReg) {
   329  			a.allocatedCalleeSavedRegs = append(a.allocatedCalleeSavedRegs, a.regInfo.RealRegToVReg[allocatedRealReg])
   330  		}
   331  	})
   332  	f.ClobberedRegisters(a.allocatedCalleeSavedRegs)
   333  }
   334  
   335  // phiBlk returns the block that defines the given phi value, nil otherwise.
   336  func (s *state) phiBlk(v VReg) Block {
   337  	vs := s.getVRegState(v)
   338  	if vs.isPhi {
   339  		return vs.defBlk
   340  	}
   341  	return nil
   342  }
   343  
   344  // liveAnalysis constructs Allocator.blockLivenessData.
   345  // The algorithm here is described in https://pfalcon.github.io/ssabook/latest/book-full.pdf Chapter 9.2.
   346  func (a *Allocator) livenessAnalysis(f Function) {
   347  	// First, we need to allocate blockLivenessData.
   348  	s := &a.state
   349  	for blk := f.PostOrderBlockIteratorBegin(); blk != nil; blk = f.PostOrderBlockIteratorNext() { // Order doesn't matter.
   350  		a.allocateBlockLivenessData(blk.ID())
   351  
   352  		// We should gather phi value data.
   353  		for _, p := range blk.BlockParams(&a.vs) {
   354  			vs := s.getVRegState(p)
   355  			vs.isPhi = true
   356  			vs.defBlk = blk
   357  		}
   358  		if blk.ID() > a.maxBlockID {
   359  			a.maxBlockID = blk.ID()
   360  		}
   361  	}
   362  
   363  	// Run the Algorithm 9.2 in the bool.
   364  	for blk := f.PostOrderBlockIteratorBegin(); blk != nil; blk = f.PostOrderBlockIteratorNext() {
   365  		blkID := blk.ID()
   366  		info := a.livenessDataAt(blkID)
   367  
   368  		ns := blk.Succs()
   369  		for i := 0; i < ns; i++ {
   370  			succ := blk.Succ(i)
   371  			if succ == nil {
   372  				continue
   373  			}
   374  
   375  			succID := succ.ID()
   376  			succInfo := a.livenessDataAt(succID)
   377  			if !succInfo.seen { // This means the back edge.
   378  				continue
   379  			}
   380  
   381  			for v := range succInfo.liveIns {
   382  				if s.phiBlk(v) != succ {
   383  					info.liveOuts[v] = struct{}{}
   384  					info.liveIns[v] = struct{}{}
   385  				}
   386  			}
   387  		}
   388  
   389  		for instr := blk.InstrRevIteratorBegin(); instr != nil; instr = blk.InstrRevIteratorNext() {
   390  
   391  			var use, def VReg
   392  			for _, def = range instr.Defs(&a.vs) {
   393  				if !def.IsRealReg() {
   394  					delete(info.liveIns, def)
   395  				}
   396  			}
   397  			for _, use = range instr.Uses(&a.vs) {
   398  				if !use.IsRealReg() {
   399  					info.liveIns[use] = struct{}{}
   400  				}
   401  			}
   402  
   403  			// If the destination is a phi value, and ...
   404  			if def.Valid() && s.phiBlk(def) != nil {
   405  				if use.Valid() && use.IsRealReg() {
   406  					// If the source is a real register, this is the beginning of the function.
   407  					a.state.argRealRegs = append(a.state.argRealRegs, use)
   408  				} else {
   409  					// Otherwise, this is the definition of the phi value for the successor block.
   410  					// So we need to make it outlive the block.
   411  					info.liveOuts[def] = struct{}{}
   412  				}
   413  			}
   414  		}
   415  		info.seen = true
   416  	}
   417  
   418  	nrs := f.LoopNestingForestRoots()
   419  	for i := 0; i < nrs; i++ {
   420  		root := f.LoopNestingForestRoot(i)
   421  		a.loopTreeDFS(root)
   422  	}
   423  }
   424  
   425  // loopTreeDFS implements the Algorithm 9.3 in the book in an iterative way.
   426  func (a *Allocator) loopTreeDFS(entry Block) {
   427  	a.blks = a.blks[:0]
   428  	a.blks = append(a.blks, entry)
   429  
   430  	s := &a.state
   431  	for len(a.blks) > 0 {
   432  		tail := len(a.blks) - 1
   433  		loop := a.blks[tail]
   434  		a.blks = a.blks[:tail]
   435  		a.vs = a.vs[:0]
   436  
   437  		info := a.livenessDataAt(loop.ID())
   438  		for v := range info.liveIns {
   439  			if s.phiBlk(v) != loop {
   440  				a.vs = append(a.vs, v)
   441  				info.liveOuts[v] = struct{}{}
   442  			}
   443  		}
   444  
   445  		cn := loop.LoopNestingForestChildren()
   446  		for i := 0; i < cn; i++ {
   447  			child := loop.LoopNestingForestChild(i)
   448  			childID := child.ID()
   449  			childInfo := a.livenessDataAt(childID)
   450  			for _, v := range a.vs {
   451  				childInfo.liveIns[v] = struct{}{}
   452  				childInfo.liveOuts[v] = struct{}{}
   453  			}
   454  			if child.LoopHeader() {
   455  				a.blks = append(a.blks, child)
   456  			}
   457  		}
   458  	}
   459  }
   460  
   461  // alloc allocates registers for the given function by iterating the blocks in the reverse postorder.
   462  // The algorithm here is derived from the Go compiler's allocator https://github.com/golang/go/blob/release-branch.go1.21/src/cmd/compile/internal/ssa/regalloc.go
   463  // In short, this is a simply linear scan register allocation where each block inherits the register allocation state from
   464  // one of its predecessors. Each block inherits the selected state and starts allocation from there.
   465  // If there's a discrepancy in the end states between predecessors, the adjustments are made to ensure consistency after allocation is done (which we call "fixing merge state").
   466  // The spill instructions (store into the dedicated slots) are inserted after all the allocations and fixing merge states. That is because
   467  // at the point, we all know where the reloads happen, and therefore we can know the best place to spill the values. More precisely,
   468  // the spill happens in the block that is the lowest common ancestor of all the blocks that reloads the value.
   469  //
   470  // All of these logics are almost the same as Go's compiler which has a dedicated description in the source file ^^.
   471  func (a *Allocator) alloc(f Function) {
   472  	// First we allocate each block in the reverse postorder (at least one predecessor should be allocated for each block).
   473  	for blk := f.ReversePostOrderBlockIteratorBegin(); blk != nil; blk = f.ReversePostOrderBlockIteratorNext() {
   474  		if wazevoapi.RegAllocLoggingEnabled {
   475  			fmt.Printf("========== allocating blk%d ========\n", blk.ID())
   476  		}
   477  		a.allocBlock(f, blk)
   478  	}
   479  	// After the allocation, we all know the start and end state of each block. So we can fix the merge states.
   480  	for blk := f.ReversePostOrderBlockIteratorBegin(); blk != nil; blk = f.ReversePostOrderBlockIteratorNext() {
   481  		a.fixMergeState(f, blk)
   482  	}
   483  	// Finally, we insert the spill instructions as we know all the places where the reloads happen.
   484  	a.scheduleSpills(f)
   485  }
   486  
   487  func (a *Allocator) allocBlock(f Function, blk Block) {
   488  	bID := blk.ID()
   489  	liveness := a.livenessDataAt(bID)
   490  	s := &a.state
   491  	currentBlkState := a.getBlockState(bID)
   492  
   493  	preds := blk.Preds()
   494  	var predState *blockState
   495  	switch preds {
   496  	case 0: // This is the entry block.
   497  	case 1:
   498  		predID := blk.Pred(0).ID()
   499  		predState = a.getBlockState(predID)
   500  		currentBlkState.startFromPredIndex = 0
   501  	default:
   502  		// TODO: there should be some better heuristic to choose the predecessor.
   503  		for i := 0; i < preds; i++ {
   504  			predID := blk.Pred(i).ID()
   505  			if _predState := a.getBlockState(predID); _predState.visited {
   506  				predState = _predState
   507  				currentBlkState.startFromPredIndex = i
   508  				break
   509  			}
   510  		}
   511  	}
   512  	if predState == nil {
   513  		if !blk.Entry() {
   514  			panic(fmt.Sprintf("BUG: at lease one predecessor should be visited for blk%d", blk.ID()))
   515  		}
   516  		for _, u := range s.argRealRegs {
   517  			s.useRealReg(u.RealReg(), u)
   518  		}
   519  	} else if predState != nil {
   520  		if wazevoapi.RegAllocLoggingEnabled {
   521  			fmt.Printf("allocating blk%d starting from blk%d (on index=%d) \n",
   522  				bID, blk.Pred(currentBlkState.startFromPredIndex).ID(), currentBlkState.startFromPredIndex)
   523  		}
   524  		s.resetAt(predState, liveness.liveIns)
   525  	}
   526  
   527  	s.regsInUse.range_(func(allocated RealReg, v VReg) {
   528  		currentBlkState.startRegs.add(allocated, v)
   529  	})
   530  
   531  	// Update the last use of each VReg.
   532  	var pc programCounter
   533  	for instr := blk.InstrIteratorBegin(); instr != nil; instr = blk.InstrIteratorNext() {
   534  		for _, use := range instr.Uses(&a.vs) {
   535  			if !use.IsRealReg() {
   536  				s.getVRegState(use).lastUse = pc
   537  			}
   538  		}
   539  		pc++
   540  	}
   541  	// Reset the last use of the liveOuts.
   542  	for outlive := range liveness.liveOuts {
   543  		s.getVRegState(outlive).lastUse = math.MaxInt32
   544  	}
   545  
   546  	pc = 0
   547  	for instr := blk.InstrIteratorBegin(); instr != nil; instr = blk.InstrIteratorNext() {
   548  		if wazevoapi.RegAllocLoggingEnabled {
   549  			fmt.Println(instr)
   550  		}
   551  
   552  		var currentUsedSet regSet
   553  		killSet := a.reals[:0]
   554  
   555  		// Gather the set of registers that will be used in the current instruction.
   556  		for _, use := range instr.Uses(&a.vs) {
   557  			if use.IsRealReg() {
   558  				r := use.RealReg()
   559  				currentUsedSet = currentUsedSet.add(r)
   560  				if a.allocatableSet.has(r) {
   561  					killSet = append(killSet, r)
   562  				}
   563  			} else {
   564  				vs := s.getVRegState(use)
   565  				if r := vs.r; r != RealRegInvalid {
   566  					currentUsedSet = currentUsedSet.add(r)
   567  				}
   568  			}
   569  		}
   570  
   571  		for i, use := range instr.Uses(&a.vs) {
   572  			if !use.IsRealReg() {
   573  				vs := s.getVRegState(use)
   574  				killed := liveness.isKilledAt(vs, pc)
   575  				r := vs.r
   576  
   577  				if r == RealRegInvalid {
   578  					r = s.findOrSpillAllocatable(a, a.regInfo.AllocatableRegisters[use.RegType()], currentUsedSet)
   579  					vs.recordReload(f, blk)
   580  					f.ReloadRegisterBefore(use.SetRealReg(r), instr)
   581  					s.useRealReg(r, use)
   582  				}
   583  				if wazevoapi.RegAllocLoggingEnabled {
   584  					fmt.Printf("\ttrying to use v%v on %s\n", use.ID(), a.regInfo.RealRegName(r))
   585  				}
   586  				instr.AssignUse(i, use.SetRealReg(r))
   587  				currentUsedSet = currentUsedSet.add(r)
   588  				if killed {
   589  					if wazevoapi.RegAllocLoggingEnabled {
   590  						fmt.Printf("\tkill v%d with %s\n", use.ID(), a.regInfo.RealRegName(r))
   591  					}
   592  					killSet = append(killSet, r)
   593  				}
   594  			}
   595  		}
   596  
   597  		isIndirect := instr.IsIndirectCall()
   598  		call := instr.IsCall() || isIndirect
   599  		if call {
   600  			addr := RealRegInvalid
   601  			if instr.IsIndirectCall() {
   602  				addr = a.vs[0].RealReg()
   603  			}
   604  			a.releaseCallerSavedRegs(addr)
   605  		}
   606  
   607  		for _, r := range killSet {
   608  			s.releaseRealReg(r)
   609  		}
   610  		a.reals = killSet
   611  
   612  		defs := instr.Defs(&a.vs)
   613  		switch {
   614  		case len(defs) > 1:
   615  			if !call {
   616  				panic("only call can have multiple defs")
   617  			}
   618  			// Call's defining register are all caller-saved registers.
   619  			// Therefore, we can assume that all of them are allocatable.
   620  			for _, def := range defs {
   621  				s.useRealReg(def.RealReg(), def)
   622  			}
   623  		case len(defs) == 1:
   624  			def := defs[0]
   625  			if def.IsRealReg() {
   626  				r := def.RealReg()
   627  				if a.allocatableSet.has(r) {
   628  					if s.regsInUse.has(r) {
   629  						s.releaseRealReg(r)
   630  					}
   631  					s.useRealReg(r, def)
   632  				}
   633  			} else {
   634  				vState := s.getVRegState(def)
   635  				r := vState.r
   636  				// Allocate a new real register if `def` is not currently assigned one.
   637  				// It can happen when multiple instructions define the same VReg (e.g. const loads).
   638  				if r == RealRegInvalid {
   639  					if instr.IsCopy() {
   640  						copySrc := instr.Uses(&a.vs)[0].RealReg()
   641  						if a.allocatableSet.has(copySrc) && !s.regsInUse.has(copySrc) {
   642  							r = copySrc
   643  						}
   644  					}
   645  					if r == RealRegInvalid {
   646  						typ := def.RegType()
   647  						r = s.findOrSpillAllocatable(a, a.regInfo.AllocatableRegisters[typ], regSet(0))
   648  					}
   649  					s.useRealReg(r, def)
   650  				}
   651  				instr.AssignDef(def.SetRealReg(r))
   652  				if wazevoapi.RegAllocLoggingEnabled {
   653  					fmt.Printf("\tdefining v%d with %s\n", def.ID(), a.regInfo.RealRegName(r))
   654  				}
   655  				if vState.isPhi {
   656  					n := a.phiDefInstListPool.Allocate()
   657  					n.instr = instr
   658  					n.next = vState.phiDefInstList
   659  					vState.phiDefInstList = n
   660  				} else {
   661  					vState.defInstr = instr
   662  					vState.defBlk = blk
   663  				}
   664  			}
   665  		}
   666  		if wazevoapi.RegAllocLoggingEnabled {
   667  			fmt.Println(instr)
   668  		}
   669  		pc++
   670  	}
   671  
   672  	s.regsInUse.range_(func(allocated RealReg, v VReg) {
   673  		currentBlkState.endRegs.add(allocated, v)
   674  	})
   675  
   676  	currentBlkState.visited = true
   677  	if wazevoapi.RegAllocLoggingEnabled {
   678  		currentBlkState.dump(a.regInfo)
   679  	}
   680  }
   681  
   682  func (a *Allocator) releaseCallerSavedRegs(addrReg RealReg) {
   683  	s := &a.state
   684  
   685  	for i := 0; i < 64; i++ {
   686  		allocated := RealReg(i)
   687  		if allocated == addrReg { // If this is the call indirect, we should not touch the addr register.
   688  			continue
   689  		}
   690  		if v := s.regsInUse.get(allocated); v.Valid() {
   691  			if v.IsRealReg() {
   692  				continue // This is the argument register as it's already used by VReg backed by the corresponding RealReg.
   693  			}
   694  			if !a.regInfo.isCallerSaved(allocated) {
   695  				// If this is not a caller-saved register, it is safe to keep it across the call.
   696  				continue
   697  			}
   698  			s.releaseRealReg(allocated)
   699  		}
   700  	}
   701  }
   702  
   703  func (a *Allocator) fixMergeState(f Function, blk Block) {
   704  	preds := blk.Preds()
   705  	if preds <= 1 {
   706  		return
   707  	}
   708  
   709  	s := &a.state
   710  
   711  	// Restores the state at the beginning of the block.
   712  	bID := blk.ID()
   713  	blkSt := a.getBlockState(bID)
   714  	desiredOccupants := &blkSt.startRegs
   715  	aliveOnRegVRegs := make(map[VReg]RealReg)
   716  	for i := 0; i < 64; i++ {
   717  		r := RealReg(i)
   718  		if v := blkSt.startRegs.get(r); v.Valid() {
   719  			aliveOnRegVRegs[v] = r
   720  		}
   721  	}
   722  
   723  	if wazevoapi.RegAllocLoggingEnabled {
   724  		fmt.Println("fixMergeState", blk.ID(), ":", desiredOccupants.format(a.regInfo))
   725  	}
   726  
   727  	currentOccupants := &a.currentOccupants
   728  	for i := 0; i < preds; i++ {
   729  		currentOccupants.reset()
   730  		if i == blkSt.startFromPredIndex {
   731  			continue
   732  		}
   733  
   734  		currentOccupantsRev := make(map[VReg]RealReg)
   735  		pred := blk.Pred(i)
   736  		predSt := a.getBlockState(pred.ID())
   737  		for ii := 0; ii < 64; ii++ {
   738  			r := RealReg(ii)
   739  			if v := predSt.endRegs.get(r); v.Valid() {
   740  				if _, ok := aliveOnRegVRegs[v]; !ok {
   741  					continue
   742  				}
   743  				currentOccupants.add(r, v)
   744  				currentOccupantsRev[v] = r
   745  			}
   746  		}
   747  
   748  		s.resetAt(predSt, a.livenessDataAt(bID).liveIns)
   749  
   750  		// Finds the free registers if any.
   751  		intTmp, floatTmp := VRegInvalid, VRegInvalid
   752  		if intFree := s.findAllocatable(
   753  			a.regInfo.AllocatableRegisters[RegTypeInt], desiredOccupants.set,
   754  		); intFree != RealRegInvalid {
   755  			intTmp = FromRealReg(intFree, RegTypeInt)
   756  		}
   757  		if floatFree := s.findAllocatable(
   758  			a.regInfo.AllocatableRegisters[RegTypeFloat], desiredOccupants.set,
   759  		); floatFree != RealRegInvalid {
   760  			floatTmp = FromRealReg(floatFree, RegTypeFloat)
   761  		}
   762  
   763  		if wazevoapi.RegAllocLoggingEnabled {
   764  			fmt.Println("\t", pred.ID(), ":", currentOccupants.format(a.regInfo))
   765  		}
   766  
   767  		for ii := 0; ii < 64; ii++ {
   768  			r := RealReg(ii)
   769  			desiredVReg := desiredOccupants.get(r)
   770  			if !desiredVReg.Valid() {
   771  				continue
   772  			}
   773  
   774  			currentVReg := currentOccupants.get(r)
   775  			if desiredVReg.ID() == currentVReg.ID() {
   776  				continue
   777  			}
   778  
   779  			typ := desiredVReg.RegType()
   780  			var tmpRealReg VReg
   781  			if typ == RegTypeInt {
   782  				tmpRealReg = intTmp
   783  			} else {
   784  				tmpRealReg = floatTmp
   785  			}
   786  			a.reconcileEdge(f, r, pred, currentOccupants, currentOccupantsRev, currentVReg, desiredVReg, tmpRealReg, typ)
   787  		}
   788  	}
   789  }
   790  
   791  func (a *Allocator) reconcileEdge(f Function,
   792  	r RealReg,
   793  	pred Block,
   794  	currentOccupants *regInUseSet,
   795  	currentOccupantsRev map[VReg]RealReg,
   796  	currentVReg, desiredVReg VReg,
   797  	freeReg VReg,
   798  	typ RegType,
   799  ) {
   800  	s := &a.state
   801  	if currentVReg.Valid() {
   802  		// Both are on reg.
   803  		er, ok := currentOccupantsRev[desiredVReg]
   804  		if !ok {
   805  			if wazevoapi.RegAllocLoggingEnabled {
   806  				fmt.Printf("\t\tv%d is desired to be on %s, but currently on the stack\n",
   807  					desiredVReg.ID(), a.regInfo.RealRegName(r),
   808  				)
   809  			}
   810  			// This case is that the desired value is on the stack, but currentVReg is on the target register.
   811  			// We need to move the current value to the stack, and reload the desired value.
   812  			// TODO: we can do better here.
   813  			f.StoreRegisterBefore(currentVReg.SetRealReg(r), pred.LastInstr())
   814  			delete(currentOccupantsRev, currentVReg)
   815  
   816  			s.getVRegState(desiredVReg).recordReload(f, pred)
   817  			f.ReloadRegisterBefore(desiredVReg.SetRealReg(r), pred.LastInstr())
   818  			currentOccupants.add(r, desiredVReg)
   819  			currentOccupantsRev[desiredVReg] = r
   820  			return
   821  		}
   822  
   823  		if wazevoapi.RegAllocLoggingEnabled {
   824  			fmt.Printf("\t\tv%d is desired to be on %s, but currently on %s\n",
   825  				desiredVReg.ID(), a.regInfo.RealRegName(r), a.regInfo.RealRegName(er),
   826  			)
   827  		}
   828  		f.SwapAtEndOfBlock(
   829  			currentVReg.SetRealReg(r),
   830  			desiredVReg.SetRealReg(er),
   831  			freeReg,
   832  			pred,
   833  		)
   834  		s.allocatedRegSet = s.allocatedRegSet.add(freeReg.RealReg())
   835  		currentOccupantsRev[desiredVReg] = r
   836  		currentOccupantsRev[currentVReg] = er
   837  		currentOccupants.add(r, desiredVReg)
   838  		currentOccupants.add(er, currentVReg)
   839  		if wazevoapi.RegAllocLoggingEnabled {
   840  			fmt.Printf("\t\tv%d previously on %s moved to %s\n", currentVReg.ID(), a.regInfo.RealRegName(r), a.regInfo.RealRegName(er))
   841  		}
   842  	} else {
   843  		// Desired is on reg, but currently the target register is not used.
   844  		if wazevoapi.RegAllocLoggingEnabled {
   845  			fmt.Printf("\t\tv%d is desired to be on %s, current not used\n",
   846  				desiredVReg.ID(), a.regInfo.RealRegName(r),
   847  			)
   848  		}
   849  		if currentReg, ok := currentOccupantsRev[desiredVReg]; ok {
   850  			f.InsertMoveBefore(
   851  				FromRealReg(r, typ),
   852  				desiredVReg.SetRealReg(currentReg),
   853  				pred.LastInstr(),
   854  			)
   855  			currentOccupants.remove(currentReg)
   856  		} else {
   857  			s.getVRegState(desiredVReg).recordReload(f, pred)
   858  			f.ReloadRegisterBefore(desiredVReg.SetRealReg(r), pred.LastInstr())
   859  		}
   860  		currentOccupantsRev[desiredVReg] = r
   861  		currentOccupants.add(r, desiredVReg)
   862  	}
   863  
   864  	if wazevoapi.RegAllocLoggingEnabled {
   865  		fmt.Println("\t", pred.ID(), ":", currentOccupants.format(a.regInfo))
   866  	}
   867  }
   868  
   869  func (a *Allocator) scheduleSpills(f Function) {
   870  	vrStates := a.state.vrStates
   871  	for i := 0; i <= a.state.maxVRegIDEncountered; i++ {
   872  		vs := &vrStates[i]
   873  		if vs.spilled {
   874  			a.scheduleSpill(f, vs)
   875  		}
   876  	}
   877  }
   878  
   879  func (a *Allocator) scheduleSpill(f Function, vs *vrState) {
   880  	v := vs.v
   881  	// If the value is the phi value, we need to insert a spill after each phi definition.
   882  	if vs.isPhi {
   883  		for defInstr := vs.phiDefInstList; defInstr != nil; defInstr = defInstr.next {
   884  			def := defInstr.instr.Defs(&a.vs)[0]
   885  			f.StoreRegisterAfter(def, defInstr.instr)
   886  		}
   887  		return
   888  	}
   889  
   890  	pos := vs.lca
   891  	definingBlk := vs.defBlk
   892  	r := RealRegInvalid
   893  	if wazevoapi.RegAllocLoggingEnabled {
   894  		fmt.Printf("v%d is spilled in blk%d, lca=blk%d\n", v.ID(), definingBlk.ID(), pos.ID())
   895  	}
   896  	for pos != definingBlk {
   897  		st := a.blockStates[pos.ID()]
   898  		for ii := 0; ii < 64; ii++ {
   899  			rr := RealReg(ii)
   900  			if st.startRegs.get(rr) == v {
   901  				r = rr
   902  				// Already in the register, so we can place the spill at the beginning of the block.
   903  				break
   904  			}
   905  		}
   906  
   907  		if r != RealRegInvalid {
   908  			break
   909  		}
   910  
   911  		pos = f.Idom(pos)
   912  	}
   913  
   914  	if pos == definingBlk {
   915  		defInstr := vs.defInstr
   916  		defInstr.Defs(&a.vs)
   917  		if wazevoapi.RegAllocLoggingEnabled {
   918  			fmt.Printf("schedule spill v%d after %v\n", v.ID(), defInstr)
   919  		}
   920  		f.StoreRegisterAfter(a.vs[0], defInstr)
   921  	} else {
   922  		// Found an ancestor block that holds the value in the register at the beginning of the block.
   923  		// We need to insert a spill before the last use.
   924  		first := pos.FirstInstr()
   925  		if wazevoapi.RegAllocLoggingEnabled {
   926  			fmt.Printf("schedule spill v%d before %v\n", v.ID(), first)
   927  		}
   928  		f.StoreRegisterAfter(v.SetRealReg(r), first)
   929  	}
   930  }
   931  
   932  // Reset resets the allocator's internal state so that it can be reused.
   933  func (a *Allocator) Reset() {
   934  	a.state.reset()
   935  	for i, l := 0, len(a.blockStates); i <= a.maxBlockID && i < l; i++ {
   936  		a.blockLivenessData[i] = nil
   937  		s := &a.blockStates[i]
   938  		s.reset()
   939  	}
   940  	a.blockLivenessDataPool.Reset()
   941  	a.phiDefInstListPool.Reset()
   942  
   943  	a.vs = a.vs[:0]
   944  	a.maxBlockID = -1
   945  }
   946  
   947  func (a *Allocator) allocateBlockLivenessData(blockID int) *blockLivenessData {
   948  	if blockID >= len(a.blockLivenessData) {
   949  		a.blockLivenessData = append(a.blockLivenessData, make([]*blockLivenessData, (blockID+1)-len(a.blockLivenessData))...)
   950  	}
   951  	info := a.blockLivenessData[blockID]
   952  	if info == nil {
   953  		info = a.blockLivenessDataPool.Allocate()
   954  		a.blockLivenessData[blockID] = info
   955  	}
   956  	return info
   957  }
   958  
   959  func (a *Allocator) livenessDataAt(blockID int) (info *blockLivenessData) {
   960  	info = a.blockLivenessData[blockID]
   961  	return
   962  }
   963  
   964  func resetBlockLivenessData(i *blockLivenessData) {
   965  	i.seen = false
   966  	i.liveOuts = resetMap(i.liveOuts)
   967  	i.liveIns = resetMap(i.liveIns)
   968  }
   969  
   970  func resetMap[K comparable, V any](m map[K]V) map[K]V {
   971  	if m == nil {
   972  		m = make(map[K]V)
   973  	} else {
   974  		for v := range m {
   975  			delete(m, v)
   976  		}
   977  	}
   978  	return m
   979  }
   980  
   981  // Format is for debugging.
   982  func (i *blockLivenessData) Format(ri *RegisterInfo) string {
   983  	var buf strings.Builder
   984  	buf.WriteString("\t\tblockLivenessData:")
   985  	buf.WriteString("\n\t\t\tliveOuts: ")
   986  	for v := range i.liveOuts {
   987  		if v.IsRealReg() {
   988  			buf.WriteString(fmt.Sprintf("%s ", ri.RealRegName(v.RealReg())))
   989  		} else {
   990  			buf.WriteString(fmt.Sprintf("%v ", v))
   991  		}
   992  	}
   993  	buf.WriteString("\n\t\t\tliveIns: ")
   994  	for v := range i.liveIns {
   995  		if v.IsRealReg() {
   996  			buf.WriteString(fmt.Sprintf("%s ", ri.RealRegName(v.RealReg())))
   997  		} else {
   998  			buf.WriteString(fmt.Sprintf("%v ", v))
   999  		}
  1000  	}
  1001  	buf.WriteString(fmt.Sprintf("\n\t\t\tseen: %v", i.seen))
  1002  	return buf.String()
  1003  }
  1004  
  1005  func (i *blockLivenessData) isKilledAt(vs *vrState, pos programCounter) bool {
  1006  	v := vs.v
  1007  	if vs.lastUse == pos {
  1008  		if _, ok := i.liveOuts[v]; !ok {
  1009  			return true
  1010  		}
  1011  	}
  1012  	return false
  1013  }
  1014  
  1015  func (r *RegisterInfo) isCalleeSaved(reg RealReg) bool {
  1016  	return r.CalleeSavedRegisters[reg]
  1017  }
  1018  
  1019  func (r *RegisterInfo) isCallerSaved(reg RealReg) bool {
  1020  	return r.CallerSavedRegisters[reg]
  1021  }