gorgonia.org/gorgonia@v0.9.17/compile.go (about) 1 package gorgonia 2 3 import ( 4 "encoding/csv" 5 "fmt" 6 "io" 7 8 "github.com/pkg/errors" 9 "gorgonia.org/tensor" 10 ) 11 12 // This file deals with the compilation from a expression graph into a program 13 // that is executed by an interpreter 14 15 // Compile takes a graph and outputs a program suitable for *tapeMachine to run 16 func Compile(g *ExprGraph) (prog *program, locMap map[*Node]register, err error) { 17 compileLogf("Compiling") 18 enterLogScope() 19 defer leaveLogScope() 20 21 switch { 22 case len(g.AllNodes()) == 0: 23 err = errors.Errorf("Cannot compile an empty graph") 24 return 25 case g.Inputs().Len() == 0: 26 err = errors.Errorf("Cannot compile a graph that has no input nodes") 27 return 28 } 29 30 compileLogf("sorting") 31 var sortedNodes Nodes 32 if sortedNodes, err = Sort(g); err != nil { 33 return nil, nil, errors.Wrap(err, sortFail) 34 } 35 reverseNodes(sortedNodes) 36 37 df := analyze(g, sortedNodes) 38 sortedNodes = df.insertDeviceInstr(sortedNodes) 39 df.buildIntervals(sortedNodes) 40 41 ra := newRegalloc(df) 42 ra.alloc(sortedNodes) 43 44 // debug related stuff 45 df.debugIntervals(sortedNodes) 46 logCompileState(g.name, g, df) 47 48 inputs := g.Inputs() 49 cg := newCodeGenerator(inputs, sortedNodes, df) 50 prog, locMap = cg.gen() 51 prog.cpulocs = ra.cpucount 52 prog.gpulocs = ra.gpucount 53 prog.cpumem = cg.cpumem 54 prog.gpumem = cg.gpumem 55 prog.df = df 56 prog.g = g 57 prog.sorted = sortedNodes 58 59 return 60 } 61 62 // CompileFunction takes a graph, subsets it based on the input and output nodes provided and outputs a program suitable for *tapeMachine to run. 63 // It is analogous to theano.Function(). 64 // If some input nodes are not used or is not reachable, this function will return an error 65 func CompileFunction(g *ExprGraph, inputs, outputs Nodes) (prog *program, locMap map[*Node]register, err error) { 66 compileLogf("CompileFunctionNEW. Inputs: %d; outputs: %d", inputs, outputs) 67 enterLogScope() 68 defer leaveLogScope() 69 70 subgraph := g.ExactSubgraphRoots(outputs...) 71 var unused Nodes 72 for _, in := range inputs { 73 if !subgraph.all.Contains(in) { 74 unused = append(unused, in) 75 } 76 } 77 78 if len(unused) > 0 { 79 return nil, nil, errors.Errorf("Not all the inputs are used: %v", unused) 80 } 81 82 var sortedNodes Nodes 83 if sortedNodes, err = Sort(subgraph); err != nil { 84 return nil, nil, errors.Wrap(err, sortFail) 85 } 86 reverseNodes(sortedNodes) 87 88 df := analyze(subgraph, sortedNodes) 89 sortedNodes = df.insertDeviceInstr(sortedNodes) 90 df.buildIntervals(sortedNodes) 91 92 ra := newRegalloc(df) 93 ra.alloc(sortedNodes) 94 95 cg := newCodeGenerator(inputs, sortedNodes, df) 96 prog, locMap = cg.gen() 97 prog.cpulocs = ra.cpucount 98 prog.gpulocs = ra.gpucount 99 prog.df = df 100 prog.g = subgraph 101 prog.sorted = sortedNodes 102 103 return 104 } 105 106 // codgenerator holds the state for the code generation process 107 type codegenerator struct { 108 locMap map[*Node]register 109 lastWrites map[register]*Node 110 flushed map[int]struct{} 111 allocated map[register]struct{} 112 freed map[register]struct{} 113 deferFree map[register]struct{} 114 instrMap map[*Node]fragment 115 queue []int // queue to flush 116 117 lastReads map[register]int 118 119 cpumem int64 120 gpumem []int64 121 122 g *ExprGraph 123 inputs, sorted Nodes 124 df *dataflow 125 instructions fragment 126 } 127 128 func newCodeGenerator(inputs, sorted Nodes, df *dataflow) *codegenerator { 129 return &codegenerator{ 130 locMap: make(map[*Node]register), 131 lastWrites: make(map[register]*Node), 132 flushed: make(map[int]struct{}), 133 allocated: make(map[register]struct{}), 134 freed: make(map[register]struct{}), 135 deferFree: make(map[register]struct{}), 136 instrMap: make(map[*Node]fragment), 137 lastReads: make(map[register]int), 138 139 g: inputs[0].g, 140 inputs: inputs, 141 sorted: sorted, 142 df: df, 143 } 144 } 145 146 // addInstr adds the instruction to the associated node in the instrMap. 147 // when we add instructions to the node map, we also try to determine the size of the allocations required 148 func (cg *codegenerator) addInstr(node *Node, instr tapeInstr) { 149 if instrs := cg.instrMap[node]; instrs != nil { 150 instrs = append(instrs, instr) 151 cg.instrMap[node] = instrs 152 } else { 153 cg.instrMap[node] = fragment{instr} 154 } 155 156 var dt tensor.Dtype 157 var err error 158 switch inst := instr.(type) { 159 case loadArg: 160 if dt, err = dtypeOf(node.t); err != nil { 161 panic(err) 162 } 163 d := instr.writes().device 164 if d != CPU { 165 if len(cg.gpumem) < int(d)+1 { 166 diff := int(d) + 1 - len(cg.gpumem) 167 cg.gpumem = append(cg.gpumem, make([]int64, diff)...) 168 } 169 } 170 171 switch d { 172 case CPU: 173 cg.cpumem += calcMemSize(dt, node.Shape()) 174 default: 175 cg.gpumem[int(d)] += calcMemSize(dt, node.Shape()) 176 } 177 case alloc: 178 if dt, err = dtypeOf(inst.t); err != nil { 179 panic(err) 180 } 181 182 d := instr.writes().device 183 if d != CPU { 184 if len(cg.gpumem) < int(d)+1 { 185 diff := int(d) + 1 - len(cg.gpumem) 186 cg.gpumem = append(cg.gpumem, make([]int64, diff)...) 187 } 188 } 189 190 switch d { 191 case CPU: 192 cg.cpumem += calcMemSize(dt, inst.s) 193 default: 194 cg.gpumem[int(d)] += calcMemSize(dt, inst.s) 195 } 196 case *execOp: 197 if !inst.op.ReturnsPtr() { 198 d := instr.writes().device 199 if d != CPU { 200 if len(cg.gpumem) < int(d)+1 { 201 diff := int(d) + 1 - len(cg.gpumem) 202 cg.gpumem = append(cg.gpumem, make([]int64, diff)...) 203 } 204 } 205 switch d { 206 case CPU: 207 cg.cpumem += inst.size 208 default: 209 cg.gpumem[int(d)] += inst.size 210 } 211 } 212 213 default: 214 // panic("EHLP") 215 } 216 } 217 218 // every time an instruction is added to the list of instructions, 219 // also add the instructionID and the register the instruction writes to. 220 // This helps with determining if a flushInstruction needs to be issued. 221 func (cg *codegenerator) updateLastWrites(reg register, n *Node) { 222 cg.lastWrites[reg] = n 223 } 224 225 func (cg *codegenerator) flush() { 226 compileLogf("Flushing") 227 for _, instrID := range cg.queue { 228 cg.flushed[instrID] = struct{}{} 229 } 230 cg.queue = cg.queue[:0] 231 } 232 233 func (cg *codegenerator) addArg(node *Node, interv *interval) { 234 compileLogf("LoadArg: %x", node.ID()) 235 writeTo := interv.result 236 237 cg.locMap[node] = writeTo 238 instr := loadArg{ 239 // index: index, 240 index: node.ID(), 241 writeTo: writeTo, 242 name: node.Name(), 243 } 244 // cg.instructions = append(cg.instructions, instr) 245 246 cg.addInstr(node, instr) 247 cg.updateLastWrites(writeTo, node) 248 } 249 250 func (cg *codegenerator) addStmt(node *Node, interv *interval, i int) { 251 compileLogf("Add Statement") 252 enterLogScope() 253 defer leaveLogScope() 254 255 writeTo := interv.result 256 257 var children Nodes 258 var ok bool 259 if children, ok = cg.df.devTransChildren[node]; !ok { 260 children = node.children 261 } 262 263 switch op := node.op.(type) { 264 case letOp: 265 // there should be only 2 chilren 266 if len(children) != 2 { 267 panic("Expected only two children") 268 } 269 compileLogf("node.children %d. [1]: %v; [0]: %v", node.ID(), children[1], children[0]) 270 compileLogf("node isInput %v", node.isInput()) 271 from := cg.df.intervals[children[1]].result 272 to := cg.df.intervals[children[0]].result 273 274 instr := letInstr{ 275 readFrom: from, 276 writeTo: to, 277 } 278 // cg.instructions = append(cg.instructions, instr) 279 280 cg.addInstr(node, instr) 281 cg.updateLastWrites(writeTo, node) 282 case readOp: 283 // there should be only 1 child 284 if len(children) != 1 { 285 panic("Expected only one child") 286 } 287 compileLogf("node.children %d. [0]: %v", node.ID(), children[0]) 288 compileLogf("node isInput %v", node.isInput()) 289 compileLogf("node.children[0] Type %v, shape %v", children[0].t, children[0].shape) 290 291 if _, ok := cg.flushed[i]; !ok { 292 cg.addInstr(node, flushInstr{}) 293 cg.flush() 294 } 295 296 from := cg.df.intervals[children[0]].result 297 instr := &readInstr{ 298 into: op.into, 299 readFrom: from, 300 301 t: children[0].t, 302 s: children[0].shape, 303 } 304 // cg.instructions = append(cg.instructions, instr) 305 306 cg.addInstr(node, instr) 307 cg.updateLastWrites(writeTo, node) 308 case devTrans: 309 if _, ok := cg.allocated[writeTo]; !ok { 310 // insert new alloc 311 var instr alloc 312 instr = newAlloc(node, writeTo) 313 // cg.instructions = append(cg.instructions, instr) 314 315 cg.addInstr(node, instr) 316 cg.updateLastWrites(writeTo, node) 317 cg.queue = append(cg.queue, i) 318 cg.allocated[writeTo] = struct{}{} 319 } 320 321 compileLogf("devTrans") 322 if len(children) != 1 { 323 panic("Expected only one child") 324 } 325 326 from := cg.df.intervals[children[0]].result 327 to := cg.df.intervals[node].result 328 329 instr := deviceTransport{ 330 from: from, to: to, 331 } 332 cg.addInstr(node, instr) 333 334 if op.from != CPU && op.to == CPU { 335 instrID := cg.sorted.index(op.toNode) 336 if _, ok := cg.flushed[instrID]; !ok { 337 // cg.instructions = append(cg.instructions, flushInstr{}) 338 cg.addInstr(node, flushInstr{}) 339 cg.flush() 340 } 341 } 342 cg.updateLastWrites(writeTo, node) 343 344 } 345 } 346 347 func (cg *codegenerator) addNode(node, replacement *Node, interv *interval, i int) { 348 compileLogf("AddNode: %x %v", node.ID(), node.op) 349 compileLogf("interval %v", interv) 350 enterLogScope() 351 defer leaveLogScope() 352 353 writeTo := interv.result 354 355 var reads []register 356 var children Nodes 357 var ok bool 358 if children, ok = cg.df.devTransChildren[node]; !ok { 359 children = node.children 360 } 361 for _, child := range children { 362 cReplacement := cg.df.replacements[child] 363 cInterv := cg.df.intervals[cReplacement] 364 reads = append(reads, cInterv.result) 365 } 366 enterLogScope() 367 defer leaveLogScope() 368 369 var prealloc bool 370 var useUnsafe bool 371 // if it's not mutable, there is no chance it will be overwritten 372 if node.isMutable() { 373 // if the instruction calls an extern (cBLAS or cuBlas), then we should preallocate the vector 374 if node.op.CallsExtern() { 375 compileLogf("calls extern") 376 if _, ok := cg.allocated[writeTo]; !ok { 377 compileLogf("Inserting new alloc") 378 var instr alloc 379 instr = newAlloc(node, writeTo) 380 cg.addInstr(node, instr) 381 cg.updateLastWrites(writeTo, node) 382 383 prealloc = true 384 385 cg.queue = append(cg.queue, i) 386 cg.allocated[writeTo] = struct{}{} 387 } 388 } 389 } 390 compileLogf("Node Reads %v", reads) 391 // check if any previously buffered cBLAS or cuBLAS calls need to be flushed 392 // it doesn't matter if the machine isn't using a batchedBLAS. flushInstr would just be a no-op at runtime 393 for _, read := range reads { 394 if lastWriteNode, ok := cg.lastWrites[read]; ok { 395 instrID := cg.sorted.index(lastWriteNode) 396 var op Op 397 var onDev, nodeOnDev Device 398 399 _, isDevTrans := lastWriteNode.Op().(devTrans) 400 switch { 401 case lastWriteNode.isArg(), lastWriteNode.isStmt && !isDevTrans: 402 continue 403 default: 404 op = lastWriteNode.op 405 } 406 switch op.(type) { 407 case CUDADoer: 408 onDev = Device(0) 409 case CLDoer: 410 onDev = Device(0) 411 default: 412 onDev = CPU 413 } 414 415 switch node.op.(type) { 416 case CUDADoer: 417 nodeOnDev = Device(0) 418 case CLDoer: 419 nodeOnDev = Device(0) 420 default: 421 nodeOnDev = CPU 422 } 423 424 // if we have sequential Extern calls, we just add it to the batch. 425 // sequential in this can mean several instructions apart. For example: 426 // 4 A × B ; read %2 ; write to %3 427 // ⋮ (doesn't use %3 or %10) 428 // ⋮ 429 // 10 Aᵀ × B ; read %3 ; write to %10 430 // ⋮ (doesn't use %3, or %10) 431 // ⋮ 432 // 12 + ; read %10 ; write to %12 433 // 434 // It is before instruction 12 that the flush will be added. 4 and 10 are considered sequential 435 // 436 // It is not sequential when both are not the same devices 437 switch { 438 case !op.CallsExtern(): 439 compileLogf("ToFlush: Node doesn't call extern. NO FLUSH") 440 // op doesn't call extern... don't bother flushing 441 case op.CallsExtern() && node.op.CallsExtern() && onDev == nodeOnDev && !isDevTrans: 442 compileLogf("ToFlush: Both calls extern, both same device. NO FLUSH") 443 // same device, both calls extern 444 // no flush needed 445 case op.CallsExtern() && node.op.CallsExtern() && onDev != nodeOnDev: 446 compileLogf("ToFlush: Differing devices") 447 // different devices, both calls extern 448 // flush needed 449 fallthrough 450 case op.CallsExtern() && !node.op.CallsExtern(): 451 compileLogf("ToFlush: Node requires value immediately") 452 // node is gonna use the value immediately 453 // flush needed 454 fallthrough 455 default: 456 compileLogf("ToFlush: FLUSH") 457 if _, ok := cg.flushed[instrID]; !ok { 458 // cg.instructions = append(cg.instructions, flushInstr{}) 459 cg.addInstr(node, flushInstr{}) 460 cg.flush() 461 } 462 } 463 464 // viaticum := cg.instructions[instrID] // ;) - it IS on the way 465 // if instr, ok := viaticum.(*execOp); ok { 466 // if op.CallsExtern() && !node.op.CallsExtern() { 467 // } 468 // } 469 } 470 471 // check the overwrites - if the overwrite and the resulting register is the same, 472 // then use unsafe options when available 473 overwrites := node.op.OverwritesInput() 474 if overwrites >= 0 { 475 compileLogf("Overwrites %d", overwrites) 476 overwritten := reads[overwrites] 477 compileLogf("NodeID:%d overwritten: %v, reads: %v, interval: %v", node.ID(), overwritten, interv.reads, interv.result) 478 if overwritten == interv.result { 479 compileLogf("Use unsafe") 480 useUnsafe = true 481 } 482 } 483 484 } 485 486 cg.locMap[node] = writeTo 487 488 // otherwise, the replacement has already been written 489 if node == replacement { 490 compileLogf("New Exec Op: %v", node.op) 491 instr := newExecOp(node) 492 instr.readFrom = reads 493 instr.writeTo = writeTo 494 instr.preAllocated = prealloc 495 instr.useUnsafe = useUnsafe 496 497 // cg.instructions = append(cg.instructions, instr) 498 cg.addInstr(node, instr) 499 cg.updateLastWrites(writeTo, node) 500 } 501 } 502 503 func (cg *codegenerator) insertFree(instrID int, node *Node) { 504 compileLogf("Inserting Free for instrID %d | instr: %v | op: %v", instrID, node, node.op) 505 enterLogScope() 506 defer leaveLogScope() 507 508 var reads []register 509 var children Nodes 510 var ok bool 511 if children, ok = cg.df.devTransChildren[node]; !ok { 512 children = node.children 513 } 514 for _, child := range children { 515 cReplacement := cg.df.replacements[child] 516 cInterv := cg.df.intervals[cReplacement] 517 reads = append(reads, cInterv.result) 518 } 519 compileLogf("reads %v", reads) 520 521 // check if anything needs to be freed 522 for _, read := range reads { 523 var readNode *Node 524 for n, reg := range cg.locMap { 525 if reg == read { 526 if readNode == nil { 527 readNode = n 528 continue 529 } 530 if readNode.id < n.id { 531 readNode = n 532 } 533 } 534 } 535 // interv := cg.df.intervals[readNode] 536 readRepl := cg.df.replacements[readNode] 537 if readRepl == nil { 538 readRepl = readNode 539 } 540 if readRepl == nil { 541 continue 542 } 543 interv := cg.df.intervals[readRepl] 544 compileLogf("interv for readRepl %v: %v", readRepl, interv) 545 lastUse := interv.lastUse() 546 compileLogf("Interval: %v; read: %v; Read Node %v; Op %v; LastUse %v; Instrid: %v", interv, read, readNode, readNode.op, lastUse, instrID) 547 if lastUse >= 0 && lastUse <= instrID && read.device != CPU { 548 if _, ok := cg.freed[read]; !ok { 549 compileLogf("Adding Free %v. LastUse %d", read, interv.lastUse()) 550 // cg.instructions = append(cg.instructions, free{read}) 551 cg.addInstr(node, free{read}) 552 cg.freed[read] = struct{}{} 553 } 554 } 555 } 556 557 write := cg.locMap[node] 558 repl := cg.df.replacements[node] 559 interv := cg.df.intervals[repl] 560 compileLogf("Node %v | write %v | Last Use %v | %v", node, write, interv.lastUse(), node.isRoot()) 561 if interv.lastUse() == -1 || interv.lastUse() >= len(cg.sorted) { 562 // if node.isRoot() { 563 cg.deferFree[write] = struct{}{} 564 // return 565 // } 566 567 // otherwise, it's essentially a NOOP, so we free the memory immediately after the Op is executed 568 // TODO: do NO-OP optimizations 569 // if _, ok := cg.freed[write]; !ok { 570 // compileLogf("Adding Free %v. Last Use %d", write, interv.lastUse()) 571 // cg.addInstr(node, free{write}) 572 // cg.freed[write] = struct{}{} 573 // } 574 } 575 } 576 577 func (cg *codegenerator) insertLastFrees() int { 578 node := cg.sorted[len(cg.sorted)-1] 579 var instructionsAdded int 580 for reg := range cg.deferFree { 581 if _, ok := cg.freed[reg]; !ok { 582 compileLogf("Adding Free %v to the final instruction", reg) 583 cg.addInstr(node, free{reg}) 584 instructionsAdded++ 585 } 586 } 587 return instructionsAdded 588 } 589 590 func (cg *codegenerator) gen() (*program, map[*Node]register) { 591 compileLogf("Generating from SORTED: %v", cg.sorted) 592 enterLogScope() 593 defer leaveLogScope() 594 for i, node := range cg.sorted { 595 // for i := len(cg.sorted) - 1; i ⩾ 0; i-- { 596 // node := cg.sorted[i] 597 replacement := cg.df.replacements[node] 598 compileLogf("Working on %x. Replacement: %x", node.ID(), replacement.ID()) 599 600 nInterv := cg.df.intervals[replacement] 601 switch { 602 case node.isArg(): 603 cg.addArg(node, nInterv) 604 case node.isStmt: 605 cg.addStmt(node, nInterv, i) 606 default: 607 cg.addNode(node, replacement, nInterv, i) 608 } 609 } 610 611 var instructionCount int 612 for i := len(cg.sorted) - 1; i >= 0; i-- { 613 node := cg.sorted[i] 614 cg.insertFree(i, node) 615 616 instructionCount += len(cg.instrMap[node]) 617 } 618 619 instructionCount += cg.insertLastFrees() 620 621 cg.instructions = make(fragment, 0, instructionCount) 622 for _, node := range cg.sorted { 623 instrs := cg.instrMap[node] 624 cg.instructions = append(cg.instructions, instrs...) 625 } 626 627 return &program{ 628 instructions: cg.instructions, 629 args: len(cg.inputs), 630 g: cg.g, 631 m: cg.instrMap, 632 }, cg.locMap 633 } 634 635 func compileState(w io.Writer, g *ExprGraph, df *dataflow) { 636 header := []string{ 637 "ID", "Op", "Type", "Register", "Interval", "Used By", "Uses", "Overwrites", "BLAS?", 638 } 639 640 var rows [][]string 641 for _, n := range g.AllNodes() { 642 interv := df.intervals[n] 643 644 row := make([]string, len(header)) 645 row[0] = fmt.Sprintf("%d", n.ID()) 646 row[2] = fmt.Sprintf("%s", n.t) 647 row[3] = fmt.Sprintf("%s", interv.result) 648 row[4] = fmt.Sprintf("%d - %d", interv.start, interv.end) 649 row[5] = fmt.Sprintf("%v", interv.usePositions) 650 row[6] = fmt.Sprintf("%d", n.children) 651 652 if n.op != nil { 653 row[1] = fmt.Sprintf("%s", n.op) 654 overwrites := n.op.OverwritesInput() 655 if overwrites >= 0 { 656 row[7] = fmt.Sprintf("%d", n.children[overwrites].ID()) 657 } 658 659 if n.op.CallsExtern() { 660 row[8] = "yes" 661 } 662 } 663 664 rows = append(rows, row) 665 } 666 cw := csv.NewWriter(w) 667 cw.Comma = ';' 668 // TODO: Check errors on writes here. 669 cw.Write(header) 670 cw.WriteAll(rows) 671 }