gorgonia.org/gorgonia@v0.9.17/vm_genera.go (about) 1 package gorgonia 2 3 import ( 4 "bytes" 5 "fmt" 6 "io/ioutil" 7 "log" 8 "runtime" 9 "strings" 10 11 "github.com/pkg/errors" 12 "gorgonia.org/tensor" 13 ) 14 15 type lispMachine struct { 16 ExternMetadata 17 g *ExprGraph 18 q []adInstr // a to-do list of differentiation instructions 19 20 // device stuff 21 cpumem int64 22 gpumem []int64 // gpumem is indexed by gpuid 23 24 // state stuff, to allow continuation 25 sorted Nodes 26 df *dataflow 27 fwd int 28 bwd int 29 30 // logging stuff 31 watchlist Nodes 32 logger *log.Logger 33 buf *bytes.Buffer 34 valueFmt string 35 tabcount int 36 logFlags byte 37 38 runFlags byte // supposed to go into state stuff. Placed here for better compacting of struct 39 checkedRoots bool // supposed to go into state stuff. 40 } 41 42 // NewLispMachine creates a VM that executes the graph as it is traversed. Depending on the VMOpts passed in 43 // this VM is also capable of performing automatic differentiation. 44 func NewLispMachine(g *ExprGraph, opts ...VMOpt) *lispMachine { 45 runFlags := (byte(0) | (byte(1) << fwdOnly)) | (1 << bwdOnly) // run fwd and backwards 46 m := &lispMachine{ 47 g: g, 48 fwd: -1, 49 bwd: -1, 50 valueFmt: "%3.3f", 51 logFlags: 0x0, // log nothing 52 runFlags: runFlags, // run only fwd and bwd 53 } 54 m.Engine = StandardEngine{} 55 56 for _, opt := range opts { 57 opt(m) 58 } 59 if err := m.init(); err != nil { 60 panic(err) 61 } 62 63 for _, n := range g.AllNodes() { 64 setEngine(n.boundTo, m.Engine) 65 } 66 67 runtime.SetFinalizer(m, finalizeLispMachine) 68 return m 69 } 70 71 func (m *lispMachine) logBwd() bool { return (m.logFlags>>bwdOnly)&byte(1) == 1 } 72 func (m *lispMachine) doLogBwd() { m.logFlags |= byte(1) << bwdOnly } 73 func (m *lispMachine) dontLogBwd() { m.logFlags &= (^(byte(1) << bwdOnly)) } 74 func (m *lispMachine) runBwd() bool { return m.runFlags>>bwdOnly&byte(1) == 1 } 75 func (m *lispMachine) doExecBwd() { m.runFlags |= byte(1) << bwdOnly } 76 func (m *lispMachine) dontExecBwd() { m.runFlags &= (^(byte(1) << bwdOnly)) } 77 78 func (m *lispMachine) logFwd() bool { return (m.logFlags>>fwdOnly)&byte(1) == 1 } 79 func (m *lispMachine) doLogFwd() { m.logFlags |= byte(1) << fwdOnly } 80 func (m *lispMachine) dontLogFwd() { m.logFlags &= (^(byte(1) << fwdOnly)) } 81 func (m *lispMachine) runFwd() bool { return m.runFlags>>fwdOnly&byte(1) == 1 } 82 func (m *lispMachine) doExecFwd() { m.runFlags |= byte(1) << fwdOnly } 83 func (m *lispMachine) dontExecFwd() { m.runFlags &= (^(byte(1) << fwdOnly)) } 84 85 func (m *lispMachine) watchNaN() bool { return (m.runFlags>>watchNaN)&byte(1) == 1 } 86 func (m *lispMachine) doWatchNaN() { m.runFlags |= byte(1) << watchNaN } 87 func (m *lispMachine) dontWatchNaN() { m.runFlags &= (^(byte(1) << watchNaN)) } 88 89 func (m *lispMachine) watchInf() bool { return (m.runFlags>>watchInf)&byte(1) == 1 } 90 func (m *lispMachine) doWatchInf() { m.runFlags |= byte(1) << watchInf } 91 func (m *lispMachine) dontWatchInf() { m.runFlags &= (^(byte(1) << watchInf)) } 92 93 func (m *lispMachine) watchAll() bool { return (m.logFlags>>watchAll)&byte(1) == 1 } 94 func (m *lispMachine) doWatchAll() { m.logFlags |= (byte(1) << watchAll) } 95 func (m *lispMachine) dontWatchAll() { m.logFlags &= (^(byte(1) << watchAll)) } 96 97 func (m *lispMachine) dealloc() bool { return (m.runFlags>>allocVals)&byte(1) == 1 } 98 func (m *lispMachine) doDealloc() { m.runFlags |= byte(1) << allocVals } 99 func (m *lispMachine) dontDealloc() { m.runFlags &= (^(byte(1) << allocVals)) } 100 101 func (m *lispMachine) setRootGrad() bool { return (m.runFlags>>spare3)&byte(1) == 1 } 102 func (m *lispMachine) allowSetRootGrad() { m.runFlags |= byte(1) << spare3 } 103 func (m *lispMachine) disallowSetRootGrad() { m.runFlags &= (^(byte(1) << spare3)) } 104 105 func (m *lispMachine) Reset() { 106 m.fwd = len(m.sorted) - 1 107 m.bwd = len(m.q) - 1 108 } 109 110 func (m *lispMachine) Close() error { 111 finalizeLispMachine(m) 112 return nil 113 } 114 115 // RunAll traverses a graph and executes every node. Backpropagation is done if necessary 116 func (m *lispMachine) RunAll() (err error) { 117 runtime.LockOSThread() 118 defer runtime.UnlockOSThread() 119 120 if err = m.checkRoots(); err != nil { 121 return errors.Wrap(err, "Could not checkRoots()") 122 } 123 124 if m.runBwd() { 125 defer func() { 126 m.q = nil // this needs to be nil'd or else there would still be references to m. Then there won't be any garbage being collected 127 }() 128 } 129 130 workAvailable := m.WorkAvailable() 131 syncChan := m.ExternMetadata.Sync() 132 errChan := make(chan error) 133 doneChan := make(chan struct{}) 134 135 go m.runall(errChan, doneChan) 136 for { 137 select { 138 case synchronous := <-workAvailable: 139 err := m.ExternMetadata.DoWork() 140 if err != nil { 141 var node *Node 142 switch { 143 case synchronous: 144 if m.fwd < len(m.sorted) { 145 node = m.sorted[m.fwd] 146 } else { 147 node = m.sorted[m.fwd-1] 148 } 149 default: 150 if m.fwd-1 > 0 && m.fwd <= len(m.sorted) { 151 node = m.sorted[m.fwd-1] 152 } else { 153 node = m.sorted[0] 154 } 155 } 156 157 err = vmContextualError{ 158 error: errors.Wrapf(err, "DoWork failed"), 159 node: node, 160 instr: m.fwd, 161 } 162 return err 163 } 164 if synchronous { 165 syncChan <- struct{}{} 166 } 167 case err = <-errChan: 168 if m.fwd < len(m.sorted) { 169 err = vmContextualError{ 170 error: errors.Wrapf(err, "Running Node: %v", m.sorted[m.fwd]), 171 node: m.sorted[m.fwd], 172 instr: m.fwd, 173 } 174 return 175 } 176 return errors.Wrap(err, "RunAll") 177 case <-doneChan: 178 err := m.ExternMetadata.DoWork() 179 if err != nil { 180 return err 181 } 182 return nil 183 } 184 } 185 } 186 187 // UnbindAll detaches the values from the node, allowing for them to be cleaned up the next GC cycle. 188 func (m *lispMachine) UnbindAll() { 189 // if m.dealloc() { 190 for _, n := range m.sorted { 191 m.logf("dealloc n; %v %x %p", n, n.Hashcode(), n) 192 if !n.isInput() { 193 n.unbind() 194 } 195 } 196 // } 197 } 198 199 // LastRun returns the nodes and results from the last run. Additionally it returns whether backprop was done. 200 func (m *lispMachine) LastRun() (n *Node, backprop bool) { 201 if m.fwd < 0 && m.runBwd() { 202 goto backward 203 } else if !m.runBwd() { 204 n = m.sorted[0] // last to run 205 return 206 } else { 207 n = m.sorted[m.fwd] 208 return 209 } 210 211 backward: 212 backprop = true 213 if m.bwd < 0 { 214 n = m.q[0].output 215 return 216 } 217 n = m.q[m.bwd].output 218 return 219 } 220 221 // check roots only applies if you want to run a backprop as well 222 func (m *lispMachine) checkRoots() (err error) { 223 if !m.checkedRoots && m.runBwd() { 224 machineLogf("Checking if provided graph is sensible") 225 m.logf("roots: %v", m.g.Roots()) 226 for _, root := range m.g.Roots() { 227 switch { 228 case m.setRootGrad() && !root.isStmt: 229 // check root's value 230 // if _, ok := root.boundTo.(*dualValue); !ok { 231 // err = errors.Errorf("Expected root %v to have a boundTo of a dualValue", root) 232 // return 233 // } 234 case !m.setRootGrad() && !root.IsScalar() && !root.isStmt: 235 err = errors.Errorf("Expected cost to be a scalar. Got %v with shape %v instead", root, root.Shape()) 236 ioutil.WriteFile("err.dot", []byte(root.RestrictedToDot(2, 10)), 0644) 237 return 238 } 239 } 240 } 241 return 242 } 243 244 func (m *lispMachine) prepGraph() (err error) { 245 if m.sorted == nil { 246 if m.sorted, err = Sort(m.g); err != nil { 247 return errors.Wrap(err, sortFail) 248 } 249 reverseNodes(m.sorted) 250 m.fwd = 0 251 } 252 return 253 } 254 255 func (m *lispMachine) runall(errChan chan error, doneChan chan struct{}) { 256 var err error 257 if !m.runFwd() { 258 goto backward 259 } 260 261 for err = nil; err == nil && m.fwd < len(m.sorted); m.fwd++ { 262 err = m.forward() 263 } 264 265 if err != nil { 266 errChan <- err 267 } 268 269 // send a synchronous signal, do all (if any) CUDA work before continuing with backprop 270 m.Signal() 271 272 backward: 273 if !m.runBwd() { 274 doneChan <- struct{}{} 275 return 276 } 277 278 if m.bwd < 0 { 279 m.bwd = len(m.q) - 1 280 } 281 282 for err = nil; err == nil && m.bwd >= 0; m.bwd-- { 283 err = m.backward() 284 } 285 if err != nil { 286 errChan <- err 287 } 288 doneChan <- struct{}{} 289 } 290 291 func (m *lispMachine) forward() (err error) { 292 if m.fwd < 0 { 293 return nil // or err? 294 } 295 n := m.sorted[m.fwd] 296 297 m.watchedLogf("n: %v | (%x) | %p", n, n.id, n) 298 m.enterLogScope() 299 defer m.leaveLogScope() 300 301 defer setEngine(n.boundTo, m.Engine) 302 303 if !n.isStmt { 304 switch { 305 case n.isArg(): 306 machineLogf("Unit() on input node") 307 if err = n.bind(dvUnit(n.boundTo)); err != nil { 308 return errors.Wrap(err, bindFail) 309 } 310 return 311 case n.isRandom(): 312 machineLogf("binding value of random node") 313 var v Value 314 if v, err = n.op.Do(); err != nil { 315 return errors.Wrapf(err, execFail, n.op, n) 316 } 317 318 // we wrap it in a dualValue, but we make it a constant 319 if err = n.bind(dvUnit(v)); err != nil { 320 return errors.Wrap(err, bindFail) 321 } 322 323 return 324 default: 325 // do nothihng 326 } 327 m.watchedLogf(m.valueFmt, n.boundTo) 328 } 329 330 // other wise it's time to execute the op 331 m.logf("execute Op") 332 dev := n.dataOn 333 op := NewExternalOp(n.op, ExecutionContext{m, dev}, nil) 334 335 // m.watchedLogf("Result of execution of this node would reside in %v", dev) 336 var output *dualValue 337 338 inputs := make([]*dualValue, len(n.children)) 339 children := n.children 340 341 m.enterLogScope() 342 for i, child := range children { 343 m.logf("child %d: %v %v", i, child, child.Shape()) 344 if child.Device() == n.Device() { 345 inputs[i] = child.boundTo.(*dualValue) 346 // continue 347 } 348 349 var allocV, allocD bool 350 var v, d Value 351 if v, allocV, err = child.ValueOnDevice(dev, m); err != nil { 352 return errors.Wrapf(err, "Unable to get Value on Device %v", dev) 353 } 354 if d, allocD, err = child.GradOnDevice(dev, m); err != nil { 355 if !child.isRandom() { 356 return errors.Wrapf(err, "Unable to get Grad on Device %v", dev) 357 } 358 err = nil 359 } 360 361 dv := borrowDV() 362 363 dv.Value = v 364 dv.d = d 365 inputs[i] = dv 366 367 defer func() { 368 if allocV { 369 m.logf("Putting 0x%x |%T", v.Uintptr(), v) 370 m.PutValue(dev, v) 371 } 372 if allocD { 373 m.PutValue(dev, d) 374 } 375 if allocV && allocD { 376 returnDV(dv) 377 } 378 }() 379 } 380 m.leaveLogScope() 381 m.watchedLogf("Before:") 382 m.watchedLogf(m.valueFmt, n.boundTo) 383 384 switch { 385 case (m.g.roots.Contains(n) || n.isRoot()) && !n.isStmt: 386 machineLogf("Applying op %v to root", op) 387 if n.boundTo == nil { 388 machineLogf("dvBindVar") 389 m.logf("dvBindVar") 390 if output, err = dvBindVar(op, inputs); err != nil { 391 return errors.Wrap(err, "Failed to bindVar") 392 } 393 if err = n.bind(output); err != nil { 394 return errors.Wrap(err, bindFail) 395 } 396 } else { 397 machineLogf("dvBindVar0") 398 m.logf("dvBindVar0") 399 dv, ok := n.boundTo.(*dualValue) 400 if !ok { 401 dv = dvUnitVar(n.boundTo) 402 n.boundTo = dv 403 // panic(fmt.Sprintf("n not dual value %v", n)) 404 } 405 if err = dvBindVar0(op, dv, inputs); err != nil { 406 return errors.Wrapf(err, execFail, op, n) 407 } 408 } 409 410 case n.isStmt: 411 switch ot := n.op.(type) { 412 case readOp: 413 machineLogf("ReadOp: %v ", op) 414 child := children[0] 415 childVal := child.boundTo 416 if child.Device() != CPU { 417 m.Signal() // get work to be done first 418 419 if dv, ok := n.children[0].boundTo.(*dualValue); ok { 420 *ot.into = dv.Value 421 } else { 422 *ot.into = childVal 423 } 424 425 } else { 426 if dv, ok := childVal.(*dualValue); ok { 427 *ot.into = dv.Value 428 } else { 429 *ot.into = childVal 430 } 431 } 432 } 433 434 case n.boundTo == nil: 435 m.watchedLogf("Fresh, unencountered node, so dvBind(%v)", op) 436 if dev != CPU { 437 var dt tensor.Dtype 438 if dt, err = dtypeOf(n.t); err != nil { 439 return errors.Wrapf(err, dtypeExtractionFail, n.t) 440 } 441 442 var mem tensor.Memory 443 memsize := calcMemSize(dt, n.shape) 444 if mem, err = m.Get(dev, memsize); err != nil { 445 return errors.Wrapf(err, allocFail, memsize, dev) 446 } 447 448 var reuse Value 449 if reuse, err = makeValueFromMem(n.t, n.shape, mem); err != nil { 450 return errors.Wrapf(err, makeValueFail, n.t, n.shape) 451 } 452 453 op.Prealloc = reuse 454 } 455 456 if output, err = dvBind(op, inputs); err != nil { 457 return errors.Wrapf(err, execFail, op, n) 458 } 459 460 if err = n.bind(output); err != nil { 461 return errors.Wrap(err, bindFail) 462 } 463 464 default: 465 m.logf("bind(%v) with as much reuse as possible", op) 466 // reuse as much as possible 467 output := dvUnit(n.boundTo) 468 if err = n.bind(output); err != nil { 469 return errors.Wrap(err, bindFail) 470 } 471 472 if dev != CPU { 473 op.Prealloc = output.Value 474 } 475 476 err = dvBind0(op, output, inputs) 477 if _, ok := errors.Cause(err).(AutoDiffError); ok { 478 err = nil 479 } else if err != nil { 480 return errors.Wrapf(err, execFail, op, n) 481 } 482 } 483 m.watchedLogf("After:") 484 m.watchedLogf(m.valueFmt, n.boundTo) 485 486 if aop, ok := op.Op.(ADOp); ok && m.runBwd() { 487 instr := adInstr{ 488 ADOp: aop, 489 ctx: op.ExecutionContext, 490 491 inputs: n.children, // this is correct. 492 output: n, 493 } 494 m.q = append(m.q, instr) 495 } 496 m.watchedLogf("Added to Queue") 497 498 if m.watchNaN() && !n.isStmt { 499 if hasNaN(n.boundTo, dev) { 500 return errors.New("NaN found in value") 501 } 502 } 503 504 return 505 } 506 507 func (m *lispMachine) backward() (err error) { 508 if m.bwd < 0 { 509 return errors.New("no backprop queue") 510 } 511 if m.bwd >= len(m.q) { 512 return errors.New("Nothing to backprop") 513 } 514 515 instr := m.q[m.bwd] 516 m.watchedLogf("Differentiating op %v. Output: %v (%x)", instr, instr.output, instr.output.Hashcode()) 517 m.enterLogScope() 518 defer m.leaveLogScope() 519 520 m.watchedLogf("Inputs: %v", instr.inputs) 521 m.enterLogScope() 522 for _, in := range instr.inputs { 523 m.watchedLogf(m.valueFmt, in.boundTo.(*dualValue).d) 524 } 525 m.leaveLogScope() 526 527 // actual differentiation 528 if err = instr.do(); err != nil { 529 return errors.Wrapf(err, autodiffFail, instr.ADOp) 530 } 531 532 // Make sure that all the engines of all the values are set to use the correct engine 533 for _, in := range instr.inputs { 534 setEngine(in.boundTo, m.Engine) 535 } 536 setEngine(instr.output.boundTo, m.Engine) 537 538 m.watchedLogf("After:") 539 m.enterLogScope() 540 for _, in := range instr.inputs { 541 m.watchedLogf(m.valueFmt, in.boundTo.(*dualValue).d) 542 } 543 544 m.leaveLogScope() 545 546 if m.watchNaN() { 547 if hasNaN(instr.output.boundTo, instr.ctx.Device) { 548 return errors.New("NaN found in value") 549 } 550 551 for _, in := range instr.inputs { 552 if hasNaN(in.boundTo, instr.ctx.Device) { 553 return errors.New("NaN found in value") 554 } 555 } 556 } 557 return 558 } 559 560 func (m *lispMachine) watchedLogf(format string, attrs ...interface{}) { 561 if !m.logFwd() && !DEBUG { 562 goto backwards 563 } 564 565 if m.fwd >= 0 && m.fwd < len(m.sorted) { 566 n := m.sorted[m.fwd] 567 if m.watchlist.Contains(n) || m.watchAll() { 568 m.logf(format, attrs...) 569 } 570 return 571 } 572 573 backwards: 574 if !m.logBwd() && !DEBUG { 575 return 576 } 577 578 if m.bwd >= 0 { 579 instr := m.q[m.bwd] 580 write := m.watchlist.Contains(instr.output) 581 if !write { 582 for _, in := range instr.inputs { 583 if m.watchlist.Contains(in) { 584 write = true 585 break 586 } 587 } 588 } 589 590 if write || m.watchAll() || DEBUG { 591 m.logf(format, attrs...) 592 } 593 } 594 } 595 596 func (m *lispMachine) logf(format string, attrs ...interface{}) { 597 switch { 598 case machineDev, autodiffDev: 599 if machineDev { 600 601 machineLogf(format, attrs...) 602 } else { 603 autodiffLogf(format, attrs...) 604 } 605 606 if m.logger != nil { 607 goto loggercase 608 } 609 610 break 611 612 loggercase: 613 fallthrough 614 case m.logger != nil: 615 s := fmt.Sprintf(format, attrs...) 616 s = strings.Replace(s, "\n", m.buf.String(), -1) 617 m.logger.Println(s) 618 } 619 } 620 621 func (m *lispMachine) enterLogScope() { 622 if DEBUG && machineDev { 623 enterLogScope() 624 } 625 m.tabcount++ 626 if m.logger != nil { 627 reps := strings.Repeat("\t", m.tabcount) 628 m.logger.SetPrefix(reps) 629 m.buf.Reset() 630 m.buf.WriteString("\n") 631 m.buf.WriteString(reps) 632 } 633 } 634 635 func (m *lispMachine) leaveLogScope() { 636 if DEBUG && machineDev { 637 leaveLogScope() 638 } 639 m.tabcount-- 640 if m.tabcount < 0 { 641 m.tabcount = 0 642 } 643 if m.logger != nil { 644 reps := strings.Repeat("\t", m.tabcount) 645 m.logger.SetPrefix(reps) 646 m.buf.Reset() 647 m.buf.WriteString("\n") 648 m.buf.WriteString(reps) 649 } 650 } 651 652 // adInstr is an autodifferentiation instruction 653 type adInstr struct { 654 ADOp 655 ctx ExecutionContext 656 657 inputs Nodes 658 output *Node 659 } 660 661 func (instr adInstr) do() error { 662 if instr.output.dataOn != CPU { 663 for _, in := range instr.inputs { 664 if in.dataOn == CPU { 665 // ensure everything gets executed in the GPU first 666 instr.ctx.Signal() 667 break 668 } 669 } 670 } 671 err := instr.ADOp.DoDiff(instr.ctx, instr.inputs, instr.output) 672 // logf("INPUTS:") 673 // for _, in := range instr.inputs { 674 // logf("%v\n", in.boundTo.(*dualValue).d) 675 // } 676 return err 677 }