gorgonia.org/gorgonia@v0.9.17/regalloc.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/xtgo/set"
     7  )
     8  
     9  // this file holds all the code that relates to register allocation
    10  // a lot of the code is shamelessly copied from my previous HIL work, the thirteenthfloor
    11  // TODO: cleanup
    12  
    13  type interval struct {
    14  	start, end int
    15  
    16  	result       register
    17  	reads        []register
    18  	ranges       []intervalRange
    19  	usePositions []int
    20  }
    21  
    22  func newInterval() *interval {
    23  	retVal := &interval{
    24  		start: -1,
    25  		end:   -1,
    26  	}
    27  	return retVal
    28  }
    29  
    30  func (i *interval) String() string {
    31  	return fmt.Sprintf("%s | %d - %d | %v", i.result, i.start, i.end, i.usePositions)
    32  }
    33  
    34  func (i *interval) setFrom(from int) {
    35  	if i.start == -1 || (from < i.start && from >= 0) {
    36  		i.start = from
    37  	}
    38  }
    39  
    40  func (i *interval) fix() {
    41  	if len(i.usePositions) == 0 {
    42  		return
    43  	}
    44  	i.usePositions = set.Ints(i.usePositions)
    45  	i.end = i.usePositions[len(i.usePositions)-1]
    46  
    47  	for _, r := range i.ranges {
    48  		if r.to > i.end {
    49  			i.end = r.to
    50  		}
    51  	}
    52  }
    53  
    54  func (i *interval) addRange(from, to int) {
    55  	if to < from {
    56  		panic("to < from") // note: to == from is a valid interval range
    57  	}
    58  
    59  	r := intervalRange{from, to}
    60  
    61  	// because I'm lazy to create a intervalRangeSet type, we'll just iterate and check
    62  	for _, ra := range i.ranges {
    63  		if r == ra {
    64  			return
    65  		}
    66  	}
    67  
    68  	i.ranges = append(i.ranges, r)
    69  
    70  	// set the end property
    71  	if to > i.end {
    72  		i.end = to
    73  	}
    74  
    75  	i.setFrom(from)
    76  }
    77  
    78  // added so only unique usePositions are added
    79  func (i *interval) addUsePositions(up int) {
    80  	i.usePositions = append(i.usePositions, up)
    81  }
    82  
    83  func (i *interval) noUsePositions() bool {
    84  	if len(i.usePositions) == 0 || i.usePositions == nil {
    85  		return true
    86  	}
    87  	return false
    88  }
    89  
    90  // inclusive of start, but exclusive of end
    91  func (i *interval) liveAt(id int) bool {
    92  	// compileLogf("%v live at %d", i, id)
    93  	if i.start <= id && id < i.end {
    94  		return true
    95  	}
    96  	return false
    97  }
    98  
    99  func (i *interval) lastUse() int {
   100  	if len(i.usePositions) == 0 {
   101  		return -1
   102  	}
   103  
   104  	// if !sort.IntsAreSorted(i.usePositions) {
   105  	// 	sort.Ints(i.usePositions)
   106  	// }
   107  	return i.usePositions[len(i.usePositions)-1]
   108  }
   109  
   110  func (i *interval) merge(other *interval) {
   111  	if other.start < i.start && other.start >= 0 {
   112  		i.start = other.start
   113  	}
   114  
   115  	if other.end > i.end {
   116  		i.end = other.end
   117  	}
   118  
   119  	for _, r := range other.ranges {
   120  		i.addRange(r.from, r.to)
   121  	}
   122  
   123  	i.usePositions = append(i.usePositions, other.usePositions...)
   124  	i.usePositions = set.Ints(i.usePositions)
   125  
   126  }
   127  
   128  type intervalRange struct {
   129  	from, to int
   130  }
   131  
   132  type regalloc struct {
   133  	cpucount      int
   134  	gpucount      int
   135  	instructionID int
   136  	df            *dataflow
   137  }
   138  
   139  func newRegalloc(df *dataflow) *regalloc {
   140  	return &regalloc{
   141  		df: df,
   142  	}
   143  }
   144  
   145  func (ra *regalloc) newReg(device Device) register {
   146  	var out register
   147  	switch device {
   148  	case CPU:
   149  		out = register{ra.cpucount, device}
   150  		ra.cpucount++
   151  	default:
   152  		out = register{ra.gpucount, device}
   153  		ra.gpucount++
   154  
   155  	}
   156  	return out
   157  }
   158  
   159  func (ra *regalloc) allocArg(nInterv *interval) {
   160  	nInterv.result = ra.newReg(CPU)
   161  }
   162  
   163  func (ra *regalloc) allocMutableOp(node *Node, nInterv *interval) {
   164  	// create new write to if overwriteInput and the used register is stil live
   165  	compileLogf("Allocating MutableOp NodeID: %x returns pointer", node.ID())
   166  	compileLogf("Op: %v", node.op)
   167  	enterLogScope()
   168  	defer leaveLogScope()
   169  
   170  	var writeTo register
   171  	var reads []*interval
   172  
   173  	var children Nodes
   174  	var ok bool
   175  	if children, ok = ra.df.devTransChildren[node]; !ok {
   176  		compileLogf("replacement children not found")
   177  		children = node.children
   178  	}
   179  	for _, child := range children {
   180  		cReplace := ra.df.replacements[child]
   181  		repInterv := ra.df.intervals[cReplace]
   182  		reads = append(reads, repInterv)
   183  	}
   184  	compileLogf("Read %v", reads)
   185  
   186  	var letStmts Nodes
   187  	it := node.g.To(node.ID())
   188  	for it.Next() {
   189  		parent := it.Node()
   190  
   191  		n := parent.(*Node)
   192  		compileLogf("Parent: %v | %T", n, n.op)
   193  		if n.isStmt {
   194  			// compileLogf("isStmt")
   195  			if _, ok := n.op.(letOp); ok {
   196  				letStmts = append(letStmts, n)
   197  			}
   198  		}
   199  	}
   200  
   201  	overwrites := node.op.OverwritesInput()
   202  	var onDev bool
   203  	switch node.op.(type) {
   204  	case CUDADoer:
   205  		onDev = true
   206  	case CLDoer:
   207  		onDev = true
   208  	default:
   209  	}
   210  
   211  	if overwrites >= 0 {
   212  		overwriteReg := reads[overwrites].result
   213  		overwriteDev := overwriteReg.device
   214  		overwrittenIsLive := reads[overwrites].liveAt(ra.instructionID)
   215  		compileLogf("Overwrites : %v ", overwrites)
   216  		compileLogf("Overwritten (%v) is live at %d? %t", reads[overwrites], ra.instructionID, overwrittenIsLive)
   217  		compileLogf("Let Statements: %d | %v", len(letStmts), reads[overwrites])
   218  
   219  		// If the overwritten is not live, and the node does not call external processes (obiviating the need to prealloc)
   220  		// then we can directly overwrite the register.
   221  		if len(letStmts) == 1 || !overwrittenIsLive {
   222  
   223  			switch {
   224  			case onDev && overwriteDev != CPU:
   225  				// if overwritten reg is on external device and op will execute on external device
   226  				// then safe to overwrite
   227  				writeTo = overwriteReg
   228  			case !node.op.CallsExtern() && overwriteDev == CPU:
   229  				// original case:
   230  				// if the op doesn't call an extern, and is executed on CPU
   231  				// safe to overwrite
   232  				writeTo = overwriteReg
   233  			case onDev:
   234  				// new register otherwise
   235  				writeTo = ra.newReg(Device(0))
   236  			case !onDev:
   237  				// new register otherwise
   238  				writeTo = ra.newReg(CPU)
   239  			}
   240  
   241  		} else {
   242  			if onDev {
   243  				writeTo = ra.newReg(Device(0))
   244  			} else {
   245  				writeTo = ra.newReg(CPU)
   246  			}
   247  		}
   248  	} else {
   249  		compileLogf("New register")
   250  		if onDev {
   251  			writeTo = ra.newReg(Device(0))
   252  		} else {
   253  			writeTo = ra.newReg(CPU)
   254  		}
   255  	}
   256  
   257  	for _, r := range reads {
   258  		nInterv.reads = append(nInterv.reads, r.result)
   259  	}
   260  	nInterv.result = writeTo
   261  	compileLogf("%v: %v", node.op, nInterv)
   262  }
   263  
   264  func (ra *regalloc) allocImmutableOp(node *Node, nInterv *interval) {
   265  	compileLogf("Allocating Immutable Op")
   266  	enterLogScope()
   267  	defer leaveLogScope()
   268  
   269  	var writeTo register
   270  	var reads []*interval
   271  
   272  	var children Nodes
   273  	var ok bool
   274  	if children, ok = ra.df.devTransChildren[node]; !ok {
   275  		children = node.children
   276  	}
   277  	for _, child := range children {
   278  		cReplace := ra.df.replacements[child]
   279  		repInterv := ra.df.intervals[cReplace]
   280  		reads = append(reads, repInterv)
   281  	}
   282  
   283  	compileLogf("NodeID: %x does not returns pointer", node.ID())
   284  	if _, ok := node.op.(CUDADoer); ok {
   285  		writeTo = ra.newReg(Device(0))
   286  	} else {
   287  		writeTo = ra.newReg(CPU)
   288  	}
   289  
   290  	for _, r := range reads {
   291  		nInterv.reads = append(nInterv.reads, r.result)
   292  	}
   293  	nInterv.result = writeTo
   294  }
   295  
   296  func (ra *regalloc) allocStatement(node *Node, nInterv *interval) {
   297  	var writeTo register
   298  	switch op := node.op.(type) {
   299  	case devTrans:
   300  		writeTo = ra.newReg(op.to)
   301  	}
   302  	nInterv.result = writeTo
   303  }
   304  
   305  func (ra *regalloc) alloc(sorted Nodes) {
   306  	compileLogf("Allocating registers")
   307  	enterLogScope()
   308  	defer leaveLogScope()
   309  
   310  	for i, node := range sorted {
   311  		ra.instructionID = i
   312  
   313  		replacement := ra.df.replacements[node]
   314  		nInterv := ra.df.intervals[replacement]
   315  
   316  		compileLogf("replacement %v, interval %v", replacement, nInterv)
   317  
   318  		if node != replacement {
   319  			compileLogf("Merging")
   320  			ra.df.intervals[node].merge(nInterv)
   321  		}
   322  		compileLogf("Working on %v(%x). InstructionID: %d", node, node.ID(), ra.instructionID)
   323  
   324  		switch {
   325  		case node.isArg():
   326  			ra.allocArg(nInterv)
   327  		case node.isStmt:
   328  			ra.allocStatement(node, nInterv)
   329  		case node.op.ReturnsPtr():
   330  			ra.allocMutableOp(node, nInterv)
   331  		default:
   332  			ra.allocImmutableOp(node, nInterv)
   333  		}
   334  		compileLogf("n: %x; result: %v; reads: %v", node.ID(), nInterv.result, nInterv.reads)
   335  		// ra.instructionID++
   336  	}
   337  }