github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dggsvp3.go (about) 1 // Copyright ©2017 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "math/rand" 9 "testing" 10 11 "github.com/gonum/blas" 12 "github.com/gonum/blas/blas64" 13 "github.com/gonum/lapack" 14 ) 15 16 type Dggsvp3er interface { 17 Dlanger 18 Dggsvp3(jobU, jobV, jobQ lapack.GSVDJob, m, p, n int, a []float64, lda int, b []float64, ldb int, tola, tolb float64, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, iwork []int, tau, work []float64, lwork int) (k, l int) 19 } 20 21 func Dggsvp3Test(t *testing.T, impl Dggsvp3er) { 22 rnd := rand.New(rand.NewSource(1)) 23 for cas, test := range []struct { 24 m, p, n, lda, ldb, ldu, ldv, ldq int 25 }{ 26 {m: 3, p: 3, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 27 {m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 28 {m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 29 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 30 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 31 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 32 {m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 33 {m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 34 {m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 35 {m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 36 {m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10}, 37 {m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10}, 38 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20}, 39 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20}, 40 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20}, 41 {m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10}, 42 {m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10}, 43 {m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20}, 44 {m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20}, 45 } { 46 m := test.m 47 p := test.p 48 n := test.n 49 lda := test.lda 50 if lda == 0 { 51 lda = n 52 } 53 ldb := test.ldb 54 if ldb == 0 { 55 ldb = n 56 } 57 ldu := test.ldu 58 if ldu == 0 { 59 ldu = m 60 } 61 ldv := test.ldv 62 if ldv == 0 { 63 ldv = p 64 } 65 ldq := test.ldq 66 if ldq == 0 { 67 ldq = n 68 } 69 70 a := randomGeneral(m, n, lda, rnd) 71 aCopy := cloneGeneral(a) 72 b := randomGeneral(p, n, ldb, rnd) 73 bCopy := cloneGeneral(b) 74 75 tola := float64(max(m, n)) * impl.Dlange(lapack.NormFrob, m, n, a.Data, a.Stride, nil) * dlamchE 76 tolb := float64(max(p, n)) * impl.Dlange(lapack.NormFrob, p, n, b.Data, b.Stride, nil) * dlamchE 77 78 u := nanGeneral(m, m, ldu) 79 v := nanGeneral(p, p, ldv) 80 q := nanGeneral(n, n, ldq) 81 82 iwork := make([]int, n) 83 tau := make([]float64, n) 84 85 work := []float64{0} 86 impl.Dggsvp3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ, 87 m, p, n, 88 a.Data, a.Stride, 89 b.Data, b.Stride, 90 tola, tolb, 91 u.Data, u.Stride, 92 v.Data, v.Stride, 93 q.Data, q.Stride, 94 iwork, tau, 95 work, -1) 96 97 lwork := int(work[0]) 98 work = make([]float64, lwork) 99 100 k, l := impl.Dggsvp3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ, 101 m, p, n, 102 a.Data, a.Stride, 103 b.Data, b.Stride, 104 tola, tolb, 105 u.Data, u.Stride, 106 v.Data, v.Stride, 107 q.Data, q.Stride, 108 iwork, tau, 109 work, lwork) 110 111 // Check orthogonality of U, V and Q. 112 if !isOrthonormal(u) { 113 t.Errorf("test %d: U is not orthogonal\n%+v", cas, u) 114 } 115 if !isOrthonormal(v) { 116 t.Errorf("test %d: V is not orthogonal\n%+v", cas, v) 117 } 118 if !isOrthonormal(q) { 119 t.Errorf("test %d: Q is not orthogonal\n%+v", cas, q) 120 } 121 122 zeroA, zeroB := constructGSVPresults(n, p, m, k, l, a, b) 123 124 // Check U^T*A*Q = [ 0 RA ]. 125 uTmp := nanGeneral(m, n, n) 126 blas64.Gemm(blas.Trans, blas.NoTrans, 1, u, aCopy, 0, uTmp) 127 uAns := nanGeneral(m, n, n) 128 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uTmp, q, 0, uAns) 129 130 if !equalApproxGeneral(uAns, zeroA, 1e-14) { 131 t.Errorf("test %d: U^T*A*Q != [ 0 RA ]\nU^T*A*Q:\n%+v\n[ 0 RA ]:\n%+v", 132 cas, uAns, zeroA) 133 } 134 135 // Check V^T*B*Q = [ 0 RB ]. 136 vTmp := nanGeneral(p, n, n) 137 blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, bCopy, 0, vTmp) 138 vAns := nanGeneral(p, n, n) 139 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vTmp, q, 0, vAns) 140 141 if !equalApproxGeneral(vAns, zeroB, 1e-14) { 142 t.Errorf("test %d: V^T*B*Q != [ 0 RB ]\nV^T*B*Q:\n%+v\n[ 0 RB ]:\n%+v", 143 cas, vAns, zeroB) 144 } 145 } 146 }