gorgonia.org/gorgonia@v0.9.17/op_nondiff.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "hash" 6 7 "github.com/chewxy/hm" 8 "gorgonia.org/tensor" 9 ) 10 11 type diagFlatOp struct{} 12 13 /* Graph Building Related Methods */ 14 15 // Arity returns the number of inputs the Op expects. -1 indicates that it's n-ary and will be determined at runtime 16 func (op diagFlatOp) Arity() int { return 1 } 17 18 // Informs the type of the Op (not the node). This will be used by the type system to infer the final type of the node 19 func (op diagFlatOp) Type() hm.Type { 20 a := hm.TypeVariable('a') 21 b := hm.TypeVariable('a') 22 T := makeTensorType(2, b) 23 return hm.NewFnType(a, T) 24 } 25 26 // returns the output shape as a function of the inputs 27 func (op diagFlatOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) { 28 if err := checkArity(op, len(inputs)); err != nil { 29 return nil, err 30 } 31 in := inputs[0].(tensor.Shape) 32 return tensor.Shape{in.TotalSize(), in.TotalSize()}, nil 33 } 34 35 /* Machine related */ // executes the op 36 func (op diagFlatOp) Do(vals ...Value) (Value, error) { 37 if err := checkArity(op, len(vals)); err != nil { 38 return nil, err 39 } 40 41 T := vals[0].(tensor.Tensor) 42 return tensor.New(tensor.AsDenseDiag(T.Data())), nil 43 } 44 45 /* Analysis Related Methods */ 46 47 // indicates if the Op will return a pointer (allowing possible inplace edits) or by value 48 // if it's false, the return value of the Op will be a copy of its input 49 func (op diagFlatOp) ReturnsPtr() bool { return false } 50 51 // Does this op potentially call external (cgo or cuda) functions (thereby requiring extra overhead for Go's trampolining thing) 52 func (op diagFlatOp) CallsExtern() bool { return false } 53 54 // overwriteInput() is a method which states which input the output will be overwriting. 55 // This allows for some efficiency gains as the underlying arrays wouldn't have to be re-allocated. 56 // The method returns an int instead of a bool because potentially different operations may be allowed 57 // to overwrite certain inputs. For example, consider an operation to increment a value: 58 // the IncrementOp would be a unary operator, and assuming we would like to overwrite the input, 59 // the retVal of overwriteInput() will be 0 (inputs[0]). 60 // -1 is returned if overwriting of input is disallowed 61 func (op diagFlatOp) OverwritesInput() int { return -1 } 62 63 /* Other methods */ 64 func (op diagFlatOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, "DiagFlatOp") } 65 66 func (op diagFlatOp) Hashcode() uint32 { return simpleHash(op) } 67 68 func (op diagFlatOp) String() string { return "DiagFlat" }