gorgonia.org/gorgonia@v0.9.17/gorgonia.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 6 "github.com/chewxy/hm" 7 "github.com/pkg/errors" 8 "gorgonia.org/tensor" 9 ) 10 11 // Functions in this file returns *Node and panics if an error happens 12 13 /* Helper functions to create new input nodes */ 14 15 // Must indicates a node must be created. If there isn't a node created, or there was an error, 16 // it subsumes the error, and immediately panics 17 func Must(n *Node, err error, opts ...NodeConsOpt) *Node { 18 if err != nil || n == nil { 19 panic(err) 20 } 21 return n 22 } 23 24 // NodeFromAny creates a Node from a tensor.Tensor, automatically filling in shape and type info 25 func NodeFromAny(g *ExprGraph, any interface{}, opts ...NodeConsOpt) *Node { 26 v, t, dt, err := anyToValue(any) 27 if err != nil { 28 panic(err) 29 } 30 31 opts = append(opts, WithValue(v)) 32 33 switch t.(type) { 34 case tensor.Dtype: 35 return NewScalar(g, dt, opts...) 36 case TensorType: 37 opts = append(opts, nil) 38 copy(opts[1:], opts[0:len(opts)-1]) 39 opts[0] = WithShape(v.Shape()...) 40 return NewTensor(g, dt, v.Shape().Dims(), opts...) 41 default: 42 panic(nyi("NewNodeFromAny", any)) 43 } 44 } 45 46 // NewScalar creates a Node representing a variable that holds a scalar value 47 func NewScalar(g *ExprGraph, t tensor.Dtype, opts ...NodeConsOpt) *Node { 48 curOpts := []NodeConsOpt{WithType(t), In(g), WithShape()} 49 curOpts = append(curOpts, opts...) 50 51 return NewUniqueNode(curOpts...) 52 } 53 54 // NewVector creates a Node representing a variable that holds a vector (nx1 matrix) 55 func NewVector(g *ExprGraph, t tensor.Dtype, opts ...NodeConsOpt) *Node { 56 tt := makeTensorType(1, t) 57 curOpts := []NodeConsOpt{WithType(tt), In(g)} 58 curOpts = append(curOpts, opts...) 59 60 return NewUniqueNode(curOpts...) 61 } 62 63 // NewMatrix creates a Node representing a variable that holds a matrix (nxm) 64 func NewMatrix(g *ExprGraph, t tensor.Dtype, opts ...NodeConsOpt) *Node { 65 tt := makeTensorType(2, t) 66 curOpts := []NodeConsOpt{WithType(tt), In(g)} 67 curOpts = append(curOpts, opts...) 68 69 return NewUniqueNode(curOpts...) 70 } 71 72 // NewTensor creates a Node representing a variable that holds a tensor (any n-dimensional array with dimensions greater than 2) 73 func NewTensor(g *ExprGraph, t tensor.Dtype, dims int, opts ...NodeConsOpt) *Node { 74 var tt hm.Type 75 if dims == 0 { 76 tt = t 77 } else { 78 tt = makeTensorType(dims, t) 79 } 80 curOpts := []NodeConsOpt{WithType(tt), In(g)} 81 curOpts = append(curOpts, opts...) 82 83 return NewUniqueNode(curOpts...) 84 } 85 86 // NewConstant takes in any reasonable value and makes it a constant node. 87 func NewConstant(v interface{}, opts ...NodeConsOpt) *Node { 88 var op Op 89 var t hm.Type 90 var name string 91 var s tensor.Shape 92 var val Value 93 94 val, t, _, err := anyToValue(v) 95 if err != nil { 96 panic(err) 97 } 98 switch vt := val.(type) { 99 case Scalar: 100 op = constantScalar{vt} 101 s = scalarShape 102 case tensor.Tensor: 103 op = constantTensor{vt} 104 s = vt.Shape() 105 } 106 107 if op == nil || t == nil { 108 panic(fmt.Sprintf("HELP. Op: %v, t: %v", op, t)) 109 } 110 111 dummy := borrowNode() 112 consOpts := []NodeConsOpt{WithOp(op), WithType(t), WithShape(s...), WithValue(val)} 113 consOpts = append(consOpts, opts...) 114 for i := range opts { 115 opts[i](dummy) 116 } 117 if dummy.name == "" { 118 name = fmt.Sprintf("%v", v) 119 } else { 120 name = dummy.name 121 } 122 returnNode(dummy) 123 124 consOpts = append(consOpts, WithName(name)) 125 return newNode(consOpts...) 126 } 127 128 // UniformRandomNode creates an input node that has a random op so everytime the node is passed, random values will be plucked from 129 // a uniform distribution. The type of the node depends on the 130 // shape passed in. To get a scalar value at run time, don't pass in any shapes 131 func UniformRandomNode(g *ExprGraph, dt tensor.Dtype, low, high float64, shape ...int) *Node { 132 op := makeRandomOp(uniform, dt, low, high, shape...) 133 s := tensor.Shape(shape) 134 135 var t hm.Type 136 if s.Eq(scalarShape) { 137 t = dt 138 } else { 139 t = makeTensorType(s.Dims(), dt) 140 } 141 142 retVal := NewUniqueNode(WithType(t), WithOp(op), In(g), WithShape(shape...)) 143 return retVal 144 } 145 146 // GaussianRandomNode creates an input node that has a random op so everytime the node is passed, random values will be plucked from 147 // a gaussian distribution with the mean and stdev provided. The type of the node depends on the 148 // shape passed in. To get a scalar value at run time, don't pass in any shapes 149 func GaussianRandomNode(g *ExprGraph, dt tensor.Dtype, mean, stdev float64, shape ...int) *Node { 150 op := makeRandomOp(gaussian, dt, mean, stdev, shape...) 151 s := tensor.Shape(shape) 152 153 var t hm.Type 154 if s.Eq(scalarShape) { 155 t = dt 156 } else { 157 t = makeTensorType(s.Dims(), dt) 158 } 159 160 retVal := NewUniqueNode(WithType(t), WithOp(op), In(g), WithShape(shape...)) 161 return retVal 162 } 163 164 // BinomialRandomNode creates an input node that has a random op so that everytime the node is passed, random values will be plucked from 165 // a binomial distribution with the mean and stdev provided. The type of the node depends on the 166 // shape passed in. To get a scalar value at run time, don't pass in any shapes 167 // 168 // Whilst technically the number of trials of a binomal distribution should be a discrete value (you can't have half a trial), to keep with 169 // API uniformity, trials is passed in as a float64, but will be truncated to an int at runtime. 170 func BinomialRandomNode(g *ExprGraph, dt tensor.Dtype, trials, prob float64, shape ...int) *Node { 171 op := makeRandomOp(binomial, dt, trials, prob, shape...) 172 s := tensor.Shape(shape) 173 174 var t hm.Type 175 if s.Eq(scalarShape) { 176 t = dt 177 } else { 178 t = makeTensorType(s.Dims(), dt) 179 } 180 181 retVal := NewUniqueNode(WithType(t), WithOp(op), In(g), WithShape(shape...)) 182 return retVal 183 } 184 185 // OneHotVector creates a node representing a one hot vector 186 func OneHotVector(id, classes int, t tensor.Dtype, opts ...NodeConsOpt) *Node { 187 T := tensor.New(tensor.Of(t), tensor.WithShape(classes)) 188 var err error 189 // This is stupid, I want generics. - docmerlin 190 switch t { 191 case tensor.Float32: 192 err = T.SetAt(float32(1), id) 193 case tensor.Float64: 194 err = T.SetAt(float64(1), id) 195 case tensor.Int64: 196 err = T.SetAt(int64(1), id) 197 case tensor.Int: 198 err = T.SetAt(int(1), id) 199 case tensor.Int32: 200 err = T.SetAt(int32(1), id) 201 default: 202 panic("tensor.Dtype not implemented") 203 } 204 if err != nil { 205 panic(err.Error()) 206 } 207 return NewConstant(T, opts...) 208 } 209 210 // Grad takes a scalar cost node and a list of with-regards-to, and returns the gradient 211 func Grad(cost *Node, WRTs ...*Node) (retVal Nodes, err error) { 212 symdiffLogf("Cost:%v", cost) 213 if !cost.IsScalar() { 214 return nil, errors.Errorf("Expected Cost to be a scalar. Got %v instead", cost) 215 } 216 217 for i, n := range WRTs { 218 if !n.isInput() { 219 err = errors.Errorf("Can only differentiate with regards to input nodes. %dth Node %v isn't an input", i, n) 220 return nil, err 221 } 222 } 223 224 var dt tensor.Dtype 225 var ok bool 226 if dt, ok = cost.t.(tensor.Dtype); !ok { 227 err = errors.Wrap(err, "Expected a scalar dtype for cost") 228 return 229 } 230 231 var gradOut *Node 232 switch dt { 233 case Float64: 234 gradOut = onef64 235 case Float32: 236 gradOut = onef32 237 default: 238 return nil, errors.Wrapf(err, "%s not yet implemented for %v of %T", dt.String(), "Grad()'s gradOut", gradOut) 239 } 240 241 gradOut = cost.g.AddNode(gradOut) 242 return Backpropagate(Nodes{cost}, Nodes{gradOut}, Nodes(WRTs)) 243 } 244 245 // Let binds a Value to a node that is a variable. A variable is represented as a *Node with no Op. 246 // It is equivalent to : 247 // x = 2 248 func Let(n *Node, be interface{}) error { 249 if !n.isInput() { 250 return errors.New("Cannot bind a value to a non input node") 251 } 252 253 return UnsafeLet(n, be) 254 } 255 256 // UnsafeLet binds a Value to any node, not just a variable node. This means that you can use it to change any node's value at the runtime of the graph. UNSAFE! 257 // 258 // Additional notes: if `be` is a tensor.Slice, and the node's op is a sliceOp or sliceIncrOp, the op's slice will be replaced with the new slice. 259 func UnsafeLet(n *Node, be interface{}) error { 260 switch v := be.(type) { 261 case tensor.Slice: 262 switch so := n.op.(type) { 263 case *sliceOp: 264 so.Slice = v 265 n.op = so 266 case sliceIncrOp: 267 so.Slice = v 268 n.op = so 269 default: 270 return errors.Errorf("Trying to Let() a node with a slice. Node's op is %v, not sliceOp", n.op) 271 } 272 273 case Value: 274 if !n.Shape().Eq(v.Shape()) { 275 return fmt.Errorf("Node's expected shape is %v. Got %v instead", n.Shape(), v.Shape()) 276 } 277 278 if !n.Dtype().Eq(v.Dtype()) { 279 return errors.Errorf("Unable to let %v be %v. Expected Dtype of %v. Got %v instead", n.name, be, n.Dtype(), v.Dtype()) 280 } 281 n.bind(v) 282 default: 283 var val Value 284 var err error 285 if val, _, _, err = anyToValue(be); err != nil { 286 return errors.Wrapf(err, anyToValueFail, be, be) 287 } 288 289 n.bind(val) 290 } 291 return nil 292 } 293 294 // Set is the equivalent of doing this: 295 // a = b 296 // where a and b are both variables 297 func Set(a, b *Node) (retVal *Node) { 298 op := letOp{} 299 name := fmt.Sprintf("%v %s %v", a, op, b) 300 return NewUniqueNode(WithOp(op), WithChildren(Nodes{a, b}), WithName(name), In(a.g)) 301 } 302 303 // Read allows for extraction of the value of the *Node at runtime into a Value. 304 // To achieve this, a pointer to a Value (*Value) is passed into this function, not a Value. 305 // The 'into' value remains nil until the execution of the graph (via a call to the Run() methods of the VM) 306 func Read(n *Node, into *Value) (retVal *Node) { 307 op := readOp{into} 308 name := fmt.Sprintf("read %v into %v", n, into) 309 retVal = NewUniqueNode(WithOp(op), WithChildren(Nodes{n}), WithName(name), In(n.g)) 310 retVal.op = op // this ensures the correct pointer is written 311 retVal.name = name 312 return 313 }