github.com/gopherd/gonum@v0.0.4/lapack/testlapack/dtgsja.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"testing"
     9  
    10  	"math/rand"
    11  
    12  	"github.com/gopherd/gonum/blas"
    13  	"github.com/gopherd/gonum/blas/blas64"
    14  	"github.com/gopherd/gonum/floats/scalar"
    15  	"github.com/gopherd/gonum/lapack"
    16  )
    17  
    18  type Dtgsjaer interface {
    19  	Dlanger
    20  	Dtgsja(jobU, jobV, jobQ lapack.GSVDJob, m, p, n, k, l int, a []float64, lda int, b []float64, ldb int, tola, tolb float64, alpha, beta, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, work []float64) (cycles int, ok bool)
    21  }
    22  
    23  func DtgsjaTest(t *testing.T, impl Dtgsjaer) {
    24  	const tol = 1e-14
    25  
    26  	rnd := rand.New(rand.NewSource(1))
    27  	for cas, test := range []struct {
    28  		m, p, n, k, l, lda, ldb, ldu, ldv, ldq int
    29  
    30  		ok bool
    31  	}{
    32  		{m: 5, p: 5, n: 5, k: 2, l: 2, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    33  		{m: 5, p: 5, n: 5, k: 4, l: 1, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    34  		{m: 5, p: 5, n: 10, k: 2, l: 2, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    35  		{m: 5, p: 5, n: 10, k: 4, l: 1, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    36  		{m: 5, p: 5, n: 10, k: 4, l: 2, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    37  		{m: 10, p: 5, n: 5, k: 2, l: 2, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    38  		{m: 10, p: 5, n: 5, k: 4, l: 1, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    39  		{m: 10, p: 10, n: 10, k: 5, l: 3, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    40  		{m: 10, p: 10, n: 10, k: 6, l: 4, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    41  		{m: 5, p: 5, n: 5, k: 2, l: 2, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10, ok: true},
    42  		{m: 5, p: 5, n: 5, k: 4, l: 1, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10, ok: true},
    43  		{m: 5, p: 5, n: 10, k: 2, l: 2, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true},
    44  		{m: 5, p: 5, n: 10, k: 4, l: 1, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true},
    45  		{m: 5, p: 5, n: 10, k: 4, l: 2, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true},
    46  		{m: 10, p: 5, n: 5, k: 2, l: 2, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10, ok: true},
    47  		{m: 10, p: 5, n: 5, k: 4, l: 1, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10, ok: true},
    48  		{m: 10, p: 10, n: 10, k: 5, l: 3, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20, ok: true},
    49  		{m: 10, p: 10, n: 10, k: 6, l: 4, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20, ok: true},
    50  	} {
    51  		m := test.m
    52  		p := test.p
    53  		n := test.n
    54  		k := test.k
    55  		l := test.l
    56  		lda := test.lda
    57  		if lda == 0 {
    58  			lda = n
    59  		}
    60  		ldb := test.ldb
    61  		if ldb == 0 {
    62  			ldb = n
    63  		}
    64  		ldu := test.ldu
    65  		if ldu == 0 {
    66  			ldu = m
    67  		}
    68  		ldv := test.ldv
    69  		if ldv == 0 {
    70  			ldv = p
    71  		}
    72  		ldq := test.ldq
    73  		if ldq == 0 {
    74  			ldq = n
    75  		}
    76  
    77  		a := blockedUpperTriGeneral(m, n, k, l, lda, true, rnd)
    78  		aCopy := cloneGeneral(a)
    79  		b := blockedUpperTriGeneral(p, n, k, l, ldb, false, rnd)
    80  		bCopy := cloneGeneral(b)
    81  
    82  		tola := float64(max(m, n)) * impl.Dlange(lapack.Frobenius, m, n, a.Data, a.Stride, nil) * dlamchE
    83  		tolb := float64(max(p, n)) * impl.Dlange(lapack.Frobenius, p, n, b.Data, b.Stride, nil) * dlamchE
    84  
    85  		alpha := make([]float64, n)
    86  		beta := make([]float64, n)
    87  
    88  		work := make([]float64, 2*n)
    89  
    90  		u := nanGeneral(m, m, ldu)
    91  		v := nanGeneral(p, p, ldv)
    92  		q := nanGeneral(n, n, ldq)
    93  
    94  		_, ok := impl.Dtgsja(lapack.GSVDUnit, lapack.GSVDUnit, lapack.GSVDUnit,
    95  			m, p, n, k, l,
    96  			a.Data, a.Stride,
    97  			b.Data, b.Stride,
    98  			tola, tolb,
    99  			alpha, beta,
   100  			u.Data, u.Stride,
   101  			v.Data, v.Stride,
   102  			q.Data, q.Stride,
   103  			work)
   104  
   105  		if !ok {
   106  			if test.ok {
   107  				t.Errorf("test %d unexpectedly did not converge", cas)
   108  			}
   109  			continue
   110  		}
   111  
   112  		// Check orthogonality of U, V and Q.
   113  		if resid := residualOrthogonal(u, false); resid > tol {
   114  			t.Errorf("Case %v: U is not orthogonal; resid=%v, want<=%v", cas, resid, tol)
   115  		}
   116  		if resid := residualOrthogonal(v, false); resid > tol {
   117  			t.Errorf("Case %v: V is not orthogonal; resid=%v, want<=%v", cas, resid, tol)
   118  		}
   119  		if resid := residualOrthogonal(q, false); resid > tol {
   120  			t.Errorf("Case %v: Q is not orthogonal; resid=%v, want<=%v", cas, resid, tol)
   121  		}
   122  
   123  		// Check C^2 + S^2 = I.
   124  		var elements []float64
   125  		if m-k-l >= 0 {
   126  			elements = alpha[k : k+l]
   127  		} else {
   128  			elements = alpha[k:m]
   129  		}
   130  		for i := range elements {
   131  			i += k
   132  			d := alpha[i]*alpha[i] + beta[i]*beta[i]
   133  			if !scalar.EqualWithinAbsOrRel(d, 1, tol, tol) {
   134  				t.Errorf("test %d: alpha_%d^2 + beta_%d^2 != 1: got: %v", cas, i, i, d)
   135  			}
   136  		}
   137  
   138  		zeroR, d1, d2 := constructGSVDresults(n, p, m, k, l, a, b, alpha, beta)
   139  
   140  		// Check Uᵀ*A*Q = D1*[ 0 R ].
   141  		uTmp := nanGeneral(m, n, n)
   142  		blas64.Gemm(blas.Trans, blas.NoTrans, 1, u, aCopy, 0, uTmp)
   143  		uAns := nanGeneral(m, n, n)
   144  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uTmp, q, 0, uAns)
   145  
   146  		d10r := nanGeneral(m, n, n)
   147  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d1, zeroR, 0, d10r)
   148  
   149  		if !equalApproxGeneral(uAns, d10r, tol) {
   150  			t.Errorf("test %d: Uᵀ*A*Q != D1*[ 0 R ]\nUᵀ*A*Q:\n%+v\nD1*[ 0 R ]:\n%+v",
   151  				cas, uAns, d10r)
   152  		}
   153  
   154  		// Check Vᵀ*B*Q = D2*[ 0 R ].
   155  		vTmp := nanGeneral(p, n, n)
   156  		blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, bCopy, 0, vTmp)
   157  		vAns := nanGeneral(p, n, n)
   158  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vTmp, q, 0, vAns)
   159  
   160  		d20r := nanGeneral(p, n, n)
   161  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d2, zeroR, 0, d20r)
   162  
   163  		if !equalApproxGeneral(vAns, d20r, tol) {
   164  			t.Errorf("test %d: Vᵀ*B*Q != D2*[ 0 R ]\nVᵀ*B*Q:\n%+v\nD2*[ 0 R ]:\n%+v",
   165  				cas, vAns, d20r)
   166  		}
   167  	}
   168  }