gorgonia.org/gorgonia@v0.9.17/operatorPointwise_binary_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"io/ioutil"
     6  	"math/rand"
     7  	"runtime"
     8  	"testing"
     9  
    10  	"github.com/pkg/errors"
    11  	"github.com/stretchr/testify/assert"
    12  	"gorgonia.org/tensor"
    13  )
    14  
    15  func ssBinOpTest(t *testing.T, op ʘBinaryOperatorType, dt tensor.Dtype) (err error) {
    16  	defer runtime.GC()
    17  	assert := assert.New(t)
    18  	var randX, randY interface{}
    19  	switch dt {
    20  	case Float64:
    21  		randX = rand.ExpFloat64()
    22  		randY = rand.ExpFloat64()
    23  	case Float32:
    24  		randX = float32(rand.ExpFloat64())
    25  		randY = float32(rand.ExpFloat64())
    26  	default:
    27  		return errors.Errorf("op %v Test not yet implemented for %v ", op, dt)
    28  	}
    29  
    30  	binOp := newEBOByType(op, dt, dt)
    31  	t.Logf("ssBinOp %v %v %v", randX, op, randY)
    32  
    33  	var g, g2 *ExprGraph
    34  	var x, y, z *Node
    35  	var a, b, c *Node
    36  	var i, j, k *Node
    37  	g = NewGraph()
    38  	x = NewScalar(g, dt, WithName("x"))
    39  	y = NewScalar(g, dt, WithName("y"))
    40  	if z, err = ApplyOp(binOp, x, y); err != nil {
    41  		return err
    42  	}
    43  
    44  	g2 = NewGraph()
    45  	a = NewScalar(g2, dt, WithName("a"))
    46  	b = NewScalar(g2, dt, WithName("b"))
    47  	if c, err = ApplyOp(binOp, a, b); err != nil {
    48  		return err
    49  	}
    50  
    51  	i = NewScalar(g, dt, WithName("i"))
    52  	j = NewScalar(g, dt, WithName("j"))
    53  	binOp.retSame = true
    54  	if k, err = ApplyOp(binOp, i, j); err != nil {
    55  		return err
    56  	}
    57  
    58  	// var grads Nodes
    59  	var m1 VM
    60  	if op.isArith() {
    61  		if _, err = Grad(c, a, b); err != nil {
    62  			return err
    63  		}
    64  		m1 = NewLispMachine(g)
    65  	} else {
    66  		m1 = NewLispMachine(g, ExecuteFwdOnly())
    67  	}
    68  
    69  	m2 := NewTapeMachine(g2, TraceExec(), BindDualValues())
    70  	defer m2.Close()
    71  	defer m1.Close()
    72  
    73  	Let(x, randX)
    74  	Let(y, randY)
    75  	Let(i, randX)
    76  	Let(j, randY)
    77  	if err = m1.RunAll(); err != nil {
    78  		return
    79  	}
    80  
    81  	Let(a, randX)
    82  	Let(b, randY)
    83  	if err = m2.RunAll(); err != nil {
    84  		return
    85  	}
    86  
    87  	var xG, aG, yG, bG, zG, cG Value
    88  	if op.isArith() {
    89  		if xG, err = x.Grad(); err != nil {
    90  			return
    91  		}
    92  		if yG, err = y.Grad(); err != nil {
    93  			return
    94  		}
    95  		if aG, err = a.Grad(); err != nil {
    96  			return
    97  		}
    98  		if bG, err = b.Grad(); err != nil {
    99  			return
   100  		}
   101  
   102  		if zG, err = z.Grad(); err != nil {
   103  			return
   104  		}
   105  		if cG, err = c.Grad(); err != nil {
   106  			return
   107  		}
   108  
   109  		if _, err = i.Grad(); err != nil {
   110  			return
   111  		}
   112  
   113  		if _, err = j.Grad(); err != nil {
   114  			return
   115  		}
   116  		if _, err = k.Grad(); err != nil {
   117  			return
   118  		}
   119  
   120  		assert.True(ValueClose(xG, aG), "Test ssDiff of %v. xG != aG. Got %v and %v", op, xG, aG)
   121  		assert.True(ValueClose(yG, bG), "Test ssDiff of %v. yG != bG. Got %v and %v", op, yG, bG)
   122  		assert.True(ValueClose(zG, cG), "Test ssDiff of %v. zG != cG. Got %v and %v", op, zG, cG)
   123  	}
   124  
   125  	assert.True(ValueClose(x.Value(), a.Value()), "Test ss op %v. Values are different: x: %v, a %v", op, x.Value(), a.Value())
   126  	assert.True(ValueClose(y.Value(), b.Value()), "Test ss op %v. Values are different: y: %v, b %v", op, y.Value(), b.Value())
   127  	assert.True(ValueClose(z.Value(), c.Value()), "Test ss op %v. Values are different: z: %v, c %v", op, z.Value(), c.Value())
   128  
   129  	return nil
   130  }
   131  
   132  func ttBinOpTest(t *testing.T, op ʘBinaryOperatorType, dt tensor.Dtype) (err error) {
   133  	defer runtime.GC()
   134  	assert := assert.New(t)
   135  	var x, y, z, a, b, c, cost *Node
   136  	var g, g2 *ExprGraph
   137  
   138  	var randX, randY interface{}
   139  	switch dt {
   140  	case Float32:
   141  		randX = []float32{1, 2, 3, 4}
   142  		randY = []float32{2, 2, 2, 2}
   143  	case Float64:
   144  		randX = []float64{1, 2, 3, 4}
   145  		randY = []float64{2, 2, 2, 2}
   146  	}
   147  
   148  	t.Logf("ttBinOp: %v %v %v", randX, op, randY)
   149  	// randX := Gaussian(0, 1)(dt, 2, 2)
   150  	// randY := Gaussian(0, 1)(dt, 2, 2)
   151  
   152  	xV := tensor.New(tensor.WithShape(2, 2), tensor.WithBacking(randX))
   153  	yV := tensor.New(tensor.WithShape(2, 2), tensor.WithBacking(randY))
   154  
   155  	g = NewGraph()
   156  	g2 = NewGraph()
   157  	x = NewMatrix(g, dt, WithName("x"), WithShape(2, 2))
   158  	y = NewMatrix(g, dt, WithName("y"), WithShape(2, 2))
   159  	a = NewMatrix(g2, dt, WithName("a"), WithShape(2, 2))
   160  	b = NewMatrix(g2, dt, WithName("b"), WithShape(2, 2))
   161  
   162  	binOp := newEBOByType(op, x.t, y.t)
   163  	if z, err = ApplyOp(binOp, x, y); err != nil {
   164  		return err
   165  	}
   166  	if c, err = ApplyOp(binOp, a, b); err != nil {
   167  		return err
   168  	}
   169  
   170  	var m1 VM
   171  	if op.isArith() {
   172  		if _, err = Sum(z); err != nil {
   173  			return err
   174  		}
   175  		if cost, err = Sum(c); err != nil {
   176  			return err
   177  		}
   178  
   179  		if _, err = Grad(cost, a, b); err != nil {
   180  			return err
   181  		}
   182  		m1 = NewLispMachine(g)
   183  	} else {
   184  		m1 = NewLispMachine(g, ExecuteFwdOnly())
   185  	}
   186  
   187  	// lg := log.New(os.Stderr, "", 0)
   188  	m2 := NewTapeMachine(g2, TraceExec())
   189  	defer m2.Close()
   190  	defer m1.Close()
   191  
   192  	// m2 := NewTapeMachine(prog, locMap, TraceExec(), WithLogger(logger), WithWatchlist())
   193  
   194  	Let(x, xV)
   195  	Let(y, yV)
   196  	if err = m1.RunAll(); err != nil {
   197  		return
   198  	}
   199  
   200  	Let(a, xV)
   201  	Let(b, yV)
   202  	if err = m2.RunAll(); err != nil {
   203  		return
   204  	}
   205  
   206  	var xG, aG, yG, bG, zG, cG Value
   207  	if op.isArith() {
   208  		if xG, err = x.Grad(); err != nil {
   209  			return
   210  		}
   211  		if yG, err = y.Grad(); err != nil {
   212  			return
   213  		}
   214  		if aG, err = a.Grad(); err != nil {
   215  			return
   216  		}
   217  		if bG, err = b.Grad(); err != nil {
   218  			return
   219  		}
   220  
   221  		if zG, err = z.Grad(); err != nil {
   222  			return
   223  		}
   224  		if cG, err = c.Grad(); err != nil {
   225  			return
   226  		}
   227  		assert.True(ValueClose(xG, aG), "Test ttDiff of %v. xG != aG. Got %+v \nand %+v", op, xG, aG)
   228  		assert.True(ValueClose(yG, bG), "Test ttDiff of %v. yG != bG. Got %+v \nand %+v", op, yG, bG)
   229  		assert.True(ValueClose(zG, cG), "Test ttDiff of %v. zG != cG. Got %+v \nand %+v", op, zG, cG)
   230  	}
   231  
   232  	assert.True(ValueClose(x.Value(), a.Value()), "Test tt op %v. Values are different: x: %+v\n a %+v", op, x.Value(), a.Value())
   233  	assert.True(ValueClose(y.Value(), b.Value()), "Test tt op %v. Values are different: y: %+v\n b %+v", op, y.Value(), b.Value())
   234  	assert.True(ValueClose(z.Value(), c.Value()), "Test tt op %v. Values are different: z: %+v\n c %+v", op, z.Value(), c.Value())
   235  
   236  	if t.Failed() {
   237  		ioutil.WriteFile(fmt.Sprintf("Test_%v_tt.dot", op), []byte(g2.ToDot()), 0644)
   238  	}
   239  
   240  	return nil
   241  }
   242  
   243  func TestBinOps(t *testing.T) {
   244  	for op := addOpType; op < maxʘBinaryOpType; op++ {
   245  		t.Logf("OP: %v", op)
   246  
   247  		// if op != addOpType {
   248  		// 	continue
   249  		// }
   250  
   251  		// for op := subOpType; op < mulOpType; op++ {
   252  		var err error
   253  		err = ssBinOpTest(t, op, Float64)
   254  		if err != nil {
   255  			t.Errorf("Float64 version err: %v", err)
   256  		}
   257  
   258  		err = ssBinOpTest(t, op, Float32)
   259  		if err != nil {
   260  			t.Errorf("Float32 version err: %v", err)
   261  		}
   262  
   263  		t.Logf("Float64 T-T test for %v", op)
   264  		err = ttBinOpTest(t, op, Float64)
   265  		if err != nil {
   266  			t.Errorf("ttBinOp Float64 version err %v", err)
   267  		}
   268  
   269  		t.Logf("Float32 T-T test")
   270  		err = ttBinOpTest(t, op, Float32)
   271  		if err != nil {
   272  			t.Errorf("ttBinOp Float64 version err %v", err)
   273  		}
   274  	}
   275  
   276  	// single tests
   277  
   278  	// ttBinOpTest(t, subOpType, Float64)
   279  }