gorgonia.org/gorgonia@v0.9.17/vm_tape.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"log"
     7  	"runtime"
     8  	"strings"
     9  
    10  	"github.com/chewxy/hm"
    11  	"github.com/pkg/errors"
    12  	"gorgonia.org/tensor"
    13  )
    14  
    15  type tapeMachine struct {
    16  	ExternMetadata
    17  
    18  	p      *program
    19  	locMap map[*Node]register
    20  
    21  	// "register" banks
    22  	cpumem []Value // Value - knows its own type and shape
    23  	gpumem []Value // Value of which the memories are stored in GPU memory
    24  
    25  	// state stuff, to allow continuation
    26  	pc int
    27  
    28  	// operational stuff
    29  	bindNodesDV Nodes // nodes that require binding of DV
    30  	watchNodes  Nodes
    31  	watchRegs   []register
    32  	logger      *log.Logger
    33  	buf         *bytes.Buffer
    34  	valueFmt    string
    35  	tabcount    int
    36  	logFlags    byte
    37  
    38  	runFlags byte //  spare2: trace(copy values and put into nodes)
    39  }
    40  
    41  // NewTapeMachine creates a VM that compiles a graph into a prog.
    42  func NewTapeMachine(g *ExprGraph, opts ...VMOpt) *tapeMachine {
    43  	m := &tapeMachine{
    44  		valueFmt: "%3.3g",
    45  	}
    46  	m.Engine = StandardEngine{}
    47  
    48  	if b, ok := whichblas.(batchedBLAS); ok {
    49  		m.b = b
    50  	}
    51  
    52  	for _, opt := range opts {
    53  		opt(m)
    54  	}
    55  
    56  	m.doAlloc()
    57  
    58  	if m.p == nil || m.locMap == nil {
    59  		prog, locMap, err := Compile(g)
    60  		if err != nil {
    61  			panic(err)
    62  		}
    63  
    64  		m.p = prog
    65  		m.locMap = locMap
    66  	}
    67  	m.cpumem = make([]Value, m.p.cpulocs)
    68  	m.gpumem = make([]Value, m.p.gpulocs)
    69  	m.init()
    70  	for _, n := range m.p.g.AllNodes() {
    71  		setEngine(n.boundTo, m.Engine)
    72  	}
    73  
    74  	runtime.SetFinalizer(m, finalizeTapeMachine) // a "defer" to deinitialize CUDA stuff (if using CUDA build)
    75  	return m
    76  }
    77  
    78  func (m *tapeMachine) logBwd() bool { return (m.logFlags>>bwdOnly)&byte(1) == 1 }
    79  func (m *tapeMachine) doLogBwd()    { m.logFlags |= byte(1) << bwdOnly }
    80  func (m *tapeMachine) dontLogBwd()  { m.logFlags &= (^(byte(1) << bwdOnly)) }
    81  
    82  func (m *tapeMachine) logFwd() bool { return (m.logFlags>>fwdOnly)&byte(1) == 1 }
    83  func (m *tapeMachine) doLogFwd()    { m.logFlags |= byte(1) << fwdOnly }
    84  func (m *tapeMachine) dontLogFwd()  { m.logFlags &= (^(byte(1) << fwdOnly)) }
    85  
    86  func (m *tapeMachine) watchNaN() bool { return (m.runFlags>>watchNaN)&byte(1) == 1 }
    87  func (m *tapeMachine) doWatchNaN()    { m.runFlags |= byte(1) << watchNaN }
    88  func (m *tapeMachine) dontWatchNaN()  { m.runFlags &= (^(byte(1) << watchNaN)) }
    89  
    90  func (m *tapeMachine) watchInf() bool { return (m.runFlags>>watchInf)&byte(1) == 1 }
    91  func (m *tapeMachine) doWatchInf()    { m.runFlags |= byte(1) << watchInf }
    92  func (m *tapeMachine) dontWatchInf()  { m.runFlags &= (^(byte(1) << watchInf)) }
    93  
    94  func (m *tapeMachine) watchAll() bool { return (m.logFlags>>watchAll)&byte(1) == 1 }
    95  func (m *tapeMachine) doWatchAll()    { m.logFlags |= (byte(1) << watchAll) }
    96  func (m *tapeMachine) dontWatchAll()  { m.logFlags &= (^(byte(1) << watchAll)) }
    97  
    98  func (m *tapeMachine) alloc() bool { return (m.runFlags>>allocVals)&byte(1) == 1 }
    99  func (m *tapeMachine) doAlloc()    { m.runFlags |= byte(1) << allocVals }
   100  func (m *tapeMachine) dontAlloc()  { m.runFlags &= (^(byte(1) << allocVals)) }
   101  
   102  func (m *tapeMachine) trace() bool { return (m.runFlags>>spare2)&byte(1) == 1 }
   103  func (m *tapeMachine) doTrace()    { m.runFlags |= byte(1) << spare2 }
   104  func (m *tapeMachine) dontTrace()  { m.runFlags &= (^(byte(1) << spare2)) }
   105  
   106  func (m *tapeMachine) bindDV() bool { return m.runFlags>>spare3&byte(1) == 1 }
   107  func (m *tapeMachine) doBindDV()    { m.runFlags |= byte(1) << spare3 }
   108  func (m *tapeMachine) dontBindDV()  { m.runFlags &= (^(byte(1) << spare3)) }
   109  
   110  // Reset resets the run state of the machine by changing the instruction pointer back to 0
   111  // and reseting the registry
   112  func (m *tapeMachine) Reset() {
   113  	m.pc = 0
   114  	m.ExternMetadata.Reset()
   115  
   116  	for i := range m.gpumem {
   117  		returnValue(m.gpumem[i])
   118  		m.gpumem[i] = nil //
   119  	}
   120  	for i := range m.cpumem {
   121  		m.cpumem[i] = nil
   122  	}
   123  }
   124  
   125  func (m *tapeMachine) Close() error {
   126  	finalizeTapeMachine(m)
   127  	return nil
   128  }
   129  
   130  // Prog returns the compiled program. This would mainly be used in debugging functions
   131  func (m *tapeMachine) Prog() *program { return m.p }
   132  
   133  // LocMap returns the location where the Node's execution results are stored. This would mainly be used in debugging functions.
   134  func (m *tapeMachine) LocMap() map[*Node]register { return m.locMap }
   135  
   136  // Let wraps the Let() function of the package, with additional checks that n is in the machine
   137  func (m *tapeMachine) Let(n *Node, be interface{}) (err error) {
   138  	if !m.p.g.Has(n.ID()) {
   139  		return errors.Errorf("Node %v does not exist in this graph", n)
   140  	}
   141  
   142  	return Let(n, be)
   143  }
   144  
   145  // Set wraps the Set() function of this package, with additional checks that both a and b are in the machine
   146  func (m *tapeMachine) Set(a, b *Node) (err error) {
   147  	if !m.p.g.Has(a.ID()) {
   148  		return errors.Errorf("Node %v does not exist in this graph", a)
   149  	}
   150  	if !m.p.g.Has(b.ID()) {
   151  		return errors.Errorf("Node %v does not exist in this graph", b)
   152  	}
   153  
   154  	if b.Value() != nil {
   155  		return a.bind(b.Value())
   156  	}
   157  
   158  	// get the registry location
   159  	breg := m.locMap[b]
   160  	v := m.getValue(breg)
   161  	if v == nil {
   162  		return nyi("handling of tensor.Memory -> Value", "tapeMachine.Set")
   163  	}
   164  
   165  	machineLogf("Setting %v to %v. Read from %v Value is %v", b, a, breg, v)
   166  	return a.bind(v)
   167  }
   168  
   169  // Run runs a fragment (a subset of a program).
   170  func (m *tapeMachine) Run(frag fragment) (err error) {
   171  	defer func() {
   172  		if err == nil {
   173  			m.dontAlloc()
   174  		}
   175  	}()
   176  
   177  	for _, instr := range frag {
   178  		if err = instr.exec(m); err != nil {
   179  			return errors.Wrap(err, "Failed to carry exec()")
   180  		}
   181  	}
   182  	machineLogf("Binding values based on final output")
   183  	enterLogScope()
   184  	for n, r := range m.locMap {
   185  		if n.isInput() {
   186  			continue
   187  		}
   188  
   189  		v := m.getValue(r)
   190  		if v == nil {
   191  			return nyi("converting tensor.Memory to Value", "TapeMachine.Run")
   192  		}
   193  
   194  		if err = n.bind(m.cpumem[r.id]); err != nil {
   195  			return errors.Wrap(err, bindFail)
   196  		}
   197  	}
   198  	leaveLogScope()
   199  	return
   200  }
   201  
   202  func (m *tapeMachine) RunAll() (err error) {
   203  	runtime.LockOSThread()
   204  	defer runtime.UnlockOSThread()
   205  	defer m.DoWork()
   206  
   207  	workAvailable := m.ExternMetadata.WorkAvailable()
   208  	syncChan := m.ExternMetadata.Sync()
   209  	errChan := make(chan error)
   210  	doneChan := make(chan struct{})
   211  
   212  	go m.runall(errChan, doneChan)
   213  	for {
   214  		select {
   215  		case sychronous := <-workAvailable:
   216  			err := m.ExternMetadata.DoWork()
   217  			if err != nil {
   218  				return err
   219  			}
   220  			if sychronous {
   221  				syncChan <- struct{}{}
   222  			}
   223  		case err := <-errChan:
   224  			return errors.Wrapf(err, "PC: %d", m.pc)
   225  		case <-doneChan:
   226  			err := m.ExternMetadata.DoWork()
   227  			if err != nil {
   228  				return err
   229  			}
   230  			return nil
   231  		}
   232  	}
   233  }
   234  
   235  func (m *tapeMachine) runall(errChan chan error, doneChan chan struct{}) {
   236  	for ; m.pc < len(m.p.instructions); m.pc++ {
   237  		instr := m.p.instructions[m.pc]
   238  		m.logf("PC %d", m.pc)
   239  		if err := instr.exec(m); err != nil {
   240  			err = errors.Wrapf(err, "PC %d. Failed to execute instruction %v", m.pc, instr)
   241  			errChan <- err
   242  			return
   243  		}
   244  		// only proceed to check NaNs and Infs for execOp
   245  		if _, ok := instr.(*execOp); !ok {
   246  			continue
   247  		}
   248  
   249  		if m.watchNaN() {
   250  			writeTo := instr.writes().id
   251  			id := instr.ID()
   252  			if writeTo > 0 && id > 0 {
   253  				v := m.getValue(instr.writes())
   254  				if v == nil {
   255  					err := errors.Errorf(nyiFail, "converting tensor.Memory to Value", "watchNaN")
   256  					errChan <- err
   257  					return
   258  				}
   259  
   260  				if hasNaN(v, CPU) {
   261  					n := m.p.g.Node(id).(*Node)
   262  					err := errors.Errorf("NaN found in value. Node: %v(%x)", n, n.ID())
   263  					errChan <- err
   264  					return
   265  				}
   266  			}
   267  		}
   268  
   269  		if m.watchInf() {
   270  			writeTo := instr.writes().id
   271  			id := instr.ID()
   272  			if writeTo > 0 && id > 0 {
   273  				v := m.getValue(instr.writes())
   274  				if v == nil {
   275  					err := errors.Errorf(nyiFail, "converting tensor.Memory to Value", "watchInf")
   276  					errChan <- err
   277  					return
   278  				}
   279  
   280  				if hasInf(v, CPU) {
   281  					n := m.p.g.Node(id).(*Node)
   282  					err := errors.Errorf("Inf found in value. Node: %v(%x)", n, n.ID())
   283  					errChan <- err
   284  					return
   285  				}
   286  			}
   287  		}
   288  	}
   289  	doneChan <- struct{}{}
   290  }
   291  
   292  func (m *tapeMachine) getValue(r register) Value {
   293  	switch r.device {
   294  	case CPU:
   295  		return m.cpumem[r.id]
   296  	default:
   297  		return m.gpumem[r.id]
   298  	}
   299  }
   300  
   301  func (m *tapeMachine) writeValue(r register, v Value) {
   302  	switch r.device {
   303  	case CPU:
   304  		m.cpumem[r.id] = v
   305  	default:
   306  		m.gpumem[r.id] = v
   307  	}
   308  }
   309  
   310  func (m *tapeMachine) watchedLogf(format string, attrs ...interface{}) {
   311  	instr := m.p.instructions[m.pc]
   312  	reads := instr.reads()
   313  	writes := instr.writes()
   314  
   315  	watched := m.watchAll()
   316  
   317  	if !watched {
   318  		for _, reg := range reads {
   319  			for _, watch := range m.watchRegs {
   320  				if reg.id == watch.id {
   321  					watched = true
   322  					break
   323  				}
   324  			}
   325  		}
   326  	}
   327  
   328  	if !watched {
   329  		for _, watch := range m.watchRegs {
   330  			if watch.id == writes.id {
   331  				watched = true
   332  				break
   333  			}
   334  		}
   335  	}
   336  
   337  	// TODO: Work on watched nodes
   338  	if !watched {
   339  
   340  	}
   341  
   342  	if watched {
   343  		m.logf(format, attrs...)
   344  	}
   345  }
   346  
   347  func (m *tapeMachine) logf(format string, attrs ...interface{}) {
   348  	switch {
   349  	case machineDev:
   350  		if m.logger != nil {
   351  			goto loggercase
   352  		}
   353  
   354  		machineLogf(format, attrs...)
   355  		break
   356  
   357  	loggercase:
   358  		fallthrough
   359  	case m.logger != nil:
   360  		s := fmt.Sprintf(format, attrs...)
   361  		s = strings.Replace(s, "\n", m.buf.String(), -1)
   362  		m.logger.Println(s)
   363  	}
   364  }
   365  
   366  func (m *tapeMachine) enterLogScope() {
   367  	if DEBUG && machineDev {
   368  		enterLogScope()
   369  	}
   370  	m.tabcount++
   371  	if m.logger != nil {
   372  		reps := strings.Repeat("\t", m.tabcount)
   373  		m.logger.SetPrefix(reps)
   374  		m.buf.Reset()
   375  		m.buf.WriteString("\n")
   376  		m.buf.WriteString(reps)
   377  	}
   378  }
   379  
   380  func (m *tapeMachine) leaveLogScope() {
   381  	if DEBUG && machineDev {
   382  		leaveLogScope()
   383  	}
   384  	m.tabcount--
   385  	if m.tabcount < 0 {
   386  		m.tabcount = 0
   387  	}
   388  	if m.logger != nil {
   389  		reps := strings.Repeat("\t", m.tabcount)
   390  		m.logger.SetPrefix(reps)
   391  		m.buf.Reset()
   392  		m.buf.WriteString("\n")
   393  		m.buf.WriteString(reps)
   394  	}
   395  }
   396  
   397  /* PROGRAM */
   398  
   399  type program struct {
   400  	instructions fragment
   401  	args         int
   402  	cpulocs      int
   403  	gpulocs      int
   404  	cpumem       int64
   405  	gpumem       []int64
   406  	g            *ExprGraph         // original dag
   407  	df           *dataflow          // dataflow analysis
   408  	m            map[*Node]fragment // store which nodes create which instructions
   409  	sorted       Nodes
   410  }
   411  
   412  func (p *program) String() string {
   413  	var buf bytes.Buffer
   414  	fmt.Fprintf(&buf, "Instructions:\n%s\nArgs: %d | CPU Memories: %d | GPU Memories: %d\nCPU Mem: %v | GPU Mem %v\n\nNode:instructions map:\n", p.instructions, p.args, p.cpulocs, p.gpulocs, p.cpumem, p.gpumem)
   415  
   416  	for i, n := range p.sorted {
   417  		fmt.Fprintf(&buf, "\t%d\t%x:", i, n.ID())
   418  		frag := p.m[n]
   419  		for j, instr := range frag {
   420  			if j == 0 {
   421  				fmt.Fprintf(&buf, "\t%v\n", instr)
   422  			} else {
   423  				fmt.Fprintf(&buf, "\t\t%v\n", instr)
   424  			}
   425  		}
   426  
   427  	}
   428  
   429  	return buf.String()
   430  }
   431  
   432  // Graph enables the end user to inspect the graph (typically useful for debugging)
   433  func (p *program) Graph() *ExprGraph { return p.g }
   434  
   435  func (p *program) CPUMemReq() int64 { return p.cpumem }
   436  
   437  func (p *program) GPUMemReq() []int64 {
   438  	retVal := make([]int64, len(p.gpumem))
   439  	copy(retVal, p.gpumem)
   440  	return retVal
   441  }
   442  
   443  /* REGISTER */
   444  
   445  type register struct {
   446  	id     int
   447  	device Device
   448  }
   449  
   450  func (r register) String() string { return fmt.Sprintf("%s%d", r.device, r.id) }
   451  
   452  /* INSTRUCTIONS */
   453  
   454  type tapeInstr interface {
   455  	ID() int64 // ID is the node ID
   456  	reads() []register
   457  	writes() register
   458  	exec(*tapeMachine) error
   459  	fmt.Stringer
   460  }
   461  
   462  type fragment []tapeInstr
   463  
   464  func (f fragment) String() string {
   465  	var buf bytes.Buffer
   466  	for i, instr := range f {
   467  		fmt.Fprintf(&buf, "\t%d\t%s\n", i, instr)
   468  	}
   469  	return buf.String()
   470  }
   471  
   472  func (f fragment) has(want tapeInstr) bool {
   473  	for _, instr := range f {
   474  		if instr == want {
   475  			return true
   476  		}
   477  	}
   478  	return false
   479  }
   480  
   481  type alloc struct {
   482  	id int64 // node ID
   483  	t  hm.Type
   484  	s  tensor.Shape
   485  
   486  	readFrom []register
   487  	writeTo  register
   488  }
   489  
   490  func newAlloc(n *Node, writeTo register) alloc {
   491  	return alloc{
   492  		id:      n.ID(),
   493  		t:       n.t,
   494  		s:       n.shape,
   495  		writeTo: writeTo,
   496  	}
   497  }
   498  
   499  func (instr alloc) ID() int64         { return instr.id }
   500  func (instr alloc) reads() []register { return instr.readFrom }
   501  func (instr alloc) writes() register  { return instr.writeTo }
   502  
   503  func (instr alloc) exec(m *tapeMachine) (err error) {
   504  	m.logf("Executing %v", instr)
   505  	m.enterLogScope()
   506  	defer m.leaveLogScope()
   507  
   508  	var dt tensor.Dtype
   509  	if dt, err = dtypeOf(instr.t); err != nil {
   510  		return errors.Wrapf(err, dtypeExtractionFail, instr.t)
   511  	}
   512  
   513  	reg := m.getValue(instr.writeTo)
   514  	if reg != nil && reg.Dtype() == dt && reg.Shape().Eq(instr.s) {
   515  		return nil
   516  	}
   517  
   518  	dev := instr.writeTo.device
   519  	var v Value
   520  	switch dev {
   521  	case CPU:
   522  
   523  		v, err = makeValue(instr.t, instr.s)
   524  
   525  	default:
   526  		var mem tensor.Memory
   527  		memsize := calcMemSize(dt, instr.s)
   528  		if mem, err = m.ExternMetadata.Get(dev, memsize); err != nil {
   529  			return errors.Wrapf(err, "Unable to allocate %v bytes from %v | %T", memsize, dev, err)
   530  		}
   531  		v, err = makeValueFromMem(instr.t, instr.s, mem)
   532  	}
   533  	if err != nil {
   534  		return
   535  	}
   536  	setEngine(v, m.getEngine(dev))
   537  	if vt, ok := v.(tensor.Tensor); ok {
   538  		m.watchedLogf("%x | %T", v.Uintptr(), vt.Engine())
   539  	} else {
   540  		m.watchedLogf("%x", v.Uintptr())
   541  	}
   542  
   543  	m.writeValue(instr.writeTo, v)
   544  	return nil
   545  }
   546  
   547  func (instr alloc) String() string {
   548  	return fmt.Sprintf("Alloc %v%v\t\t%v", instr.t, instr.s, instr.writeTo)
   549  }
   550  
   551  type free struct {
   552  	readsFrom register
   553  }
   554  
   555  func (instr free) ID() int64         { return -1 }
   556  func (instr free) reads() []register { return []register{instr.readsFrom} }
   557  func (instr free) writes() register  { return register{-1, CPU} }
   558  func (instr free) exec(m *tapeMachine) error {
   559  	m.logf("Executing Free %v", instr.readsFrom)
   560  	switch instr.readsFrom.device {
   561  	case CPU:
   562  		return nil
   563  	default:
   564  		m.logf("instr.read from not CPU - %v %v %d", instr.readsFrom, instr.readsFrom.device == CPU, instr.readsFrom.device)
   565  		mem := m.gpumem[instr.readsFrom.id]
   566  		size := int64(mem.MemSize())
   567  
   568  		m.Put(instr.readsFrom.device, mem, size)
   569  		m.gpumem[instr.readsFrom.id] = nil
   570  		return nil
   571  	}
   572  }
   573  func (instr free) String() string { return fmt.Sprintf("Free %v", instr.readsFrom) }
   574  
   575  type loadArg struct {
   576  	index   int64
   577  	writeTo register
   578  	name    string
   579  }
   580  
   581  func (instr loadArg) ID() int64         { return instr.index }
   582  func (instr loadArg) reads() []register { return nil }
   583  func (instr loadArg) writes() register  { return instr.writeTo }
   584  
   585  func (instr loadArg) exec(m *tapeMachine) error {
   586  	m.logf("Executing %v", instr)
   587  	m.enterLogScope()
   588  	defer m.leaveLogScope()
   589  
   590  	node := m.p.g.Node(instr.index).(*Node)
   591  	m.logf("node %v", node)
   592  
   593  	if node.boundTo == nil {
   594  		return errors.Errorf("No value bound to node %v (%x)", node, node.ID())
   595  	}
   596  
   597  	var v Value
   598  	if dv, ok := node.boundTo.(*dualValue); ok {
   599  		v = dv.Value
   600  	} else {
   601  		v = node.boundTo
   602  	}
   603  
   604  	m.writeValue(instr.writeTo, v)
   605  	// m.watchedLogf("Write To: %v", instr.writeTo)
   606  	// m.watchedLogf(m.valueFmt, m.cpumem[instr.writeTo.id])
   607  	return nil
   608  }
   609  
   610  func (instr loadArg) String() string {
   611  	return fmt.Sprintf("loadArg %x (%v) to %v", instr.index, instr.name, instr.writeTo)
   612  }
   613  
   614  type execOp struct {
   615  	op Op
   616  
   617  	id int64
   618  
   619  	readFrom []register
   620  	writeTo  register
   621  	size     int64 // size represents the outputsize
   622  
   623  	preAllocated bool
   624  	useUnsafe    bool
   625  	useGPU       bool
   626  }
   627  
   628  func (instr *execOp) ID() int64         { return instr.id }
   629  func (instr *execOp) reads() []register { return instr.readFrom }
   630  func (instr *execOp) writes() register  { return instr.writeTo }
   631  
   632  func newExecOp(n *Node) *execOp {
   633  	_, useGPU := n.op.(CUDADoer)
   634  	compileLogf("op %v uses GPU %v", n.op, useGPU)
   635  	dt, err := dtypeOf(n.t)
   636  	if err != nil {
   637  		panic(err)
   638  	}
   639  	size := calcMemSize(dt, n.Shape())
   640  
   641  	return &execOp{
   642  		op:     n.op,
   643  		id:     n.ID(),
   644  		useGPU: useGPU,
   645  		size:   size,
   646  	}
   647  }
   648  
   649  func (instr *execOp) String() string {
   650  	return fmt.Sprintf("%v\t%v\t%v\t%t\t%t\t%t", instr.op, instr.readFrom, instr.writeTo, instr.op.CallsExtern(), instr.useUnsafe, instr.preAllocated)
   651  }
   652  
   653  // flushInstr is for blastoise and cubone
   654  type flushInstr struct{}
   655  
   656  func (instr flushInstr) exec(m *tapeMachine) error {
   657  	m.logf("Executing DoWork")
   658  	return m.ExternMetadata.DoWork()
   659  }
   660  
   661  func (instr flushInstr) ID() int64         { return -1 }
   662  func (instr flushInstr) reads() []register { return nil }
   663  func (instr flushInstr) writes() register  { return register{-1, CPU} }
   664  func (instr flushInstr) String() string    { return "DoWork" }
   665  
   666  type letInstr struct {
   667  	readFrom register
   668  	writeTo  register
   669  }
   670  
   671  func (instr letInstr) ID() int64               { return -1 }
   672  func (instr letInstr) reads() []register       { return []register{instr.readFrom} }
   673  func (instr letInstr) writes() register        { return instr.writeTo }
   674  func (instr letInstr) exec(*tapeMachine) error { return nil }
   675  
   676  func (instr letInstr) String() string {
   677  	return fmt.Sprintf("LET %v = %v", instr.writeTo, instr.readFrom)
   678  }
   679  
   680  type readInstr struct {
   681  	readFrom register
   682  	into     *Value
   683  
   684  	// required to convert tensor.Memory to Value
   685  	t hm.Type
   686  	s tensor.Shape
   687  }
   688  
   689  func (instr *readInstr) ID() int64         { return -1 }
   690  func (instr *readInstr) reads() []register { return []register{instr.readFrom} }
   691  func (instr *readInstr) writes() register  { return register{-1, CPU} }
   692  func (instr *readInstr) exec(m *tapeMachine) (err error) {
   693  	m.logf("Executing READ - read from %v into %v", instr.readFrom, instr.into)
   694  	v := m.getValue(instr.readFrom)
   695  	if v == nil {
   696  		return nyi("value of nil", "readInstr.exec")
   697  	}
   698  
   699  	if *instr.into != nil {
   700  		dest := *instr.into
   701  		_, err = Copy(dest, v)
   702  		return err
   703  	}
   704  
   705  	v2, err := CloneValue(v)
   706  	if err != nil {
   707  		return errors.Wrap(err, cloneFail)
   708  	}
   709  
   710  	*instr.into = v2
   711  	return nil
   712  }
   713  
   714  func (instr *readInstr) String() string {
   715  	return fmt.Sprintf("Read %v into %p", instr.readFrom, instr.into)
   716  }
   717  
   718  type deviceTransport struct {
   719  	from, to register
   720  }
   721  
   722  func (instr deviceTransport) ID() int64 { return -1 }
   723  func (instr deviceTransport) reads() []register {
   724  	return []register{instr.from}
   725  }
   726  func (instr deviceTransport) writes() register { return instr.to }
   727  
   728  func (instr deviceTransport) String() string {
   729  	return fmt.Sprintf("memcpy(%v, %v)", instr.to, instr.from)
   730  }