gorgonia.org/gorgonia@v0.9.17/dual.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 6 "github.com/chewxy/hm" 7 "github.com/pkg/errors" 8 "gorgonia.org/tensor" 9 ) 10 11 type dualValue struct { 12 Value 13 d Value // the derivative wrt to each input 14 } 15 16 func (dv *dualValue) SetDeriv(d Value) error { 17 if t, ok := d.(tensor.Tensor); ok && t.IsScalar() { 18 d, _ = anyToScalar(t.ScalarValue()) 19 } 20 dv.d = d 21 22 return dv.sanity() 23 } 24 25 func (dv *dualValue) SetValue(v Value) error { 26 dv.Value = v 27 return dv.sanity() 28 } 29 30 func (dv *dualValue) Clone() (retVal interface{}, err error) { 31 var v, d Value 32 if v, err = CloneValue(dv.Value); err != nil { 33 return nil, errors.Wrap(err, cloneFail) 34 } 35 36 if dv.d != nil { 37 if d, err = CloneValue(dv.d); err != nil { 38 return nil, errors.Wrap(err, cloneFail) 39 } 40 } 41 42 dv2 := borrowDV() 43 dv2.Value = v 44 dv2.d = d 45 retVal = dv2 46 return 47 } 48 49 func (dv *dualValue) Type() hm.Type { return TypeOf(dv.Value) } 50 func (dv *dualValue) Dtype() tensor.Dtype { return dv.Value.Dtype() } 51 52 func (dv *dualValue) ValueEq(a Value) bool { 53 switch at := a.(type) { 54 case *dualValue: 55 if at == dv { 56 return true 57 } 58 veq := ValueEq(at.Value, dv.Value) 59 deq := ValueEq(at.d, dv.d) 60 return veq && deq 61 // case Value: 62 // return ValueEq(at, dv.Value) 63 default: 64 return false 65 } 66 } 67 68 func (dv *dualValue) String() string { 69 return fmt.Sprintf("%#+v", dv.Value) 70 } 71 72 func (dv *dualValue) sanity() error { 73 // check that d and v are the same type 74 75 // dvv := typeCheckTypeOf(dv.Value) 76 // dvd := typeCheckTypeOf(dv.d) 77 // if !dvv.Eq(dvd) { 78 // return errors.Errorf("DualValues do not have the same types: %v and %v", dvv, dvd) 79 // } 80 // ReturnType(dvv) 81 // ReturnType(dvd) 82 83 // TODO: check that the shapes are the same 84 85 return nil 86 } 87 88 // clones the dualValue and zeroes out the ndarrays 89 func (dv *dualValue) clone0() (retVal *dualValue, err error) { 90 var v, d Value 91 if v, err = CloneValue(dv.Value); err != nil { 92 return nil, errors.Wrap(err, cloneFail) 93 } 94 95 if d, err = CloneValue(dv.d); err != nil { 96 return nil, errors.Wrap(err, cloneFail) 97 } 98 99 v = ZeroValue(v) 100 d = ZeroValue(d) 101 102 dv2 := borrowDV() 103 dv2.Value = v 104 dv2.d = d 105 retVal = dv2 106 return 107 } 108 109 // the derivative of a constant is zero. 110 // 111 // The original implementation was to have a constantDualValue type. This would lead to waaay less allocations of matrices 112 // but as it turns out, as I waws working, the constants turn out to be not so constant afterall. 113 // Is this a problem with the graph that leads to derivation of constant values? I don't quite know. TO CHECK 114 func constantDV(val Value) *dualValue { 115 enterLogScope() 116 defer leaveLogScope() 117 118 // retVal := &dualValue{Value: val} 119 retVal := borrowDV() 120 retVal.Value = val 121 122 var err error 123 if retVal.d, err = CloneValue(val); err != nil { 124 panic(err) 125 } 126 127 retVal.d = ZeroValue(retVal.d) 128 return retVal 129 } 130 131 // the derivative of x is 1. 132 func variableDV(val Value) *dualValue { 133 // retVal := &dualValue{Value: val} 134 retVal := borrowDV() 135 retVal.Value = val 136 137 switch v := val.(type) { 138 case Scalar: 139 retVal.d = one(v.Dtype()) 140 case tensor.Tensor: 141 shp := v.Shape() 142 dt := v.Dtype() 143 retVal.d = tensor.Ones(dt, shp...) 144 default: 145 panic(fmt.Sprintf("%v(%T) not handled yet", v, v)) 146 } 147 148 return retVal 149 } 150 151 // monadic unit() function. This unit() function will allocate a Value for dv.d 152 // this is useful for forward mode autodiff 153 func dvUnit(v Value) *dualValue { 154 enterLogScope() 155 defer leaveLogScope() 156 157 if dv, ok := v.(*dualValue); ok { 158 return dv 159 } 160 return constantDV(v) 161 } 162 163 func dvUnitVar(v Value) *dualValue { 164 if dv, ok := v.(*dualValue); ok { 165 return dv 166 } 167 return variableDV(v) 168 } 169 170 // no alloc is done. It'll just return a *dualValue with nil as the dv.d 171 func dvUnit0(v Value) *dualValue { 172 if dv, ok := v.(*dualValue); ok { 173 return dv 174 } 175 176 retVal := borrowDV() 177 retVal.Value = v 178 179 return retVal 180 } 181 182 // dvUnitManaged does dvUnit for values whose memories are manually managed 183 func dvUnitManaged(v Value, op *ExternalOp) (*dualValue, error) { 184 if op.Device == CPU { 185 return dvUnit(v), nil 186 } 187 188 if dv, ok := v.(*dualValue); ok { 189 return dv, nil 190 } 191 192 retVal := borrowDV() 193 retVal.Value = v 194 195 s := v.Shape() 196 dt := v.Dtype() 197 memsize := calcMemSize(dt, s) 198 // allocate on device 199 mem, err := op.Get(op.Device, memsize) 200 if err != nil { 201 return nil, err 202 } 203 204 d, err := makeValueFromMem(TypeOf(v), s, mem) 205 if err != nil { 206 return nil, err 207 } 208 retVal.d = d 209 210 return retVal, nil 211 } 212 213 func dvUnitVarManaged(v Value, op *ExternalOp) (*dualValue, error) { 214 dv, err := dvUnitManaged(v, op) 215 if err != nil { 216 return dv, err 217 } 218 219 switch d := dv.d.(type) { 220 case tensor.Tensor: 221 dt := d.Dtype() 222 switch dt { 223 case tensor.Float64: 224 d.Memset(1.0) 225 case tensor.Float32: 226 d.Memset(float32(1)) 227 case tensor.Bool: 228 d.Memset(true) 229 default: 230 return dv, errors.Errorf("Unhandled dtype: %v", dt) 231 } 232 case *F64: 233 *d = F64(1) 234 case *F32: 235 *d = F32(1) 236 case *I: 237 *d = I(1) 238 case *I64: 239 *d = I64(1) 240 case *I32: 241 *d = I32(1) 242 case *U8: 243 *d = U8(1) 244 case *B: 245 *d = B(true) 246 default: 247 return dv, errors.Errorf("Unhandeled type: %T", d) 248 } 249 return dv, nil 250 } 251 252 // helper to unpack from []*dualValue 253 func idValue(inputs []*dualValue) (retVals []Value) { 254 retVals = make([]Value, len(inputs)) 255 for i, input := range inputs { 256 retVals[i] = input.Value 257 } 258 return 259 } 260 261 // dvBind applies an op to the inputs, and returns a *dualValue 262 func dvBind(op Op, inputs []*dualValue) (retVal *dualValue, err error) { 263 enterLogScope() 264 defer leaveLogScope() 265 266 vals := idValue(inputs) 267 268 var ret Value 269 if ret, err = op.Do(vals...); err != nil { 270 return nil, errors.Wrap(err, opDoFail) 271 } 272 if o, ok := op.(*ExternalOp); ok { 273 return dvUnitManaged(ret, o) 274 } 275 return dvUnit(ret), nil 276 } 277 278 // dvBindVar returns a dvUnitVar instead of dvUnit (which zeroes the derivative). 279 // The default derivative of a variable wrt itself is 1 (dx/dx == 1) 280 func dvBindVar(op Op, inputs []*dualValue) (retVal *dualValue, err error) { 281 vals := idValue(inputs) 282 283 var ret Value 284 if ret, err = op.Do(vals...); err != nil { 285 return nil, errors.Wrap(err, opDoFail) 286 } 287 if o, ok := op.(*ExternalOp); ok { 288 return dvUnitVarManaged(ret, o) 289 } 290 return dvUnitVar(ret), nil 291 } 292 293 //TODO test vecvecdot divBind0 294 295 // doesn't alloc a dualValue, and reuses whatever that is there, and zeroes out the deriv 296 func dvBind0(op Op, retVal *dualValue, inputs []*dualValue) (err error) { 297 prealloc := retVal.Value 298 vals := idValue(inputs) 299 300 var ret Value 301 if pd, ok := op.(UsePreallocDoer); ok { 302 if ret, err = pd.UsePreallocDo(prealloc, vals...); err == nil { 303 goto next 304 } 305 } 306 if ret, err = op.Do(vals...); err != nil { 307 return errors.Wrap(err, opDoFail) 308 } 309 310 next: 311 if err != nil { 312 return 313 } 314 315 if err = retVal.SetValue(ret); err != nil { 316 return 317 } 318 319 retVal.SetDeriv(ZeroValue(retVal.d)) 320 return 321 } 322 323 func dvBindVar0(op Op, retVal *dualValue, inputs []*dualValue) (err error) { 324 prealloc := retVal.Value 325 326 vals := idValue(inputs) 327 328 var ret Value 329 if pd, ok := op.(UsePreallocDoer); ok { 330 ret, err = pd.UsePreallocDo(prealloc, vals...) 331 } else { 332 if ret, err = op.Do(vals...); err != nil { 333 return errors.Wrap(err, opDoFail) 334 } 335 } 336 337 if err != nil { 338 return errors.Wrapf(err, opDoFail) 339 } 340 341 if err = retVal.SetValue(ret); err != nil { 342 return errors.Wrap(err, "Failed at setting the value") 343 } 344 345 switch v := retVal.d.(type) { 346 case Scalar: 347 retVal.d = one(v.Dtype()) 348 case tensor.Tensor: 349 switch v.Dtype() { 350 case tensor.Float64: 351 err = v.Memset(float64(1)) 352 case tensor.Float32: 353 err = v.Memset(float32(1)) 354 } 355 retVal.d = v 356 default: 357 err = errors.Errorf(nyiTypeFail, "dvBindVar0", retVal.d) 358 } 359 return 360 }