gonum.org/v1/gonum@v0.14.0/mat/lu_test.go (about)

     1  // Copyright ©2013 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"testing"
     9  
    10  	"golang.org/x/exp/rand"
    11  )
    12  
    13  func TestLUD(t *testing.T) {
    14  	t.Parallel()
    15  	rnd := rand.New(rand.NewSource(1))
    16  	for _, n := range []int{1, 5, 10, 11, 50} {
    17  		a := NewDense(n, n, nil)
    18  		for i := 0; i < n; i++ {
    19  			for j := 0; j < n; j++ {
    20  				a.Set(i, j, rnd.NormFloat64())
    21  			}
    22  		}
    23  		var want Dense
    24  		want.CloneFrom(a)
    25  
    26  		var lu LU
    27  		lu.Factorize(a)
    28  
    29  		var l, u TriDense
    30  		lu.LTo(&l)
    31  		lu.UTo(&u)
    32  		var p Dense
    33  		pivot := lu.Pivot(nil)
    34  		p.Permutation(n, pivot)
    35  		var got Dense
    36  		got.Product(&p, &l, &u)
    37  		if !EqualApprox(&got, &want, 1e-12) {
    38  			t.Errorf("PLU does not equal original matrix.\nWant: %v\n Got: %v", want, got)
    39  		}
    40  	}
    41  }
    42  
    43  func TestLURankOne(t *testing.T) {
    44  	t.Parallel()
    45  	rnd := rand.New(rand.NewSource(1))
    46  	for _, pivoting := range []bool{true} {
    47  		for _, n := range []int{3, 10, 50} {
    48  			// Construct a random LU factorization
    49  			lu := &LU{}
    50  			lu.lu = NewDense(n, n, nil)
    51  			for i := 0; i < n; i++ {
    52  				for j := 0; j < n; j++ {
    53  					lu.lu.Set(i, j, rnd.Float64())
    54  				}
    55  			}
    56  			lu.pivot = make([]int, n)
    57  			for i := range lu.pivot {
    58  				lu.pivot[i] = i
    59  			}
    60  			if pivoting {
    61  				// For each row, randomly swap with itself or a row after (like is done)
    62  				// in the actual LU factorization.
    63  				for i := range lu.pivot {
    64  					idx := i + rnd.Intn(n-i)
    65  					lu.pivot[i], lu.pivot[idx] = lu.pivot[idx], lu.pivot[i]
    66  				}
    67  			}
    68  			// Apply a rank one update. Ensure the update magnitude is larger than
    69  			// the equal tolerance.
    70  			alpha := rnd.Float64() + 1
    71  			x := NewVecDense(n, nil)
    72  			y := NewVecDense(n, nil)
    73  			for i := 0; i < n; i++ {
    74  				x.setVec(i, rnd.Float64()+1)
    75  				y.setVec(i, rnd.Float64()+1)
    76  			}
    77  			a := luReconstruct(lu)
    78  			a.RankOne(a, alpha, x, y)
    79  
    80  			var luNew LU
    81  			luNew.RankOne(lu, alpha, x, y)
    82  			lu.RankOne(lu, alpha, x, y)
    83  
    84  			aR1New := luReconstruct(&luNew)
    85  			aR1 := luReconstruct(lu)
    86  
    87  			if !Equal(aR1, aR1New) {
    88  				t.Error("Different answer when new receiver")
    89  			}
    90  			if !EqualApprox(aR1, a, 1e-10) {
    91  				t.Errorf("Rank one mismatch, pivot %v.\nWant: %v\nGot:%v\n", pivoting, a, aR1)
    92  			}
    93  		}
    94  	}
    95  }
    96  
    97  // luReconstruct reconstructs the original A matrix from an LU decomposition.
    98  func luReconstruct(lu *LU) *Dense {
    99  	var L, U TriDense
   100  	lu.LTo(&L)
   101  	lu.UTo(&U)
   102  	var P Dense
   103  	pivot := lu.Pivot(nil)
   104  	P.Permutation(len(pivot), pivot)
   105  
   106  	var a Dense
   107  	a.Mul(&L, &U)
   108  	a.Mul(&P, &a)
   109  	return &a
   110  }
   111  
   112  func TestLUSolveTo(t *testing.T) {
   113  	t.Parallel()
   114  	rnd := rand.New(rand.NewSource(1))
   115  	for _, test := range []struct {
   116  		n, bc int
   117  	}{
   118  		{5, 5},
   119  		{5, 10},
   120  		{10, 5},
   121  	} {
   122  		n := test.n
   123  		bc := test.bc
   124  		a := NewDense(n, n, nil)
   125  		for i := 0; i < n; i++ {
   126  			for j := 0; j < n; j++ {
   127  				a.Set(i, j, rnd.NormFloat64())
   128  			}
   129  		}
   130  		b := NewDense(n, bc, nil)
   131  		for i := 0; i < n; i++ {
   132  			for j := 0; j < bc; j++ {
   133  				b.Set(i, j, rnd.NormFloat64())
   134  			}
   135  		}
   136  		var lu LU
   137  		lu.Factorize(a)
   138  		var x Dense
   139  		if err := lu.SolveTo(&x, false, b); err != nil {
   140  			continue
   141  		}
   142  		var got Dense
   143  		got.Mul(a, &x)
   144  		if !EqualApprox(&got, b, 1e-12) {
   145  			t.Errorf("SolveTo mismatch for non-singular matrix. n = %v, bc = %v.\nWant: %v\nGot: %v", n, bc, b, got)
   146  		}
   147  	}
   148  	// TODO(btracey): Add testOneInput test when such a function exists.
   149  }
   150  
   151  func TestLUSolveToCond(t *testing.T) {
   152  	t.Parallel()
   153  	for _, test := range []*Dense{
   154  		NewDense(2, 2, []float64{1, 0, 0, 1e-20}),
   155  	} {
   156  		m, _ := test.Dims()
   157  		var lu LU
   158  		lu.Factorize(test)
   159  		b := NewDense(m, 2, nil)
   160  		var x Dense
   161  		if err := lu.SolveTo(&x, false, b); err == nil {
   162  			t.Error("No error for near-singular matrix in matrix solve.")
   163  		}
   164  
   165  		bvec := NewVecDense(m, nil)
   166  		var xvec VecDense
   167  		if err := lu.SolveVecTo(&xvec, false, bvec); err == nil {
   168  			t.Error("No error for near-singular matrix in matrix solve.")
   169  		}
   170  	}
   171  }
   172  
   173  func TestLUSolveVecTo(t *testing.T) {
   174  	t.Parallel()
   175  	rnd := rand.New(rand.NewSource(1))
   176  	for _, n := range []int{5, 10} {
   177  		a := NewDense(n, n, nil)
   178  		for i := 0; i < n; i++ {
   179  			for j := 0; j < n; j++ {
   180  				a.Set(i, j, rnd.NormFloat64())
   181  			}
   182  		}
   183  		b := NewVecDense(n, nil)
   184  		for i := 0; i < n; i++ {
   185  			b.SetVec(i, rnd.NormFloat64())
   186  		}
   187  		var lu LU
   188  		lu.Factorize(a)
   189  		var x VecDense
   190  		if err := lu.SolveVecTo(&x, false, b); err != nil {
   191  			continue
   192  		}
   193  		var got VecDense
   194  		got.MulVec(a, &x)
   195  		if !EqualApprox(&got, b, 1e-12) {
   196  			t.Errorf("SolveTo mismatch n = %v.\nWant: %v\nGot: %v", n, b, got)
   197  		}
   198  	}
   199  	// TODO(btracey): Add testOneInput test when such a function exists.
   200  }