github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/testlapack/dgesvd.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"sort"
    11  	"testing"
    12  
    13  	"golang.org/x/exp/rand"
    14  
    15  	"github.com/jingcheng-WU/gonum/blas"
    16  	"github.com/jingcheng-WU/gonum/blas/blas64"
    17  	"github.com/jingcheng-WU/gonum/floats"
    18  	"github.com/jingcheng-WU/gonum/lapack"
    19  )
    20  
    21  type Dgesvder interface {
    22  	Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, lwork int) (ok bool)
    23  }
    24  
    25  func DgesvdTest(t *testing.T, impl Dgesvder, tol float64) {
    26  	for _, m := range []int{0, 1, 2, 3, 4, 5, 10, 150, 300} {
    27  		for _, n := range []int{0, 1, 2, 3, 4, 5, 10, 150} {
    28  			for _, mtype := range []int{1, 2, 3, 4, 5} {
    29  				dgesvdTest(t, impl, m, n, mtype, tol)
    30  			}
    31  		}
    32  	}
    33  }
    34  
    35  // dgesvdTest tests a Dgesvd implementation on an m×n matrix A generated
    36  // according to mtype as:
    37  //  - the zero matrix if mtype == 1,
    38  //  - the identity matrix if mtype == 2,
    39  //  - a random matrix with a given condition number and singular values if mtype == 3, 4, or 5.
    40  // It first computes the full SVD  A = U*Sigma*Vᵀ  and checks that
    41  //  - U has orthonormal columns, and Vᵀ has orthonormal rows,
    42  //  - U*Sigma*Vᵀ multiply back to A,
    43  //  - the singular values are non-negative and sorted in decreasing order.
    44  // Then all combinations of partial SVD results are computed and checked whether
    45  // they match the full SVD result.
    46  func dgesvdTest(t *testing.T, impl Dgesvder, m, n, mtype int, tol float64) {
    47  	const tolOrtho = 1e-15
    48  
    49  	rnd := rand.New(rand.NewSource(1))
    50  
    51  	// Use a fixed leading dimension to reduce testing time.
    52  	lda := n + 3
    53  	ldu := m + 5
    54  	ldvt := n + 7
    55  
    56  	minmn := min(m, n)
    57  
    58  	// Allocate A and fill it with random values. The in-range elements will
    59  	// be overwritten below according to mtype.
    60  	a := make([]float64, m*lda)
    61  	for i := range a {
    62  		a[i] = rnd.NormFloat64()
    63  	}
    64  
    65  	var aNorm float64
    66  	switch mtype {
    67  	default:
    68  		panic("unknown test matrix type")
    69  	case 1:
    70  		// Zero matrix.
    71  		for i := 0; i < m; i++ {
    72  			for j := 0; j < n; j++ {
    73  				a[i*lda+j] = 0
    74  			}
    75  		}
    76  		aNorm = 0
    77  	case 2:
    78  		// Identity matrix.
    79  		for i := 0; i < m; i++ {
    80  			for j := 0; j < n; j++ {
    81  				if i == j {
    82  					a[i*lda+i] = 1
    83  				} else {
    84  					a[i*lda+j] = 0
    85  				}
    86  			}
    87  		}
    88  		aNorm = 1
    89  	case 3, 4, 5:
    90  		// Scaled random matrix.
    91  		// Generate singular values.
    92  		s := make([]float64, minmn)
    93  		Dlatm1(s,
    94  			4,                      // s[i] = 1 - i*(1-1/cond)/(minmn-1)
    95  			float64(max(1, minmn)), // where cond = max(1,minmn)
    96  			false,                  // signs of s[i] are not randomly flipped
    97  			1, rnd)                 // random numbers are drawn uniformly from [0,1)
    98  		// Decide scale factor for the singular values based on the matrix type.
    99  		ulp := dlamchP
   100  		unfl := dlamchS
   101  		ovfl := 1 / unfl
   102  		aNorm = 1
   103  		if mtype == 4 {
   104  			aNorm = unfl / ulp
   105  		}
   106  		if mtype == 5 {
   107  			aNorm = ovfl * ulp
   108  		}
   109  		// Scale singular values so that the maximum singular value is
   110  		// equal to aNorm (we know that the singular values are
   111  		// generated above to be spread linearly between 1/cond and 1).
   112  		floats.Scale(aNorm, s)
   113  		// Generate A by multiplying S by random orthogonal matrices
   114  		// from left and right.
   115  		Dlagge(m, n, max(0, m-1), max(0, n-1), s, a, lda, rnd, make([]float64, m+n))
   116  	}
   117  	aCopy := make([]float64, len(a))
   118  	copy(aCopy, a)
   119  
   120  	for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} {
   121  		// Restore A because Dgesvd overwrites it.
   122  		copy(a, aCopy)
   123  
   124  		// Allocate slices that will be used below to store the results of full
   125  		// SVD and fill them.
   126  		uAll := make([]float64, m*ldu)
   127  		for i := range uAll {
   128  			uAll[i] = rnd.NormFloat64()
   129  		}
   130  		vtAll := make([]float64, n*ldvt)
   131  		for i := range vtAll {
   132  			vtAll[i] = rnd.NormFloat64()
   133  		}
   134  		sAll := make([]float64, min(m, n))
   135  		for i := range sAll {
   136  			sAll[i] = math.NaN()
   137  		}
   138  
   139  		prefix := fmt.Sprintf("m=%v,n=%v,work=%v,mtype=%v", m, n, wl, mtype)
   140  
   141  		// Determine workspace size based on wl.
   142  		minwork := max(1, max(5*min(m, n), 3*min(m, n)+max(m, n)))
   143  		var lwork int
   144  		switch wl {
   145  		case minimumWork:
   146  			lwork = minwork
   147  		case mediumWork:
   148  			work := make([]float64, 1)
   149  			impl.Dgesvd(lapack.SVDAll, lapack.SVDAll, m, n, a, lda, sAll, uAll, ldu, vtAll, ldvt, work, -1)
   150  			lwork = (int(work[0]) + minwork) / 2
   151  		case optimumWork:
   152  			work := make([]float64, 1)
   153  			impl.Dgesvd(lapack.SVDAll, lapack.SVDAll, m, n, a, lda, sAll, uAll, ldu, vtAll, ldvt, work, -1)
   154  			lwork = int(work[0])
   155  		}
   156  		work := make([]float64, max(1, lwork))
   157  		for i := range work {
   158  			work[i] = math.NaN()
   159  		}
   160  
   161  		// Compute the full SVD which will be used later for checking the partial results.
   162  		ok := impl.Dgesvd(lapack.SVDAll, lapack.SVDAll, m, n, a, lda, sAll, uAll, ldu, vtAll, ldvt, work, len(work))
   163  		if !ok {
   164  			t.Fatalf("Case %v: unexpected failure in full SVD", prefix)
   165  		}
   166  
   167  		// Check that uAll, sAll, and vtAll multiply back to A by computing a residual
   168  		//  |A - U*S*VT| / (n*aNorm)
   169  		if resid := svdFullResidual(m, n, aNorm, aCopy, lda, uAll, ldu, sAll, vtAll, ldvt); resid > tol {
   170  			t.Errorf("Case %v: original matrix not recovered for full SVD, |A - U*D*VT|=%v", prefix, resid)
   171  		}
   172  		if minmn > 0 {
   173  			// Check that uAll is orthogonal.
   174  			q := blas64.General{Rows: m, Cols: m, Data: uAll, Stride: ldu}
   175  			if resid := residualOrthogonal(q, false); resid > tolOrtho*float64(m) {
   176  				t.Errorf("Case %v: UAll is not orthogonal; resid=%v, want<=%v", prefix, resid, tolOrtho*float64(m))
   177  			}
   178  			// Check that vtAll is orthogonal.
   179  			q = blas64.General{Rows: n, Cols: n, Data: vtAll, Stride: ldvt}
   180  			if resid := residualOrthogonal(q, false); resid > tolOrtho*float64(n) {
   181  				t.Errorf("Case %v: VTAll is not orthogonal; resid=%v, want<=%v", prefix, resid, tolOrtho*float64(n))
   182  			}
   183  		}
   184  		// Check that singular values are decreasing.
   185  		if !sort.IsSorted(sort.Reverse(sort.Float64Slice(sAll))) {
   186  			t.Errorf("Case %v: singular values from full SVD are not decreasing", prefix)
   187  		}
   188  		// Check that singular values are non-negative.
   189  		if minmn > 0 && floats.Min(sAll) < 0 {
   190  			t.Errorf("Case %v: some singular values from full SVD are negative", prefix)
   191  		}
   192  
   193  		// Do partial SVD and compare the results to sAll, uAll, and vtAll.
   194  		for _, jobU := range []lapack.SVDJob{lapack.SVDAll, lapack.SVDStore, lapack.SVDOverwrite, lapack.SVDNone} {
   195  			for _, jobVT := range []lapack.SVDJob{lapack.SVDAll, lapack.SVDStore, lapack.SVDOverwrite, lapack.SVDNone} {
   196  				if jobU == lapack.SVDOverwrite || jobVT == lapack.SVDOverwrite {
   197  					// Not implemented.
   198  					continue
   199  				}
   200  				if jobU == lapack.SVDAll && jobVT == lapack.SVDAll {
   201  					// Already checked above.
   202  					continue
   203  				}
   204  
   205  				prefix := prefix + ",job=" + svdJobString(jobU) + "U-" + svdJobString(jobVT) + "VT"
   206  
   207  				// Restore A to its original values.
   208  				copy(a, aCopy)
   209  
   210  				// Allocate slices for the results of partial SVD and fill them.
   211  				u := make([]float64, m*ldu)
   212  				for i := range u {
   213  					u[i] = rnd.NormFloat64()
   214  				}
   215  				vt := make([]float64, n*ldvt)
   216  				for i := range vt {
   217  					vt[i] = rnd.NormFloat64()
   218  				}
   219  				s := make([]float64, min(m, n))
   220  				for i := range s {
   221  					s[i] = math.NaN()
   222  				}
   223  
   224  				for i := range work {
   225  					work[i] = math.NaN()
   226  				}
   227  
   228  				ok := impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work))
   229  				if !ok {
   230  					t.Fatalf("Case %v: unexpected failure in partial Dgesvd", prefix)
   231  				}
   232  
   233  				if minmn == 0 {
   234  					// No panic and the result is ok, there is
   235  					// nothing else to check.
   236  					continue
   237  				}
   238  
   239  				// Check that U has orthogonal columns and that it matches UAll.
   240  				switch jobU {
   241  				case lapack.SVDStore:
   242  					q := blas64.General{Rows: m, Cols: minmn, Data: u, Stride: ldu}
   243  					if resid := residualOrthogonal(q, false); resid > tolOrtho*float64(m) {
   244  						t.Errorf("Case %v: columns of U are not orthogonal; resid=%v, want<=%v", prefix, resid, tolOrtho*float64(m))
   245  					}
   246  					if res := svdPartialUResidual(m, minmn, u, uAll, ldu); res > tol {
   247  						t.Errorf("Case %v: columns of U do not match UAll", prefix)
   248  					}
   249  				case lapack.SVDAll:
   250  					q := blas64.General{Rows: m, Cols: m, Data: u, Stride: ldu}
   251  					if resid := residualOrthogonal(q, false); resid > tolOrtho*float64(m) {
   252  						t.Errorf("Case %v: columns of U are not orthogonal; resid=%v, want<=%v", prefix, resid, tolOrtho*float64(m))
   253  					}
   254  					if res := svdPartialUResidual(m, m, u, uAll, ldu); res > tol {
   255  						t.Errorf("Case %v: columns of U do not match UAll", prefix)
   256  					}
   257  				}
   258  				// Check that VT has orthogonal rows and that it matches VTAll.
   259  				switch jobVT {
   260  				case lapack.SVDStore:
   261  					q := blas64.General{Rows: minmn, Cols: n, Data: vtAll, Stride: ldvt}
   262  					if resid := residualOrthogonal(q, true); resid > tolOrtho*float64(n) {
   263  						t.Errorf("Case %v: rows of VT are not orthogonal; resid=%v, want<=%v", prefix, resid, tolOrtho*float64(n))
   264  					}
   265  					if res := svdPartialVTResidual(minmn, n, vt, vtAll, ldvt); res > tol {
   266  						t.Errorf("Case %v: rows of VT do not match VTAll", prefix)
   267  					}
   268  				case lapack.SVDAll:
   269  					q := blas64.General{Rows: n, Cols: n, Data: vtAll, Stride: ldvt}
   270  					if resid := residualOrthogonal(q, true); resid > tolOrtho*float64(n) {
   271  						t.Errorf("Case %v: rows of VT are not orthogonal; resid=%v, want<=%v", prefix, resid, tolOrtho*float64(n))
   272  					}
   273  					if res := svdPartialVTResidual(n, n, vt, vtAll, ldvt); res > tol {
   274  						t.Errorf("Case %v: rows of VT do not match VTAll", prefix)
   275  					}
   276  				}
   277  				// Check that singular values are decreasing.
   278  				if !sort.IsSorted(sort.Reverse(sort.Float64Slice(s))) {
   279  					t.Errorf("Case %v: singular values from full SVD are not decreasing", prefix)
   280  				}
   281  				// Check that singular values are non-negative.
   282  				if floats.Min(s) < 0 {
   283  					t.Errorf("Case %v: some singular values from full SVD are negative", prefix)
   284  				}
   285  				if !floats.EqualApprox(s, sAll, tol/10) {
   286  					t.Errorf("Case %v: singular values differ between full and partial SVD\n%v\n%v", prefix, s, sAll)
   287  				}
   288  			}
   289  		}
   290  	}
   291  }
   292  
   293  // svdFullResidual returns
   294  //  |A - U*D*VT| / (n * aNorm)
   295  // where U, D, and VT are as computed by Dgesvd with jobU = jobVT = lapack.SVDAll.
   296  func svdFullResidual(m, n int, aNorm float64, a []float64, lda int, u []float64, ldu int, d []float64, vt []float64, ldvt int) float64 {
   297  	// The implementation follows TESTING/dbdt01.f from the reference.
   298  
   299  	minmn := min(m, n)
   300  	if minmn == 0 {
   301  		return 0
   302  	}
   303  
   304  	// j-th column of A - U*D*VT.
   305  	aMinusUDVT := make([]float64, m)
   306  	// D times the j-th column of VT.
   307  	dvt := make([]float64, minmn)
   308  	// Compute the residual |A - U*D*VT| one column at a time.
   309  	var resid float64
   310  	for j := 0; j < n; j++ {
   311  		// Copy j-th column of A to aj.
   312  		blas64.Copy(blas64.Vector{N: m, Data: a[j:], Inc: lda}, blas64.Vector{N: m, Data: aMinusUDVT, Inc: 1})
   313  		// Multiply D times j-th column of VT.
   314  		for i := 0; i < minmn; i++ {
   315  			dvt[i] = d[i] * vt[i*ldvt+j]
   316  		}
   317  		// Compute the j-th column of A - U*D*VT.
   318  		blas64.Gemv(blas.NoTrans,
   319  			-1, blas64.General{Rows: m, Cols: minmn, Data: u, Stride: ldu}, blas64.Vector{N: minmn, Data: dvt, Inc: 1},
   320  			1, blas64.Vector{N: m, Data: aMinusUDVT, Inc: 1})
   321  		resid = math.Max(resid, blas64.Asum(blas64.Vector{N: m, Data: aMinusUDVT, Inc: 1}))
   322  	}
   323  	if aNorm == 0 {
   324  		if resid != 0 {
   325  			// Original matrix A is zero but the residual is non-zero,
   326  			// return infinity.
   327  			return math.Inf(1)
   328  		}
   329  		// Original matrix A is zero, residual is zero, return 0.
   330  		return 0
   331  	}
   332  	// Original matrix A is non-zero.
   333  	if aNorm >= resid {
   334  		resid = resid / aNorm / float64(n)
   335  	} else {
   336  		if aNorm < 1 {
   337  			resid = math.Min(resid, float64(n)*aNorm) / aNorm / float64(n)
   338  		} else {
   339  			resid = math.Min(resid/aNorm, float64(n)) / float64(n)
   340  		}
   341  	}
   342  	return resid
   343  }
   344  
   345  // svdPartialUResidual compares U and URef to see if their columns span the same
   346  // spaces. It returns the maximum over columns of
   347  //  |URef(i) - S*U(i)|
   348  // where URef(i) and U(i) are the i-th columns of URef and U, respectively, and
   349  // S is ±1 chosen to minimize the expression.
   350  func svdPartialUResidual(m, n int, u, uRef []float64, ldu int) float64 {
   351  	var res float64
   352  	for j := 0; j < n; j++ {
   353  		imax := blas64.Iamax(blas64.Vector{N: m, Data: uRef[j:], Inc: ldu})
   354  		s := math.Copysign(1, uRef[imax*ldu+j]) * math.Copysign(1, u[imax*ldu+j])
   355  		for i := 0; i < m; i++ {
   356  			diff := math.Abs(uRef[i*ldu+j] - s*u[i*ldu+j])
   357  			res = math.Max(res, diff)
   358  		}
   359  	}
   360  	return res
   361  }
   362  
   363  // svdPartialVTResidual compares VT and VTRef to see if their rows span the same
   364  // spaces. It returns the maximum over rows of
   365  //  |VTRef(i) - S*VT(i)|
   366  // where VTRef(i) and VT(i) are the i-th columns of VTRef and VT, respectively, and
   367  // S is ±1 chosen to minimize the expression.
   368  func svdPartialVTResidual(m, n int, vt, vtRef []float64, ldvt int) float64 {
   369  	var res float64
   370  	for i := 0; i < m; i++ {
   371  		jmax := blas64.Iamax(blas64.Vector{N: n, Data: vtRef[i*ldvt:], Inc: 1})
   372  		s := math.Copysign(1, vtRef[i*ldvt+jmax]) * math.Copysign(1, vt[i*ldvt+jmax])
   373  		for j := 0; j < n; j++ {
   374  			diff := math.Abs(vtRef[i*ldvt+j] - s*vt[i*ldvt+j])
   375  			res = math.Max(res, diff)
   376  		}
   377  	}
   378  	return res
   379  }