gorgonia.org/gorgonia@v0.9.17/vm_tape_nocuda.go (about)

     1  // +build !cuda
     2  
     3  package gorgonia
     4  
     5  import (
     6  	"github.com/pkg/errors"
     7  	"gorgonia.org/tensor"
     8  )
     9  
    10  func finalizeTapeMachine(m *tapeMachine) {}
    11  
    12  // UseCudaFor is an option for *tapeMachine. This function is NO-OP unless the program is built with the `cuda` tag.
    13  func UseCudaFor(ops ...string) VMOpt {
    14  	return func(m VM) {}
    15  }
    16  
    17  func (m *tapeMachine) getEngine(dev Device) tensor.Engine { return m.Engine }
    18  
    19  func (instr *execOp) exec(m *tapeMachine) (err error) {
    20  	m.logf("Executing %v. Node is: %x", instr, instr.id)
    21  	m.enterLogScope()
    22  	defer m.leaveLogScope()
    23  
    24  	// Read
    25  	m.watchedLogf("Inputs:")
    26  	m.enterLogScope()
    27  	var inputs []Value
    28  	for _, reg := range instr.readFrom {
    29  		v := m.cpumem[reg.id]
    30  		inputs = append(inputs, v)
    31  		m.watchedLogf(m.valueFmt, v)
    32  	}
    33  	m.leaveLogScope()
    34  
    35  	// check if the destination has already been allocated
    36  	var usePrealloc bool
    37  	dest := instr.writeTo.id
    38  	if m.cpumem[dest] != nil {
    39  		usePrealloc = true
    40  	}
    41  
    42  	// Execute
    43  	var v Value
    44  	switch {
    45  	case instr.preAllocated:
    46  		if pd, ok := instr.op.(UsePreallocDoer); ok {
    47  			p := m.cpumem[instr.writeTo.id]
    48  			if v, err = pd.UsePreallocDo(p, inputs...); err != nil {
    49  				return errors.Wrapf(err, "Happened while attempting to execute %v. Node is %x. Register was: %v ", instr, instr.id, instr.writeTo.id)
    50  			}
    51  		} else {
    52  			// TODO: maybe warn?
    53  			if v, err = instr.op.Do(inputs...); err != nil {
    54  				return errors.Wrap(err, opDoFail)
    55  			}
    56  		}
    57  	case usePrealloc:
    58  		if pd, ok := instr.op.(UsePreallocDoer); ok {
    59  			p := m.cpumem[instr.writeTo.id]
    60  			if v, err = pd.UsePreallocDo(p, inputs...); err != nil {
    61  				if v, err = instr.op.Do(inputs...); err != nil {
    62  					return errors.Wrap(err, opDoFail)
    63  				}
    64  			}
    65  		} else {
    66  			if v, err = instr.op.Do(inputs...); err != nil {
    67  				return errors.Wrap(err, opDoFail)
    68  			}
    69  		}
    70  	case instr.useUnsafe:
    71  		if ud, ok := instr.op.(UnsafeDoer); ok {
    72  			if v, err = ud.UnsafeDo(inputs...); err != nil {
    73  				return errors.Wrap(err, "Failed to carry UnsafeDo()")
    74  			}
    75  		} else {
    76  			// TODO: warn?
    77  			if v, err = instr.op.Do(inputs...); err != nil {
    78  				return errors.Wrap(err, opDoFail)
    79  			}
    80  		}
    81  	default:
    82  		if v, err = instr.op.Do(inputs...); err != nil {
    83  			return errors.Wrap(err, opDoFail)
    84  		}
    85  	}
    86  
    87  	m.watchedLogf("Result:")
    88  	m.enterLogScope()
    89  	m.watchedLogf(m.valueFmt, v)
    90  	m.leaveLogScope()
    91  	// TODO: type and shape checks
    92  
    93  	// Write
    94  	setEngine(v, m.Engine)
    95  
    96  	m.cpumem[dest] = v
    97  	node := m.p.g.Node(instr.id).(*Node)
    98  
    99  	if m.trace() && (len(m.watchNodes) == 0 || m.watchNodes.Contains(node)) {
   100  		if err = node.bindCopy(v); err != nil {
   101  			return errors.Wrapf(err, "TraceExec failed to bind copy")
   102  		}
   103  	} else {
   104  		node.bind(v)
   105  	}
   106  
   107  	// this is a gradient node then, we should also bind the value to the node's dualValue
   108  	if m.bindDV() && node.derivOf != nil {
   109  		for _, src := range node.derivOf {
   110  			if len(m.bindNodesDV) > 0 && !m.bindNodesDV.Contains(src) {
   111  				continue
   112  			}
   113  
   114  			if src.boundTo != nil {
   115  				dv := dvUnit(src.boundTo)
   116  
   117  				add := newEBOByType(addOpType, TypeOf(dv.d), TypeOf(v))
   118  
   119  				if d, err := add.UnsafeDo(dv.d, v); err == nil {
   120  					dv.SetDeriv(d)
   121  					src.bind(dv)
   122  				} else {
   123  					return err
   124  				}
   125  			}
   126  		}
   127  
   128  	}
   129  
   130  	m.watchedLogf("Written To: %v", instr.writeTo)
   131  	m.enterLogScope()
   132  	m.watchedLogf(m.valueFmt, v)
   133  	m.leaveLogScope()
   134  	return nil
   135  }
   136  
   137  func (instr deviceTransport) exec(m *tapeMachine) error {
   138  	return nil
   139  }