gonum.org/v1/gonum@v0.14.0/mat/inner_test.go (about) 1 // Copyright ©2014 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "math" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/blas/blas64" 14 "gonum.org/v1/gonum/blas/testblas" 15 ) 16 17 func TestInner(t *testing.T) { 18 t.Parallel() 19 for i, test := range []struct { 20 x []float64 21 y []float64 22 m [][]float64 23 }{ 24 { 25 x: []float64{5}, 26 y: []float64{10}, 27 m: [][]float64{{2}}, 28 }, 29 { 30 x: []float64{5, 6, 1}, 31 y: []float64{10}, 32 m: [][]float64{{2}, {-3}, {5}}, 33 }, 34 { 35 x: []float64{5}, 36 y: []float64{10, 15}, 37 m: [][]float64{{2, -3}}, 38 }, 39 { 40 x: []float64{1, 5}, 41 y: []float64{10, 15}, 42 m: [][]float64{ 43 {2, -3}, 44 {4, -1}, 45 }, 46 }, 47 { 48 x: []float64{2, 3, 9}, 49 y: []float64{8, 9}, 50 m: [][]float64{ 51 {2, 3}, 52 {4, 5}, 53 {6, 7}, 54 }, 55 }, 56 { 57 x: []float64{2, 3}, 58 y: []float64{8, 9, 9}, 59 m: [][]float64{ 60 {2, 3, 6}, 61 {4, 5, 7}, 62 }, 63 }, 64 } { 65 for _, inc := range []struct{ x, y int }{ 66 {1, 1}, 67 {1, 2}, 68 {2, 1}, 69 {2, 2}, 70 } { 71 x := NewDense(1, len(test.x), test.x) 72 m := NewDense(flatten(test.m)) 73 mWant := NewDense(flatten(test.m)) 74 y := NewDense(len(test.y), 1, test.y) 75 76 var tmp, cell Dense 77 tmp.Mul(mWant, y) 78 cell.Mul(x, &tmp) 79 80 rm, cm := cell.Dims() 81 if rm != 1 { 82 t.Errorf("Test %d result doesn't have 1 row", i) 83 } 84 if cm != 1 { 85 t.Errorf("Test %d result doesn't have 1 column", i) 86 } 87 88 want := cell.At(0, 0) 89 got := Inner(makeVecDenseInc(inc.x, test.x), m, makeVecDenseInc(inc.y, test.y)) 90 if got != want { 91 t.Errorf("Test %v: want %v, got %v", i, want, got) 92 } 93 } 94 } 95 } 96 97 func TestInnerSym(t *testing.T) { 98 t.Parallel() 99 for _, inc := range []struct{ x, y int }{ 100 {1, 1}, 101 {1, 2}, 102 {2, 1}, 103 {2, 2}, 104 } { 105 n := 10 106 xData := make([]float64, n) 107 yData := make([]float64, n) 108 data := make([]float64, n*n) 109 for i := 0; i < n; i++ { 110 xData[i] = float64(i) 111 yData[i] = float64(i) 112 for j := i; j < n; j++ { 113 data[i*n+j] = float64(i*n + j) 114 data[j*n+i] = data[i*n+j] 115 } 116 } 117 x := makeVecDenseInc(inc.x, xData) 118 y := makeVecDenseInc(inc.y, yData) 119 m := NewDense(n, n, data) 120 ans := Inner(x, m, y) 121 sym := NewSymDense(n, data) 122 // Poison the lower half of data to ensure it is not used. 123 for i := 1; i < n; i++ { 124 for j := 0; j < i; j++ { 125 data[i*n+j] = math.NaN() 126 } 127 } 128 129 if math.Abs(Inner(x, sym, y)-ans) > 1e-14 { 130 t.Error("inner different symmetric and dense") 131 } 132 } 133 } 134 135 func makeVecDenseInc(inc int, f []float64) *VecDense { 136 v := &VecDense{ 137 mat: blas64.Vector{ 138 N: len(f), 139 Inc: inc, 140 Data: make([]float64, (len(f)-1)*inc+1), 141 }, 142 } 143 144 // Contaminate backing data in all positions... 145 const base = 100 146 for i := range v.mat.Data { 147 v.mat.Data[i] = float64(i + base) 148 } 149 150 // then write real elements. 151 for i := range f { 152 v.mat.Data[i*inc] = f[i] 153 } 154 return v 155 } 156 157 func benchmarkInner(b *testing.B, m, n int) { 158 src := rand.NewSource(1) 159 x := NewVecDense(m, nil) 160 randomSlice(x.mat.Data, src) 161 y := NewVecDense(n, nil) 162 randomSlice(y.mat.Data, src) 163 data := make([]float64, m*n) 164 randomSlice(data, src) 165 mat := &Dense{mat: blas64.General{Rows: m, Cols: n, Stride: n, Data: data}, capRows: m, capCols: n} 166 b.ResetTimer() 167 for i := 0; i < b.N; i++ { 168 Inner(x, mat, y) 169 } 170 } 171 172 func BenchmarkInnerSmSm(b *testing.B) { 173 benchmarkInner(b, testblas.SmallMat, testblas.SmallMat) 174 } 175 176 func BenchmarkInnerMedMed(b *testing.B) { 177 benchmarkInner(b, testblas.MediumMat, testblas.MediumMat) 178 } 179 180 func BenchmarkInnerLgLg(b *testing.B) { 181 benchmarkInner(b, testblas.LargeMat, testblas.LargeMat) 182 } 183 184 func BenchmarkInnerLgSm(b *testing.B) { 185 benchmarkInner(b, testblas.LargeMat, testblas.SmallMat) 186 }