gorgonia.org/gorgonia@v0.9.17/dual_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/chewxy/hm"
     7  	"github.com/stretchr/testify/assert"
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  func Test_dvBind0(t *testing.T) {
    12  	var x, y, z Value
    13  	var xT, yT, zT hm.Type
    14  	x, xT = anyToScalar(2.0)
    15  	y, yT = anyToScalar(3.0)
    16  	z, zT = anyToScalar(0.0)
    17  
    18  	op := newEBOByType(addOpType, xT, yT)
    19  	xdv := constantDV(x)
    20  	ydv := constantDV(y)
    21  	zdv := constantDV(z)
    22  	dvBind0(op, zdv, []*dualValue{xdv, ydv})
    23  
    24  	t.Logf("%v %v", zdv, zT)
    25  
    26  }
    27  
    28  func TestDVBindVar(t *testing.T) {
    29  	var x, y Value
    30  	var xT, yT hm.Type
    31  	x, xT = anyToScalar(2.0)
    32  	y, yT = anyToScalar(3.0)
    33  
    34  	op := newEBOByType(addOpType, xT, yT)
    35  	xdv := constantDV(x)
    36  	ydv := constantDV(y)
    37  	retVal, err := dvBindVar(op, []*dualValue{xdv, ydv})
    38  	if err != nil {
    39  		t.Error(err)
    40  	}
    41  	assert.Equal(t, 1.0, retVal.d.Data())
    42  
    43  	x = tensor.New(tensor.WithBacking([]float64{4, 3, 2, 1}))
    44  	op = newEBOByType(addOpType, TypeOf(x), TypeOf(y))
    45  	xdv = constantDV(x)
    46  	ydv = constantDV(y)
    47  	if retVal, err = dvBindVar(op, []*dualValue{xdv, ydv}); err != nil {
    48  		t.Error(err)
    49  	}
    50  	assert.Equal(t, []float64{1, 1, 1, 1}, retVal.d.Data())
    51  }