gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dggsvp3.go (about) 1 // Copyright ©2017 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "testing" 9 10 "golang.org/x/exp/rand" 11 12 "gonum.org/v1/gonum/blas" 13 "gonum.org/v1/gonum/blas/blas64" 14 "gonum.org/v1/gonum/lapack" 15 ) 16 17 type Dggsvp3er interface { 18 Dlanger 19 Dggsvp3(jobU, jobV, jobQ lapack.GSVDJob, m, p, n int, a []float64, lda int, b []float64, ldb int, tola, tolb float64, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, iwork []int, tau, work []float64, lwork int) (k, l int) 20 } 21 22 func Dggsvp3Test(t *testing.T, impl Dggsvp3er) { 23 const tol = 1e-14 24 25 rnd := rand.New(rand.NewSource(1)) 26 for cas, test := range []struct { 27 m, p, n, lda, ldb, ldu, ldv, ldq int 28 }{ 29 {m: 3, p: 3, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 30 {m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 31 {m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 32 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 33 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 34 {m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 35 {m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 36 {m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 37 {m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 38 {m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0}, 39 {m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10}, 40 {m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10}, 41 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20}, 42 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20}, 43 {m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20}, 44 {m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10}, 45 {m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10}, 46 {m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20}, 47 {m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20}, 48 } { 49 m := test.m 50 p := test.p 51 n := test.n 52 lda := test.lda 53 if lda == 0 { 54 lda = n 55 } 56 ldb := test.ldb 57 if ldb == 0 { 58 ldb = n 59 } 60 ldu := test.ldu 61 if ldu == 0 { 62 ldu = m 63 } 64 ldv := test.ldv 65 if ldv == 0 { 66 ldv = p 67 } 68 ldq := test.ldq 69 if ldq == 0 { 70 ldq = n 71 } 72 73 a := randomGeneral(m, n, lda, rnd) 74 aCopy := cloneGeneral(a) 75 b := randomGeneral(p, n, ldb, rnd) 76 bCopy := cloneGeneral(b) 77 78 tola := float64(max(m, n)) * impl.Dlange(lapack.Frobenius, m, n, a.Data, a.Stride, nil) * dlamchE 79 tolb := float64(max(p, n)) * impl.Dlange(lapack.Frobenius, p, n, b.Data, b.Stride, nil) * dlamchE 80 81 u := nanGeneral(m, m, ldu) 82 v := nanGeneral(p, p, ldv) 83 q := nanGeneral(n, n, ldq) 84 85 iwork := make([]int, n) 86 tau := make([]float64, n) 87 88 work := []float64{0} 89 impl.Dggsvp3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ, 90 m, p, n, 91 a.Data, a.Stride, 92 b.Data, b.Stride, 93 tola, tolb, 94 u.Data, u.Stride, 95 v.Data, v.Stride, 96 q.Data, q.Stride, 97 iwork, tau, 98 work, -1) 99 100 lwork := int(work[0]) 101 work = make([]float64, lwork) 102 103 k, l := impl.Dggsvp3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ, 104 m, p, n, 105 a.Data, a.Stride, 106 b.Data, b.Stride, 107 tola, tolb, 108 u.Data, u.Stride, 109 v.Data, v.Stride, 110 q.Data, q.Stride, 111 iwork, tau, 112 work, lwork) 113 114 // Check orthogonality of U, V and Q. 115 if resid := residualOrthogonal(u, false); resid > tol { 116 t.Errorf("Case %v: U is not orthogonal; resid=%v, want<=%v", cas, resid, tol) 117 } 118 if resid := residualOrthogonal(v, false); resid > tol { 119 t.Errorf("Case %v: V is not orthogonal; resid=%v, want<=%v", cas, resid, tol) 120 } 121 if resid := residualOrthogonal(q, false); resid > tol { 122 t.Errorf("Case %v: Q is not orthogonal; resid=%v, want<=%v", cas, resid, tol) 123 } 124 125 zeroA, zeroB := constructGSVPresults(n, p, m, k, l, a, b) 126 127 // Check Uᵀ*A*Q = [ 0 RA ]. 128 uTmp := nanGeneral(m, n, n) 129 blas64.Gemm(blas.Trans, blas.NoTrans, 1, u, aCopy, 0, uTmp) 130 uAns := nanGeneral(m, n, n) 131 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uTmp, q, 0, uAns) 132 133 if !equalApproxGeneral(uAns, zeroA, tol) { 134 t.Errorf("test %d: Uᵀ*A*Q != [ 0 RA ]\nUᵀ*A*Q:\n%+v\n[ 0 RA ]:\n%+v", 135 cas, uAns, zeroA) 136 } 137 138 // Check Vᵀ*B*Q = [ 0 RB ]. 139 vTmp := nanGeneral(p, n, n) 140 blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, bCopy, 0, vTmp) 141 vAns := nanGeneral(p, n, n) 142 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vTmp, q, 0, vAns) 143 144 if !equalApproxGeneral(vAns, zeroB, tol) { 145 t.Errorf("test %d: Vᵀ*B*Q != [ 0 RB ]\nVᵀ*B*Q:\n%+v\n[ 0 RB ]:\n%+v", 146 cas, vAns, zeroB) 147 } 148 } 149 }