gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dtgsja.go (about) 1 // Copyright ©2017 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "testing" 9 10 "golang.org/x/exp/rand" 11 12 "gonum.org/v1/gonum/blas" 13 "gonum.org/v1/gonum/blas/blas64" 14 "gonum.org/v1/gonum/floats/scalar" 15 "gonum.org/v1/gonum/lapack" 16 ) 17 18 type Dtgsjaer interface { 19 Dlanger 20 Dtgsja(jobU, jobV, jobQ lapack.GSVDJob, m, p, n, k, l int, a []float64, lda int, b []float64, ldb int, tola, tolb float64, alpha, beta, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, work []float64) (cycles int, ok bool) 21 } 22 23 func DtgsjaTest(t *testing.T, impl Dtgsjaer) { 24 const tol = 1e-14 25 26 rnd := rand.New(rand.NewSource(1)) 27 for cas, test := range []struct { 28 m, p, n, k, l, lda, ldb, ldu, ldv, ldq int 29 30 ok bool 31 }{ 32 {m: 5, p: 5, n: 5, k: 2, l: 2, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 33 {m: 5, p: 5, n: 5, k: 4, l: 1, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 34 {m: 5, p: 5, n: 10, k: 2, l: 2, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 35 {m: 5, p: 5, n: 10, k: 4, l: 1, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 36 {m: 5, p: 5, n: 10, k: 4, l: 2, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 37 {m: 10, p: 5, n: 5, k: 2, l: 2, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 38 {m: 10, p: 5, n: 5, k: 4, l: 1, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 39 {m: 10, p: 10, n: 10, k: 5, l: 3, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 40 {m: 10, p: 10, n: 10, k: 6, l: 4, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 41 {m: 5, p: 5, n: 5, k: 2, l: 2, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10, ok: true}, 42 {m: 5, p: 5, n: 5, k: 4, l: 1, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10, ok: true}, 43 {m: 5, p: 5, n: 10, k: 2, l: 2, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true}, 44 {m: 5, p: 5, n: 10, k: 4, l: 1, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true}, 45 {m: 5, p: 5, n: 10, k: 4, l: 2, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true}, 46 {m: 10, p: 5, n: 5, k: 2, l: 2, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10, ok: true}, 47 {m: 10, p: 5, n: 5, k: 4, l: 1, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10, ok: true}, 48 {m: 10, p: 10, n: 10, k: 5, l: 3, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20, ok: true}, 49 {m: 10, p: 10, n: 10, k: 6, l: 4, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20, ok: true}, 50 } { 51 m := test.m 52 p := test.p 53 n := test.n 54 k := test.k 55 l := test.l 56 lda := test.lda 57 if lda == 0 { 58 lda = n 59 } 60 ldb := test.ldb 61 if ldb == 0 { 62 ldb = n 63 } 64 ldu := test.ldu 65 if ldu == 0 { 66 ldu = m 67 } 68 ldv := test.ldv 69 if ldv == 0 { 70 ldv = p 71 } 72 ldq := test.ldq 73 if ldq == 0 { 74 ldq = n 75 } 76 77 a := blockedUpperTriGeneral(m, n, k, l, lda, true, rnd) 78 aCopy := cloneGeneral(a) 79 b := blockedUpperTriGeneral(p, n, k, l, ldb, false, rnd) 80 bCopy := cloneGeneral(b) 81 82 tola := float64(max(m, n)) * impl.Dlange(lapack.Frobenius, m, n, a.Data, a.Stride, nil) * dlamchE 83 tolb := float64(max(p, n)) * impl.Dlange(lapack.Frobenius, p, n, b.Data, b.Stride, nil) * dlamchE 84 85 alpha := make([]float64, n) 86 beta := make([]float64, n) 87 88 work := make([]float64, 2*n) 89 90 u := nanGeneral(m, m, ldu) 91 v := nanGeneral(p, p, ldv) 92 q := nanGeneral(n, n, ldq) 93 94 _, ok := impl.Dtgsja(lapack.GSVDUnit, lapack.GSVDUnit, lapack.GSVDUnit, 95 m, p, n, k, l, 96 a.Data, a.Stride, 97 b.Data, b.Stride, 98 tola, tolb, 99 alpha, beta, 100 u.Data, u.Stride, 101 v.Data, v.Stride, 102 q.Data, q.Stride, 103 work) 104 105 if !ok { 106 if test.ok { 107 t.Errorf("test %d unexpectedly did not converge", cas) 108 } 109 continue 110 } 111 112 // Check orthogonality of U, V and Q. 113 if resid := residualOrthogonal(u, false); resid > tol { 114 t.Errorf("Case %v: U is not orthogonal; resid=%v, want<=%v", cas, resid, tol) 115 } 116 if resid := residualOrthogonal(v, false); resid > tol { 117 t.Errorf("Case %v: V is not orthogonal; resid=%v, want<=%v", cas, resid, tol) 118 } 119 if resid := residualOrthogonal(q, false); resid > tol { 120 t.Errorf("Case %v: Q is not orthogonal; resid=%v, want<=%v", cas, resid, tol) 121 } 122 123 // Check C^2 + S^2 = I. 124 var elements []float64 125 if m-k-l >= 0 { 126 elements = alpha[k : k+l] 127 } else { 128 elements = alpha[k:m] 129 } 130 for i := range elements { 131 i += k 132 d := alpha[i]*alpha[i] + beta[i]*beta[i] 133 if !scalar.EqualWithinAbsOrRel(d, 1, tol, tol) { 134 t.Errorf("test %d: alpha_%d^2 + beta_%d^2 != 1: got: %v", cas, i, i, d) 135 } 136 } 137 138 zeroR, d1, d2 := constructGSVDresults(n, p, m, k, l, a, b, alpha, beta) 139 140 // Check Uᵀ*A*Q = D1*[ 0 R ]. 141 uTmp := nanGeneral(m, n, n) 142 blas64.Gemm(blas.Trans, blas.NoTrans, 1, u, aCopy, 0, uTmp) 143 uAns := nanGeneral(m, n, n) 144 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uTmp, q, 0, uAns) 145 146 d10r := nanGeneral(m, n, n) 147 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d1, zeroR, 0, d10r) 148 149 if !equalApproxGeneral(uAns, d10r, tol) { 150 t.Errorf("test %d: Uᵀ*A*Q != D1*[ 0 R ]\nUᵀ*A*Q:\n%+v\nD1*[ 0 R ]:\n%+v", 151 cas, uAns, d10r) 152 } 153 154 // Check Vᵀ*B*Q = D2*[ 0 R ]. 155 vTmp := nanGeneral(p, n, n) 156 blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, bCopy, 0, vTmp) 157 vAns := nanGeneral(p, n, n) 158 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vTmp, q, 0, vAns) 159 160 d20r := nanGeneral(p, n, n) 161 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d2, zeroR, 0, d20r) 162 163 if !equalApproxGeneral(vAns, d20r, tol) { 164 t.Errorf("test %d: Vᵀ*B*Q != D2*[ 0 R ]\nVᵀ*B*Q:\n%+v\nD2*[ 0 R ]:\n%+v", 165 cas, vAns, d20r) 166 } 167 } 168 }