gorgonia.org/gorgonia@v0.9.17/op.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "hash" 6 "hash/fnv" 7 8 "github.com/chewxy/hm" 9 "github.com/pkg/errors" 10 "gorgonia.org/tensor" 11 ) 12 13 // DimSizer is any type (typically a tensor.Shape) that allows querying for a dimension size given an input dimension. 14 type DimSizer interface { 15 DimSize(int) (int, error) 16 } 17 18 // ShapesToDimSizers is a convenience function to convert a slice of tensor.Shape to a slice of DimSizer 19 func ShapesToDimSizers(shapes []tensor.Shape) []DimSizer { 20 retVal := make([]DimSizer, len(shapes)) 21 for i, s := range shapes { 22 retVal[i] = s 23 } 24 return retVal 25 } 26 27 // DimSizersToShapes is a convenience function to convert a slice of DimSizer to a slice of tensor.Shape. It will return an error if any of them isn't a tensor.Shape 28 func DimSizersToShapes(ds []DimSizer) ([]tensor.Shape, error) { 29 retVal := make([]tensor.Shape, len(ds)) 30 var ok bool 31 for i, d := range ds { 32 if retVal[i], ok = d.(tensor.Shape); !ok { 33 return nil, errors.Errorf("Dimsizer %d is not a Shape.", i) 34 } 35 } 36 return retVal, nil 37 } 38 39 // An Op is a symbolic representation of an operation 40 // Think of them as functions, taking an input (or multiple), and outputting something 41 // 42 // All Ops have type signatures that look like this: 43 // OpName :: (Floats a) ⇒ Tensor a → Tensor a → Tensor a 44 type Op interface { 45 /* Graph Building Related Methods */ 46 47 // Arity returns the number of inputs the Op expects. -1 indicates that it's n-ary and will be determined at runtime 48 Arity() int 49 50 // Informs the type of the Op (not the node). This will be used by the type system to infer the final type of the node 51 Type() hm.Type 52 53 // returns the output shape as a function of the inputs 54 InferShape(...DimSizer) (tensor.Shape, error) 55 56 /* Machine related */ 57 58 // executes the op 59 Do(...Value) (Value, error) 60 61 /* Analysis Related Methods */ 62 63 // indicates if the Op will return a pointer (allowing possible inplace edits) or by value 64 // if it's false, the return value of the Op will be a copy of its input 65 ReturnsPtr() bool 66 67 // Does this op potentially call external (cgo or cuda) functions (thereby requiring extra overhead for Go's trampolining thing) 68 CallsExtern() bool 69 70 // overwriteInput() is a method which states which input the output will be overwriting. 71 // This allows for some efficiency gains as the underlying arrays wouldn't have to be re-allocated. 72 // The method returns an int instead of a bool because potentially different operations may be allowed 73 // to overwrite certain inputs. For example, consider an operation to increment a value: 74 // the IncrementOp would be a unary operator, and assuming we would like to overwrite the input, 75 // the retVal of overwriteInput() will be 0 (inputs[0]). 76 // -1 is returned if overwriting of input is disallowed 77 OverwritesInput() int 78 79 /* Other methods */ 80 WriteHash(h hash.Hash) 81 Hashcode() uint32 82 fmt.Stringer 83 } 84 85 // A UnaryOp is an Op that takes only one input 86 type UnaryOp interface { 87 Op 88 89 IsUnary() bool 90 } 91 92 // A BinaryOp is an Op that takes only two inputs 93 type BinaryOp interface { 94 Op 95 96 IsBinary() bool 97 } 98 99 // A NoRetOp is an Op that reads a value, but does not return any value. It's a representation of a not-pure function 100 type NoRetOp interface { 101 Op 102 103 ReturnsNothing() bool 104 } 105 106 // An ADOp is an Op that supports automatic differentiation. 107 type ADOp interface { 108 Op 109 110 DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) error 111 } 112 113 // A SDOp is an Op that supports symbolic differentiation 114 type SDOp interface { 115 Op 116 117 // DiffWRT indicates if the op is differentiable with regards to the given number of inputs 118 // returns []bool to indicate which input it is differentiable to 119 DiffWRT(inputs int) []bool 120 121 // SymDiff symbolically differentiates the op 122 SymDiff(inputs Nodes, output, grad *Node) (retVal Nodes, err error) 123 } 124 125 // ReductionOp changes the shape of the node 126 type ReductionOp interface { 127 Op 128 129 IsReduction() bool 130 } 131 132 // IncrDoer increments the toIncr with the result of doing 133 type IncrDoer interface { 134 IncrDo(toIncr Value, inputs ...Value) error 135 } 136 137 // UsePreallocDoer is an op that works when a preallocated value is provided 138 type UsePreallocDoer interface { 139 UsePreallocDo(prealloc Value, inputs ...Value) (Value, error) 140 } 141 142 // UnsafeDoer is an op that will overwrite the underlying value. 143 type UnsafeDoer interface { 144 UnsafeDo(inputs ...Value) (Value, error) 145 } 146 147 // CUDADoer uses CUDA to perform the Op. 148 type CUDADoer interface { 149 CUDADo(extern External, dev Device, prealloc Value, inputs ...Value) (retVal Value, err error) 150 } 151 152 // CLDoer uses OpenCL to perform the Op. As of now, there are NO Ops that support OpenCL 153 type CLDoer interface { 154 CLDo(inputs ...Value) (Value, error) 155 } 156 157 // A CUDAADOp operation have a specific method to run with CUDA 158 type CUDAADOp interface { 159 ADOp 160 CUDADoDiff(extern External, dev Device, inputs Nodes, output *Node) error 161 } 162 163 // ApplyOp is the generic function application - for when no specialization is required 164 func ApplyOp(op Op, children ...*Node) (retVal *Node, err error) { 165 var g *ExprGraph 166 167 for _, child := range children { 168 if child.g != nil { 169 g = child.g 170 break 171 } 172 } 173 174 if g == nil { 175 return nil, errors.New("No Graph Supplied") 176 } 177 178 if !Nodes(children).AllSameGraph() { 179 return nil, errors.New("Not all children have the same graph") 180 } 181 182 // typecheck before creating 183 typeSysLogf("Inferring node type of %v :: %v with children: %#Y", op, op.Type(), Nodes(children)) 184 enterLogScope() 185 defer leaveLogScope() 186 var retType hm.Type 187 if retType, err = inferNodeType(op, children...); err != nil { 188 return nil, errors.Wrapf(err, "Type inference error. Op: %v. Children: %#Y, OpType:%v", op, Nodes(children), op.Type()) 189 } 190 typeSysLogf("Done inferring. Return type is: %#v(%T)", retType, retType) 191 192 // infer shapes, but print errors instead of returning 193 shapeLogf("op: %v(%T) inferring shape", op, op) 194 if err = checkArity(op, len(children)); err != nil { 195 return 196 } 197 198 ds := Nodes(children).dimSizers() 199 var s tensor.Shape 200 if s, err = op.InferShape(ds...); err == nil { 201 shapeLogf("inferred shape %v", s) 202 retVal = NewUniqueNode(WithType(retType), WithOp(op), WithChildren(children), In(g), WithShape(s...)) 203 } else { 204 err = errors.Wrapf(err, "Failed to infer shape. Op: %v", op) 205 // retVal = newUniqueNode(withType(retType), withOp(op), withChildren(children), withGraph(g)) 206 } 207 returnDimSizers(ds) 208 return 209 } 210 211 // ApplyOpWithName applies the op, and then gives the node the given name 212 func ApplyOpWithName(op Op, name string, children ...*Node) (retVal *Node, err error) { 213 if retVal, err = ApplyOp(op, children...); err == nil { 214 WithName(name)(retVal) 215 } else { 216 return nil, errors.Wrap(err, applyOpFail) 217 } 218 return 219 } 220 221 // a constant is an unchanging value. I think everyone would know what a constant is 222 // a constant op is an op that creates a constant. It is also a Value of a constant value 223 type constant interface { 224 Op 225 226 isconstant() bool 227 Value() Value 228 } 229 230 type constantScalar struct { 231 v Scalar 232 } 233 234 func (c constantScalar) Arity() int { return 0 } 235 func (c constantScalar) Type() hm.Type { return TypeOf(c.v) } 236 func (c constantScalar) InferShape(...DimSizer) (tensor.Shape, error) { return scalarShape, nil } 237 func (c constantScalar) ReturnsPtr() bool { return false } 238 func (c constantScalar) CallsExtern() bool { return false } 239 func (c constantScalar) OverwritesInput() int { return -1 } 240 func (c constantScalar) DiffWRT(i int) []bool { return nil } 241 func (c constantScalar) SymDiff(Nodes, *Node, *Node) (Nodes, error) { return nil, nil } 242 243 func (c constantScalar) Do(...Value) (Value, error) { return c.v, nil } 244 func (c constantScalar) String() string { return fmt.Sprintf("const %s", c.v) } 245 246 func (c constantScalar) WriteHash(h hash.Hash) { 247 fmt.Fprintf(h, "const %v: %v", TypeOf(c.v), c.v) 248 } 249 250 func (c constantScalar) Hashcode() uint32 { 251 h := fnv.New32a() 252 c.WriteHash(h) 253 return h.Sum32() 254 } 255 256 func (c constantScalar) isconstant() bool { return true } 257 func (c constantScalar) Value() Value { return c.v } 258 259 type constantTensor struct { 260 v tensor.Tensor 261 } 262 263 func (c constantTensor) Arity() int { return 1 } 264 func (c constantTensor) Type() hm.Type { return TypeOf(c.v) } 265 func (c constantTensor) InferShape(...DimSizer) (tensor.Shape, error) { return c.v.Shape(), nil } 266 267 // danger! The only reason why this is the case is because matrices may be too large. copying is costly. 268 // constants should return value but for the sake of memory, we're going to return pointers 269 func (c constantTensor) ReturnsPtr() bool { return true } 270 func (c constantTensor) OverwritesInput() int { return -1 } 271 func (c constantTensor) CallsExtern() bool { return false } 272 func (c constantTensor) DiffWRT(i int) []bool { return nil } 273 func (c constantTensor) SymDiff(Nodes, *Node, *Node) (Nodes, error) { return nil, nil } 274 func (c constantTensor) Do(...Value) (Value, error) { return c.v, nil } 275 func (c constantTensor) String() string { return fmt.Sprintf("const %s", TypeOf(c.v)) } 276 277 func (c constantTensor) WriteHash(h hash.Hash) { 278 fmt.Fprintf(h, "const %v:%v", c.Type(), c.v) 279 } 280 281 func (c constantTensor) Hashcode() uint32 { 282 h := fnv.New32a() 283 c.WriteHash(h) 284 return h.Sum32() 285 } 286 287 func (c constantTensor) isconstant() bool { return true } 288 func (c constantTensor) Value() Value { return c.v }