gorgonia.org/gorgonia@v0.9.17/op.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"hash"
     6  	"hash/fnv"
     7  
     8  	"github.com/chewxy/hm"
     9  	"github.com/pkg/errors"
    10  	"gorgonia.org/tensor"
    11  )
    12  
    13  // DimSizer is any type (typically a tensor.Shape) that allows querying for a dimension size given an input dimension.
    14  type DimSizer interface {
    15  	DimSize(int) (int, error)
    16  }
    17  
    18  // ShapesToDimSizers is a convenience function to convert a slice of tensor.Shape to a slice of DimSizer
    19  func ShapesToDimSizers(shapes []tensor.Shape) []DimSizer {
    20  	retVal := make([]DimSizer, len(shapes))
    21  	for i, s := range shapes {
    22  		retVal[i] = s
    23  	}
    24  	return retVal
    25  }
    26  
    27  // DimSizersToShapes is a convenience function to convert a slice of DimSizer to a slice of tensor.Shape. It will return an error if any of them isn't a tensor.Shape
    28  func DimSizersToShapes(ds []DimSizer) ([]tensor.Shape, error) {
    29  	retVal := make([]tensor.Shape, len(ds))
    30  	var ok bool
    31  	for i, d := range ds {
    32  		if retVal[i], ok = d.(tensor.Shape); !ok {
    33  			return nil, errors.Errorf("Dimsizer %d is not a Shape.", i)
    34  		}
    35  	}
    36  	return retVal, nil
    37  }
    38  
    39  // An Op is a symbolic representation of an operation
    40  // Think of them as functions, taking an input (or multiple), and outputting something
    41  //
    42  // All Ops have type signatures that look like this:
    43  //		OpName :: (Floats a) ⇒ Tensor a → Tensor a → Tensor a
    44  type Op interface {
    45  	/* Graph Building Related Methods */
    46  
    47  	// Arity returns the number of inputs the Op expects. -1 indicates that it's n-ary and will be determined at runtime
    48  	Arity() int
    49  
    50  	// Informs the type of the Op (not the node). This will be used by the type system to infer the final type of the node
    51  	Type() hm.Type
    52  
    53  	// returns the output shape as a function of the inputs
    54  	InferShape(...DimSizer) (tensor.Shape, error)
    55  
    56  	/* Machine related */
    57  
    58  	// executes the op
    59  	Do(...Value) (Value, error)
    60  
    61  	/* Analysis Related Methods */
    62  
    63  	// indicates if the Op will return a pointer (allowing possible inplace edits) or by value
    64  	// if it's false, the return value of the Op will be a copy of its input
    65  	ReturnsPtr() bool
    66  
    67  	// Does this op potentially call external (cgo or cuda) functions (thereby requiring extra overhead for Go's trampolining thing)
    68  	CallsExtern() bool
    69  
    70  	// overwriteInput() is a method which states which input the output will be overwriting.
    71  	// This allows for some efficiency gains as the underlying arrays wouldn't have to be re-allocated.
    72  	// The method returns an int instead of a bool because potentially different operations may be allowed
    73  	// to overwrite certain inputs. For example, consider an operation to increment a value:
    74  	// the IncrementOp would be a unary operator, and assuming we would like to overwrite the input,
    75  	// the retVal of overwriteInput() will be 0 (inputs[0]).
    76  	// -1 is returned if overwriting of input is disallowed
    77  	OverwritesInput() int
    78  
    79  	/* Other methods */
    80  	WriteHash(h hash.Hash)
    81  	Hashcode() uint32
    82  	fmt.Stringer
    83  }
    84  
    85  // A UnaryOp is an Op that takes only one input
    86  type UnaryOp interface {
    87  	Op
    88  
    89  	IsUnary() bool
    90  }
    91  
    92  // A BinaryOp is an Op that takes only two inputs
    93  type BinaryOp interface {
    94  	Op
    95  
    96  	IsBinary() bool
    97  }
    98  
    99  // A NoRetOp is an Op that reads a value, but does not return any value. It's a representation of a not-pure function
   100  type NoRetOp interface {
   101  	Op
   102  
   103  	ReturnsNothing() bool
   104  }
   105  
   106  // An ADOp is an Op that supports automatic differentiation.
   107  type ADOp interface {
   108  	Op
   109  
   110  	DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) error
   111  }
   112  
   113  // A SDOp is an Op that supports symbolic differentiation
   114  type SDOp interface {
   115  	Op
   116  
   117  	// DiffWRT indicates if the op is differentiable with regards to the given number of inputs
   118  	// returns []bool to indicate which input it is differentiable to
   119  	DiffWRT(inputs int) []bool
   120  
   121  	// SymDiff symbolically differentiates the op
   122  	SymDiff(inputs Nodes, output, grad *Node) (retVal Nodes, err error)
   123  }
   124  
   125  // ReductionOp changes the shape of the node
   126  type ReductionOp interface {
   127  	Op
   128  
   129  	IsReduction() bool
   130  }
   131  
   132  // IncrDoer increments the toIncr with the result of doing
   133  type IncrDoer interface {
   134  	IncrDo(toIncr Value, inputs ...Value) error
   135  }
   136  
   137  // UsePreallocDoer is an op that works when a preallocated value is provided
   138  type UsePreallocDoer interface {
   139  	UsePreallocDo(prealloc Value, inputs ...Value) (Value, error)
   140  }
   141  
   142  // UnsafeDoer is an op that will overwrite the underlying value.
   143  type UnsafeDoer interface {
   144  	UnsafeDo(inputs ...Value) (Value, error)
   145  }
   146  
   147  // CUDADoer uses CUDA to perform the Op.
   148  type CUDADoer interface {
   149  	CUDADo(extern External, dev Device, prealloc Value, inputs ...Value) (retVal Value, err error)
   150  }
   151  
   152  // CLDoer uses OpenCL to perform the Op. As of now, there are NO Ops that support OpenCL
   153  type CLDoer interface {
   154  	CLDo(inputs ...Value) (Value, error)
   155  }
   156  
   157  // A CUDAADOp operation have a specific method to run with CUDA
   158  type CUDAADOp interface {
   159  	ADOp
   160  	CUDADoDiff(extern External, dev Device, inputs Nodes, output *Node) error
   161  }
   162  
   163  // ApplyOp is the generic function application - for when no specialization is required
   164  func ApplyOp(op Op, children ...*Node) (retVal *Node, err error) {
   165  	var g *ExprGraph
   166  
   167  	for _, child := range children {
   168  		if child.g != nil {
   169  			g = child.g
   170  			break
   171  		}
   172  	}
   173  
   174  	if g == nil {
   175  		return nil, errors.New("No Graph Supplied")
   176  	}
   177  
   178  	if !Nodes(children).AllSameGraph() {
   179  		return nil, errors.New("Not all children have the same graph")
   180  	}
   181  
   182  	// typecheck  before creating
   183  	typeSysLogf("Inferring node type of %v :: %v with children: %#Y", op, op.Type(), Nodes(children))
   184  	enterLogScope()
   185  	defer leaveLogScope()
   186  	var retType hm.Type
   187  	if retType, err = inferNodeType(op, children...); err != nil {
   188  		return nil, errors.Wrapf(err, "Type inference error. Op: %v. Children: %#Y, OpType:%v", op, Nodes(children), op.Type())
   189  	}
   190  	typeSysLogf("Done inferring. Return type is: %#v(%T)", retType, retType)
   191  
   192  	// infer shapes, but print errors instead of returning
   193  	shapeLogf("op: %v(%T) inferring shape", op, op)
   194  	if err = checkArity(op, len(children)); err != nil {
   195  		return
   196  	}
   197  
   198  	ds := Nodes(children).dimSizers()
   199  	var s tensor.Shape
   200  	if s, err = op.InferShape(ds...); err == nil {
   201  		shapeLogf("inferred shape %v", s)
   202  		retVal = NewUniqueNode(WithType(retType), WithOp(op), WithChildren(children), In(g), WithShape(s...))
   203  	} else {
   204  		err = errors.Wrapf(err, "Failed to infer shape. Op: %v", op)
   205  		// retVal = newUniqueNode(withType(retType), withOp(op), withChildren(children), withGraph(g))
   206  	}
   207  	returnDimSizers(ds)
   208  	return
   209  }
   210  
   211  // ApplyOpWithName applies the op, and then gives the node the given name
   212  func ApplyOpWithName(op Op, name string, children ...*Node) (retVal *Node, err error) {
   213  	if retVal, err = ApplyOp(op, children...); err == nil {
   214  		WithName(name)(retVal)
   215  	} else {
   216  		return nil, errors.Wrap(err, applyOpFail)
   217  	}
   218  	return
   219  }
   220  
   221  // a constant is an unchanging value. I think everyone would know what a constant is
   222  // a constant op is an op that creates a constant. It is also a Value of a constant value
   223  type constant interface {
   224  	Op
   225  
   226  	isconstant() bool
   227  	Value() Value
   228  }
   229  
   230  type constantScalar struct {
   231  	v Scalar
   232  }
   233  
   234  func (c constantScalar) Arity() int                                   { return 0 }
   235  func (c constantScalar) Type() hm.Type                                { return TypeOf(c.v) }
   236  func (c constantScalar) InferShape(...DimSizer) (tensor.Shape, error) { return scalarShape, nil }
   237  func (c constantScalar) ReturnsPtr() bool                             { return false }
   238  func (c constantScalar) CallsExtern() bool                            { return false }
   239  func (c constantScalar) OverwritesInput() int                         { return -1 }
   240  func (c constantScalar) DiffWRT(i int) []bool                         { return nil }
   241  func (c constantScalar) SymDiff(Nodes, *Node, *Node) (Nodes, error)   { return nil, nil }
   242  
   243  func (c constantScalar) Do(...Value) (Value, error) { return c.v, nil }
   244  func (c constantScalar) String() string             { return fmt.Sprintf("const %s", c.v) }
   245  
   246  func (c constantScalar) WriteHash(h hash.Hash) {
   247  	fmt.Fprintf(h, "const %v: %v", TypeOf(c.v), c.v)
   248  }
   249  
   250  func (c constantScalar) Hashcode() uint32 {
   251  	h := fnv.New32a()
   252  	c.WriteHash(h)
   253  	return h.Sum32()
   254  }
   255  
   256  func (c constantScalar) isconstant() bool { return true }
   257  func (c constantScalar) Value() Value     { return c.v }
   258  
   259  type constantTensor struct {
   260  	v tensor.Tensor
   261  }
   262  
   263  func (c constantTensor) Arity() int                                   { return 1 }
   264  func (c constantTensor) Type() hm.Type                                { return TypeOf(c.v) }
   265  func (c constantTensor) InferShape(...DimSizer) (tensor.Shape, error) { return c.v.Shape(), nil }
   266  
   267  // danger! The only reason why this is the case is because matrices may be too large. copying is costly.
   268  // constants should return value but for the sake of memory, we're going to return pointers
   269  func (c constantTensor) ReturnsPtr() bool                           { return true }
   270  func (c constantTensor) OverwritesInput() int                       { return -1 }
   271  func (c constantTensor) CallsExtern() bool                          { return false }
   272  func (c constantTensor) DiffWRT(i int) []bool                       { return nil }
   273  func (c constantTensor) SymDiff(Nodes, *Node, *Node) (Nodes, error) { return nil, nil }
   274  func (c constantTensor) Do(...Value) (Value, error)                 { return c.v, nil }
   275  func (c constantTensor) String() string                             { return fmt.Sprintf("const %s", TypeOf(c.v)) }
   276  
   277  func (c constantTensor) WriteHash(h hash.Hash) {
   278  	fmt.Fprintf(h, "const %v:%v", c.Type(), c.v)
   279  }
   280  
   281  func (c constantTensor) Hashcode() uint32 {
   282  	h := fnv.New32a()
   283  	c.WriteHash(h)
   284  	return h.Sum32()
   285  }
   286  
   287  func (c constantTensor) isconstant() bool { return true }
   288  func (c constantTensor) Value() Value     { return c.v }