gonum.org/v1/gonum@v0.14.0/mat/symband_test.go (about) 1 // Copyright ©2017 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "reflect" 9 "testing" 10 11 "gonum.org/v1/gonum/blas" 12 "gonum.org/v1/gonum/blas/blas64" 13 ) 14 15 func TestNewSymBand(t *testing.T) { 16 t.Parallel() 17 for i, test := range []struct { 18 data []float64 19 n int 20 k int 21 mat *SymBandDense 22 dense *Dense 23 }{ 24 { 25 data: []float64{ 26 1, 2, 3, 27 4, 5, 6, 28 7, 8, 9, 29 10, 11, 12, 30 13, 14, -1, 31 15, -1, -1, 32 }, 33 n: 6, 34 k: 2, 35 mat: &SymBandDense{ 36 mat: blas64.SymmetricBand{ 37 N: 6, 38 K: 2, 39 Stride: 3, 40 Uplo: blas.Upper, 41 Data: []float64{ 42 1, 2, 3, 43 4, 5, 6, 44 7, 8, 9, 45 10, 11, 12, 46 13, 14, -1, 47 15, -1, -1, 48 }, 49 }, 50 }, 51 dense: NewDense(6, 6, []float64{ 52 1, 2, 3, 0, 0, 0, 53 2, 4, 5, 6, 0, 0, 54 3, 5, 7, 8, 9, 0, 55 0, 6, 8, 10, 11, 12, 56 0, 0, 9, 11, 13, 14, 57 0, 0, 0, 12, 14, 15, 58 }), 59 }, 60 } { 61 band := NewSymBandDense(test.n, test.k, test.data) 62 rows, cols := band.Dims() 63 64 if rows != test.n { 65 t.Errorf("unexpected number of rows for test %d: got: %d want: %d", i, rows, test.n) 66 } 67 if cols != test.n { 68 t.Errorf("unexpected number of cols for test %d: got: %d want: %d", i, cols, test.n) 69 } 70 if !reflect.DeepEqual(band, test.mat) { 71 t.Errorf("unexpected value via reflect for test %d: got: %v want: %v", i, band, test.mat) 72 } 73 if !Equal(band, test.mat) { 74 t.Errorf("unexpected value via mat.Equal for test %d: got: %v want: %v", i, band, test.mat) 75 } 76 if !Equal(band, test.dense) { 77 t.Errorf("unexpected value via mat.Equal(band, dense) for test %d:\ngot:\n% v\nwant:\n% v", i, Formatted(band), Formatted(test.dense)) 78 } 79 } 80 } 81 82 func TestSymBandAtSet(t *testing.T) { 83 t.Parallel() 84 // 1 2 3 0 0 0 85 // 2 4 5 6 0 0 86 // 3 5 7 8 9 0 87 // 0 6 8 10 11 12 88 // 0 0 9 11 13 14 89 // 0 0 0 12 14 16 90 band := NewSymBandDense(6, 2, []float64{ 91 1, 2, 3, 92 4, 5, 6, 93 7, 8, 9, 94 10, 11, 12, 95 13, 14, -1, 96 16, -1, -1, 97 }) 98 99 rows, cols := band.Dims() 100 kl, ku := band.Bandwidth() 101 102 // Explicitly test all indexes. 103 want := bandImplicit{rows, cols, kl, ku, func(i, j int) float64 { 104 if i > j { 105 i, j = j, i 106 } 107 return float64(i*ku + j + 1) 108 }} 109 for i := 0; i < 6; i++ { 110 for j := 0; j < 6; j++ { 111 if band.At(i, j) != want.At(i, j) { 112 t.Errorf("unexpected value for band.At(%d, %d): got:%v want:%v", i, j, band.At(i, j), want.At(i, j)) 113 } 114 } 115 } 116 // Do that same thing via a call to Equal. 117 if !Equal(band, want) { 118 t.Errorf("unexpected value via mat.Equal:\ngot:\n% v\nwant:\n% v", Formatted(band), Formatted(want)) 119 } 120 121 // Check At out of bounds 122 for _, row := range []int{-1, rows, rows + 1} { 123 panicked, message := panics(func() { band.At(row, 0) }) 124 if !panicked || message != ErrRowAccess.Error() { 125 t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row) 126 } 127 } 128 for _, col := range []int{-1, cols, cols + 1} { 129 panicked, message := panics(func() { band.At(0, col) }) 130 if !panicked || message != ErrColAccess.Error() { 131 t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col) 132 } 133 } 134 135 // Check Set out of bounds 136 for _, row := range []int{-1, rows, rows + 1} { 137 panicked, message := panics(func() { band.SetSymBand(row, 0, 1.2) }) 138 if !panicked || message != ErrRowAccess.Error() { 139 t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row) 140 } 141 } 142 for _, col := range []int{-1, cols, cols + 1} { 143 panicked, message := panics(func() { band.SetSymBand(0, col, 1.2) }) 144 if !panicked || message != ErrColAccess.Error() { 145 t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col) 146 } 147 } 148 149 for _, st := range []struct { 150 row, col int 151 }{ 152 {row: 0, col: 3}, 153 {row: 0, col: 4}, 154 {row: 0, col: 5}, 155 {row: 1, col: 4}, 156 {row: 1, col: 5}, 157 {row: 2, col: 5}, 158 {row: 3, col: 0}, 159 {row: 4, col: 1}, 160 {row: 5, col: 2}, 161 } { 162 panicked, message := panics(func() { band.SetSymBand(st.row, st.col, 1.2) }) 163 if !panicked || message != ErrBandSet.Error() { 164 t.Errorf("expected panic for %+v %s", st, message) 165 } 166 } 167 168 for _, st := range []struct { 169 row, col int 170 orig, new float64 171 }{ 172 {row: 1, col: 2, orig: 5, new: 15}, 173 {row: 2, col: 3, orig: 8, new: 15}, 174 } { 175 if e := band.At(st.row, st.col); e != st.orig { 176 t.Errorf("unexpected value for At(%d, %d): got: %v want: %v", st.row, st.col, e, st.orig) 177 } 178 band.SetSymBand(st.row, st.col, st.new) 179 if e := band.At(st.row, st.col); e != st.new { 180 t.Errorf("unexpected value for At(%d, %d) after SetSymBand(%[1]d, %d, %v): got: %v want: %[3]v", st.row, st.col, st.new, e) 181 } 182 } 183 } 184 185 func TestSymBandDiagView(t *testing.T) { 186 t.Parallel() 187 for cas, test := range []*SymBandDense{ 188 NewSymBandDense(1, 0, []float64{1}), 189 NewSymBandDense(6, 2, []float64{ 190 1, 2, 3, 191 4, 5, 6, 192 7, 8, 9, 193 10, 11, 12, 194 13, 14, -1, 195 16, -1, -1, 196 }), 197 } { 198 testDiagView(t, cas, test) 199 } 200 } 201 202 func TestSymBandDenseZero(t *testing.T) { 203 t.Parallel() 204 // Elements that equal 1 should be set to zero, elements that equal -1 205 // should remain unchanged. 206 for _, test := range []*SymBandDense{ 207 { 208 mat: blas64.SymmetricBand{ 209 Uplo: blas.Upper, 210 N: 6, 211 K: 2, 212 Stride: 5, 213 Data: []float64{ 214 1, 1, 1, -1, -1, 215 1, 1, 1, -1, -1, 216 1, 1, 1, -1, -1, 217 1, 1, 1, -1, -1, 218 1, 1, -1, -1, -1, 219 1, -1, -1, -1, -1, 220 }, 221 }, 222 }, 223 } { 224 dataCopy := make([]float64, len(test.mat.Data)) 225 copy(dataCopy, test.mat.Data) 226 test.Zero() 227 for i, v := range test.mat.Data { 228 if dataCopy[i] != -1 && v != 0 { 229 t.Errorf("Matrix not zeroed in bounds") 230 } 231 if dataCopy[i] == -1 && v != -1 { 232 t.Errorf("Matrix zeroed out of bounds") 233 } 234 } 235 } 236 }