gonum.org/v1/gonum@v0.14.0/mat/gsvd_test.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"fmt"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"gonum.org/v1/gonum/floats"
    14  	"gonum.org/v1/gonum/floats/scalar"
    15  )
    16  
    17  func TestGSVD(t *testing.T) {
    18  	t.Parallel()
    19  
    20  	const tol = 1e-10
    21  	for _, test := range []struct {
    22  		m, p, n int
    23  	}{
    24  		{5, 3, 5},
    25  		{5, 3, 3},
    26  		{3, 3, 5},
    27  		{5, 5, 5},
    28  		{5, 5, 3},
    29  		{3, 5, 5},
    30  		{150, 150, 150},
    31  		{200, 150, 150},
    32  		{150, 150, 200},
    33  		{150, 200, 150},
    34  		{200, 200, 150},
    35  		{150, 200, 200},
    36  	} {
    37  		m := test.m
    38  		p := test.p
    39  		n := test.n
    40  		t.Run(fmt.Sprintf("%v", test), func(t *testing.T) {
    41  			t.Parallel()
    42  
    43  			rnd := rand.New(rand.NewSource(1))
    44  			for trial := 0; trial < 10; trial++ {
    45  				a := NewDense(m, n, nil)
    46  				for i := range a.mat.Data {
    47  					a.mat.Data[i] = rnd.NormFloat64()
    48  				}
    49  				aCopy := DenseCopyOf(a)
    50  
    51  				b := NewDense(p, n, nil)
    52  				for i := range b.mat.Data {
    53  					b.mat.Data[i] = rnd.NormFloat64()
    54  				}
    55  				bCopy := DenseCopyOf(b)
    56  
    57  				// Test Full decomposition.
    58  				var gsvd GSVD
    59  				ok := gsvd.Factorize(a, b, GSVDU|GSVDV|GSVDQ)
    60  				if !ok {
    61  					t.Errorf("GSVD factorization failed")
    62  				}
    63  				if !Equal(a, aCopy) {
    64  					t.Errorf("A changed during call to GSVD.Factorize with GSVDU|GSVDV|GSVDQ")
    65  				}
    66  				if !Equal(b, bCopy) {
    67  					t.Errorf("B changed during call to GSVD.Factorize with GSVDU|GSVDV|GSVDQ")
    68  				}
    69  				c, s, sigma1, sigma2, zeroR, u, v, q := extractGSVD(&gsvd)
    70  				var ansU, ansV, d1R, d2R Dense
    71  				ansU.Product(u.T(), a, q)
    72  				ansV.Product(v.T(), b, q)
    73  				d1R.Mul(sigma1, zeroR)
    74  				d2R.Mul(sigma2, zeroR)
    75  				if !EqualApprox(&ansU, &d1R, tol) {
    76  					t.Errorf("Answer mismatch with GSVDU|GSVDV|GSVDQ\nUᵀ * A * Q:\n% 0.2f\nΣ₁ * [ 0 R ]:\n% 0.2f",
    77  						Formatted(&ansU), Formatted(&d1R))
    78  				}
    79  				if !EqualApprox(&ansV, &d2R, tol) {
    80  					t.Errorf("Answer mismatch with GSVDU|GSVDV|GSVDQ\nVᵀ * B  *Q:\n% 0.2f\nΣ₂ * [ 0 R ]:\n% 0.2f",
    81  						Formatted(&d2R), Formatted(&ansV))
    82  				}
    83  
    84  				// Check C^2 + S^2 = I.
    85  				for i := range c {
    86  					d := c[i]*c[i] + s[i]*s[i]
    87  					if !scalar.EqualWithinAbsOrRel(d, 1, 1e-14, 1e-14) {
    88  						t.Errorf("c_%d^2 + s_%d^2 != 1: got: %v", i, i, d)
    89  					}
    90  				}
    91  
    92  				// Test None decomposition.
    93  				ok = gsvd.Factorize(a, b, GSVDNone)
    94  				if !ok {
    95  					t.Errorf("GSVD factorization failed")
    96  				}
    97  				if !Equal(a, aCopy) {
    98  					t.Errorf("A changed during call to GSVD with GSVDNone")
    99  				}
   100  				if !Equal(b, bCopy) {
   101  					t.Errorf("B changed during call to GSVD with GSVDNone")
   102  				}
   103  				cNone := gsvd.ValuesA(nil)
   104  				if !floats.EqualApprox(c, cNone, tol) {
   105  					t.Errorf("Singular value mismatch between GSVDU|GSVDV|GSVDQ and GSVDNone decomposition")
   106  				}
   107  				sNone := gsvd.ValuesB(nil)
   108  				if !floats.EqualApprox(s, sNone, tol) {
   109  					t.Errorf("Singular value mismatch between GSVDU|GSVDV|GSVDQ and GSVDNone decomposition")
   110  				}
   111  			}
   112  		})
   113  
   114  	}
   115  }
   116  
   117  func extractGSVD(gsvd *GSVD) (c, s []float64, s1, s2, zR, u, v, q *Dense) {
   118  	s1 = &Dense{}
   119  	s2 = &Dense{}
   120  	zR = &Dense{}
   121  	u = &Dense{}
   122  	v = &Dense{}
   123  	q = &Dense{}
   124  	gsvd.SigmaATo(s1)
   125  	gsvd.SigmaBTo(s2)
   126  	gsvd.ZeroRTo(zR)
   127  	gsvd.UTo(u)
   128  	gsvd.VTo(v)
   129  	gsvd.QTo(q)
   130  	c = gsvd.ValuesA(nil)
   131  	s = gsvd.ValuesB(nil)
   132  	return c, s, s1, s2, zR, u, v, q
   133  }