github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dggsvd3.go (about)

     1  // Copyright ©2017 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math/rand"
     9  	"testing"
    10  
    11  	"github.com/gonum/blas"
    12  	"github.com/gonum/blas/blas64"
    13  	"github.com/gonum/floats"
    14  	"github.com/gonum/lapack"
    15  )
    16  
    17  type Dggsvd3er interface {
    18  	Dggsvd3(jobU, jobV, jobQ lapack.GSVDJob, m, n, p int, a []float64, lda int, b []float64, ldb int, alpha, beta, u []float64, ldu int, v []float64, ldv int, q []float64, ldq int, work []float64, lwork int, iwork []int) (k, l int, ok bool)
    19  }
    20  
    21  func Dggsvd3Test(t *testing.T, impl Dggsvd3er) {
    22  	rnd := rand.New(rand.NewSource(1))
    23  	for cas, test := range []struct {
    24  		m, p, n, lda, ldb, ldu, ldv, ldq int
    25  
    26  		ok bool
    27  	}{
    28  		{m: 3, p: 3, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    29  		{m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    30  		{m: 5, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    31  		{m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    32  		{m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    33  		{m: 5, p: 5, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    34  		{m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    35  		{m: 10, p: 5, n: 5, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    36  		{m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    37  		{m: 10, p: 10, n: 10, lda: 0, ldb: 0, ldu: 0, ldv: 0, ldq: 0, ok: true},
    38  		{m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10, ok: true},
    39  		{m: 5, p: 5, n: 5, lda: 10, ldb: 10, ldu: 10, ldv: 10, ldq: 10, ok: true},
    40  		{m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true},
    41  		{m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true},
    42  		{m: 5, p: 5, n: 10, lda: 20, ldb: 20, ldu: 10, ldv: 10, ldq: 20, ok: true},
    43  		{m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10, ok: true},
    44  		{m: 10, p: 5, n: 5, lda: 10, ldb: 10, ldu: 20, ldv: 10, ldq: 10, ok: true},
    45  		{m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20, ok: true},
    46  		{m: 10, p: 10, n: 10, lda: 20, ldb: 20, ldu: 20, ldv: 20, ldq: 20, ok: true},
    47  	} {
    48  		m := test.m
    49  		p := test.p
    50  		n := test.n
    51  		lda := test.lda
    52  		if lda == 0 {
    53  			lda = n
    54  		}
    55  		ldb := test.ldb
    56  		if ldb == 0 {
    57  			ldb = n
    58  		}
    59  		ldu := test.ldu
    60  		if ldu == 0 {
    61  			ldu = m
    62  		}
    63  		ldv := test.ldv
    64  		if ldv == 0 {
    65  			ldv = p
    66  		}
    67  		ldq := test.ldq
    68  		if ldq == 0 {
    69  			ldq = n
    70  		}
    71  
    72  		a := randomGeneral(m, n, lda, rnd)
    73  		aCopy := cloneGeneral(a)
    74  		b := randomGeneral(p, n, ldb, rnd)
    75  		bCopy := cloneGeneral(b)
    76  
    77  		alpha := make([]float64, n)
    78  		beta := make([]float64, n)
    79  
    80  		u := nanGeneral(m, m, ldu)
    81  		v := nanGeneral(p, p, ldv)
    82  		q := nanGeneral(n, n, ldq)
    83  
    84  		iwork := make([]int, n)
    85  
    86  		work := []float64{0}
    87  		impl.Dggsvd3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ,
    88  			m, n, p,
    89  			a.Data, a.Stride,
    90  			b.Data, b.Stride,
    91  			alpha, beta,
    92  			u.Data, u.Stride,
    93  			v.Data, v.Stride,
    94  			q.Data, q.Stride,
    95  			work, -1, iwork)
    96  
    97  		lwork := int(work[0])
    98  		work = make([]float64, lwork)
    99  
   100  		k, l, ok := impl.Dggsvd3(lapack.GSVDU, lapack.GSVDV, lapack.GSVDQ,
   101  			m, n, p,
   102  			a.Data, a.Stride,
   103  			b.Data, b.Stride,
   104  			alpha, beta,
   105  			u.Data, u.Stride,
   106  			v.Data, v.Stride,
   107  			q.Data, q.Stride,
   108  			work, lwork, iwork)
   109  
   110  		if !ok {
   111  			if test.ok {
   112  				t.Errorf("test %d unexpectedly did not converge", cas)
   113  			}
   114  			continue
   115  		}
   116  
   117  		// Check orthogonality of U, V and Q.
   118  		if !isOrthonormal(u) {
   119  			t.Errorf("test %d: U is not orthogonal\n%+v", cas, u)
   120  		}
   121  		if !isOrthonormal(v) {
   122  			t.Errorf("test %d: V is not orthogonal\n%+v", cas, v)
   123  		}
   124  		if !isOrthonormal(q) {
   125  			t.Errorf("test %d: Q is not orthogonal\n%+v", cas, q)
   126  		}
   127  
   128  		// Check C^2 + S^2 = I.
   129  		var elements []float64
   130  		if m-k-l >= 0 {
   131  			elements = alpha[k : k+l]
   132  		} else {
   133  			elements = alpha[k:m]
   134  		}
   135  		for i := range elements {
   136  			i += k
   137  			d := alpha[i]*alpha[i] + beta[i]*beta[i]
   138  			if !floats.EqualWithinAbsOrRel(d, 1, 1e-14, 1e-14) {
   139  				t.Errorf("test %d: alpha_%d^2 + beta_%d^2 != 1: got: %v", cas, i, i, d)
   140  			}
   141  		}
   142  
   143  		zeroR, d1, d2 := constructGSVDresults(n, p, m, k, l, a, b, alpha, beta)
   144  
   145  		// Check U^T*A*Q = D1*[ 0 R ].
   146  		uTmp := nanGeneral(m, n, n)
   147  		blas64.Gemm(blas.Trans, blas.NoTrans, 1, u, aCopy, 0, uTmp)
   148  		uAns := nanGeneral(m, n, n)
   149  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uTmp, q, 0, uAns)
   150  
   151  		d10r := nanGeneral(m, n, n)
   152  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d1, zeroR, 0, d10r)
   153  
   154  		if !equalApproxGeneral(uAns, d10r, 1e-14) {
   155  			t.Errorf("test %d: U^T*A*Q != D1*[ 0 R ]\nU^T*A*Q:\n%+v\nD1*[ 0 R ]:\n%+v",
   156  				cas, uAns, d10r)
   157  		}
   158  
   159  		// Check V^T*B*Q = D2*[ 0 R ].
   160  		vTmp := nanGeneral(p, n, n)
   161  		blas64.Gemm(blas.Trans, blas.NoTrans, 1, v, bCopy, 0, vTmp)
   162  		vAns := nanGeneral(p, n, n)
   163  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vTmp, q, 0, vAns)
   164  
   165  		d20r := nanGeneral(p, n, n)
   166  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, d2, zeroR, 0, d20r)
   167  
   168  		if !equalApproxGeneral(vAns, d20r, 1e-14) {
   169  			t.Errorf("test %d: V^T*B*Q != D2*[ 0 R ]\nV^T*B*Q:\n%+v\nD2*[ 0 R ]:\n%+v",
   170  				cas, vAns, d20r)
   171  		}
   172  	}
   173  }