github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dtgsja.go (about) 1 // Copyright ©2017 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "math/rand" 9 "testing" 10 11 "github.com/gonum/blas" 12 "github.com/gonum/blas/blas64" 13 "github.com/gonum/floats" 14 "github.com/gonum/lapack" 15 ) 16 17 type Dtgsjaer interface { 18 Dlanger 19 Dtgsja(jobU, jobV, jobQ lapack.GSVDJob, m, p, n, k, l int, a []float64, lda int, b []float64, ldb int, tola, tolb float64, alpha, beta, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, work []float64) (cycles int, ok bool) 20 } 21 22 func DtgsjaTest(t *testing.T, impl Dtgsjaer) { 23 rnd := rand.New(rand.NewSource(1)) 24 for cas, test := range []struct { 25 m, p, n, k, l, lda, ldb, ldu, ldv, ldq int 26 27 ok bool 28 }{ 29 {m: 5, p: 5, n: 5, k: 2, l: 2, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 30 {m: 5, p: 5, n: 5, k: 4, l: 1, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 31 {m: 5, p: 5, n: 10, k: 2, l: 2, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 32 {m: 5, p: 5, n: 10, k: 4, l: 1, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 33 {m: 5, p: 5, n: 10, k: 4, l: 2, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 34 {m: 10, p: 5, n: 5, k: 2, l: 2, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 35 {m: 10, p: 5, n: 5, k: 4, l: 1, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 36 {m: 10, p: 10, n: 10, k: 5, l: 3, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 37 {m: 10, p: 10, n: 10, k: 6, l: 4, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 38 {m: 5, p: 5, n: 5, k: 2, l: 2, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10, ok: true}, 39 {m: 5, p: 5, n: 5, k: 4, l: 1, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10, ok: true}, 40 {m: 5, p: 5, n: 10, k: 2, l: 2, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true}, 41 {m: 5, p: 5, n: 10, k: 4, l: 1, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true}, 42 {m: 5, p: 5, n: 10, k: 4, l: 2, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true}, 43 {m: 10, p: 5, n: 5, k: 2, l: 2, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10, ok: true}, 44 {m: 10, p: 5, n: 5, k: 4, l: 1, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10, ok: true}, 45 {m: 10, p: 10, n: 10, k: 5, l: 3, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20, ok: true}, 46 {m: 10, p: 10, n: 10, k: 6, l: 4, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20, ok: true}, 47 } { 48 m := test.m 49 p := test.p 50 n := test.n 51 k := test.k 52 l := test.l 53 lda := test.lda 54 if lda == 0 { 55 lda = n 56 } 57 ldb := test.ldb 58 if ldb == 0 { 59 ldb = n 60 } 61 ldu := test.ldu 62 if ldu == 0 { 63 ldu = m 64 } 65 ldv := test.ldv 66 if ldv == 0 { 67 ldv = p 68 } 69 ldq := test.ldq 70 if ldq == 0 { 71 ldq = n 72 } 73 74 a := blockedUpperTriGeneral(m, n, k, l, lda, true, rnd) 75 aCopy := cloneGeneral(a) 76 b := blockedUpperTriGeneral(p, n, k, l, ldb, false, rnd) 77 bCopy := cloneGeneral(b) 78 79 tola := float64(max(m, n)) * impl.Dlange(lapack.NormFrob, m, n, a.Data, a.Stride, nil) * dlamchE 80 tolb := float64(max(p, n)) * impl.Dlange(lapack.NormFrob, p, n, b.Data, b.Stride, nil) * dlamchE 81 82 alpha := make([]float64, n) 83 beta := make([]float64, n) 84 85 work := make([]float64, 2*n) 86 87 u := nanGeneral(m, m, ldu) 88 v := nanGeneral(p, p, ldv) 89 q := nanGeneral(n, n, ldq) 90 91 _, ok := impl.Dtgsja(lapack.GSVDUnit, lapack.GSVDUnit, lapack.GSVDUnit, 92 m, p, n, k, l, 93 a.Data, a.Stride, 94 b.Data, b.Stride, 95 tola, tolb, 96 alpha, beta, 97 u.Data, u.Stride, 98 v.Data, v.Stride, 99 q.Data, q.Stride, 100 work) 101 102 if !ok { 103 if test.ok { 104 t.Errorf("test %d unexpectedly did not converge", cas) 105 } 106 continue 107 } 108 109 // Check orthogonality of U, V and Q. 110 if !isOrthonormal(u) { 111 t.Errorf("test %d: U is not orthogonal\n%+v", cas, u) 112 } 113 if !isOrthonormal(v) { 114 t.Errorf("test %d: V is not orthogonal\n%+v", cas, v) 115 } 116 if !isOrthonormal(q) { 117 t.Errorf("test %d: Q is not orthogonal\n%+v", cas, q) 118 } 119 120 // Check C^2 + S^2 = I. 121 var elements []float64 122 if m-k-l >= 0 { 123 elements = alpha[k : k+l] 124 } else { 125 elements = alpha[k:m] 126 } 127 for i := range elements { 128 i += k 129 d := alpha[i]*alpha[i] + beta[i]*beta[i] 130 if !floats.EqualWithinAbsOrRel(d, 1, 1e-14, 1e-14) { 131 t.Errorf("test %d: alpha_%d^2 + beta_%d^2 != 1: got: %v", cas, i, i, d) 132 } 133 } 134 135 zeroR, d1, d2 := constructGSVDresults(n, p, m, k, l, a, b, alpha, beta) 136 137 // Check U^T*A*Q = D1*[ 0 R ]. 138 uTmp := nanGeneral(m, n, n) 139 blas64.Gemm(blas.Trans, blas.NoTrans, 1, u, aCopy, 0, uTmp) 140 uAns := nanGeneral(m, n, n) 141 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uTmp, q, 0, uAns) 142 143 d10r := nanGeneral(m, n, n) 144 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d1, zeroR, 0, d10r) 145 146 if !equalApproxGeneral(uAns, d10r, 1e-14) { 147 t.Errorf("test %d: U^T*A*Q != D1*[ 0 R ]\nU^T*A*Q:\n%+v\nD1*[ 0 R ]:\n%+v", 148 cas, uAns, d10r) 149 } 150 151 // Check V^T*B*Q = D2*[ 0 R ]. 152 vTmp := nanGeneral(p, n, n) 153 blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, bCopy, 0, vTmp) 154 vAns := nanGeneral(p, n, n) 155 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vTmp, q, 0, vAns) 156 157 d20r := nanGeneral(p, n, n) 158 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d2, zeroR, 0, d20r) 159 160 if !equalApproxGeneral(vAns, d20r, 1e-14) { 161 t.Errorf("test %d: V^T*B*Q != D2*[ 0 R ]\nV^T*B*Q:\n%+v\nD2*[ 0 R ]:\n%+v", 162 cas, vAns, d20r) 163 } 164 } 165 }