gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/mat/qr_test.go (about) 1 // Copyright ©2013 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "math" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/blas/blas64" 14 ) 15 16 func TestQR(t *testing.T) { 17 t.Parallel() 18 rnd := rand.New(rand.NewSource(1)) 19 for _, test := range []struct { 20 m, n int 21 }{ 22 {5, 5}, 23 {10, 5}, 24 } { 25 m := test.m 26 n := test.n 27 a := NewDense(m, n, nil) 28 for i := 0; i < m; i++ { 29 for j := 0; j < n; j++ { 30 a.Set(i, j, rnd.NormFloat64()) 31 } 32 } 33 var want Dense 34 want.CloneFrom(a) 35 36 var qr QR 37 qr.Factorize(a) 38 var q, r Dense 39 qr.QTo(&q) 40 41 if !isOrthonormal(&q, 1e-10) { 42 t.Errorf("Q is not orthonormal: m = %v, n = %v", m, n) 43 } 44 45 if !EqualApprox(a, &qr, 1e-14) { 46 t.Errorf("m=%d,n=%d: A and QR are not equal", m, n) 47 } 48 if !EqualApprox(a.T(), qr.T(), 1e-14) { 49 t.Errorf("m=%d,n=%d: Aᵀ and (QR)ᵀ are not equal", m, n) 50 } 51 52 qr.RTo(&r) 53 54 var got Dense 55 got.Mul(&q, &r) 56 if !EqualApprox(&got, &want, 1e-12) { 57 t.Errorf("QR does not equal original matrix. \nWant: %v\nGot: %v", want, got) 58 } 59 } 60 } 61 62 func isOrthonormal(q *Dense, tol float64) bool { 63 m, n := q.Dims() 64 if m != n { 65 return false 66 } 67 for i := 0; i < m; i++ { 68 for j := i; j < m; j++ { 69 dot := blas64.Dot(blas64.Vector{N: m, Inc: 1, Data: q.mat.Data[i*q.mat.Stride:]}, 70 blas64.Vector{N: m, Inc: 1, Data: q.mat.Data[j*q.mat.Stride:]}) 71 // Dot product should be 1 if i == j and 0 otherwise. 72 if i == j && math.Abs(dot-1) > tol { 73 return false 74 } 75 if i != j && math.Abs(dot) > tol { 76 return false 77 } 78 } 79 } 80 return true 81 } 82 83 func TestQRSolveTo(t *testing.T) { 84 t.Parallel() 85 rnd := rand.New(rand.NewSource(1)) 86 for _, trans := range []bool{false, true} { 87 for _, test := range []struct { 88 m, n, bc int 89 }{ 90 {5, 5, 1}, 91 {10, 5, 1}, 92 {5, 5, 3}, 93 {10, 5, 3}, 94 } { 95 m := test.m 96 n := test.n 97 bc := test.bc 98 a := NewDense(m, n, nil) 99 for i := 0; i < m; i++ { 100 for j := 0; j < n; j++ { 101 a.Set(i, j, rnd.Float64()) 102 } 103 } 104 br := m 105 if trans { 106 br = n 107 } 108 b := NewDense(br, bc, nil) 109 for i := 0; i < br; i++ { 110 for j := 0; j < bc; j++ { 111 b.Set(i, j, rnd.Float64()) 112 } 113 } 114 var x Dense 115 var qr QR 116 qr.Factorize(a) 117 err := qr.SolveTo(&x, trans, b) 118 if err != nil { 119 t.Errorf("unexpected error from QR solve: %v", err) 120 } 121 122 // Test that the normal equations hold. 123 // Aᵀ * A * x = Aᵀ * b if !trans 124 // A * Aᵀ * x = A * b if trans 125 var lhs Dense 126 var rhs Dense 127 if trans { 128 var tmp Dense 129 tmp.Mul(a, a.T()) 130 lhs.Mul(&tmp, &x) 131 rhs.Mul(a, b) 132 } else { 133 var tmp Dense 134 tmp.Mul(a.T(), a) 135 lhs.Mul(&tmp, &x) 136 rhs.Mul(a.T(), b) 137 } 138 if !EqualApprox(&lhs, &rhs, 1e-10) { 139 t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs) 140 } 141 } 142 } 143 // TODO(btracey): Add in testOneInput when it exists. 144 } 145 146 func TestQRSolveVecTo(t *testing.T) { 147 t.Parallel() 148 rnd := rand.New(rand.NewSource(1)) 149 for _, trans := range []bool{false, true} { 150 for _, test := range []struct { 151 m, n int 152 }{ 153 {5, 5}, 154 {10, 5}, 155 } { 156 m := test.m 157 n := test.n 158 a := NewDense(m, n, nil) 159 for i := 0; i < m; i++ { 160 for j := 0; j < n; j++ { 161 a.Set(i, j, rnd.Float64()) 162 } 163 } 164 br := m 165 if trans { 166 br = n 167 } 168 b := NewVecDense(br, nil) 169 for i := 0; i < br; i++ { 170 b.SetVec(i, rnd.Float64()) 171 } 172 var x VecDense 173 var qr QR 174 qr.Factorize(a) 175 err := qr.SolveVecTo(&x, trans, b) 176 if err != nil { 177 t.Errorf("unexpected error from QR solve: %v", err) 178 } 179 180 // Test that the normal equations hold. 181 // Aᵀ * A * x = Aᵀ * b if !trans 182 // A * Aᵀ * x = A * b if trans 183 var lhs Dense 184 var rhs Dense 185 if trans { 186 var tmp Dense 187 tmp.Mul(a, a.T()) 188 lhs.Mul(&tmp, &x) 189 rhs.Mul(a, b) 190 } else { 191 var tmp Dense 192 tmp.Mul(a.T(), a) 193 lhs.Mul(&tmp, &x) 194 rhs.Mul(a.T(), b) 195 } 196 if !EqualApprox(&lhs, &rhs, 1e-10) { 197 t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs) 198 } 199 } 200 } 201 // TODO(btracey): Add in testOneInput when it exists. 202 } 203 204 func TestQRSolveCondTo(t *testing.T) { 205 t.Parallel() 206 for _, test := range []*Dense{ 207 NewDense(2, 2, []float64{1, 0, 0, 1e-20}), 208 NewDense(3, 2, []float64{1, 0, 0, 1e-20, 0, 0}), 209 } { 210 m, _ := test.Dims() 211 var qr QR 212 qr.Factorize(test) 213 b := NewDense(m, 2, nil) 214 var x Dense 215 if err := qr.SolveTo(&x, false, b); err == nil { 216 t.Error("No error for near-singular matrix in matrix solve.") 217 } 218 219 bvec := NewVecDense(m, nil) 220 var xvec VecDense 221 if err := qr.SolveVecTo(&xvec, false, bvec); err == nil { 222 t.Error("No error for near-singular matrix in matrix solve.") 223 } 224 } 225 }