gorgonia.org/gorgonia@v0.9.17/op_types.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "hash" 6 7 "github.com/chewxy/hm" 8 "github.com/pkg/errors" 9 "gorgonia.org/tensor" 10 ) 11 12 // ConvType converts the type of the x Node from one type to other 13 func ConvType(x *Node, from, to tensor.Dtype) (*Node, error) { 14 op := &dtConvOp{ 15 inshape: x.Shape(), 16 from: from, 17 to: to, 18 } 19 20 return ApplyOp(op, x) 21 } 22 23 type dtConvOp struct { 24 inshape tensor.Shape 25 from, to tensor.Dtype 26 } 27 28 /* Graph Building Related Methods */ 29 30 // Arity returns the number of inputs the Op expects. -1 indicates that it's n-ary and will be determined at runtime 31 func (op *dtConvOp) Arity() int { return 1 } 32 33 // Informs the type of the Op (not the node). This will be used by the type system to infer the final type of the node 34 func (op *dtConvOp) Type() hm.Type { 35 if op.inshape.IsScalar() { 36 return hm.NewFnType(op.from, op.to) 37 } 38 t := makeTensorType(op.inshape.Dims(), op.from) 39 u := makeTensorType(op.inshape.Dims(), op.to) 40 return hm.NewFnType(t, u) 41 } 42 43 // returns the output shape as a function of the inputs 44 func (op *dtConvOp) InferShape(_ ...DimSizer) (tensor.Shape, error) { 45 return op.inshape.Clone(), nil 46 } 47 48 // Do executes the op 49 func (op *dtConvOp) Do(vals ...Value) (Value, error) { 50 retVal := tensor.New(tensor.Of(op.to), tensor.WithShape(op.inshape.Clone()...)) 51 return op.UsePreallocDo(retVal, vals...) 52 } 53 54 /* Analysis Related Methods */ // indicates if the Op will return a pointer (allowing possible inplace edits) or by value 55 // if it's false, the return value of the Op will be a copy of its input 56 func (op *dtConvOp) ReturnsPtr() bool { return false } 57 58 // Does this op potentially call external (cgo or cuda) functions (thereby requiring extra overhead for Go's trampolining thing) 59 func (op *dtConvOp) CallsExtern() bool { return false } 60 61 // overwriteInput() is a method which states which input the output will be overwriting. 62 // This allows for some efficiency gains as the underlying arrays wouldn't have to be re-allocated. 63 // The method returns an int instead of a bool because potentially different operations may be allowed 64 // to overwrite certain inputs. For example, consider an operation to increment a value: 65 // the IncrementOp would be a unary operator, and assuming we would like to overwrite the input, 66 // the retVal of overwriteInput() will be 0 (inputs[0]). 67 // -1 is returned if overwriting of input is disallowed 68 func (op *dtConvOp) OverwritesInput() int { return -1 } 69 70 /* Other methods */ 71 func (op *dtConvOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, "(%v)", op.Type()) } 72 73 func (op *dtConvOp) Hashcode() uint32 { return simpleHash(op) } 74 75 func (op *dtConvOp) String() string { return fmt.Sprintf("%v", op.Type()) } 76 77 // DiffWRT indicates if the op is differentiable with regards to the given number of inputs 78 // returns []bool to indicate which input it is differentiable to 79 func (op *dtConvOp) DiffWRT(inputs int) []bool { return []bool{true} } 80 81 // SymDiff symbolically differentiates the op 82 func (op *dtConvOp) SymDiff(inputs Nodes, output *Node, grad *Node) (retVal Nodes, err error) { 83 diffOp := &dtConvOp{ 84 inshape: grad.Shape().Clone(), 85 from: op.to, 86 to: op.from, 87 } 88 retVal = make(Nodes, op.Arity()) 89 retVal[0], err = ApplyOp(diffOp, grad) 90 return retVal, err 91 } 92 93 // UsePreallocDo executes the Op with a preallocated value in the result.s 94 func (op *dtConvOp) UsePreallocDo(prealloc Value, inputs ...Value) (Value, error) { 95 a := inputs[0] 96 retVal := prealloc 97 switch { 98 case op.from == tensor.Float64 && op.to == tensor.Int: 99 switch aData := a.Data().(type) { 100 case []float64: 101 retData := retVal.Data().([]int) 102 for i := range aData { 103 retData[i] = int(aData[i]) 104 } 105 case float64: 106 retVal = tensor.New( 107 tensor.Of(tensor.Int), 108 tensor.WithShape(1), 109 tensor.WithBacking([]int{int(aData)}), 110 ) 111 } 112 case op.from == tensor.Float32 && op.to == tensor.Int: 113 switch aData := a.Data().(type) { 114 case []float32: 115 retData := retVal.Data().([]int) 116 for i := range aData { 117 retData[i] = int(aData[i]) 118 } 119 case float32: 120 retVal = tensor.New( 121 tensor.Of(tensor.Int), 122 tensor.WithShape(1), 123 tensor.WithBacking([]int{int(aData)}), 124 ) 125 } 126 case op.from == tensor.Int && op.to == tensor.Float64: 127 switch aData := a.Data().(type) { 128 case []int: 129 retData := retVal.Data().([]float64) 130 for i := range aData { 131 retData[i] = float64(aData[i]) 132 } 133 case int: 134 retVal = tensor.New( 135 tensor.Of(tensor.Float64), 136 tensor.WithShape(1), 137 tensor.WithBacking([]float64{float64(aData)}), 138 ) 139 } 140 case op.from == tensor.Int && op.to == tensor.Float32: 141 switch aData := a.Data().(type) { 142 case []int: 143 retData := retVal.Data().([]float32) 144 for i := range aData { 145 retData[i] = float32(aData[i]) 146 } 147 case int: 148 retVal = tensor.New( 149 tensor.Of(tensor.Float32), 150 tensor.WithShape(1), 151 tensor.WithBacking([]float32{float32(aData)}), 152 ) 153 } 154 default: 155 return nil, errors.Errorf("Cannot do conversion %v", op.Type()) 156 // TODO: other types 157 } 158 159 return retVal, nil 160 }