gorgonia.org/gorgonia@v0.9.17/op_nondiff.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"hash"
     6  
     7  	"github.com/chewxy/hm"
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  type diagFlatOp struct{}
    12  
    13  /* Graph Building Related Methods */
    14  
    15  // Arity returns the number of inputs the Op expects. -1 indicates that it's n-ary and will be determined at runtime
    16  func (op diagFlatOp) Arity() int { return 1 }
    17  
    18  // Informs the type of the Op (not the node). This will be used by the type system to infer the final type of the node
    19  func (op diagFlatOp) Type() hm.Type {
    20  	a := hm.TypeVariable('a')
    21  	b := hm.TypeVariable('a')
    22  	T := makeTensorType(2, b)
    23  	return hm.NewFnType(a, T)
    24  }
    25  
    26  // returns the output shape as a function of the inputs
    27  func (op diagFlatOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) {
    28  	if err := checkArity(op, len(inputs)); err != nil {
    29  		return nil, err
    30  	}
    31  	in := inputs[0].(tensor.Shape)
    32  	return tensor.Shape{in.TotalSize(), in.TotalSize()}, nil
    33  }
    34  
    35  /* Machine related */ // executes the op
    36  func (op diagFlatOp) Do(vals ...Value) (Value, error) {
    37  	if err := checkArity(op, len(vals)); err != nil {
    38  		return nil, err
    39  	}
    40  
    41  	T := vals[0].(tensor.Tensor)
    42  	return tensor.New(tensor.AsDenseDiag(T.Data())), nil
    43  }
    44  
    45  /* Analysis Related Methods */
    46  
    47  // indicates if the Op will return a pointer (allowing possible inplace edits) or by value
    48  // if it's false, the return value of the Op will be a copy of its input
    49  func (op diagFlatOp) ReturnsPtr() bool { return false }
    50  
    51  // Does this op potentially call external (cgo or cuda) functions (thereby requiring extra overhead for Go's trampolining thing)
    52  func (op diagFlatOp) CallsExtern() bool { return false }
    53  
    54  // overwriteInput() is a method which states which input the output will be overwriting.
    55  // This allows for some efficiency gains as the underlying arrays wouldn't have to be re-allocated.
    56  // The method returns an int instead of a bool because potentially different operations may be allowed
    57  // to overwrite certain inputs. For example, consider an operation to increment a value:
    58  // the IncrementOp would be a unary operator, and assuming we would like to overwrite the input,
    59  // the retVal of overwriteInput() will be 0 (inputs[0]).
    60  // -1 is returned if overwriting of input is disallowed
    61  func (op diagFlatOp) OverwritesInput() int { return -1 }
    62  
    63  /* Other methods */
    64  func (op diagFlatOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, "DiagFlatOp") }
    65  
    66  func (op diagFlatOp) Hashcode() uint32 { return simpleHash(op) }
    67  
    68  func (op diagFlatOp) String() string { return "DiagFlat" }