github.com/wzzhu/tensor@v0.9.24/dense_svd_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  
     7  	"github.com/pkg/errors"
     8  	"gonum.org/v1/gonum/mat"
     9  )
    10  
    11  // tests for SVD adapted from Gonum's SVD tests.
    12  // Gonum's licence is listed at https://gonum.org/v1/gonum/license
    13  
    14  var svdtestsThin = []struct {
    15  	data  []float64
    16  	shape Shape
    17  
    18  	correctSData  []float64
    19  	correctSShape Shape
    20  
    21  	correctUData  []float64
    22  	correctUShape Shape
    23  
    24  	correctVData  []float64
    25  	correctVShape Shape
    26  }{
    27  	{
    28  		[]float64{2, 4, 1, 3, 0, 0, 0, 0}, Shape{4, 2},
    29  		[]float64{5.464985704219041, 0.365966190626258}, Shape{2},
    30  		[]float64{-0.8174155604703632, -0.5760484367663209, -0.5760484367663209, 0.8174155604703633, 0, 0, 0, 0}, Shape{4, 2},
    31  		[]float64{-0.4045535848337571, -0.9145142956773044, -0.9145142956773044, 0.4045535848337571}, Shape{2, 2},
    32  	},
    33  
    34  	{
    35  		[]float64{1, 1, 0, 1, 0, 0, 0, 0, 0, 11, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 12, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 13, 3}, Shape{3, 11},
    36  		[]float64{21.259500881097434, 1.5415021616856566, 1.2873979074613628}, Shape{3},
    37  		[]float64{-0.5224167862273765, 0.7864430360363114, 0.3295270133658976, -0.5739526766688285, -0.03852203026050301, -0.8179818935216693, -0.6306021141833781, -0.6164603833618163, 0.4715056408282468}, Shape{3, 3},
    38  		[]float64{
    39  			-0.08123293141915189, 0.08528085505260324, -0.013165501690885152,
    40  			-0.05423546426886932, 0.1102707844980355, 0.622210623111631,
    41  			0, 0, 0,
    42  			-0.0245733326078166, 0.510179651760153, 0.25596360803140994,
    43  			0, 0, 0,
    44  			0, 0, 0,
    45  			-0.026997467150282436, -0.024989929445430496, -0.6353761248025164,
    46  			0, 0, 0,
    47  			-0.029662131661052707, -0.3999088672621176, 0.3662470150802212,
    48  			-0.9798839760830571, 0.11328174160898856, -0.047702613241813366,
    49  			-0.16755466189153964, -0.7395268089170608, 0.08395240366704032}, Shape{11, 3},
    50  	},
    51  }
    52  
    53  var svdtestsFull = []Shape{
    54  	{5, 5},
    55  	{5, 3},
    56  	{3, 5},
    57  	{150, 150},
    58  	{200, 150},
    59  	{150, 200},
    60  }
    61  
    62  // calculate corrects
    63  func calcSigma(s, T *Dense, shape Shape) (sigma *Dense, err error) {
    64  	sigma = New(Of(Float64), WithShape(shape...))
    65  	for i := 0; i < MinInt(shape[0], shape[1]); i++ {
    66  		var idx int
    67  		if idx, err = Ltoi(sigma.Shape(), sigma.Strides(), i, i); err != nil {
    68  			return
    69  		}
    70  		sigma.Float64s()[idx] = s.Float64s()[i]
    71  	}
    72  
    73  	return
    74  }
    75  
    76  // test svd by doing the SVD, then calculating the corrects
    77  func testSVD(T, T2, s, u, v *Dense, t string, i int) (err error) {
    78  	var sigma, reconstructed *Dense
    79  
    80  	if !allClose(T2.Data(), T.Data(), closeenoughf64) {
    81  		return errors.Errorf("A call to SVD modified the underlying data! %s Test %d", t, i)
    82  	}
    83  
    84  	shape := T2.Shape()
    85  	if t == "thin" {
    86  		shape = Shape{MinInt(shape[0], shape[1]), MinInt(shape[0], shape[1])}
    87  	}
    88  
    89  	if sigma, err = calcSigma(s, T, shape); err != nil {
    90  		return
    91  	}
    92  	v.T()
    93  
    94  	if reconstructed, err = u.MatMul(sigma, UseSafe()); err != nil {
    95  		return
    96  	}
    97  	if reconstructed, err = reconstructed.MatMul(v, UseSafe()); err != nil {
    98  		return
    99  	}
   100  
   101  	if !allClose(T2.Data(), reconstructed.Data(), closeenoughf64) {
   102  		return errors.Errorf("Expected reconstructed to be %v. Got %v instead", T2.Data(), reconstructed.Data())
   103  	}
   104  	return nil
   105  }
   106  
   107  func ExampleDense_SVD() {
   108  	T := New(
   109  		WithShape(4, 5),
   110  		WithBacking([]float64{1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0}),
   111  	)
   112  	_, u, _, _ := T.SVD(true, true)
   113  	uT := u.Clone().(*Dense)
   114  	uT.T()
   115  	eye, err := u.MatMul(uT)
   116  	fmt.Println(eye)
   117  	fmt.Println(err)
   118  
   119  	// Output:
   120  	// ⎡1  0  0  0⎤
   121  	// ⎢0  1  0  0⎥
   122  	// ⎢0  0  1  0⎥
   123  	// ⎣0  0  0  1⎦
   124  	//
   125  	// <nil>
   126  }
   127  
   128  func TestDense_SVD(t *testing.T) {
   129  	var T, T2, s, u, v *Dense
   130  	var err error
   131  
   132  	// gonum specific thin special cases
   133  	for i, stts := range svdtestsThin {
   134  		T = New(WithShape(stts.shape...), WithBacking(stts.data))
   135  		T2 = T.Clone().(*Dense)
   136  
   137  		if s, u, v, err = T.SVD(true, false); err != nil {
   138  			t.Error(err)
   139  			continue
   140  		}
   141  
   142  		if !allClose(T2.Data(), T.Data(), closeenoughf64) {
   143  			t.Errorf("A call to SVD modified the underlying data! Thin Test %d", i)
   144  			continue
   145  		}
   146  
   147  		if !allClose(stts.correctSData, s.Data(), closeenoughf64) {
   148  			t.Errorf("Expected s = %v. Got %v instead", stts.correctSData, s.Data())
   149  		}
   150  
   151  		if !allClose(stts.correctUData, u.Data(), closeenoughf64) {
   152  			t.Errorf("Expected u = %v. Got %v instead", stts.correctUData, u.Data())
   153  		}
   154  
   155  		if !allClose(stts.correctVData, v.Data(), closeenoughf64) {
   156  			t.Errorf("Expected v = %v. Got %v instead", stts.correctVData, v.Data())
   157  		}
   158  	}
   159  	// standard tests
   160  	for i, stfs := range svdtestsFull {
   161  		T = New(WithShape(stfs...), WithBacking(Random(Float64, stfs.TotalSize())))
   162  		T2 = T.Clone().(*Dense)
   163  
   164  		// full
   165  		if s, u, v, err = T.SVD(true, true); err != nil {
   166  			t.Error(err)
   167  			fmt.Println(err)
   168  			continue
   169  		}
   170  		if err = testSVD(T, T2, s, u, v, "full", i); err != nil {
   171  			t.Error(err)
   172  			fmt.Println(err)
   173  			continue
   174  		}
   175  		// thin
   176  		if s, u, v, err = T.SVD(true, false); err != nil {
   177  			t.Error(err)
   178  			continue
   179  		}
   180  
   181  		if err = testSVD(T, T2, s, u, v, "thin", i); err != nil {
   182  			t.Error(err)
   183  			continue
   184  		}
   185  
   186  		// none
   187  		if s, u, v, err = T.SVD(false, false); err != nil {
   188  			t.Error(err)
   189  			continue
   190  		}
   191  
   192  		var svd mat.SVD
   193  		var m *mat.Dense
   194  		if m, err = ToMat64(T); err != nil {
   195  			t.Error(err)
   196  			continue
   197  		}
   198  
   199  		if !svd.Factorize(m, mat.SVDFull) {
   200  			t.Errorf("Unable to factorise %v", m)
   201  			continue
   202  		}
   203  
   204  		if !allClose(s.Data(), svd.Values(nil), closeenoughf64) {
   205  			t.Errorf("Singular value mismatch between Full and None decomposition. Expected %v. Got %v instead", svd.Values(nil), s.Data())
   206  		}
   207  
   208  	}
   209  	// this is illogical
   210  	T = New(Of(Float64), WithShape(2, 2))
   211  	if _, _, _, err = T.SVD(false, true); err == nil {
   212  		t.Errorf("Expected an error!")
   213  	}
   214  
   215  	// if you do this, it is bad and you should feel bad
   216  	T = New(Of(Float64), WithShape(2, 3, 4))
   217  	if _, _, _, err = T.SVD(true, true); err == nil {
   218  		t.Errorf("Expecetd an error: cannot SVD() a Tensor > 2 dimensions")
   219  	}
   220  
   221  	T = New(Of(Float64), WithShape(2))
   222  	if _, _, _, err = T.SVD(true, true); err == nil {
   223  		t.Errorf("Expecetd an error: cannot SVD() a Tensor < 2 dimensions")
   224  	}
   225  }