gorgonia.org/gorgonia@v0.9.17/dual.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/chewxy/hm"
     7  	"github.com/pkg/errors"
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  type dualValue struct {
    12  	Value
    13  	d Value // the derivative wrt to each input
    14  }
    15  
    16  func (dv *dualValue) SetDeriv(d Value) error {
    17  	if t, ok := d.(tensor.Tensor); ok && t.IsScalar() {
    18  		d, _ = anyToScalar(t.ScalarValue())
    19  	}
    20  	dv.d = d
    21  
    22  	return dv.sanity()
    23  }
    24  
    25  func (dv *dualValue) SetValue(v Value) error {
    26  	dv.Value = v
    27  	return dv.sanity()
    28  }
    29  
    30  func (dv *dualValue) Clone() (retVal interface{}, err error) {
    31  	var v, d Value
    32  	if v, err = CloneValue(dv.Value); err != nil {
    33  		return nil, errors.Wrap(err, cloneFail)
    34  	}
    35  
    36  	if dv.d != nil {
    37  		if d, err = CloneValue(dv.d); err != nil {
    38  			return nil, errors.Wrap(err, cloneFail)
    39  		}
    40  	}
    41  
    42  	dv2 := borrowDV()
    43  	dv2.Value = v
    44  	dv2.d = d
    45  	retVal = dv2
    46  	return
    47  }
    48  
    49  func (dv *dualValue) Type() hm.Type       { return TypeOf(dv.Value) }
    50  func (dv *dualValue) Dtype() tensor.Dtype { return dv.Value.Dtype() }
    51  
    52  func (dv *dualValue) ValueEq(a Value) bool {
    53  	switch at := a.(type) {
    54  	case *dualValue:
    55  		if at == dv {
    56  			return true
    57  		}
    58  		veq := ValueEq(at.Value, dv.Value)
    59  		deq := ValueEq(at.d, dv.d)
    60  		return veq && deq
    61  	// case Value:
    62  	// 	return ValueEq(at, dv.Value)
    63  	default:
    64  		return false
    65  	}
    66  }
    67  
    68  func (dv *dualValue) String() string {
    69  	return fmt.Sprintf("%#+v", dv.Value)
    70  }
    71  
    72  func (dv *dualValue) sanity() error {
    73  	// check that d and v are the same type
    74  
    75  	// dvv := typeCheckTypeOf(dv.Value)
    76  	// dvd := typeCheckTypeOf(dv.d)
    77  	// if !dvv.Eq(dvd) {
    78  	// 	return errors.Errorf("DualValues do not have the same types: %v and %v", dvv, dvd)
    79  	// }
    80  	// ReturnType(dvv)
    81  	// ReturnType(dvd)
    82  
    83  	// TODO: check that the shapes are the same
    84  
    85  	return nil
    86  }
    87  
    88  // clones the dualValue and zeroes out the ndarrays
    89  func (dv *dualValue) clone0() (retVal *dualValue, err error) {
    90  	var v, d Value
    91  	if v, err = CloneValue(dv.Value); err != nil {
    92  		return nil, errors.Wrap(err, cloneFail)
    93  	}
    94  
    95  	if d, err = CloneValue(dv.d); err != nil {
    96  		return nil, errors.Wrap(err, cloneFail)
    97  	}
    98  
    99  	v = ZeroValue(v)
   100  	d = ZeroValue(d)
   101  
   102  	dv2 := borrowDV()
   103  	dv2.Value = v
   104  	dv2.d = d
   105  	retVal = dv2
   106  	return
   107  }
   108  
   109  // the derivative of a constant is zero.
   110  //
   111  // The original implementation was to have a constantDualValue type. This would lead to waaay less allocations of matrices
   112  // but as it turns out, as I waws working, the constants turn out to be not so constant afterall.
   113  // Is this a problem with the graph that leads to derivation of constant values? I don't quite know. TO CHECK
   114  func constantDV(val Value) *dualValue {
   115  	enterLogScope()
   116  	defer leaveLogScope()
   117  
   118  	// retVal := &dualValue{Value: val}
   119  	retVal := borrowDV()
   120  	retVal.Value = val
   121  
   122  	var err error
   123  	if retVal.d, err = CloneValue(val); err != nil {
   124  		panic(err)
   125  	}
   126  
   127  	retVal.d = ZeroValue(retVal.d)
   128  	return retVal
   129  }
   130  
   131  // the derivative of x is 1.
   132  func variableDV(val Value) *dualValue {
   133  	// retVal := &dualValue{Value: val}
   134  	retVal := borrowDV()
   135  	retVal.Value = val
   136  
   137  	switch v := val.(type) {
   138  	case Scalar:
   139  		retVal.d = one(v.Dtype())
   140  	case tensor.Tensor:
   141  		shp := v.Shape()
   142  		dt := v.Dtype()
   143  		retVal.d = tensor.Ones(dt, shp...)
   144  	default:
   145  		panic(fmt.Sprintf("%v(%T) not handled yet", v, v))
   146  	}
   147  
   148  	return retVal
   149  }
   150  
   151  // monadic unit() function. This unit() function will allocate a Value for dv.d
   152  // this is useful for forward mode autodiff
   153  func dvUnit(v Value) *dualValue {
   154  	enterLogScope()
   155  	defer leaveLogScope()
   156  
   157  	if dv, ok := v.(*dualValue); ok {
   158  		return dv
   159  	}
   160  	return constantDV(v)
   161  }
   162  
   163  func dvUnitVar(v Value) *dualValue {
   164  	if dv, ok := v.(*dualValue); ok {
   165  		return dv
   166  	}
   167  	return variableDV(v)
   168  }
   169  
   170  // no alloc is done. It'll just return a *dualValue with nil as the dv.d
   171  func dvUnit0(v Value) *dualValue {
   172  	if dv, ok := v.(*dualValue); ok {
   173  		return dv
   174  	}
   175  
   176  	retVal := borrowDV()
   177  	retVal.Value = v
   178  
   179  	return retVal
   180  }
   181  
   182  // dvUnitManaged does dvUnit for values whose memories are manually managed
   183  func dvUnitManaged(v Value, op *ExternalOp) (*dualValue, error) {
   184  	if op.Device == CPU {
   185  		return dvUnit(v), nil
   186  	}
   187  
   188  	if dv, ok := v.(*dualValue); ok {
   189  		return dv, nil
   190  	}
   191  
   192  	retVal := borrowDV()
   193  	retVal.Value = v
   194  
   195  	s := v.Shape()
   196  	dt := v.Dtype()
   197  	memsize := calcMemSize(dt, s)
   198  	// allocate on device
   199  	mem, err := op.Get(op.Device, memsize)
   200  	if err != nil {
   201  		return nil, err
   202  	}
   203  
   204  	d, err := makeValueFromMem(TypeOf(v), s, mem)
   205  	if err != nil {
   206  		return nil, err
   207  	}
   208  	retVal.d = d
   209  
   210  	return retVal, nil
   211  }
   212  
   213  func dvUnitVarManaged(v Value, op *ExternalOp) (*dualValue, error) {
   214  	dv, err := dvUnitManaged(v, op)
   215  	if err != nil {
   216  		return dv, err
   217  	}
   218  
   219  	switch d := dv.d.(type) {
   220  	case tensor.Tensor:
   221  		dt := d.Dtype()
   222  		switch dt {
   223  		case tensor.Float64:
   224  			d.Memset(1.0)
   225  		case tensor.Float32:
   226  			d.Memset(float32(1))
   227  		case tensor.Bool:
   228  			d.Memset(true)
   229  		default:
   230  			return dv, errors.Errorf("Unhandled dtype: %v", dt)
   231  		}
   232  	case *F64:
   233  		*d = F64(1)
   234  	case *F32:
   235  		*d = F32(1)
   236  	case *I:
   237  		*d = I(1)
   238  	case *I64:
   239  		*d = I64(1)
   240  	case *I32:
   241  		*d = I32(1)
   242  	case *U8:
   243  		*d = U8(1)
   244  	case *B:
   245  		*d = B(true)
   246  	default:
   247  		return dv, errors.Errorf("Unhandeled type: %T", d)
   248  	}
   249  	return dv, nil
   250  }
   251  
   252  // helper to unpack from []*dualValue
   253  func idValue(inputs []*dualValue) (retVals []Value) {
   254  	retVals = make([]Value, len(inputs))
   255  	for i, input := range inputs {
   256  		retVals[i] = input.Value
   257  	}
   258  	return
   259  }
   260  
   261  // dvBind applies an op to the inputs, and returns a *dualValue
   262  func dvBind(op Op, inputs []*dualValue) (retVal *dualValue, err error) {
   263  	enterLogScope()
   264  	defer leaveLogScope()
   265  
   266  	vals := idValue(inputs)
   267  
   268  	var ret Value
   269  	if ret, err = op.Do(vals...); err != nil {
   270  		return nil, errors.Wrap(err, opDoFail)
   271  	}
   272  	if o, ok := op.(*ExternalOp); ok {
   273  		return dvUnitManaged(ret, o)
   274  	}
   275  	return dvUnit(ret), nil
   276  }
   277  
   278  // dvBindVar returns a dvUnitVar instead of dvUnit (which zeroes the derivative).
   279  // The default derivative of a variable wrt itself is 1 (dx/dx == 1)
   280  func dvBindVar(op Op, inputs []*dualValue) (retVal *dualValue, err error) {
   281  	vals := idValue(inputs)
   282  
   283  	var ret Value
   284  	if ret, err = op.Do(vals...); err != nil {
   285  		return nil, errors.Wrap(err, opDoFail)
   286  	}
   287  	if o, ok := op.(*ExternalOp); ok {
   288  		return dvUnitVarManaged(ret, o)
   289  	}
   290  	return dvUnitVar(ret), nil
   291  }
   292  
   293  //TODO test vecvecdot divBind0
   294  
   295  // doesn't alloc a dualValue, and reuses whatever that is there, and zeroes out the deriv
   296  func dvBind0(op Op, retVal *dualValue, inputs []*dualValue) (err error) {
   297  	prealloc := retVal.Value
   298  	vals := idValue(inputs)
   299  
   300  	var ret Value
   301  	if pd, ok := op.(UsePreallocDoer); ok {
   302  		if ret, err = pd.UsePreallocDo(prealloc, vals...); err == nil {
   303  			goto next
   304  		}
   305  	}
   306  	if ret, err = op.Do(vals...); err != nil {
   307  		return errors.Wrap(err, opDoFail)
   308  	}
   309  
   310  next:
   311  	if err != nil {
   312  		return
   313  	}
   314  
   315  	if err = retVal.SetValue(ret); err != nil {
   316  		return
   317  	}
   318  
   319  	retVal.SetDeriv(ZeroValue(retVal.d))
   320  	return
   321  }
   322  
   323  func dvBindVar0(op Op, retVal *dualValue, inputs []*dualValue) (err error) {
   324  	prealloc := retVal.Value
   325  
   326  	vals := idValue(inputs)
   327  
   328  	var ret Value
   329  	if pd, ok := op.(UsePreallocDoer); ok {
   330  		ret, err = pd.UsePreallocDo(prealloc, vals...)
   331  	} else {
   332  		if ret, err = op.Do(vals...); err != nil {
   333  			return errors.Wrap(err, opDoFail)
   334  		}
   335  	}
   336  
   337  	if err != nil {
   338  		return errors.Wrapf(err, opDoFail)
   339  	}
   340  
   341  	if err = retVal.SetValue(ret); err != nil {
   342  		return errors.Wrap(err, "Failed at setting the value")
   343  	}
   344  
   345  	switch v := retVal.d.(type) {
   346  	case Scalar:
   347  		retVal.d = one(v.Dtype())
   348  	case tensor.Tensor:
   349  		switch v.Dtype() {
   350  		case tensor.Float64:
   351  			err = v.Memset(float64(1))
   352  		case tensor.Float32:
   353  			err = v.Memset(float32(1))
   354  		}
   355  		retVal.d = v
   356  	default:
   357  		err = errors.Errorf(nyiTypeFail, "dvBindVar0", retVal.d)
   358  	}
   359  	return
   360  }