gonum.org/v1/gonum@v0.14.0/lapack/gonum/dbdsqr.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"math"
     9  
    10  	"gonum.org/v1/gonum/blas"
    11  	"gonum.org/v1/gonum/blas/blas64"
    12  	"gonum.org/v1/gonum/lapack"
    13  )
    14  
    15  // Dbdsqr performs a singular value decomposition of a real n×n bidiagonal matrix.
    16  //
    17  // The SVD of the bidiagonal matrix B is
    18  //
    19  //	B = Q * S * Pᵀ
    20  //
    21  // where S is a diagonal matrix of singular values, Q is an orthogonal matrix of
    22  // left singular vectors, and P is an orthogonal matrix of right singular vectors.
    23  //
    24  // Q and P are only computed if requested. If left singular vectors are requested,
    25  // this routine returns U * Q instead of Q, and if right singular vectors are
    26  // requested Pᵀ * VT is returned instead of Pᵀ.
    27  //
    28  // Frequently Dbdsqr is used in conjunction with Dgebrd which reduces a general
    29  // matrix A into bidiagonal form. In this case, the SVD of A is
    30  //
    31  //	A = (U * Q) * S * (Pᵀ * VT)
    32  //
    33  // This routine may also compute Qᵀ * C.
    34  //
    35  // d and e contain the elements of the bidiagonal matrix b. d must have length at
    36  // least n, and e must have length at least n-1. Dbdsqr will panic if there is
    37  // insufficient length. On exit, D contains the singular values of B in decreasing
    38  // order.
    39  //
    40  // VT is a matrix of size n×ncvt whose elements are stored in vt. The elements
    41  // of vt are modified to contain Pᵀ * VT on exit. VT is not used if ncvt == 0.
    42  //
    43  // U is a matrix of size nru×n whose elements are stored in u. The elements
    44  // of u are modified to contain U * Q on exit. U is not used if nru == 0.
    45  //
    46  // C is a matrix of size n×ncc whose elements are stored in c. The elements
    47  // of c are modified to contain Qᵀ * C on exit. C is not used if ncc == 0.
    48  //
    49  // work contains temporary storage and must have length at least 4*(n-1). Dbdsqr
    50  // will panic if there is insufficient working memory.
    51  //
    52  // Dbdsqr returns whether the decomposition was successful.
    53  //
    54  // Dbdsqr is an internal routine. It is exported for testing purposes.
    55  func (impl Implementation) Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, vt []float64, ldvt int, u []float64, ldu int, c []float64, ldc int, work []float64) (ok bool) {
    56  	switch {
    57  	case uplo != blas.Upper && uplo != blas.Lower:
    58  		panic(badUplo)
    59  	case n < 0:
    60  		panic(nLT0)
    61  	case ncvt < 0:
    62  		panic(ncvtLT0)
    63  	case nru < 0:
    64  		panic(nruLT0)
    65  	case ncc < 0:
    66  		panic(nccLT0)
    67  	case ldvt < max(1, ncvt):
    68  		panic(badLdVT)
    69  	case (ldu < max(1, n) && nru > 0) || (ldu < 1 && nru == 0):
    70  		panic(badLdU)
    71  	case ldc < max(1, ncc):
    72  		panic(badLdC)
    73  	}
    74  
    75  	// Quick return if possible.
    76  	if n == 0 {
    77  		return true
    78  	}
    79  
    80  	if len(vt) < (n-1)*ldvt+ncvt && ncvt != 0 {
    81  		panic(shortVT)
    82  	}
    83  	if len(u) < (nru-1)*ldu+n && nru != 0 {
    84  		panic(shortU)
    85  	}
    86  	if len(c) < (n-1)*ldc+ncc && ncc != 0 {
    87  		panic(shortC)
    88  	}
    89  	if len(d) < n {
    90  		panic(shortD)
    91  	}
    92  	if len(e) < n-1 {
    93  		panic(shortE)
    94  	}
    95  	if len(work) < 4*(n-1) {
    96  		panic(shortWork)
    97  	}
    98  
    99  	var info int
   100  	bi := blas64.Implementation()
   101  	const maxIter = 6
   102  
   103  	if n != 1 {
   104  		// If the singular vectors do not need to be computed, use qd algorithm.
   105  		if !(ncvt > 0 || nru > 0 || ncc > 0) {
   106  			info = impl.Dlasq1(n, d, e, work)
   107  			// If info is 2 dqds didn't finish, and so try to.
   108  			if info != 2 {
   109  				return info == 0
   110  			}
   111  		}
   112  		nm1 := n - 1
   113  		nm12 := nm1 + nm1
   114  		nm13 := nm12 + nm1
   115  		idir := 0
   116  
   117  		eps := dlamchE
   118  		unfl := dlamchS
   119  		lower := uplo == blas.Lower
   120  		var cs, sn, r float64
   121  		if lower {
   122  			for i := 0; i < n-1; i++ {
   123  				cs, sn, r = impl.Dlartg(d[i], e[i])
   124  				d[i] = r
   125  				e[i] = sn * d[i+1]
   126  				d[i+1] *= cs
   127  				work[i] = cs
   128  				work[nm1+i] = sn
   129  			}
   130  			if nru > 0 {
   131  				impl.Dlasr(blas.Right, lapack.Variable, lapack.Forward, nru, n, work, work[n-1:], u, ldu)
   132  			}
   133  			if ncc > 0 {
   134  				impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, n, ncc, work, work[n-1:], c, ldc)
   135  			}
   136  		}
   137  		// Compute singular values to a relative accuracy of tol. If tol is negative
   138  		// the values will be computed to an absolute accuracy of math.Abs(tol) * norm(b)
   139  		tolmul := math.Max(10, math.Min(100, math.Pow(eps, -1.0/8)))
   140  		tol := tolmul * eps
   141  		var smax float64
   142  		for i := 0; i < n; i++ {
   143  			smax = math.Max(smax, math.Abs(d[i]))
   144  		}
   145  		for i := 0; i < n-1; i++ {
   146  			smax = math.Max(smax, math.Abs(e[i]))
   147  		}
   148  
   149  		var smin float64
   150  		var thresh float64
   151  		if tol >= 0 {
   152  			sminoa := math.Abs(d[0])
   153  			if sminoa != 0 {
   154  				mu := sminoa
   155  				for i := 1; i < n; i++ {
   156  					mu = math.Abs(d[i]) * (mu / (mu + math.Abs(e[i-1])))
   157  					sminoa = math.Min(sminoa, mu)
   158  					if sminoa == 0 {
   159  						break
   160  					}
   161  				}
   162  			}
   163  			sminoa = sminoa / math.Sqrt(float64(n))
   164  			thresh = math.Max(tol*sminoa, float64(maxIter*n*n)*unfl)
   165  		} else {
   166  			thresh = math.Max(math.Abs(tol)*smax, float64(maxIter*n*n)*unfl)
   167  		}
   168  		// Prepare for the main iteration loop for the singular values.
   169  		maxIt := maxIter * n * n
   170  		iter := 0
   171  		oldl2 := -1
   172  		oldm := -1
   173  		// m points to the last element of unconverged part of matrix.
   174  		m := n
   175  
   176  	Outer:
   177  		for m > 1 {
   178  			if iter > maxIt {
   179  				info = 0
   180  				for i := 0; i < n-1; i++ {
   181  					if e[i] != 0 {
   182  						info++
   183  					}
   184  				}
   185  				return info == 0
   186  			}
   187  			// Find diagonal block of matrix to work on.
   188  			if tol < 0 && math.Abs(d[m-1]) <= thresh {
   189  				d[m-1] = 0
   190  			}
   191  			smax = math.Abs(d[m-1])
   192  			var l2 int
   193  			var broke bool
   194  			for l3 := 0; l3 < m-1; l3++ {
   195  				l2 = m - l3 - 2
   196  				abss := math.Abs(d[l2])
   197  				abse := math.Abs(e[l2])
   198  				if tol < 0 && abss <= thresh {
   199  					d[l2] = 0
   200  				}
   201  				if abse <= thresh {
   202  					broke = true
   203  					break
   204  				}
   205  				smax = math.Max(math.Max(smax, abss), abse)
   206  			}
   207  			if broke {
   208  				e[l2] = 0
   209  				if l2 == m-2 {
   210  					// Convergence of bottom singular value, return to top.
   211  					m--
   212  					continue
   213  				}
   214  				l2++
   215  			} else {
   216  				l2 = 0
   217  			}
   218  			// e[ll] through e[m-2] are nonzero, e[ll-1] is zero
   219  			if l2 == m-2 {
   220  				// Handle 2×2 block separately.
   221  				var sinr, cosr, sinl, cosl float64
   222  				d[m-1], d[m-2], sinr, cosr, sinl, cosl = impl.Dlasv2(d[m-2], e[m-2], d[m-1])
   223  				e[m-2] = 0
   224  				if ncvt > 0 {
   225  					bi.Drot(ncvt, vt[(m-2)*ldvt:], 1, vt[(m-1)*ldvt:], 1, cosr, sinr)
   226  				}
   227  				if nru > 0 {
   228  					bi.Drot(nru, u[m-2:], ldu, u[m-1:], ldu, cosl, sinl)
   229  				}
   230  				if ncc > 0 {
   231  					bi.Drot(ncc, c[(m-2)*ldc:], 1, c[(m-1)*ldc:], 1, cosl, sinl)
   232  				}
   233  				m -= 2
   234  				continue
   235  			}
   236  			// If working on a new submatrix, choose shift direction from larger end
   237  			// diagonal element toward smaller.
   238  			if l2 > oldm-1 || m-1 < oldl2 {
   239  				if math.Abs(d[l2]) >= math.Abs(d[m-1]) {
   240  					idir = 1
   241  				} else {
   242  					idir = 2
   243  				}
   244  			}
   245  			// Apply convergence tests.
   246  			// TODO(btracey): There is a lot of similar looking code here. See
   247  			// if there is a better way to de-duplicate.
   248  			if idir == 1 {
   249  				// Run convergence test in forward direction.
   250  				// First apply standard test to bottom of matrix.
   251  				if math.Abs(e[m-2]) <= math.Abs(tol)*math.Abs(d[m-1]) || (tol < 0 && math.Abs(e[m-2]) <= thresh) {
   252  					e[m-2] = 0
   253  					continue
   254  				}
   255  				if tol >= 0 {
   256  					// If relative accuracy desired, apply convergence criterion forward.
   257  					mu := math.Abs(d[l2])
   258  					smin = mu
   259  					for l3 := l2; l3 < m-1; l3++ {
   260  						if math.Abs(e[l3]) <= tol*mu {
   261  							e[l3] = 0
   262  							continue Outer
   263  						}
   264  						mu = math.Abs(d[l3+1]) * (mu / (mu + math.Abs(e[l3])))
   265  						smin = math.Min(smin, mu)
   266  					}
   267  				}
   268  			} else {
   269  				// Run convergence test in backward direction.
   270  				// First apply standard test to top of matrix.
   271  				if math.Abs(e[l2]) <= math.Abs(tol)*math.Abs(d[l2]) || (tol < 0 && math.Abs(e[l2]) <= thresh) {
   272  					e[l2] = 0
   273  					continue
   274  				}
   275  				if tol >= 0 {
   276  					// If relative accuracy desired, apply convergence criterion backward.
   277  					mu := math.Abs(d[m-1])
   278  					smin = mu
   279  					for l3 := m - 2; l3 >= l2; l3-- {
   280  						if math.Abs(e[l3]) <= tol*mu {
   281  							e[l3] = 0
   282  							continue Outer
   283  						}
   284  						mu = math.Abs(d[l3]) * (mu / (mu + math.Abs(e[l3])))
   285  						smin = math.Min(smin, mu)
   286  					}
   287  				}
   288  			}
   289  			oldl2 = l2
   290  			oldm = m
   291  			// Compute shift. First, test if shifting would ruin relative accuracy,
   292  			// and if so set the shift to zero.
   293  			var shift float64
   294  			if tol >= 0 && float64(n)*tol*(smin/smax) <= math.Max(eps, (1.0/100)*tol) {
   295  				shift = 0
   296  			} else {
   297  				var sl2 float64
   298  				if idir == 1 {
   299  					sl2 = math.Abs(d[l2])
   300  					shift, _ = impl.Dlas2(d[m-2], e[m-2], d[m-1])
   301  				} else {
   302  					sl2 = math.Abs(d[m-1])
   303  					shift, _ = impl.Dlas2(d[l2], e[l2], d[l2+1])
   304  				}
   305  				// Test if shift is negligible
   306  				if sl2 > 0 {
   307  					if (shift/sl2)*(shift/sl2) < eps {
   308  						shift = 0
   309  					}
   310  				}
   311  			}
   312  			iter += m - l2 + 1
   313  			// If no shift, do simplified QR iteration.
   314  			if shift == 0 {
   315  				if idir == 1 {
   316  					cs := 1.0
   317  					oldcs := 1.0
   318  					var sn, r, oldsn float64
   319  					for i := l2; i < m-1; i++ {
   320  						cs, sn, r = impl.Dlartg(d[i]*cs, e[i])
   321  						if i > l2 {
   322  							e[i-1] = oldsn * r
   323  						}
   324  						oldcs, oldsn, d[i] = impl.Dlartg(oldcs*r, d[i+1]*sn)
   325  						work[i-l2] = cs
   326  						work[i-l2+nm1] = sn
   327  						work[i-l2+nm12] = oldcs
   328  						work[i-l2+nm13] = oldsn
   329  					}
   330  					h := d[m-1] * cs
   331  					d[m-1] = h * oldcs
   332  					e[m-2] = h * oldsn
   333  					if ncvt > 0 {
   334  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncvt, work, work[n-1:], vt[l2*ldvt:], ldvt)
   335  					}
   336  					if nru > 0 {
   337  						impl.Dlasr(blas.Right, lapack.Variable, lapack.Forward, nru, m-l2, work[nm12:], work[nm13:], u[l2:], ldu)
   338  					}
   339  					if ncc > 0 {
   340  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncc, work[nm12:], work[nm13:], c[l2*ldc:], ldc)
   341  					}
   342  					if math.Abs(e[m-2]) < thresh {
   343  						e[m-2] = 0
   344  					}
   345  				} else {
   346  					cs := 1.0
   347  					oldcs := 1.0
   348  					var sn, r, oldsn float64
   349  					for i := m - 1; i >= l2+1; i-- {
   350  						cs, sn, r = impl.Dlartg(d[i]*cs, e[i-1])
   351  						if i < m-1 {
   352  							e[i] = oldsn * r
   353  						}
   354  						oldcs, oldsn, d[i] = impl.Dlartg(oldcs*r, d[i-1]*sn)
   355  						work[i-l2-1] = cs
   356  						work[i-l2+nm1-1] = -sn
   357  						work[i-l2+nm12-1] = oldcs
   358  						work[i-l2+nm13-1] = -oldsn
   359  					}
   360  					h := d[l2] * cs
   361  					d[l2] = h * oldcs
   362  					e[l2] = h * oldsn
   363  					if ncvt > 0 {
   364  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncvt, work[nm12:], work[nm13:], vt[l2*ldvt:], ldvt)
   365  					}
   366  					if nru > 0 {
   367  						impl.Dlasr(blas.Right, lapack.Variable, lapack.Backward, nru, m-l2, work, work[n-1:], u[l2:], ldu)
   368  					}
   369  					if ncc > 0 {
   370  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncc, work, work[n-1:], c[l2*ldc:], ldc)
   371  					}
   372  					if math.Abs(e[l2]) <= thresh {
   373  						e[l2] = 0
   374  					}
   375  				}
   376  			} else {
   377  				// Use nonzero shift.
   378  				if idir == 1 {
   379  					// Chase bulge from top to bottom. Save cosines and sines for
   380  					// later singular vector updates.
   381  					f := (math.Abs(d[l2]) - shift) * (math.Copysign(1, d[l2]) + shift/d[l2])
   382  					g := e[l2]
   383  					var cosl, sinl float64
   384  					for i := l2; i < m-1; i++ {
   385  						cosr, sinr, r := impl.Dlartg(f, g)
   386  						if i > l2 {
   387  							e[i-1] = r
   388  						}
   389  						f = cosr*d[i] + sinr*e[i]
   390  						e[i] = cosr*e[i] - sinr*d[i]
   391  						g = sinr * d[i+1]
   392  						d[i+1] *= cosr
   393  						cosl, sinl, r = impl.Dlartg(f, g)
   394  						d[i] = r
   395  						f = cosl*e[i] + sinl*d[i+1]
   396  						d[i+1] = cosl*d[i+1] - sinl*e[i]
   397  						if i < m-2 {
   398  							g = sinl * e[i+1]
   399  							e[i+1] = cosl * e[i+1]
   400  						}
   401  						work[i-l2] = cosr
   402  						work[i-l2+nm1] = sinr
   403  						work[i-l2+nm12] = cosl
   404  						work[i-l2+nm13] = sinl
   405  					}
   406  					e[m-2] = f
   407  					if ncvt > 0 {
   408  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncvt, work, work[n-1:], vt[l2*ldvt:], ldvt)
   409  					}
   410  					if nru > 0 {
   411  						impl.Dlasr(blas.Right, lapack.Variable, lapack.Forward, nru, m-l2, work[nm12:], work[nm13:], u[l2:], ldu)
   412  					}
   413  					if ncc > 0 {
   414  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncc, work[nm12:], work[nm13:], c[l2*ldc:], ldc)
   415  					}
   416  					if math.Abs(e[m-2]) <= thresh {
   417  						e[m-2] = 0
   418  					}
   419  				} else {
   420  					// Chase bulge from top to bottom. Save cosines and sines for
   421  					// later singular vector updates.
   422  					f := (math.Abs(d[m-1]) - shift) * (math.Copysign(1, d[m-1]) + shift/d[m-1])
   423  					g := e[m-2]
   424  					for i := m - 1; i > l2; i-- {
   425  						cosr, sinr, r := impl.Dlartg(f, g)
   426  						if i < m-1 {
   427  							e[i] = r
   428  						}
   429  						f = cosr*d[i] + sinr*e[i-1]
   430  						e[i-1] = cosr*e[i-1] - sinr*d[i]
   431  						g = sinr * d[i-1]
   432  						d[i-1] *= cosr
   433  						cosl, sinl, r := impl.Dlartg(f, g)
   434  						d[i] = r
   435  						f = cosl*e[i-1] + sinl*d[i-1]
   436  						d[i-1] = cosl*d[i-1] - sinl*e[i-1]
   437  						if i > l2+1 {
   438  							g = sinl * e[i-2]
   439  							e[i-2] *= cosl
   440  						}
   441  						work[i-l2-1] = cosr
   442  						work[i-l2+nm1-1] = -sinr
   443  						work[i-l2+nm12-1] = cosl
   444  						work[i-l2+nm13-1] = -sinl
   445  					}
   446  					e[l2] = f
   447  					if math.Abs(e[l2]) <= thresh {
   448  						e[l2] = 0
   449  					}
   450  					if ncvt > 0 {
   451  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncvt, work[nm12:], work[nm13:], vt[l2*ldvt:], ldvt)
   452  					}
   453  					if nru > 0 {
   454  						impl.Dlasr(blas.Right, lapack.Variable, lapack.Backward, nru, m-l2, work, work[n-1:], u[l2:], ldu)
   455  					}
   456  					if ncc > 0 {
   457  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncc, work, work[n-1:], c[l2*ldc:], ldc)
   458  					}
   459  				}
   460  			}
   461  		}
   462  	}
   463  
   464  	// All singular values converged, make them positive.
   465  	for i := 0; i < n; i++ {
   466  		if d[i] < 0 {
   467  			d[i] *= -1
   468  			if ncvt > 0 {
   469  				bi.Dscal(ncvt, -1, vt[i*ldvt:], 1)
   470  			}
   471  		}
   472  	}
   473  
   474  	// Sort the singular values in decreasing order.
   475  	for i := 0; i < n-1; i++ {
   476  		isub := 0
   477  		smin := d[0]
   478  		for j := 1; j < n-i; j++ {
   479  			if d[j] <= smin {
   480  				isub = j
   481  				smin = d[j]
   482  			}
   483  		}
   484  		if isub != n-i {
   485  			// Swap singular values and vectors.
   486  			d[isub] = d[n-i-1]
   487  			d[n-i-1] = smin
   488  			if ncvt > 0 {
   489  				bi.Dswap(ncvt, vt[isub*ldvt:], 1, vt[(n-i-1)*ldvt:], 1)
   490  			}
   491  			if nru > 0 {
   492  				bi.Dswap(nru, u[isub:], ldu, u[n-i-1:], ldu)
   493  			}
   494  			if ncc > 0 {
   495  				bi.Dswap(ncc, c[isub*ldc:], 1, c[(n-i-1)*ldc:], 1)
   496  			}
   497  		}
   498  	}
   499  	info = 0
   500  	for i := 0; i < n-1; i++ {
   501  		if e[i] != 0 {
   502  			info++
   503  		}
   504  	}
   505  	return info == 0
   506  }