gonum.org/v1/gonum@v0.14.0/mat/hogsvd_test.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"testing"
     9  
    10  	"golang.org/x/exp/rand"
    11  )
    12  
    13  func TestHOGSVD(t *testing.T) {
    14  	t.Parallel()
    15  	const tol = 1e-10
    16  	rnd := rand.New(rand.NewSource(1))
    17  	for cas, test := range []struct {
    18  		r, c int
    19  	}{
    20  		{5, 3},
    21  		{5, 5},
    22  		{150, 150},
    23  		{200, 150},
    24  
    25  		// Calculating A_i*A_jᵀ and A_j*A_iᵀ fails for wide matrices.
    26  		{3, 5},
    27  	} {
    28  		r := test.r
    29  		c := test.c
    30  		for n := 3; n < 6; n++ {
    31  			data := make([]Matrix, n)
    32  			dataCopy := make([]*Dense, n)
    33  			for trial := 0; trial < 10; trial++ {
    34  				for i := range data {
    35  					d := NewDense(r, c, nil)
    36  					for j := range d.mat.Data {
    37  						d.mat.Data[j] = rnd.Float64()
    38  					}
    39  					data[i] = d
    40  					dataCopy[i] = DenseCopyOf(d)
    41  				}
    42  
    43  				var gsvd HOGSVD
    44  				ok := gsvd.Factorize(data...)
    45  				if r >= c {
    46  					if !ok {
    47  						t.Errorf("HOGSVD factorization failed for %d %d×%d matrices: %v", n, r, c, gsvd.Err())
    48  						continue
    49  					}
    50  				} else {
    51  					if ok {
    52  						t.Errorf("HOGSVD factorization unexpectedly succeeded for %d %d×%d matrices", n, r, c)
    53  					}
    54  					continue
    55  				}
    56  				for i := range data {
    57  					if !Equal(data[i], dataCopy[i]) {
    58  						t.Errorf("A changed during call to HOGSVD.Factorize")
    59  					}
    60  				}
    61  				u, s, v := extractHOGSVD(&gsvd)
    62  				for i, want := range data {
    63  					var got Dense
    64  					sigma := NewDense(c, c, nil)
    65  					for j := 0; j < c; j++ {
    66  						sigma.Set(j, j, s[i][j])
    67  					}
    68  
    69  					got.Product(u[i], sigma, v.T())
    70  					if !EqualApprox(&got, want, tol) {
    71  						t.Errorf("test %d n=%d trial %d: unexpected answer\nU_%[4]d * S_%[4]d * Vᵀ:\n% 0.2f\nD_%d:\n% 0.2f",
    72  							cas, n, trial, i, Formatted(&got, Excerpt(5)), i, Formatted(want, Excerpt(5)))
    73  					}
    74  				}
    75  			}
    76  		}
    77  	}
    78  }
    79  
    80  func extractHOGSVD(gsvd *HOGSVD) (u []*Dense, s [][]float64, v *Dense) {
    81  	u = make([]*Dense, gsvd.Len())
    82  	s = make([][]float64, gsvd.Len())
    83  	for i := 0; i < gsvd.Len(); i++ {
    84  		u[i] = &Dense{}
    85  		gsvd.UTo(u[i], i)
    86  		s[i] = gsvd.Values(nil, i)
    87  	}
    88  	v = &Dense{}
    89  	gsvd.VTo(v)
    90  	return u, s, v
    91  }