github.com/wzzhu/tensor@v0.9.24/dense_colmajor_linalg_test.go (about) 1 package tensor 2 3 import ( 4 "testing" 5 6 "github.com/stretchr/testify/assert" 7 ) 8 9 var colMajorTraceTests = []struct { 10 data interface{} 11 12 correct interface{} 13 err bool 14 }{ 15 {[]int{0, 1, 2, 3, 4, 5}, int(4), false}, 16 {[]int8{0, 1, 2, 3, 4, 5}, int8(4), false}, 17 {[]int16{0, 1, 2, 3, 4, 5}, int16(4), false}, 18 {[]int32{0, 1, 2, 3, 4, 5}, int32(4), false}, 19 {[]int64{0, 1, 2, 3, 4, 5}, int64(4), false}, 20 {[]uint{0, 1, 2, 3, 4, 5}, uint(4), false}, 21 {[]uint8{0, 1, 2, 3, 4, 5}, uint8(4), false}, 22 {[]uint16{0, 1, 2, 3, 4, 5}, uint16(4), false}, 23 {[]uint32{0, 1, 2, 3, 4, 5}, uint32(4), false}, 24 {[]uint64{0, 1, 2, 3, 4, 5}, uint64(4), false}, 25 {[]float32{0, 1, 2, 3, 4, 5}, float32(4), false}, 26 {[]float64{0, 1, 2, 3, 4, 5}, float64(4), false}, 27 {[]complex64{0, 1, 2, 3, 4, 5}, complex64(4), false}, 28 {[]complex128{0, 1, 2, 3, 4, 5}, complex128(4), false}, 29 {[]bool{true, false, true, false, true, false}, nil, true}, 30 } 31 32 func TestColMajor_Dense_Trace(t *testing.T) { 33 assert := assert.New(t) 34 for i, tts := range colMajorTraceTests { 35 T := New(WithShape(2, 3), AsFortran(tts.data)) 36 trace, err := T.Trace() 37 38 if checkErr(t, tts.err, err, "Trace", i) { 39 continue 40 } 41 assert.Equal(tts.correct, trace) 42 43 // 44 T = New(WithBacking(tts.data)) 45 _, err = T.Trace() 46 if err == nil { 47 t.Error("Expected an error when Trace() on non-matrices") 48 } 49 } 50 } 51 52 var colMajorInnerTests = []struct { 53 a, b interface{} 54 shapeA, shapeB Shape 55 56 correct interface{} 57 err bool 58 }{ 59 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, float64(5), false}, 60 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{3}, float64(5), false}, 61 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{3}, float64(5), false}, 62 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{3, 1}, float64(5), false}, 63 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{3, 1}, float64(5), false}, 64 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{1, 3}, float64(5), false}, 65 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{1, 3}, float64(5), false}, 66 67 {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, float32(5), false}, 68 {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{3}, float32(5), false}, 69 {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{3}, float32(5), false}, 70 {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{3, 1}, float32(5), false}, 71 {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{3, 1}, float32(5), false}, 72 {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{1, 3}, float32(5), false}, 73 {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{1, 3}, float32(5), false}, 74 75 // stupids: type differences 76 {Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, nil, true}, 77 {Range(Float32, 0, 3), Range(Byte, 0, 3), Shape{3}, Shape{3}, nil, true}, 78 {Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, nil, true}, 79 {Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, nil, true}, 80 81 // differing size 82 {Range(Float64, 0, 4), Range(Float64, 0, 3), Shape{4}, Shape{3}, nil, true}, 83 84 // A is not a matrix 85 {Range(Float64, 0, 4), Range(Float64, 0, 3), Shape{2, 2}, Shape{3}, nil, true}, 86 } 87 88 func TestColMajor_Dense_Inner(t *testing.T) { 89 for i, its := range colMajorInnerTests { 90 a := New(WithShape(its.shapeA...), AsFortran(its.a)) 91 b := New(WithShape(its.shapeB...), AsFortran(its.b)) 92 93 T, err := a.Inner(b) 94 if checkErr(t, its.err, err, "Inner", i) { 95 continue 96 } 97 98 assert.Equal(t, its.correct, T) 99 } 100 } 101 102 var colMajorMatVecMulTests = []linalgTest{ 103 // Float64s 104 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, 105 Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, 106 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, 107 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false, 108 Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, 109 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, 110 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false, 111 Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, 112 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, 113 114 // float64s with transposed matrix 115 {Range(Float64, 0, 6), Range(Float64, 0, 2), Shape{2, 3}, Shape{2}, true, false, 116 Range(Float64, 52, 55), Range(Float64, 100, 103), Shape{3}, Shape{3}, 117 []float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{3}, false, false, false}, 118 119 // Float32s 120 {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, 121 Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, 122 []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, 123 {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false, 124 Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, 125 []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, 126 {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false, 127 Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, 128 []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, 129 130 // stupids : unpossible shapes (wrong A) 131 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{6}, Shape{3}, false, false, 132 Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, 133 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, 134 135 //stupids: bad A shape 136 {Range(Float64, 0, 8), Range(Float64, 0, 3), Shape{4, 2}, Shape{3}, false, false, 137 Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, 138 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, 139 140 //stupids: bad B shape 141 {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 142 Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, 143 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, 144 145 //stupids: bad reuse 146 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, 147 Range(Float64, 52, 55), Range(Float64, 100, 102), Shape{3}, Shape{2}, 148 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true}, 149 150 //stupids: bad incr shape 151 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, 152 Range(Float64, 52, 54), Range(Float64, 100, 105), Shape{2}, Shape{5}, 153 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, 154 155 // stupids: type mismatch A and B 156 {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, 157 Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, 158 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, 159 160 // stupids: type mismatch A and B 161 {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, 162 Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, 163 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, 164 165 // stupids: type mismatch A and B 166 {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, 167 Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, 168 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, 169 170 // stupids: type mismatch A and B 171 {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, 172 Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, 173 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, 174 175 // stupids: type mismatch A and B (non-Float) 176 {Range(Float64, 0, 6), Range(Int, 0, 3), Shape{2, 3}, Shape{3}, false, false, 177 Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, 178 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, 179 180 // stupids: type mismatch, reuse 181 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, 182 Range(Float32, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, 183 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true}, 184 185 // stupids: type mismatch, incr 186 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, 187 Range(Float64, 52, 54), Range(Float32, 100, 103), Shape{2}, Shape{3}, 188 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, 189 190 // stupids: type mismatch, incr not a Number 191 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, 192 Range(Float64, 52, 54), []bool{true, true, true}, Shape{2}, Shape{3}, 193 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, 194 } 195 196 func TestColMajor_Dense_MatVecMul(t *testing.T) { 197 assert := assert.New(t) 198 for i, mvmt := range colMajorMatVecMulTests { 199 a := New(WithShape(mvmt.shapeA...), AsFortran(mvmt.a)) 200 b := New(WithShape(mvmt.shapeB...), AsFortran(mvmt.b)) 201 202 if mvmt.transA { 203 if err := a.T(); err != nil { 204 t.Error(err) 205 continue 206 } 207 } 208 209 T, err := a.MatVecMul(b) 210 if checkErr(t, mvmt.err, err, "Safe", i) { 211 continue 212 } 213 214 assert.True(mvmt.correctShape.Eq(T.Shape())) 215 assert.True(T.DataOrder().IsColMajor()) 216 assert.Equal(mvmt.correct, T.Data()) 217 218 // incr 219 incr := New(WithShape(mvmt.shapeI...), AsFortran(mvmt.incr)) 220 T, err = a.MatVecMul(b, WithIncr(incr)) 221 if checkErr(t, mvmt.errIncr, err, "WithIncr", i) { 222 continue 223 } 224 225 assert.True(mvmt.correctShape.Eq(T.Shape())) 226 assert.True(T.DataOrder().IsColMajor()) 227 assert.Equal(mvmt.correctIncr, T.Data()) 228 229 // reuse 230 reuse := New(WithShape(mvmt.shapeR...), AsFortran(mvmt.reuse)) 231 T, err = a.MatVecMul(b, WithReuse(reuse)) 232 if checkErr(t, mvmt.errReuse, err, "WithReuse", i) { 233 continue 234 } 235 236 assert.True(mvmt.correctShape.Eq(T.Shape())) 237 assert.True(T.DataOrder().IsColMajor()) 238 assert.Equal(mvmt.correct, T.Data()) 239 240 // reuse AND incr 241 T, err = a.MatVecMul(b, WithIncr(incr), WithReuse(reuse)) 242 if checkErr(t, mvmt.err, err, "WithReuse and WithIncr", i) { 243 continue 244 } 245 assert.True(mvmt.correctShape.Eq(T.Shape())) 246 assert.True(T.DataOrder().IsColMajor()) 247 assert.Equal(mvmt.correctIncrReuse, T.Data()) 248 } 249 } 250 251 var colMajorMatMulTests = []linalgTest{ 252 // Float64s 253 {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 254 Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, 255 []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, false, false}, 256 257 // Float32s 258 {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 259 Range(Float32, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, 260 []float32{10, 28, 13, 40}, []float32{110, 130, 114, 143}, []float32{120, 158, 127, 183}, Shape{2, 2}, false, false, false}, 261 262 // Edge cases - Row Vecs (Float64) 263 {Range(Float64, 0, 2), Range(Float64, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false, 264 Range(Float64, 10, 16), Range(Float64, 100, 106), Shape{2, 3}, Shape{2, 3}, 265 []float64{0, 0, 0, 1, 0, 2}, []float64{100, 103, 101, 105, 102, 107}, []float64{100, 103, 101, 106, 102, 109}, Shape{2, 3}, false, false, false}, 266 {Range(Float64, 0, 2), Range(Float64, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false, 267 Range(Float64, 10, 13), Range(Float64, 100, 103), Shape{1, 3}, Shape{1, 3}, 268 []float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{1, 3}, false, false, false}, 269 {Range(Float64, 0, 2), Range(Float64, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false, 270 Range(Float64, 0, 1), Range(Float64, 100, 101), Shape{1, 1}, Shape{1, 1}, 271 []float64{1}, []float64{101}, []float64{102}, Shape{1, 1}, false, false, false}, 272 273 // Edge cases - Row Vecs (Float32) 274 {Range(Float32, 0, 2), Range(Float32, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false, 275 Range(Float32, 10, 16), Range(Float32, 100, 106), Shape{2, 3}, Shape{2, 3}, 276 []float32{0, 0, 0, 1, 0, 2}, []float32{100, 103, 101, 105, 102, 107}, []float32{100, 103, 101, 106, 102, 109}, Shape{2, 3}, false, false, false}, 277 {Range(Float32, 0, 2), Range(Float32, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false, 278 Range(Float32, 10, 13), Range(Float32, 100, 103), Shape{1, 3}, Shape{1, 3}, 279 []float32{3, 4, 5}, []float32{103, 105, 107}, []float32{106, 109, 112}, Shape{1, 3}, false, false, false}, 280 {Range(Float32, 0, 2), Range(Float32, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false, 281 Range(Float32, 0, 1), Range(Float32, 100, 101), Shape{1, 1}, Shape{1, 1}, 282 []float32{1}, []float32{101}, []float32{102}, Shape{1, 1}, false, false, false}, 283 284 // stupids - bad shape (not matrices): 285 {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6}, false, false, 286 Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, 287 []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, 288 289 // stupids - bad shape (incompatible shapes): 290 {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6, 1}, false, false, 291 Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, 292 []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, 293 294 // stupids - bad shape (bad reuse shape): 295 {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 296 Range(Float64, 52, 57), Range(Float64, 100, 104), Shape{5}, Shape{2, 2}, 297 []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, false, true}, 298 299 // stupids - bad shape (bad incr shape): 300 {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 301 Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{4}, 302 []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, true, false}, 303 304 // stupids - type mismatch (a,b) 305 {Range(Float64, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 306 Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, 307 []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, 308 309 // stupids - type mismatch (a,b) 310 {Range(Float32, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 311 Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, 312 []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, 313 314 // stupids type mismatch (b not float) 315 {Range(Float64, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 316 Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, 317 []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, 318 319 // stupids type mismatch (a not float) 320 {Range(Int, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 321 Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, 322 []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, true, false, false}, 323 324 // stupids: type mismatch (incr) 325 {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 326 Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, 327 []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, true, false}, 328 329 // stupids: type mismatch (reuse) 330 {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 331 Range(Float32, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, 332 []float64{10, 28, 13, 40}, []float64{110, 130, 114, 143}, []float64{120, 158, 127, 183}, Shape{2, 2}, false, false, true}, 333 334 // stupids: type mismatch (reuse) 335 {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 336 Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, 337 []float32{10, 28, 13, 40}, []float32{110, 130, 114, 143}, []float32{120, 158, 127, 183}, Shape{2, 2}, false, false, true}, 338 } 339 340 func TestColMajorDense_MatMul(t *testing.T) { 341 assert := assert.New(t) 342 for i, mmt := range colMajorMatMulTests { 343 a := New(WithShape(mmt.shapeA...), AsFortran(mmt.a)) 344 b := New(WithShape(mmt.shapeB...), AsFortran(mmt.b)) 345 346 T, err := a.MatMul(b) 347 if checkErr(t, mmt.err, err, "Safe", i) { 348 continue 349 } 350 assert.True(mmt.correctShape.Eq(T.Shape())) 351 assert.True(T.DataOrder().IsColMajor()) 352 assert.Equal(mmt.correct, T.Data(), "Test %d", i) 353 354 // incr 355 incr := New(WithShape(mmt.shapeI...), AsFortran(mmt.incr)) 356 T, err = a.MatMul(b, WithIncr(incr)) 357 if checkErr(t, mmt.errIncr, err, "WithIncr", i) { 358 continue 359 } 360 assert.True(mmt.correctShape.Eq(T.Shape())) 361 assert.Equal(mmt.correctIncr, T.Data()) 362 363 // reuse 364 reuse := New(WithShape(mmt.shapeR...), AsFortran(mmt.reuse)) 365 T, err = a.MatMul(b, WithReuse(reuse)) 366 367 if checkErr(t, mmt.errReuse, err, "WithReuse", i) { 368 continue 369 } 370 assert.True(mmt.correctShape.Eq(T.Shape())) 371 assert.Equal(mmt.correct, T.Data()) 372 373 // reuse AND incr 374 T, err = a.MatMul(b, WithIncr(incr), WithReuse(reuse)) 375 if checkErr(t, mmt.err, err, "WithIncr and WithReuse", i) { 376 continue 377 } 378 assert.True(mmt.correctShape.Eq(T.Shape())) 379 assert.Equal(mmt.correctIncrReuse, T.Data()) 380 } 381 } 382 383 var colMajorOuterTests = []linalgTest{ 384 // Float64s 385 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, 386 Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, 387 []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, 388 false, false, false}, 389 390 // Float32s 391 {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false, 392 Range(Float32, 52, 61), Range(Float32, 100, 109), Shape{3, 3}, Shape{3, 3}, 393 []float32{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float32{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float32{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, 394 false, false, false}, 395 396 // stupids - a or b not vector 397 {Range(Float64, 0, 3), Range(Float64, 0, 6), Shape{3}, Shape{3, 2}, false, false, 398 Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, 399 []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, 400 true, false, false}, 401 402 // stupids - bad incr shape 403 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, 404 Range(Float64, 52, 61), Range(Float64, 100, 106), Shape{3, 3}, Shape{3, 2}, 405 []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, 406 false, true, false}, 407 408 // stupids - bad reuse shape 409 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, 410 Range(Float64, 52, 58), Range(Float64, 100, 109), Shape{3, 2}, Shape{3, 3}, 411 []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, 412 false, false, true}, 413 414 // stupids - b not Float 415 {Range(Float64, 0, 3), Range(Int, 0, 3), Shape{3}, Shape{3}, false, false, 416 Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, 417 []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, 418 true, false, false}, 419 420 // stupids - a not Float 421 {Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, 422 Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, 423 []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, 424 true, false, false}, 425 426 // stupids - a-b type mismatch 427 {Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false, 428 Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, 429 []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, 430 true, false, false}, 431 432 // stupids a-b type mismatch 433 {Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, 434 Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, 435 []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 103, 106, 101, 105, 109, 102, 107, 112}, []float64{100, 103, 106, 101, 106, 111, 102, 109, 116}, Shape{3, 3}, 436 true, false, false}, 437 } 438 439 func TestColMajor_Dense_Outer(t *testing.T) { 440 assert := assert.New(t) 441 for i, ot := range colMajorOuterTests { 442 a := New(WithShape(ot.shapeA...), AsFortran(ot.a)) 443 b := New(WithShape(ot.shapeB...), AsFortran(ot.b)) 444 445 T, err := a.Outer(b) 446 if checkErr(t, ot.err, err, "Safe", i) { 447 continue 448 } 449 assert.True(ot.correctShape.Eq(T.Shape())) 450 assert.True(T.DataOrder().IsColMajor()) 451 assert.Equal(ot.correct, T.Data()) 452 453 // incr 454 incr := New(WithShape(ot.shapeI...), AsFortran(ot.incr)) 455 T, err = a.Outer(b, WithIncr(incr)) 456 if checkErr(t, ot.errIncr, err, "WithIncr", i) { 457 continue 458 } 459 assert.True(ot.correctShape.Eq(T.Shape())) 460 assert.True(T.DataOrder().IsColMajor()) 461 assert.Equal(ot.correctIncr, T.Data()) 462 463 // reuse 464 reuse := New(WithShape(ot.shapeR...), AsFortran(ot.reuse)) 465 T, err = a.Outer(b, WithReuse(reuse)) 466 if checkErr(t, ot.errReuse, err, "WithReuse", i) { 467 continue 468 } 469 assert.True(ot.correctShape.Eq(T.Shape())) 470 assert.True(T.DataOrder().IsColMajor()) 471 assert.Equal(ot.correct, T.Data()) 472 473 // reuse AND incr 474 T, err = a.Outer(b, WithIncr(incr), WithReuse(reuse)) 475 if err != nil { 476 t.Errorf("Reuse and Incr error'd %+v", err) 477 continue 478 } 479 assert.True(ot.correctShape.Eq(T.Shape())) 480 assert.True(T.DataOrder().IsColMajor()) 481 assert.Equal(ot.correctIncrReuse, T.Data()) 482 } 483 }