github.com/gonum/matrix@v0.0.0-20181209220409-c518dec07be9/mat64/triangular_test.go (about)

     1  package mat64
     2  
     3  import (
     4  	"math"
     5  	"math/rand"
     6  	"reflect"
     7  	"testing"
     8  
     9  	"github.com/gonum/blas"
    10  	"github.com/gonum/blas/blas64"
    11  	"github.com/gonum/matrix"
    12  )
    13  
    14  func TestNewTriangular(t *testing.T) {
    15  	for i, test := range []struct {
    16  		data []float64
    17  		n    int
    18  		kind matrix.TriKind
    19  		mat  *TriDense
    20  	}{
    21  		{
    22  			data: []float64{
    23  				1, 2, 3,
    24  				4, 5, 6,
    25  				7, 8, 9,
    26  			},
    27  			n:    3,
    28  			kind: matrix.Upper,
    29  			mat: &TriDense{
    30  				mat: blas64.Triangular{
    31  					N:      3,
    32  					Stride: 3,
    33  					Uplo:   blas.Upper,
    34  					Data:   []float64{1, 2, 3, 4, 5, 6, 7, 8, 9},
    35  					Diag:   blas.NonUnit,
    36  				},
    37  				cap: 3,
    38  			},
    39  		},
    40  	} {
    41  		tri := NewTriDense(test.n, test.kind, test.data)
    42  		rows, cols := tri.Dims()
    43  
    44  		if rows != test.n {
    45  			t.Errorf("unexpected number of rows for test %d: got: %d want: %d", i, rows, test.n)
    46  		}
    47  		if cols != test.n {
    48  			t.Errorf("unexpected number of cols for test %d: got: %d want: %d", i, cols, test.n)
    49  		}
    50  		if !reflect.DeepEqual(tri, test.mat) {
    51  			t.Errorf("unexpected data slice for test %d: got: %v want: %v", i, tri, test.mat)
    52  		}
    53  	}
    54  
    55  	for _, kind := range []matrix.TriKind{matrix.Lower, matrix.Upper} {
    56  		panicked, message := panics(func() { NewTriDense(3, kind, []float64{1, 2}) })
    57  		if !panicked || message != matrix.ErrShape.Error() {
    58  			t.Errorf("expected panic for invalid data slice length for upper=%t", kind)
    59  		}
    60  	}
    61  }
    62  
    63  func TestTriAtSet(t *testing.T) {
    64  	tri := &TriDense{
    65  		mat: blas64.Triangular{
    66  			N:      3,
    67  			Stride: 3,
    68  			Uplo:   blas.Upper,
    69  			Data:   []float64{1, 2, 3, 4, 5, 6, 7, 8, 9},
    70  			Diag:   blas.NonUnit,
    71  		},
    72  		cap: 3,
    73  	}
    74  
    75  	rows, cols := tri.Dims()
    76  
    77  	// Check At out of bounds
    78  	for _, row := range []int{-1, rows, rows + 1} {
    79  		panicked, message := panics(func() { tri.At(row, 0) })
    80  		if !panicked || message != matrix.ErrRowAccess.Error() {
    81  			t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row)
    82  		}
    83  	}
    84  	for _, col := range []int{-1, cols, cols + 1} {
    85  		panicked, message := panics(func() { tri.At(0, col) })
    86  		if !panicked || message != matrix.ErrColAccess.Error() {
    87  			t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col)
    88  		}
    89  	}
    90  
    91  	// Check Set out of bounds
    92  	for _, row := range []int{-1, rows, rows + 1} {
    93  		panicked, message := panics(func() { tri.SetTri(row, 0, 1.2) })
    94  		if !panicked || message != matrix.ErrRowAccess.Error() {
    95  			t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row)
    96  		}
    97  	}
    98  	for _, col := range []int{-1, cols, cols + 1} {
    99  		panicked, message := panics(func() { tri.SetTri(0, col, 1.2) })
   100  		if !panicked || message != matrix.ErrColAccess.Error() {
   101  			t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col)
   102  		}
   103  	}
   104  
   105  	for _, st := range []struct {
   106  		row, col int
   107  		uplo     blas.Uplo
   108  	}{
   109  		{row: 2, col: 1, uplo: blas.Upper},
   110  		{row: 1, col: 2, uplo: blas.Lower},
   111  	} {
   112  		tri.mat.Uplo = st.uplo
   113  		panicked, message := panics(func() { tri.SetTri(st.row, st.col, 1.2) })
   114  		if !panicked || message != matrix.ErrTriangleSet.Error() {
   115  			t.Errorf("expected panic for %+v", st)
   116  		}
   117  	}
   118  
   119  	for _, st := range []struct {
   120  		row, col  int
   121  		uplo      blas.Uplo
   122  		orig, new float64
   123  	}{
   124  		{row: 2, col: 1, uplo: blas.Lower, orig: 8, new: 15},
   125  		{row: 1, col: 2, uplo: blas.Upper, orig: 6, new: 15},
   126  	} {
   127  		tri.mat.Uplo = st.uplo
   128  		if e := tri.At(st.row, st.col); e != st.orig {
   129  			t.Errorf("unexpected value for At(%d, %d): got: %v want: %v", st.row, st.col, e, st.orig)
   130  		}
   131  		tri.SetTri(st.row, st.col, st.new)
   132  		if e := tri.At(st.row, st.col); e != st.new {
   133  			t.Errorf("unexpected value for At(%d, %d) after SetTri(%[1]d, %d, %v): got: %v want: %[3]v", st.row, st.col, st.new, e)
   134  		}
   135  	}
   136  }
   137  
   138  func TestTriDenseCopy(t *testing.T) {
   139  	for i := 0; i < 100; i++ {
   140  		size := rand.Intn(100)
   141  		r, err := randDense(size, 0.9, rand.NormFloat64)
   142  		if size == 0 {
   143  			if err != matrix.ErrZeroLength {
   144  				t.Fatalf("expected error %v: got: %v", matrix.ErrZeroLength, err)
   145  			}
   146  			continue
   147  		}
   148  		if err != nil {
   149  			t.Fatalf("unexpected error: %v", err)
   150  		}
   151  
   152  		u := NewTriDense(size, true, nil)
   153  		l := NewTriDense(size, false, nil)
   154  
   155  		for _, typ := range []Matrix{r, (*basicMatrix)(r)} {
   156  			for j := range u.mat.Data {
   157  				u.mat.Data[j] = math.NaN()
   158  				l.mat.Data[j] = math.NaN()
   159  			}
   160  			u.Copy(typ)
   161  			l.Copy(typ)
   162  			for m := 0; m < size; m++ {
   163  				for n := 0; n < size; n++ {
   164  					want := typ.At(m, n)
   165  					switch {
   166  					case m < n: // Upper triangular matrix.
   167  						if got := u.At(m, n); got != want {
   168  							t.Errorf("unexpected upper value for At(%d, %d) for test %d: got: %v want: %v", m, n, i, got, want)
   169  						}
   170  					case m == n: // Diagonal matrix.
   171  						if got := u.At(m, n); got != want {
   172  							t.Errorf("unexpected upper value for At(%d, %d) for test %d: got: %v want: %v", m, n, i, got, want)
   173  						}
   174  						if got := l.At(m, n); got != want {
   175  							t.Errorf("unexpected diagonal value for At(%d, %d) for test %d: got: %v want: %v", m, n, i, got, want)
   176  						}
   177  					case m < n: // Lower triangular matrix.
   178  						if got := l.At(m, n); got != want {
   179  							t.Errorf("unexpected lower value for At(%d, %d) for test %d: got: %v want: %v", m, n, i, got, want)
   180  						}
   181  					}
   182  				}
   183  			}
   184  		}
   185  	}
   186  }
   187  
   188  func TestTriTriDenseCopy(t *testing.T) {
   189  	for i := 0; i < 100; i++ {
   190  		size := rand.Intn(100)
   191  		r, err := randDense(size, 1, rand.NormFloat64)
   192  		if size == 0 {
   193  			if err != matrix.ErrZeroLength {
   194  				t.Fatalf("expected error %v: got: %v", matrix.ErrZeroLength, err)
   195  			}
   196  			continue
   197  		}
   198  		if err != nil {
   199  			t.Fatalf("unexpected error: %v", err)
   200  		}
   201  
   202  		ur := NewTriDense(size, true, nil)
   203  		lr := NewTriDense(size, false, nil)
   204  
   205  		ur.Copy(r)
   206  		lr.Copy(r)
   207  
   208  		u := NewTriDense(size, true, nil)
   209  		u.Copy(ur)
   210  		if !equal(u, ur) {
   211  			t.Fatal("unexpected result for U triangle copy of U triangle: not equal")
   212  		}
   213  
   214  		l := NewTriDense(size, false, nil)
   215  		l.Copy(lr)
   216  		if !equal(l, lr) {
   217  			t.Fatal("unexpected result for L triangle copy of L triangle: not equal")
   218  		}
   219  
   220  		zero(u.mat.Data)
   221  		u.Copy(lr)
   222  		if !isDiagonal(u) {
   223  			t.Fatal("unexpected result for U triangle copy of L triangle: off diagonal non-zero element")
   224  		}
   225  		if !equalDiagonal(u, lr) {
   226  			t.Fatal("unexpected result for U triangle copy of L triangle: diagonal not equal")
   227  		}
   228  
   229  		zero(l.mat.Data)
   230  		l.Copy(ur)
   231  		if !isDiagonal(l) {
   232  			t.Fatal("unexpected result for L triangle copy of U triangle: off diagonal non-zero element")
   233  		}
   234  		if !equalDiagonal(l, ur) {
   235  			t.Fatal("unexpected result for L triangle copy of U triangle: diagonal not equal")
   236  		}
   237  	}
   238  }
   239  
   240  func TestTriInverse(t *testing.T) {
   241  	for _, kind := range []matrix.TriKind{matrix.Upper, matrix.Lower} {
   242  		for _, n := range []int{1, 3, 5, 9} {
   243  			data := make([]float64, n*n)
   244  			for i := range data {
   245  				data[i] = rand.NormFloat64()
   246  			}
   247  			a := NewTriDense(n, kind, data)
   248  			var tr TriDense
   249  			err := tr.InverseTri(a)
   250  			if err != nil {
   251  				t.Errorf("Bad test: %s", err)
   252  			}
   253  			var d Dense
   254  			d.Mul(a, &tr)
   255  			if !equalApprox(eye(n), &d, 1e-8, false) {
   256  				var diff Dense
   257  				diff.Sub(eye(n), &d)
   258  				t.Errorf("Tri times inverse is not identity. Norm of difference: %v", Norm(&diff, 2))
   259  			}
   260  		}
   261  	}
   262  }
   263  
   264  func TestTriMul(t *testing.T) {
   265  	method := func(receiver, a, b Matrix) {
   266  		type MulTrier interface {
   267  			MulTri(a, b Triangular)
   268  		}
   269  		receiver.(MulTrier).MulTri(a.(Triangular), b.(Triangular))
   270  	}
   271  	denseComparison := func(receiver, a, b *Dense) {
   272  		receiver.Mul(a, b)
   273  	}
   274  	legalSizeTriMul := func(ar, ac, br, bc int) bool {
   275  		// Need both to be square and the sizes to be the same
   276  		return ar == ac && br == bc && ar == br
   277  	}
   278  
   279  	// The legal types are triangles with the same TriKind.
   280  	// legalTypesTri returns whether both input arguments are Triangular.
   281  	legalTypes := func(a, b Matrix) bool {
   282  		at, ok := a.(Triangular)
   283  		if !ok {
   284  			return false
   285  		}
   286  		bt, ok := b.(Triangular)
   287  		if !ok {
   288  			return false
   289  		}
   290  		_, ak := at.Triangle()
   291  		_, bk := bt.Triangle()
   292  		return ak == bk
   293  	}
   294  	legalTypesLower := func(a, b Matrix) bool {
   295  		legal := legalTypes(a, b)
   296  		if !legal {
   297  			return false
   298  		}
   299  		_, kind := a.(Triangular).Triangle()
   300  		r := kind == matrix.Lower
   301  		return r
   302  	}
   303  	receiver := NewTriDense(3, matrix.Lower, nil)
   304  	testTwoInput(t, "TriMul", receiver, method, denseComparison, legalTypesLower, legalSizeTriMul, 1e-14)
   305  
   306  	legalTypesUpper := func(a, b Matrix) bool {
   307  		legal := legalTypes(a, b)
   308  		if !legal {
   309  			return false
   310  		}
   311  		_, kind := a.(Triangular).Triangle()
   312  		r := kind == matrix.Upper
   313  		return r
   314  	}
   315  	receiver = NewTriDense(3, matrix.Upper, nil)
   316  	testTwoInput(t, "TriMul", receiver, method, denseComparison, legalTypesUpper, legalSizeTriMul, 1e-14)
   317  }