gonum.org/v1/gonum@v0.14.0/mat/mul_test.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "testing" 9 10 "golang.org/x/exp/rand" 11 12 "gonum.org/v1/gonum/blas" 13 "gonum.org/v1/gonum/blas/blas64" 14 "gonum.org/v1/gonum/floats" 15 ) 16 17 // TODO: Need to add tests where one is overwritten. 18 func TestMulTypes(t *testing.T) { 19 t.Parallel() 20 src := rand.NewSource(1) 21 for _, test := range []struct { 22 ar int 23 ac int 24 br int 25 bc int 26 Panics bool 27 }{ 28 { 29 ar: 5, 30 ac: 5, 31 br: 5, 32 bc: 5, 33 Panics: false, 34 }, 35 { 36 ar: 10, 37 ac: 5, 38 br: 5, 39 bc: 3, 40 Panics: false, 41 }, 42 { 43 ar: 10, 44 ac: 5, 45 br: 5, 46 bc: 8, 47 Panics: false, 48 }, 49 { 50 ar: 8, 51 ac: 10, 52 br: 10, 53 bc: 3, 54 Panics: false, 55 }, 56 { 57 ar: 8, 58 ac: 3, 59 br: 3, 60 bc: 10, 61 Panics: false, 62 }, 63 { 64 ar: 5, 65 ac: 8, 66 br: 8, 67 bc: 10, 68 Panics: false, 69 }, 70 { 71 ar: 5, 72 ac: 12, 73 br: 12, 74 bc: 8, 75 Panics: false, 76 }, 77 { 78 ar: 5, 79 ac: 7, 80 br: 8, 81 bc: 10, 82 Panics: true, 83 }, 84 } { 85 ar := test.ar 86 ac := test.ac 87 br := test.br 88 bc := test.bc 89 90 // Generate random matrices 91 avec := make([]float64, ar*ac) 92 randomSlice(avec, src) 93 a := NewDense(ar, ac, avec) 94 95 bvec := make([]float64, br*bc) 96 randomSlice(bvec, src) 97 98 b := NewDense(br, bc, bvec) 99 100 // Check that it panics if it is supposed to 101 if test.Panics { 102 c := &Dense{} 103 fn := func() { 104 c.Mul(a, b) 105 } 106 pan, _ := panics(fn) 107 if !pan { 108 t.Errorf("Mul did not panic with dimension mismatch") 109 } 110 continue 111 } 112 113 cvec := make([]float64, ar*bc) 114 115 // Get correct matrix multiply answer from blas64.Gemm 116 blas64.Gemm(blas.NoTrans, blas.NoTrans, 117 1, a.mat, b.mat, 118 0, blas64.General{Rows: ar, Cols: bc, Stride: bc, Data: cvec}, 119 ) 120 121 avecCopy := append([]float64{}, avec...) 122 bvecCopy := append([]float64{}, bvec...) 123 cvecCopy := append([]float64{}, cvec...) 124 125 acomp := matComp{r: ar, c: ac, data: avecCopy} 126 bcomp := matComp{r: br, c: bc, data: bvecCopy} 127 ccomp := matComp{r: ar, c: bc, data: cvecCopy} 128 129 // Do normal multiply with empty dense 130 d := &Dense{} 131 132 testMul(t, a, b, d, acomp, bcomp, ccomp, false, "empty receiver") 133 134 // Normal multiply with existing receiver 135 c := NewDense(ar, bc, cvec) 136 randomSlice(cvec, src) 137 testMul(t, a, b, c, acomp, bcomp, ccomp, false, "existing receiver") 138 139 // Cast a as a basic matrix 140 am := (*basicMatrix)(a) 141 bm := (*basicMatrix)(b) 142 d.Reset() 143 testMul(t, am, b, d, acomp, bcomp, ccomp, true, "a is basic, receiver is empty") 144 d.Reset() 145 testMul(t, a, bm, d, acomp, bcomp, ccomp, true, "b is basic, receiver is empty") 146 d.Reset() 147 testMul(t, am, bm, d, acomp, bcomp, ccomp, true, "both basic, receiver is empty") 148 randomSlice(cvec, src) 149 testMul(t, am, b, d, acomp, bcomp, ccomp, true, "a is basic, receiver is full") 150 randomSlice(cvec, src) 151 testMul(t, a, bm, d, acomp, bcomp, ccomp, true, "b is basic, receiver is full") 152 randomSlice(cvec, src) 153 testMul(t, am, bm, d, acomp, bcomp, ccomp, true, "both basic, receiver is full") 154 } 155 } 156 157 func randomSlice(s []float64, src rand.Source) { 158 rnd := rand.New(src) 159 for i := range s { 160 s[i] = rnd.NormFloat64() 161 } 162 } 163 164 type matComp struct { 165 r, c int 166 data []float64 167 } 168 169 func testMul(t *testing.T, a, b Matrix, c *Dense, acomp, bcomp, ccomp matComp, cvecApprox bool, name string) { 170 c.Mul(a, b) 171 var aDense *Dense 172 switch t := a.(type) { 173 case *Dense: 174 aDense = t 175 case *basicMatrix: 176 aDense = (*Dense)(t) 177 } 178 179 var bDense *Dense 180 switch t := b.(type) { 181 case *Dense: 182 bDense = t 183 case *basicMatrix: 184 bDense = (*Dense)(t) 185 } 186 187 if !denseEqual(aDense, acomp) { 188 t.Errorf("a changed unexpectedly for %v", name) 189 } 190 if !denseEqual(bDense, bcomp) { 191 t.Errorf("b changed unexpectedly for %v", name) 192 } 193 if cvecApprox { 194 if !denseEqualApprox(c, ccomp, 1e-14) { 195 t.Errorf("mul answer not within tol for %v", name) 196 } 197 return 198 } 199 200 if !denseEqual(c, ccomp) { 201 t.Errorf("mul answer not equal for %v", name) 202 } 203 } 204 205 func denseEqual(a *Dense, acomp matComp) bool { 206 ar2, ac2 := a.Dims() 207 if ar2 != acomp.r { 208 return false 209 } 210 if ac2 != acomp.c { 211 return false 212 } 213 if !floats.Equal(a.mat.Data, acomp.data) { 214 return false 215 } 216 return true 217 } 218 219 func denseEqualApprox(a *Dense, acomp matComp, tol float64) bool { 220 ar2, ac2 := a.Dims() 221 if ar2 != acomp.r { 222 return false 223 } 224 if ac2 != acomp.c { 225 return false 226 } 227 if !floats.EqualApprox(a.mat.Data, acomp.data, tol) { 228 return false 229 } 230 return true 231 }