gorgonia.org/gorgonia@v0.9.17/op_tensor_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"crypto/sha256"
     5  	"runtime"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  	"gorgonia.org/tensor"
    10  )
    11  
    12  var repeatOpTests = []struct {
    13  	name string
    14  	rep  int
    15  	axes int
    16  	val  Value
    17  
    18  	correct       Value
    19  	expectedShape tensor.Shape
    20  	err           bool
    21  }{
    22  	{
    23  		"repeat matrix on axis 0", 2, 0,
    24  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4}), tensor.WithShape(2, 2)),
    25  		tensor.New(tensor.WithBacking([]float64{1, 2, 1, 2, 3, 4, 3, 4}), tensor.WithShape(4, 2)),
    26  		tensor.Shape{4, 2}, false,
    27  	},
    28  
    29  	{
    30  		"repeat matrix on axis 1", 2, 1,
    31  		tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4}), tensor.WithShape(2, 2)),
    32  		tensor.New(tensor.WithBacking([]float64{1, 1, 2, 2, 3, 3, 4, 4}), tensor.WithShape(2, 4)),
    33  		tensor.Shape{2, 4}, false,
    34  	},
    35  
    36  	{
    37  		"repeat col vec on axis 0", 2, 0,
    38  		tensor.New(tensor.WithBacking([]float64{1, 2}), tensor.WithShape(2, 1)),
    39  		tensor.New(tensor.WithBacking([]float64{1, 1, 2, 2}), tensor.WithShape(4, 1)),
    40  		tensor.Shape{4, 1}, false,
    41  	},
    42  
    43  	{
    44  		"repeat col vec on axis 1", 2, 1,
    45  		tensor.New(tensor.WithBacking([]float64{1, 2}), tensor.WithShape(2, 1)),
    46  		tensor.New(tensor.WithBacking([]float64{1, 1, 2, 2}), tensor.WithShape(2, 2)),
    47  		tensor.Shape{2, 2}, false,
    48  	},
    49  
    50  	{
    51  		"repeat row vec on axis 0", 2, 0,
    52  		tensor.New(tensor.WithBacking([]float64{1, 2}), tensor.WithShape(1, 2)),
    53  		tensor.New(tensor.WithBacking([]float64{1, 2, 1, 2}), tensor.WithShape(2, 2)),
    54  		tensor.Shape{2, 2}, false,
    55  	},
    56  
    57  	{
    58  		"repeat row vec on axis 1", 2, 1,
    59  		tensor.New(tensor.WithBacking([]float64{1, 2}), tensor.WithShape(1, 2)),
    60  		tensor.New(tensor.WithBacking([]float64{1, 1, 2, 2}), tensor.WithShape(1, 4)),
    61  		tensor.Shape{1, 4}, false,
    62  	},
    63  
    64  	{
    65  		"repeat vector on axis 0", 2, 0,
    66  		tensor.New(tensor.WithBacking([]float64{1, 2}), tensor.WithShape(2)),
    67  		tensor.New(tensor.WithBacking([]float64{1, 1, 2, 2}), tensor.WithShape(4)),
    68  		tensor.Shape{4}, false,
    69  	},
    70  
    71  	{
    72  		"repeat vector on axis 1", 2, 1,
    73  		tensor.New(tensor.WithBacking([]float64{1, 2}), tensor.WithShape(2)),
    74  		tensor.New(tensor.WithBacking([]float64{1, 1, 2, 2}), tensor.WithShape(2, 2)),
    75  		tensor.Shape{2, 2}, false,
    76  	},
    77  
    78  	{
    79  		"repeat scalar", 2, 0,
    80  		NewF64(3.14), tensor.New(tensor.WithBacking([]float64{3.14, 3.14}), tensor.WithShape(2)),
    81  		tensor.Shape{2}, false,
    82  	},
    83  }
    84  
    85  func TestRepeatOp(t *testing.T) {
    86  	// assert := assert.New(t)
    87  
    88  	for _, rots := range repeatOpTests {
    89  		// if rots.name != "repeat matrix on axis 1" {
    90  		// 	continue
    91  		// }
    92  		g := NewGraph()
    93  		var res Value
    94  		var err error
    95  		var repeat *repeatOp
    96  
    97  		rep := NewI(rots.rep)
    98  		n := NodeFromAny(g, rots.val)
    99  
   100  		repeat = newRepeatOp(rots.axes, n)
   101  
   102  		res, err = repeat.Do(rots.val, rep)
   103  		switch {
   104  		case rots.err:
   105  			if err == nil {
   106  				t.Errorf("Test %q: Expected an error", rots.name)
   107  			}
   108  			goto infershape
   109  		case !rots.err && err != nil:
   110  			t.Errorf("%+v", err)
   111  			goto infershape
   112  		}
   113  
   114  		if !ValueEq(res, rots.correct) {
   115  			t.Errorf("Test %q: Expected %v. Got %v", rots.name, rots.correct, res)
   116  		}
   117  
   118  	infershape:
   119  		var s tensor.Shape
   120  		size := sizeOp{axis: rots.axes, val: rots.rep}
   121  		s, err = repeat.InferShape(rots.val.Shape(), size)
   122  		switch {
   123  		case rots.err:
   124  			if err == nil {
   125  				t.Error("Expected an error")
   126  			}
   127  			continue
   128  		case !rots.err && err != nil:
   129  			t.Errorf("Test %q %+v", rots.name, err)
   130  			continue
   131  		}
   132  
   133  		if !rots.expectedShape.Eq(s) {
   134  			t.Errorf("Test %q InferShape: Expected %v. Got %v instead", rots.name, rots.expectedShape, s)
   135  		}
   136  	}
   137  }
   138  
   139  func repeatOpDiff(repeatOn int, shape tensor.Shape, xV, yV interface{}) (g *ExprGraph, x, y *Node, err error) {
   140  	g = NewGraph()
   141  	switch shape.Dims() {
   142  	case 0:
   143  		x = NewScalar(g, Float64, WithName("x"))
   144  	case 1:
   145  		// vanilla vector
   146  		x = NewVector(g, Float64, WithName("x"), WithShape(shape...))
   147  	case 2:
   148  		x = NewMatrix(g, Float64, WithName("x"), WithShape(shape...))
   149  	default:
   150  		//matrix and tensors
   151  		x = NewTensor(g, Float64, shape.Dims(), WithName("x"), WithShape(shape...))
   152  	}
   153  
   154  	repOp := sizeOp{axis: repeatOn, val: 2}
   155  	repN := NewScalar(g, Float64, WithName("REPCONST"), WithOp(repOp), WithValue(2.0))
   156  	repeat := newRepeatOp(repeatOn, x)
   157  
   158  	if y, err = ApplyOp(repeat, x, repN); err != nil {
   159  		return
   160  	}
   161  	xVal, _, _, _ := anyToValue(xV)
   162  	yVal, _, _, _ := anyToValue(yV)
   163  	x.bind(dvUnit(xVal))
   164  	y.bind(dvUnitVar(yVal))
   165  	if err = repeat.DoDiff(ExecutionContext{}, Nodes{x, repN}, y); err != nil {
   166  		return
   167  	}
   168  	return
   169  }
   170  
   171  func TestRepeatOpDoDiff(t *testing.T) {
   172  	//t.SkipNow()
   173  	assert := assert.New(t)
   174  	// var g *ExprGraph
   175  	// var x, y, repN *Node
   176  	// var repeat *repeatOp
   177  	var x *Node
   178  	var err error
   179  
   180  	var xG Value
   181  	var xT, yT *tensor.Dense
   182  
   183  	yT = tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{3.14, 3.14}))
   184  
   185  	// scalar repeated into a vec/colvec
   186  	if _, x, _, err = repeatOpDiff(0, scalarShape, 3.14, yT); err != nil {
   187  		t.Fatal(err)
   188  	}
   189  	xG, _ = x.Grad()
   190  	assert.Equal(2.0, extractF64(xG))
   191  
   192  	// scalar repeated into a rowvec
   193  	// if _, x, _, err = repeatOpDiff(1, scalarShape, 3.14, yT); err != nil {
   194  	// 	t.Fatal(err)
   195  	// }
   196  	// xG, _ = x.Grad()
   197  	// assert.Equal(2.0, extractF64(xG))
   198  
   199  	// vector repeated unto itself
   200  	xT = tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{3.14, 3.14}))
   201  	yT = tensor.New(tensor.WithShape(4), tensor.WithBacking([]float64{3.14, 3.14, 3.14, 3.14}))
   202  	if _, x, _, err = repeatOpDiff(0, tensor.Shape{2}, xT, yT); err != nil {
   203  		t.Fatal(err)
   204  	}
   205  	xG, _ = x.Grad()
   206  	assert.Equal([]float64{2, 2}, extractF64s(xG))
   207  
   208  	// colvec repeated unto itself
   209  	xT = tensor.New(tensor.WithShape(2, 1), tensor.WithBacking([]float64{3.14, 3.14}))
   210  	yT = tensor.New(tensor.WithShape(4, 1), tensor.WithBacking([]float64{3.14, 3.14, 3.14, 3.14}))
   211  	if _, x, _, err = repeatOpDiff(0, tensor.Shape{2}, xT, yT); err != nil {
   212  		t.Fatal(err)
   213  	}
   214  	xG, _ = x.Grad()
   215  	assert.Equal([]float64{2, 2}, extractF64s(xG))
   216  
   217  	// rowvec repeated unto itself
   218  	xT = tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float64{3.14, 3.14}))
   219  	yT = tensor.New(tensor.WithShape(1, 4), tensor.WithBacking([]float64{3.14, 3.14, 3.14, 3.14}))
   220  	if _, x, _, err = repeatOpDiff(1, tensor.Shape{1, 2}, xT, yT); err != nil {
   221  		t.Fatal(err)
   222  	}
   223  	xG, _ = x.Grad()
   224  	assert.Equal([]float64{2, 2}, extractF64s(xG))
   225  
   226  	// matrix on axis 0
   227  	xT = tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{3.14, 2.718, 1.618, 1.414}))
   228  	yT = tensor.New(tensor.WithShape(4, 2), tensor.WithBacking([]float64{3.14, 2.718, 3.14, 2.718, 1.618, 1.414, 1.618, 1.414}))
   229  	if _, x, _, err = repeatOpDiff(0, tensor.Shape{1, 2}, xT, yT); err != nil {
   230  		t.Fatal(err)
   231  	}
   232  	xG, _ = x.Grad()
   233  	assert.Equal([]float64{2, 2, 2, 2}, extractF64s(xG))
   234  
   235  	// matrix on axis 1
   236  	xT = tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{3.14, 2.718, 1.618, 1.414}))
   237  	yT = tensor.New(tensor.WithShape(4, 2), tensor.WithBacking([]float64{3.14, 2.718, 3.14, 2.718, 1.618, 1.414, 1.618, 1.414}))
   238  	if _, x, _, err = repeatOpDiff(1, tensor.Shape{1, 2}, xT, yT); err != nil {
   239  		t.Fatal(err)
   240  	}
   241  	xG, _ = x.Grad()
   242  	assert.Equal([]float64{2, 2, 2, 2}, extractF64s(xG))
   243  
   244  }
   245  
   246  func TestTransposeOp(t *testing.T) {
   247  	assert := assert.New(t)
   248  	g := NewGraph()
   249  	A := NewMatrix(g, Float64, WithShape(2, 3), WithInit(RangedFrom(0)))
   250  	AT := Must(Transpose(A))
   251  	cost1 := Must(Sum(AT))
   252  
   253  	var m VM
   254  	var err error
   255  
   256  	m = NewLispMachine(g)
   257  	defer m.Close()
   258  	if err = m.RunAll(); err != nil {
   259  		t.Error(err)
   260  	}
   261  
   262  	assert.Equal(tensor.Shape{3, 2}, AT.shape)
   263  
   264  	h := NewGraph()
   265  	B := NewMatrix(h, Float64, WithShape(2, 3), WithInit(RangedFrom(0)))
   266  	BT := Must(Transpose(B))
   267  	cost2 := Must(Sum(BT))
   268  	Grad(cost2, B)
   269  
   270  	m = NewTapeMachine(h)
   271  	defer m.Close()
   272  	if err = m.RunAll(); err != nil {
   273  		t.Error(err)
   274  	}
   275  	assert.Equal(tensor.Shape{3, 2}, BT.shape)
   276  
   277  	var ag, bg Value
   278  	if ag, err = A.Grad(); err != nil {
   279  		t.Fatalf("Cannot get grad of A. Err: %v", err)
   280  	}
   281  
   282  	if bg, err = B.Grad(); err != nil {
   283  		t.Fatalf("Cannot get grad of B. Err: %v", err)
   284  	}
   285  
   286  	var costGrad1, costGrad2 Value
   287  	if costGrad1, err = cost1.Grad(); err != nil {
   288  		t.Fatalf("Cannot get grad of Cost1. Err %v", err)
   289  	}
   290  
   291  	if costGrad2, err = cost2.Grad(); err != nil {
   292  		t.Fatalf("Cannot get grad of Cost2. Err %v", err)
   293  	}
   294  
   295  	t.Logf("%v %v", cost1.Value(), cost2.Value())
   296  	t.Logf("%v %v", costGrad1, costGrad2)
   297  
   298  	assert.True(ValueEq(ag, bg))
   299  }
   300  
   301  func TestConcatOp(t *testing.T) {
   302  	defer runtime.GC()
   303  
   304  	assert := assert.New(t)
   305  	g := NewGraph()
   306  	x := NewVector(g, Float64, WithShape(2))
   307  	xx, err := Concat(0, x, x)
   308  	if err != nil {
   309  		t.Fatalf("%+v", err)
   310  	}
   311  
   312  	cost := Must(Sum(xx))
   313  	Grad(cost, x)
   314  
   315  	g2 := NewGraph()
   316  	a := NewVector(g2, Float64, WithShape(2))
   317  	aa, err := Concat(0, a, a)
   318  	if err != nil {
   319  		t.Fatalf("%+v", err)
   320  	}
   321  	Must(Sum(aa)) // cost
   322  
   323  	aBack := []float64{1, 2}
   324  	aT := tensor.New(tensor.WithShape(2), tensor.WithBacking(aBack))
   325  
   326  	xBack := []float64{1, 2}
   327  	xT := tensor.New(tensor.WithShape(2), tensor.WithBacking(xBack))
   328  
   329  	Let(a, aT)
   330  	Let(x, xT)
   331  	m1 := NewTapeMachine(g)
   332  	m2 := NewLispMachine(g2)
   333  	defer m1.Close()
   334  	defer m2.Close()
   335  
   336  	if err = m1.RunAll(); err != nil {
   337  		t.Fatal(err)
   338  	}
   339  
   340  	if err = m2.RunAll(); err != nil {
   341  		t.Fatalf("%+v", err)
   342  	}
   343  
   344  	xG, _ := x.Grad()
   345  	aG, _ := a.Grad()
   346  	assert.True(ValueEq(xG, aG))
   347  	assert.True(ValueEq(xx.Value(), aa.Value()))
   348  }
   349  
   350  func Test_atOp_WriteHash(t *testing.T) {
   351  	defer func() {
   352  		if r := recover(); r != nil {
   353  			t.Fail()
   354  		}
   355  	}()
   356  	h := sha256.New()
   357  
   358  	at := &atOp{}
   359  	at.WriteHash(h)
   360  }