github.com/gopherd/gonum@v0.0.4/mat/band_test.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"reflect"
     9  	"testing"
    10  
    11  	"github.com/gopherd/gonum/blas/blas64"
    12  )
    13  
    14  func TestNewBand(t *testing.T) {
    15  	t.Parallel()
    16  	for i, test := range []struct {
    17  		data   []float64
    18  		r, c   int
    19  		kl, ku int
    20  		mat    *BandDense
    21  		dense  *Dense
    22  	}{
    23  		{
    24  			data: []float64{
    25  				-1, 1, 2, 3,
    26  				4, 5, 6, 7,
    27  				8, 9, 10, 11,
    28  				12, 13, 14, 15,
    29  				16, 17, 18, -1,
    30  				19, 20, -1, -1,
    31  			},
    32  			r: 6, c: 6,
    33  			kl: 1, ku: 2,
    34  			mat: &BandDense{
    35  				mat: blas64.Band{
    36  					Rows:   6,
    37  					Cols:   6,
    38  					KL:     1,
    39  					KU:     2,
    40  					Stride: 4,
    41  					Data: []float64{
    42  						-1, 1, 2, 3,
    43  						4, 5, 6, 7,
    44  						8, 9, 10, 11,
    45  						12, 13, 14, 15,
    46  						16, 17, 18, -1,
    47  						19, 20, -1, -1,
    48  					},
    49  				},
    50  			},
    51  			dense: NewDense(6, 6, []float64{
    52  				1, 2, 3, 0, 0, 0,
    53  				4, 5, 6, 7, 0, 0,
    54  				0, 8, 9, 10, 11, 0,
    55  				0, 0, 12, 13, 14, 15,
    56  				0, 0, 0, 16, 17, 18,
    57  				0, 0, 0, 0, 19, 20,
    58  			}),
    59  		},
    60  		{
    61  			data: []float64{
    62  				-1, 1, 2, 3,
    63  				4, 5, 6, 7,
    64  				8, 9, 10, 11,
    65  				12, 13, 14, 15,
    66  				16, 17, 18, -1,
    67  				19, 20, -1, -1,
    68  				21, -1, -1, -1,
    69  			},
    70  			r: 10, c: 6,
    71  			kl: 1, ku: 2,
    72  			mat: &BandDense{
    73  				mat: blas64.Band{
    74  					Rows:   10,
    75  					Cols:   6,
    76  					KL:     1,
    77  					KU:     2,
    78  					Stride: 4,
    79  					Data: []float64{
    80  						-1, 1, 2, 3,
    81  						4, 5, 6, 7,
    82  						8, 9, 10, 11,
    83  						12, 13, 14, 15,
    84  						16, 17, 18, -1,
    85  						19, 20, -1, -1,
    86  						21, -1, -1, -1,
    87  					},
    88  				},
    89  			},
    90  			dense: NewDense(10, 6, []float64{
    91  				1, 2, 3, 0, 0, 0,
    92  				4, 5, 6, 7, 0, 0,
    93  				0, 8, 9, 10, 11, 0,
    94  				0, 0, 12, 13, 14, 15,
    95  				0, 0, 0, 16, 17, 18,
    96  				0, 0, 0, 0, 19, 20,
    97  				0, 0, 0, 0, 0, 21,
    98  				0, 0, 0, 0, 0, 0,
    99  				0, 0, 0, 0, 0, 0,
   100  				0, 0, 0, 0, 0, 0,
   101  			}),
   102  		},
   103  		{
   104  			data: []float64{
   105  				-1, 1, 2, 3,
   106  				4, 5, 6, 7,
   107  				8, 9, 10, 11,
   108  				12, 13, 14, 15,
   109  				16, 17, 18, 19,
   110  				20, 21, 22, 23,
   111  			},
   112  			r: 6, c: 10,
   113  			kl: 1, ku: 2,
   114  			mat: &BandDense{
   115  				mat: blas64.Band{
   116  					Rows:   6,
   117  					Cols:   10,
   118  					KL:     1,
   119  					KU:     2,
   120  					Stride: 4,
   121  					Data: []float64{
   122  						-1, 1, 2, 3,
   123  						4, 5, 6, 7,
   124  						8, 9, 10, 11,
   125  						12, 13, 14, 15,
   126  						16, 17, 18, 19,
   127  						20, 21, 22, 23,
   128  					},
   129  				},
   130  			},
   131  			dense: NewDense(6, 10, []float64{
   132  				1, 2, 3, 0, 0, 0, 0, 0, 0, 0,
   133  				4, 5, 6, 7, 0, 0, 0, 0, 0, 0,
   134  				0, 8, 9, 10, 11, 0, 0, 0, 0, 0,
   135  				0, 0, 12, 13, 14, 15, 0, 0, 0, 0,
   136  				0, 0, 0, 16, 17, 18, 19, 0, 0, 0,
   137  				0, 0, 0, 0, 20, 21, 22, 23, 0, 0,
   138  			}),
   139  		},
   140  	} {
   141  		band := NewBandDense(test.r, test.c, test.kl, test.ku, test.data)
   142  		rows, cols := band.Dims()
   143  
   144  		if rows != test.r {
   145  			t.Errorf("unexpected number of rows for test %d: got: %d want: %d", i, rows, test.r)
   146  		}
   147  		if cols != test.c {
   148  			t.Errorf("unexpected number of cols for test %d: got: %d want: %d", i, cols, test.c)
   149  		}
   150  		if !reflect.DeepEqual(band, test.mat) {
   151  			t.Errorf("unexpected value via reflect for test %d: got: %v want: %v", i, band, test.mat)
   152  		}
   153  		if !Equal(band, test.mat) {
   154  			t.Errorf("unexpected value via mat.Equal for test %d: got: %v want: %v", i, band, test.mat)
   155  		}
   156  		if !Equal(band, test.dense) {
   157  			t.Errorf("unexpected value via mat.Equal(band, dense) for test %d:\ngot:\n% v\nwant:\n% v", i, Formatted(band), Formatted(test.dense))
   158  		}
   159  	}
   160  }
   161  
   162  func TestNewDiagonalRect(t *testing.T) {
   163  	t.Parallel()
   164  	for i, test := range []struct {
   165  		data  []float64
   166  		r, c  int
   167  		mat   *BandDense
   168  		dense *Dense
   169  	}{
   170  		{
   171  			data: []float64{1, 2, 3, 4, 5, 6},
   172  			r:    6, c: 6,
   173  			mat: &BandDense{
   174  				mat: blas64.Band{
   175  					Rows:   6,
   176  					Cols:   6,
   177  					Stride: 1,
   178  					Data:   []float64{1, 2, 3, 4, 5, 6},
   179  				},
   180  			},
   181  			dense: NewDense(6, 6, []float64{
   182  				1, 0, 0, 0, 0, 0,
   183  				0, 2, 0, 0, 0, 0,
   184  				0, 0, 3, 0, 0, 0,
   185  				0, 0, 0, 4, 0, 0,
   186  				0, 0, 0, 0, 5, 0,
   187  				0, 0, 0, 0, 0, 6,
   188  			}),
   189  		},
   190  		{
   191  			data: []float64{1, 2, 3, 4, 5, 6},
   192  			r:    7, c: 6,
   193  			mat: &BandDense{
   194  				mat: blas64.Band{
   195  					Rows:   7,
   196  					Cols:   6,
   197  					Stride: 1,
   198  					Data:   []float64{1, 2, 3, 4, 5, 6},
   199  				},
   200  			},
   201  			dense: NewDense(7, 6, []float64{
   202  				1, 0, 0, 0, 0, 0,
   203  				0, 2, 0, 0, 0, 0,
   204  				0, 0, 3, 0, 0, 0,
   205  				0, 0, 0, 4, 0, 0,
   206  				0, 0, 0, 0, 5, 0,
   207  				0, 0, 0, 0, 0, 6,
   208  				0, 0, 0, 0, 0, 0,
   209  			}),
   210  		},
   211  		{
   212  			data: []float64{1, 2, 3, 4, 5, 6},
   213  			r:    6, c: 7,
   214  			mat: &BandDense{
   215  				mat: blas64.Band{
   216  					Rows:   6,
   217  					Cols:   7,
   218  					Stride: 1,
   219  					Data:   []float64{1, 2, 3, 4, 5, 6},
   220  				},
   221  			},
   222  			dense: NewDense(6, 7, []float64{
   223  				1, 0, 0, 0, 0, 0, 0,
   224  				0, 2, 0, 0, 0, 0, 0,
   225  				0, 0, 3, 0, 0, 0, 0,
   226  				0, 0, 0, 4, 0, 0, 0,
   227  				0, 0, 0, 0, 5, 0, 0,
   228  				0, 0, 0, 0, 0, 6, 0,
   229  			}),
   230  		},
   231  	} {
   232  		band := NewDiagonalRect(test.r, test.c, test.data)
   233  		rows, cols := band.Dims()
   234  
   235  		if rows != test.r {
   236  			t.Errorf("unexpected number of rows for test %d: got: %d want: %d", i, rows, test.r)
   237  		}
   238  		if cols != test.c {
   239  			t.Errorf("unexpected number of cols for test %d: got: %d want: %d", i, cols, test.c)
   240  		}
   241  		if !reflect.DeepEqual(band, test.mat) {
   242  			t.Errorf("unexpected value via reflect for test %d: got: %v want: %v", i, band, test.mat)
   243  		}
   244  		if !Equal(band, test.mat) {
   245  			t.Errorf("unexpected value via mat.Equal for test %d: got: %v want: %v", i, band, test.mat)
   246  		}
   247  		if !Equal(band, test.dense) {
   248  			t.Errorf("unexpected value via mat.Equal(band, dense) for test %d:\ngot:\n% v\nwant:\n% v", i, Formatted(band), Formatted(test.dense))
   249  		}
   250  	}
   251  }
   252  
   253  func TestBandDenseZero(t *testing.T) {
   254  	t.Parallel()
   255  	// Elements that equal 1 should be set to zero, elements that equal -1
   256  	// should remain unchanged.
   257  	for _, test := range []*BandDense{
   258  		{
   259  			mat: blas64.Band{
   260  				Rows:   6,
   261  				Cols:   7,
   262  				Stride: 8,
   263  				KL:     1,
   264  				KU:     2,
   265  				Data: []float64{
   266  					-1, 1, 1, 1, -1, -1, -1, -1,
   267  					1, 1, 1, 1, -1, -1, -1, -1,
   268  					1, 1, 1, 1, -1, -1, -1, -1,
   269  					1, 1, 1, 1, -1, -1, -1, -1,
   270  					1, 1, 1, -1, -1, -1, -1, -1,
   271  					1, 1, -1, -1, -1, -1, -1, -1,
   272  				},
   273  			},
   274  		},
   275  		{
   276  			mat: blas64.Band{
   277  				Rows:   6,
   278  				Cols:   7,
   279  				Stride: 8,
   280  				KL:     2,
   281  				KU:     1,
   282  				Data: []float64{
   283  					-1, -1, 1, 1, -1, -1, -1, -1,
   284  					-1, 1, 1, 1, -1, -1, -1, -1,
   285  					1, 1, 1, 1, -1, -1, -1, -1,
   286  					1, 1, 1, 1, -1, -1, -1, -1,
   287  					1, 1, 1, 1, -1, -1, -1, -1,
   288  					1, 1, 1, -1, -1, -1, -1, -1,
   289  				},
   290  			},
   291  		},
   292  	} {
   293  		dataCopy := make([]float64, len(test.mat.Data))
   294  		copy(dataCopy, test.mat.Data)
   295  		test.Zero()
   296  		for i, v := range test.mat.Data {
   297  			if dataCopy[i] != -1 && v != 0 {
   298  				t.Errorf("Matrix not zeroed in bounds")
   299  			}
   300  			if dataCopy[i] == -1 && v != -1 {
   301  				t.Errorf("Matrix zeroed out of bounds")
   302  			}
   303  		}
   304  	}
   305  }
   306  
   307  func TestBandDiagView(t *testing.T) {
   308  	t.Parallel()
   309  	for cas, test := range []*BandDense{
   310  		NewBandDense(1, 1, 0, 0, []float64{1}),
   311  		NewBandDense(6, 6, 1, 2, []float64{
   312  			-1, 2, 3, 4,
   313  			5, 6, 7, 8,
   314  			9, 10, 11, 12,
   315  			13, 14, 15, 16,
   316  			17, 18, 19, -1,
   317  			21, 22, -1, -1,
   318  		}),
   319  		NewBandDense(6, 6, 2, 1, []float64{
   320  			-1, -1, 1, 2,
   321  			-1, 3, 4, 5,
   322  			6, 7, 8, 9,
   323  			10, 11, 12, 13,
   324  			14, 15, 16, 17,
   325  			18, 19, 20, -1,
   326  		}),
   327  	} {
   328  		testDiagView(t, cas, test)
   329  	}
   330  }
   331  
   332  func TestBandAtSet(t *testing.T) {
   333  	t.Parallel()
   334  	// 2  3  4  0  0  0
   335  	// 5  6  7  8  0  0
   336  	// 0  9 10 11 12  0
   337  	// 0  0 13 14 15 16
   338  	// 0  0  0 17 18 19
   339  	// 0  0  0  0 21 22
   340  	band := NewBandDense(6, 6, 1, 2, []float64{
   341  		-1, 2, 3, 4,
   342  		5, 6, 7, 8,
   343  		9, 10, 11, 12,
   344  		13, 14, 15, 16,
   345  		17, 18, 19, -1,
   346  		21, 22, -1, -1,
   347  	})
   348  
   349  	rows, cols := band.Dims()
   350  	kl, ku := band.Bandwidth()
   351  
   352  	// Explicitly test all indexes.
   353  	want := bandImplicit{rows, cols, kl, ku, func(i, j int) float64 {
   354  		return float64(i*(kl+ku) + j + kl + 1)
   355  	}}
   356  	for i := 0; i < 6; i++ {
   357  		for j := 0; j < 6; j++ {
   358  			if band.At(i, j) != want.At(i, j) {
   359  				t.Errorf("unexpected value for band.At(%d, %d): got:%v want:%v", i, j, band.At(i, j), want.At(i, j))
   360  			}
   361  		}
   362  	}
   363  	// Do that same thing via a call to Equal.
   364  	if !Equal(band, want) {
   365  		t.Errorf("unexpected value via mat.Equal:\ngot:\n% v\nwant:\n% v", Formatted(band), Formatted(want))
   366  	}
   367  
   368  	// Check At out of bounds
   369  	for _, row := range []int{-1, rows, rows + 1} {
   370  		panicked, message := panics(func() { band.At(row, 0) })
   371  		if !panicked || message != ErrRowAccess.Error() {
   372  			t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row)
   373  		}
   374  	}
   375  	for _, col := range []int{-1, cols, cols + 1} {
   376  		panicked, message := panics(func() { band.At(0, col) })
   377  		if !panicked || message != ErrColAccess.Error() {
   378  			t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col)
   379  		}
   380  	}
   381  
   382  	// Check Set out of bounds
   383  	for _, row := range []int{-1, rows, rows + 1} {
   384  		panicked, message := panics(func() { band.SetBand(row, 0, 1.2) })
   385  		if !panicked || message != ErrRowAccess.Error() {
   386  			t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row)
   387  		}
   388  	}
   389  	for _, col := range []int{-1, cols, cols + 1} {
   390  		panicked, message := panics(func() { band.SetBand(0, col, 1.2) })
   391  		if !panicked || message != ErrColAccess.Error() {
   392  			t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col)
   393  		}
   394  	}
   395  
   396  	for _, st := range []struct {
   397  		row, col int
   398  	}{
   399  		{row: 0, col: 3},
   400  		{row: 0, col: 4},
   401  		{row: 0, col: 5},
   402  		{row: 1, col: 4},
   403  		{row: 1, col: 5},
   404  		{row: 2, col: 5},
   405  		{row: 2, col: 0},
   406  		{row: 3, col: 1},
   407  		{row: 4, col: 2},
   408  		{row: 5, col: 3},
   409  	} {
   410  		panicked, message := panics(func() { band.SetBand(st.row, st.col, 1.2) })
   411  		if !panicked || message != ErrBandSet.Error() {
   412  			t.Errorf("expected panic for %+v %s", st, message)
   413  		}
   414  	}
   415  
   416  	for _, st := range []struct {
   417  		row, col  int
   418  		orig, new float64
   419  	}{
   420  		{row: 1, col: 2, orig: 7, new: 15},
   421  		{row: 2, col: 3, orig: 11, new: 15},
   422  	} {
   423  		if e := band.At(st.row, st.col); e != st.orig {
   424  			t.Errorf("unexpected value for At(%d, %d): got: %v want: %v", st.row, st.col, e, st.orig)
   425  		}
   426  		band.SetBand(st.row, st.col, st.new)
   427  		if e := band.At(st.row, st.col); e != st.new {
   428  			t.Errorf("unexpected value for At(%d, %d) after SetBand(%[1]d, %d, %v): got: %v want: %[3]v", st.row, st.col, st.new, e)
   429  		}
   430  	}
   431  }
   432  
   433  // bandImplicit is an implicit band matrix returning val(i, j)
   434  // for the value at (i, j).
   435  type bandImplicit struct {
   436  	r, c, kl, ku int
   437  	val          func(i, j int) float64
   438  }
   439  
   440  func (b bandImplicit) Dims() (r, c int) {
   441  	return b.r, b.c
   442  }
   443  
   444  func (b bandImplicit) T() Matrix {
   445  	return Transpose{b}
   446  }
   447  
   448  func (b bandImplicit) At(i, j int) float64 {
   449  	if i < 0 || b.r <= i {
   450  		panic("row")
   451  	}
   452  	if j < 0 || b.c <= j {
   453  		panic("col")
   454  	}
   455  	if j < i-b.kl || i+b.ku < j {
   456  		return 0
   457  	}
   458  	return b.val(i, j)
   459  }