gorgonia.org/gorgonia@v0.9.17/gorgonia.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/chewxy/hm"
     7  	"github.com/pkg/errors"
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  // Functions in this file returns *Node and panics if an error happens
    12  
    13  /* Helper functions to create new input nodes */
    14  
    15  // Must indicates a node must be created. If there isn't a node created, or there was an error,
    16  // it subsumes the error, and immediately panics
    17  func Must(n *Node, err error, opts ...NodeConsOpt) *Node {
    18  	if err != nil || n == nil {
    19  		panic(err)
    20  	}
    21  	return n
    22  }
    23  
    24  // NodeFromAny creates a Node from a tensor.Tensor, automatically filling in shape and type info
    25  func NodeFromAny(g *ExprGraph, any interface{}, opts ...NodeConsOpt) *Node {
    26  	v, t, dt, err := anyToValue(any)
    27  	if err != nil {
    28  		panic(err)
    29  	}
    30  
    31  	opts = append(opts, WithValue(v))
    32  
    33  	switch t.(type) {
    34  	case tensor.Dtype:
    35  		return NewScalar(g, dt, opts...)
    36  	case TensorType:
    37  		opts = append(opts, nil)
    38  		copy(opts[1:], opts[0:len(opts)-1])
    39  		opts[0] = WithShape(v.Shape()...)
    40  		return NewTensor(g, dt, v.Shape().Dims(), opts...)
    41  	default:
    42  		panic(nyi("NewNodeFromAny", any))
    43  	}
    44  }
    45  
    46  // NewScalar creates a Node representing a variable that holds a scalar value
    47  func NewScalar(g *ExprGraph, t tensor.Dtype, opts ...NodeConsOpt) *Node {
    48  	curOpts := []NodeConsOpt{WithType(t), In(g), WithShape()}
    49  	curOpts = append(curOpts, opts...)
    50  
    51  	return NewUniqueNode(curOpts...)
    52  }
    53  
    54  // NewVector creates a Node representing a variable that holds a vector (nx1 matrix)
    55  func NewVector(g *ExprGraph, t tensor.Dtype, opts ...NodeConsOpt) *Node {
    56  	tt := makeTensorType(1, t)
    57  	curOpts := []NodeConsOpt{WithType(tt), In(g)}
    58  	curOpts = append(curOpts, opts...)
    59  
    60  	return NewUniqueNode(curOpts...)
    61  }
    62  
    63  // NewMatrix creates a Node representing a variable that holds a matrix (nxm)
    64  func NewMatrix(g *ExprGraph, t tensor.Dtype, opts ...NodeConsOpt) *Node {
    65  	tt := makeTensorType(2, t)
    66  	curOpts := []NodeConsOpt{WithType(tt), In(g)}
    67  	curOpts = append(curOpts, opts...)
    68  
    69  	return NewUniqueNode(curOpts...)
    70  }
    71  
    72  // NewTensor creates a Node representing a variable that holds a tensor (any n-dimensional array with dimensions greater than 2)
    73  func NewTensor(g *ExprGraph, t tensor.Dtype, dims int, opts ...NodeConsOpt) *Node {
    74  	var tt hm.Type
    75  	if dims == 0 {
    76  		tt = t
    77  	} else {
    78  		tt = makeTensorType(dims, t)
    79  	}
    80  	curOpts := []NodeConsOpt{WithType(tt), In(g)}
    81  	curOpts = append(curOpts, opts...)
    82  
    83  	return NewUniqueNode(curOpts...)
    84  }
    85  
    86  // NewConstant takes in any reasonable value and makes it a constant node.
    87  func NewConstant(v interface{}, opts ...NodeConsOpt) *Node {
    88  	var op Op
    89  	var t hm.Type
    90  	var name string
    91  	var s tensor.Shape
    92  	var val Value
    93  
    94  	val, t, _, err := anyToValue(v)
    95  	if err != nil {
    96  		panic(err)
    97  	}
    98  	switch vt := val.(type) {
    99  	case Scalar:
   100  		op = constantScalar{vt}
   101  		s = scalarShape
   102  	case tensor.Tensor:
   103  		op = constantTensor{vt}
   104  		s = vt.Shape()
   105  	}
   106  
   107  	if op == nil || t == nil {
   108  		panic(fmt.Sprintf("HELP. Op: %v, t: %v", op, t))
   109  	}
   110  
   111  	dummy := borrowNode()
   112  	consOpts := []NodeConsOpt{WithOp(op), WithType(t), WithShape(s...), WithValue(val)}
   113  	consOpts = append(consOpts, opts...)
   114  	for i := range opts {
   115  		opts[i](dummy)
   116  	}
   117  	if dummy.name == "" {
   118  		name = fmt.Sprintf("%v", v)
   119  	} else {
   120  		name = dummy.name
   121  	}
   122  	returnNode(dummy)
   123  
   124  	consOpts = append(consOpts, WithName(name))
   125  	return newNode(consOpts...)
   126  }
   127  
   128  // UniformRandomNode creates an input node that has a random op so everytime the node is passed, random values will be plucked from
   129  // a uniform distribution. The type of the node depends on the
   130  // shape passed in. To get a scalar value at run time, don't pass in any shapes
   131  func UniformRandomNode(g *ExprGraph, dt tensor.Dtype, low, high float64, shape ...int) *Node {
   132  	op := makeRandomOp(uniform, dt, low, high, shape...)
   133  	s := tensor.Shape(shape)
   134  
   135  	var t hm.Type
   136  	if s.Eq(scalarShape) {
   137  		t = dt
   138  	} else {
   139  		t = makeTensorType(s.Dims(), dt)
   140  	}
   141  
   142  	retVal := NewUniqueNode(WithType(t), WithOp(op), In(g), WithShape(shape...))
   143  	return retVal
   144  }
   145  
   146  // GaussianRandomNode creates an input node that has a random op so everytime the node is passed, random values will be plucked from
   147  // a gaussian distribution with the mean and stdev provided. The type of the node depends on the
   148  // shape passed in. To get a scalar value at run time, don't pass in any shapes
   149  func GaussianRandomNode(g *ExprGraph, dt tensor.Dtype, mean, stdev float64, shape ...int) *Node {
   150  	op := makeRandomOp(gaussian, dt, mean, stdev, shape...)
   151  	s := tensor.Shape(shape)
   152  
   153  	var t hm.Type
   154  	if s.Eq(scalarShape) {
   155  		t = dt
   156  	} else {
   157  		t = makeTensorType(s.Dims(), dt)
   158  	}
   159  
   160  	retVal := NewUniqueNode(WithType(t), WithOp(op), In(g), WithShape(shape...))
   161  	return retVal
   162  }
   163  
   164  // BinomialRandomNode creates an input node that has a random op so that everytime the node is passed, random values will be plucked from
   165  // a binomial distribution with the mean and stdev provided. The type of the node depends on the
   166  // shape passed in. To get a scalar value at run time, don't pass in any shapes
   167  //
   168  // Whilst technically the number of trials of a binomal distribution should be a discrete value (you can't have half a trial), to keep with
   169  // API uniformity, trials is passed in as a float64, but will be truncated to an int at runtime.
   170  func BinomialRandomNode(g *ExprGraph, dt tensor.Dtype, trials, prob float64, shape ...int) *Node {
   171  	op := makeRandomOp(binomial, dt, trials, prob, shape...)
   172  	s := tensor.Shape(shape)
   173  
   174  	var t hm.Type
   175  	if s.Eq(scalarShape) {
   176  		t = dt
   177  	} else {
   178  		t = makeTensorType(s.Dims(), dt)
   179  	}
   180  
   181  	retVal := NewUniqueNode(WithType(t), WithOp(op), In(g), WithShape(shape...))
   182  	return retVal
   183  }
   184  
   185  // OneHotVector creates a node representing a one hot vector
   186  func OneHotVector(id, classes int, t tensor.Dtype, opts ...NodeConsOpt) *Node {
   187  	T := tensor.New(tensor.Of(t), tensor.WithShape(classes))
   188  	var err error
   189  	// This is stupid, I want generics. - docmerlin
   190  	switch t {
   191  	case tensor.Float32:
   192  		err = T.SetAt(float32(1), id)
   193  	case tensor.Float64:
   194  		err = T.SetAt(float64(1), id)
   195  	case tensor.Int64:
   196  		err = T.SetAt(int64(1), id)
   197  	case tensor.Int:
   198  		err = T.SetAt(int(1), id)
   199  	case tensor.Int32:
   200  		err = T.SetAt(int32(1), id)
   201  	default:
   202  		panic("tensor.Dtype not implemented")
   203  	}
   204  	if err != nil {
   205  		panic(err.Error())
   206  	}
   207  	return NewConstant(T, opts...)
   208  }
   209  
   210  // Grad takes a scalar cost node and a list of with-regards-to, and returns the gradient
   211  func Grad(cost *Node, WRTs ...*Node) (retVal Nodes, err error) {
   212  	symdiffLogf("Cost:%v", cost)
   213  	if !cost.IsScalar() {
   214  		return nil, errors.Errorf("Expected Cost to be a scalar. Got %v instead", cost)
   215  	}
   216  
   217  	for i, n := range WRTs {
   218  		if !n.isInput() {
   219  			err = errors.Errorf("Can only differentiate with regards to input nodes. %dth Node %v isn't an input", i, n)
   220  			return nil, err
   221  		}
   222  	}
   223  
   224  	var dt tensor.Dtype
   225  	var ok bool
   226  	if dt, ok = cost.t.(tensor.Dtype); !ok {
   227  		err = errors.Wrap(err, "Expected a scalar dtype for cost")
   228  		return
   229  	}
   230  
   231  	var gradOut *Node
   232  	switch dt {
   233  	case Float64:
   234  		gradOut = onef64
   235  	case Float32:
   236  		gradOut = onef32
   237  	default:
   238  		return nil, errors.Wrapf(err, "%s not yet implemented for %v of %T", dt.String(), "Grad()'s gradOut", gradOut)
   239  	}
   240  
   241  	gradOut = cost.g.AddNode(gradOut)
   242  	return Backpropagate(Nodes{cost}, Nodes{gradOut}, Nodes(WRTs))
   243  }
   244  
   245  // Let binds a Value to a node that is a variable. A variable is represented as a *Node with no Op.
   246  // It is equivalent to :
   247  //		x = 2
   248  func Let(n *Node, be interface{}) error {
   249  	if !n.isInput() {
   250  		return errors.New("Cannot bind a value to a non input node")
   251  	}
   252  
   253  	return UnsafeLet(n, be)
   254  }
   255  
   256  // UnsafeLet binds a Value to any node, not just a variable node. This means that you can use it to change any node's value at the runtime of the graph. UNSAFE!
   257  //
   258  // Additional notes: if `be` is a tensor.Slice, and the node's op is a sliceOp or sliceIncrOp, the op's slice will be replaced with the new slice.
   259  func UnsafeLet(n *Node, be interface{}) error {
   260  	switch v := be.(type) {
   261  	case tensor.Slice:
   262  		switch so := n.op.(type) {
   263  		case *sliceOp:
   264  			so.Slice = v
   265  			n.op = so
   266  		case sliceIncrOp:
   267  			so.Slice = v
   268  			n.op = so
   269  		default:
   270  			return errors.Errorf("Trying to Let() a node with a slice. Node's op is %v, not sliceOp", n.op)
   271  		}
   272  
   273  	case Value:
   274  		if !n.Shape().Eq(v.Shape()) {
   275  			return fmt.Errorf("Node's expected shape is %v. Got %v instead", n.Shape(), v.Shape())
   276  		}
   277  
   278  		if !n.Dtype().Eq(v.Dtype()) {
   279  			return errors.Errorf("Unable to let %v be %v. Expected Dtype of %v. Got %v instead", n.name, be, n.Dtype(), v.Dtype())
   280  		}
   281  		n.bind(v)
   282  	default:
   283  		var val Value
   284  		var err error
   285  		if val, _, _, err = anyToValue(be); err != nil {
   286  			return errors.Wrapf(err, anyToValueFail, be, be)
   287  		}
   288  
   289  		n.bind(val)
   290  	}
   291  	return nil
   292  }
   293  
   294  // Set is the equivalent of doing this:
   295  //		a = b
   296  // where a and b are both variables
   297  func Set(a, b *Node) (retVal *Node) {
   298  	op := letOp{}
   299  	name := fmt.Sprintf("%v %s %v", a, op, b)
   300  	return NewUniqueNode(WithOp(op), WithChildren(Nodes{a, b}), WithName(name), In(a.g))
   301  }
   302  
   303  // Read allows for extraction of the value of the *Node at runtime into a Value.
   304  // To achieve this, a pointer to a Value (*Value) is passed into this function, not a Value.
   305  // The 'into' value remains nil until the execution of the graph (via a call to the Run() methods of the VM)
   306  func Read(n *Node, into *Value) (retVal *Node) {
   307  	op := readOp{into}
   308  	name := fmt.Sprintf("read %v into %v", n, into)
   309  	retVal = NewUniqueNode(WithOp(op), WithChildren(Nodes{n}), WithName(name), In(n.g))
   310  	retVal.op = op // this ensures the correct pointer is written
   311  	retVal.name = name
   312  	return
   313  }