gorgonia.org/tensor@v0.9.24/dense_matop_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/assert"
     8  	"gorgonia.org/vecf64"
     9  )
    10  
    11  func cloneArray(a interface{}) interface{} {
    12  	switch at := a.(type) {
    13  	case []float64:
    14  		retVal := make([]float64, len(at))
    15  		copy(retVal, at)
    16  		return retVal
    17  	case []float32:
    18  		retVal := make([]float32, len(at))
    19  		copy(retVal, at)
    20  		return retVal
    21  	case []int:
    22  		retVal := make([]int, len(at))
    23  		copy(retVal, at)
    24  		return retVal
    25  	case []int64:
    26  		retVal := make([]int64, len(at))
    27  		copy(retVal, at)
    28  		return retVal
    29  	case []int32:
    30  		retVal := make([]int32, len(at))
    31  		copy(retVal, at)
    32  		return retVal
    33  	case []byte:
    34  		retVal := make([]byte, len(at))
    35  		copy(retVal, at)
    36  		return retVal
    37  	case []bool:
    38  		retVal := make([]bool, len(at))
    39  		copy(retVal, at)
    40  		return retVal
    41  	}
    42  	return nil
    43  }
    44  
    45  func castToDt(val float64, dt Dtype) interface{} {
    46  	switch dt {
    47  	case Bool:
    48  		return false
    49  	case Int:
    50  		return int(val)
    51  	case Int8:
    52  		return int8(val)
    53  	case Int16:
    54  		return int16(val)
    55  	case Int32:
    56  		return int32(val)
    57  	case Int64:
    58  		return int64(val)
    59  	case Uint:
    60  		return uint(val)
    61  	case Uint8:
    62  		return uint8(val)
    63  	case Uint16:
    64  		return uint16(val)
    65  	case Uint32:
    66  		return uint32(val)
    67  	case Uint64:
    68  		return uint64(val)
    69  	case Float32:
    70  		return float32(val)
    71  	case Float64:
    72  		return float64(val)
    73  	default:
    74  		return 0
    75  	}
    76  }
    77  
    78  var atTests = []struct {
    79  	data  interface{}
    80  	shape Shape
    81  	coord []int
    82  
    83  	correct interface{}
    84  	err     bool
    85  }{
    86  	// matrix
    87  	{[]float64{0, 1, 2, 3, 4, 5}, Shape{2, 3}, []int{0, 1}, float64(1), false},
    88  	{[]float32{0, 1, 2, 3, 4, 5}, Shape{2, 3}, []int{1, 1}, float32(4), false},
    89  	{[]float64{0, 1, 2, 3, 4, 5}, Shape{2, 3}, []int{1, 2, 3}, nil, true},
    90  
    91  	// 3-tensor
    92  	{[]int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
    93  		Shape{2, 3, 4}, []int{1, 1, 1}, 17, false},
    94  	{[]int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
    95  		Shape{2, 3, 4}, []int{1, 2, 3}, int64(23), false},
    96  	{[]int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
    97  		Shape{2, 3, 4}, []int{0, 3, 2}, 23, true},
    98  }
    99  
   100  func TestDense_At(t *testing.T) {
   101  	for i, ats := range atTests {
   102  		T := New(WithShape(ats.shape...), WithBacking(ats.data))
   103  		got, err := T.At(ats.coord...)
   104  		if checkErr(t, ats.err, err, "At", i) {
   105  			continue
   106  		}
   107  
   108  		if got != ats.correct {
   109  			t.Errorf("Expected %v. Got %v", ats.correct, got)
   110  		}
   111  	}
   112  }
   113  
   114  func Test_transposeIndex(t *testing.T) {
   115  	a := []byte{0, 1, 2, 3}
   116  	T := New(WithShape(2, 2), WithBacking(a))
   117  
   118  	correct := []int{0, 2, 1, 3}
   119  	for i, v := range correct {
   120  		got := T.transposeIndex(i, []int{1, 0}, []int{2, 1})
   121  		if v != got {
   122  			t.Errorf("transposeIndex error. Expected %v. Got %v", v, got)
   123  		}
   124  	}
   125  }
   126  
   127  var transposeTests = []struct {
   128  	name          string
   129  	shape         Shape
   130  	transposeWith []int
   131  	data          interface{}
   132  
   133  	correctShape    Shape
   134  	correctStrides  []int // after .T()
   135  	correctStrides2 []int // after .Transpose()
   136  	correctData     interface{}
   137  }{
   138  	{"c.T()", Shape{4, 1}, nil, []float64{0, 1, 2, 3},
   139  		Shape{1, 4}, []int{1, 1}, []int{4, 1}, []float64{0, 1, 2, 3}},
   140  
   141  	{"r.T()", Shape{1, 4}, nil, []float32{0, 1, 2, 3},
   142  		Shape{4, 1}, []int{1, 1}, []int{1, 1}, []float32{0, 1, 2, 3}},
   143  
   144  	{"v.T()", Shape{4}, nil, []int{0, 1, 2, 3},
   145  		Shape{4}, []int{1}, []int{1}, []int{0, 1, 2, 3}},
   146  
   147  	{"M.T()", Shape{2, 3}, nil, []int64{0, 1, 2, 3, 4, 5},
   148  		Shape{3, 2}, []int{1, 3}, []int{2, 1}, []int64{0, 3, 1, 4, 2, 5}},
   149  
   150  	{"M.T(0,1) (NOOP)", Shape{2, 3}, []int{0, 1}, []int32{0, 1, 2, 3, 4, 5},
   151  		Shape{2, 3}, []int{3, 1}, []int{3, 1}, []int32{0, 1, 2, 3, 4, 5}},
   152  
   153  	{"3T.T()", Shape{2, 3, 4}, nil,
   154  		[]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
   155  
   156  		Shape{4, 3, 2}, []int{1, 4, 12}, []int{6, 2, 1},
   157  		[]byte{0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23}},
   158  
   159  	{"3T.T(2, 1, 0) (Same as .T())", Shape{2, 3, 4}, []int{2, 1, 0},
   160  		[]int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
   161  		Shape{4, 3, 2}, []int{1, 4, 12}, []int{6, 2, 1},
   162  		[]int{0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23}},
   163  
   164  	{"3T.T(2, 1, 0) (Same as .T())", Shape{2, 3, 4}, []int{2, 1, 0},
   165  		[]int16{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
   166  		Shape{4, 3, 2}, []int{1, 4, 12}, []int{6, 2, 1},
   167  		[]int16{0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23}},
   168  
   169  	{"3T.T(0, 2, 1)", Shape{2, 3, 4}, []int{0, 2, 1},
   170  		[]int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
   171  		Shape{2, 4, 3}, []int{12, 1, 4}, []int{12, 3, 1},
   172  		[]int32{0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11, 12, 16, 20, 13, 17, 21, 14, 18, 22, 15, 19, 23}},
   173  
   174  	{"3T.T{1, 0, 2)", Shape{2, 3, 4}, []int{1, 0, 2},
   175  		[]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
   176  		Shape{3, 2, 4}, []int{4, 12, 1}, []int{8, 4, 1},
   177  		[]float64{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23}},
   178  
   179  	{"3T.T{1, 2, 0)", Shape{2, 3, 4}, []int{1, 2, 0},
   180  		[]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
   181  		Shape{3, 4, 2}, []int{4, 1, 12}, []int{8, 2, 1},
   182  		[]float64{0, 12, 1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18, 7, 19, 8, 20, 9, 21, 10, 22, 11, 23}},
   183  
   184  	{"3T.T{2, 0, 1)", Shape{2, 3, 4}, []int{2, 0, 1},
   185  		[]float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
   186  		Shape{4, 2, 3}, []int{1, 12, 4}, []int{6, 3, 1},
   187  		[]float32{0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}},
   188  
   189  	{"3T.T{0, 1, 2} (NOOP)", Shape{2, 3, 4}, []int{0, 1, 2},
   190  		[]bool{true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false},
   191  		Shape{2, 3, 4}, []int{12, 4, 1}, []int{12, 4, 1},
   192  		[]bool{true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false}},
   193  
   194  	{"M[2,2].T for bools, just for completeness sake", Shape{2, 2}, nil,
   195  		[]bool{true, true, false, false},
   196  		Shape{2, 2}, []int{1, 2}, []int{2, 1},
   197  		[]bool{true, false, true, false},
   198  	},
   199  
   200  	{"M[2,2].T for strings, just for completeness sake", Shape{2, 2}, nil,
   201  		[]string{"hello", "world", "今日は", "世界"},
   202  		Shape{2, 2}, []int{1, 2}, []int{2, 1},
   203  		[]string{"hello", "今日は", "world", "世界"},
   204  	},
   205  }
   206  
   207  func TestDense_Transpose(t *testing.T) {
   208  	assert := assert.New(t)
   209  	var err error
   210  
   211  	// standard transposes
   212  	for _, tts := range transposeTests {
   213  		T := New(WithShape(tts.shape...), WithBacking(tts.data))
   214  		if err = T.T(tts.transposeWith...); err != nil {
   215  			t.Errorf("%v - %v", tts.name, err)
   216  			continue
   217  		}
   218  
   219  		assert.True(tts.correctShape.Eq(T.Shape()), "Transpose %v Expected shape: %v. Got %v", tts.name, tts.correctShape, T.Shape())
   220  		assert.Equal(tts.correctStrides, T.Strides(), "Transpose %v. Expected stride: %v. Got %v", tts.name, tts.correctStrides, T.Strides())
   221  		T.Transpose()
   222  		assert.True(tts.correctShape.Eq(T.Shape()), "Transpose %v Expected shape: %v. Got %v", tts.name, tts.correctShape, T.Shape())
   223  		assert.Equal(tts.correctStrides2, T.Strides(), "Transpose2 %v - Expected stride %v. Got %v", tts.name, tts.correctStrides2, T.Strides())
   224  		assert.Equal(tts.correctData, T.Data(), "Transpose %v", tts.name)
   225  	}
   226  
   227  	// test stacked .T() calls
   228  	var T *Dense
   229  
   230  	// column vector
   231  	T = New(WithShape(4, 1), WithBacking(Range(Int, 0, 4)))
   232  	if err = T.T(); err != nil {
   233  		t.Errorf("Stacked .T() #1 for vector. Error: %v", err)
   234  		goto matrev
   235  	}
   236  	if err = T.T(); err != nil {
   237  		t.Errorf("Stacked .T() #1 for vector. Error: %v", err)
   238  		goto matrev
   239  	}
   240  	assert.True(T.old.IsZero())
   241  	assert.Nil(T.transposeWith)
   242  	assert.True(T.IsColVec())
   243  
   244  matrev:
   245  	// matrix, reversed
   246  	T = New(WithShape(2, 3), WithBacking(Range(Byte, 0, 6)))
   247  	if err = T.T(); err != nil {
   248  		t.Errorf("Stacked .T() #1 for matrix reverse. Error: %v", err)
   249  		goto matnorev
   250  	}
   251  	if err = T.T(); err != nil {
   252  		t.Errorf("Stacked .T() #2 for matrix reverse. Error: %v", err)
   253  		goto matnorev
   254  	}
   255  	assert.True(T.old.IsZero())
   256  	assert.Nil(T.transposeWith)
   257  	assert.True(Shape{2, 3}.Eq(T.Shape()))
   258  
   259  matnorev:
   260  	// 3-tensor, non reversed
   261  	T = New(WithShape(2, 3, 4), WithBacking(Range(Int64, 0, 24)))
   262  	if err = T.T(); err != nil {
   263  		t.Fatalf("Stacked .T() #1 for tensor with no reverse. Error: %v", err)
   264  	}
   265  	if err = T.T(2, 0, 1); err != nil {
   266  		t.Fatalf("Stacked .T() #2 for tensor with no reverse. Error: %v", err)
   267  	}
   268  	correctData := []int64{0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23}
   269  	assert.Equal(correctData, T.Data())
   270  	assert.Equal([]int{2, 0, 1}, T.transposeWith)
   271  	assert.NotNil(T.old)
   272  
   273  }
   274  
   275  func TestTUT(t *testing.T) {
   276  	assert := assert.New(t)
   277  	var T *Dense
   278  
   279  	T = New(Of(Float64), WithShape(2, 3, 4))
   280  	T.T()
   281  	T.UT()
   282  	assert.True(T.old.IsZero())
   283  	assert.Nil(T.transposeWith)
   284  
   285  	T.T(2, 0, 1)
   286  	T.UT()
   287  	assert.True(T.old.IsZero())
   288  	assert.Nil(T.transposeWith)
   289  }
   290  
   291  type repeatTest struct {
   292  	name    string
   293  	tensor  *Dense
   294  	ne      bool // should assert tensor not equal
   295  	axis    int
   296  	repeats []int
   297  
   298  	correct interface{}
   299  	shape   Shape
   300  	err     bool
   301  }
   302  
   303  var repeatTests = []repeatTest{
   304  	{"Scalar Repeat on axis 0", New(FromScalar(true)),
   305  		true, 0, []int{3},
   306  		[]bool{true, true, true},
   307  		Shape{3}, false,
   308  	},
   309  
   310  	{"Scalar Repeat on axis 1", New(FromScalar(byte(255))),
   311  		false, 1, []int{3},
   312  		[]byte{255, 255, 255},
   313  		Shape{1, 3}, false,
   314  	},
   315  
   316  	{"Vector Repeat on axis 0", New(WithShape(2), WithBacking([]int32{1, 2})),
   317  		false, 0, []int{3},
   318  		[]int32{1, 1, 1, 2, 2, 2},
   319  		Shape{6}, false,
   320  	},
   321  
   322  	{"ColVec Repeat on axis 0", New(WithShape(2, 1), WithBacking([]int64{1, 2})),
   323  		false, 0, []int{3},
   324  		[]int64{1, 1, 1, 2, 2, 2},
   325  		Shape{6, 1}, false,
   326  	},
   327  
   328  	{"RowVec Repeat on axis 0", New(WithShape(1, 2), WithBacking([]int{1, 2})),
   329  		false, 0, []int{3},
   330  		[]int{1, 2, 1, 2, 1, 2},
   331  		Shape{3, 2}, false,
   332  	},
   333  
   334  	{"ColVec Repeat on axis 1", New(WithShape(2, 1), WithBacking([]float32{1, 2})),
   335  		false, 1, []int{3},
   336  		[]float32{1, 1, 1, 2, 2, 2},
   337  		Shape{2, 3}, false,
   338  	},
   339  
   340  	{"RowVec Repeat on axis 1", New(WithShape(1, 2), WithBacking([]float64{1, 2})),
   341  		false, 1, []int{3},
   342  		[]float64{1, 1, 1, 2, 2, 2},
   343  		Shape{1, 6}, false,
   344  	},
   345  
   346  	{"Vector Repeat on all axes", New(WithShape(2), WithBacking([]byte{1, 2})),
   347  		false, AllAxes, []int{3},
   348  		[]byte{1, 1, 1, 2, 2, 2},
   349  		Shape{6}, false,
   350  	},
   351  
   352  	{"ColVec Repeat on all axes", New(WithShape(2, 1), WithBacking([]int32{1, 2})),
   353  		false, AllAxes, []int{3},
   354  		[]int32{1, 1, 1, 2, 2, 2},
   355  		Shape{6}, false,
   356  	},
   357  
   358  	{"RowVec Repeat on all axes", New(WithShape(1, 2), WithBacking([]int64{1, 2})),
   359  		false, AllAxes, []int{3},
   360  		[]int64{1, 1, 1, 2, 2, 2},
   361  		Shape{6}, false,
   362  	},
   363  
   364  	{"M[2,2] Repeat on all axes with repeats = (1,2,1,1)", New(WithShape(2, 2), WithBacking([]int{1, 2, 3, 4})),
   365  		false, AllAxes, []int{1, 2, 1, 1},
   366  		[]int{1, 2, 2, 3, 4},
   367  		Shape{5}, false,
   368  	},
   369  
   370  	{"M[2,2] Repeat on axis 1 with repeats = (2, 1)", New(WithShape(2, 2), WithBacking([]float32{1, 2, 3, 4})),
   371  		false, 1, []int{2, 1},
   372  		[]float32{1, 1, 2, 3, 3, 4},
   373  		Shape{2, 3}, false,
   374  	},
   375  
   376  	{"M[2,2] Repeat on axis 1 with repeats = (1, 2)", New(WithShape(2, 2), WithBacking([]float64{1, 2, 3, 4})),
   377  		false, 1, []int{1, 2},
   378  		[]float64{1, 2, 2, 3, 4, 4},
   379  		Shape{2, 3}, false,
   380  	},
   381  
   382  	{"M[2,2] Repeat on axis 0 with repeats = (1, 2)", New(WithShape(2, 2), WithBacking([]float64{1, 2, 3, 4})),
   383  		false, 0, []int{1, 2},
   384  		[]float64{1, 2, 3, 4, 3, 4},
   385  		Shape{3, 2}, false,
   386  	},
   387  
   388  	{"M[2,2] Repeat on axis 0 with repeats = (2, 1)", New(WithShape(2, 2), WithBacking([]float64{1, 2, 3, 4})),
   389  		false, 0, []int{2, 1},
   390  		[]float64{1, 2, 1, 2, 3, 4},
   391  		Shape{3, 2}, false,
   392  	},
   393  
   394  	{"3T[2,3,2] Repeat on axis 1 with repeats = (1,2,1)", New(WithShape(2, 3, 2), WithBacking(vecf64.Range(1, 2*3*2+1))),
   395  		false, 1, []int{1, 2, 1},
   396  		[]float64{1, 2, 3, 4, 3, 4, 5, 6, 7, 8, 9, 10, 9, 10, 11, 12},
   397  		Shape{2, 4, 2}, false,
   398  	},
   399  
   400  	{"3T[2,3,2] Generic Repeat by 2", New(WithShape(2, 3, 2), WithBacking(vecf64.Range(1, 2*3*2+1))),
   401  		false, AllAxes, []int{2},
   402  		[]float64{1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12},
   403  		Shape{24}, false,
   404  	},
   405  
   406  	{"3T[2,3,2] repeat with broadcast errors", New(WithShape(2, 3, 2), WithBacking(vecf64.Range(1, 2*3*2+1))),
   407  		false, 0, []int{1, 2, 1},
   408  		nil, nil, true,
   409  	},
   410  
   411  	// idiots
   412  	{"Nonexistent axis", New(WithShape(2, 1), WithBacking([]bool{true, false})),
   413  		false, 2, []int{3}, nil, nil, true,
   414  	},
   415  }
   416  
   417  func TestDense_Repeat(t *testing.T) {
   418  	assert := assert.New(t)
   419  
   420  	for i, test := range repeatTests {
   421  		T, err := test.tensor.Repeat(test.axis, test.repeats...)
   422  		if checkErr(t, test.err, err, "Repeat", i) {
   423  			continue
   424  		}
   425  
   426  		var D DenseTensor
   427  		if D, err = getDenseTensor(T); err != nil {
   428  			t.Errorf("Expected Repeat to return a *Dense. got %v of %T instead", T, T)
   429  			continue
   430  		}
   431  
   432  		if test.ne {
   433  			assert.NotEqual(test.tensor, D, test.name)
   434  		}
   435  
   436  		assert.Equal(test.correct, D.Data(), test.name)
   437  		assert.Equal(test.shape, D.Shape(), test.name)
   438  	}
   439  }
   440  
   441  func TestDense_Repeat_Slow(t *testing.T) {
   442  	rt2 := make([]repeatTest, len(repeatTests))
   443  	for i, rt := range repeatTests {
   444  		rt2[i] = repeatTest{
   445  			name:    rt.name,
   446  			ne:      rt.ne,
   447  			axis:    rt.axis,
   448  			repeats: rt.repeats,
   449  			correct: rt.correct,
   450  			shape:   rt.shape,
   451  			err:     rt.err,
   452  			tensor:  rt.tensor.Clone().(*Dense),
   453  		}
   454  	}
   455  	for i := range rt2 {
   456  		maskLen := rt2[i].tensor.len()
   457  		mask := make([]bool, maskLen)
   458  		rt2[i].tensor.mask = mask
   459  	}
   460  
   461  	assert := assert.New(t)
   462  
   463  	for i, test := range rt2 {
   464  		T, err := test.tensor.Repeat(test.axis, test.repeats...)
   465  		if checkErr(t, test.err, err, "Repeat", i) {
   466  			continue
   467  		}
   468  
   469  		var D DenseTensor
   470  		if D, err = getDenseTensor(T); err != nil {
   471  			t.Errorf("Expected Repeat to return a *Dense. got %v of %T instead", T, T)
   472  			continue
   473  		}
   474  
   475  		if test.ne {
   476  			assert.NotEqual(test.tensor, D, test.name)
   477  		}
   478  
   479  		assert.Equal(test.correct, D.Data(), test.name)
   480  		assert.Equal(test.shape, D.Shape(), test.name)
   481  	}
   482  }
   483  
   484  func TestDense_CopyTo(t *testing.T) {
   485  	assert := assert.New(t)
   486  	var T, T2 *Dense
   487  	var T3 Tensor
   488  	var err error
   489  
   490  	T = New(WithShape(2), WithBacking([]float64{1, 2}))
   491  	T2 = New(Of(Float64), WithShape(1, 2))
   492  
   493  	err = T.CopyTo(T2)
   494  	if err != nil {
   495  		t.Fatal(err)
   496  	}
   497  	assert.Equal(T2.Data(), T.Data())
   498  
   499  	// now, modify T1's data
   500  	T.Set(0, float64(5000))
   501  	assert.NotEqual(T2.Data(), T.Data())
   502  
   503  	// test views
   504  	T = New(Of(Byte), WithShape(3, 3))
   505  	T2 = New(Of(Byte), WithShape(2, 2))
   506  	T3, _ = T.Slice(makeRS(0, 2), makeRS(0, 2)) // T[0:2, 0:2], shape == (2,2)
   507  	if err = T2.CopyTo(T3.(*Dense)); err != nil {
   508  		t.Log(err) // for now it's a not yet implemented error. TODO: FIX THIS
   509  	}
   510  
   511  	// dumbass time
   512  
   513  	T = New(Of(Float32), WithShape(3, 3))
   514  	T2 = New(Of(Float32), WithShape(2, 2))
   515  	if err = T.CopyTo(T2); err == nil {
   516  		t.Error("Expected an error")
   517  	}
   518  
   519  	if err = T.CopyTo(T); err != nil {
   520  		t.Error("Copying a *Tensor to itself should yield no error. ")
   521  	}
   522  
   523  }
   524  
   525  var denseSliceTests = []struct {
   526  	name   string
   527  	data   interface{}
   528  	shape  Shape
   529  	slices []Slice
   530  
   531  	correctShape  Shape
   532  	correctStride []int
   533  	correctData   interface{}
   534  }{
   535  	// scalar-equiv vector (issue 102)
   536  	{"a[0], a is scalar-equiv", []float64{2},
   537  		Shape{1}, []Slice{ss(0)}, ScalarShape(), nil, 2.0},
   538  
   539  	// vector
   540  	{"a[0]", []bool{true, true, false, false, false},
   541  		Shape{5}, []Slice{ss(0)}, ScalarShape(), nil, true},
   542  	{"a[0:2]", Range(Byte, 0, 5), Shape{5}, []Slice{makeRS(0, 2)}, Shape{2}, []int{1}, []byte{0, 1}},
   543  	{"a[1:5:2]", Range(Int32, 0, 5), Shape{5}, []Slice{makeRS(1, 5, 2)}, Shape{2}, []int{2}, []int32{1, 2, 3, 4}},
   544  
   545  	// colvec
   546  	{"c[0]", Range(Int64, 0, 5), Shape{5, 1}, []Slice{ss(0)}, ScalarShape(), nil, int64(0)},
   547  	{"c[0:2]", Range(Float32, 0, 5), Shape{5, 1}, []Slice{makeRS(0, 2)}, Shape{2, 1}, []int{1, 1}, []float32{0, 1}},
   548  	{"c[1:5:2]", Range(Float64, 0, 5), Shape{5, 1}, []Slice{makeRS(0, 5, 2)}, Shape{2, 1}, []int{2, 1}, []float64{0, 1, 2, 3, 4}},
   549  
   550  	// // rowvec
   551  	{"r[0]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{ss(0)}, Shape{1, 5}, []int{1}, []float64{0, 1, 2, 3, 4}},
   552  	{"r[0:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{makeRS(0, 2)}, Shape{1, 5}, []int{1}, []float64{0, 1, 2, 3, 4}},
   553  	{"r[0:5:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{makeRS(0, 5, 2)}, Shape{1, 5}, []int{1}, []float64{0, 1, 2, 3, 4}},
   554  	{"r[:, 0]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, ss(0)}, ScalarShape(), nil, float64(0)},
   555  	{"r[:, 0:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, makeRS(0, 2)}, Shape{1, 2}, []int{5, 1}, []float64{0, 1}},
   556  	{"r[:, 1:5:2]", Range(Float64, 0, 5), Shape{1, 5}, []Slice{nil, makeRS(1, 5, 2)}, Shape{1, 2}, []int{5, 2}, []float64{1, 2, 3, 4}},
   557  
   558  	// // matrix
   559  	{"A[0]", Range(Float64, 0, 6), Shape{2, 3}, []Slice{ss(0)}, Shape{1, 3}, []int{1}, Range(Float64, 0, 3)},
   560  	{"A[0:2]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{makeRS(0, 2)}, Shape{2, 5}, []int{5, 1}, Range(Float64, 0, 10)},
   561  	{"A[0, 0]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{ss(0), ss(0)}, ScalarShape(), nil, float64(0)},
   562  	{"A[0, 1:5]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{ss(0), makeRS(1, 5)}, Shape{4}, []int{1}, Range(Float64, 1, 5)},
   563  	{"A[0, 1:5:2]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{ss(0), makeRS(1, 5, 2)}, Shape{1, 2}, []int{2}, Range(Float64, 1, 5)},
   564  	{"A[:, 0]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{nil, ss(0)}, Shape{4, 1}, []int{5}, Range(Float64, 0, 16)},
   565  	{"A[:, 1:5]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{nil, makeRS(1, 5)}, Shape{4, 4}, []int{5, 1}, Range(Float64, 1, 20)},
   566  	{"A[:, 1:5:2]", Range(Float64, 0, 20), Shape{4, 5}, []Slice{nil, makeRS(1, 5, 2)}, Shape{4, 2}, []int{5, 2}, Range(Float64, 1, 20)},
   567  
   568  	// 3tensor with leading and trailing 1s
   569  
   570  	{"3T1[0]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{ss(0)}, Shape{9, 1}, []int{1, 1}, Range(Float64, 0, 9)},
   571  	{"3T1[nil, 0:2]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{nil, makeRS(0, 2)}, Shape{1, 2, 1}, []int{9, 1, 1}, Range(Float64, 0, 2)},
   572  	{"3T1[nil, 0:5:3]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{nil, makeRS(0, 5, 3)}, Shape{1, 2, 1}, []int{9, 3, 1}, Range(Float64, 0, 5)},
   573  	{"3T1[nil, 1:5:3]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{nil, makeRS(1, 5, 3)}, Shape{1, 2, 1}, []int{9, 3, 1}, Range(Float64, 1, 5)},
   574  	{"3T1[nil, 1:9:3]", Range(Float64, 0, 9), Shape{1, 9, 1}, []Slice{nil, makeRS(1, 9, 3)}, Shape{1, 3, 1}, []int{9, 3, 1}, Range(Float64, 1, 9)},
   575  
   576  	// 3tensor
   577  	{"3T[0]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(0)}, Shape{9, 2}, []int{2, 1}, Range(Float64, 0, 18)},
   578  	{"3T[1]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1)}, Shape{9, 2}, []int{2, 1}, Range(Float64, 18, 36)},
   579  	{"3T[1, 2]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1), ss(2)}, Shape{2}, []int{1}, Range(Float64, 22, 24)},
   580  	{"3T[1, 2:4]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1), makeRS(2, 4)}, Shape{2, 2}, []int{2, 1}, Range(Float64, 22, 26)},
   581  	{"3T[1, 2:8:2]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1), makeRS(2, 8, 2)}, Shape{3, 2}, []int{4, 1}, Range(Float64, 22, 34)},
   582  	{"3T[1, 2:8:3]", Range(Float64, 0, 36), Shape{2, 9, 2}, []Slice{ss(1), makeRS(2, 8, 3)}, Shape{2, 2}, []int{6, 1}, Range(Float64, 22, 34)},
   583  	{"3T[1, 2:9:2]", Range(Float64, 0, 126), Shape{2, 9, 7}, []Slice{ss(1), makeRS(2, 9, 2)}, Shape{4, 7}, []int{14, 1}, Range(Float64, 77, 126)},
   584  	{"3T[1, 2:9:2, 1]", Range(Float64, 0, 126), Shape{2, 9, 7}, []Slice{ss(1), makeRS(2, 9, 2), ss(1)}, Shape{4}, []int{14}, Range(Float64, 78, 121)}, // should this be a colvec?
   585  	{"3T[1, 2:9:2, 1:4:2]", Range(Float64, 0, 126), Shape{2, 9, 7}, []Slice{ss(1), makeRS(2, 9, 2), makeRS(1, 4, 2)}, Shape{4, 2}, []int{14, 2}, Range(Float64, 78, 123)},
   586  }
   587  
   588  func TestDense_Slice(t *testing.T) {
   589  	assert := assert.New(t)
   590  	var T *Dense
   591  	var V Tensor
   592  	var err error
   593  
   594  	for _, sts := range denseSliceTests {
   595  		T = New(WithShape(sts.shape...), WithBacking(sts.data))
   596  		t.Log(sts.name)
   597  		if V, err = T.Slice(sts.slices...); err != nil {
   598  			t.Error(err)
   599  			continue
   600  		}
   601  		assert.True(sts.correctShape.Eq(V.Shape()), "Test: %v - Incorrect Shape. Correct: %v. Got %v", sts.name, sts.correctShape, V.Shape())
   602  		assert.Equal(sts.correctStride, V.Strides(), "Test: %v - Incorrect Stride", sts.name)
   603  		assert.Equal(sts.correctData, V.Data(), "Test: %v - Incorrect Data", sts.name)
   604  	}
   605  
   606  	// Transposed slice
   607  	T = New(WithShape(2, 3), WithBacking(Range(Float32, 0, 6)))
   608  	T.T()
   609  	V, err = T.Slice(ss(0))
   610  	assert.True(Shape{2}.Eq(V.Shape()))
   611  	assert.Equal([]int{3}, V.Strides())
   612  	assert.Equal([]float32{0, 1, 2, 3}, V.Data())
   613  	assert.True(V.(*Dense).old.IsZero())
   614  
   615  	// slice a sliced
   616  	t.Logf("%v", V)
   617  	V, err = V.Slice(makeRS(1, 2))
   618  	t.Logf("%v", V)
   619  	assert.True(ScalarShape().Eq(V.Shape()))
   620  	assert.Equal(float32(3), V.Data())
   621  
   622  	// And now, ladies and gentlemen, the idiots!
   623  
   624  	// too many slices
   625  	_, err = T.Slice(ss(1), ss(2), ss(3), ss(4))
   626  	if err == nil {
   627  		t.Error("Expected a DimMismatchError error")
   628  	}
   629  
   630  	// out of range sliced
   631  	_, err = T.Slice(makeRS(20, 5))
   632  	if err == nil {
   633  		t.Error("Expected a IndexError")
   634  	}
   635  
   636  	// surely nobody can be this dumb? Having a start of negatives
   637  	_, err = T.Slice(makeRS(-1, 1))
   638  	if err == nil {
   639  		t.Error("Expected a IndexError")
   640  	}
   641  }
   642  
   643  func TestDense_Narrow(t *testing.T) {
   644  	testCases := []struct {
   645  		x                  *Dense
   646  		dim, start, length int
   647  		expected           *Dense
   648  	}{
   649  		{
   650  			x: New(
   651  				WithShape(3),
   652  				WithBacking([]int{1, 2, 3}),
   653  			),
   654  			dim:    0,
   655  			start:  1,
   656  			length: 1,
   657  			expected: New(
   658  				WithShape(),
   659  				WithBacking([]int{2}),
   660  			),
   661  		},
   662  		{
   663  			x: New(
   664  				WithShape(3, 3),
   665  				WithBacking([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}),
   666  			),
   667  			dim:    0,
   668  			start:  0,
   669  			length: 2,
   670  			expected: New(
   671  				WithShape(2, 3),
   672  				WithBacking([]int{1, 2, 3, 4, 5, 6}),
   673  			),
   674  		},
   675  		{
   676  			x: New(
   677  				WithShape(3, 3),
   678  				WithBacking([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}),
   679  			),
   680  			dim:    1,
   681  			start:  1,
   682  			length: 2,
   683  			expected: New(
   684  				WithShape(3, 2),
   685  				WithBacking([]int{2, 3, 5, 6, 8, 9}),
   686  			),
   687  		},
   688  		{
   689  			x: New(
   690  				WithShape(3, 3),
   691  				WithBacking([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}),
   692  			),
   693  			dim:    1,
   694  			start:  0,
   695  			length: 1,
   696  			expected: New(
   697  				WithShape(3),
   698  				WithBacking([]int{1, 4, 7}),
   699  			),
   700  		},
   701  	}
   702  
   703  	for i, tC := range testCases {
   704  		t.Run(fmt.Sprintf("Example #%d narrow(%v,%d,%d,%v)", i+1, tC.x.Shape(), tC.dim, tC.start, tC.length), func(t *testing.T) {
   705  			c := assert.New(t)
   706  			// t.Logf("X:\n%v", tC.x)
   707  
   708  			y, err := tC.x.Narrow(tC.dim, tC.start, tC.length)
   709  			c.NoError(err)
   710  			// t.Logf("y:\n%v", y)
   711  
   712  			yMat := y.Materialize()
   713  			c.Equal(tC.expected.Shape(), yMat.Shape())
   714  			c.Equal(tC.expected.Data(), yMat.Data())
   715  
   716  			// err = y.Memset(1024)
   717  			// c.NoError(err)
   718  			// t.Logf("After Memset\nY: %v\nX:\n%v", y, tC.x)
   719  		})
   720  	}
   721  }
   722  
   723  func TestDense_SliceInto(t *testing.T) {
   724  	V := New(WithShape(100), Of(Byte))
   725  	T := New(WithBacking([]float64{1, 2, 3, 4, 5, 6}), WithShape(2, 3))
   726  	T.SliceInto(V, ss(0))
   727  
   728  	assert.True(t, Shape{3}.Eq(V.Shape()), "Got %v", V.Shape())
   729  	assert.Equal(t, []float64{1, 2, 3}, V.Data())
   730  }
   731  
   732  var rollaxisTests = []struct {
   733  	axis, start int
   734  
   735  	correctShape Shape
   736  }{
   737  	{0, 0, Shape{1, 2, 3, 4}},
   738  	{0, 1, Shape{1, 2, 3, 4}},
   739  	{0, 2, Shape{2, 1, 3, 4}},
   740  	{0, 3, Shape{2, 3, 1, 4}},
   741  	{0, 4, Shape{2, 3, 4, 1}},
   742  
   743  	{1, 0, Shape{2, 1, 3, 4}},
   744  	{1, 1, Shape{1, 2, 3, 4}},
   745  	{1, 2, Shape{1, 2, 3, 4}},
   746  	{1, 3, Shape{1, 3, 2, 4}},
   747  	{1, 4, Shape{1, 3, 4, 2}},
   748  
   749  	{2, 0, Shape{3, 1, 2, 4}},
   750  	{2, 1, Shape{1, 3, 2, 4}},
   751  	{2, 2, Shape{1, 2, 3, 4}},
   752  	{2, 3, Shape{1, 2, 3, 4}},
   753  	{2, 4, Shape{1, 2, 4, 3}},
   754  
   755  	{3, 0, Shape{4, 1, 2, 3}},
   756  	{3, 1, Shape{1, 4, 2, 3}},
   757  	{3, 2, Shape{1, 2, 4, 3}},
   758  	{3, 3, Shape{1, 2, 3, 4}},
   759  	{3, 4, Shape{1, 2, 3, 4}},
   760  }
   761  
   762  // The RollAxis tests are directly adapted from Numpy's test cases.
   763  func TestDense_RollAxis(t *testing.T) {
   764  	assert := assert.New(t)
   765  	var T *Dense
   766  	var err error
   767  
   768  	for _, rats := range rollaxisTests {
   769  		T = New(Of(Byte), WithShape(1, 2, 3, 4))
   770  		if _, err = T.RollAxis(rats.axis, rats.start, false); assert.NoError(err) {
   771  			assert.True(rats.correctShape.Eq(T.Shape()), "%d %d Expected %v, got %v", rats.axis, rats.start, rats.correctShape, T.Shape())
   772  		}
   773  	}
   774  }
   775  
   776  var concatTests = []struct {
   777  	name   string
   778  	dt     Dtype
   779  	a      interface{}
   780  	b      interface{}
   781  	shape  Shape
   782  	shapeB Shape
   783  	axis   int
   784  
   785  	correctShape Shape
   786  	correctData  interface{}
   787  }{
   788  	// Float64
   789  	{"vector", Float64, nil, nil, Shape{2}, nil, 0, Shape{4}, []float64{0, 1, 0, 1}},
   790  	{"matrix; axis 0 ", Float64, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []float64{0, 1, 2, 3, 0, 1, 2, 3}},
   791  	{"matrix; axis 1 ", Float64, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []float64{0, 1, 0, 1, 2, 3, 2, 3}},
   792  
   793  	// Float32
   794  	{"vector", Float32, nil, nil, Shape{2}, nil, 0, Shape{4}, []float32{0, 1, 0, 1}},
   795  	{"matrix; axis 0 ", Float32, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []float32{0, 1, 2, 3, 0, 1, 2, 3}},
   796  	{"matrix; axis 1 ", Float32, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []float32{0, 1, 0, 1, 2, 3, 2, 3}},
   797  
   798  	// Int
   799  	{"vector", Int, nil, nil, Shape{2}, nil, 0, Shape{4}, []int{0, 1, 0, 1}},
   800  	{"matrix; axis 0 ", Int, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []int{0, 1, 2, 3, 0, 1, 2, 3}},
   801  	{"matrix; axis 1 ", Int, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []int{0, 1, 0, 1, 2, 3, 2, 3}},
   802  
   803  	// Int64
   804  	{"vector", Int64, nil, nil, Shape{2}, nil, 0, Shape{4}, []int64{0, 1, 0, 1}},
   805  	{"matrix; axis 0 ", Int64, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []int64{0, 1, 2, 3, 0, 1, 2, 3}},
   806  	{"matrix; axis 1 ", Int64, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []int64{0, 1, 0, 1, 2, 3, 2, 3}},
   807  
   808  	// Int32
   809  	{"vector", Int32, nil, nil, Shape{2}, nil, 0, Shape{4}, []int32{0, 1, 0, 1}},
   810  	{"matrix; axis 0 ", Int32, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []int32{0, 1, 2, 3, 0, 1, 2, 3}},
   811  	{"matrix; axis 1 ", Int32, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []int32{0, 1, 0, 1, 2, 3, 2, 3}},
   812  
   813  	// Byte
   814  	{"vector", Byte, nil, nil, Shape{2}, nil, 0, Shape{4}, []byte{0, 1, 0, 1}},
   815  	{"matrix; axis 0 ", Byte, nil, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []byte{0, 1, 2, 3, 0, 1, 2, 3}},
   816  	{"matrix; axis 1 ", Byte, nil, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []byte{0, 1, 0, 1, 2, 3, 2, 3}},
   817  
   818  	// Bool
   819  	{"vector", Bool, []bool{true, false}, nil, Shape{2}, nil, 0, Shape{4}, []bool{true, false, true, false}},
   820  	{"matrix; axis 0 ", Bool, []bool{true, false, true, false}, nil, Shape{2, 2}, nil, 0, Shape{4, 2}, []bool{true, false, true, false, true, false, true, false}},
   821  	{"matrix; axis 1 ", Bool, []bool{true, false, true, false}, nil, Shape{2, 2}, nil, 1, Shape{2, 4}, []bool{true, false, true, false, true, false, true, false}},
   822  
   823  	// gorgonia/gorgonia#218 related
   824  	{"matrix; axis 0", Float64, nil, nil, Shape{2, 2}, Shape{1, 2}, 0, Shape{3, 2}, []float64{0, 1, 2, 3, 0, 1}},
   825  	{"matrix; axis 1", Float64, nil, nil, Shape{2, 2}, Shape{2, 1}, 1, Shape{2, 3}, []float64{0, 1, 0, 2, 3, 1}},
   826  	{"colvec matrix, axis 0", Float64, nil, nil, Shape{2, 1}, Shape{1, 1}, 0, Shape{3, 1}, []float64{0, 1, 0}},
   827  	{"rowvec matrix, axis 1", Float64, nil, nil, Shape{1, 2}, Shape{1, 1}, 1, Shape{1, 3}, []float64{0, 1, 0}},
   828  
   829  	{"3tensor; axis 0", Float64, nil, nil, Shape{2, 3, 2}, Shape{1, 3, 2}, 0, Shape{3, 3, 2}, []float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5}},
   830  	{"3tensor; axis 2", Float64, nil, nil, Shape{2, 3, 2}, Shape{2, 3, 1}, 2, Shape{2, 3, 3}, []float64{0, 1, 0, 2, 3, 1, 4, 5, 2, 6, 7, 3, 8, 9, 4, 10, 11, 5}},
   831  	{"3tensor; axis 1", Float64, nil, nil, Shape{2, 3, 2}, Shape{2, 1, 2}, 1, Shape{2, 4, 2}, []float64{0, 1, 2, 3, 4, 5, 0, 1, 6, 7, 8, 9, 10, 11, 2, 3}},
   832  }
   833  
   834  func TestDense_Concat(t *testing.T) {
   835  	assert := assert.New(t)
   836  
   837  	for _, cts := range concatTests {
   838  		var T0, T1 *Dense
   839  
   840  		if cts.a == nil {
   841  			T0 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize())))
   842  		} else {
   843  			T0 = New(WithShape(cts.shape...), WithBacking(cts.a))
   844  		}
   845  
   846  		switch {
   847  		case cts.shapeB == nil && cts.a == nil:
   848  			T1 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize())))
   849  		case cts.shapeB == nil && cts.a != nil:
   850  			T1 = New(WithShape(cts.shape...), WithBacking(cloneArray(cts.a)))
   851  		case cts.shapeB != nil && cts.b == nil:
   852  			T1 = New(WithShape(cts.shapeB...), WithBacking(Range(cts.dt, 0, cts.shapeB.TotalSize())))
   853  		case cts.shapeB != nil && cts.b != nil:
   854  			T1 = New(WithShape(cts.shapeB...), WithBacking(cts.b))
   855  		}
   856  
   857  		T2, err := T0.Concat(cts.axis, T1)
   858  		if err != nil {
   859  			t.Errorf("Test %v failed: %v", cts.name, err)
   860  			continue
   861  		}
   862  
   863  		assert.True(cts.correctShape.Eq(T2.Shape()))
   864  		assert.Equal(cts.correctData, T2.Data())
   865  	}
   866  
   867  	//Masked case
   868  
   869  	for _, cts := range concatTests {
   870  		var T0, T1 *Dense
   871  
   872  		if cts.a == nil {
   873  			T0 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize())))
   874  			T0.MaskedEqual(castToDt(0.0, cts.dt))
   875  		} else {
   876  			T0 = New(WithShape(cts.shape...), WithBacking(cts.a))
   877  			T0.MaskedEqual(castToDt(0.0, cts.dt))
   878  		}
   879  
   880  		switch {
   881  		case cts.shapeB == nil && cts.a == nil:
   882  			T1 = New(WithShape(cts.shape...), WithBacking(Range(cts.dt, 0, cts.shape.TotalSize())))
   883  		case cts.shapeB == nil && cts.a != nil:
   884  			T1 = New(WithShape(cts.shape...), WithBacking(cloneArray(cts.a)))
   885  		case cts.shapeB != nil && cts.b == nil:
   886  			T1 = New(WithShape(cts.shapeB...), WithBacking(Range(cts.dt, 0, cts.shapeB.TotalSize())))
   887  		case cts.shapeB != nil && cts.b != nil:
   888  			T1 = New(WithShape(cts.shapeB...), WithBacking(cts.b))
   889  		}
   890  		T1.MaskedEqual(castToDt(0.0, cts.dt))
   891  
   892  		T2, err := T0.Concat(cts.axis, T1)
   893  		if err != nil {
   894  			t.Errorf("Test %v failed: %v", cts.name, err)
   895  			continue
   896  		}
   897  
   898  		T3 := New(WithShape(cts.correctShape...), WithBacking(cts.correctData))
   899  		T3.MaskedEqual(castToDt(0.0, cts.dt))
   900  
   901  		assert.True(cts.correctShape.Eq(T2.Shape()))
   902  		assert.Equal(cts.correctData, T2.Data())
   903  		assert.Equal(T3.mask, T2.mask)
   904  	}
   905  }
   906  
   907  func TestDense_Concat_sliced(t *testing.T) {
   908  	v := New(
   909  		WithShape(1, 5),
   910  		WithBacking([]float64{0, 1, 2, 3, 4}),
   911  	)
   912  	cols := make([]Tensor, v.Shape().TotalSize())
   913  	for i := 0; i < v.Shape().TotalSize(); i++ {
   914  		sliced, err := v.Slice(nil, ss(i))
   915  		if err != nil {
   916  			t.Fatalf("Failed to slice %d. Error: %v", i, err)
   917  		}
   918  		if err = sliced.Reshape(sliced.Shape().TotalSize(), 1); err != nil {
   919  			t.Fatalf("Failed to reshape %d. Error %v", i, err)
   920  		}
   921  		cols[i] = sliced
   922  	}
   923  	result, err := Concat(1, cols[0], cols[1:]...)
   924  	if err != nil {
   925  		t.Error(err)
   926  	}
   927  	assert.Equal(t, v.Data(), result.Data())
   928  	if v.Uintptr() == result.Uintptr() {
   929  		t.Error("They should not share the same backing data!")
   930  	}
   931  
   932  }
   933  
   934  var simpleStackTests = []struct {
   935  	name       string
   936  	dt         Dtype
   937  	shape      Shape
   938  	axis       int
   939  	stackCount int
   940  
   941  	correctShape Shape
   942  	correctData  interface{}
   943  }{
   944  	// Size 8
   945  	{"vector, axis 0, stack 2", Float64, Shape{2}, 0, 2, Shape{2, 2}, []float64{0, 1, 100, 101}},
   946  	{"vector, axis 1, stack 2", Float64, Shape{2}, 1, 2, Shape{2, 2}, []float64{0, 100, 1, 101}},
   947  	{"matrix, axis 0, stack 2", Float64, Shape{2, 3}, 0, 2, Shape{2, 2, 3}, []float64{0, 1, 2, 3, 4, 5, 100, 101, 102, 103, 104, 105}},
   948  	{"matrix, axis 1, stack 2", Float64, Shape{2, 3}, 1, 2, Shape{2, 2, 3}, []float64{0, 1, 2, 100, 101, 102, 3, 4, 5, 103, 104, 105}},
   949  	{"matrix, axis 2, stack 2", Float64, Shape{2, 3}, 2, 2, Shape{2, 3, 2}, []float64{0, 100, 1, 101, 2, 102, 3, 103, 4, 104, 5, 105}},
   950  	{"matrix, axis 0, stack 3", Float64, Shape{2, 3}, 0, 3, Shape{3, 2, 3}, []float64{0, 1, 2, 3, 4, 5, 100, 101, 102, 103, 104, 105, 200, 201, 202, 203, 204, 205}},
   951  	{"matrix, axis 1, stack 3", Float64, Shape{2, 3}, 1, 3, Shape{2, 3, 3}, []float64{0, 1, 2, 100, 101, 102, 200, 201, 202, 3, 4, 5, 103, 104, 105, 203, 204, 205}},
   952  	{"matrix, axis 2, stack 3", Float64, Shape{2, 3}, 2, 3, Shape{2, 3, 3}, []float64{0, 100, 200, 1, 101, 201, 2, 102, 202, 3, 103, 203, 4, 104, 204, 5, 105, 205}},
   953  
   954  	// Size 4
   955  	{"vector, axis 0, stack 2 (f32)", Float32, Shape{2}, 0, 2, Shape{2, 2}, []float32{0, 1, 100, 101}},
   956  	{"vector, axis 1, stack 2 (f32)", Float32, Shape{2}, 1, 2, Shape{2, 2}, []float32{0, 100, 1, 101}},
   957  	{"matrix, axis 0, stack 2 (f32)", Float32, Shape{2, 3}, 0, 2, Shape{2, 2, 3}, []float32{0, 1, 2, 3, 4, 5, 100, 101, 102, 103, 104, 105}},
   958  	{"matrix, axis 1, stack 2 (f32)", Float32, Shape{2, 3}, 1, 2, Shape{2, 2, 3}, []float32{0, 1, 2, 100, 101, 102, 3, 4, 5, 103, 104, 105}},
   959  	{"matrix, axis 2, stack 2 (f32)", Float32, Shape{2, 3}, 2, 2, Shape{2, 3, 2}, []float32{0, 100, 1, 101, 2, 102, 3, 103, 4, 104, 5, 105}},
   960  	{"matrix, axis 0, stack 3 (f32)", Float32, Shape{2, 3}, 0, 3, Shape{3, 2, 3}, []float32{0, 1, 2, 3, 4, 5, 100, 101, 102, 103, 104, 105, 200, 201, 202, 203, 204, 205}},
   961  	{"matrix, axis 1, stack 3 (f32)", Float32, Shape{2, 3}, 1, 3, Shape{2, 3, 3}, []float32{0, 1, 2, 100, 101, 102, 200, 201, 202, 3, 4, 5, 103, 104, 105, 203, 204, 205}},
   962  	{"matrix, axis 2, stack 3 (f32)", Float32, Shape{2, 3}, 2, 3, Shape{2, 3, 3}, []float32{0, 100, 200, 1, 101, 201, 2, 102, 202, 3, 103, 203, 4, 104, 204, 5, 105, 205}},
   963  
   964  	// Size 2
   965  	{"vector, axis 0, stack 2 (i16)", Int16, Shape{2}, 0, 2, Shape{2, 2}, []int16{0, 1, 100, 101}},
   966  	{"vector, axis 1, stack 2 (i16)", Int16, Shape{2}, 1, 2, Shape{2, 2}, []int16{0, 100, 1, 101}},
   967  	{"matrix, axis 0, stack 2 (i16)", Int16, Shape{2, 3}, 0, 2, Shape{2, 2, 3}, []int16{0, 1, 2, 3, 4, 5, 100, 101, 102, 103, 104, 105}},
   968  	{"matrix, axis 1, stack 2 (i16)", Int16, Shape{2, 3}, 1, 2, Shape{2, 2, 3}, []int16{0, 1, 2, 100, 101, 102, 3, 4, 5, 103, 104, 105}},
   969  	{"matrix, axis 2, stack 2 (i16)", Int16, Shape{2, 3}, 2, 2, Shape{2, 3, 2}, []int16{0, 100, 1, 101, 2, 102, 3, 103, 4, 104, 5, 105}},
   970  	{"matrix, axis 0, stack 3 (i16)", Int16, Shape{2, 3}, 0, 3, Shape{3, 2, 3}, []int16{0, 1, 2, 3, 4, 5, 100, 101, 102, 103, 104, 105, 200, 201, 202, 203, 204, 205}},
   971  	{"matrix, axis 1, stack 3 (i16)", Int16, Shape{2, 3}, 1, 3, Shape{2, 3, 3}, []int16{0, 1, 2, 100, 101, 102, 200, 201, 202, 3, 4, 5, 103, 104, 105, 203, 204, 205}},
   972  	{"matrix, axis 2, stack 3 (i16)", Int16, Shape{2, 3}, 2, 3, Shape{2, 3, 3}, []int16{0, 100, 200, 1, 101, 201, 2, 102, 202, 3, 103, 203, 4, 104, 204, 5, 105, 205}},
   973  
   974  	// Size 1
   975  	{"vector, axis 0, stack 2 (u8)", Byte, Shape{2}, 0, 2, Shape{2, 2}, []byte{0, 1, 100, 101}},
   976  	{"vector, axis 1, stack 2 (u8)", Byte, Shape{2}, 1, 2, Shape{2, 2}, []byte{0, 100, 1, 101}},
   977  	{"matrix, axis 0, stack 2 (u8)", Byte, Shape{2, 3}, 0, 2, Shape{2, 2, 3}, []byte{0, 1, 2, 3, 4, 5, 100, 101, 102, 103, 104, 105}},
   978  	{"matrix, axis 1, stack 2 (u8)", Byte, Shape{2, 3}, 1, 2, Shape{2, 2, 3}, []byte{0, 1, 2, 100, 101, 102, 3, 4, 5, 103, 104, 105}},
   979  	{"matrix, axis 2, stack 2 (u8)", Byte, Shape{2, 3}, 2, 2, Shape{2, 3, 2}, []byte{0, 100, 1, 101, 2, 102, 3, 103, 4, 104, 5, 105}},
   980  	{"matrix, axis 0, stack 3 (u8)", Byte, Shape{2, 3}, 0, 3, Shape{3, 2, 3}, []byte{0, 1, 2, 3, 4, 5, 100, 101, 102, 103, 104, 105, 200, 201, 202, 203, 204, 205}},
   981  	{"matrix, axis 1, stack 3 (u8)", Byte, Shape{2, 3}, 1, 3, Shape{2, 3, 3}, []byte{0, 1, 2, 100, 101, 102, 200, 201, 202, 3, 4, 5, 103, 104, 105, 203, 204, 205}},
   982  	{"matrix, axis 2, stack 3 (u8)", Byte, Shape{2, 3}, 2, 3, Shape{2, 3, 3}, []byte{0, 100, 200, 1, 101, 201, 2, 102, 202, 3, 103, 203, 4, 104, 204, 5, 105, 205}},
   983  }
   984  
   985  var viewStackTests = []struct {
   986  	name       string
   987  	dt         Dtype
   988  	shape      Shape
   989  	transform  []int
   990  	slices     []Slice
   991  	axis       int
   992  	stackCount int
   993  
   994  	correctShape Shape
   995  	correctData  interface{}
   996  }{
   997  	// Size 8
   998  	{"matrix(4x4)[1:3, 1:3] axis 0", Float64, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 0, 2, Shape{2, 2, 2}, []float64{5, 6, 9, 10, 105, 106, 109, 110}},
   999  	{"matrix(4x4)[1:3, 1:3] axis 1", Float64, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 1, 2, Shape{2, 2, 2}, []float64{5, 6, 105, 106, 9, 10, 109, 110}},
  1000  	{"matrix(4x4)[1:3, 1:3] axis 2", Float64, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 2, 2, Shape{2, 2, 2}, []float64{5, 105, 6, 106, 9, 109, 10, 110}},
  1001  
  1002  	// Size 4
  1003  	{"matrix(4x4)[1:3, 1:3] axis 0 (u32)", Uint32, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 0, 2, Shape{2, 2, 2}, []uint32{5, 6, 9, 10, 105, 106, 109, 110}},
  1004  	{"matrix(4x4)[1:3, 1:3] axis 1 (u32)", Uint32, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 1, 2, Shape{2, 2, 2}, []uint32{5, 6, 105, 106, 9, 10, 109, 110}},
  1005  	{"matrix(4x4)[1:3, 1:3] axis 2 (u32)", Uint32, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 2, 2, Shape{2, 2, 2}, []uint32{5, 105, 6, 106, 9, 109, 10, 110}},
  1006  
  1007  	// Size 2
  1008  	{"matrix(4x4)[1:3, 1:3] axis 0 (u16)", Uint16, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 0, 2, Shape{2, 2, 2}, []uint16{5, 6, 9, 10, 105, 106, 109, 110}},
  1009  	{"matrix(4x4)[1:3, 1:3] axis 1 (u16)", Uint16, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 1, 2, Shape{2, 2, 2}, []uint16{5, 6, 105, 106, 9, 10, 109, 110}},
  1010  	{"matrix(4x4)[1:3, 1:3] axis 2 (u16)", Uint16, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 2, 2, Shape{2, 2, 2}, []uint16{5, 105, 6, 106, 9, 109, 10, 110}},
  1011  
  1012  	// Size 1
  1013  	{"matrix(4x4)[1:3, 1:3] axis 0 (u8)", Byte, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 0, 2, Shape{2, 2, 2}, []byte{5, 6, 9, 10, 105, 106, 109, 110}},
  1014  	{"matrix(4x4)[1:3, 1:3] axis 1 (u8)", Byte, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 1, 2, Shape{2, 2, 2}, []byte{5, 6, 105, 106, 9, 10, 109, 110}},
  1015  	{"matrix(4x4)[1:3, 1:3] axis 2 (u8)", Byte, Shape{4, 4}, nil, []Slice{makeRS(1, 3), makeRS(1, 3)}, 2, 2, Shape{2, 2, 2}, []byte{5, 105, 6, 106, 9, 109, 10, 110}},
  1016  }
  1017  
  1018  func TestDense_Stack(t *testing.T) {
  1019  	assert := assert.New(t)
  1020  	var err error
  1021  	for _, sts := range simpleStackTests {
  1022  		T := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, 0, sts.shape.TotalSize())))
  1023  
  1024  		var stacked []*Dense
  1025  		for i := 0; i < sts.stackCount-1; i++ {
  1026  			offset := (i + 1) * 100
  1027  			T1 := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, offset, sts.shape.TotalSize()+offset)))
  1028  			stacked = append(stacked, T1)
  1029  		}
  1030  
  1031  		T2, err := T.Stack(sts.axis, stacked...)
  1032  		if err != nil {
  1033  			t.Error(err)
  1034  			continue
  1035  		}
  1036  		assert.True(sts.correctShape.Eq(T2.Shape()))
  1037  		assert.Equal(sts.correctData, T2.Data())
  1038  	}
  1039  
  1040  	for _, sts := range viewStackTests {
  1041  		T := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, 0, sts.shape.TotalSize())))
  1042  		switch {
  1043  		case sts.slices != nil && sts.transform == nil:
  1044  			var sliced Tensor
  1045  			if sliced, err = T.Slice(sts.slices...); err != nil {
  1046  				t.Error(err)
  1047  				continue
  1048  			}
  1049  			T = sliced.(*Dense)
  1050  		case sts.transform != nil && sts.slices == nil:
  1051  			T.T(sts.transform...)
  1052  		}
  1053  
  1054  		var stacked []*Dense
  1055  		for i := 0; i < sts.stackCount-1; i++ {
  1056  			offset := (i + 1) * 100
  1057  			T1 := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, offset, sts.shape.TotalSize()+offset)))
  1058  			switch {
  1059  			case sts.slices != nil && sts.transform == nil:
  1060  				var sliced Tensor
  1061  				if sliced, err = T1.Slice(sts.slices...); err != nil {
  1062  					t.Error(err)
  1063  					continue
  1064  				}
  1065  				T1 = sliced.(*Dense)
  1066  			case sts.transform != nil && sts.slices == nil:
  1067  				T1.T(sts.transform...)
  1068  			}
  1069  
  1070  			stacked = append(stacked, T1)
  1071  		}
  1072  		T2, err := T.Stack(sts.axis, stacked...)
  1073  		if err != nil {
  1074  			t.Error(err)
  1075  			continue
  1076  		}
  1077  		assert.True(sts.correctShape.Eq(T2.Shape()))
  1078  		assert.Equal(sts.correctData, T2.Data(), "%q failed", sts.name)
  1079  	}
  1080  
  1081  	// Repeat tests with masks
  1082  	for _, sts := range simpleStackTests {
  1083  		T := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, 0, sts.shape.TotalSize())))
  1084  
  1085  		var stacked []*Dense
  1086  		for i := 0; i < sts.stackCount-1; i++ {
  1087  			offset := (i + 1) * 100
  1088  			T1 := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, offset, sts.shape.TotalSize()+offset)))
  1089  			T1.MaskedInside(castToDt(102.0, sts.dt), castToDt(225.0, sts.dt))
  1090  			stacked = append(stacked, T1)
  1091  		}
  1092  
  1093  		T2, err := T.Stack(sts.axis, stacked...)
  1094  		if err != nil {
  1095  			t.Error(err)
  1096  			continue
  1097  		}
  1098  
  1099  		T3 := New(WithShape(sts.correctShape...), WithBacking(sts.correctData))
  1100  		T3.MaskedInside(castToDt(102.0, sts.dt), castToDt(225.0, sts.dt))
  1101  
  1102  		assert.True(sts.correctShape.Eq(T2.Shape()))
  1103  		assert.Equal(sts.correctData, T2.Data())
  1104  		assert.Equal(T3.mask, T2.mask)
  1105  	}
  1106  
  1107  	for _, sts := range viewStackTests {
  1108  		T := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, 0, sts.shape.TotalSize())))
  1109  		switch {
  1110  		case sts.slices != nil && sts.transform == nil:
  1111  			var sliced Tensor
  1112  			if sliced, err = T.Slice(sts.slices...); err != nil {
  1113  				t.Error(err)
  1114  				continue
  1115  			}
  1116  			T = sliced.(*Dense)
  1117  		case sts.transform != nil && sts.slices == nil:
  1118  			T.T(sts.transform...)
  1119  		}
  1120  
  1121  		var stacked []*Dense
  1122  		for i := 0; i < sts.stackCount-1; i++ {
  1123  			offset := (i + 1) * 100
  1124  			T1 := New(WithShape(sts.shape...), WithBacking(Range(sts.dt, offset, sts.shape.TotalSize()+offset)))
  1125  			T1.MaskedInside(castToDt(102.0, sts.dt), castToDt(225.0, sts.dt))
  1126  			switch {
  1127  			case sts.slices != nil && sts.transform == nil:
  1128  				var sliced Tensor
  1129  				if sliced, err = T1.Slice(sts.slices...); err != nil {
  1130  					t.Error(err)
  1131  					continue
  1132  				}
  1133  				T1 = sliced.(*Dense)
  1134  			case sts.transform != nil && sts.slices == nil:
  1135  				T1.T(sts.transform...)
  1136  			}
  1137  
  1138  			stacked = append(stacked, T1)
  1139  		}
  1140  
  1141  		T2, err := T.Stack(sts.axis, stacked...)
  1142  		if err != nil {
  1143  			t.Error(err)
  1144  			continue
  1145  		}
  1146  
  1147  		T3 := New(WithShape(sts.correctShape...), WithBacking(sts.correctData))
  1148  		T3.MaskedInside(castToDt(102.0, sts.dt), castToDt(225.0, sts.dt))
  1149  
  1150  		assert.True(sts.correctShape.Eq(T2.Shape()))
  1151  		assert.Equal(sts.correctData, T2.Data())
  1152  		assert.Equal(T3.mask, T2.mask)
  1153  	}
  1154  
  1155  	// arbitrary view slices
  1156  
  1157  	T := New(WithShape(2, 2), WithBacking([]string{"hello", "world", "nihao", "sekai"}))
  1158  	var stacked []*Dense
  1159  	for i := 0; i < 1; i++ {
  1160  		T1 := New(WithShape(2, 2), WithBacking([]string{"blah1", "blah2", "blah3", "blah4"}))
  1161  		var sliced Tensor
  1162  		if sliced, err = T1.Slice(nil, nil); err != nil {
  1163  			t.Error(err)
  1164  			break
  1165  		}
  1166  		T1 = sliced.(*Dense)
  1167  		stacked = append(stacked, T1)
  1168  	}
  1169  	T2, err := T.Stack(0, stacked...)
  1170  	if err != nil {
  1171  		t.Error(err)
  1172  		return
  1173  	}
  1174  
  1175  	correctShape := Shape{2, 2, 2}
  1176  	correctData := []string{"hello", "world", "nihao", "sekai", "blah1", "blah2", "blah3", "blah4"}
  1177  	assert.True(correctShape.Eq(T2.Shape()))
  1178  	assert.Equal(correctData, T2.Data(), "%q failed", "arbitrary view slice")
  1179  }