github.com/wzzhu/tensor@v0.9.24/dense_linalg_test.go (about) 1 package tensor 2 3 import ( 4 "testing" 5 6 "github.com/stretchr/testify/assert" 7 "gorgonia.org/vecf64" 8 ) 9 10 type linalgTest struct { 11 a, b interface{} 12 shapeA, shapeB Shape 13 transA, transB bool 14 15 reuse, incr interface{} 16 shapeR, shapeI Shape 17 18 correct interface{} 19 correctIncr interface{} 20 correctIncrReuse interface{} 21 correctShape Shape 22 err bool 23 errIncr bool 24 errReuse bool 25 } 26 27 var traceTests = []struct { 28 data interface{} 29 30 correct interface{} 31 err bool 32 }{ 33 {[]int{0, 1, 2, 3, 4, 5}, int(4), false}, 34 {[]int8{0, 1, 2, 3, 4, 5}, int8(4), false}, 35 {[]int16{0, 1, 2, 3, 4, 5}, int16(4), false}, 36 {[]int32{0, 1, 2, 3, 4, 5}, int32(4), false}, 37 {[]int64{0, 1, 2, 3, 4, 5}, int64(4), false}, 38 {[]uint{0, 1, 2, 3, 4, 5}, uint(4), false}, 39 {[]uint8{0, 1, 2, 3, 4, 5}, uint8(4), false}, 40 {[]uint16{0, 1, 2, 3, 4, 5}, uint16(4), false}, 41 {[]uint32{0, 1, 2, 3, 4, 5}, uint32(4), false}, 42 {[]uint64{0, 1, 2, 3, 4, 5}, uint64(4), false}, 43 {[]float32{0, 1, 2, 3, 4, 5}, float32(4), false}, 44 {[]float64{0, 1, 2, 3, 4, 5}, float64(4), false}, 45 {[]complex64{0, 1, 2, 3, 4, 5}, complex64(4), false}, 46 {[]complex128{0, 1, 2, 3, 4, 5}, complex128(4), false}, 47 {[]bool{true, false, true, false, true, false}, nil, true}, 48 } 49 50 func TestDense_Trace(t *testing.T) { 51 assert := assert.New(t) 52 for i, tts := range traceTests { 53 T := New(WithBacking(tts.data), WithShape(2, 3)) 54 trace, err := T.Trace() 55 56 if checkErr(t, tts.err, err, "Trace", i) { 57 continue 58 } 59 assert.Equal(tts.correct, trace) 60 61 // 62 T = New(WithBacking(tts.data)) 63 _, err = T.Trace() 64 if err == nil { 65 t.Error("Expected an error when Trace() on non-matrices") 66 } 67 } 68 } 69 70 var innerTests = []struct { 71 a, b interface{} 72 shapeA, shapeB Shape 73 74 correct interface{} 75 err bool 76 }{ 77 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, float64(5), false}, 78 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{3}, float64(5), false}, 79 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{3}, float64(5), false}, 80 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{3, 1}, float64(5), false}, 81 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{3, 1}, float64(5), false}, 82 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3, 1}, Shape{1, 3}, float64(5), false}, 83 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{1, 3}, Shape{1, 3}, float64(5), false}, 84 85 {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, float32(5), false}, 86 {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{3}, float32(5), false}, 87 {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{3}, float32(5), false}, 88 {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{3, 1}, float32(5), false}, 89 {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{3, 1}, float32(5), false}, 90 {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3, 1}, Shape{1, 3}, float32(5), false}, 91 {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{1, 3}, Shape{1, 3}, float32(5), false}, 92 93 // stupids: type differences 94 {Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, nil, true}, 95 {Range(Float32, 0, 3), Range(Byte, 0, 3), Shape{3}, Shape{3}, nil, true}, 96 {Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, nil, true}, 97 {Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, nil, true}, 98 99 // differing size 100 {Range(Float64, 0, 4), Range(Float64, 0, 3), Shape{4}, Shape{3}, nil, true}, 101 102 // A is not a matrix 103 {Range(Float64, 0, 4), Range(Float64, 0, 3), Shape{2, 2}, Shape{3}, nil, true}, 104 } 105 106 func TestDense_Inner(t *testing.T) { 107 for i, its := range innerTests { 108 a := New(WithShape(its.shapeA...), WithBacking(its.a)) 109 b := New(WithShape(its.shapeB...), WithBacking(its.b)) 110 111 T, err := a.Inner(b) 112 if checkErr(t, its.err, err, "Inner", i) { 113 continue 114 } 115 116 assert.Equal(t, its.correct, T) 117 } 118 } 119 120 var matVecMulTests = []linalgTest{ 121 // Float64s 122 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, 123 Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, 124 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, 125 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false, 126 Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, 127 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, 128 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false, 129 Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, 130 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, false}, 131 132 // float64s with transposed matrix 133 {Range(Float64, 0, 6), Range(Float64, 0, 2), Shape{2, 3}, Shape{2}, true, false, 134 Range(Float64, 52, 55), Range(Float64, 100, 103), Shape{3}, Shape{3}, 135 []float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{3}, false, false, false}, 136 137 // Float32s 138 {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, 139 Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, 140 []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, 141 {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3, 1}, false, false, 142 Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, 143 []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, 144 {Range(Float32, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{1, 3}, false, false, 145 Range(Float32, 52, 54), Range(Float32, 100, 102), Shape{2}, Shape{2}, 146 []float32{5, 14}, []float32{105, 115}, []float32{110, 129}, Shape{2}, false, false, false}, 147 148 // stupids : unpossible shapes (wrong A) 149 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{6}, Shape{3}, false, false, 150 Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, 151 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, 152 153 //stupids: bad A shape 154 {Range(Float64, 0, 8), Range(Float64, 0, 3), Shape{4, 2}, Shape{3}, false, false, 155 Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, 156 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, 157 158 //stupids: bad B shape 159 {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 160 Range(Float64, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, 161 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, 162 163 //stupids: bad reuse 164 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, 165 Range(Float64, 52, 55), Range(Float64, 100, 102), Shape{3}, Shape{2}, 166 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true}, 167 168 //stupids: bad incr shape 169 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, 170 Range(Float64, 52, 54), Range(Float64, 100, 105), Shape{2}, Shape{5}, 171 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, 172 173 // stupids: type mismatch A and B 174 {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, 175 Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, 176 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, 177 178 // stupids: type mismatch A and B 179 {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, 180 Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, 181 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, 182 183 // stupids: type mismatch A and B 184 {Range(Float64, 0, 6), Range(Float32, 0, 3), Shape{2, 3}, Shape{3}, false, false, 185 Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, 186 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, 187 188 // stupids: type mismatch A and B 189 {Range(Float32, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, 190 Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, 191 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, 192 193 // stupids: type mismatch A and B (non-Float) 194 {Range(Float64, 0, 6), Range(Int, 0, 3), Shape{2, 3}, Shape{3}, false, false, 195 Range(Float64, 52, 54), Range(Float64, 100, 103), Shape{2}, Shape{3}, 196 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, true, false, false}, 197 198 // stupids: type mismatch, reuse 199 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, 200 Range(Float32, 52, 54), Range(Float64, 100, 102), Shape{2}, Shape{2}, 201 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, false, true}, 202 203 // stupids: type mismatch, incr 204 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, 205 Range(Float64, 52, 54), Range(Float32, 100, 103), Shape{2}, Shape{3}, 206 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, 207 208 // stupids: type mismatch, incr not a Number 209 {Range(Float64, 0, 6), Range(Float64, 0, 3), Shape{2, 3}, Shape{3}, false, false, 210 Range(Float64, 52, 54), []bool{true, true, true}, Shape{2}, Shape{3}, 211 []float64{5, 14}, []float64{105, 115}, []float64{110, 129}, Shape{2}, false, true, false}, 212 } 213 214 func TestDense_MatVecMul(t *testing.T) { 215 assert := assert.New(t) 216 for i, mvmt := range matVecMulTests { 217 a := New(WithBacking(mvmt.a), WithShape(mvmt.shapeA...)) 218 b := New(WithBacking(mvmt.b), WithShape(mvmt.shapeB...)) 219 220 if mvmt.transA { 221 if err := a.T(); err != nil { 222 t.Error(err) 223 continue 224 } 225 } 226 T, err := a.MatVecMul(b) 227 if checkErr(t, mvmt.err, err, "Safe", i) { 228 continue 229 } 230 231 assert.True(mvmt.correctShape.Eq(T.Shape())) 232 assert.True(T.DataOrder().IsRowMajor()) 233 assert.Equal(mvmt.correct, T.Data()) 234 235 // incr 236 incr := New(WithBacking(mvmt.incr), WithShape(mvmt.shapeI...)) 237 T, err = a.MatVecMul(b, WithIncr(incr)) 238 if checkErr(t, mvmt.errIncr, err, "WithIncr", i) { 239 continue 240 } 241 242 assert.True(mvmt.correctShape.Eq(T.Shape())) 243 assert.True(T.DataOrder().IsRowMajor()) 244 assert.Equal(mvmt.correctIncr, T.Data()) 245 246 // reuse 247 reuse := New(WithBacking(mvmt.reuse), WithShape(mvmt.shapeR...)) 248 T, err = a.MatVecMul(b, WithReuse(reuse)) 249 if checkErr(t, mvmt.errReuse, err, "WithReuse", i) { 250 continue 251 } 252 253 assert.True(mvmt.correctShape.Eq(T.Shape())) 254 assert.True(T.DataOrder().IsRowMajor()) 255 assert.Equal(mvmt.correct, T.Data()) 256 257 // reuse AND incr 258 T, err = a.MatVecMul(b, WithIncr(incr), WithReuse(reuse)) 259 if checkErr(t, mvmt.err, err, "WithReuse and WithIncr", i) { 260 continue 261 } 262 assert.True(mvmt.correctShape.Eq(T.Shape())) 263 assert.Equal(mvmt.correctIncrReuse, T.Data()) 264 } 265 } 266 267 var matMulTests = []linalgTest{ 268 // Float64s 269 {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 270 Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, 271 []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, false, false}, 272 273 // Float32s 274 {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 275 Range(Float32, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, 276 []float32{10, 13, 28, 40}, []float32{110, 114, 130, 143}, []float32{120, 127, 158, 183}, Shape{2, 2}, false, false, false}, 277 278 // Edge cases - Row Vecs (Float64) 279 {Range(Float64, 0, 2), Range(Float64, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false, 280 Range(Float64, 10, 16), Range(Float64, 100, 106), Shape{2, 3}, Shape{2, 3}, 281 []float64{0, 0, 0, 0, 1, 2}, []float64{100, 101, 102, 103, 105, 107}, []float64{100, 101, 102, 103, 106, 109}, Shape{2, 3}, false, false, false}, 282 {Range(Float64, 0, 2), Range(Float64, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false, 283 Range(Float64, 10, 13), Range(Float64, 100, 103), Shape{1, 3}, Shape{1, 3}, 284 []float64{3, 4, 5}, []float64{103, 105, 107}, []float64{106, 109, 112}, Shape{1, 3}, false, false, false}, 285 {Range(Float64, 0, 2), Range(Float64, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false, 286 Range(Float64, 0, 1), Range(Float64, 100, 101), Shape{1, 1}, Shape{1, 1}, 287 []float64{1}, []float64{101}, []float64{102}, Shape{1, 1}, false, false, false}, 288 289 // Edge cases - Row Vecs (Float32) 290 {Range(Float32, 0, 2), Range(Float32, 0, 3), Shape{2, 1}, Shape{1, 3}, false, false, 291 Range(Float32, 10, 16), Range(Float32, 100, 106), Shape{2, 3}, Shape{2, 3}, 292 []float32{0, 0, 0, 0, 1, 2}, []float32{100, 101, 102, 103, 105, 107}, []float32{100, 101, 102, 103, 106, 109}, Shape{2, 3}, false, false, false}, 293 {Range(Float32, 0, 2), Range(Float32, 0, 6), Shape{1, 2}, Shape{2, 3}, false, false, 294 Range(Float32, 10, 13), Range(Float32, 100, 103), Shape{1, 3}, Shape{1, 3}, 295 []float32{3, 4, 5}, []float32{103, 105, 107}, []float32{106, 109, 112}, Shape{1, 3}, false, false, false}, 296 {Range(Float32, 0, 2), Range(Float32, 0, 2), Shape{1, 2}, Shape{2, 1}, false, false, 297 Range(Float32, 0, 1), Range(Float32, 100, 101), Shape{1, 1}, Shape{1, 1}, 298 []float32{1}, []float32{101}, []float32{102}, Shape{1, 1}, false, false, false}, 299 300 // stupids - bad shape (not matrices): 301 {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6}, false, false, 302 Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, 303 []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, 304 305 // stupids - bad shape (incompatible shapes): 306 {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{6, 1}, false, false, 307 Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, 308 []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, 309 310 // stupids - bad shape (bad reuse shape): 311 {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 312 Range(Float64, 52, 57), Range(Float64, 100, 104), Shape{5}, Shape{2, 2}, 313 []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, false, true}, 314 315 // stupids - bad shape (bad incr shape): 316 {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 317 Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{4}, 318 []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, true, false}, 319 320 // stupids - type mismatch (a,b) 321 {Range(Float64, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 322 Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, 323 []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, 324 325 // stupids - type mismatch (a,b) 326 {Range(Float32, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 327 Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, 328 []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, 329 330 // stupids type mismatch (b not float) 331 {Range(Float64, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 332 Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, 333 []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, 334 335 // stupids type mismatch (a not float) 336 {Range(Int, 0, 6), Range(Int, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 337 Range(Float64, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, 338 []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, true, false, false}, 339 340 // stupids: type mismatch (incr) 341 {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 342 Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, 343 []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, true, false}, 344 345 // stupids: type mismatch (reuse) 346 {Range(Float64, 0, 6), Range(Float64, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 347 Range(Float32, 52, 56), Range(Float64, 100, 104), Shape{2, 2}, Shape{2, 2}, 348 []float64{10, 13, 28, 40}, []float64{110, 114, 130, 143}, []float64{120, 127, 158, 183}, Shape{2, 2}, false, false, true}, 349 350 // stupids: type mismatch (reuse) 351 {Range(Float32, 0, 6), Range(Float32, 0, 6), Shape{2, 3}, Shape{3, 2}, false, false, 352 Range(Float64, 52, 56), Range(Float32, 100, 104), Shape{2, 2}, Shape{2, 2}, 353 []float32{10, 13, 28, 40}, []float32{110, 114, 130, 143}, []float32{120, 127, 158, 183}, Shape{2, 2}, false, false, true}, 354 } 355 356 func TestDense_MatMul(t *testing.T) { 357 assert := assert.New(t) 358 for i, mmt := range matMulTests { 359 a := New(WithBacking(mmt.a), WithShape(mmt.shapeA...)) 360 b := New(WithBacking(mmt.b), WithShape(mmt.shapeB...)) 361 362 T, err := a.MatMul(b) 363 if checkErr(t, mmt.err, err, "Safe", i) { 364 continue 365 } 366 assert.True(mmt.correctShape.Eq(T.Shape())) 367 assert.Equal(mmt.correct, T.Data()) 368 369 // incr 370 incr := New(WithBacking(mmt.incr), WithShape(mmt.shapeI...)) 371 T, err = a.MatMul(b, WithIncr(incr)) 372 if checkErr(t, mmt.errIncr, err, "WithIncr", i) { 373 continue 374 } 375 assert.True(mmt.correctShape.Eq(T.Shape())) 376 assert.Equal(mmt.correctIncr, T.Data()) 377 378 // reuse 379 reuse := New(WithBacking(mmt.reuse), WithShape(mmt.shapeR...)) 380 T, err = a.MatMul(b, WithReuse(reuse)) 381 382 if checkErr(t, mmt.errReuse, err, "WithReuse", i) { 383 continue 384 } 385 assert.True(mmt.correctShape.Eq(T.Shape())) 386 assert.Equal(mmt.correct, T.Data()) 387 388 // reuse AND incr 389 T, err = a.MatMul(b, WithIncr(incr), WithReuse(reuse)) 390 if checkErr(t, mmt.err, err, "WithIncr and WithReuse", i) { 391 continue 392 } 393 assert.True(mmt.correctShape.Eq(T.Shape())) 394 assert.Equal(mmt.correctIncrReuse, T.Data()) 395 } 396 } 397 398 var outerTests = []linalgTest{ 399 // Float64s 400 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, 401 Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, 402 []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, 403 false, false, false}, 404 405 // Float32s 406 {Range(Float32, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false, 407 Range(Float32, 52, 61), Range(Float32, 100, 109), Shape{3, 3}, Shape{3, 3}, 408 []float32{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float32{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float32{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, 409 false, false, false}, 410 411 // stupids - a or b not vector 412 {Range(Float64, 0, 3), Range(Float64, 0, 6), Shape{3}, Shape{3, 2}, false, false, 413 Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, 414 []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, 415 true, false, false}, 416 417 // stupids - bad incr shape 418 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, 419 Range(Float64, 52, 61), Range(Float64, 100, 106), Shape{3, 3}, Shape{3, 2}, 420 []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, 421 false, true, false}, 422 423 // stupids - bad reuse shape 424 {Range(Float64, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, 425 Range(Float64, 52, 58), Range(Float64, 100, 109), Shape{3, 2}, Shape{3, 3}, 426 []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, 427 false, false, true}, 428 429 // stupids - b not Float 430 {Range(Float64, 0, 3), Range(Int, 0, 3), Shape{3}, Shape{3}, false, false, 431 Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, 432 []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, 433 true, false, false}, 434 435 // stupids - a not Float 436 {Range(Int, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, 437 Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, 438 []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, 439 true, false, false}, 440 441 // stupids - a-b type mismatch 442 {Range(Float64, 0, 3), Range(Float32, 0, 3), Shape{3}, Shape{3}, false, false, 443 Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, 444 []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, 445 true, false, false}, 446 447 // stupids a-b type mismatch 448 {Range(Float32, 0, 3), Range(Float64, 0, 3), Shape{3}, Shape{3}, false, false, 449 Range(Float64, 52, 61), Range(Float64, 100, 109), Shape{3, 3}, Shape{3, 3}, 450 []float64{0, 0, 0, 0, 1, 2, 0, 2, 4}, []float64{100, 101, 102, 103, 105, 107, 106, 109, 112}, []float64{100, 101, 102, 103, 106, 109, 106, 111, 116}, Shape{3, 3}, 451 true, false, false}, 452 } 453 454 func TestDense_Outer(t *testing.T) { 455 assert := assert.New(t) 456 for i, ot := range outerTests { 457 a := New(WithBacking(ot.a), WithShape(ot.shapeA...)) 458 b := New(WithBacking(ot.b), WithShape(ot.shapeB...)) 459 460 T, err := a.Outer(b) 461 if checkErr(t, ot.err, err, "Safe", i) { 462 continue 463 } 464 assert.True(ot.correctShape.Eq(T.Shape())) 465 assert.Equal(ot.correct, T.Data()) 466 467 // incr 468 incr := New(WithBacking(ot.incr), WithShape(ot.shapeI...)) 469 T, err = a.Outer(b, WithIncr(incr)) 470 if checkErr(t, ot.errIncr, err, "WithIncr", i) { 471 continue 472 } 473 assert.True(ot.correctShape.Eq(T.Shape())) 474 assert.Equal(ot.correctIncr, T.Data()) 475 476 // reuse 477 reuse := New(WithBacking(ot.reuse), WithShape(ot.shapeR...)) 478 T, err = a.Outer(b, WithReuse(reuse)) 479 if checkErr(t, ot.errReuse, err, "WithReuse", i) { 480 continue 481 } 482 assert.True(ot.correctShape.Eq(T.Shape())) 483 assert.Equal(ot.correct, T.Data()) 484 485 // reuse AND incr 486 T, err = a.Outer(b, WithIncr(incr), WithReuse(reuse)) 487 if err != nil { 488 t.Errorf("Reuse and Incr error'd %+v", err) 489 continue 490 } 491 assert.True(ot.correctShape.Eq(T.Shape())) 492 assert.Equal(ot.correctIncrReuse, T.Data()) 493 } 494 } 495 496 var tensorMulTests = []struct { 497 a, b interface{} 498 shapeA, shapeB Shape 499 500 reuse, incr interface{} 501 shapeR, shapeI Shape 502 503 correct interface{} 504 correctIncr interface{} 505 correctIncrReuse interface{} 506 correctShape Shape 507 err bool 508 errIncr bool 509 errReuse bool 510 511 axesA, axesB []int 512 }{ 513 {a: Range(Float64, 0, 60), b: Range(Float64, 0, 24), shapeA: Shape{3, 4, 5}, shapeB: Shape{4, 3, 2}, 514 axesA: []int{1, 0}, axesB: []int{0, 1}, 515 correct: []float64{4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306}, correctShape: Shape{5, 2}}, 516 } 517 518 func TestDense_TensorMul(t *testing.T) { 519 assert := assert.New(t) 520 for i, tmt := range tensorMulTests { 521 a := New(WithShape(tmt.shapeA...), WithBacking(tmt.a)) 522 b := New(WithShape(tmt.shapeB...), WithBacking(tmt.b)) 523 524 T, err := a.TensorMul(b, tmt.axesA, tmt.axesB) 525 if checkErr(t, tmt.err, err, "Safe", i) { 526 continue 527 } 528 assert.True(tmt.correctShape.Eq(T.Shape())) 529 assert.Equal(tmt.correct, T.Data()) 530 } 531 } 532 533 func TestDot(t *testing.T) { 534 assert := assert.New(t) 535 var a, b, c, r Tensor 536 var A, B, R, R2 Tensor 537 var s, s2 Tensor 538 var incr Tensor 539 var err error 540 var expectedShape Shape 541 var expectedData []float64 542 var expectedScalar float64 543 544 // vector-vector 545 t.Log("Vec⋅Vec") 546 a = New(Of(Float64), WithShape(3, 1), WithBacking(Range(Float64, 0, 3))) 547 b = New(Of(Float64), WithShape(3, 1), WithBacking(Range(Float64, 0, 3))) 548 r, err = Dot(a, b) 549 expectedShape = Shape{1} 550 expectedScalar = float64(5) 551 assert.Nil(err) 552 assert.Equal(expectedScalar, r.Data()) 553 assert.True(ScalarShape().Eq(r.Shape())) 554 555 // vector-mat (which is the same as matᵀ*vec) 556 t.Log("Vec⋅Mat dot, should be equal to Aᵀb") 557 A = New(Of(Float64), WithShape(3, 2), WithBacking(Range(Float64, 0, 6))) 558 R, err = Dot(b, A) 559 expectedShape = Shape{2} 560 expectedData = []float64{10, 13} 561 assert.Nil(err) 562 assert.Equal(expectedData, R.Data()) 563 assert.Equal(expectedShape, R.Shape()) 564 // mat-mat 565 t.Log("Mat⋅Mat") 566 A = New(Of(Float64), WithShape(4, 5), WithBacking(Range(Float64, 0, 20))) 567 B = New(Of(Float64), WithShape(5, 10), WithBacking(Range(Float64, 2, 52))) 568 R, err = Dot(A, B) 569 expectedShape = Shape{4, 10} 570 expectedData = []float64{ 571 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 870, 572 905, 940, 975, 1010, 1045, 1080, 1115, 1150, 1185, 1420, 1480, 573 1540, 1600, 1660, 1720, 1780, 1840, 1900, 1960, 1970, 2055, 2140, 574 2225, 2310, 2395, 2480, 2565, 2650, 2735, 575 } 576 assert.Nil(err) 577 assert.Equal(expectedData, R.Data()) 578 assert.Equal(expectedShape, R.Shape()) 579 // T-T 580 t.Log("3T⋅3T") 581 A = New(Of(Float64), WithShape(2, 3, 4), WithBacking(Range(Float64, 0, 24))) 582 B = New(Of(Float64), WithShape(3, 4, 2), WithBacking(Range(Float64, 0, 24))) 583 R, err = Dot(A, B) 584 expectedShape = Shape{2, 3, 3, 2} 585 expectedData = []float64{ 586 28, 34, 587 76, 82, 588 124, 130, 589 76, 98, 590 252, 274, 591 428, 450, 592 124, 162, 593 428, 466, 594 732, 770, 595 // 596 172, 226, 597 604, 658, 598 1036, 1090, 599 220, 290, 600 780, 850, 601 1340, 1410, 602 268, 354, 603 956, 1042, 604 1644, 1730, 605 } 606 assert.Nil(err) 607 assert.Equal(expectedData, R.Data()) 608 assert.Equal(expectedShape, R.Shape()) 609 610 // T-T 611 t.Log("3T⋅4T") 612 A = New(Of(Float64), WithShape(2, 3, 4), WithBacking(Range(Float64, 0, 24))) 613 B = New(Of(Float64), WithShape(2, 3, 4, 5), WithBacking(Range(Float64, 0, 120))) 614 R, err = Dot(A, B) 615 expectedShape = Shape{2, 3, 2, 3, 5} 616 expectedData = []float64{ 617 70, 76, 82, 88, 94, 190, 196, 202, 208, 214, 310, 618 316, 322, 328, 334, 430, 436, 442, 448, 454, 550, 556, 619 562, 568, 574, 670, 676, 682, 688, 694, 190, 212, 234, 620 256, 278, 630, 652, 674, 696, 718, 1070, 1092, 1114, 1136, 621 1158, 1510, 1532, 1554, 1576, 1598, 1950, 1972, 1994, 2016, 2038, 622 2390, 2412, 2434, 2456, 2478, 310, 348, 386, 424, 462, 1070, 623 1108, 1146, 1184, 1222, 1830, 1868, 1906, 1944, 1982, 2590, 2628, 624 2666, 2704, 2742, 3350, 3388, 3426, 3464, 3502, 4110, 4148, 4186, 625 4224, 4262, 430, 484, 538, 592, 646, 1510, 1564, 1618, 1672, 626 1726, 2590, 2644, 2698, 2752, 2806, 3670, 3724, 3778, 3832, 3886, 627 4750, 4804, 4858, 4912, 4966, 5830, 5884, 5938, 5992, 6046, 550, 628 620, 690, 760, 830, 1950, 2020, 2090, 2160, 2230, 3350, 3420, 629 3490, 3560, 3630, 4750, 4820, 4890, 4960, 5030, 6150, 6220, 6290, 630 6360, 6430, 7550, 7620, 7690, 7760, 7830, 670, 756, 842, 928, 631 1014, 2390, 2476, 2562, 2648, 2734, 4110, 4196, 4282, 4368, 4454, 632 5830, 5916, 6002, 6088, 6174, 7550, 7636, 7722, 7808, 7894, 9270, 633 9356, 9442, 9528, 9614, 634 } 635 assert.Nil(err) 636 assert.Equal(expectedData, R.Data()) 637 assert.Equal(expectedShape, R.Shape()) 638 // T-v 639 640 t.Log("3T⋅Vec") 641 b = New(Of(Float64), WithShape(4), WithBacking(Range(Float64, 0, 4))) 642 R, err = Dot(A, b) 643 expectedShape = Shape{2, 3} 644 expectedData = []float64{ 645 14, 38, 62, 646 86, 110, 134, 647 } 648 assert.Nil(err) 649 assert.Equal(expectedData, R.Data()) 650 assert.Equal(expectedShape, R.Shape()) 651 652 // v-T 653 t.Log("Vec⋅3T") 654 R2, err = Dot(b, B) 655 expectedShape = Shape{2, 3, 5} 656 expectedData = []float64{ 657 70, 76, 82, 88, 94, 658 190, 196, 202, 208, 214, 659 310, 316, 322, 328, 334, 660 430, 436, 442, 448, 454, 661 550, 556, 562, 568, 574, 662 670, 676, 682, 688, 694, 663 } 664 assert.Nil(err) 665 assert.Equal(expectedData, R2.Data()) 666 assert.Equal(expectedShape, R2.Shape()) 667 // m-3T 668 t.Log("Mat⋅3T") 669 A = New(Of(Float64), WithShape(2, 4), WithBacking(Range(Float64, 0, 8))) 670 B = New(Of(Float64), WithShape(2, 4, 5), WithBacking(Range(Float64, 0, 40))) 671 R, err = Dot(A, B) 672 expectedShape = Shape{2, 2, 5} 673 expectedData = []float64{ 674 70, 76, 82, 88, 94, 675 190, 196, 202, 208, 214, 676 190, 212, 234, 256, 278, 677 630, 652, 674, 696, 718, 678 } 679 assert.Nil(err) 680 assert.Equal(expectedData, R.Data()) 681 assert.Equal(expectedShape, R.Shape()) 682 // test reuse 683 // m-v with reuse 684 t.Log("Mat⋅Vec with reuse") 685 R = New(Of(Float64), WithShape(2)) 686 R2, err = Dot(A, b, WithReuse(R)) 687 expectedShape = Shape{2} 688 expectedData = []float64{14, 38} 689 assert.Nil(err) 690 assert.Equal(R, R2) 691 assert.Equal(expectedData, R.Data()) 692 assert.Equal(expectedShape, R.Shape()) 693 694 // 3T-vec with reuse 695 t.Logf("3T⋅vec with reuse") 696 R = New(Of(Float64), WithShape(6)) 697 A = New(Of(Float64), WithShape(2, 3, 4), WithBacking(Range(Float64, 0, 24))) 698 R2, err = Dot(A, b, WithReuse(R)) 699 expectedShape = Shape{2, 3} 700 expectedData = []float64{ 701 14, 38, 62, 702 86, 110, 134, 703 } 704 assert.Nil(err) 705 assert.Equal(R, R2) 706 assert.Equal(expectedData, R2.Data()) 707 assert.Equal(expectedShape, R2.Shape()) 708 // v-m 709 t.Log("vec⋅Mat with reuse") 710 R = New(Of(Float64), WithShape(2)) 711 a = New(Of(Float64), WithShape(4), WithBacking(Range(Float64, 0, 4))) 712 B = New(Of(Float64), WithShape(4, 2), WithBacking(Range(Float64, 0, 8))) 713 R2, err = Dot(a, B, WithReuse(R)) 714 expectedShape = Shape{2} 715 expectedData = []float64{28, 34} 716 assert.Nil(err) 717 assert.Equal(R, R2) 718 assert.Equal(expectedData, R.Data()) 719 assert.Equal(expectedShape, R.Shape()) 720 // test incr 721 incrBack := make([]float64, 2) 722 copy(incrBack, expectedData) 723 incr = New(Of(Float64), WithBacking(incrBack), WithShape(2)) 724 R, err = Dot(a, B, WithIncr(incr)) 725 vecf64.Scale(expectedData, 2) 726 assert.Nil(err) 727 assert.Equal(incr, R) 728 assert.Equal(expectedData, R.Data()) 729 assert.Equal(expectedShape, R.Shape()) 730 731 // The Nearly Stupids 732 s = New(FromScalar(5.0)) 733 s2 = New(FromScalar(10.0)) 734 R, err = Dot(s, s2) 735 assert.Nil(err) 736 assert.True(R.IsScalar()) 737 assert.Equal(float64(50), R.Data()) 738 R.Zero() 739 R2, err = Dot(s, s2, WithReuse(R)) 740 assert.Nil(err) 741 assert.True(R2.IsScalar()) 742 assert.Equal(float64(50), R2.Data()) 743 744 R, err = Dot(s, A) 745 expectedData = vecf64.Range(0, 24) 746 vecf64.Scale(expectedData, 5) 747 assert.Nil(err) 748 assert.Equal(A.Shape(), R.Shape()) 749 assert.Equal(expectedData, R.Data()) 750 R.Zero() 751 R2, err = Dot(s, A, WithReuse(R)) 752 assert.Nil(err) 753 assert.Equal(R, R2) 754 assert.Equal(A.Shape(), R2.Shape()) 755 assert.Equal(expectedData, R2.Data()) 756 R, err = Dot(A, s) 757 assert.Nil(err) 758 assert.Equal(A.Shape(), R.Shape()) 759 assert.Equal(expectedData, R.Data()) 760 R.Zero() 761 R2, err = Dot(A, s, WithReuse(R)) 762 assert.Nil(err) 763 assert.Equal(R, R2) 764 assert.Equal(A.Shape(), R2.Shape()) 765 assert.Equal(expectedData, R2.Data()) 766 incr = New(Of(Float64), WithShape(R2.Shape()...)) 767 copy(incr.Data().([]float64), expectedData) 768 incr2 := incr.Clone().(*Dense) // backup a copy for the following test 769 vecf64.Scale(expectedData, 2) 770 R, err = Dot(A, s, WithIncr(incr)) 771 assert.Nil(err) 772 assert.Equal(incr, R) 773 assert.Equal(A.Shape(), R.Shape()) 774 assert.Equal(expectedData, R.Data()) 775 incr = incr2 776 777 R, err = Dot(s, A, WithIncr(incr)) 778 assert.Nil(err) 779 assert.Equal(incr, R) 780 assert.Equal(A.Shape(), R.Shape()) 781 assert.Equal(expectedData, R.Data()) 782 incr = New(Of(Float64), FromScalar(float64(50))) 783 784 R, err = Dot(s, s2, WithIncr(incr)) 785 assert.Nil(err) 786 assert.Equal(R, incr) 787 assert.True(R.IsScalar()) 788 assert.Equal(float64(100), R.Data()) 789 790 /* HERE BE STUPIDS */ 791 // different sizes of vectors 792 c = New(Of(Float64), WithShape(1, 100)) 793 _, err = Dot(a, c) 794 assert.NotNil(err) 795 // vector mat, but with shape mismatch 796 B = New(Of(Float64), WithShape(2, 3), WithBacking(Range(Float64, 0, 6))) 797 _, err = Dot(b, B) 798 assert.NotNil(err) 799 // mat-mat but wrong reuse size 800 A = New(Of(Float64), WithShape(2, 2)) 801 R = New(Of(Float64), WithShape(5, 10)) 802 _, err = Dot(A, B, WithReuse(R)) 803 assert.NotNil(err) 804 // mat-vec but wrong reuse size 805 b = New(Of(Float64), WithShape(2)) 806 _, err = Dot(A, b, WithReuse(R)) 807 assert.NotNil(err) 808 // T-T but misaligned shape 809 A = New(Of(Float64), WithShape(2, 3, 4)) 810 B = New(Of(Float64), WithShape(4, 2, 3)) 811 _, err = Dot(A, B) 812 assert.NotNil(err) 813 } 814 815 func TestOneDot(t *testing.T) { 816 assert := assert.New(t) 817 A := New(Of(Float64), WithShape(2, 3, 4), WithBacking(Range(Float64, 0, 24))) 818 b := New(Of(Float64), WithShape(4), WithBacking(Range(Float64, 0, 4))) 819 820 R, err := Dot(A, b) 821 expectedShape := Shape{2, 3} 822 expectedData := []float64{ 823 14, 38, 62, 824 86, 110, 134, 825 } 826 assert.Nil(err) 827 assert.Equal(expectedData, R.Data()) 828 assert.Equal(expectedShape, R.Shape()) 829 830 // 3T-vec with reuse 831 t.Logf("3T⋅vec with reuse") 832 R.Zero() 833 A = New(Of(Float64), WithShape(2, 3, 4), WithBacking(Range(Float64, 0, 24))) 834 R2, err := Dot(A, b, WithReuse(R)) 835 expectedShape = Shape{2, 3} 836 expectedData = []float64{ 837 14, 38, 62, 838 86, 110, 134, 839 } 840 assert.Nil(err) 841 assert.Equal(R, R2) 842 assert.Equal(expectedData, R2.Data()) 843 assert.Equal(expectedShape, R2.Shape()) 844 }