gorgonia.org/gorgonia@v0.9.17/dual_test.go (about) 1 package gorgonia 2 3 import ( 4 "testing" 5 6 "github.com/chewxy/hm" 7 "github.com/stretchr/testify/assert" 8 "gorgonia.org/tensor" 9 ) 10 11 func Test_dvBind0(t *testing.T) { 12 var x, y, z Value 13 var xT, yT, zT hm.Type 14 x, xT = anyToScalar(2.0) 15 y, yT = anyToScalar(3.0) 16 z, zT = anyToScalar(0.0) 17 18 op := newEBOByType(addOpType, xT, yT) 19 xdv := constantDV(x) 20 ydv := constantDV(y) 21 zdv := constantDV(z) 22 dvBind0(op, zdv, []*dualValue{xdv, ydv}) 23 24 t.Logf("%v %v", zdv, zT) 25 26 } 27 28 func TestDVBindVar(t *testing.T) { 29 var x, y Value 30 var xT, yT hm.Type 31 x, xT = anyToScalar(2.0) 32 y, yT = anyToScalar(3.0) 33 34 op := newEBOByType(addOpType, xT, yT) 35 xdv := constantDV(x) 36 ydv := constantDV(y) 37 retVal, err := dvBindVar(op, []*dualValue{xdv, ydv}) 38 if err != nil { 39 t.Error(err) 40 } 41 assert.Equal(t, 1.0, retVal.d.Data()) 42 43 x = tensor.New(tensor.WithBacking([]float64{4, 3, 2, 1})) 44 op = newEBOByType(addOpType, TypeOf(x), TypeOf(y)) 45 xdv = constantDV(x) 46 ydv = constantDV(y) 47 if retVal, err = dvBindVar(op, []*dualValue{xdv, ydv}); err != nil { 48 t.Error(err) 49 } 50 assert.Equal(t, []float64{1, 1, 1, 1}, retVal.d.Data()) 51 }