github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgeqr2.go (about) 1 // Copyright ©2015 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "math" 9 "math/rand" 10 "testing" 11 12 "github.com/gonum/blas" 13 "github.com/gonum/blas/blas64" 14 "github.com/gonum/floats" 15 ) 16 17 type Dgeqr2er interface { 18 Dgeqr2(m, n int, a []float64, lda int, tau []float64, work []float64) 19 } 20 21 func Dgeqr2Test(t *testing.T, impl Dgeqr2er) { 22 rnd := rand.New(rand.NewSource(1)) 23 for c, test := range []struct { 24 m, n, lda int 25 }{ 26 {1, 1, 0}, 27 {2, 2, 0}, 28 {3, 2, 0}, 29 {2, 3, 0}, 30 {1, 12, 0}, 31 {2, 6, 0}, 32 {3, 4, 0}, 33 {4, 3, 0}, 34 {6, 2, 0}, 35 {12, 1, 0}, 36 {1, 1, 20}, 37 {2, 2, 20}, 38 {3, 2, 20}, 39 {2, 3, 20}, 40 {1, 12, 20}, 41 {2, 6, 20}, 42 {3, 4, 20}, 43 {4, 3, 20}, 44 {6, 2, 20}, 45 {12, 1, 20}, 46 } { 47 n := test.n 48 m := test.m 49 lda := test.lda 50 if lda == 0 { 51 lda = test.n 52 } 53 a := make([]float64, m*lda) 54 for i := range a { 55 a[i] = rnd.Float64() 56 } 57 aCopy := make([]float64, len(a)) 58 k := min(m, n) 59 tau := make([]float64, k) 60 for i := range tau { 61 tau[i] = rnd.Float64() 62 } 63 work := make([]float64, n) 64 for i := range work { 65 work[i] = rnd.Float64() 66 } 67 copy(aCopy, a) 68 impl.Dgeqr2(m, n, a, lda, tau, work) 69 70 // Test that the QR factorization has completed successfully. Compute 71 // Q based on the vectors. 72 q := constructQ("QR", m, n, a, lda, tau) 73 74 // Check that q is orthonormal 75 for i := 0; i < m; i++ { 76 nrm := blas64.Nrm2(m, blas64.Vector{Inc: 1, Data: q.Data[i*m:]}) 77 if math.Abs(nrm-1) > 1e-14 { 78 t.Errorf("Case %v, q not normal", c) 79 } 80 for j := 0; j < i; j++ { 81 dot := blas64.Dot(m, blas64.Vector{Inc: 1, Data: q.Data[i*m:]}, blas64.Vector{Inc: 1, Data: q.Data[j*m:]}) 82 if math.Abs(dot) > 1e-14 { 83 t.Errorf("Case %v, q not orthogonal", c) 84 } 85 } 86 } 87 // Check that A = Q * R 88 r := blas64.General{ 89 Rows: m, 90 Cols: n, 91 Stride: n, 92 Data: make([]float64, m*n), 93 } 94 for i := 0; i < m; i++ { 95 for j := i; j < n; j++ { 96 r.Data[i*n+j] = a[i*lda+j] 97 } 98 } 99 atmp := blas64.General{ 100 Rows: m, 101 Cols: n, 102 Stride: lda, 103 Data: make([]float64, m*lda), 104 } 105 copy(atmp.Data, a) 106 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, r, 0, atmp) 107 if !floats.EqualApprox(atmp.Data, aCopy, 1e-14) { 108 t.Errorf("Q*R != a") 109 } 110 } 111 }