gonum.org/v1/gonum@v0.14.0/mat/lu_test.go (about) 1 // Copyright ©2013 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "testing" 9 10 "golang.org/x/exp/rand" 11 ) 12 13 func TestLUD(t *testing.T) { 14 t.Parallel() 15 rnd := rand.New(rand.NewSource(1)) 16 for _, n := range []int{1, 5, 10, 11, 50} { 17 a := NewDense(n, n, nil) 18 for i := 0; i < n; i++ { 19 for j := 0; j < n; j++ { 20 a.Set(i, j, rnd.NormFloat64()) 21 } 22 } 23 var want Dense 24 want.CloneFrom(a) 25 26 var lu LU 27 lu.Factorize(a) 28 29 var l, u TriDense 30 lu.LTo(&l) 31 lu.UTo(&u) 32 var p Dense 33 pivot := lu.Pivot(nil) 34 p.Permutation(n, pivot) 35 var got Dense 36 got.Product(&p, &l, &u) 37 if !EqualApprox(&got, &want, 1e-12) { 38 t.Errorf("PLU does not equal original matrix.\nWant: %v\n Got: %v", want, got) 39 } 40 } 41 } 42 43 func TestLURankOne(t *testing.T) { 44 t.Parallel() 45 rnd := rand.New(rand.NewSource(1)) 46 for _, pivoting := range []bool{true} { 47 for _, n := range []int{3, 10, 50} { 48 // Construct a random LU factorization 49 lu := &LU{} 50 lu.lu = NewDense(n, n, nil) 51 for i := 0; i < n; i++ { 52 for j := 0; j < n; j++ { 53 lu.lu.Set(i, j, rnd.Float64()) 54 } 55 } 56 lu.pivot = make([]int, n) 57 for i := range lu.pivot { 58 lu.pivot[i] = i 59 } 60 if pivoting { 61 // For each row, randomly swap with itself or a row after (like is done) 62 // in the actual LU factorization. 63 for i := range lu.pivot { 64 idx := i + rnd.Intn(n-i) 65 lu.pivot[i], lu.pivot[idx] = lu.pivot[idx], lu.pivot[i] 66 } 67 } 68 // Apply a rank one update. Ensure the update magnitude is larger than 69 // the equal tolerance. 70 alpha := rnd.Float64() + 1 71 x := NewVecDense(n, nil) 72 y := NewVecDense(n, nil) 73 for i := 0; i < n; i++ { 74 x.setVec(i, rnd.Float64()+1) 75 y.setVec(i, rnd.Float64()+1) 76 } 77 a := luReconstruct(lu) 78 a.RankOne(a, alpha, x, y) 79 80 var luNew LU 81 luNew.RankOne(lu, alpha, x, y) 82 lu.RankOne(lu, alpha, x, y) 83 84 aR1New := luReconstruct(&luNew) 85 aR1 := luReconstruct(lu) 86 87 if !Equal(aR1, aR1New) { 88 t.Error("Different answer when new receiver") 89 } 90 if !EqualApprox(aR1, a, 1e-10) { 91 t.Errorf("Rank one mismatch, pivot %v.\nWant: %v\nGot:%v\n", pivoting, a, aR1) 92 } 93 } 94 } 95 } 96 97 // luReconstruct reconstructs the original A matrix from an LU decomposition. 98 func luReconstruct(lu *LU) *Dense { 99 var L, U TriDense 100 lu.LTo(&L) 101 lu.UTo(&U) 102 var P Dense 103 pivot := lu.Pivot(nil) 104 P.Permutation(len(pivot), pivot) 105 106 var a Dense 107 a.Mul(&L, &U) 108 a.Mul(&P, &a) 109 return &a 110 } 111 112 func TestLUSolveTo(t *testing.T) { 113 t.Parallel() 114 rnd := rand.New(rand.NewSource(1)) 115 for _, test := range []struct { 116 n, bc int 117 }{ 118 {5, 5}, 119 {5, 10}, 120 {10, 5}, 121 } { 122 n := test.n 123 bc := test.bc 124 a := NewDense(n, n, nil) 125 for i := 0; i < n; i++ { 126 for j := 0; j < n; j++ { 127 a.Set(i, j, rnd.NormFloat64()) 128 } 129 } 130 b := NewDense(n, bc, nil) 131 for i := 0; i < n; i++ { 132 for j := 0; j < bc; j++ { 133 b.Set(i, j, rnd.NormFloat64()) 134 } 135 } 136 var lu LU 137 lu.Factorize(a) 138 var x Dense 139 if err := lu.SolveTo(&x, false, b); err != nil { 140 continue 141 } 142 var got Dense 143 got.Mul(a, &x) 144 if !EqualApprox(&got, b, 1e-12) { 145 t.Errorf("SolveTo mismatch for non-singular matrix. n = %v, bc = %v.\nWant: %v\nGot: %v", n, bc, b, got) 146 } 147 } 148 // TODO(btracey): Add testOneInput test when such a function exists. 149 } 150 151 func TestLUSolveToCond(t *testing.T) { 152 t.Parallel() 153 for _, test := range []*Dense{ 154 NewDense(2, 2, []float64{1, 0, 0, 1e-20}), 155 } { 156 m, _ := test.Dims() 157 var lu LU 158 lu.Factorize(test) 159 b := NewDense(m, 2, nil) 160 var x Dense 161 if err := lu.SolveTo(&x, false, b); err == nil { 162 t.Error("No error for near-singular matrix in matrix solve.") 163 } 164 165 bvec := NewVecDense(m, nil) 166 var xvec VecDense 167 if err := lu.SolveVecTo(&xvec, false, bvec); err == nil { 168 t.Error("No error for near-singular matrix in matrix solve.") 169 } 170 } 171 } 172 173 func TestLUSolveVecTo(t *testing.T) { 174 t.Parallel() 175 rnd := rand.New(rand.NewSource(1)) 176 for _, n := range []int{5, 10} { 177 a := NewDense(n, n, nil) 178 for i := 0; i < n; i++ { 179 for j := 0; j < n; j++ { 180 a.Set(i, j, rnd.NormFloat64()) 181 } 182 } 183 b := NewVecDense(n, nil) 184 for i := 0; i < n; i++ { 185 b.SetVec(i, rnd.NormFloat64()) 186 } 187 var lu LU 188 lu.Factorize(a) 189 var x VecDense 190 if err := lu.SolveVecTo(&x, false, b); err != nil { 191 continue 192 } 193 var got VecDense 194 got.MulVec(a, &x) 195 if !EqualApprox(&got, b, 1e-12) { 196 t.Errorf("SolveTo mismatch n = %v.\nWant: %v\nGot: %v", n, b, got) 197 } 198 } 199 // TODO(btracey): Add testOneInput test when such a function exists. 200 }