github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/gonum/dbdsqr.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"math"
     9  
    10  	"github.com/jingcheng-WU/gonum/blas"
    11  	"github.com/jingcheng-WU/gonum/blas/blas64"
    12  	"github.com/jingcheng-WU/gonum/lapack"
    13  )
    14  
    15  // Dbdsqr performs a singular value decomposition of a real n×n bidiagonal matrix.
    16  //
    17  // The SVD of the bidiagonal matrix B is
    18  //  B = Q * S * Pᵀ
    19  // where S is a diagonal matrix of singular values, Q is an orthogonal matrix of
    20  // left singular vectors, and P is an orthogonal matrix of right singular vectors.
    21  //
    22  // Q and P are only computed if requested. If left singular vectors are requested,
    23  // this routine returns U * Q instead of Q, and if right singular vectors are
    24  // requested Pᵀ * VT is returned instead of Pᵀ.
    25  //
    26  // Frequently Dbdsqr is used in conjunction with Dgebrd which reduces a general
    27  // matrix A into bidiagonal form. In this case, the SVD of A is
    28  //  A = (U * Q) * S * (Pᵀ * VT)
    29  //
    30  // This routine may also compute Qᵀ * C.
    31  //
    32  // d and e contain the elements of the bidiagonal matrix b. d must have length at
    33  // least n, and e must have length at least n-1. Dbdsqr will panic if there is
    34  // insufficient length. On exit, D contains the singular values of B in decreasing
    35  // order.
    36  //
    37  // VT is a matrix of size n×ncvt whose elements are stored in vt. The elements
    38  // of vt are modified to contain Pᵀ * VT on exit. VT is not used if ncvt == 0.
    39  //
    40  // U is a matrix of size nru×n whose elements are stored in u. The elements
    41  // of u are modified to contain U * Q on exit. U is not used if nru == 0.
    42  //
    43  // C is a matrix of size n×ncc whose elements are stored in c. The elements
    44  // of c are modified to contain Qᵀ * C on exit. C is not used if ncc == 0.
    45  //
    46  // work contains temporary storage and must have length at least 4*(n-1). Dbdsqr
    47  // will panic if there is insufficient working memory.
    48  //
    49  // Dbdsqr returns whether the decomposition was successful.
    50  //
    51  // Dbdsqr is an internal routine. It is exported for testing purposes.
    52  func (impl Implementation) Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, vt []float64, ldvt int, u []float64, ldu int, c []float64, ldc int, work []float64) (ok bool) {
    53  	switch {
    54  	case uplo != blas.Upper && uplo != blas.Lower:
    55  		panic(badUplo)
    56  	case n < 0:
    57  		panic(nLT0)
    58  	case ncvt < 0:
    59  		panic(ncvtLT0)
    60  	case nru < 0:
    61  		panic(nruLT0)
    62  	case ncc < 0:
    63  		panic(nccLT0)
    64  	case ldvt < max(1, ncvt):
    65  		panic(badLdVT)
    66  	case (ldu < max(1, n) && nru > 0) || (ldu < 1 && nru == 0):
    67  		panic(badLdU)
    68  	case ldc < max(1, ncc):
    69  		panic(badLdC)
    70  	}
    71  
    72  	// Quick return if possible.
    73  	if n == 0 {
    74  		return true
    75  	}
    76  
    77  	if len(vt) < (n-1)*ldvt+ncvt && ncvt != 0 {
    78  		panic(shortVT)
    79  	}
    80  	if len(u) < (nru-1)*ldu+n && nru != 0 {
    81  		panic(shortU)
    82  	}
    83  	if len(c) < (n-1)*ldc+ncc && ncc != 0 {
    84  		panic(shortC)
    85  	}
    86  	if len(d) < n {
    87  		panic(shortD)
    88  	}
    89  	if len(e) < n-1 {
    90  		panic(shortE)
    91  	}
    92  	if len(work) < 4*(n-1) {
    93  		panic(shortWork)
    94  	}
    95  
    96  	var info int
    97  	bi := blas64.Implementation()
    98  	const maxIter = 6
    99  
   100  	if n != 1 {
   101  		// If the singular vectors do not need to be computed, use qd algorithm.
   102  		if !(ncvt > 0 || nru > 0 || ncc > 0) {
   103  			info = impl.Dlasq1(n, d, e, work)
   104  			// If info is 2 dqds didn't finish, and so try to.
   105  			if info != 2 {
   106  				return info == 0
   107  			}
   108  		}
   109  		nm1 := n - 1
   110  		nm12 := nm1 + nm1
   111  		nm13 := nm12 + nm1
   112  		idir := 0
   113  
   114  		eps := dlamchE
   115  		unfl := dlamchS
   116  		lower := uplo == blas.Lower
   117  		var cs, sn, r float64
   118  		if lower {
   119  			for i := 0; i < n-1; i++ {
   120  				cs, sn, r = impl.Dlartg(d[i], e[i])
   121  				d[i] = r
   122  				e[i] = sn * d[i+1]
   123  				d[i+1] *= cs
   124  				work[i] = cs
   125  				work[nm1+i] = sn
   126  			}
   127  			if nru > 0 {
   128  				impl.Dlasr(blas.Right, lapack.Variable, lapack.Forward, nru, n, work, work[n-1:], u, ldu)
   129  			}
   130  			if ncc > 0 {
   131  				impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, n, ncc, work, work[n-1:], c, ldc)
   132  			}
   133  		}
   134  		// Compute singular values to a relative accuracy of tol. If tol is negative
   135  		// the values will be computed to an absolute accuracy of math.Abs(tol) * norm(b)
   136  		tolmul := math.Max(10, math.Min(100, math.Pow(eps, -1.0/8)))
   137  		tol := tolmul * eps
   138  		var smax float64
   139  		for i := 0; i < n; i++ {
   140  			smax = math.Max(smax, math.Abs(d[i]))
   141  		}
   142  		for i := 0; i < n-1; i++ {
   143  			smax = math.Max(smax, math.Abs(e[i]))
   144  		}
   145  
   146  		var sminl float64
   147  		var thresh float64
   148  		if tol >= 0 {
   149  			sminoa := math.Abs(d[0])
   150  			if sminoa != 0 {
   151  				mu := sminoa
   152  				for i := 1; i < n; i++ {
   153  					mu = math.Abs(d[i]) * (mu / (mu + math.Abs(e[i-1])))
   154  					sminoa = math.Min(sminoa, mu)
   155  					if sminoa == 0 {
   156  						break
   157  					}
   158  				}
   159  			}
   160  			sminoa = sminoa / math.Sqrt(float64(n))
   161  			thresh = math.Max(tol*sminoa, float64(maxIter*n*n)*unfl)
   162  		} else {
   163  			thresh = math.Max(math.Abs(tol)*smax, float64(maxIter*n*n)*unfl)
   164  		}
   165  		// Prepare for the main iteration loop for the singular values.
   166  		maxIt := maxIter * n * n
   167  		iter := 0
   168  		oldl2 := -1
   169  		oldm := -1
   170  		// m points to the last element of unconverged part of matrix.
   171  		m := n
   172  
   173  	Outer:
   174  		for m > 1 {
   175  			if iter > maxIt {
   176  				info = 0
   177  				for i := 0; i < n-1; i++ {
   178  					if e[i] != 0 {
   179  						info++
   180  					}
   181  				}
   182  				return info == 0
   183  			}
   184  			// Find diagonal block of matrix to work on.
   185  			if tol < 0 && math.Abs(d[m-1]) <= thresh {
   186  				d[m-1] = 0
   187  			}
   188  			smax = math.Abs(d[m-1])
   189  			smin := smax
   190  			var l2 int
   191  			var broke bool
   192  			for l3 := 0; l3 < m-1; l3++ {
   193  				l2 = m - l3 - 2
   194  				abss := math.Abs(d[l2])
   195  				abse := math.Abs(e[l2])
   196  				if tol < 0 && abss <= thresh {
   197  					d[l2] = 0
   198  				}
   199  				if abse <= thresh {
   200  					broke = true
   201  					break
   202  				}
   203  				smin = math.Min(smin, abss)
   204  				smax = math.Max(math.Max(smax, abss), abse)
   205  			}
   206  			if broke {
   207  				e[l2] = 0
   208  				if l2 == m-2 {
   209  					// Convergence of bottom singular value, return to top.
   210  					m--
   211  					continue
   212  				}
   213  				l2++
   214  			} else {
   215  				l2 = 0
   216  			}
   217  			// e[ll] through e[m-2] are nonzero, e[ll-1] is zero
   218  			if l2 == m-2 {
   219  				// Handle 2×2 block separately.
   220  				var sinr, cosr, sinl, cosl float64
   221  				d[m-1], d[m-2], sinr, cosr, sinl, cosl = impl.Dlasv2(d[m-2], e[m-2], d[m-1])
   222  				e[m-2] = 0
   223  				if ncvt > 0 {
   224  					bi.Drot(ncvt, vt[(m-2)*ldvt:], 1, vt[(m-1)*ldvt:], 1, cosr, sinr)
   225  				}
   226  				if nru > 0 {
   227  					bi.Drot(nru, u[m-2:], ldu, u[m-1:], ldu, cosl, sinl)
   228  				}
   229  				if ncc > 0 {
   230  					bi.Drot(ncc, c[(m-2)*ldc:], 1, c[(m-1)*ldc:], 1, cosl, sinl)
   231  				}
   232  				m -= 2
   233  				continue
   234  			}
   235  			// If working on a new submatrix, choose shift direction from larger end
   236  			// diagonal element toward smaller.
   237  			if l2 > oldm-1 || m-1 < oldl2 {
   238  				if math.Abs(d[l2]) >= math.Abs(d[m-1]) {
   239  					idir = 1
   240  				} else {
   241  					idir = 2
   242  				}
   243  			}
   244  			// Apply convergence tests.
   245  			// TODO(btracey): There is a lot of similar looking code here. See
   246  			// if there is a better way to de-duplicate.
   247  			if idir == 1 {
   248  				// Run convergence test in forward direction.
   249  				// First apply standard test to bottom of matrix.
   250  				if math.Abs(e[m-2]) <= math.Abs(tol)*math.Abs(d[m-1]) || (tol < 0 && math.Abs(e[m-2]) <= thresh) {
   251  					e[m-2] = 0
   252  					continue
   253  				}
   254  				if tol >= 0 {
   255  					// If relative accuracy desired, apply convergence criterion forward.
   256  					mu := math.Abs(d[l2])
   257  					sminl = mu
   258  					for l3 := l2; l3 < m-1; l3++ {
   259  						if math.Abs(e[l3]) <= tol*mu {
   260  							e[l3] = 0
   261  							continue Outer
   262  						}
   263  						mu = math.Abs(d[l3+1]) * (mu / (mu + math.Abs(e[l3])))
   264  						sminl = math.Min(sminl, mu)
   265  					}
   266  				}
   267  			} else {
   268  				// Run convergence test in backward direction.
   269  				// First apply standard test to top of matrix.
   270  				if math.Abs(e[l2]) <= math.Abs(tol)*math.Abs(d[l2]) || (tol < 0 && math.Abs(e[l2]) <= thresh) {
   271  					e[l2] = 0
   272  					continue
   273  				}
   274  				if tol >= 0 {
   275  					// If relative accuracy desired, apply convergence criterion backward.
   276  					mu := math.Abs(d[m-1])
   277  					sminl = mu
   278  					for l3 := m - 2; l3 >= l2; l3-- {
   279  						if math.Abs(e[l3]) <= tol*mu {
   280  							e[l3] = 0
   281  							continue Outer
   282  						}
   283  						mu = math.Abs(d[l3]) * (mu / (mu + math.Abs(e[l3])))
   284  						sminl = math.Min(sminl, mu)
   285  					}
   286  				}
   287  			}
   288  			oldl2 = l2
   289  			oldm = m
   290  			// Compute shift. First, test if shifting would ruin relative accuracy,
   291  			// and if so set the shift to zero.
   292  			var shift float64
   293  			if tol >= 0 && float64(n)*tol*(sminl/smax) <= math.Max(eps, (1.0/100)*tol) {
   294  				shift = 0
   295  			} else {
   296  				var sl2 float64
   297  				if idir == 1 {
   298  					sl2 = math.Abs(d[l2])
   299  					shift, _ = impl.Dlas2(d[m-2], e[m-2], d[m-1])
   300  				} else {
   301  					sl2 = math.Abs(d[m-1])
   302  					shift, _ = impl.Dlas2(d[l2], e[l2], d[l2+1])
   303  				}
   304  				// Test if shift is negligible
   305  				if sl2 > 0 {
   306  					if (shift/sl2)*(shift/sl2) < eps {
   307  						shift = 0
   308  					}
   309  				}
   310  			}
   311  			iter += m - l2 + 1
   312  			// If no shift, do simplified QR iteration.
   313  			if shift == 0 {
   314  				if idir == 1 {
   315  					cs := 1.0
   316  					oldcs := 1.0
   317  					var sn, r, oldsn float64
   318  					for i := l2; i < m-1; i++ {
   319  						cs, sn, r = impl.Dlartg(d[i]*cs, e[i])
   320  						if i > l2 {
   321  							e[i-1] = oldsn * r
   322  						}
   323  						oldcs, oldsn, d[i] = impl.Dlartg(oldcs*r, d[i+1]*sn)
   324  						work[i-l2] = cs
   325  						work[i-l2+nm1] = sn
   326  						work[i-l2+nm12] = oldcs
   327  						work[i-l2+nm13] = oldsn
   328  					}
   329  					h := d[m-1] * cs
   330  					d[m-1] = h * oldcs
   331  					e[m-2] = h * oldsn
   332  					if ncvt > 0 {
   333  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncvt, work, work[n-1:], vt[l2*ldvt:], ldvt)
   334  					}
   335  					if nru > 0 {
   336  						impl.Dlasr(blas.Right, lapack.Variable, lapack.Forward, nru, m-l2, work[nm12:], work[nm13:], u[l2:], ldu)
   337  					}
   338  					if ncc > 0 {
   339  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncc, work[nm12:], work[nm13:], c[l2*ldc:], ldc)
   340  					}
   341  					if math.Abs(e[m-2]) < thresh {
   342  						e[m-2] = 0
   343  					}
   344  				} else {
   345  					cs := 1.0
   346  					oldcs := 1.0
   347  					var sn, r, oldsn float64
   348  					for i := m - 1; i >= l2+1; i-- {
   349  						cs, sn, r = impl.Dlartg(d[i]*cs, e[i-1])
   350  						if i < m-1 {
   351  							e[i] = oldsn * r
   352  						}
   353  						oldcs, oldsn, d[i] = impl.Dlartg(oldcs*r, d[i-1]*sn)
   354  						work[i-l2-1] = cs
   355  						work[i-l2+nm1-1] = -sn
   356  						work[i-l2+nm12-1] = oldcs
   357  						work[i-l2+nm13-1] = -oldsn
   358  					}
   359  					h := d[l2] * cs
   360  					d[l2] = h * oldcs
   361  					e[l2] = h * oldsn
   362  					if ncvt > 0 {
   363  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncvt, work[nm12:], work[nm13:], vt[l2*ldvt:], ldvt)
   364  					}
   365  					if nru > 0 {
   366  						impl.Dlasr(blas.Right, lapack.Variable, lapack.Backward, nru, m-l2, work, work[n-1:], u[l2:], ldu)
   367  					}
   368  					if ncc > 0 {
   369  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncc, work, work[n-1:], c[l2*ldc:], ldc)
   370  					}
   371  					if math.Abs(e[l2]) <= thresh {
   372  						e[l2] = 0
   373  					}
   374  				}
   375  			} else {
   376  				// Use nonzero shift.
   377  				if idir == 1 {
   378  					// Chase bulge from top to bottom. Save cosines and sines for
   379  					// later singular vector updates.
   380  					f := (math.Abs(d[l2]) - shift) * (math.Copysign(1, d[l2]) + shift/d[l2])
   381  					g := e[l2]
   382  					var cosl, sinl float64
   383  					for i := l2; i < m-1; i++ {
   384  						cosr, sinr, r := impl.Dlartg(f, g)
   385  						if i > l2 {
   386  							e[i-1] = r
   387  						}
   388  						f = cosr*d[i] + sinr*e[i]
   389  						e[i] = cosr*e[i] - sinr*d[i]
   390  						g = sinr * d[i+1]
   391  						d[i+1] *= cosr
   392  						cosl, sinl, r = impl.Dlartg(f, g)
   393  						d[i] = r
   394  						f = cosl*e[i] + sinl*d[i+1]
   395  						d[i+1] = cosl*d[i+1] - sinl*e[i]
   396  						if i < m-2 {
   397  							g = sinl * e[i+1]
   398  							e[i+1] = cosl * e[i+1]
   399  						}
   400  						work[i-l2] = cosr
   401  						work[i-l2+nm1] = sinr
   402  						work[i-l2+nm12] = cosl
   403  						work[i-l2+nm13] = sinl
   404  					}
   405  					e[m-2] = f
   406  					if ncvt > 0 {
   407  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncvt, work, work[n-1:], vt[l2*ldvt:], ldvt)
   408  					}
   409  					if nru > 0 {
   410  						impl.Dlasr(blas.Right, lapack.Variable, lapack.Forward, nru, m-l2, work[nm12:], work[nm13:], u[l2:], ldu)
   411  					}
   412  					if ncc > 0 {
   413  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncc, work[nm12:], work[nm13:], c[l2*ldc:], ldc)
   414  					}
   415  					if math.Abs(e[m-2]) <= thresh {
   416  						e[m-2] = 0
   417  					}
   418  				} else {
   419  					// Chase bulge from top to bottom. Save cosines and sines for
   420  					// later singular vector updates.
   421  					f := (math.Abs(d[m-1]) - shift) * (math.Copysign(1, d[m-1]) + shift/d[m-1])
   422  					g := e[m-2]
   423  					for i := m - 1; i > l2; i-- {
   424  						cosr, sinr, r := impl.Dlartg(f, g)
   425  						if i < m-1 {
   426  							e[i] = r
   427  						}
   428  						f = cosr*d[i] + sinr*e[i-1]
   429  						e[i-1] = cosr*e[i-1] - sinr*d[i]
   430  						g = sinr * d[i-1]
   431  						d[i-1] *= cosr
   432  						cosl, sinl, r := impl.Dlartg(f, g)
   433  						d[i] = r
   434  						f = cosl*e[i-1] + sinl*d[i-1]
   435  						d[i-1] = cosl*d[i-1] - sinl*e[i-1]
   436  						if i > l2+1 {
   437  							g = sinl * e[i-2]
   438  							e[i-2] *= cosl
   439  						}
   440  						work[i-l2-1] = cosr
   441  						work[i-l2+nm1-1] = -sinr
   442  						work[i-l2+nm12-1] = cosl
   443  						work[i-l2+nm13-1] = -sinl
   444  					}
   445  					e[l2] = f
   446  					if math.Abs(e[l2]) <= thresh {
   447  						e[l2] = 0
   448  					}
   449  					if ncvt > 0 {
   450  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncvt, work[nm12:], work[nm13:], vt[l2*ldvt:], ldvt)
   451  					}
   452  					if nru > 0 {
   453  						impl.Dlasr(blas.Right, lapack.Variable, lapack.Backward, nru, m-l2, work, work[n-1:], u[l2:], ldu)
   454  					}
   455  					if ncc > 0 {
   456  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncc, work, work[n-1:], c[l2*ldc:], ldc)
   457  					}
   458  				}
   459  			}
   460  		}
   461  	}
   462  
   463  	// All singular values converged, make them positive.
   464  	for i := 0; i < n; i++ {
   465  		if d[i] < 0 {
   466  			d[i] *= -1
   467  			if ncvt > 0 {
   468  				bi.Dscal(ncvt, -1, vt[i*ldvt:], 1)
   469  			}
   470  		}
   471  	}
   472  
   473  	// Sort the singular values in decreasing order.
   474  	for i := 0; i < n-1; i++ {
   475  		isub := 0
   476  		smin := d[0]
   477  		for j := 1; j < n-i; j++ {
   478  			if d[j] <= smin {
   479  				isub = j
   480  				smin = d[j]
   481  			}
   482  		}
   483  		if isub != n-i {
   484  			// Swap singular values and vectors.
   485  			d[isub] = d[n-i-1]
   486  			d[n-i-1] = smin
   487  			if ncvt > 0 {
   488  				bi.Dswap(ncvt, vt[isub*ldvt:], 1, vt[(n-i-1)*ldvt:], 1)
   489  			}
   490  			if nru > 0 {
   491  				bi.Dswap(nru, u[isub:], ldu, u[n-i-1:], ldu)
   492  			}
   493  			if ncc > 0 {
   494  				bi.Dswap(ncc, c[isub*ldc:], 1, c[(n-i-1)*ldc:], 1)
   495  			}
   496  		}
   497  	}
   498  	info = 0
   499  	for i := 0; i < n-1; i++ {
   500  		if e[i] != 0 {
   501  			info++
   502  		}
   503  	}
   504  	return info == 0
   505  }