gorgonia.org/gorgonia@v0.9.17/execution.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  	"gorgonia.org/tensor"
     6  )
     7  
     8  // Arena is a representation of a pool of tensor.Memory
     9  type Arena interface {
    10  	Get(dev Device, size int64) (tensor.Memory, error)       // Get returns a NoOpError when it cannot get a memory. Please allocate
    11  	GetFromValue(dev Device, v Value) (tensor.Memory, error) // Gets a memory and copies the values into the memory and returns it.
    12  	Put(dev Device, mem tensor.Memory, size int64)           // puts the memory back into the arena
    13  	PutValue(dev Device, v Value)                            // puts the memory back into the arena
    14  
    15  	// Transfers memory from device to device
    16  	Transfer(toDev, fromDev Device, v Value, synchronous bool) (retVal Value, err error)
    17  }
    18  
    19  // External is a representation of an external device (cuda/cgo/openCL), conceptually modelled as a machine.
    20  type External interface {
    21  	Arena
    22  	Signal() // signals the machine to do work
    23  	Sync() chan struct{}
    24  }
    25  
    26  // ExecutionContext informs how an op should be executed
    27  type ExecutionContext struct {
    28  	External
    29  	Device
    30  }
    31  
    32  // ExternalOp is an op that contains an external context. This allows for ops to be run without needing a VM
    33  type ExternalOp struct {
    34  	Op
    35  	ExecutionContext
    36  
    37  	Prealloc  Value
    38  	Incr      Value // is this a Incr? IncrDoers have higher precedence over PreallocDo
    39  	UseUnsafe bool  // Is this an unsafe op? Lowest of all "special" Dos
    40  }
    41  
    42  // NewExternalOp creates a new *ExternalOp.
    43  func NewExternalOp(op Op, ctx ExecutionContext, prealloc Value) *ExternalOp {
    44  	retVal := &ExternalOp{
    45  		Op:               op,
    46  		ExecutionContext: ctx,
    47  		Prealloc:         prealloc,
    48  		UseUnsafe:        false,
    49  	}
    50  
    51  	return retVal
    52  }
    53  
    54  // DetermineDevice ...
    55  func (op *ExternalOp) DetermineDevice(inputs Nodes, output *Node) error {
    56  	dev := output.dataOn
    57  	var inDev Device = -2
    58  	var allSame bool
    59  	for _, in := range inputs {
    60  		if in.dataOn != dev {
    61  			allSame = false
    62  		}
    63  
    64  		if inDev == -2 {
    65  			inDev = in.dataOn
    66  			continue
    67  		}
    68  		if in.dataOn != inDev && in.dataOn != dev {
    69  			return errors.Errorf("Cannot automatically determine device.")
    70  		}
    71  	}
    72  
    73  	if !allSame {
    74  		return errors.Errorf("Not all the same devices")
    75  	}
    76  	op.Device = dev
    77  	return nil
    78  }
    79  
    80  // Do performs the op,
    81  func (op *ExternalOp) Do(vals ...Value) (Value, error) {
    82  	if op.Device == CPU {
    83  		switch {
    84  		case op.Incr != nil:
    85  			if id, ok := op.Op.(IncrDoer); ok {
    86  				if err := id.IncrDo(op.Incr, vals...); err != nil {
    87  					if ver, ok := err.(Valuer); ok {
    88  						return ver.Value(), nil
    89  					}
    90  					return nil, err
    91  				}
    92  				return op.Incr, nil
    93  			}
    94  		case op.Prealloc != nil:
    95  			if pd, ok := op.Op.(UsePreallocDoer); ok {
    96  				pd.UsePreallocDo(op.Prealloc, vals...)
    97  			}
    98  			retVal, err := op.Op.Do(vals...)
    99  			if err != nil {
   100  				return retVal, err
   101  			}
   102  			return Copy(op.Prealloc, retVal)
   103  		case op.UseUnsafe:
   104  			if ud, ok := op.Op.(UnsafeDoer); ok {
   105  				return ud.UnsafeDo(vals...)
   106  			}
   107  			fallthrough
   108  		default:
   109  			return op.Op.Do(vals...)
   110  		}
   111  	}
   112  
   113  	switch o := op.Op.(type) {
   114  	case CUDADoer:
   115  		if op.Incr != nil {
   116  			v, err := o.CUDADo(op.External, op.Device, op.Prealloc, vals...)
   117  			if err != nil {
   118  				return nil, err
   119  			}
   120  
   121  			add := newEBOByType(addOpType, TypeOf(op.Incr), TypeOf(v))
   122  			addOp := NewExternalOp(add, op.ExecutionContext, nil)
   123  			addOp.UseUnsafe = true
   124  			retVal, err := addOp.Do(op.Incr, v)
   125  			return retVal, err
   126  		}
   127  		return o.CUDADo(op.External, op.Device, op.Prealloc, vals...)
   128  	case CLDoer:
   129  	case IncrDoer:
   130  		if op.Incr != nil {
   131  			if err := o.IncrDo(op.Incr, vals...); err != nil {
   132  				if ver, ok := err.(Valuer); ok {
   133  					return ver.Value(), nil
   134  				}
   135  				return nil, err
   136  			}
   137  			return op.Incr, nil
   138  		}
   139  		return op.Op.Do(vals...)
   140  	case UsePreallocDoer:
   141  		if op.Prealloc != nil {
   142  			return o.UsePreallocDo(op.Prealloc, vals...)
   143  		}
   144  		return op.Op.Do(vals...)
   145  	case UnsafeDoer:
   146  		if op.UseUnsafe {
   147  			return o.UnsafeDo(vals...)
   148  		}
   149  		return op.Op.Do(vals...)
   150  	default:
   151  		return o.Do(vals...)
   152  	}
   153  
   154  	panic("Unreachable")
   155  }
   156  
   157  func (op *ExternalOp) String() string {
   158  	return op.Op.String()
   159  }