gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/mat/qr_test.go (about)

     1  // Copyright ©2013 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"gonum.org/v1/gonum/blas/blas64"
    14  )
    15  
    16  func TestQR(t *testing.T) {
    17  	t.Parallel()
    18  	rnd := rand.New(rand.NewSource(1))
    19  	for _, test := range []struct {
    20  		m, n int
    21  	}{
    22  		{5, 5},
    23  		{10, 5},
    24  	} {
    25  		m := test.m
    26  		n := test.n
    27  		a := NewDense(m, n, nil)
    28  		for i := 0; i < m; i++ {
    29  			for j := 0; j < n; j++ {
    30  				a.Set(i, j, rnd.NormFloat64())
    31  			}
    32  		}
    33  		var want Dense
    34  		want.CloneFrom(a)
    35  
    36  		var qr QR
    37  		qr.Factorize(a)
    38  		var q, r Dense
    39  		qr.QTo(&q)
    40  
    41  		if !isOrthonormal(&q, 1e-10) {
    42  			t.Errorf("Q is not orthonormal: m = %v, n = %v", m, n)
    43  		}
    44  
    45  		if !EqualApprox(a, &qr, 1e-14) {
    46  			t.Errorf("m=%d,n=%d: A and QR are not equal", m, n)
    47  		}
    48  		if !EqualApprox(a.T(), qr.T(), 1e-14) {
    49  			t.Errorf("m=%d,n=%d: Aᵀ and (QR)ᵀ are not equal", m, n)
    50  		}
    51  
    52  		qr.RTo(&r)
    53  
    54  		var got Dense
    55  		got.Mul(&q, &r)
    56  		if !EqualApprox(&got, &want, 1e-12) {
    57  			t.Errorf("QR does not equal original matrix. \nWant: %v\nGot: %v", want, got)
    58  		}
    59  	}
    60  }
    61  
    62  func isOrthonormal(q *Dense, tol float64) bool {
    63  	m, n := q.Dims()
    64  	if m != n {
    65  		return false
    66  	}
    67  	for i := 0; i < m; i++ {
    68  		for j := i; j < m; j++ {
    69  			dot := blas64.Dot(blas64.Vector{N: m, Inc: 1, Data: q.mat.Data[i*q.mat.Stride:]},
    70  				blas64.Vector{N: m, Inc: 1, Data: q.mat.Data[j*q.mat.Stride:]})
    71  			// Dot product should be 1 if i == j and 0 otherwise.
    72  			if i == j && math.Abs(dot-1) > tol {
    73  				return false
    74  			}
    75  			if i != j && math.Abs(dot) > tol {
    76  				return false
    77  			}
    78  		}
    79  	}
    80  	return true
    81  }
    82  
    83  func TestQRSolveTo(t *testing.T) {
    84  	t.Parallel()
    85  	rnd := rand.New(rand.NewSource(1))
    86  	for _, trans := range []bool{false, true} {
    87  		for _, test := range []struct {
    88  			m, n, bc int
    89  		}{
    90  			{5, 5, 1},
    91  			{10, 5, 1},
    92  			{5, 5, 3},
    93  			{10, 5, 3},
    94  		} {
    95  			m := test.m
    96  			n := test.n
    97  			bc := test.bc
    98  			a := NewDense(m, n, nil)
    99  			for i := 0; i < m; i++ {
   100  				for j := 0; j < n; j++ {
   101  					a.Set(i, j, rnd.Float64())
   102  				}
   103  			}
   104  			br := m
   105  			if trans {
   106  				br = n
   107  			}
   108  			b := NewDense(br, bc, nil)
   109  			for i := 0; i < br; i++ {
   110  				for j := 0; j < bc; j++ {
   111  					b.Set(i, j, rnd.Float64())
   112  				}
   113  			}
   114  			var x Dense
   115  			var qr QR
   116  			qr.Factorize(a)
   117  			err := qr.SolveTo(&x, trans, b)
   118  			if err != nil {
   119  				t.Errorf("unexpected error from QR solve: %v", err)
   120  			}
   121  
   122  			// Test that the normal equations hold.
   123  			// Aᵀ * A * x = Aᵀ * b if !trans
   124  			// A * Aᵀ * x = A * b if trans
   125  			var lhs Dense
   126  			var rhs Dense
   127  			if trans {
   128  				var tmp Dense
   129  				tmp.Mul(a, a.T())
   130  				lhs.Mul(&tmp, &x)
   131  				rhs.Mul(a, b)
   132  			} else {
   133  				var tmp Dense
   134  				tmp.Mul(a.T(), a)
   135  				lhs.Mul(&tmp, &x)
   136  				rhs.Mul(a.T(), b)
   137  			}
   138  			if !EqualApprox(&lhs, &rhs, 1e-10) {
   139  				t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs)
   140  			}
   141  		}
   142  	}
   143  	// TODO(btracey): Add in testOneInput when it exists.
   144  }
   145  
   146  func TestQRSolveVecTo(t *testing.T) {
   147  	t.Parallel()
   148  	rnd := rand.New(rand.NewSource(1))
   149  	for _, trans := range []bool{false, true} {
   150  		for _, test := range []struct {
   151  			m, n int
   152  		}{
   153  			{5, 5},
   154  			{10, 5},
   155  		} {
   156  			m := test.m
   157  			n := test.n
   158  			a := NewDense(m, n, nil)
   159  			for i := 0; i < m; i++ {
   160  				for j := 0; j < n; j++ {
   161  					a.Set(i, j, rnd.Float64())
   162  				}
   163  			}
   164  			br := m
   165  			if trans {
   166  				br = n
   167  			}
   168  			b := NewVecDense(br, nil)
   169  			for i := 0; i < br; i++ {
   170  				b.SetVec(i, rnd.Float64())
   171  			}
   172  			var x VecDense
   173  			var qr QR
   174  			qr.Factorize(a)
   175  			err := qr.SolveVecTo(&x, trans, b)
   176  			if err != nil {
   177  				t.Errorf("unexpected error from QR solve: %v", err)
   178  			}
   179  
   180  			// Test that the normal equations hold.
   181  			// Aᵀ * A * x = Aᵀ * b if !trans
   182  			// A * Aᵀ * x = A * b if trans
   183  			var lhs Dense
   184  			var rhs Dense
   185  			if trans {
   186  				var tmp Dense
   187  				tmp.Mul(a, a.T())
   188  				lhs.Mul(&tmp, &x)
   189  				rhs.Mul(a, b)
   190  			} else {
   191  				var tmp Dense
   192  				tmp.Mul(a.T(), a)
   193  				lhs.Mul(&tmp, &x)
   194  				rhs.Mul(a.T(), b)
   195  			}
   196  			if !EqualApprox(&lhs, &rhs, 1e-10) {
   197  				t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs)
   198  			}
   199  		}
   200  	}
   201  	// TODO(btracey): Add in testOneInput when it exists.
   202  }
   203  
   204  func TestQRSolveCondTo(t *testing.T) {
   205  	t.Parallel()
   206  	for _, test := range []*Dense{
   207  		NewDense(2, 2, []float64{1, 0, 0, 1e-20}),
   208  		NewDense(3, 2, []float64{1, 0, 0, 1e-20, 0, 0}),
   209  	} {
   210  		m, _ := test.Dims()
   211  		var qr QR
   212  		qr.Factorize(test)
   213  		b := NewDense(m, 2, nil)
   214  		var x Dense
   215  		if err := qr.SolveTo(&x, false, b); err == nil {
   216  			t.Error("No error for near-singular matrix in matrix solve.")
   217  		}
   218  
   219  		bvec := NewVecDense(m, nil)
   220  		var xvec VecDense
   221  		if err := qr.SolveVecTo(&xvec, false, bvec); err == nil {
   222  			t.Error("No error for near-singular matrix in matrix solve.")
   223  		}
   224  	}
   225  }