github.com/gopherd/gonum@v0.0.4/mat/symband_test.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"reflect"
     9  	"testing"
    10  
    11  	"github.com/gopherd/gonum/blas"
    12  	"github.com/gopherd/gonum/blas/blas64"
    13  )
    14  
    15  func TestNewSymBand(t *testing.T) {
    16  	t.Parallel()
    17  	for i, test := range []struct {
    18  		data  []float64
    19  		n     int
    20  		k     int
    21  		mat   *SymBandDense
    22  		dense *Dense
    23  	}{
    24  		{
    25  			data: []float64{
    26  				1, 2, 3,
    27  				4, 5, 6,
    28  				7, 8, 9,
    29  				10, 11, 12,
    30  				13, 14, -1,
    31  				15, -1, -1,
    32  			},
    33  			n: 6,
    34  			k: 2,
    35  			mat: &SymBandDense{
    36  				mat: blas64.SymmetricBand{
    37  					N:      6,
    38  					K:      2,
    39  					Stride: 3,
    40  					Uplo:   blas.Upper,
    41  					Data: []float64{
    42  						1, 2, 3,
    43  						4, 5, 6,
    44  						7, 8, 9,
    45  						10, 11, 12,
    46  						13, 14, -1,
    47  						15, -1, -1,
    48  					},
    49  				},
    50  			},
    51  			dense: NewDense(6, 6, []float64{
    52  				1, 2, 3, 0, 0, 0,
    53  				2, 4, 5, 6, 0, 0,
    54  				3, 5, 7, 8, 9, 0,
    55  				0, 6, 8, 10, 11, 12,
    56  				0, 0, 9, 11, 13, 14,
    57  				0, 0, 0, 12, 14, 15,
    58  			}),
    59  		},
    60  	} {
    61  		band := NewSymBandDense(test.n, test.k, test.data)
    62  		rows, cols := band.Dims()
    63  
    64  		if rows != test.n {
    65  			t.Errorf("unexpected number of rows for test %d: got: %d want: %d", i, rows, test.n)
    66  		}
    67  		if cols != test.n {
    68  			t.Errorf("unexpected number of cols for test %d: got: %d want: %d", i, cols, test.n)
    69  		}
    70  		if !reflect.DeepEqual(band, test.mat) {
    71  			t.Errorf("unexpected value via reflect for test %d: got: %v want: %v", i, band, test.mat)
    72  		}
    73  		if !Equal(band, test.mat) {
    74  			t.Errorf("unexpected value via mat.Equal for test %d: got: %v want: %v", i, band, test.mat)
    75  		}
    76  		if !Equal(band, test.dense) {
    77  			t.Errorf("unexpected value via mat.Equal(band, dense) for test %d:\ngot:\n% v\nwant:\n% v", i, Formatted(band), Formatted(test.dense))
    78  		}
    79  	}
    80  }
    81  
    82  func TestSymBandAtSet(t *testing.T) {
    83  	t.Parallel()
    84  	// 1  2  3  0  0  0
    85  	// 2  4  5  6  0  0
    86  	// 3  5  7  8  9  0
    87  	// 0  6  8 10 11 12
    88  	// 0  0  9 11 13 14
    89  	// 0  0  0 12 14 16
    90  	band := NewSymBandDense(6, 2, []float64{
    91  		1, 2, 3,
    92  		4, 5, 6,
    93  		7, 8, 9,
    94  		10, 11, 12,
    95  		13, 14, -1,
    96  		16, -1, -1,
    97  	})
    98  
    99  	rows, cols := band.Dims()
   100  	kl, ku := band.Bandwidth()
   101  
   102  	// Explicitly test all indexes.
   103  	want := bandImplicit{rows, cols, kl, ku, func(i, j int) float64 {
   104  		if i > j {
   105  			i, j = j, i
   106  		}
   107  		return float64(i*ku + j + 1)
   108  	}}
   109  	for i := 0; i < 6; i++ {
   110  		for j := 0; j < 6; j++ {
   111  			if band.At(i, j) != want.At(i, j) {
   112  				t.Errorf("unexpected value for band.At(%d, %d): got:%v want:%v", i, j, band.At(i, j), want.At(i, j))
   113  			}
   114  		}
   115  	}
   116  	// Do that same thing via a call to Equal.
   117  	if !Equal(band, want) {
   118  		t.Errorf("unexpected value via mat.Equal:\ngot:\n% v\nwant:\n% v", Formatted(band), Formatted(want))
   119  	}
   120  
   121  	// Check At out of bounds
   122  	for _, row := range []int{-1, rows, rows + 1} {
   123  		panicked, message := panics(func() { band.At(row, 0) })
   124  		if !panicked || message != ErrRowAccess.Error() {
   125  			t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row)
   126  		}
   127  	}
   128  	for _, col := range []int{-1, cols, cols + 1} {
   129  		panicked, message := panics(func() { band.At(0, col) })
   130  		if !panicked || message != ErrColAccess.Error() {
   131  			t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col)
   132  		}
   133  	}
   134  
   135  	// Check Set out of bounds
   136  	for _, row := range []int{-1, rows, rows + 1} {
   137  		panicked, message := panics(func() { band.SetSymBand(row, 0, 1.2) })
   138  		if !panicked || message != ErrRowAccess.Error() {
   139  			t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row)
   140  		}
   141  	}
   142  	for _, col := range []int{-1, cols, cols + 1} {
   143  		panicked, message := panics(func() { band.SetSymBand(0, col, 1.2) })
   144  		if !panicked || message != ErrColAccess.Error() {
   145  			t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col)
   146  		}
   147  	}
   148  
   149  	for _, st := range []struct {
   150  		row, col int
   151  	}{
   152  		{row: 0, col: 3},
   153  		{row: 0, col: 4},
   154  		{row: 0, col: 5},
   155  		{row: 1, col: 4},
   156  		{row: 1, col: 5},
   157  		{row: 2, col: 5},
   158  		{row: 3, col: 0},
   159  		{row: 4, col: 1},
   160  		{row: 5, col: 2},
   161  	} {
   162  		panicked, message := panics(func() { band.SetSymBand(st.row, st.col, 1.2) })
   163  		if !panicked || message != ErrBandSet.Error() {
   164  			t.Errorf("expected panic for %+v %s", st, message)
   165  		}
   166  	}
   167  
   168  	for _, st := range []struct {
   169  		row, col  int
   170  		orig, new float64
   171  	}{
   172  		{row: 1, col: 2, orig: 5, new: 15},
   173  		{row: 2, col: 3, orig: 8, new: 15},
   174  	} {
   175  		if e := band.At(st.row, st.col); e != st.orig {
   176  			t.Errorf("unexpected value for At(%d, %d): got: %v want: %v", st.row, st.col, e, st.orig)
   177  		}
   178  		band.SetSymBand(st.row, st.col, st.new)
   179  		if e := band.At(st.row, st.col); e != st.new {
   180  			t.Errorf("unexpected value for At(%d, %d) after SetSymBand(%[1]d, %d, %v): got: %v want: %[3]v", st.row, st.col, st.new, e)
   181  		}
   182  	}
   183  }
   184  
   185  func TestSymBandDiagView(t *testing.T) {
   186  	t.Parallel()
   187  	for cas, test := range []*SymBandDense{
   188  		NewSymBandDense(1, 0, []float64{1}),
   189  		NewSymBandDense(6, 2, []float64{
   190  			1, 2, 3,
   191  			4, 5, 6,
   192  			7, 8, 9,
   193  			10, 11, 12,
   194  			13, 14, -1,
   195  			16, -1, -1,
   196  		}),
   197  	} {
   198  		testDiagView(t, cas, test)
   199  	}
   200  }
   201  
   202  func TestSymBandDenseZero(t *testing.T) {
   203  	t.Parallel()
   204  	// Elements that equal 1 should be set to zero, elements that equal -1
   205  	// should remain unchanged.
   206  	for _, test := range []*SymBandDense{
   207  		{
   208  			mat: blas64.SymmetricBand{
   209  				Uplo:   blas.Upper,
   210  				N:      6,
   211  				K:      2,
   212  				Stride: 5,
   213  				Data: []float64{
   214  					1, 1, 1, -1, -1,
   215  					1, 1, 1, -1, -1,
   216  					1, 1, 1, -1, -1,
   217  					1, 1, 1, -1, -1,
   218  					1, 1, -1, -1, -1,
   219  					1, -1, -1, -1, -1,
   220  				},
   221  			},
   222  		},
   223  	} {
   224  		dataCopy := make([]float64, len(test.mat.Data))
   225  		copy(dataCopy, test.mat.Data)
   226  		test.Zero()
   227  		for i, v := range test.mat.Data {
   228  			if dataCopy[i] != -1 && v != 0 {
   229  				t.Errorf("Matrix not zeroed in bounds")
   230  			}
   231  			if dataCopy[i] == -1 && v != -1 {
   232  				t.Errorf("Matrix zeroed out of bounds")
   233  			}
   234  		}
   235  	}
   236  }