github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgesvd.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math/rand"
    10  	"testing"
    11  
    12  	"github.com/gonum/blas"
    13  	"github.com/gonum/blas/blas64"
    14  	"github.com/gonum/floats"
    15  	"github.com/gonum/lapack"
    16  )
    17  
    18  type Dgesvder interface {
    19  	Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, lwork int) (ok bool)
    20  }
    21  
    22  func DgesvdTest(t *testing.T, impl Dgesvder) {
    23  	rnd := rand.New(rand.NewSource(1))
    24  	// TODO(btracey): Add tests for all of the cases when the SVD implementation
    25  	// is finished.
    26  	// TODO(btracey): Add tests for m > mnthr and n > mnthr when other SVD
    27  	// conditions are implemented. Right now mnthr is 5,000,000 which is too
    28  	// large to create a square matrix of that size.
    29  	for _, test := range []struct {
    30  		m, n, lda, ldu, ldvt int
    31  	}{
    32  		{5, 5, 0, 0, 0},
    33  		{5, 6, 0, 0, 0},
    34  		{6, 5, 0, 0, 0},
    35  		{5, 9, 0, 0, 0},
    36  		{9, 5, 0, 0, 0},
    37  
    38  		{5, 5, 10, 11, 12},
    39  		{5, 6, 10, 11, 12},
    40  		{6, 5, 10, 11, 12},
    41  		{5, 5, 10, 11, 12},
    42  		{5, 9, 10, 11, 12},
    43  		{9, 5, 10, 11, 12},
    44  
    45  		{300, 300, 0, 0, 0},
    46  		{300, 400, 0, 0, 0},
    47  		{400, 300, 0, 0, 0},
    48  		{300, 600, 0, 0, 0},
    49  		{600, 300, 0, 0, 0},
    50  
    51  		{300, 300, 400, 450, 460},
    52  		{300, 400, 500, 550, 560},
    53  		{400, 300, 550, 550, 560},
    54  		{300, 600, 700, 750, 760},
    55  		{600, 300, 700, 750, 760},
    56  	} {
    57  		jobU := lapack.SVDAll
    58  		jobVT := lapack.SVDAll
    59  
    60  		m := test.m
    61  		n := test.n
    62  		lda := test.lda
    63  		if lda == 0 {
    64  			lda = n
    65  		}
    66  		ldu := test.ldu
    67  		if ldu == 0 {
    68  			ldu = m
    69  		}
    70  		ldvt := test.ldvt
    71  		if ldvt == 0 {
    72  			ldvt = n
    73  		}
    74  
    75  		a := make([]float64, m*lda)
    76  		for i := range a {
    77  			a[i] = rnd.NormFloat64()
    78  		}
    79  
    80  		u := make([]float64, m*ldu)
    81  		for i := range u {
    82  			u[i] = rnd.NormFloat64()
    83  		}
    84  
    85  		vt := make([]float64, n*ldvt)
    86  		for i := range vt {
    87  			vt[i] = rnd.NormFloat64()
    88  		}
    89  
    90  		uAllOrig := make([]float64, len(u))
    91  		copy(uAllOrig, u)
    92  		vtAllOrig := make([]float64, len(vt))
    93  		copy(vtAllOrig, vt)
    94  		aCopy := make([]float64, len(a))
    95  		copy(aCopy, a)
    96  
    97  		s := make([]float64, min(m, n))
    98  
    99  		work := make([]float64, 1)
   100  		impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, -1)
   101  
   102  		if !floats.Equal(a, aCopy) {
   103  			t.Errorf("a changed during call to get work length")
   104  		}
   105  
   106  		work = make([]float64, int(work[0]))
   107  		impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work))
   108  
   109  		errStr := fmt.Sprintf("m = %v, n = %v, lda = %v, ldu = %v, ldv = %v", m, n, lda, ldu, ldvt)
   110  		svdCheck(t, false, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda)
   111  		svdCheckPartial(t, impl, lapack.SVDAll, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false)
   112  
   113  		// Test InPlace
   114  		jobU = lapack.SVDInPlace
   115  		jobVT = lapack.SVDInPlace
   116  		copy(a, aCopy)
   117  		copy(u, uAllOrig)
   118  		copy(vt, vtAllOrig)
   119  
   120  		impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work))
   121  		svdCheck(t, true, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda)
   122  		svdCheckPartial(t, impl, lapack.SVDInPlace, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false)
   123  	}
   124  }
   125  
   126  // svdCheckPartial checks that the singular values and vectors are computed when
   127  // not all of them are computed.
   128  func svdCheckPartial(t *testing.T, impl Dgesvder, job lapack.SVDJob, errStr string, uAllOrig, vtAllOrig, aCopy []float64, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, shortWork bool) {
   129  	rnd := rand.New(rand.NewSource(1))
   130  	jobU := job
   131  	jobVT := job
   132  	// Compare the singular values when computed with {SVDNone, SVDNone.}
   133  	sCopy := make([]float64, len(s))
   134  	copy(sCopy, s)
   135  	copy(a, aCopy)
   136  	for i := range s {
   137  		s[i] = rnd.Float64()
   138  	}
   139  	tmp1 := make([]float64, 1)
   140  	tmp2 := make([]float64, 1)
   141  	jobU = lapack.SVDNone
   142  	jobVT = lapack.SVDNone
   143  
   144  	impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, -1)
   145  	work = make([]float64, int(work[0]))
   146  	lwork := len(work)
   147  	if shortWork {
   148  		lwork--
   149  	}
   150  	ok := impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, lwork)
   151  	if !ok {
   152  		t.Errorf("Dgesvd did not complete successfully")
   153  	}
   154  	if !floats.EqualApprox(s, sCopy, 1e-10) {
   155  		t.Errorf("Singular value mismatch when singular vectors not computed: %s", errStr)
   156  	}
   157  	// Check that the singular vectors are correctly computed when the other
   158  	// is none.
   159  	uAll := make([]float64, len(u))
   160  	copy(uAll, u)
   161  	vtAll := make([]float64, len(vt))
   162  	copy(vtAll, vt)
   163  
   164  	// Copy the original vectors so the data outside the matrix bounds is the same.
   165  	copy(u, uAllOrig)
   166  	copy(vt, vtAllOrig)
   167  
   168  	jobU = job
   169  	jobVT = lapack.SVDNone
   170  	copy(a, aCopy)
   171  	for i := range s {
   172  		s[i] = rnd.Float64()
   173  	}
   174  	impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, -1)
   175  	work = make([]float64, int(work[0]))
   176  	lwork = len(work)
   177  	if shortWork {
   178  		lwork--
   179  	}
   180  	impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, len(work))
   181  	if !floats.EqualApprox(uAll, u, 1e-10) {
   182  		t.Errorf("U mismatch when VT is not computed: %s", errStr)
   183  	}
   184  	if !floats.EqualApprox(s, sCopy, 1e-10) {
   185  		t.Errorf("Singular value mismatch when U computed VT not")
   186  	}
   187  	jobU = lapack.SVDNone
   188  	jobVT = job
   189  	copy(a, aCopy)
   190  	for i := range s {
   191  		s[i] = rnd.Float64()
   192  	}
   193  	impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, -1)
   194  	work = make([]float64, int(work[0]))
   195  	lwork = len(work)
   196  	if shortWork {
   197  		lwork--
   198  	}
   199  	impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, len(work))
   200  	if !floats.EqualApprox(vtAll, vt, 1e-10) {
   201  		t.Errorf("VT mismatch when U is not computed: %s", errStr)
   202  	}
   203  	if !floats.EqualApprox(s, sCopy, 1e-10) {
   204  		t.Errorf("Singular value mismatch when VT computed U not")
   205  	}
   206  }
   207  
   208  // svdCheck checks that the singular value decomposition correctly multiplies back
   209  // to the original matrix.
   210  func svdCheck(t *testing.T, thin bool, errStr string, m, n int, s, a, u []float64, ldu int, vt []float64, ldvt int, aCopy []float64, lda int) {
   211  	sigma := blas64.General{
   212  		Rows:   m,
   213  		Cols:   n,
   214  		Stride: n,
   215  		Data:   make([]float64, m*n),
   216  	}
   217  	for i := 0; i < min(m, n); i++ {
   218  		sigma.Data[i*sigma.Stride+i] = s[i]
   219  	}
   220  
   221  	uMat := blas64.General{
   222  		Rows:   m,
   223  		Cols:   m,
   224  		Stride: ldu,
   225  		Data:   u,
   226  	}
   227  	vTMat := blas64.General{
   228  		Rows:   n,
   229  		Cols:   n,
   230  		Stride: ldvt,
   231  		Data:   vt,
   232  	}
   233  	if thin {
   234  		sigma.Rows = min(m, n)
   235  		sigma.Cols = min(m, n)
   236  		uMat.Cols = min(m, n)
   237  		vTMat.Rows = min(m, n)
   238  	}
   239  
   240  	tmp := blas64.General{
   241  		Rows:   m,
   242  		Cols:   n,
   243  		Stride: n,
   244  		Data:   make([]float64, m*n),
   245  	}
   246  	ans := blas64.General{
   247  		Rows:   m,
   248  		Cols:   n,
   249  		Stride: lda,
   250  		Data:   make([]float64, m*lda),
   251  	}
   252  	copy(ans.Data, a)
   253  
   254  	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uMat, sigma, 0, tmp)
   255  	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, vTMat, 0, ans)
   256  
   257  	if !floats.EqualApprox(ans.Data, aCopy, 1e-8) {
   258  		t.Errorf("Decomposition mismatch. Trim = %v, %s", thin, errStr)
   259  	}
   260  
   261  	if !thin {
   262  		// Check that U and V are orthogonal.
   263  		for i := 0; i < uMat.Rows; i++ {
   264  			for j := i + 1; j < uMat.Rows; j++ {
   265  				dot := blas64.Dot(uMat.Cols,
   266  					blas64.Vector{Inc: 1, Data: uMat.Data[i*uMat.Stride:]},
   267  					blas64.Vector{Inc: 1, Data: uMat.Data[j*uMat.Stride:]},
   268  				)
   269  				if dot > 1e-8 {
   270  					t.Errorf("U not orthogonal %s", errStr)
   271  				}
   272  			}
   273  		}
   274  		for i := 0; i < vTMat.Rows; i++ {
   275  			for j := i + 1; j < vTMat.Rows; j++ {
   276  				dot := blas64.Dot(vTMat.Cols,
   277  					blas64.Vector{Inc: 1, Data: vTMat.Data[i*vTMat.Stride:]},
   278  					blas64.Vector{Inc: 1, Data: vTMat.Data[j*vTMat.Stride:]},
   279  				)
   280  				if dot > 1e-8 {
   281  					t.Errorf("V not orthogonal %s", errStr)
   282  				}
   283  			}
   284  		}
   285  	}
   286  }