gorgonia.org/gorgonia@v0.9.17/regalloc.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 6 "github.com/xtgo/set" 7 ) 8 9 // this file holds all the code that relates to register allocation 10 // a lot of the code is shamelessly copied from my previous HIL work, the thirteenthfloor 11 // TODO: cleanup 12 13 type interval struct { 14 start, end int 15 16 result register 17 reads []register 18 ranges []intervalRange 19 usePositions []int 20 } 21 22 func newInterval() *interval { 23 retVal := &interval{ 24 start: -1, 25 end: -1, 26 } 27 return retVal 28 } 29 30 func (i *interval) String() string { 31 return fmt.Sprintf("%s | %d - %d | %v", i.result, i.start, i.end, i.usePositions) 32 } 33 34 func (i *interval) setFrom(from int) { 35 if i.start == -1 || (from < i.start && from >= 0) { 36 i.start = from 37 } 38 } 39 40 func (i *interval) fix() { 41 if len(i.usePositions) == 0 { 42 return 43 } 44 i.usePositions = set.Ints(i.usePositions) 45 i.end = i.usePositions[len(i.usePositions)-1] 46 47 for _, r := range i.ranges { 48 if r.to > i.end { 49 i.end = r.to 50 } 51 } 52 } 53 54 func (i *interval) addRange(from, to int) { 55 if to < from { 56 panic("to < from") // note: to == from is a valid interval range 57 } 58 59 r := intervalRange{from, to} 60 61 // because I'm lazy to create a intervalRangeSet type, we'll just iterate and check 62 for _, ra := range i.ranges { 63 if r == ra { 64 return 65 } 66 } 67 68 i.ranges = append(i.ranges, r) 69 70 // set the end property 71 if to > i.end { 72 i.end = to 73 } 74 75 i.setFrom(from) 76 } 77 78 // added so only unique usePositions are added 79 func (i *interval) addUsePositions(up int) { 80 i.usePositions = append(i.usePositions, up) 81 } 82 83 func (i *interval) noUsePositions() bool { 84 if len(i.usePositions) == 0 || i.usePositions == nil { 85 return true 86 } 87 return false 88 } 89 90 // inclusive of start, but exclusive of end 91 func (i *interval) liveAt(id int) bool { 92 // compileLogf("%v live at %d", i, id) 93 if i.start <= id && id < i.end { 94 return true 95 } 96 return false 97 } 98 99 func (i *interval) lastUse() int { 100 if len(i.usePositions) == 0 { 101 return -1 102 } 103 104 // if !sort.IntsAreSorted(i.usePositions) { 105 // sort.Ints(i.usePositions) 106 // } 107 return i.usePositions[len(i.usePositions)-1] 108 } 109 110 func (i *interval) merge(other *interval) { 111 if other.start < i.start && other.start >= 0 { 112 i.start = other.start 113 } 114 115 if other.end > i.end { 116 i.end = other.end 117 } 118 119 for _, r := range other.ranges { 120 i.addRange(r.from, r.to) 121 } 122 123 i.usePositions = append(i.usePositions, other.usePositions...) 124 i.usePositions = set.Ints(i.usePositions) 125 126 } 127 128 type intervalRange struct { 129 from, to int 130 } 131 132 type regalloc struct { 133 cpucount int 134 gpucount int 135 instructionID int 136 df *dataflow 137 } 138 139 func newRegalloc(df *dataflow) *regalloc { 140 return ®alloc{ 141 df: df, 142 } 143 } 144 145 func (ra *regalloc) newReg(device Device) register { 146 var out register 147 switch device { 148 case CPU: 149 out = register{ra.cpucount, device} 150 ra.cpucount++ 151 default: 152 out = register{ra.gpucount, device} 153 ra.gpucount++ 154 155 } 156 return out 157 } 158 159 func (ra *regalloc) allocArg(nInterv *interval) { 160 nInterv.result = ra.newReg(CPU) 161 } 162 163 func (ra *regalloc) allocMutableOp(node *Node, nInterv *interval) { 164 // create new write to if overwriteInput and the used register is stil live 165 compileLogf("Allocating MutableOp NodeID: %x returns pointer", node.ID()) 166 compileLogf("Op: %v", node.op) 167 enterLogScope() 168 defer leaveLogScope() 169 170 var writeTo register 171 var reads []*interval 172 173 var children Nodes 174 var ok bool 175 if children, ok = ra.df.devTransChildren[node]; !ok { 176 compileLogf("replacement children not found") 177 children = node.children 178 } 179 for _, child := range children { 180 cReplace := ra.df.replacements[child] 181 repInterv := ra.df.intervals[cReplace] 182 reads = append(reads, repInterv) 183 } 184 compileLogf("Read %v", reads) 185 186 var letStmts Nodes 187 it := node.g.To(node.ID()) 188 for it.Next() { 189 parent := it.Node() 190 191 n := parent.(*Node) 192 compileLogf("Parent: %v | %T", n, n.op) 193 if n.isStmt { 194 // compileLogf("isStmt") 195 if _, ok := n.op.(letOp); ok { 196 letStmts = append(letStmts, n) 197 } 198 } 199 } 200 201 overwrites := node.op.OverwritesInput() 202 var onDev bool 203 switch node.op.(type) { 204 case CUDADoer: 205 onDev = true 206 case CLDoer: 207 onDev = true 208 default: 209 } 210 211 if overwrites >= 0 { 212 overwriteReg := reads[overwrites].result 213 overwriteDev := overwriteReg.device 214 overwrittenIsLive := reads[overwrites].liveAt(ra.instructionID) 215 compileLogf("Overwrites : %v ", overwrites) 216 compileLogf("Overwritten (%v) is live at %d? %t", reads[overwrites], ra.instructionID, overwrittenIsLive) 217 compileLogf("Let Statements: %d | %v", len(letStmts), reads[overwrites]) 218 219 // If the overwritten is not live, and the node does not call external processes (obiviating the need to prealloc) 220 // then we can directly overwrite the register. 221 if len(letStmts) == 1 || !overwrittenIsLive { 222 223 switch { 224 case onDev && overwriteDev != CPU: 225 // if overwritten reg is on external device and op will execute on external device 226 // then safe to overwrite 227 writeTo = overwriteReg 228 case !node.op.CallsExtern() && overwriteDev == CPU: 229 // original case: 230 // if the op doesn't call an extern, and is executed on CPU 231 // safe to overwrite 232 writeTo = overwriteReg 233 case onDev: 234 // new register otherwise 235 writeTo = ra.newReg(Device(0)) 236 case !onDev: 237 // new register otherwise 238 writeTo = ra.newReg(CPU) 239 } 240 241 } else { 242 if onDev { 243 writeTo = ra.newReg(Device(0)) 244 } else { 245 writeTo = ra.newReg(CPU) 246 } 247 } 248 } else { 249 compileLogf("New register") 250 if onDev { 251 writeTo = ra.newReg(Device(0)) 252 } else { 253 writeTo = ra.newReg(CPU) 254 } 255 } 256 257 for _, r := range reads { 258 nInterv.reads = append(nInterv.reads, r.result) 259 } 260 nInterv.result = writeTo 261 compileLogf("%v: %v", node.op, nInterv) 262 } 263 264 func (ra *regalloc) allocImmutableOp(node *Node, nInterv *interval) { 265 compileLogf("Allocating Immutable Op") 266 enterLogScope() 267 defer leaveLogScope() 268 269 var writeTo register 270 var reads []*interval 271 272 var children Nodes 273 var ok bool 274 if children, ok = ra.df.devTransChildren[node]; !ok { 275 children = node.children 276 } 277 for _, child := range children { 278 cReplace := ra.df.replacements[child] 279 repInterv := ra.df.intervals[cReplace] 280 reads = append(reads, repInterv) 281 } 282 283 compileLogf("NodeID: %x does not returns pointer", node.ID()) 284 if _, ok := node.op.(CUDADoer); ok { 285 writeTo = ra.newReg(Device(0)) 286 } else { 287 writeTo = ra.newReg(CPU) 288 } 289 290 for _, r := range reads { 291 nInterv.reads = append(nInterv.reads, r.result) 292 } 293 nInterv.result = writeTo 294 } 295 296 func (ra *regalloc) allocStatement(node *Node, nInterv *interval) { 297 var writeTo register 298 switch op := node.op.(type) { 299 case devTrans: 300 writeTo = ra.newReg(op.to) 301 } 302 nInterv.result = writeTo 303 } 304 305 func (ra *regalloc) alloc(sorted Nodes) { 306 compileLogf("Allocating registers") 307 enterLogScope() 308 defer leaveLogScope() 309 310 for i, node := range sorted { 311 ra.instructionID = i 312 313 replacement := ra.df.replacements[node] 314 nInterv := ra.df.intervals[replacement] 315 316 compileLogf("replacement %v, interval %v", replacement, nInterv) 317 318 if node != replacement { 319 compileLogf("Merging") 320 ra.df.intervals[node].merge(nInterv) 321 } 322 compileLogf("Working on %v(%x). InstructionID: %d", node, node.ID(), ra.instructionID) 323 324 switch { 325 case node.isArg(): 326 ra.allocArg(nInterv) 327 case node.isStmt: 328 ra.allocStatement(node, nInterv) 329 case node.op.ReturnsPtr(): 330 ra.allocMutableOp(node, nInterv) 331 default: 332 ra.allocImmutableOp(node, nInterv) 333 } 334 compileLogf("n: %x; result: %v; reads: %v", node.ID(), nInterv.result, nInterv.reads) 335 // ra.instructionID++ 336 } 337 }