github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/mat/band_test.go (about) 1 // Copyright ©2017 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "reflect" 9 "testing" 10 11 "github.com/jingcheng-WU/gonum/blas/blas64" 12 ) 13 14 func TestNewBand(t *testing.T) { 15 t.Parallel() 16 for i, test := range []struct { 17 data []float64 18 r, c int 19 kl, ku int 20 mat *BandDense 21 dense *Dense 22 }{ 23 { 24 data: []float64{ 25 -1, 1, 2, 3, 26 4, 5, 6, 7, 27 8, 9, 10, 11, 28 12, 13, 14, 15, 29 16, 17, 18, -1, 30 19, 20, -1, -1, 31 }, 32 r: 6, c: 6, 33 kl: 1, ku: 2, 34 mat: &BandDense{ 35 mat: blas64.Band{ 36 Rows: 6, 37 Cols: 6, 38 KL: 1, 39 KU: 2, 40 Stride: 4, 41 Data: []float64{ 42 -1, 1, 2, 3, 43 4, 5, 6, 7, 44 8, 9, 10, 11, 45 12, 13, 14, 15, 46 16, 17, 18, -1, 47 19, 20, -1, -1, 48 }, 49 }, 50 }, 51 dense: NewDense(6, 6, []float64{ 52 1, 2, 3, 0, 0, 0, 53 4, 5, 6, 7, 0, 0, 54 0, 8, 9, 10, 11, 0, 55 0, 0, 12, 13, 14, 15, 56 0, 0, 0, 16, 17, 18, 57 0, 0, 0, 0, 19, 20, 58 }), 59 }, 60 { 61 data: []float64{ 62 -1, 1, 2, 3, 63 4, 5, 6, 7, 64 8, 9, 10, 11, 65 12, 13, 14, 15, 66 16, 17, 18, -1, 67 19, 20, -1, -1, 68 21, -1, -1, -1, 69 }, 70 r: 10, c: 6, 71 kl: 1, ku: 2, 72 mat: &BandDense{ 73 mat: blas64.Band{ 74 Rows: 10, 75 Cols: 6, 76 KL: 1, 77 KU: 2, 78 Stride: 4, 79 Data: []float64{ 80 -1, 1, 2, 3, 81 4, 5, 6, 7, 82 8, 9, 10, 11, 83 12, 13, 14, 15, 84 16, 17, 18, -1, 85 19, 20, -1, -1, 86 21, -1, -1, -1, 87 }, 88 }, 89 }, 90 dense: NewDense(10, 6, []float64{ 91 1, 2, 3, 0, 0, 0, 92 4, 5, 6, 7, 0, 0, 93 0, 8, 9, 10, 11, 0, 94 0, 0, 12, 13, 14, 15, 95 0, 0, 0, 16, 17, 18, 96 0, 0, 0, 0, 19, 20, 97 0, 0, 0, 0, 0, 21, 98 0, 0, 0, 0, 0, 0, 99 0, 0, 0, 0, 0, 0, 100 0, 0, 0, 0, 0, 0, 101 }), 102 }, 103 { 104 data: []float64{ 105 -1, 1, 2, 3, 106 4, 5, 6, 7, 107 8, 9, 10, 11, 108 12, 13, 14, 15, 109 16, 17, 18, 19, 110 20, 21, 22, 23, 111 }, 112 r: 6, c: 10, 113 kl: 1, ku: 2, 114 mat: &BandDense{ 115 mat: blas64.Band{ 116 Rows: 6, 117 Cols: 10, 118 KL: 1, 119 KU: 2, 120 Stride: 4, 121 Data: []float64{ 122 -1, 1, 2, 3, 123 4, 5, 6, 7, 124 8, 9, 10, 11, 125 12, 13, 14, 15, 126 16, 17, 18, 19, 127 20, 21, 22, 23, 128 }, 129 }, 130 }, 131 dense: NewDense(6, 10, []float64{ 132 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 133 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 134 0, 8, 9, 10, 11, 0, 0, 0, 0, 0, 135 0, 0, 12, 13, 14, 15, 0, 0, 0, 0, 136 0, 0, 0, 16, 17, 18, 19, 0, 0, 0, 137 0, 0, 0, 0, 20, 21, 22, 23, 0, 0, 138 }), 139 }, 140 } { 141 band := NewBandDense(test.r, test.c, test.kl, test.ku, test.data) 142 rows, cols := band.Dims() 143 144 if rows != test.r { 145 t.Errorf("unexpected number of rows for test %d: got: %d want: %d", i, rows, test.r) 146 } 147 if cols != test.c { 148 t.Errorf("unexpected number of cols for test %d: got: %d want: %d", i, cols, test.c) 149 } 150 if !reflect.DeepEqual(band, test.mat) { 151 t.Errorf("unexpected value via reflect for test %d: got: %v want: %v", i, band, test.mat) 152 } 153 if !Equal(band, test.mat) { 154 t.Errorf("unexpected value via mat.Equal for test %d: got: %v want: %v", i, band, test.mat) 155 } 156 if !Equal(band, test.dense) { 157 t.Errorf("unexpected value via mat.Equal(band, dense) for test %d:\ngot:\n% v\nwant:\n% v", i, Formatted(band), Formatted(test.dense)) 158 } 159 } 160 } 161 162 func TestNewDiagonalRect(t *testing.T) { 163 t.Parallel() 164 for i, test := range []struct { 165 data []float64 166 r, c int 167 mat *BandDense 168 dense *Dense 169 }{ 170 { 171 data: []float64{1, 2, 3, 4, 5, 6}, 172 r: 6, c: 6, 173 mat: &BandDense{ 174 mat: blas64.Band{ 175 Rows: 6, 176 Cols: 6, 177 Stride: 1, 178 Data: []float64{1, 2, 3, 4, 5, 6}, 179 }, 180 }, 181 dense: NewDense(6, 6, []float64{ 182 1, 0, 0, 0, 0, 0, 183 0, 2, 0, 0, 0, 0, 184 0, 0, 3, 0, 0, 0, 185 0, 0, 0, 4, 0, 0, 186 0, 0, 0, 0, 5, 0, 187 0, 0, 0, 0, 0, 6, 188 }), 189 }, 190 { 191 data: []float64{1, 2, 3, 4, 5, 6}, 192 r: 7, c: 6, 193 mat: &BandDense{ 194 mat: blas64.Band{ 195 Rows: 7, 196 Cols: 6, 197 Stride: 1, 198 Data: []float64{1, 2, 3, 4, 5, 6}, 199 }, 200 }, 201 dense: NewDense(7, 6, []float64{ 202 1, 0, 0, 0, 0, 0, 203 0, 2, 0, 0, 0, 0, 204 0, 0, 3, 0, 0, 0, 205 0, 0, 0, 4, 0, 0, 206 0, 0, 0, 0, 5, 0, 207 0, 0, 0, 0, 0, 6, 208 0, 0, 0, 0, 0, 0, 209 }), 210 }, 211 { 212 data: []float64{1, 2, 3, 4, 5, 6}, 213 r: 6, c: 7, 214 mat: &BandDense{ 215 mat: blas64.Band{ 216 Rows: 6, 217 Cols: 7, 218 Stride: 1, 219 Data: []float64{1, 2, 3, 4, 5, 6}, 220 }, 221 }, 222 dense: NewDense(6, 7, []float64{ 223 1, 0, 0, 0, 0, 0, 0, 224 0, 2, 0, 0, 0, 0, 0, 225 0, 0, 3, 0, 0, 0, 0, 226 0, 0, 0, 4, 0, 0, 0, 227 0, 0, 0, 0, 5, 0, 0, 228 0, 0, 0, 0, 0, 6, 0, 229 }), 230 }, 231 } { 232 band := NewDiagonalRect(test.r, test.c, test.data) 233 rows, cols := band.Dims() 234 235 if rows != test.r { 236 t.Errorf("unexpected number of rows for test %d: got: %d want: %d", i, rows, test.r) 237 } 238 if cols != test.c { 239 t.Errorf("unexpected number of cols for test %d: got: %d want: %d", i, cols, test.c) 240 } 241 if !reflect.DeepEqual(band, test.mat) { 242 t.Errorf("unexpected value via reflect for test %d: got: %v want: %v", i, band, test.mat) 243 } 244 if !Equal(band, test.mat) { 245 t.Errorf("unexpected value via mat.Equal for test %d: got: %v want: %v", i, band, test.mat) 246 } 247 if !Equal(band, test.dense) { 248 t.Errorf("unexpected value via mat.Equal(band, dense) for test %d:\ngot:\n% v\nwant:\n% v", i, Formatted(band), Formatted(test.dense)) 249 } 250 } 251 } 252 253 func TestBandDenseZero(t *testing.T) { 254 t.Parallel() 255 // Elements that equal 1 should be set to zero, elements that equal -1 256 // should remain unchanged. 257 for _, test := range []*BandDense{ 258 { 259 mat: blas64.Band{ 260 Rows: 6, 261 Cols: 7, 262 Stride: 8, 263 KL: 1, 264 KU: 2, 265 Data: []float64{ 266 -1, 1, 1, 1, -1, -1, -1, -1, 267 1, 1, 1, 1, -1, -1, -1, -1, 268 1, 1, 1, 1, -1, -1, -1, -1, 269 1, 1, 1, 1, -1, -1, -1, -1, 270 1, 1, 1, -1, -1, -1, -1, -1, 271 1, 1, -1, -1, -1, -1, -1, -1, 272 }, 273 }, 274 }, 275 { 276 mat: blas64.Band{ 277 Rows: 6, 278 Cols: 7, 279 Stride: 8, 280 KL: 2, 281 KU: 1, 282 Data: []float64{ 283 -1, -1, 1, 1, -1, -1, -1, -1, 284 -1, 1, 1, 1, -1, -1, -1, -1, 285 1, 1, 1, 1, -1, -1, -1, -1, 286 1, 1, 1, 1, -1, -1, -1, -1, 287 1, 1, 1, 1, -1, -1, -1, -1, 288 1, 1, 1, -1, -1, -1, -1, -1, 289 }, 290 }, 291 }, 292 } { 293 dataCopy := make([]float64, len(test.mat.Data)) 294 copy(dataCopy, test.mat.Data) 295 test.Zero() 296 for i, v := range test.mat.Data { 297 if dataCopy[i] != -1 && v != 0 { 298 t.Errorf("Matrix not zeroed in bounds") 299 } 300 if dataCopy[i] == -1 && v != -1 { 301 t.Errorf("Matrix zeroed out of bounds") 302 } 303 } 304 } 305 } 306 307 func TestBandDiagView(t *testing.T) { 308 t.Parallel() 309 for cas, test := range []*BandDense{ 310 NewBandDense(1, 1, 0, 0, []float64{1}), 311 NewBandDense(6, 6, 1, 2, []float64{ 312 -1, 2, 3, 4, 313 5, 6, 7, 8, 314 9, 10, 11, 12, 315 13, 14, 15, 16, 316 17, 18, 19, -1, 317 21, 22, -1, -1, 318 }), 319 NewBandDense(6, 6, 2, 1, []float64{ 320 -1, -1, 1, 2, 321 -1, 3, 4, 5, 322 6, 7, 8, 9, 323 10, 11, 12, 13, 324 14, 15, 16, 17, 325 18, 19, 20, -1, 326 }), 327 } { 328 testDiagView(t, cas, test) 329 } 330 } 331 332 func TestBandAtSet(t *testing.T) { 333 t.Parallel() 334 // 2 3 4 0 0 0 335 // 5 6 7 8 0 0 336 // 0 9 10 11 12 0 337 // 0 0 13 14 15 16 338 // 0 0 0 17 18 19 339 // 0 0 0 0 21 22 340 band := NewBandDense(6, 6, 1, 2, []float64{ 341 -1, 2, 3, 4, 342 5, 6, 7, 8, 343 9, 10, 11, 12, 344 13, 14, 15, 16, 345 17, 18, 19, -1, 346 21, 22, -1, -1, 347 }) 348 349 rows, cols := band.Dims() 350 kl, ku := band.Bandwidth() 351 352 // Explicitly test all indexes. 353 want := bandImplicit{rows, cols, kl, ku, func(i, j int) float64 { 354 return float64(i*(kl+ku) + j + kl + 1) 355 }} 356 for i := 0; i < 6; i++ { 357 for j := 0; j < 6; j++ { 358 if band.At(i, j) != want.At(i, j) { 359 t.Errorf("unexpected value for band.At(%d, %d): got:%v want:%v", i, j, band.At(i, j), want.At(i, j)) 360 } 361 } 362 } 363 // Do that same thing via a call to Equal. 364 if !Equal(band, want) { 365 t.Errorf("unexpected value via mat.Equal:\ngot:\n% v\nwant:\n% v", Formatted(band), Formatted(want)) 366 } 367 368 // Check At out of bounds 369 for _, row := range []int{-1, rows, rows + 1} { 370 panicked, message := panics(func() { band.At(row, 0) }) 371 if !panicked || message != ErrRowAccess.Error() { 372 t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row) 373 } 374 } 375 for _, col := range []int{-1, cols, cols + 1} { 376 panicked, message := panics(func() { band.At(0, col) }) 377 if !panicked || message != ErrColAccess.Error() { 378 t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col) 379 } 380 } 381 382 // Check Set out of bounds 383 for _, row := range []int{-1, rows, rows + 1} { 384 panicked, message := panics(func() { band.SetBand(row, 0, 1.2) }) 385 if !panicked || message != ErrRowAccess.Error() { 386 t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row) 387 } 388 } 389 for _, col := range []int{-1, cols, cols + 1} { 390 panicked, message := panics(func() { band.SetBand(0, col, 1.2) }) 391 if !panicked || message != ErrColAccess.Error() { 392 t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col) 393 } 394 } 395 396 for _, st := range []struct { 397 row, col int 398 }{ 399 {row: 0, col: 3}, 400 {row: 0, col: 4}, 401 {row: 0, col: 5}, 402 {row: 1, col: 4}, 403 {row: 1, col: 5}, 404 {row: 2, col: 5}, 405 {row: 2, col: 0}, 406 {row: 3, col: 1}, 407 {row: 4, col: 2}, 408 {row: 5, col: 3}, 409 } { 410 panicked, message := panics(func() { band.SetBand(st.row, st.col, 1.2) }) 411 if !panicked || message != ErrBandSet.Error() { 412 t.Errorf("expected panic for %+v %s", st, message) 413 } 414 } 415 416 for _, st := range []struct { 417 row, col int 418 orig, new float64 419 }{ 420 {row: 1, col: 2, orig: 7, new: 15}, 421 {row: 2, col: 3, orig: 11, new: 15}, 422 } { 423 if e := band.At(st.row, st.col); e != st.orig { 424 t.Errorf("unexpected value for At(%d, %d): got: %v want: %v", st.row, st.col, e, st.orig) 425 } 426 band.SetBand(st.row, st.col, st.new) 427 if e := band.At(st.row, st.col); e != st.new { 428 t.Errorf("unexpected value for At(%d, %d) after SetBand(%[1]d, %d, %v): got: %v want: %[3]v", st.row, st.col, st.new, e) 429 } 430 } 431 } 432 433 // bandImplicit is an implicit band matrix returning val(i, j) 434 // for the value at (i, j). 435 type bandImplicit struct { 436 r, c, kl, ku int 437 val func(i, j int) float64 438 } 439 440 func (b bandImplicit) Dims() (r, c int) { 441 return b.r, b.c 442 } 443 444 func (b bandImplicit) T() Matrix { 445 return Transpose{b} 446 } 447 448 func (b bandImplicit) At(i, j int) float64 { 449 if i < 0 || b.r <= i { 450 panic("row") 451 } 452 if j < 0 || b.c <= j { 453 panic("col") 454 } 455 if j < i-b.kl || i+b.ku < j { 456 return 0 457 } 458 return b.val(i, j) 459 }