gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dgeqr2.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/blas" 14 "gonum.org/v1/gonum/blas/blas64" 15 "gonum.org/v1/gonum/lapack" 16 ) 17 18 type Dgeqr2er interface { 19 Dgeqr2(m, n int, a []float64, lda int, tau []float64, work []float64) 20 } 21 22 func Dgeqr2Test(t *testing.T, impl Dgeqr2er) { 23 rnd := rand.New(rand.NewSource(1)) 24 for _, m := range []int{0, 1, 2, 3, 4, 5, 6, 12, 23} { 25 for _, n := range []int{0, 1, 2, 3, 4, 5, 6, 12, 23} { 26 for _, lda := range []int{max(1, n), n + 4} { 27 dgeqr2Test(t, impl, rnd, m, n, lda) 28 } 29 } 30 } 31 } 32 33 func dgeqr2Test(t *testing.T, impl Dgeqr2er, rnd *rand.Rand, m, n, lda int) { 34 const tol = 1e-14 35 36 name := fmt.Sprintf("m=%d,n=%d,lda=%d", m, n, lda) 37 38 a := randomGeneral(m, n, lda, rnd) 39 aCopy := cloneGeneral(a) 40 41 k := min(m, n) 42 tau := make([]float64, k) 43 for i := range tau { 44 tau[i] = rnd.Float64() 45 } 46 47 work := make([]float64, n) 48 for i := range work { 49 work[i] = rnd.Float64() 50 } 51 52 impl.Dgeqr2(m, n, a.Data, a.Stride, tau, work) 53 54 // Test that the QR factorization has completed successfully. Compute 55 // Q based on the vectors. 56 q := constructQ("QR", m, n, a.Data, a.Stride, tau) 57 58 // Check that Q is orthogonal. 59 if resid := residualOrthogonal(q, false); resid > tol { 60 t.Errorf("Case %v: Q not orthogonal; resid=%v, want<=%v", name, resid, tol) 61 } 62 63 // Check that |Q*R - A| is small. 64 r := zeros(m, n, n) 65 for i := 0; i < m; i++ { 66 for j := i; j < n; j++ { 67 r.Data[i*r.Stride+j] = a.Data[i*a.Stride+j] 68 } 69 } 70 qra := cloneGeneral(aCopy) 71 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, r, -1, qra) 72 resid := dlange(lapack.MaxColumnSum, qra.Rows, qra.Cols, qra.Data, qra.Stride) 73 if resid > tol*float64(m) { 74 t.Errorf("Case %v: |Q*R - A|=%v, want<=%v", name, resid, tol*float64(m)) 75 } 76 }