github.com/gonum/matrix@v0.0.0-20181209220409-c518dec07be9/mat64/svd_test.go (about)

     1  // Copyright ©2013 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat64
     6  
     7  import (
     8  	"math/rand"
     9  	"testing"
    10  
    11  	"github.com/gonum/floats"
    12  	"github.com/gonum/matrix"
    13  )
    14  
    15  func TestSVD(t *testing.T) {
    16  	// Hand coded tests
    17  	for _, test := range []struct {
    18  		a *Dense
    19  		u *Dense
    20  		v *Dense
    21  		s []float64
    22  	}{
    23  		{
    24  			a: NewDense(4, 2, []float64{2, 4, 1, 3, 0, 0, 0, 0}),
    25  			u: NewDense(4, 2, []float64{
    26  				-0.8174155604703632, -0.5760484367663209,
    27  				-0.5760484367663209, 0.8174155604703633,
    28  				0, 0,
    29  				0, 0,
    30  			}),
    31  			v: NewDense(2, 2, []float64{
    32  				-0.4045535848337571, -0.9145142956773044,
    33  				-0.9145142956773044, 0.4045535848337571,
    34  			}),
    35  			s: []float64{5.464985704219041, 0.365966190626258},
    36  		},
    37  		{
    38  			// Issue #5.
    39  			a: NewDense(3, 11, []float64{
    40  				1, 1, 0, 1, 0, 0, 0, 0, 0, 11, 1,
    41  				1, 0, 0, 0, 0, 0, 1, 0, 0, 12, 2,
    42  				1, 1, 0, 0, 0, 0, 0, 0, 1, 13, 3,
    43  			}),
    44  			u: NewDense(3, 3, []float64{
    45  				-0.5224167862273765, 0.7864430360363114, 0.3295270133658976,
    46  				-0.5739526766688285, -0.03852203026050301, -0.8179818935216693,
    47  				-0.6306021141833781, -0.6164603833618163, 0.4715056408282468,
    48  			}),
    49  			v: NewDense(11, 3, []float64{
    50  				-0.08123293141915189, 0.08528085505260324, -0.013165501690885152,
    51  				-0.05423546426886932, 0.1102707844980355, 0.622210623111631,
    52  				0, 0, 0,
    53  				-0.0245733326078166, 0.510179651760153, 0.25596360803140994,
    54  				0, 0, 0,
    55  				0, 0, 0,
    56  				-0.026997467150282436, -0.024989929445430496, -0.6353761248025164,
    57  				0, 0, 0,
    58  				-0.029662131661052707, -0.3999088672621176, 0.3662470150802212,
    59  				-0.9798839760830571, 0.11328174160898856, -0.047702613241813366,
    60  				-0.16755466189153964, -0.7395268089170608, 0.08395240366704032,
    61  			}),
    62  			s: []float64{21.259500881097434, 1.5415021616856566, 1.2873979074613628},
    63  		},
    64  	} {
    65  		var svd SVD
    66  		ok := svd.Factorize(test.a, matrix.SVDThin)
    67  		if !ok {
    68  			t.Errorf("SVD failed")
    69  		}
    70  		s, u, v := extractSVD(&svd)
    71  		if !floats.EqualApprox(s, test.s, 1e-10) {
    72  			t.Errorf("Singular value mismatch. Got %v, want %v.", s, test.s)
    73  		}
    74  		if !EqualApprox(u, test.u, 1e-10) {
    75  			t.Errorf("U mismatch.\nGot:\n%v\nWant:\n%v", Formatted(u), Formatted(test.u))
    76  		}
    77  		if !EqualApprox(v, test.v, 1e-10) {
    78  			t.Errorf("V mismatch.\nGot:\n%v\nWant:\n%v", Formatted(v), Formatted(test.v))
    79  		}
    80  		m, n := test.a.Dims()
    81  		sigma := NewDense(min(m, n), min(m, n), nil)
    82  		for i := 0; i < min(m, n); i++ {
    83  			sigma.Set(i, i, s[i])
    84  		}
    85  
    86  		var ans Dense
    87  		ans.Product(u, sigma, v.T())
    88  		if !EqualApprox(test.a, &ans, 1e-10) {
    89  			t.Errorf("A reconstruction mismatch.\nGot:\n%v\nWant:\n%v\n", Formatted(&ans), Formatted(test.a))
    90  		}
    91  	}
    92  
    93  	for _, test := range []struct {
    94  		m, n int
    95  	}{
    96  		{5, 5},
    97  		{5, 3},
    98  		{3, 5},
    99  		{150, 150},
   100  		{200, 150},
   101  		{150, 200},
   102  	} {
   103  		m := test.m
   104  		n := test.n
   105  		for trial := 0; trial < 10; trial++ {
   106  			a := NewDense(m, n, nil)
   107  			for i := range a.mat.Data {
   108  				a.mat.Data[i] = rand.NormFloat64()
   109  			}
   110  			aCopy := DenseCopyOf(a)
   111  
   112  			// Test Full decomposition.
   113  			var svd SVD
   114  			ok := svd.Factorize(a, matrix.SVDFull)
   115  			if !ok {
   116  				t.Errorf("SVD factorization failed")
   117  			}
   118  			if !Equal(a, aCopy) {
   119  				t.Errorf("A changed during call to SVD with full")
   120  			}
   121  			s, u, v := extractSVD(&svd)
   122  			sigma := NewDense(m, n, nil)
   123  			for i := 0; i < min(m, n); i++ {
   124  				sigma.Set(i, i, s[i])
   125  			}
   126  			var ansFull Dense
   127  			ansFull.Product(u, sigma, v.T())
   128  			if !EqualApprox(&ansFull, a, 1e-8) {
   129  				t.Errorf("Answer mismatch when SVDFull")
   130  			}
   131  
   132  			// Test Thin decomposition.
   133  			ok = svd.Factorize(a, matrix.SVDThin)
   134  			if !ok {
   135  				t.Errorf("SVD factorization failed")
   136  			}
   137  			if !Equal(a, aCopy) {
   138  				t.Errorf("A changed during call to SVD with Thin")
   139  			}
   140  			sThin, u, v := extractSVD(&svd)
   141  			if !floats.EqualApprox(s, sThin, 1e-8) {
   142  				t.Errorf("Singular value mismatch between Full and Thin decomposition")
   143  			}
   144  			sigma = NewDense(min(m, n), min(m, n), nil)
   145  			for i := 0; i < min(m, n); i++ {
   146  				sigma.Set(i, i, sThin[i])
   147  			}
   148  			ansFull.Reset()
   149  			ansFull.Product(u, sigma, v.T())
   150  			if !EqualApprox(&ansFull, a, 1e-8) {
   151  				t.Errorf("Answer mismatch when SVDFull")
   152  			}
   153  
   154  			// Test None decomposition.
   155  			ok = svd.Factorize(a, matrix.SVDNone)
   156  			if !ok {
   157  				t.Errorf("SVD factorization failed")
   158  			}
   159  			if !Equal(a, aCopy) {
   160  				t.Errorf("A changed during call to SVD with none")
   161  			}
   162  			sNone := make([]float64, min(m, n))
   163  			svd.Values(sNone)
   164  			if !floats.EqualApprox(s, sNone, 1e-8) {
   165  				t.Errorf("Singular value mismatch between Full and None decomposition")
   166  			}
   167  		}
   168  	}
   169  }
   170  
   171  func extractSVD(svd *SVD) (s []float64, u, v *Dense) {
   172  	var um, vm Dense
   173  	um.UFromSVD(svd)
   174  	vm.VFromSVD(svd)
   175  	s = svd.Values(nil)
   176  	return s, &um, &vm
   177  }