gorgonia.org/gorgonia@v0.9.17/vm_tape.go (about) 1 package gorgonia 2 3 import ( 4 "bytes" 5 "fmt" 6 "log" 7 "runtime" 8 "strings" 9 10 "github.com/chewxy/hm" 11 "github.com/pkg/errors" 12 "gorgonia.org/tensor" 13 ) 14 15 type tapeMachine struct { 16 ExternMetadata 17 18 p *program 19 locMap map[*Node]register 20 21 // "register" banks 22 cpumem []Value // Value - knows its own type and shape 23 gpumem []Value // Value of which the memories are stored in GPU memory 24 25 // state stuff, to allow continuation 26 pc int 27 28 // operational stuff 29 bindNodesDV Nodes // nodes that require binding of DV 30 watchNodes Nodes 31 watchRegs []register 32 logger *log.Logger 33 buf *bytes.Buffer 34 valueFmt string 35 tabcount int 36 logFlags byte 37 38 runFlags byte // spare2: trace(copy values and put into nodes) 39 } 40 41 // NewTapeMachine creates a VM that compiles a graph into a prog. 42 func NewTapeMachine(g *ExprGraph, opts ...VMOpt) *tapeMachine { 43 m := &tapeMachine{ 44 valueFmt: "%3.3g", 45 } 46 m.Engine = StandardEngine{} 47 48 if b, ok := whichblas.(batchedBLAS); ok { 49 m.b = b 50 } 51 52 for _, opt := range opts { 53 opt(m) 54 } 55 56 m.doAlloc() 57 58 if m.p == nil || m.locMap == nil { 59 prog, locMap, err := Compile(g) 60 if err != nil { 61 panic(err) 62 } 63 64 m.p = prog 65 m.locMap = locMap 66 } 67 m.cpumem = make([]Value, m.p.cpulocs) 68 m.gpumem = make([]Value, m.p.gpulocs) 69 m.init() 70 for _, n := range m.p.g.AllNodes() { 71 setEngine(n.boundTo, m.Engine) 72 } 73 74 runtime.SetFinalizer(m, finalizeTapeMachine) // a "defer" to deinitialize CUDA stuff (if using CUDA build) 75 return m 76 } 77 78 func (m *tapeMachine) logBwd() bool { return (m.logFlags>>bwdOnly)&byte(1) == 1 } 79 func (m *tapeMachine) doLogBwd() { m.logFlags |= byte(1) << bwdOnly } 80 func (m *tapeMachine) dontLogBwd() { m.logFlags &= (^(byte(1) << bwdOnly)) } 81 82 func (m *tapeMachine) logFwd() bool { return (m.logFlags>>fwdOnly)&byte(1) == 1 } 83 func (m *tapeMachine) doLogFwd() { m.logFlags |= byte(1) << fwdOnly } 84 func (m *tapeMachine) dontLogFwd() { m.logFlags &= (^(byte(1) << fwdOnly)) } 85 86 func (m *tapeMachine) watchNaN() bool { return (m.runFlags>>watchNaN)&byte(1) == 1 } 87 func (m *tapeMachine) doWatchNaN() { m.runFlags |= byte(1) << watchNaN } 88 func (m *tapeMachine) dontWatchNaN() { m.runFlags &= (^(byte(1) << watchNaN)) } 89 90 func (m *tapeMachine) watchInf() bool { return (m.runFlags>>watchInf)&byte(1) == 1 } 91 func (m *tapeMachine) doWatchInf() { m.runFlags |= byte(1) << watchInf } 92 func (m *tapeMachine) dontWatchInf() { m.runFlags &= (^(byte(1) << watchInf)) } 93 94 func (m *tapeMachine) watchAll() bool { return (m.logFlags>>watchAll)&byte(1) == 1 } 95 func (m *tapeMachine) doWatchAll() { m.logFlags |= (byte(1) << watchAll) } 96 func (m *tapeMachine) dontWatchAll() { m.logFlags &= (^(byte(1) << watchAll)) } 97 98 func (m *tapeMachine) alloc() bool { return (m.runFlags>>allocVals)&byte(1) == 1 } 99 func (m *tapeMachine) doAlloc() { m.runFlags |= byte(1) << allocVals } 100 func (m *tapeMachine) dontAlloc() { m.runFlags &= (^(byte(1) << allocVals)) } 101 102 func (m *tapeMachine) trace() bool { return (m.runFlags>>spare2)&byte(1) == 1 } 103 func (m *tapeMachine) doTrace() { m.runFlags |= byte(1) << spare2 } 104 func (m *tapeMachine) dontTrace() { m.runFlags &= (^(byte(1) << spare2)) } 105 106 func (m *tapeMachine) bindDV() bool { return m.runFlags>>spare3&byte(1) == 1 } 107 func (m *tapeMachine) doBindDV() { m.runFlags |= byte(1) << spare3 } 108 func (m *tapeMachine) dontBindDV() { m.runFlags &= (^(byte(1) << spare3)) } 109 110 // Reset resets the run state of the machine by changing the instruction pointer back to 0 111 // and reseting the registry 112 func (m *tapeMachine) Reset() { 113 m.pc = 0 114 m.ExternMetadata.Reset() 115 116 for i := range m.gpumem { 117 returnValue(m.gpumem[i]) 118 m.gpumem[i] = nil // 119 } 120 for i := range m.cpumem { 121 m.cpumem[i] = nil 122 } 123 } 124 125 func (m *tapeMachine) Close() error { 126 finalizeTapeMachine(m) 127 return nil 128 } 129 130 // Prog returns the compiled program. This would mainly be used in debugging functions 131 func (m *tapeMachine) Prog() *program { return m.p } 132 133 // LocMap returns the location where the Node's execution results are stored. This would mainly be used in debugging functions. 134 func (m *tapeMachine) LocMap() map[*Node]register { return m.locMap } 135 136 // Let wraps the Let() function of the package, with additional checks that n is in the machine 137 func (m *tapeMachine) Let(n *Node, be interface{}) (err error) { 138 if !m.p.g.Has(n.ID()) { 139 return errors.Errorf("Node %v does not exist in this graph", n) 140 } 141 142 return Let(n, be) 143 } 144 145 // Set wraps the Set() function of this package, with additional checks that both a and b are in the machine 146 func (m *tapeMachine) Set(a, b *Node) (err error) { 147 if !m.p.g.Has(a.ID()) { 148 return errors.Errorf("Node %v does not exist in this graph", a) 149 } 150 if !m.p.g.Has(b.ID()) { 151 return errors.Errorf("Node %v does not exist in this graph", b) 152 } 153 154 if b.Value() != nil { 155 return a.bind(b.Value()) 156 } 157 158 // get the registry location 159 breg := m.locMap[b] 160 v := m.getValue(breg) 161 if v == nil { 162 return nyi("handling of tensor.Memory -> Value", "tapeMachine.Set") 163 } 164 165 machineLogf("Setting %v to %v. Read from %v Value is %v", b, a, breg, v) 166 return a.bind(v) 167 } 168 169 // Run runs a fragment (a subset of a program). 170 func (m *tapeMachine) Run(frag fragment) (err error) { 171 defer func() { 172 if err == nil { 173 m.dontAlloc() 174 } 175 }() 176 177 for _, instr := range frag { 178 if err = instr.exec(m); err != nil { 179 return errors.Wrap(err, "Failed to carry exec()") 180 } 181 } 182 machineLogf("Binding values based on final output") 183 enterLogScope() 184 for n, r := range m.locMap { 185 if n.isInput() { 186 continue 187 } 188 189 v := m.getValue(r) 190 if v == nil { 191 return nyi("converting tensor.Memory to Value", "TapeMachine.Run") 192 } 193 194 if err = n.bind(m.cpumem[r.id]); err != nil { 195 return errors.Wrap(err, bindFail) 196 } 197 } 198 leaveLogScope() 199 return 200 } 201 202 func (m *tapeMachine) RunAll() (err error) { 203 runtime.LockOSThread() 204 defer runtime.UnlockOSThread() 205 defer m.DoWork() 206 207 workAvailable := m.ExternMetadata.WorkAvailable() 208 syncChan := m.ExternMetadata.Sync() 209 errChan := make(chan error) 210 doneChan := make(chan struct{}) 211 212 go m.runall(errChan, doneChan) 213 for { 214 select { 215 case sychronous := <-workAvailable: 216 err := m.ExternMetadata.DoWork() 217 if err != nil { 218 return err 219 } 220 if sychronous { 221 syncChan <- struct{}{} 222 } 223 case err := <-errChan: 224 return errors.Wrapf(err, "PC: %d", m.pc) 225 case <-doneChan: 226 err := m.ExternMetadata.DoWork() 227 if err != nil { 228 return err 229 } 230 return nil 231 } 232 } 233 } 234 235 func (m *tapeMachine) runall(errChan chan error, doneChan chan struct{}) { 236 for ; m.pc < len(m.p.instructions); m.pc++ { 237 instr := m.p.instructions[m.pc] 238 m.logf("PC %d", m.pc) 239 if err := instr.exec(m); err != nil { 240 err = errors.Wrapf(err, "PC %d. Failed to execute instruction %v", m.pc, instr) 241 errChan <- err 242 return 243 } 244 // only proceed to check NaNs and Infs for execOp 245 if _, ok := instr.(*execOp); !ok { 246 continue 247 } 248 249 if m.watchNaN() { 250 writeTo := instr.writes().id 251 id := instr.ID() 252 if writeTo > 0 && id > 0 { 253 v := m.getValue(instr.writes()) 254 if v == nil { 255 err := errors.Errorf(nyiFail, "converting tensor.Memory to Value", "watchNaN") 256 errChan <- err 257 return 258 } 259 260 if hasNaN(v, CPU) { 261 n := m.p.g.Node(id).(*Node) 262 err := errors.Errorf("NaN found in value. Node: %v(%x)", n, n.ID()) 263 errChan <- err 264 return 265 } 266 } 267 } 268 269 if m.watchInf() { 270 writeTo := instr.writes().id 271 id := instr.ID() 272 if writeTo > 0 && id > 0 { 273 v := m.getValue(instr.writes()) 274 if v == nil { 275 err := errors.Errorf(nyiFail, "converting tensor.Memory to Value", "watchInf") 276 errChan <- err 277 return 278 } 279 280 if hasInf(v, CPU) { 281 n := m.p.g.Node(id).(*Node) 282 err := errors.Errorf("Inf found in value. Node: %v(%x)", n, n.ID()) 283 errChan <- err 284 return 285 } 286 } 287 } 288 } 289 doneChan <- struct{}{} 290 } 291 292 func (m *tapeMachine) getValue(r register) Value { 293 switch r.device { 294 case CPU: 295 return m.cpumem[r.id] 296 default: 297 return m.gpumem[r.id] 298 } 299 } 300 301 func (m *tapeMachine) writeValue(r register, v Value) { 302 switch r.device { 303 case CPU: 304 m.cpumem[r.id] = v 305 default: 306 m.gpumem[r.id] = v 307 } 308 } 309 310 func (m *tapeMachine) watchedLogf(format string, attrs ...interface{}) { 311 instr := m.p.instructions[m.pc] 312 reads := instr.reads() 313 writes := instr.writes() 314 315 watched := m.watchAll() 316 317 if !watched { 318 for _, reg := range reads { 319 for _, watch := range m.watchRegs { 320 if reg.id == watch.id { 321 watched = true 322 break 323 } 324 } 325 } 326 } 327 328 if !watched { 329 for _, watch := range m.watchRegs { 330 if watch.id == writes.id { 331 watched = true 332 break 333 } 334 } 335 } 336 337 // TODO: Work on watched nodes 338 if !watched { 339 340 } 341 342 if watched { 343 m.logf(format, attrs...) 344 } 345 } 346 347 func (m *tapeMachine) logf(format string, attrs ...interface{}) { 348 switch { 349 case machineDev: 350 if m.logger != nil { 351 goto loggercase 352 } 353 354 machineLogf(format, attrs...) 355 break 356 357 loggercase: 358 fallthrough 359 case m.logger != nil: 360 s := fmt.Sprintf(format, attrs...) 361 s = strings.Replace(s, "\n", m.buf.String(), -1) 362 m.logger.Println(s) 363 } 364 } 365 366 func (m *tapeMachine) enterLogScope() { 367 if DEBUG && machineDev { 368 enterLogScope() 369 } 370 m.tabcount++ 371 if m.logger != nil { 372 reps := strings.Repeat("\t", m.tabcount) 373 m.logger.SetPrefix(reps) 374 m.buf.Reset() 375 m.buf.WriteString("\n") 376 m.buf.WriteString(reps) 377 } 378 } 379 380 func (m *tapeMachine) leaveLogScope() { 381 if DEBUG && machineDev { 382 leaveLogScope() 383 } 384 m.tabcount-- 385 if m.tabcount < 0 { 386 m.tabcount = 0 387 } 388 if m.logger != nil { 389 reps := strings.Repeat("\t", m.tabcount) 390 m.logger.SetPrefix(reps) 391 m.buf.Reset() 392 m.buf.WriteString("\n") 393 m.buf.WriteString(reps) 394 } 395 } 396 397 /* PROGRAM */ 398 399 type program struct { 400 instructions fragment 401 args int 402 cpulocs int 403 gpulocs int 404 cpumem int64 405 gpumem []int64 406 g *ExprGraph // original dag 407 df *dataflow // dataflow analysis 408 m map[*Node]fragment // store which nodes create which instructions 409 sorted Nodes 410 } 411 412 func (p *program) String() string { 413 var buf bytes.Buffer 414 fmt.Fprintf(&buf, "Instructions:\n%s\nArgs: %d | CPU Memories: %d | GPU Memories: %d\nCPU Mem: %v | GPU Mem %v\n\nNode:instructions map:\n", p.instructions, p.args, p.cpulocs, p.gpulocs, p.cpumem, p.gpumem) 415 416 for i, n := range p.sorted { 417 fmt.Fprintf(&buf, "\t%d\t%x:", i, n.ID()) 418 frag := p.m[n] 419 for j, instr := range frag { 420 if j == 0 { 421 fmt.Fprintf(&buf, "\t%v\n", instr) 422 } else { 423 fmt.Fprintf(&buf, "\t\t%v\n", instr) 424 } 425 } 426 427 } 428 429 return buf.String() 430 } 431 432 // Graph enables the end user to inspect the graph (typically useful for debugging) 433 func (p *program) Graph() *ExprGraph { return p.g } 434 435 func (p *program) CPUMemReq() int64 { return p.cpumem } 436 437 func (p *program) GPUMemReq() []int64 { 438 retVal := make([]int64, len(p.gpumem)) 439 copy(retVal, p.gpumem) 440 return retVal 441 } 442 443 /* REGISTER */ 444 445 type register struct { 446 id int 447 device Device 448 } 449 450 func (r register) String() string { return fmt.Sprintf("%s%d", r.device, r.id) } 451 452 /* INSTRUCTIONS */ 453 454 type tapeInstr interface { 455 ID() int64 // ID is the node ID 456 reads() []register 457 writes() register 458 exec(*tapeMachine) error 459 fmt.Stringer 460 } 461 462 type fragment []tapeInstr 463 464 func (f fragment) String() string { 465 var buf bytes.Buffer 466 for i, instr := range f { 467 fmt.Fprintf(&buf, "\t%d\t%s\n", i, instr) 468 } 469 return buf.String() 470 } 471 472 func (f fragment) has(want tapeInstr) bool { 473 for _, instr := range f { 474 if instr == want { 475 return true 476 } 477 } 478 return false 479 } 480 481 type alloc struct { 482 id int64 // node ID 483 t hm.Type 484 s tensor.Shape 485 486 readFrom []register 487 writeTo register 488 } 489 490 func newAlloc(n *Node, writeTo register) alloc { 491 return alloc{ 492 id: n.ID(), 493 t: n.t, 494 s: n.shape, 495 writeTo: writeTo, 496 } 497 } 498 499 func (instr alloc) ID() int64 { return instr.id } 500 func (instr alloc) reads() []register { return instr.readFrom } 501 func (instr alloc) writes() register { return instr.writeTo } 502 503 func (instr alloc) exec(m *tapeMachine) (err error) { 504 m.logf("Executing %v", instr) 505 m.enterLogScope() 506 defer m.leaveLogScope() 507 508 var dt tensor.Dtype 509 if dt, err = dtypeOf(instr.t); err != nil { 510 return errors.Wrapf(err, dtypeExtractionFail, instr.t) 511 } 512 513 reg := m.getValue(instr.writeTo) 514 if reg != nil && reg.Dtype() == dt && reg.Shape().Eq(instr.s) { 515 return nil 516 } 517 518 dev := instr.writeTo.device 519 var v Value 520 switch dev { 521 case CPU: 522 523 v, err = makeValue(instr.t, instr.s) 524 525 default: 526 var mem tensor.Memory 527 memsize := calcMemSize(dt, instr.s) 528 if mem, err = m.ExternMetadata.Get(dev, memsize); err != nil { 529 return errors.Wrapf(err, "Unable to allocate %v bytes from %v | %T", memsize, dev, err) 530 } 531 v, err = makeValueFromMem(instr.t, instr.s, mem) 532 } 533 if err != nil { 534 return 535 } 536 setEngine(v, m.getEngine(dev)) 537 if vt, ok := v.(tensor.Tensor); ok { 538 m.watchedLogf("%x | %T", v.Uintptr(), vt.Engine()) 539 } else { 540 m.watchedLogf("%x", v.Uintptr()) 541 } 542 543 m.writeValue(instr.writeTo, v) 544 return nil 545 } 546 547 func (instr alloc) String() string { 548 return fmt.Sprintf("Alloc %v%v\t\t%v", instr.t, instr.s, instr.writeTo) 549 } 550 551 type free struct { 552 readsFrom register 553 } 554 555 func (instr free) ID() int64 { return -1 } 556 func (instr free) reads() []register { return []register{instr.readsFrom} } 557 func (instr free) writes() register { return register{-1, CPU} } 558 func (instr free) exec(m *tapeMachine) error { 559 m.logf("Executing Free %v", instr.readsFrom) 560 switch instr.readsFrom.device { 561 case CPU: 562 return nil 563 default: 564 m.logf("instr.read from not CPU - %v %v %d", instr.readsFrom, instr.readsFrom.device == CPU, instr.readsFrom.device) 565 mem := m.gpumem[instr.readsFrom.id] 566 size := int64(mem.MemSize()) 567 568 m.Put(instr.readsFrom.device, mem, size) 569 m.gpumem[instr.readsFrom.id] = nil 570 return nil 571 } 572 } 573 func (instr free) String() string { return fmt.Sprintf("Free %v", instr.readsFrom) } 574 575 type loadArg struct { 576 index int64 577 writeTo register 578 name string 579 } 580 581 func (instr loadArg) ID() int64 { return instr.index } 582 func (instr loadArg) reads() []register { return nil } 583 func (instr loadArg) writes() register { return instr.writeTo } 584 585 func (instr loadArg) exec(m *tapeMachine) error { 586 m.logf("Executing %v", instr) 587 m.enterLogScope() 588 defer m.leaveLogScope() 589 590 node := m.p.g.Node(instr.index).(*Node) 591 m.logf("node %v", node) 592 593 if node.boundTo == nil { 594 return errors.Errorf("No value bound to node %v (%x)", node, node.ID()) 595 } 596 597 var v Value 598 if dv, ok := node.boundTo.(*dualValue); ok { 599 v = dv.Value 600 } else { 601 v = node.boundTo 602 } 603 604 m.writeValue(instr.writeTo, v) 605 // m.watchedLogf("Write To: %v", instr.writeTo) 606 // m.watchedLogf(m.valueFmt, m.cpumem[instr.writeTo.id]) 607 return nil 608 } 609 610 func (instr loadArg) String() string { 611 return fmt.Sprintf("loadArg %x (%v) to %v", instr.index, instr.name, instr.writeTo) 612 } 613 614 type execOp struct { 615 op Op 616 617 id int64 618 619 readFrom []register 620 writeTo register 621 size int64 // size represents the outputsize 622 623 preAllocated bool 624 useUnsafe bool 625 useGPU bool 626 } 627 628 func (instr *execOp) ID() int64 { return instr.id } 629 func (instr *execOp) reads() []register { return instr.readFrom } 630 func (instr *execOp) writes() register { return instr.writeTo } 631 632 func newExecOp(n *Node) *execOp { 633 _, useGPU := n.op.(CUDADoer) 634 compileLogf("op %v uses GPU %v", n.op, useGPU) 635 dt, err := dtypeOf(n.t) 636 if err != nil { 637 panic(err) 638 } 639 size := calcMemSize(dt, n.Shape()) 640 641 return &execOp{ 642 op: n.op, 643 id: n.ID(), 644 useGPU: useGPU, 645 size: size, 646 } 647 } 648 649 func (instr *execOp) String() string { 650 return fmt.Sprintf("%v\t%v\t%v\t%t\t%t\t%t", instr.op, instr.readFrom, instr.writeTo, instr.op.CallsExtern(), instr.useUnsafe, instr.preAllocated) 651 } 652 653 // flushInstr is for blastoise and cubone 654 type flushInstr struct{} 655 656 func (instr flushInstr) exec(m *tapeMachine) error { 657 m.logf("Executing DoWork") 658 return m.ExternMetadata.DoWork() 659 } 660 661 func (instr flushInstr) ID() int64 { return -1 } 662 func (instr flushInstr) reads() []register { return nil } 663 func (instr flushInstr) writes() register { return register{-1, CPU} } 664 func (instr flushInstr) String() string { return "DoWork" } 665 666 type letInstr struct { 667 readFrom register 668 writeTo register 669 } 670 671 func (instr letInstr) ID() int64 { return -1 } 672 func (instr letInstr) reads() []register { return []register{instr.readFrom} } 673 func (instr letInstr) writes() register { return instr.writeTo } 674 func (instr letInstr) exec(*tapeMachine) error { return nil } 675 676 func (instr letInstr) String() string { 677 return fmt.Sprintf("LET %v = %v", instr.writeTo, instr.readFrom) 678 } 679 680 type readInstr struct { 681 readFrom register 682 into *Value 683 684 // required to convert tensor.Memory to Value 685 t hm.Type 686 s tensor.Shape 687 } 688 689 func (instr *readInstr) ID() int64 { return -1 } 690 func (instr *readInstr) reads() []register { return []register{instr.readFrom} } 691 func (instr *readInstr) writes() register { return register{-1, CPU} } 692 func (instr *readInstr) exec(m *tapeMachine) (err error) { 693 m.logf("Executing READ - read from %v into %v", instr.readFrom, instr.into) 694 v := m.getValue(instr.readFrom) 695 if v == nil { 696 return nyi("value of nil", "readInstr.exec") 697 } 698 699 if *instr.into != nil { 700 dest := *instr.into 701 _, err = Copy(dest, v) 702 return err 703 } 704 705 v2, err := CloneValue(v) 706 if err != nil { 707 return errors.Wrap(err, cloneFail) 708 } 709 710 *instr.into = v2 711 return nil 712 } 713 714 func (instr *readInstr) String() string { 715 return fmt.Sprintf("Read %v into %p", instr.readFrom, instr.into) 716 } 717 718 type deviceTransport struct { 719 from, to register 720 } 721 722 func (instr deviceTransport) ID() int64 { return -1 } 723 func (instr deviceTransport) reads() []register { 724 return []register{instr.from} 725 } 726 func (instr deviceTransport) writes() register { return instr.to } 727 728 func (instr deviceTransport) String() string { 729 return fmt.Sprintf("memcpy(%v, %v)", instr.to, instr.from) 730 }