github.com/gopherd/gonum@v0.0.4/mat/triband_test.go (about)

     1  // Copyright ©2018 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"fmt"
     9  	"reflect"
    10  	"testing"
    11  
    12  	"github.com/gopherd/gonum/blas"
    13  	"github.com/gopherd/gonum/blas/blas64"
    14  )
    15  
    16  func TestNewTriBand(t *testing.T) {
    17  	t.Parallel()
    18  	for cas, test := range []struct {
    19  		data  []float64
    20  		n, k  int
    21  		kind  TriKind
    22  		mat   *TriBandDense
    23  		dense *Dense
    24  	}{
    25  		{
    26  			data: []float64{1, 2, 3},
    27  			n:    3, k: 0,
    28  			kind: Upper,
    29  			mat: &TriBandDense{
    30  				mat: blas64.TriangularBand{
    31  					Diag: blas.NonUnit,
    32  					Uplo: blas.Upper,
    33  					N:    3, K: 0,
    34  					Data:   []float64{1, 2, 3},
    35  					Stride: 1,
    36  				},
    37  			},
    38  			dense: NewDense(3, 3, []float64{
    39  				1, 0, 0,
    40  				0, 2, 0,
    41  				0, 0, 3,
    42  			}),
    43  		},
    44  		{
    45  			data: []float64{
    46  				1, 2,
    47  				3, 4,
    48  				5, 6,
    49  				7, 8,
    50  				9, 10,
    51  				11, -1,
    52  			},
    53  			n: 6, k: 1,
    54  			kind: Upper,
    55  			mat: &TriBandDense{
    56  				mat: blas64.TriangularBand{
    57  					Diag: blas.NonUnit,
    58  					Uplo: blas.Upper,
    59  					N:    6, K: 1,
    60  					Data: []float64{
    61  						1, 2,
    62  						3, 4,
    63  						5, 6,
    64  						7, 8,
    65  						9, 10,
    66  						11, -1,
    67  					},
    68  					Stride: 2,
    69  				},
    70  			},
    71  			dense: NewDense(6, 6, []float64{
    72  				1, 2, 0, 0, 0, 0,
    73  				0, 3, 4, 0, 0, 0,
    74  				0, 0, 5, 6, 0, 0,
    75  				0, 0, 0, 7, 8, 0,
    76  				0, 0, 0, 0, 9, 10,
    77  				0, 0, 0, 0, 0, 11,
    78  			}),
    79  		},
    80  		{
    81  			data: []float64{
    82  				1, 2, 3,
    83  				4, 5, 6,
    84  				7, 8, 9,
    85  				10, 11, 12,
    86  				13, 14, -1,
    87  				15, -1, -1,
    88  			},
    89  			n: 6, k: 2,
    90  			kind: Upper,
    91  			mat: &TriBandDense{
    92  				mat: blas64.TriangularBand{
    93  					Diag: blas.NonUnit,
    94  					Uplo: blas.Upper,
    95  					N:    6, K: 2,
    96  					Data: []float64{
    97  						1, 2, 3,
    98  						4, 5, 6,
    99  						7, 8, 9,
   100  						10, 11, 12,
   101  						13, 14, -1,
   102  						15, -1, -1,
   103  					},
   104  					Stride: 3,
   105  				},
   106  			},
   107  			dense: NewDense(6, 6, []float64{
   108  				1, 2, 3, 0, 0, 0,
   109  				0, 4, 5, 6, 0, 0,
   110  				0, 0, 7, 8, 9, 0,
   111  				0, 0, 0, 10, 11, 12,
   112  				0, 0, 0, 0, 13, 14,
   113  				0, 0, 0, 0, 0, 15,
   114  			}),
   115  		},
   116  		{
   117  			data: []float64{
   118  				-1, 1,
   119  				2, 3,
   120  				4, 5,
   121  				6, 7,
   122  				8, 9,
   123  				10, 11,
   124  			},
   125  			n: 6, k: 1,
   126  			kind: Lower,
   127  			mat: &TriBandDense{
   128  				mat: blas64.TriangularBand{
   129  					Diag: blas.NonUnit,
   130  					Uplo: blas.Lower,
   131  					N:    6, K: 1,
   132  					Data: []float64{
   133  						-1, 1,
   134  						2, 3,
   135  						4, 5,
   136  						6, 7,
   137  						8, 9,
   138  						10, 11,
   139  					},
   140  					Stride: 2,
   141  				},
   142  			},
   143  			dense: NewDense(6, 6, []float64{
   144  				1, 0, 0, 0, 0, 0,
   145  				2, 3, 0, 0, 0, 0,
   146  				0, 4, 5, 0, 0, 0,
   147  				0, 0, 6, 7, 0, 0,
   148  				0, 0, 0, 8, 9, 0,
   149  				0, 0, 0, 0, 10, 11,
   150  			}),
   151  		},
   152  		{
   153  			data: []float64{
   154  				-1, -1, 1,
   155  				-1, 2, 3,
   156  				4, 5, 6,
   157  				7, 8, 9,
   158  				10, 11, 12,
   159  				13, 14, 15,
   160  			},
   161  			n: 6, k: 2,
   162  			kind: Lower,
   163  			mat: &TriBandDense{
   164  				mat: blas64.TriangularBand{
   165  					Diag: blas.NonUnit,
   166  					Uplo: blas.Lower,
   167  					N:    6, K: 2,
   168  					Data: []float64{
   169  						-1, -1, 1,
   170  						-1, 2, 3,
   171  						4, 5, 6,
   172  						7, 8, 9,
   173  						10, 11, 12,
   174  						13, 14, 15,
   175  					},
   176  					Stride: 3,
   177  				},
   178  			},
   179  			dense: NewDense(6, 6, []float64{
   180  				1, 0, 0, 0, 0, 0,
   181  				2, 3, 0, 0, 0, 0,
   182  				4, 5, 6, 0, 0, 0,
   183  				0, 7, 8, 9, 0, 0,
   184  				0, 0, 10, 11, 12, 0,
   185  				0, 0, 0, 13, 14, 15,
   186  			}),
   187  		},
   188  	} {
   189  		triBand := NewTriBandDense(test.n, test.k, test.kind, test.data)
   190  		r, c := triBand.Dims()
   191  		n, k, kind := triBand.TriBand()
   192  		if n != test.n {
   193  			t.Errorf("unexpected triband size for test %d: got: %d want: %d", cas, n, test.n)
   194  		}
   195  		if k != test.k {
   196  			t.Errorf("unexpected triband bandwidth for test %d: got: %d want: %d", cas, k, test.k)
   197  		}
   198  		if kind != test.kind {
   199  			t.Errorf("unexpected triband bandwidth for test %v: got: %v want: %v", cas, kind, test.kind)
   200  		}
   201  		if r != n {
   202  			t.Errorf("unexpected number of rows for test %d: got: %d want: %d", cas, r, n)
   203  		}
   204  		if c != n {
   205  			t.Errorf("unexpected number of cols for test %d: got: %d want: %d", cas, c, n)
   206  		}
   207  		if !reflect.DeepEqual(triBand, test.mat) {
   208  			t.Errorf("unexpected value via reflect for test %d: got: %v want: %v", cas, triBand, test.mat)
   209  		}
   210  		if !Equal(triBand, test.mat) {
   211  			t.Errorf("unexpected value via mat.Equal for test %d: got: %v want: %v", cas, triBand, test.mat)
   212  		}
   213  		if !Equal(triBand, test.dense) {
   214  			t.Errorf("unexpected value via mat.Equal(band, dense) for test %d:\ngot:\n% v\nwant:\n% v", cas, Formatted(triBand), Formatted(test.dense))
   215  		}
   216  	}
   217  }
   218  
   219  func TestTriBandAtSetUpper(t *testing.T) {
   220  	t.Parallel()
   221  	for _, kind := range []TriKind{Upper, Lower} {
   222  		var band *TriBandDense
   223  		var data []float64
   224  		if kind {
   225  			// 1  2  3  0  0  0
   226  			// 0  4  5  6  0  0
   227  			// 0  0  7  8  9  0
   228  			// 0  0  0 10 11 12
   229  			// 0  0  0  0 13 14
   230  			// 0  0  0  0  0 15
   231  			data = []float64{
   232  				1, 2, 3,
   233  				4, 5, 6,
   234  				7, 8, 9,
   235  				10, 11, 12,
   236  				13, 14, -1,
   237  				15, -1, -1,
   238  			}
   239  			band = NewTriBandDense(6, 2, kind, data)
   240  		} else {
   241  			// 1  0  0  0  0  0
   242  			// 2  3  0  0  0  0
   243  			// 4  5  6  0  0  0
   244  			// 0  7  8  9  0  0
   245  			// 0  0 10 11 12  0
   246  			// 0  0  0 13 14 15
   247  			data = []float64{
   248  				-1, -1, 1,
   249  				-1, 2, 3,
   250  				4, 5, 6,
   251  				7, 8, 9,
   252  				10, 11, 12,
   253  				13, 14, 15,
   254  			}
   255  			band = NewTriBandDense(6, 2, kind, data)
   256  		}
   257  
   258  		rows, cols := band.Dims()
   259  
   260  		// Check At out of bounds.
   261  		for _, row := range []int{-1, rows, rows + 1} {
   262  			panicked, message := panics(func() { band.At(row, 0) })
   263  			if !panicked || message != ErrRowAccess.Error() {
   264  				t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row)
   265  			}
   266  		}
   267  		for _, col := range []int{-1, cols, cols + 1} {
   268  			panicked, message := panics(func() { band.At(0, col) })
   269  			if !panicked || message != ErrColAccess.Error() {
   270  				t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col)
   271  			}
   272  		}
   273  
   274  		// Check Set out of bounds
   275  		// First, check outside the matrix bounds.
   276  		for _, row := range []int{-1, rows, rows + 1} {
   277  			panicked, message := panics(func() { band.SetTriBand(row, 0, 1.2) })
   278  			if !panicked || message != ErrRowAccess.Error() {
   279  				t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row)
   280  			}
   281  		}
   282  		for _, col := range []int{-1, cols, cols + 1} {
   283  			panicked, message := panics(func() { band.SetTriBand(0, col, 1.2) })
   284  			if !panicked || message != ErrColAccess.Error() {
   285  				t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col)
   286  			}
   287  		}
   288  		// Next, check outside the Triangular bounds.
   289  		for _, s := range []struct{ r, c int }{
   290  			{3, 2},
   291  		} {
   292  			if kind == Lower {
   293  				s.r, s.c = s.c, s.r
   294  			}
   295  			panicked, message := panics(func() { band.SetTriBand(s.r, s.c, 1.2) })
   296  			if !panicked || message != ErrTriangleSet.Error() {
   297  				t.Errorf("expected panic for invalid triangular access N=%d, r=%d c=%d", cols, s.r, s.c)
   298  			}
   299  		}
   300  		// Finally, check inside the triangle, but outside the band.
   301  		for _, s := range []struct{ r, c int }{
   302  			{1, 5},
   303  		} {
   304  			if kind == Lower {
   305  				s.r, s.c = s.c, s.r
   306  			}
   307  			panicked, message := panics(func() { band.SetTriBand(s.r, s.c, 1.2) })
   308  			if !panicked || message != ErrBandSet.Error() {
   309  				t.Errorf("expected panic for invalid triangular access N=%d, r=%d c=%d", cols, s.r, s.c)
   310  			}
   311  		}
   312  
   313  		// Test that At and Set work correctly.
   314  		offset := 100.0
   315  		dataCopy := make([]float64, len(data))
   316  		copy(dataCopy, data)
   317  		for i := 0; i < rows; i++ {
   318  			for j := 0; j < rows; j++ {
   319  				v := band.At(i, j)
   320  				if v != 0 {
   321  					band.SetTriBand(i, j, v+offset)
   322  				}
   323  			}
   324  		}
   325  		for i, v := range dataCopy {
   326  			if v == -1 {
   327  				if data[i] != -1 {
   328  					t.Errorf("Set changed unexpected entry. Want %v, got %v", -1, data[i])
   329  				}
   330  			} else {
   331  				if v != data[i]-offset {
   332  					t.Errorf("Set incorrectly changed for %v. got %v, want %v", v, data[i], v+offset)
   333  				}
   334  			}
   335  		}
   336  	}
   337  }
   338  
   339  func TestTriBandDenseZero(t *testing.T) {
   340  	t.Parallel()
   341  	// Elements that equal 1 should be set to zero, elements that equal -1
   342  	// should remain unchanged.
   343  	for _, test := range []*TriBandDense{
   344  		{
   345  			mat: blas64.TriangularBand{
   346  				Uplo:   blas.Upper,
   347  				N:      6,
   348  				K:      2,
   349  				Stride: 5,
   350  				Data: []float64{
   351  					1, 1, 1, -1, -1,
   352  					1, 1, 1, -1, -1,
   353  					1, 1, 1, -1, -1,
   354  					1, 1, 1, -1, -1,
   355  					1, 1, -1, -1, -1,
   356  					1, -1, -1, -1, -1,
   357  				},
   358  			},
   359  		},
   360  		{
   361  			mat: blas64.TriangularBand{
   362  				Uplo:   blas.Lower,
   363  				N:      6,
   364  				K:      2,
   365  				Stride: 5,
   366  				Data: []float64{
   367  					-1, -1, 1, -1, -1,
   368  					-1, 1, 1, -1, -1,
   369  					1, 1, 1, -1, -1,
   370  					1, 1, 1, -1, -1,
   371  					1, 1, 1, -1, -1,
   372  					1, 1, 1, -1, -1,
   373  				},
   374  			},
   375  		},
   376  	} {
   377  		dataCopy := make([]float64, len(test.mat.Data))
   378  		copy(dataCopy, test.mat.Data)
   379  		test.Zero()
   380  		for i, v := range test.mat.Data {
   381  			if dataCopy[i] != -1 && v != 0 {
   382  				t.Errorf("Matrix not zeroed in bounds")
   383  			}
   384  			if dataCopy[i] == -1 && v != -1 {
   385  				t.Errorf("Matrix zeroed out of bounds")
   386  			}
   387  		}
   388  	}
   389  }
   390  
   391  func TestTriBandDiagView(t *testing.T) {
   392  	t.Parallel()
   393  	for cas, test := range []*TriBandDense{
   394  		NewTriBandDense(1, 0, Upper, []float64{1}),
   395  		NewTriBandDense(4, 0, Upper, []float64{1, 2, 3, 4}),
   396  		NewTriBandDense(6, 2, Upper, []float64{
   397  			1, 2, 3,
   398  			4, 5, 6,
   399  			7, 8, 9,
   400  			10, 11, 12,
   401  			13, 14, -1,
   402  			15, -1, -1,
   403  		}),
   404  		NewTriBandDense(1, 0, Lower, []float64{1}),
   405  		NewTriBandDense(4, 0, Lower, []float64{1, 2, 3, 4}),
   406  		NewTriBandDense(6, 2, Lower, []float64{
   407  			-1, -1, 1,
   408  			-1, 2, 3,
   409  			4, 5, 6,
   410  			7, 8, 9,
   411  			10, 11, 12,
   412  			13, 14, 15,
   413  		}),
   414  	} {
   415  		testDiagView(t, cas, test)
   416  	}
   417  }
   418  
   419  func TestTriBandDenseSolveTo(t *testing.T) {
   420  	t.Parallel()
   421  
   422  	const tol = 1e-15
   423  
   424  	for tc, test := range []struct {
   425  		a *TriBandDense
   426  		b *Dense
   427  	}{
   428  		{
   429  			a: NewTriBandDense(5, 2, Upper, []float64{
   430  				-0.34, -0.49, -0.51,
   431  				-0.25, -0.5, 1.03,
   432  				-1.1, 0.3, -0.82,
   433  				1.69, 0.69, -2.22,
   434  				-0.62, 1.22, -0.85,
   435  			}),
   436  			b: NewDense(5, 2, []float64{
   437  				0.44, 1.34,
   438  				0.07, -1.45,
   439  				-0.32, -0.88,
   440  				-0.09, -0.15,
   441  				-1.17, -0.19,
   442  			}),
   443  		},
   444  		{
   445  			a: NewTriBandDense(5, 2, Lower, []float64{
   446  				0, 0, -0.34,
   447  				0, -0.49, -0.25,
   448  				-0.51, -0.5, -1.1,
   449  				1.03, 0.3, 1.69,
   450  				-0.82, 0.69, -0.62,
   451  			}),
   452  			b: NewDense(5, 2, []float64{
   453  				0.44, 1.34,
   454  				0.07, -1.45,
   455  				-0.32, -0.88,
   456  				-0.09, -0.15,
   457  				-1.17, -0.19,
   458  			}),
   459  		},
   460  	} {
   461  		a := test.a
   462  		for _, trans := range []bool{false, true} {
   463  			for _, dstSameAsB := range []bool{false, true} {
   464  				name := fmt.Sprintf("Case %d,trans=%v,dstSameAsB=%v", tc, trans, dstSameAsB)
   465  
   466  				n, nrhs := test.b.Dims()
   467  				var dst Dense
   468  				var err error
   469  				if dstSameAsB {
   470  					dst = *NewDense(n, nrhs, nil)
   471  					dst.Copy(test.b)
   472  					err = a.SolveTo(&dst, trans, &dst)
   473  				} else {
   474  					tmp := NewDense(n, nrhs, nil)
   475  					tmp.Copy(test.b)
   476  					err = a.SolveTo(&dst, trans, asBasicMatrix(tmp))
   477  				}
   478  
   479  				if err != nil {
   480  					t.Fatalf("%v: unexpected error from SolveTo", name)
   481  				}
   482  
   483  				var resid Dense
   484  				if trans {
   485  					resid.Mul(a.T(), &dst)
   486  				} else {
   487  					resid.Mul(a, &dst)
   488  				}
   489  				resid.Sub(&resid, test.b)
   490  				diff := Norm(&resid, 1)
   491  				if diff > tol {
   492  					t.Errorf("%v: unexpected result; diff=%v,want<=%v", name, diff, tol)
   493  				}
   494  			}
   495  		}
   496  	}
   497  }
   498  
   499  func TestTriBandDenseSolveVecTo(t *testing.T) {
   500  	t.Parallel()
   501  
   502  	const tol = 1e-15
   503  
   504  	for tc, test := range []struct {
   505  		a *TriBandDense
   506  		b *VecDense
   507  	}{
   508  		{
   509  			a: NewTriBandDense(5, 2, Upper, []float64{
   510  				-0.34, -0.49, -0.51,
   511  				-0.25, -0.5, 1.03,
   512  				-1.1, 0.3, -0.82,
   513  				1.69, 0.69, -2.22,
   514  				-0.62, 1.22, -0.85,
   515  			}),
   516  			b: NewVecDense(5, []float64{
   517  				0.44,
   518  				0.07,
   519  				-0.32,
   520  				-0.09,
   521  				-1.17,
   522  			}),
   523  		},
   524  		{
   525  			a: NewTriBandDense(5, 2, Lower, []float64{
   526  				0, 0, -0.34,
   527  				0, -0.49, -0.25,
   528  				-0.51, -0.5, -1.1,
   529  				1.03, 0.3, 1.69,
   530  				-0.82, 0.69, -0.62,
   531  			}),
   532  			b: NewVecDense(5, []float64{
   533  				0.44,
   534  				0.07,
   535  				-0.32,
   536  				-0.09,
   537  				-1.17,
   538  			}),
   539  		},
   540  	} {
   541  		a := test.a
   542  		for _, trans := range []bool{false, true} {
   543  			for _, dstSameAsB := range []bool{false, true} {
   544  				name := fmt.Sprintf("Case %d,trans=%v,dstSameAsB=%v", tc, trans, dstSameAsB)
   545  
   546  				n, _ := test.b.Dims()
   547  				var dst VecDense
   548  				var err error
   549  				if dstSameAsB {
   550  					dst = *NewVecDense(n, nil)
   551  					dst.CopyVec(test.b)
   552  					err = a.SolveVecTo(&dst, trans, &dst)
   553  				} else {
   554  					tmp := NewVecDense(n, nil)
   555  					tmp.CopyVec(test.b)
   556  					err = a.SolveVecTo(&dst, trans, asBasicVector(tmp))
   557  				}
   558  
   559  				if err != nil {
   560  					t.Fatalf("%v: unexpected error from SolveVecTo", name)
   561  				}
   562  
   563  				var resid VecDense
   564  				if trans {
   565  					resid.MulVec(a.T(), &dst)
   566  				} else {
   567  					resid.MulVec(a, &dst)
   568  				}
   569  				resid.SubVec(&resid, test.b)
   570  				diff := Norm(&resid, 1)
   571  				if diff > tol {
   572  					t.Errorf("%v: unexpected result; diff=%v,want<=%v", name, diff, tol)
   573  				}
   574  			}
   575  		}
   576  	}
   577  }