gorgonia.org/gorgonia@v0.9.17/operatorPointwise_unary_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"math"
     5  	"math/rand"
     6  	"testing"
     7  
     8  	"github.com/chewxy/math32"
     9  	"github.com/stretchr/testify/assert"
    10  	"gorgonia.org/dawson"
    11  	"gorgonia.org/tensor"
    12  )
    13  
    14  func unaryOpTest(t *testing.T, dt tensor.Dtype, shape tensor.Shape, fn func(*Node) (*Node, error)) (x, y, a, b *Node, v Value, err error) {
    15  	var xV, aV Value
    16  	var any interface{}
    17  	if shape.IsScalar() {
    18  		if dt == tensor.Float64 {
    19  			any = rand.ExpFloat64()
    20  		} else {
    21  			any = float32(rand.ExpFloat64())
    22  		}
    23  	} else {
    24  		any = tensor.New(tensor.WithBacking(tensor.Random(dt, shape.TotalSize())))
    25  	}
    26  	if v, _, _, err = anyToValue(any); err != nil {
    27  		t.Errorf("anyToValue failed %v", err)
    28  		return
    29  	}
    30  	if xV, err = CloneValue(v); err != nil {
    31  		t.Errorf("Clone to xV failed %v", err)
    32  		return
    33  	}
    34  
    35  	g := NewGraph()
    36  	x = NodeFromAny(g, xV, WithName("x"))
    37  	y = Must(fn(x))
    38  	Must(Sum(y))
    39  
    40  	var grads Nodes
    41  	h := NewGraph()
    42  	a = NodeFromAny(h, xV, WithName("x"))
    43  	b = Must(fn(a))
    44  	cost := Must(Sum(b))
    45  	if grads, err = Grad(cost, a); err != nil {
    46  		t.Errorf("Unable to get gradient %v", err)
    47  		return
    48  	}
    49  
    50  	if aV, err = CloneValue(v); err != nil {
    51  		t.Errorf("Clone to aV failed: %v", err)
    52  		return
    53  	}
    54  
    55  	m0 := NewLispMachine(g)
    56  	m1 := NewTapeMachine(h)
    57  	defer m1.Close()
    58  	defer m0.Close()
    59  
    60  	Let(x, xV)
    61  	if err = m0.RunAll(); err != nil {
    62  		t.Errorf("m0 failed: %v", err)
    63  		return
    64  	}
    65  
    66  	Let(a, aV)
    67  	if err = m1.RunAll(); err != nil {
    68  		t.Errorf("m1 failed: %v", err)
    69  		return
    70  	}
    71  
    72  	var yV, xG, bV, aG Value
    73  	yV = y.Value()
    74  	if xG, err = x.Grad(); err != nil {
    75  		t.Errorf("x has no grad: %v", err)
    76  		return
    77  	}
    78  
    79  	bV = b.Value()
    80  	if aG, err = a.Grad(); err != nil {
    81  		t.Errorf("a has no grad: %v", err)
    82  		t.Logf("a.deriv %p | %p", a.deriv, grads[0])
    83  		return
    84  	}
    85  
    86  	if !ValueClose(yV, bV) {
    87  		t.Errorf("Expected yV and bV to be close. yV: %v, bV: %v", yV, bV)
    88  	}
    89  
    90  	if !ValueClose(aG, xG) {
    91  		t.Errorf("Expected aG and xG to be close. aG: %v, xG %v", aG, xG)
    92  	}
    93  
    94  	return
    95  }
    96  
    97  func unaryOpDiffTest(op ʘUnaryOperatorType) (xRandVal float64, x, y, xT, yT *Node, err error) {
    98  	_, x, y = simpleUnaryEqn()
    99  
   100  	xRandVal = rand.ExpFloat64()
   101  	fn := *(sf64UnaryOperators[op])
   102  	diff := ʘUnaryOpDiffFns[op]
   103  
   104  	// let the first stone be cast!
   105  	Let(x, xRandVal)
   106  	v, _, _, _ := anyToValue(fn(xRandVal)) // as if the graph has been executed upon
   107  	ydv := variableDV(v)
   108  
   109  	if err = y.bind(ydv); err != nil {
   110  		return
   111  	}
   112  
   113  	if err = x.bind(dvUnit(x.boundTo)); err != nil {
   114  		return
   115  	}
   116  
   117  	if err = diff(x, y); err != nil {
   118  		return
   119  	}
   120  
   121  	// Tensor edition
   122  	_, xT, yT = simpleUnaryVecEqn()
   123  
   124  	xBack := []float64{-xRandVal, xRandVal}
   125  	yBack := []float64{fn(-xRandVal), fn(xRandVal)}
   126  	Let(xT, tensor.New(tensor.WithShape(2, 1), tensor.WithBacking(xBack)))
   127  	vT, _, _, _ := anyToValue(tensor.New(tensor.WithShape(2, 1), tensor.WithBacking(yBack)))
   128  	yTdv := variableDV(vT)
   129  
   130  	if err = yT.bind(yTdv); err != nil {
   131  		return
   132  	}
   133  
   134  	if err = xT.bind(dvUnit(xT.boundTo)); err != nil {
   135  		return
   136  	}
   137  
   138  	if err = diff(xT, yT); err != nil {
   139  		return
   140  	}
   141  	return
   142  }
   143  
   144  func TestAbs(t *testing.T) {
   145  	assert := assert.New(t)
   146  
   147  	var x, y, a, b *Node
   148  	var v Value
   149  	var yV, xG, bV, aG Value
   150  	var err error
   151  
   152  	/* FLOAT 64 Scalar */
   153  
   154  	x, y, a, b, v, err = unaryOpTest(t, Float64, tensor.Shape{}, Abs)
   155  	if err != nil {
   156  		t.Fatal(err)
   157  	}
   158  
   159  	yV = y.Value()
   160  	if xG, err = x.Grad(); err != nil {
   161  		t.Errorf("x has no grad: %v", err)
   162  		return
   163  	}
   164  
   165  	bV = b.Value()
   166  	if aG, err = a.Grad(); err != nil {
   167  		t.Errorf("a has no grad: %v", err)
   168  	}
   169  
   170  	correctF64 := math.Abs(v.Data().(float64))
   171  	assert.True(ValueClose(NewF64(correctF64), yV))
   172  	assert.True(ValueClose(NewF64(correctF64), bV))
   173  	assert.True(ValueClose(NewF64(1.0), xG))
   174  	assert.True(ValueClose(NewF64(1.0), aG))
   175  
   176  	/* FLOAT 32 Scalar */
   177  
   178  	x, y, a, b, v, err = unaryOpTest(t, Float32, tensor.Shape{}, Abs)
   179  	if err != nil {
   180  		t.Fatal(err)
   181  	}
   182  
   183  	yV = y.Value()
   184  	if xG, err = x.Grad(); err != nil {
   185  		t.Errorf("x has no grad: %v", err)
   186  		return
   187  	}
   188  
   189  	bV = b.Value()
   190  	if aG, err = a.Grad(); err != nil {
   191  		t.Errorf("a has no grad: %v", err)
   192  	}
   193  
   194  	correctF32 := math32.Abs(v.Data().(float32))
   195  	assert.True(ValueClose(NewF32(correctF32), yV))
   196  	assert.True(ValueClose(NewF32(correctF32), bV))
   197  	assert.True(ValueClose(NewF32(1.0), xG))
   198  	assert.True(ValueClose(NewF32(1.0), aG))
   199  
   200  	/* FLOAT64 Vector */
   201  
   202  	x, y, a, b, v, err = unaryOpTest(t, Float64, tensor.Shape{10}, Abs)
   203  	if err != nil {
   204  		t.Fatal(err)
   205  	}
   206  
   207  	yV = y.Value()
   208  	if xG, err = x.Grad(); err != nil {
   209  		t.Errorf("x has no grad: %v", err)
   210  		return
   211  	}
   212  
   213  	bV = b.Value()
   214  	if aG, err = a.Grad(); err != nil {
   215  		t.Errorf("a has no grad: %v", err)
   216  	}
   217  
   218  	absF64s := v.Data().([]float64)
   219  	backingGrad64 := make([]float64, len(absF64s))
   220  	for i, v := range absF64s {
   221  		absF64s[i] = math.Abs(v)
   222  		if v > 0 {
   223  			backingGrad64[i] = 1
   224  		} else {
   225  			backingGrad64[i] = -1
   226  		}
   227  	}
   228  	correctVecF64 := tensor.New(tensor.WithBacking(absF64s))
   229  	gradF64s := tensor.New(tensor.WithBacking(backingGrad64))
   230  
   231  	assert.True(ValueClose(correctVecF64, yV))
   232  	assert.True(ValueClose(correctVecF64, bV))
   233  	assert.True(ValueClose(gradF64s, xG), "xG %v", xG)
   234  	assert.True(ValueClose(gradF64s, aG), "aG %v", aG)
   235  
   236  	/* FLOAT32 Vector */
   237  
   238  	x, y, a, b, v, err = unaryOpTest(t, Float32, tensor.Shape{10}, Abs)
   239  	if err != nil {
   240  		t.Fatal(err)
   241  	}
   242  
   243  	yV = y.Value()
   244  	if xG, err = x.Grad(); err != nil {
   245  		t.Errorf("x has no grad: %v", err)
   246  		return
   247  	}
   248  
   249  	bV = b.Value()
   250  	if aG, err = a.Grad(); err != nil {
   251  		t.Errorf("a has no grad: %v", err)
   252  	}
   253  
   254  	absF32s := v.Data().([]float32)
   255  	backingGrad32 := make([]float32, len(absF32s))
   256  	for i, v := range absF32s {
   257  		absF32s[i] = math32.Abs(v)
   258  		if v > 0 {
   259  			backingGrad32[i] = 1
   260  		} else {
   261  			backingGrad32[i] = -1
   262  		}
   263  	}
   264  	correctVecF32 := tensor.New(tensor.WithBacking(absF32s))
   265  	gradF32s := tensor.New(tensor.WithBacking(backingGrad32))
   266  
   267  	assert.True(ValueClose(correctVecF32, yV))
   268  	assert.True(ValueClose(correctVecF32, bV))
   269  	assert.True(ValueClose(gradF32s, xG), "xG %v", xG)
   270  	assert.True(ValueClose(gradF32s, aG), "aG %v", aG)
   271  
   272  }
   273  
   274  func TestSinDiff(t *testing.T) {
   275  	assert := assert.New(t)
   276  	v, x, _, xT, _, err := unaryOpDiffTest(sinOpType)
   277  	if err != nil {
   278  		t.Error(err)
   279  	}
   280  
   281  	correct := math.Cos(v)
   282  	assert.Equal(correct, x.boundTo.(*dualValue).d.Data())
   283  
   284  	// Tensor edition
   285  	xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense)
   286  	correctT := []float64{math.Cos(-v), math.Cos(v)}
   287  	assert.Equal(correctT, xdvd.Data())
   288  }
   289  
   290  func TestCosDiff(t *testing.T) {
   291  	assert := assert.New(t)
   292  
   293  	v, x, _, xT, _, err := unaryOpDiffTest(cosOpType)
   294  	if err != nil {
   295  		t.Error(err)
   296  	}
   297  
   298  	assert.Equal(-math.Sin(v), x.boundTo.(*dualValue).d.Data())
   299  
   300  	// Tensor edition
   301  	xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense)
   302  	correct := []float64{-math.Sin(-v), -math.Sin(v)}
   303  	assert.Equal(correct, xdvd.Data())
   304  }
   305  
   306  func TestExpDiff(t *testing.T) {
   307  	assert := assert.New(t)
   308  	_, x, y, xT, yT, err := unaryOpDiffTest(expOpType)
   309  	if err != nil {
   310  		t.Error(err)
   311  	}
   312  
   313  	assert.Equal(y.boundTo.(*dualValue).Value, x.boundTo.(*dualValue).d)
   314  
   315  	// Tensor edition
   316  	xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense)
   317  	ydvd := yT.boundTo.(*dualValue).Value.(*tensor.Dense)
   318  	assert.Equal(ydvd.Data(), xdvd.Data())
   319  }
   320  
   321  func TestLnDiff(t *testing.T) {
   322  	assert := assert.New(t)
   323  	var err error
   324  	v, x, _, xT, _, err := unaryOpDiffTest(lnOpType)
   325  	if err != nil {
   326  		t.Error(err)
   327  	}
   328  	correct := 1.0 / v
   329  	assert.Equal(correct, x.boundTo.(*dualValue).d.Data(), "v was %v", v)
   330  
   331  	// Tensor edition
   332  	xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense)
   333  	correctT := []float64{1.0 / -v, 1.0 / v}
   334  	assert.Equal(correctT, xdvd.Data())
   335  }
   336  
   337  func TestLog2Diff(t *testing.T) {
   338  	assert := assert.New(t)
   339  	v, x, _, xT, _, err := unaryOpDiffTest(log2OpType)
   340  	if err != nil {
   341  		t.Error(err)
   342  	}
   343  	correct := 1.0 / (v * math.Ln2)
   344  	assert.Equal(correct, x.boundTo.(*dualValue).d.Data())
   345  
   346  	// Tensor edition
   347  	xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense)
   348  	correctT := []float64{1.0 / (-v * math.Ln2), 1.0 / (v * math.Ln2)}
   349  	assert.Equal(correctT, xdvd.Data())
   350  }
   351  
   352  func TestSquareDiff(t *testing.T) {
   353  	assert := assert.New(t)
   354  	var err error
   355  	v, x, _, xT, _, err := unaryOpDiffTest(squareOpType)
   356  	if err != nil {
   357  		t.Error(err)
   358  	}
   359  
   360  	assert.Equal(2*v, x.boundTo.(*dualValue).d.Data())
   361  
   362  	// Tensor edition
   363  	xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense)
   364  	correct := []float64{2 * -v, 2 * v}
   365  	assert.Equal(correct, xdvd.Data())
   366  }
   367  
   368  func TestSqrtDiff(t *testing.T) {
   369  	assert := assert.New(t)
   370  	v, x, _, xT, _, err := unaryOpDiffTest(sqrtOpType)
   371  	if err != nil {
   372  		t.Error(err)
   373  	}
   374  
   375  	assert.Equal(1.0/(2*math.Sqrt(v)), x.boundTo.(*dualValue).d.Data())
   376  
   377  	// Tensor edition
   378  	xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense)
   379  	correct := []float64{1.0 / (2 * math.Sqrt(-v)), 1.0 / (2 * math.Sqrt(v))}
   380  	got := xdvd.Data().([]float64)
   381  	if !math.IsNaN(got[0]) && math.IsNaN(correct[0]) {
   382  		t.Error("Expected NaN for the first value")
   383  	}
   384  	if got[1] != correct[1] {
   385  		t.Error("Different second values")
   386  	}
   387  }
   388  
   389  func TestInverseDiff(t *testing.T) {
   390  	assert := assert.New(t)
   391  	v, x, _, xT, _, err := unaryOpDiffTest(inverseOpType)
   392  	if err != nil {
   393  		t.Error(err)
   394  	}
   395  
   396  	correct := -((1 / v) * (1 / v))
   397  	assert.Equal(correct, x.boundTo.(*dualValue).d.Data())
   398  
   399  	// Tensor edition
   400  	xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense)
   401  	correctT := []float64{correct, correct}
   402  	assert.Equal(correctT, xdvd.Data())
   403  }
   404  
   405  func TestCubeDiff(t *testing.T) {
   406  	assert := assert.New(t)
   407  	v, x, _, xT, _, err := unaryOpDiffTest(cubeOpType)
   408  	if err != nil {
   409  		t.Error(err)
   410  	}
   411  
   412  	correct := 3 * v * v
   413  	xG, err := x.Grad()
   414  	if err != nil {
   415  		t.Error(err)
   416  	}
   417  
   418  	assert.True(dawson.CloseF64(correct, extractF64(xG)), "%v != %v", xG, correct)
   419  
   420  	// Tensor edition
   421  	xdvd := xT.boundTo.(*dualValue).d
   422  	correctT := []float64{correct, correct}
   423  	assert.True(floatsEqual64(correctT, extractF64s(xdvd)))
   424  }
   425  
   426  func TestTanhDiff(t *testing.T) {
   427  	assert := assert.New(t)
   428  	v, x, _, xT, _, err := unaryOpDiffTest(tanhOpType)
   429  	if err != nil {
   430  		t.Error(err)
   431  	}
   432  
   433  	// NOTE: there are not guarantees of identical behaviours across architectures,
   434  	// in this case arm64 gives different results than amd64 for Tanh.
   435  	// See https://github.com/golang/go/issues/18354#issuecomment-267705645
   436  	correct := 1.0 - (float64(math.Tanh(v)) * float64(math.Tanh(v))) // I'm surprised Golang doesn't have a secant function!
   437  	assert.InDeltaf(correct, x.boundTo.(*dualValue).d.Data(), 1e-14, "")
   438  
   439  	// Tensor edition
   440  	xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense)
   441  	assert.InDeltaSlicef([]float64{correct, correct}, xdvd.Data(), 1e-14, "")
   442  }
   443  
   444  func TestSigmoidDiff(t *testing.T) {
   445  	assert := assert.New(t)
   446  	v, x, _, xT, _, err := unaryOpDiffTest(sigmoidOpType)
   447  	if err != nil {
   448  		t.Error(err)
   449  	}
   450  
   451  	correct := math.Exp(-v) / ((1 + math.Exp(-v)) * (1 + math.Exp(-v)))
   452  	xG := x.boundTo.(*dualValue).d
   453  	assert.True(dawson.CloseF64(correct, extractF64(xG)))
   454  
   455  	// Tensor edition
   456  	xdvd := xT.boundTo.(*dualValue).d
   457  	negCorrect := math.Exp(v) / ((1 + math.Exp(v)) * (1 + math.Exp(v)))
   458  	corrects := []float64{negCorrect, correct}
   459  	assert.True(floatsEqual64(corrects, extractF64s(xdvd)))
   460  }
   461  
   462  func TestLog1pDiff(t *testing.T) {
   463  	assert := assert.New(t)
   464  	v, x, _, xT, _, err := unaryOpDiffTest(log1pOpType)
   465  	if err != nil {
   466  		t.Error(err)
   467  	}
   468  
   469  	correct := 1 / (1.0 + v)
   470  	assert.Equal(correct, x.boundTo.(*dualValue).d.Data())
   471  
   472  	// Tensor edition
   473  	xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense)
   474  	correct0 := 1 / (1.0 - v)
   475  	assert.Equal([]float64{correct0, correct}, xdvd.Data())
   476  }
   477  
   478  func TestExpm1Diff(t *testing.T) {
   479  	assert := assert.New(t)
   480  	v, x, _, xT, _, err := unaryOpDiffTest(expm1OpType)
   481  	if err != nil {
   482  		t.Error(err)
   483  	}
   484  
   485  	correct := math.Exp(v)
   486  	assert.Equal(correct, x.boundTo.(*dualValue).d.Data())
   487  
   488  	// Tensor edition
   489  	xdvd := xT.boundTo.(*dualValue).d.(*tensor.Dense)
   490  	correct0 := math.Exp(-v)
   491  	assert.Equal([]float64{correct0, correct}, xdvd.Data())
   492  }
   493  
   494  func TestSoftplus(t *testing.T) {
   495  	assert := assert.New(t)
   496  
   497  	var x, y, a, b *Node
   498  	var v Value
   499  	var xV, yV, xG, bV, aG Value
   500  	var err error
   501  
   502  	/* FLOAT64 SCALAR */
   503  
   504  	if x, y, a, b, v, err = unaryOpTest(t, Float64, tensor.Shape{}, Softplus); err != nil {
   505  		t.Fatal(err)
   506  	}
   507  
   508  	xV = x.Value()
   509  	yV = y.Value()
   510  	if xG, err = x.Grad(); err != nil {
   511  		t.Errorf("x has no grad: %v", err)
   512  		return
   513  	}
   514  
   515  	bV = b.Value()
   516  	if aG, err = a.Grad(); err != nil {
   517  		t.Errorf("a has no grad: %v", err)
   518  	}
   519  
   520  	correctVF64 := softplusf64(v.Data().(float64))
   521  	correctDF64 := sigmoidf64(xV.Data().(float64))
   522  	assert.True(ValueClose(NewF64(correctVF64), yV))
   523  	assert.True(ValueClose(NewF64(correctVF64), bV))
   524  	assert.True(ValueClose(NewF64(correctDF64), xG))
   525  	assert.True(ValueClose(NewF64(correctDF64), aG))
   526  
   527  	/* FLOAT32 SCALAR */
   528  
   529  	if x, y, a, b, v, err = unaryOpTest(t, Float32, tensor.Shape{}, Softplus); err != nil {
   530  		t.Fatal(err)
   531  	}
   532  
   533  	xV = x.Value()
   534  	yV = y.Value()
   535  	if xG, err = x.Grad(); err != nil {
   536  		t.Errorf("x has no grad: %v", err)
   537  		return
   538  	}
   539  
   540  	bV = b.Value()
   541  	if aG, err = a.Grad(); err != nil {
   542  		t.Errorf("a has no grad: %v", err)
   543  	}
   544  
   545  	correctVF32 := softplusf32(v.Data().(float32))
   546  	correctDF32 := sigmoidf32(xV.Data().(float32))
   547  	assert.True(ValueClose(NewF32(correctVF32), yV))
   548  	assert.True(ValueClose(NewF32(correctVF32), bV))
   549  	assert.True(ValueClose(NewF32(correctDF32), xG))
   550  	assert.True(ValueClose(NewF32(correctDF32), aG))
   551  
   552  	/* FLOAT64 Vector */
   553  
   554  	if x, y, a, b, v, err = unaryOpTest(t, Float64, tensor.Shape{10}, Softplus); err != nil {
   555  		t.Fatal(err)
   556  	}
   557  
   558  	xV = x.Value()
   559  	yV = y.Value()
   560  	if xG, err = x.Grad(); err != nil {
   561  		t.Errorf("x has no grad: %v", err)
   562  		return
   563  	}
   564  
   565  	bV = b.Value()
   566  	if aG, err = a.Grad(); err != nil {
   567  		t.Errorf("a has no grad: %v", err)
   568  	}
   569  
   570  	correctVF64s := v.Data().([]float64)
   571  	correctDF64s := xV.Data().([]float64)
   572  
   573  	for i, v := range correctVF64s {
   574  		correctVF64s[i] = softplusf64(v)
   575  		correctDF64s[i] = sigmoidf64(correctDF64s[i])
   576  	}
   577  	assert.True(floatsEqual64(correctVF64s, yV.Data().([]float64)))
   578  	assert.True(floatsEqual64(correctVF64s, bV.Data().([]float64)))
   579  	assert.True(floatsEqual64(correctDF64s, xG.Data().([]float64)))
   580  	assert.True(floatsEqual64(correctDF64s, aG.Data().([]float64)))
   581  
   582  	/* FLOAT32 Vector */
   583  
   584  	if x, y, a, b, v, err = unaryOpTest(t, Float32, tensor.Shape{10}, Softplus); err != nil {
   585  		t.Fatal(err)
   586  	}
   587  
   588  	xV = x.Value()
   589  	yV = y.Value()
   590  	if xG, err = x.Grad(); err != nil {
   591  		t.Errorf("x has no grad: %v", err)
   592  		return
   593  	}
   594  
   595  	bV = b.Value()
   596  	if aG, err = a.Grad(); err != nil {
   597  		t.Errorf("a has no grad: %v", err)
   598  	}
   599  
   600  	correctVF32s := v.Data().([]float32)
   601  	correctDF32s := xV.Data().([]float32)
   602  
   603  	for i, v := range correctVF32s {
   604  		correctVF32s[i] = softplusf32(v)
   605  		correctDF32s[i] = sigmoidf32(correctDF32s[i])
   606  	}
   607  	assert.True(floatsEqual32(correctVF32s, yV.Data().([]float32)))
   608  	assert.True(floatsEqual32(correctVF32s, bV.Data().([]float32)))
   609  	assert.True(floatsEqual32(correctDF32s, xG.Data().([]float32)))
   610  	assert.True(floatsEqual32(correctDF32s, aG.Data().([]float32)))
   611  }