gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dggsvd3.go (about) 1 // Copyright ©2017 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "testing" 9 10 "golang.org/x/exp/rand" 11 12 "gonum.org/v1/gonum/blas" 13 "gonum.org/v1/gonum/blas/blas64" 14 "gonum.org/v1/gonum/floats/scalar" 15 "gonum.org/v1/gonum/lapack" 16 ) 17 18 type Dggsvd3er interface { 19 Dggsvd3(jobU, jobV, jobQ lapack.GSVDJob, m, n, p int, a []float64, lda int, b []float64, ldb int, alpha, beta, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, work []float64, lwork int, iwork []int) (k, l int, ok bool) 20 } 21 22 func Dggsvd3Test(t *testing.T, impl Dggsvd3er) { 23 const tol = 1e-13 24 25 rnd := rand.New(rand.NewSource(1)) 26 for cas, test := range []struct { 27 m, p, n, lda, ldb, ldu, ldv, ldq int 28 29 ok bool 30 }{ 31 {m: 3, p: 3, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 32 {m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 33 {m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 34 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 35 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 36 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 37 {m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 38 {m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 39 {m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 40 {m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true}, 41 {m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10, ok: true}, 42 {m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10, ok: true}, 43 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true}, 44 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true}, 45 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true}, 46 {m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10, ok: true}, 47 {m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10, ok: true}, 48 {m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20, ok: true}, 49 {m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20, ok: true}, 50 } { 51 m := test.m 52 p := test.p 53 n := test.n 54 lda := test.lda 55 if lda == 0 { 56 lda = n 57 } 58 ldb := test.ldb 59 if ldb == 0 { 60 ldb = n 61 } 62 ldu := test.ldu 63 if ldu == 0 { 64 ldu = m 65 } 66 ldv := test.ldv 67 if ldv == 0 { 68 ldv = p 69 } 70 ldq := test.ldq 71 if ldq == 0 { 72 ldq = n 73 } 74 75 a := randomGeneral(m, n, lda, rnd) 76 aCopy := cloneGeneral(a) 77 b := randomGeneral(p, n, ldb, rnd) 78 bCopy := cloneGeneral(b) 79 80 alpha := make([]float64, n) 81 beta := make([]float64, n) 82 83 u := nanGeneral(m, m, ldu) 84 v := nanGeneral(p, p, ldv) 85 q := nanGeneral(n, n, ldq) 86 87 iwork := make([]int, n) 88 89 work := []float64{0} 90 impl.Dggsvd3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ, 91 m, n, p, 92 a.Data, a.Stride, 93 b.Data, b.Stride, 94 alpha, beta, 95 u.Data, u.Stride, 96 v.Data, v.Stride, 97 q.Data, q.Stride, 98 work, -1, iwork) 99 100 lwork := int(work[0]) 101 work = make([]float64, lwork) 102 103 k, l, ok := impl.Dggsvd3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ, 104 m, n, p, 105 a.Data, a.Stride, 106 b.Data, b.Stride, 107 alpha, beta, 108 u.Data, u.Stride, 109 v.Data, v.Stride, 110 q.Data, q.Stride, 111 work, lwork, iwork) 112 113 if !ok { 114 if test.ok { 115 t.Errorf("test %d unexpectedly did not converge", cas) 116 } 117 continue 118 } 119 120 // Check orthogonality of U, V and Q. 121 if resid := residualOrthogonal(u, false); resid > tol { 122 t.Errorf("Case %v: U is not orthogonal; resid=%v, want<=%v", cas, resid, tol) 123 } 124 if resid := residualOrthogonal(v, false); resid > tol { 125 t.Errorf("Case %v: V is not orthogonal; resid=%v, want<=%v", cas, resid, tol) 126 } 127 if resid := residualOrthogonal(q, false); resid > tol { 128 t.Errorf("Case %v: Q is not orthogonal; resid=%v, want<=%v", cas, resid, tol) 129 } 130 131 // Check C^2 + S^2 = I. 132 var elements []float64 133 if m-k-l >= 0 { 134 elements = alpha[k : k+l] 135 } else { 136 elements = alpha[k:m] 137 } 138 for i := range elements { 139 i += k 140 d := alpha[i]*alpha[i] + beta[i]*beta[i] 141 if !scalar.EqualWithinAbsOrRel(d, 1, tol, tol) { 142 t.Errorf("test %d: alpha_%d^2 + beta_%d^2 != 1: got: %v", cas, i, i, d) 143 } 144 } 145 146 zeroR, d1, d2 := constructGSVDresults(n, p, m, k, l, a, b, alpha, beta) 147 148 // Check Uᵀ*A*Q = D1*[ 0 R ]. 149 uTmp := nanGeneral(m, n, n) 150 blas64.Gemm(blas.Trans, blas.NoTrans, 1, u, aCopy, 0, uTmp) 151 uAns := nanGeneral(m, n, n) 152 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uTmp, q, 0, uAns) 153 154 d10r := nanGeneral(m, n, n) 155 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d1, zeroR, 0, d10r) 156 157 if !equalApproxGeneral(uAns, d10r, tol) { 158 t.Errorf("test %d: Uᵀ*A*Q != D1*[ 0 R ]\nUᵀ*A*Q:\n%+v\nD1*[ 0 R ]:\n%+v", 159 cas, uAns, d10r) 160 } 161 162 // Check Vᵀ*B*Q = D2*[ 0 R ]. 163 vTmp := nanGeneral(p, n, n) 164 blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, bCopy, 0, vTmp) 165 vAns := nanGeneral(p, n, n) 166 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vTmp, q, 0, vAns) 167 168 d20r := nanGeneral(p, n, n) 169 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d2, zeroR, 0, d20r) 170 171 if !equalApproxGeneral(vAns, d20r, tol) { 172 t.Errorf("test %d: Vᵀ*B*Q != D2*[ 0 R ]\nVᵀ*B*Q:\n%+v\nD2*[ 0 R ]:\n%+v", 173 cas, vAns, d20r) 174 } 175 } 176 }