github.com/gopherd/gonum@v0.0.4/mat/tridiag_test.go (about)

     1  // Copyright ©2021 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"fmt"
     9  	"reflect"
    10  	"testing"
    11  
    12  	"math/rand"
    13  
    14  	"github.com/gopherd/gonum/lapack/lapack64"
    15  )
    16  
    17  func TestNewTridiag(t *testing.T) {
    18  	for i, test := range []struct {
    19  		n         int
    20  		dl, d, du []float64
    21  		panics    bool
    22  		want      *Tridiag
    23  		dense     *Dense
    24  	}{
    25  		{
    26  			n:      1,
    27  			dl:     nil,
    28  			d:      []float64{1.2},
    29  			du:     nil,
    30  			panics: false,
    31  			want: &Tridiag{
    32  				mat: lapack64.Tridiagonal{
    33  					N:  1,
    34  					DL: nil,
    35  					D:  []float64{1.2},
    36  					DU: nil,
    37  				},
    38  			},
    39  			dense: NewDense(1, 1, []float64{1.2}),
    40  		},
    41  		{
    42  			n:      1,
    43  			dl:     []float64{},
    44  			d:      []float64{1.2},
    45  			du:     []float64{},
    46  			panics: false,
    47  			want: &Tridiag{
    48  				mat: lapack64.Tridiagonal{
    49  					N:  1,
    50  					DL: []float64{},
    51  					D:  []float64{1.2},
    52  					DU: []float64{},
    53  				},
    54  			},
    55  			dense: NewDense(1, 1, []float64{1.2}),
    56  		},
    57  		{
    58  			n:      4,
    59  			dl:     []float64{1.2, 2.3, 3.4},
    60  			d:      []float64{4.5, 5.6, 6.7, 7.8},
    61  			du:     []float64{8.9, 9.0, 0.1},
    62  			panics: false,
    63  			want: &Tridiag{
    64  				mat: lapack64.Tridiagonal{
    65  					N:  4,
    66  					DL: []float64{1.2, 2.3, 3.4},
    67  					D:  []float64{4.5, 5.6, 6.7, 7.8},
    68  					DU: []float64{8.9, 9.0, 0.1},
    69  				},
    70  			},
    71  			dense: NewDense(4, 4, []float64{
    72  				4.5, 8.9, 0, 0,
    73  				1.2, 5.6, 9.0, 0,
    74  				0, 2.3, 6.7, 0.1,
    75  				0, 0, 3.4, 7.8,
    76  			}),
    77  		},
    78  		{
    79  			n:      4,
    80  			dl:     nil,
    81  			d:      nil,
    82  			du:     nil,
    83  			panics: false,
    84  			want: &Tridiag{
    85  				mat: lapack64.Tridiagonal{
    86  					N:  4,
    87  					DL: []float64{0, 0, 0},
    88  					D:  []float64{0, 0, 0, 0},
    89  					DU: []float64{0, 0, 0},
    90  				},
    91  			},
    92  			dense: NewDense(4, 4, nil),
    93  		},
    94  		{
    95  			n:      -1,
    96  			panics: true,
    97  		},
    98  		{
    99  			n:      0,
   100  			panics: true,
   101  		},
   102  		{
   103  			n:      1,
   104  			dl:     []float64{1.2},
   105  			d:      nil,
   106  			du:     nil,
   107  			panics: true,
   108  		},
   109  		{
   110  			n:      1,
   111  			dl:     nil,
   112  			d:      []float64{1.2, 2.3},
   113  			du:     nil,
   114  			panics: true,
   115  		},
   116  		{
   117  			n:      1,
   118  			dl:     []float64{},
   119  			d:      nil,
   120  			du:     []float64{},
   121  			panics: true,
   122  		},
   123  		{
   124  			n:      4,
   125  			dl:     []float64{1.2},
   126  			d:      nil,
   127  			du:     nil,
   128  			panics: true,
   129  		},
   130  		{
   131  			n:      4,
   132  			dl:     []float64{1.2, 2.3, 3.4},
   133  			d:      []float64{4.5, 5.6, 6.7, 7.8, 1.2},
   134  			du:     []float64{8.9, 9.0, 0.1},
   135  			panics: true,
   136  		},
   137  	} {
   138  		var a *Tridiag
   139  		panicked, msg := panics(func() {
   140  			a = NewTridiag(test.n, test.dl, test.d, test.du)
   141  		})
   142  		if panicked {
   143  			if !test.panics {
   144  				t.Errorf("Case %d: unexpected panic: %s", i, msg)
   145  			}
   146  			continue
   147  		}
   148  		if test.panics {
   149  			t.Errorf("Case %d: expected panic", i)
   150  			continue
   151  		}
   152  
   153  		r, c := a.Dims()
   154  		if r != test.n {
   155  			t.Errorf("Case %d: unexpected number of rows: got=%d want=%d", i, r, test.n)
   156  		}
   157  		if c != test.n {
   158  			t.Errorf("Case %d: unexpected number of columns: got=%d want=%d", i, c, test.n)
   159  		}
   160  
   161  		kl, ku := a.Bandwidth()
   162  		if kl != 1 || ku != 1 {
   163  			t.Errorf("Case %d: unexpected bandwidth: got=%d,%d want=1,1", i, kl, ku)
   164  		}
   165  
   166  		if !reflect.DeepEqual(a, test.want) {
   167  			t.Errorf("Case %d: unexpected value via reflect: got=%v, want=%v", i, a, test.want)
   168  		}
   169  		if !Equal(a, test.want) {
   170  			t.Errorf("Case %d: unexpected value via mat.Equal: got=%v, want=%v", i, a, test.want)
   171  		}
   172  		if !Equal(a, test.dense) {
   173  			t.Errorf("Case %d: unexpected value via mat.Equal(Tridiag,Dense):\ngot:\n% v\nwant:\n% v", i, Formatted(a), Formatted(test.dense))
   174  		}
   175  	}
   176  }
   177  
   178  func TestTridiagAtSet(t *testing.T) {
   179  	t.Parallel()
   180  	for _, n := range []int{1, 2, 3, 4, 7, 10} {
   181  		tri, ref := newTestTridiag(n)
   182  
   183  		name := fmt.Sprintf("Case n=%v", n)
   184  
   185  		// Check At explicitly with all valid indices.
   186  		for i := 0; i < n; i++ {
   187  			for j := 0; j < n; j++ {
   188  				if tri.At(i, j) != ref.At(i, j) {
   189  					t.Errorf("%v: unexpected value for At(%d,%d): got %v, want %v",
   190  						name, i, j, tri.At(i, j), ref.At(i, j))
   191  				}
   192  			}
   193  		}
   194  		// Check At via a call to Equal.
   195  		if !Equal(tri, ref) {
   196  			t.Errorf("%v: unexpected value:\ngot: % v\nwant:% v",
   197  				name, Formatted(tri, Prefix("     ")), Formatted(ref, Prefix("     ")))
   198  		}
   199  
   200  		// Check At out of bounds.
   201  		for _, i := range []int{-1, n, n + 1} {
   202  			for j := 0; j < n; j++ {
   203  				panicked, message := panics(func() { tri.At(i, j) })
   204  				if !panicked || message != ErrRowAccess.Error() {
   205  					t.Errorf("%v: expected panic for invalid row access at (%d,%d)", name, i, j)
   206  				}
   207  			}
   208  		}
   209  		for _, j := range []int{-1, n, n + 1} {
   210  			for i := 0; i < n; i++ {
   211  				panicked, message := panics(func() { tri.At(i, j) })
   212  				if !panicked || message != ErrColAccess.Error() {
   213  					t.Errorf("%v: expected panic for invalid column access at (%d,%d)", name, i, j)
   214  				}
   215  			}
   216  		}
   217  
   218  		// Check SetBand out of bounds.
   219  		for _, i := range []int{-1, n, n + 1} {
   220  			for j := 0; j < n; j++ {
   221  				panicked, message := panics(func() { tri.SetBand(i, j, 1.2) })
   222  				if !panicked || message != ErrRowAccess.Error() {
   223  					t.Errorf("%v: expected panic for invalid row access at (%d,%d)", name, i, j)
   224  				}
   225  			}
   226  		}
   227  		for _, j := range []int{-1, n, n + 1} {
   228  			for i := 0; i < n; i++ {
   229  				panicked, message := panics(func() { tri.SetBand(i, j, 1.2) })
   230  				if !panicked || message != ErrColAccess.Error() {
   231  					t.Errorf("%v: expected panic for invalid column access at (%d,%d)", name, i, j)
   232  				}
   233  			}
   234  		}
   235  		for i := 0; i < n; i++ {
   236  			for j := 0; j <= i-2; j++ {
   237  				panicked, message := panics(func() { tri.SetBand(i, j, 1.2) })
   238  				if !panicked || message != ErrBandSet.Error() {
   239  					t.Errorf("%v: expected panic for invalid access at (%d,%d)", name, i, j)
   240  				}
   241  			}
   242  			for j := i + 2; j < n; j++ {
   243  				panicked, message := panics(func() { tri.SetBand(i, j, 1.2) })
   244  				if !panicked || message != ErrBandSet.Error() {
   245  					t.Errorf("%v: expected panic for invalid access at (%d,%d)", name, i, j)
   246  				}
   247  			}
   248  		}
   249  
   250  		// Check SetBand within bandwidth.
   251  		for i := 0; i < n; i++ {
   252  			for j := max(0, i-1); j <= min(i+1, n-1); j++ {
   253  				want := float64(i*n + j + 100)
   254  				tri.SetBand(i, j, want)
   255  				if got := tri.At(i, j); got != want {
   256  					t.Errorf("%v: unexpected value at (%d,%d) after SetBand: got %v, want %v", name, i, j, got, want)
   257  				}
   258  			}
   259  		}
   260  	}
   261  }
   262  
   263  func newTestTridiag(n int) (*Tridiag, *Dense) {
   264  	var dl, d, du []float64
   265  	d = make([]float64, n)
   266  	if n > 1 {
   267  		dl = make([]float64, n-1)
   268  		du = make([]float64, n-1)
   269  	}
   270  	for i := range d {
   271  		d[i] = float64(i*n + i + 1)
   272  	}
   273  	for j := range dl {
   274  		i := j + 1
   275  		dl[j] = float64(i*n + j + 1)
   276  	}
   277  	for i := range du {
   278  		j := i + 1
   279  		du[i] = float64(i*n + j + 1)
   280  	}
   281  	dense := make([]float64, n*n)
   282  	for i := 0; i < n; i++ {
   283  		for j := max(0, i-1); j <= min(i+1, n-1); j++ {
   284  			dense[i*n+j] = float64(i*n + j + 1)
   285  		}
   286  	}
   287  	return NewTridiag(n, dl, d, du), NewDense(n, n, dense)
   288  }
   289  
   290  func TestTridiagReset(t *testing.T) {
   291  	t.Parallel()
   292  	for _, n := range []int{1, 2, 3, 4, 7, 10} {
   293  		a, _ := newTestTridiag(n)
   294  		if a.IsEmpty() {
   295  			t.Errorf("Case n=%d: matrix is empty", n)
   296  		}
   297  		a.Reset()
   298  		if !a.IsEmpty() {
   299  			t.Errorf("Case n=%d: matrix is not empty after Reset", n)
   300  		}
   301  	}
   302  }
   303  
   304  func TestTridiagDiagView(t *testing.T) {
   305  	t.Parallel()
   306  	for _, n := range []int{1, 2, 3, 4, 7, 10} {
   307  		a, _ := newTestTridiag(n)
   308  		testDiagView(t, n, a)
   309  	}
   310  }
   311  
   312  func TestTridiagZero(t *testing.T) {
   313  	t.Parallel()
   314  	for _, n := range []int{1, 2, 3, 4, 7, 10} {
   315  		a, _ := newTestTridiag(n)
   316  		a.Zero()
   317  		for i := 0; i < n; i++ {
   318  			for j := 0; j < n; j++ {
   319  				if a.At(i, j) != 0 {
   320  					t.Errorf("Case n=%d: unexpected non-zero at (%d,%d): got %f", n, i, j, a.At(i, j))
   321  				}
   322  			}
   323  		}
   324  	}
   325  }
   326  
   327  func TestTridiagSolveTo(t *testing.T) {
   328  	t.Parallel()
   329  
   330  	const tol = 1e-13
   331  
   332  	rnd := rand.New(rand.NewSource(1))
   333  	random := func(n int) []float64 {
   334  		d := make([]float64, n)
   335  		for i := range d {
   336  			d[i] = rnd.NormFloat64()
   337  		}
   338  		return d
   339  	}
   340  
   341  	for _, n := range []int{1, 2, 3, 4, 7, 10} {
   342  		a := NewTridiag(n, random(n-1), random(n), random(n-1))
   343  		var aDense Dense
   344  		aDense.CloneFrom(a)
   345  		for _, trans := range []bool{false, true} {
   346  			for _, nrhs := range []int{1, 2, 5} {
   347  				const (
   348  					denseB = iota
   349  					rawB
   350  					basicB
   351  				)
   352  				for _, bType := range []int{denseB, rawB, basicB} {
   353  					const (
   354  						emptyDst = iota
   355  						shapedDst
   356  						bIsDst
   357  					)
   358  					for _, dstType := range []int{emptyDst, shapedDst, bIsDst} {
   359  						if dstType == bIsDst && bType != denseB {
   360  							continue
   361  						}
   362  
   363  						var b Matrix
   364  						switch bType {
   365  						case denseB:
   366  							b = NewDense(n, nrhs, random(n*nrhs))
   367  						case rawB:
   368  							b = &rawMatrix{asBasicMatrix(NewDense(n, nrhs, random(n*nrhs)))}
   369  						case basicB:
   370  							b = asBasicMatrix(NewDense(n, nrhs, random(n*nrhs)))
   371  						default:
   372  							panic("bad bType")
   373  						}
   374  
   375  						var dst *Dense
   376  						switch dstType {
   377  						case emptyDst:
   378  							dst = new(Dense)
   379  						case shapedDst:
   380  							dst = NewDense(n, nrhs, random(n*nrhs))
   381  						case bIsDst:
   382  							dst = b.(*Dense)
   383  						default:
   384  							panic("bad dstType")
   385  						}
   386  
   387  						name := fmt.Sprintf("n=%d,nrhs=%d,trans=%t,dstType=%d,bType=%d", n, nrhs, trans, dstType, bType)
   388  
   389  						var want Dense
   390  						var err error
   391  						if !trans {
   392  							err = want.Solve(&aDense, b)
   393  						} else {
   394  							err = want.Solve(aDense.T(), b)
   395  						}
   396  						if err != nil {
   397  							t.Fatalf("%v: unexpected failure when computing reference solution: %v", name, err)
   398  						}
   399  
   400  						err = a.SolveTo(dst, trans, b)
   401  						if err != nil {
   402  							t.Fatalf("%v: unexpected failure from Tridiag.SolveTo: %v", name, err)
   403  						}
   404  
   405  						var diff Dense
   406  						diff.Sub(dst, &want)
   407  						if resid := Norm(&diff, 1); resid > tol*float64(n) {
   408  							t.Errorf("%v: unexpected result; resid=%v, want<=%v", name, resid, tol*float64(n))
   409  						}
   410  					}
   411  				}
   412  			}
   413  		}
   414  	}
   415  }
   416  
   417  func TestTridiagSolveVecTo(t *testing.T) {
   418  	t.Parallel()
   419  
   420  	const tol = 1e-13
   421  
   422  	rnd := rand.New(rand.NewSource(1))
   423  	random := func(n int) []float64 {
   424  		d := make([]float64, n)
   425  		for i := range d {
   426  			d[i] = rnd.NormFloat64()
   427  		}
   428  		return d
   429  	}
   430  
   431  	for _, n := range []int{1, 2, 3, 4, 7, 10} {
   432  		a := NewTridiag(n, random(n-1), random(n), random(n-1))
   433  		var aDense Dense
   434  		aDense.CloneFrom(a)
   435  		for _, trans := range []bool{false, true} {
   436  			const (
   437  				denseB = iota
   438  				rawB
   439  				basicB
   440  			)
   441  			for _, bType := range []int{denseB, rawB, basicB} {
   442  				const (
   443  					emptyDst = iota
   444  					shapedDst
   445  					bIsDst
   446  				)
   447  				for _, dstType := range []int{emptyDst, shapedDst, bIsDst} {
   448  					if dstType == bIsDst && bType != denseB {
   449  						continue
   450  					}
   451  
   452  					var b Vector
   453  					switch bType {
   454  					case denseB:
   455  						b = NewVecDense(n, random(n))
   456  					case rawB:
   457  						b = &rawVector{asBasicVector(NewVecDense(n, random(n)))}
   458  					case basicB:
   459  						b = asBasicVector(NewVecDense(n, random(n)))
   460  					default:
   461  						panic("bad bType")
   462  					}
   463  
   464  					var dst *VecDense
   465  					switch dstType {
   466  					case emptyDst:
   467  						dst = new(VecDense)
   468  					case shapedDst:
   469  						dst = NewVecDense(n, random(n))
   470  					case bIsDst:
   471  						dst = b.(*VecDense)
   472  					default:
   473  						panic("bad dstType")
   474  					}
   475  
   476  					name := fmt.Sprintf("n=%d,trans=%t,dstType=%d,bType=%d", n, trans, dstType, bType)
   477  
   478  					var want VecDense
   479  					var err error
   480  					if !trans {
   481  						err = want.SolveVec(&aDense, b)
   482  					} else {
   483  						err = want.SolveVec(aDense.T(), b)
   484  					}
   485  					if err != nil {
   486  						t.Fatalf("%v: unexpected failure when computing reference solution: %v", name, err)
   487  					}
   488  
   489  					err = a.SolveVecTo(dst, trans, b)
   490  					if err != nil {
   491  						t.Fatalf("%v: unexpected failure from Tridiag.SolveTo: %v", name, err)
   492  					}
   493  
   494  					var diff Dense
   495  					diff.Sub(dst, &want)
   496  					if resid := Norm(&diff, 1); resid > tol*float64(n) {
   497  						t.Errorf("%v: unexpected result; resid=%v, want<=%v", name, resid, tol*float64(n))
   498  					}
   499  				}
   500  			}
   501  		}
   502  	}
   503  }