github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgeqp3.go (about) 1 // Copyright ©2015 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "math" 9 "math/rand" 10 "testing" 11 12 "github.com/gonum/blas" 13 "github.com/gonum/blas/blas64" 14 ) 15 16 type Dgeqp3er interface { 17 Dlapmter 18 Dgeqp3(m, n int, a []float64, lda int, jpvt []int, tau, work []float64, lwork int) 19 } 20 21 func Dgeqp3Test(t *testing.T, impl Dgeqp3er) { 22 rnd := rand.New(rand.NewSource(1)) 23 for c, test := range []struct { 24 m, n, lda int 25 }{ 26 {1, 1, 0}, 27 {2, 2, 0}, 28 {3, 2, 0}, 29 {2, 3, 0}, 30 {1, 12, 0}, 31 {2, 6, 0}, 32 {3, 4, 0}, 33 {4, 3, 0}, 34 {6, 2, 0}, 35 {12, 1, 0}, 36 {1, 1, 20}, 37 {2, 2, 20}, 38 {3, 2, 20}, 39 {2, 3, 20}, 40 {1, 12, 20}, 41 {2, 6, 20}, 42 {3, 4, 20}, 43 {4, 3, 20}, 44 {6, 2, 20}, 45 {12, 1, 20}, 46 {129, 256, 0}, 47 {256, 129, 0}, 48 {129, 256, 266}, 49 {256, 129, 266}, 50 } { 51 n := test.n 52 m := test.m 53 lda := test.lda 54 if lda == 0 { 55 lda = test.n 56 } 57 const ( 58 all = iota 59 some 60 none 61 ) 62 for _, free := range []int{all, some, none} { 63 a := make([]float64, m*lda) 64 for i := range a { 65 a[i] = rnd.Float64() 66 } 67 aCopy := make([]float64, len(a)) 68 copy(aCopy, a) 69 jpvt := make([]int, n) 70 for j := range jpvt { 71 switch free { 72 case all: 73 jpvt[j] = -1 74 case some: 75 jpvt[j] = rnd.Intn(2) - 1 76 case none: 77 jpvt[j] = 0 78 default: 79 panic("bad freedom") 80 } 81 } 82 k := min(m, n) 83 tau := make([]float64, k) 84 for i := range tau { 85 tau[i] = rnd.Float64() 86 } 87 work := make([]float64, 1) 88 impl.Dgeqp3(m, n, a, lda, jpvt, tau, work, -1) 89 lwork := int(work[0]) 90 work = make([]float64, lwork) 91 for i := range work { 92 work[i] = rnd.Float64() 93 } 94 impl.Dgeqp3(m, n, a, lda, jpvt, tau, work, lwork) 95 96 // Test that the QR factorization has completed successfully. Compute 97 // Q based on the vectors. 98 q := constructQ("QR", m, n, a, lda, tau) 99 100 // Check that q is orthonormal 101 for i := 0; i < m; i++ { 102 nrm := blas64.Nrm2(m, blas64.Vector{Inc: 1, Data: q.Data[i*m:]}) 103 if math.Abs(nrm-1) > 1e-13 { 104 t.Errorf("Case %v, q not normal", c) 105 } 106 for j := 0; j < i; j++ { 107 dot := blas64.Dot(m, blas64.Vector{Inc: 1, Data: q.Data[i*m:]}, blas64.Vector{Inc: 1, Data: q.Data[j*m:]}) 108 if math.Abs(dot) > 1e-14 { 109 t.Errorf("Case %v, q not orthogonal", c) 110 } 111 } 112 } 113 // Check that A * P = Q * R 114 r := blas64.General{ 115 Rows: m, 116 Cols: n, 117 Stride: n, 118 Data: make([]float64, m*n), 119 } 120 for i := 0; i < m; i++ { 121 for j := i; j < n; j++ { 122 r.Data[i*n+j] = a[i*lda+j] 123 } 124 } 125 got := nanGeneral(m, n, lda) 126 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, r, 0, got) 127 128 want := blas64.General{Rows: m, Cols: n, Stride: lda, Data: aCopy} 129 impl.Dlapmt(true, want.Rows, want.Cols, want.Data, want.Stride, jpvt) 130 if !equalApproxGeneral(got, want, 1e-13) { 131 t.Errorf("Case %v, Q*R != A*P\nQ*R=%v\nA*P=%v", c, got, want) 132 } 133 } 134 } 135 }