github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgelq2.go (about) 1 // Copyright ©2015 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "math" 9 "math/rand" 10 "testing" 11 12 "github.com/gonum/blas" 13 "github.com/gonum/blas/blas64" 14 "github.com/gonum/floats" 15 ) 16 17 type Dgelq2er interface { 18 Dgelq2(m, n int, a []float64, lda int, tau, work []float64) 19 } 20 21 func Dgelq2Test(t *testing.T, impl Dgelq2er) { 22 rnd := rand.New(rand.NewSource(1)) 23 for c, test := range []struct { 24 m, n, lda int 25 }{ 26 {1, 1, 0}, 27 {2, 2, 0}, 28 {3, 2, 0}, 29 {2, 3, 0}, 30 {1, 12, 0}, 31 {2, 6, 0}, 32 {3, 4, 0}, 33 {4, 3, 0}, 34 {6, 2, 0}, 35 {1, 12, 0}, 36 {1, 1, 20}, 37 {2, 2, 20}, 38 {3, 2, 20}, 39 {2, 3, 20}, 40 {1, 12, 20}, 41 {2, 6, 20}, 42 {3, 4, 20}, 43 {4, 3, 20}, 44 {6, 2, 20}, 45 {1, 12, 20}, 46 } { 47 n := test.n 48 m := test.m 49 lda := test.lda 50 if lda == 0 { 51 lda = test.n 52 } 53 k := min(m, n) 54 tau := make([]float64, k) 55 for i := range tau { 56 tau[i] = rnd.Float64() 57 } 58 work := make([]float64, m) 59 for i := range work { 60 work[i] = rnd.Float64() 61 } 62 a := make([]float64, m*lda) 63 for i := 0; i < m*lda; i++ { 64 a[i] = rnd.Float64() 65 } 66 aCopy := make([]float64, len(a)) 67 copy(aCopy, a) 68 impl.Dgelq2(m, n, a, lda, tau, work) 69 70 Q := constructQ("LQ", m, n, a, lda, tau) 71 72 // Check that Q is orthonormal 73 for i := 0; i < Q.Rows; i++ { 74 nrm := blas64.Nrm2(Q.Cols, blas64.Vector{Inc: 1, Data: Q.Data[i*Q.Stride:]}) 75 if math.Abs(nrm-1) > 1e-14 { 76 t.Errorf("Q not normal. Norm is %v", nrm) 77 } 78 for j := 0; j < i; j++ { 79 dot := blas64.Dot(Q.Rows, 80 blas64.Vector{Inc: 1, Data: Q.Data[i*Q.Stride:]}, 81 blas64.Vector{Inc: 1, Data: Q.Data[j*Q.Stride:]}, 82 ) 83 if math.Abs(dot) > 1e-14 { 84 t.Errorf("Q not orthogonal. Dot is %v", dot) 85 } 86 } 87 } 88 89 L := blas64.General{ 90 Rows: m, 91 Cols: n, 92 Stride: n, 93 Data: make([]float64, m*n), 94 } 95 for i := 0; i < m; i++ { 96 for j := 0; j <= min(i, n-1); j++ { 97 L.Data[i*L.Stride+j] = a[i*lda+j] 98 } 99 } 100 101 ans := blas64.General{ 102 Rows: m, 103 Cols: n, 104 Stride: lda, 105 Data: make([]float64, m*lda), 106 } 107 copy(ans.Data, aCopy) 108 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, Q, 0, ans) 109 if !floats.EqualApprox(aCopy, ans.Data, 1e-14) { 110 t.Errorf("Case %v, LQ mismatch. Want %v, got %v.", c, aCopy, ans.Data) 111 } 112 } 113 }