github.com/wzzhu/tensor@v0.9.24/dense_colmajor_linalg_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/stretchr/testify/assert"
     7  )
     8  
     9  var colMajorTraceTests = []struct {
    10  	data interface{}
    11  
    12  	correct interface{}
    13  	err     bool
    14  }{
    15  	{[]int{0, 1, 2, 3, 4, 5}, int(4), false},
    16  	{[]int8{0, 1, 2, 3, 4, 5}, int8(4), false},
    17  	{[]int16{0, 1, 2, 3, 4, 5}, int16(4), false},
    18  	{[]int32{0, 1, 2, 3, 4, 5}, int32(4), false},
    19  	{[]int64{0, 1, 2, 3, 4, 5}, int64(4), false},
    20  	{[]uint{0, 1, 2, 3, 4, 5}, uint(4), false},
    21  	{[]uint8{0, 1, 2, 3, 4, 5}, uint8(4), false},
    22  	{[]uint16{0, 1, 2, 3, 4, 5}, uint16(4), false},
    23  	{[]uint32{0, 1, 2, 3, 4, 5}, uint32(4), false},
    24  	{[]uint64{0, 1, 2, 3, 4, 5}, uint64(4), false},
    25  	{[]float32{0, 1, 2, 3, 4, 5}, float32(4), false},
    26  	{[]float64{0, 1, 2, 3, 4, 5}, float64(4), false},
    27  	{[]complex64{0, 1, 2, 3, 4, 5}, complex64(4), false},
    28  	{[]complex128{0, 1, 2, 3, 4, 5}, complex128(4), false},
    29  	{[]bool{true, false, true, false, true, false}, nil, true},
    30  }
    31  
    32  func TestColMajor_Dense_Trace(t *testing.T) {
    33  	assert := assert.New(t)
    34  	for i, tts := range colMajorTraceTests {
    35  		T := New(WithShape(2, 3), AsFortran(tts.data))
    36  		trace, err := T.Trace()
    37  
    38  		if checkErr(t, tts.err, err, "Trace", i) {
    39  			continue
    40  		}
    41  		assert.Equal(tts.correct, trace)
    42  
    43  		//
    44  		T = New(WithBacking(tts.data))
    45  		_, err = T.Trace()
    46  		if err == nil {
    47  			t.Error("Expected an error when Trace() on non-matrices")
    48  		}
    49  	}
    50  }
    51  
    52  var colMajorInnerTests = []struct {
    53  	a, b           interface{}
    54  	shapeA, shapeB Shape
    55  
    56  	correct interface{}
    57  	err     bool
    58  }{
    59  	{Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, float64(5), false},
    60  	{Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{3}, float64(5), false},
    61  	{Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{3}, float64(5), false},
    62  	{Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{3, 1}, float64(5), false},
    63  	{Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{3, 1}, float64(5), false},
    64  	{Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{1, 3}, float64(5), false},
    65  	{Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{1, 3}, float64(5), false},
    66  
    67  	{Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, float32(5), false},
    68  	{Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{3}, float32(5), false},
    69  	{Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{3}, float32(5), false},
    70  	{Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{3, 1}, float32(5), false},
    71  	{Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{3, 1}, float32(5), false},
    72  	{Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{1, 3}, float32(5), false},
    73  	{Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{1, 3}, float32(5), false},
    74  
    75  	// stupids: type differences
    76  	{Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, nil, true},
    77  	{Range(Float32, 0, 3), Range(Byte, 0, 3), Shape{3}, Shape{3}, nil, true},
    78  	{Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, nil, true},
    79  	{Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, nil, true},
    80  
    81  	// differing size
    82  	{Range(Float64, 0, 4), Range(Float64, 0, 3), Shape{4}, Shape{3}, nil, true},
    83  
    84  	// A is not a matrix
    85  	{Range(Float64, 0, 4), Range(Float64, 0, 3), Shape{2, 2}, Shape{3}, nil, true},
    86  }
    87  
    88  func TestColMajor_Dense_Inner(t *testing.T) {
    89  	for i, its := range colMajorInnerTests {
    90  		a := New(WithShape(its.shapeA...), AsFortran(its.a))
    91  		b := New(WithShape(its.shapeB...), AsFortran(its.b))
    92  
    93  		T, err := a.Inner(b)
    94  		if checkErr(t, its.err, err, "Inner", i) {
    95  			continue
    96  		}
    97  
    98  		assert.Equal(t, its.correct, T)
    99  	}
   100  }
   101  
   102  var colMajorMatVecMulTests = []linalgTest{
   103  	// Float64s
   104  	{Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false,
   105  		Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2},
   106  		[]float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false},
   107  	{Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false,
   108  		Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2},
   109  		[]float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false},
   110  	{Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false,
   111  		Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2},
   112  		[]float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false},
   113  
   114  	// float64s with transposed matrix
   115  	{Range(Float64, 0, 6), Range(Float64, 0, 2), Shape{2, 3}, Shape{2}, true, false,
   116  		Range(Float64, 52, 55), Range(Float64, 100, 103), Shape{3}, Shape{3},
   117  		[]float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{3}, false, false, false},
   118  
   119  	// Float32s
   120  	{Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false,
   121  		Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2},
   122  		[]float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false},
   123  	{Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false,
   124  		Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2},
   125  		[]float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false},
   126  	{Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false,
   127  		Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2},
   128  		[]float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false},
   129  
   130  	// stupids : unpossible shapes (wrong A)
   131  	{Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{6}, Shape{3}, false, false,
   132  		Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2},
   133  		[]float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false},
   134  
   135  	//stupids: bad A shape
   136  	{Range(Float64, 0, 8), Range(Float64, 0, 3), Shape{4, 2}, Shape{3}, false, false,
   137  		Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2},
   138  		[]float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false},
   139  
   140  	//stupids: bad B shape
   141  	{Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false,
   142  		Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2},
   143  		[]float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false},
   144  
   145  	//stupids: bad reuse
   146  	{Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false,
   147  		Range(Float64, 52, 55), Range(Float64, 100, 102), Shape{3}, Shape{2},
   148  		[]float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true},
   149  
   150  	//stupids: bad incr shape
   151  	{Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false,
   152  		Range(Float64, 52, 54), Range(Float64, 100, 105), Shape{2}, Shape{5},
   153  		[]float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false},
   154  
   155  	// stupids: type mismatch A and B
   156  	{Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false,
   157  		Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3},
   158  		[]float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false},
   159  
   160  	// stupids: type mismatch A and B
   161  	{Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false,
   162  		Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3},
   163  		[]float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false},
   164  
   165  	// stupids: type mismatch A and B
   166  	{Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false,
   167  		Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3},
   168  		[]float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false},
   169  
   170  	// stupids: type mismatch A and B
   171  	{Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false,
   172  		Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3},
   173  		[]float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false},
   174  
   175  	// stupids: type mismatch A and B (non-Float)
   176  	{Range(Float64, 0, 6), Range(Int, 0, 3), Shape{2, 3}, Shape{3}, false, false,
   177  		Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3},
   178  		[]float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false},
   179  
   180  	// stupids: type mismatch, reuse
   181  	{Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false,
   182  		Range(Float32, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2},
   183  		[]float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true},
   184  
   185  	// stupids: type mismatch, incr
   186  	{Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false,
   187  		Range(Float64, 52, 54), Range(Float32, 100, 103), Shape{2}, Shape{3},
   188  		[]float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false},
   189  
   190  	// stupids: type mismatch, incr not a Number
   191  	{Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false,
   192  		Range(Float64, 52, 54), []bool{true, true, true}, Shape{2}, Shape{3},
   193  		[]float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false},
   194  }
   195  
   196  func TestColMajor_Dense_MatVecMul(t *testing.T) {
   197  	assert := assert.New(t)
   198  	for i, mvmt := range colMajorMatVecMulTests {
   199  		a := New(WithShape(mvmt.shapeA...), AsFortran(mvmt.a))
   200  		b := New(WithShape(mvmt.shapeB...), AsFortran(mvmt.b))
   201  
   202  		if mvmt.transA {
   203  			if err := a.T(); err != nil {
   204  				t.Error(err)
   205  				continue
   206  			}
   207  		}
   208  
   209  		T, err := a.MatVecMul(b)
   210  		if checkErr(t, mvmt.err, err, "Safe", i) {
   211  			continue
   212  		}
   213  
   214  		assert.True(mvmt.correctShape.Eq(T.Shape()))
   215  		assert.True(T.DataOrder().IsColMajor())
   216  		assert.Equal(mvmt.correct, T.Data())
   217  
   218  		// incr
   219  		incr := New(WithShape(mvmt.shapeI...), AsFortran(mvmt.incr))
   220  		T, err = a.MatVecMul(b, WithIncr(incr))
   221  		if checkErr(t, mvmt.errIncr, err, "WithIncr", i) {
   222  			continue
   223  		}
   224  
   225  		assert.True(mvmt.correctShape.Eq(T.Shape()))
   226  		assert.True(T.DataOrder().IsColMajor())
   227  		assert.Equal(mvmt.correctIncr, T.Data())
   228  
   229  		// reuse
   230  		reuse := New(WithShape(mvmt.shapeR...), AsFortran(mvmt.reuse))
   231  		T, err = a.MatVecMul(b, WithReuse(reuse))
   232  		if checkErr(t, mvmt.errReuse, err, "WithReuse", i) {
   233  			continue
   234  		}
   235  
   236  		assert.True(mvmt.correctShape.Eq(T.Shape()))
   237  		assert.True(T.DataOrder().IsColMajor())
   238  		assert.Equal(mvmt.correct, T.Data())
   239  
   240  		// reuse AND incr
   241  		T, err = a.MatVecMul(b, WithIncr(incr), WithReuse(reuse))
   242  		if checkErr(t, mvmt.err, err, "WithReuse and WithIncr", i) {
   243  			continue
   244  		}
   245  		assert.True(mvmt.correctShape.Eq(T.Shape()))
   246  		assert.True(T.DataOrder().IsColMajor())
   247  		assert.Equal(mvmt.correctIncrReuse, T.Data())
   248  	}
   249  }
   250  
   251  var colMajorMatMulTests = []linalgTest{
   252  	// Float64s
   253  	{Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false,
   254  		Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2},
   255  		[]float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, false, false},
   256  
   257  	// Float32s
   258  	{Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false,
   259  		Range(Float32, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2},
   260  		[]float32{10, 28, 13, 40}, []float32{110, 130, 114, 143}, []float32{120, 158, 127, 183}, Shape{2, 2}, false, false, false},
   261  
   262  	// Edge cases - Row Vecs (Float64)
   263  	{Range(Float64, 0, 2), Range(Float64, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false,
   264  		Range(Float64, 10, 16), Range(Float64, 100, 106), Shape{2, 3}, Shape{2, 3},
   265  		[]float64{0, 0, 0, 1, 0, 2}, []float64{100, 103, 101, 105, 102, 107}, []float64{100, 103, 101, 106, 102, 109}, Shape{2, 3}, false, false, false},
   266  	{Range(Float64, 0, 2), Range(Float64, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false,
   267  		Range(Float64, 10, 13), Range(Float64, 100, 103), Shape{1, 3}, Shape{1, 3},
   268  		[]float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{1, 3}, false, false, false},
   269  	{Range(Float64, 0, 2), Range(Float64, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false,
   270  		Range(Float64, 0, 1), Range(Float64, 100, 101), Shape{1, 1}, Shape{1, 1},
   271  		[]float64{1}, []float64{101}, []float64{102}, Shape{1, 1}, false, false, false},
   272  
   273  	// Edge cases - Row Vecs (Float32)
   274  	{Range(Float32, 0, 2), Range(Float32, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false,
   275  		Range(Float32, 10, 16), Range(Float32, 100, 106), Shape{2, 3}, Shape{2, 3},
   276  		[]float32{0, 0, 0, 1, 0, 2}, []float32{100, 103, 101, 105, 102, 107}, []float32{100, 103, 101, 106, 102, 109}, Shape{2, 3}, false, false, false},
   277  	{Range(Float32, 0, 2), Range(Float32, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false,
   278  		Range(Float32, 10, 13), Range(Float32, 100, 103), Shape{1, 3}, Shape{1, 3},
   279  		[]float32{3, 4, 5}, []float32{103, 105, 107}, []float32{106, 109, 112}, Shape{1, 3}, false, false, false},
   280  	{Range(Float32, 0, 2), Range(Float32, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false,
   281  		Range(Float32, 0, 1), Range(Float32, 100, 101), Shape{1, 1}, Shape{1, 1},
   282  		[]float32{1}, []float32{101}, []float32{102}, Shape{1, 1}, false, false, false},
   283  
   284  	// stupids - bad shape (not matrices):
   285  	{Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6}, false, false,
   286  		Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2},
   287  		[]float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false},
   288  
   289  	// stupids - bad shape (incompatible shapes):
   290  	{Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6, 1}, false, false,
   291  		Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2},
   292  		[]float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false},
   293  
   294  	// stupids - bad shape (bad reuse shape):
   295  	{Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false,
   296  		Range(Float64, 52, 57), Range(Float64, 100, 104), Shape{5}, Shape{2, 2},
   297  		[]float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, false, true},
   298  
   299  	// stupids - bad shape (bad incr shape):
   300  	{Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false,
   301  		Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{4},
   302  		[]float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, true, false},
   303  
   304  	// stupids - type mismatch (a,b)
   305  	{Range(Float64, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false,
   306  		Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2},
   307  		[]float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false},
   308  
   309  	// stupids - type mismatch (a,b)
   310  	{Range(Float32, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false,
   311  		Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2},
   312  		[]float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false},
   313  
   314  	// stupids type mismatch (b not float)
   315  	{Range(Float64, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false,
   316  		Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2},
   317  		[]float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false},
   318  
   319  	// stupids type mismatch (a not float)
   320  	{Range(Int, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false,
   321  		Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2},
   322  		[]float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false},
   323  
   324  	// stupids: type mismatch (incr)
   325  	{Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false,
   326  		Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2},
   327  		[]float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, true, false},
   328  
   329  	// stupids: type mismatch (reuse)
   330  	{Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false,
   331  		Range(Float32, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2},
   332  		[]float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, false, true},
   333  
   334  	// stupids: type mismatch (reuse)
   335  	{Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false,
   336  		Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2},
   337  		[]float32{10, 28, 13, 40}, []float32{110, 130, 114, 143}, []float32{120, 158, 127, 183}, Shape{2, 2}, false, false, true},
   338  }
   339  
   340  func TestColMajorDense_MatMul(t *testing.T) {
   341  	assert := assert.New(t)
   342  	for i, mmt := range colMajorMatMulTests {
   343  		a := New(WithShape(mmt.shapeA...), AsFortran(mmt.a))
   344  		b := New(WithShape(mmt.shapeB...), AsFortran(mmt.b))
   345  
   346  		T, err := a.MatMul(b)
   347  		if checkErr(t, mmt.err, err, "Safe", i) {
   348  			continue
   349  		}
   350  		assert.True(mmt.correctShape.Eq(T.Shape()))
   351  		assert.True(T.DataOrder().IsColMajor())
   352  		assert.Equal(mmt.correct, T.Data(), "Test %d", i)
   353  
   354  		// incr
   355  		incr := New(WithShape(mmt.shapeI...), AsFortran(mmt.incr))
   356  		T, err = a.MatMul(b, WithIncr(incr))
   357  		if checkErr(t, mmt.errIncr, err, "WithIncr", i) {
   358  			continue
   359  		}
   360  		assert.True(mmt.correctShape.Eq(T.Shape()))
   361  		assert.Equal(mmt.correctIncr, T.Data())
   362  
   363  		// reuse
   364  		reuse := New(WithShape(mmt.shapeR...), AsFortran(mmt.reuse))
   365  		T, err = a.MatMul(b, WithReuse(reuse))
   366  
   367  		if checkErr(t, mmt.errReuse, err, "WithReuse", i) {
   368  			continue
   369  		}
   370  		assert.True(mmt.correctShape.Eq(T.Shape()))
   371  		assert.Equal(mmt.correct, T.Data())
   372  
   373  		// reuse AND incr
   374  		T, err = a.MatMul(b, WithIncr(incr), WithReuse(reuse))
   375  		if checkErr(t, mmt.err, err, "WithIncr and WithReuse", i) {
   376  			continue
   377  		}
   378  		assert.True(mmt.correctShape.Eq(T.Shape()))
   379  		assert.Equal(mmt.correctIncrReuse, T.Data())
   380  	}
   381  }
   382  
   383  var colMajorOuterTests = []linalgTest{
   384  	// Float64s
   385  	{Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false,
   386  		Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3},
   387  		[]float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3},
   388  		false, false, false},
   389  
   390  	// Float32s
   391  	{Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false,
   392  		Range(Float32, 52, 61), Range(Float32, 100, 109), Shape{3, 3}, Shape{3, 3},
   393  		[]float32{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float32{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float32{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3},
   394  		false, false, false},
   395  
   396  	// stupids - a or b not vector
   397  	{Range(Float64, 0, 3), Range(Float64, 0, 6), Shape{3}, Shape{3, 2}, false, false,
   398  		Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3},
   399  		[]float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3},
   400  		true, false, false},
   401  
   402  	//	stupids - bad incr shape
   403  	{Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false,
   404  		Range(Float64, 52, 61), Range(Float64, 100, 106), Shape{3, 3}, Shape{3, 2},
   405  		[]float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3},
   406  		false, true, false},
   407  
   408  	// stupids - bad reuse shape
   409  	{Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false,
   410  		Range(Float64, 52, 58), Range(Float64, 100, 109), Shape{3, 2}, Shape{3, 3},
   411  		[]float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3},
   412  		false, false, true},
   413  
   414  	// stupids - b not Float
   415  	{Range(Float64, 0, 3), Range(Int, 0, 3), Shape{3}, Shape{3}, false, false,
   416  		Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3},
   417  		[]float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3},
   418  		true, false, false},
   419  
   420  	// stupids - a not Float
   421  	{Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false,
   422  		Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3},
   423  		[]float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3},
   424  		true, false, false},
   425  
   426  	// stupids - a-b type mismatch
   427  	{Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false,
   428  		Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3},
   429  		[]float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3},
   430  		true, false, false},
   431  
   432  	// stupids a-b type mismatch
   433  	{Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false,
   434  		Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3},
   435  		[]float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3},
   436  		true, false, false},
   437  }
   438  
   439  func TestColMajor_Dense_Outer(t *testing.T) {
   440  	assert := assert.New(t)
   441  	for i, ot := range colMajorOuterTests {
   442  		a := New(WithShape(ot.shapeA...), AsFortran(ot.a))
   443  		b := New(WithShape(ot.shapeB...), AsFortran(ot.b))
   444  
   445  		T, err := a.Outer(b)
   446  		if checkErr(t, ot.err, err, "Safe", i) {
   447  			continue
   448  		}
   449  		assert.True(ot.correctShape.Eq(T.Shape()))
   450  		assert.True(T.DataOrder().IsColMajor())
   451  		assert.Equal(ot.correct, T.Data())
   452  
   453  		// incr
   454  		incr := New(WithShape(ot.shapeI...), AsFortran(ot.incr))
   455  		T, err = a.Outer(b, WithIncr(incr))
   456  		if checkErr(t, ot.errIncr, err, "WithIncr", i) {
   457  			continue
   458  		}
   459  		assert.True(ot.correctShape.Eq(T.Shape()))
   460  		assert.True(T.DataOrder().IsColMajor())
   461  		assert.Equal(ot.correctIncr, T.Data())
   462  
   463  		// reuse
   464  		reuse := New(WithShape(ot.shapeR...), AsFortran(ot.reuse))
   465  		T, err = a.Outer(b, WithReuse(reuse))
   466  		if checkErr(t, ot.errReuse, err, "WithReuse", i) {
   467  			continue
   468  		}
   469  		assert.True(ot.correctShape.Eq(T.Shape()))
   470  		assert.True(T.DataOrder().IsColMajor())
   471  		assert.Equal(ot.correct, T.Data())
   472  
   473  		// reuse AND incr
   474  		T, err = a.Outer(b, WithIncr(incr), WithReuse(reuse))
   475  		if err != nil {
   476  			t.Errorf("Reuse and Incr error'd %+v", err)
   477  			continue
   478  		}
   479  		assert.True(ot.correctShape.Eq(T.Shape()))
   480  		assert.True(T.DataOrder().IsColMajor())
   481  		assert.Equal(ot.correctIncrReuse, T.Data())
   482  	}
   483  }