gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dggsvp3.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"testing"
     9  
    10  	"golang.org/x/exp/rand"
    11  
    12  	"gonum.org/v1/gonum/blas"
    13  	"gonum.org/v1/gonum/blas/blas64"
    14  	"gonum.org/v1/gonum/lapack"
    15  )
    16  
    17  type Dggsvp3er interface {
    18  	Dlanger
    19  	Dggsvp3(jobU, jobV, jobQ lapack.GSVDJob, m, p, n int, a []float64, lda int, b []float64, ldb int, tola, tolb float64, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, iwork []int, tau, work []float64, lwork int) (k, l int)
    20  }
    21  
    22  func Dggsvp3Test(t *testing.T, impl Dggsvp3er) {
    23  	const tol = 1e-14
    24  
    25  	rnd := rand.New(rand.NewSource(1))
    26  	for cas, test := range []struct {
    27  		m, p, n, lda, ldb, ldu, ldv, ldq int
    28  	}{
    29  		{m: 3, p: 3, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
    30  		{m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
    31  		{m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
    32  		{m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
    33  		{m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
    34  		{m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
    35  		{m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
    36  		{m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
    37  		{m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
    38  		{m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0},
    39  		{m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10},
    40  		{m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10},
    41  		{m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20},
    42  		{m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20},
    43  		{m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20},
    44  		{m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10},
    45  		{m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10},
    46  		{m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20},
    47  		{m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20},
    48  	} {
    49  		m := test.m
    50  		p := test.p
    51  		n := test.n
    52  		lda := test.lda
    53  		if lda == 0 {
    54  			lda = n
    55  		}
    56  		ldb := test.ldb
    57  		if ldb == 0 {
    58  			ldb = n
    59  		}
    60  		ldu := test.ldu
    61  		if ldu == 0 {
    62  			ldu = m
    63  		}
    64  		ldv := test.ldv
    65  		if ldv == 0 {
    66  			ldv = p
    67  		}
    68  		ldq := test.ldq
    69  		if ldq == 0 {
    70  			ldq = n
    71  		}
    72  
    73  		a := randomGeneral(m, n, lda, rnd)
    74  		aCopy := cloneGeneral(a)
    75  		b := randomGeneral(p, n, ldb, rnd)
    76  		bCopy := cloneGeneral(b)
    77  
    78  		tola := float64(max(m, n)) * impl.Dlange(lapack.Frobenius, m, n, a.Data, a.Stride, nil) * dlamchE
    79  		tolb := float64(max(p, n)) * impl.Dlange(lapack.Frobenius, p, n, b.Data, b.Stride, nil) * dlamchE
    80  
    81  		u := nanGeneral(m, m, ldu)
    82  		v := nanGeneral(p, p, ldv)
    83  		q := nanGeneral(n, n, ldq)
    84  
    85  		iwork := make([]int, n)
    86  		tau := make([]float64, n)
    87  
    88  		work := []float64{0}
    89  		impl.Dggsvp3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ,
    90  			m, p, n,
    91  			a.Data, a.Stride,
    92  			b.Data, b.Stride,
    93  			tola, tolb,
    94  			u.Data, u.Stride,
    95  			v.Data, v.Stride,
    96  			q.Data, q.Stride,
    97  			iwork, tau,
    98  			work, -1)
    99  
   100  		lwork := int(work[0])
   101  		work = make([]float64, lwork)
   102  
   103  		k, l := impl.Dggsvp3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ,
   104  			m, p, n,
   105  			a.Data, a.Stride,
   106  			b.Data, b.Stride,
   107  			tola, tolb,
   108  			u.Data, u.Stride,
   109  			v.Data, v.Stride,
   110  			q.Data, q.Stride,
   111  			iwork, tau,
   112  			work, lwork)
   113  
   114  		// Check orthogonality of U, V and Q.
   115  		if resid := residualOrthogonal(u, false); resid > tol {
   116  			t.Errorf("Case %v: U is not orthogonal; resid=%v, want<=%v", cas, resid, tol)
   117  		}
   118  		if resid := residualOrthogonal(v, false); resid > tol {
   119  			t.Errorf("Case %v: V is not orthogonal; resid=%v, want<=%v", cas, resid, tol)
   120  		}
   121  		if resid := residualOrthogonal(q, false); resid > tol {
   122  			t.Errorf("Case %v: Q is not orthogonal; resid=%v, want<=%v", cas, resid, tol)
   123  		}
   124  
   125  		zeroA, zeroB := constructGSVPresults(n, p, m, k, l, a, b)
   126  
   127  		// Check Uᵀ*A*Q = [ 0 RA ].
   128  		uTmp := nanGeneral(m, n, n)
   129  		blas64.Gemm(blas.Trans, blas.NoTrans, 1, u, aCopy, 0, uTmp)
   130  		uAns := nanGeneral(m, n, n)
   131  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uTmp, q, 0, uAns)
   132  
   133  		if !equalApproxGeneral(uAns, zeroA, tol) {
   134  			t.Errorf("test %d: Uᵀ*A*Q != [ 0 RA ]\nUᵀ*A*Q:\n%+v\n[ 0 RA ]:\n%+v",
   135  				cas, uAns, zeroA)
   136  		}
   137  
   138  		// Check Vᵀ*B*Q = [ 0 RB ].
   139  		vTmp := nanGeneral(p, n, n)
   140  		blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, bCopy, 0, vTmp)
   141  		vAns := nanGeneral(p, n, n)
   142  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vTmp, q, 0, vAns)
   143  
   144  		if !equalApproxGeneral(vAns, zeroB, tol) {
   145  			t.Errorf("test %d: Vᵀ*B*Q != [ 0 RB ]\nVᵀ*B*Q:\n%+v\n[ 0 RB ]:\n%+v",
   146  				cas, vAns, zeroB)
   147  		}
   148  	}
   149  }