github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/native/dbdsqr.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package native
     6  
     7  import (
     8  	"math"
     9  
    10  	"github.com/gonum/blas"
    11  	"github.com/gonum/blas/blas64"
    12  	"github.com/gonum/lapack"
    13  )
    14  
    15  // Dbdsqr performs a singular value decomposition of a real n×n bidiagonal matrix.
    16  //
    17  // The SVD of the bidiagonal matrix B is
    18  //  B = Q * S * P^T
    19  // where S is a diagonal matrix of singular values, Q is an orthogonal matrix of
    20  // left singular vectors, and P is an orthogonal matrix of right singular vectors.
    21  //
    22  // Q and P are only computed if requested. If left singular vectors are requested,
    23  // this routine returns U * Q instead of Q, and if right singular vectors are
    24  // requested P^T * VT is returned instead of P^T.
    25  //
    26  // Frequently Dbdsqr is used in conjunction with Dgebrd which reduces a general
    27  // matrix A into bidiagonal form. In this case, the SVD of A is
    28  //  A = (U * Q) * S * (P^T * VT)
    29  //
    30  // This routine may also compute Q^T * C.
    31  //
    32  // d and e contain the elements of the bidiagonal matrix b. d must have length at
    33  // least n, and e must have length at least n-1. Dbdsqr will panic if there is
    34  // insufficient length. On exit, D contains the singular values of B in decreasing
    35  // order.
    36  //
    37  // VT is a matrix of size n×ncvt whose elements are stored in vt. The elements
    38  // of vt are modified to contain P^T * VT on exit. VT is not used if ncvt == 0.
    39  //
    40  // U is a matrix of size nru×n whose elements are stored in u. The elements
    41  // of u are modified to contain U * Q on exit. U is not used if nru == 0.
    42  //
    43  // C is a matrix of size n×ncc whose elements are stored in c. The elements
    44  // of c are modified to contain Q^T * C on exit. C is not used if ncc == 0.
    45  //
    46  // work contains temporary storage and must have length at least 4*n. Dbdsqr
    47  // will panic if there is insufficient working memory.
    48  //
    49  // Dbdsqr returns whether the decomposition was successful.
    50  //
    51  // Dbdsqr is an internal routine. It is exported for testing purposes.
    52  func (impl Implementation) Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, vt []float64, ldvt int, u []float64, ldu int, c []float64, ldc int, work []float64) (ok bool) {
    53  	if uplo != blas.Upper && uplo != blas.Lower {
    54  		panic(badUplo)
    55  	}
    56  	if ncvt != 0 {
    57  		checkMatrix(n, ncvt, vt, ldvt)
    58  	}
    59  	if nru != 0 {
    60  		checkMatrix(nru, n, u, ldu)
    61  	}
    62  	if ncc != 0 {
    63  		checkMatrix(n, ncc, c, ldc)
    64  	}
    65  	if len(d) < n {
    66  		panic(badD)
    67  	}
    68  	if len(e) < n-1 {
    69  		panic(badE)
    70  	}
    71  	if len(work) < 4*n {
    72  		panic(badWork)
    73  	}
    74  	var info int
    75  	bi := blas64.Implementation()
    76  	const (
    77  		maxIter = 6
    78  	)
    79  	if n == 0 {
    80  		return true
    81  	}
    82  	if n != 1 {
    83  		// If the singular vectors do not need to be computed, use qd algorithm.
    84  		if !(ncvt > 0 || nru > 0 || ncc > 0) {
    85  			info = impl.Dlasq1(n, d, e, work)
    86  			// If info is 2 dqds didn't finish, and so try to.
    87  			if info != 2 {
    88  				return info == 0
    89  			}
    90  			info = 0
    91  		}
    92  		nm1 := n - 1
    93  		nm12 := nm1 + nm1
    94  		nm13 := nm12 + nm1
    95  		idir := 0
    96  
    97  		eps := dlamchE
    98  		unfl := dlamchS
    99  		lower := uplo == blas.Lower
   100  		var cs, sn, r float64
   101  		if lower {
   102  			for i := 0; i < n-1; i++ {
   103  				cs, sn, r = impl.Dlartg(d[i], e[i])
   104  				d[i] = r
   105  				e[i] = sn * d[i+1]
   106  				d[i+1] *= cs
   107  				work[i] = cs
   108  				work[nm1+i] = sn
   109  			}
   110  			if nru > 0 {
   111  				impl.Dlasr(blas.Right, lapack.Variable, lapack.Forward, nru, n, work, work[n-1:], u, ldu)
   112  			}
   113  			if ncc > 0 {
   114  				impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, n, ncc, work, work[n-1:], c, ldc)
   115  			}
   116  		}
   117  		// Compute singular values to a relative accuracy of tol. If tol is negative
   118  		// the values will be computed to an absolute accuracy of math.Abs(tol) * norm(b)
   119  		tolmul := math.Max(10, math.Min(100, math.Pow(eps, -1.0/8)))
   120  		tol := tolmul * eps
   121  		var smax float64
   122  		for i := 0; i < n; i++ {
   123  			smax = math.Max(smax, math.Abs(d[i]))
   124  		}
   125  		for i := 0; i < n-1; i++ {
   126  			smax = math.Max(smax, math.Abs(e[i]))
   127  		}
   128  
   129  		var sminl float64
   130  		var thresh float64
   131  		if tol >= 0 {
   132  			sminoa := math.Abs(d[0])
   133  			if sminoa != 0 {
   134  				mu := sminoa
   135  				for i := 1; i < n; i++ {
   136  					mu = math.Abs(d[i]) * (mu / (mu + math.Abs(e[i-1])))
   137  					sminoa = math.Min(sminoa, mu)
   138  					if sminoa == 0 {
   139  						break
   140  					}
   141  				}
   142  			}
   143  			sminoa = sminoa / math.Sqrt(float64(n))
   144  			thresh = math.Max(tol*sminoa, float64(maxIter*n*n)*unfl)
   145  		} else {
   146  			thresh = math.Max(math.Abs(tol)*smax, float64(maxIter*n*n)*unfl)
   147  		}
   148  		// Prepare for the main iteration loop for the singular values.
   149  		maxIt := maxIter * n * n
   150  		iter := 0
   151  		oldl2 := -1
   152  		oldm := -1
   153  		// m points to the last element of unconverged part of matrix.
   154  		m := n
   155  
   156  	Outer:
   157  		for m > 1 {
   158  			if iter > maxIt {
   159  				info = 0
   160  				for i := 0; i < n-1; i++ {
   161  					if e[i] != 0 {
   162  						info++
   163  					}
   164  				}
   165  				return info == 0
   166  			}
   167  			// Find diagonal block of matrix to work on.
   168  			if tol < 0 && math.Abs(d[m-1]) <= thresh {
   169  				d[m-1] = 0
   170  			}
   171  			smax = math.Abs(d[m-1])
   172  			smin := smax
   173  			var l2 int
   174  			var broke bool
   175  			for l3 := 0; l3 < m-1; l3++ {
   176  				l2 = m - l3 - 2
   177  				abss := math.Abs(d[l2])
   178  				abse := math.Abs(e[l2])
   179  				if tol < 0 && abss <= thresh {
   180  					d[l2] = 0
   181  				}
   182  				if abse <= thresh {
   183  					broke = true
   184  					break
   185  				}
   186  				smin = math.Min(smin, abss)
   187  				smax = math.Max(math.Max(smax, abss), abse)
   188  			}
   189  			if broke {
   190  				e[l2] = 0
   191  				if l2 == m-2 {
   192  					// Convergence of bottom singular value, return to top.
   193  					m--
   194  					continue
   195  				}
   196  				l2++
   197  			} else {
   198  				l2 = 0
   199  			}
   200  			// e[ll] through e[m-2] are nonzero, e[ll-1] is zero
   201  			if l2 == m-2 {
   202  				// Handle 2×2 block separately.
   203  				var sinr, cosr, sinl, cosl float64
   204  				d[m-1], d[m-2], sinr, cosr, sinl, cosl = impl.Dlasv2(d[m-2], e[m-2], d[m-1])
   205  				e[m-2] = 0
   206  				if ncvt > 0 {
   207  					bi.Drot(ncvt, vt[(m-2)*ldvt:], 1, vt[(m-1)*ldvt:], 1, cosr, sinr)
   208  				}
   209  				if nru > 0 {
   210  					bi.Drot(nru, u[m-2:], ldu, u[m-1:], ldu, cosl, sinl)
   211  				}
   212  				if ncc > 0 {
   213  					bi.Drot(ncc, c[(m-2)*ldc:], 1, c[(m-1)*ldc:], 1, cosl, sinl)
   214  				}
   215  				m -= 2
   216  				continue
   217  			}
   218  			// If working on a new submatrix, choose shift direction from larger end
   219  			// diagonal element toward smaller.
   220  			if l2 > oldm-1 || m-1 < oldl2 {
   221  				if math.Abs(d[l2]) >= math.Abs(d[m-1]) {
   222  					idir = 1
   223  				} else {
   224  					idir = 2
   225  				}
   226  			}
   227  			// Apply convergence tests.
   228  			// TODO(btracey): There is a lot of similar looking code here. See
   229  			// if there is a better way to de-duplicate.
   230  			if idir == 1 {
   231  				// Run convergence test in forward direction.
   232  				// First apply standard test to bottom of matrix.
   233  				if math.Abs(e[m-2]) <= math.Abs(tol)*math.Abs(d[m-1]) || (tol < 0 && math.Abs(e[m-2]) <= thresh) {
   234  					e[m-2] = 0
   235  					continue
   236  				}
   237  				if tol >= 0 {
   238  					// If relative accuracy desired, apply convergence criterion forward.
   239  					mu := math.Abs(d[l2])
   240  					sminl = mu
   241  					for l3 := l2; l3 < m-1; l3++ {
   242  						if math.Abs(e[l3]) <= tol*mu {
   243  							e[l3] = 0
   244  							continue Outer
   245  						}
   246  						mu = math.Abs(d[l3+1]) * (mu / (mu + math.Abs(e[l3])))
   247  						sminl = math.Min(sminl, mu)
   248  					}
   249  				}
   250  			} else {
   251  				// Run convergence test in backward direction.
   252  				// First apply standard test to top of matrix.
   253  				if math.Abs(e[l2]) <= math.Abs(tol)*math.Abs(d[l2]) || (tol < 0 && math.Abs(e[l2]) <= thresh) {
   254  					e[l2] = 0
   255  					continue
   256  				}
   257  				if tol >= 0 {
   258  					// If relative accuracy desired, apply convergence criterion backward.
   259  					mu := math.Abs(d[m-1])
   260  					sminl = mu
   261  					for l3 := m - 2; l3 >= l2; l3-- {
   262  						if math.Abs(e[l3]) <= tol*mu {
   263  							e[l3] = 0
   264  							continue Outer
   265  						}
   266  						mu = math.Abs(d[l3]) * (mu / (mu + math.Abs(e[l3])))
   267  						sminl = math.Min(sminl, mu)
   268  					}
   269  				}
   270  			}
   271  			oldl2 = l2
   272  			oldm = m
   273  			// Compute shift. First, test if shifting would ruin relative accuracy,
   274  			// and if so set the shift to zero.
   275  			var shift float64
   276  			if tol >= 0 && float64(n)*tol*(sminl/smax) <= math.Max(eps, (1.0/100)*tol) {
   277  				shift = 0
   278  			} else {
   279  				var sl2 float64
   280  				if idir == 1 {
   281  					sl2 = math.Abs(d[l2])
   282  					shift, _ = impl.Dlas2(d[m-2], e[m-2], d[m-1])
   283  				} else {
   284  					sl2 = math.Abs(d[m-1])
   285  					shift, _ = impl.Dlas2(d[l2], e[l2], d[l2+1])
   286  				}
   287  				// Test if shift is negligible
   288  				if sl2 > 0 {
   289  					if (shift/sl2)*(shift/sl2) < eps {
   290  						shift = 0
   291  					}
   292  				}
   293  			}
   294  			iter += m - l2 + 1
   295  			// If no shift, do simplified QR iteration.
   296  			if shift == 0 {
   297  				if idir == 1 {
   298  					cs := 1.0
   299  					oldcs := 1.0
   300  					var sn, r, oldsn float64
   301  					for i := l2; i < m-1; i++ {
   302  						cs, sn, r = impl.Dlartg(d[i]*cs, e[i])
   303  						if i > l2 {
   304  							e[i-1] = oldsn * r
   305  						}
   306  						oldcs, oldsn, d[i] = impl.Dlartg(oldcs*r, d[i+1]*sn)
   307  						work[i-l2] = cs
   308  						work[i-l2+nm1] = sn
   309  						work[i-l2+nm12] = oldcs
   310  						work[i-l2+nm13] = oldsn
   311  					}
   312  					h := d[m-1] * cs
   313  					d[m-1] = h * oldcs
   314  					e[m-2] = h * oldsn
   315  					if ncvt > 0 {
   316  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncvt, work, work[n-1:], vt[l2*ldvt:], ldvt)
   317  					}
   318  					if nru > 0 {
   319  						impl.Dlasr(blas.Right, lapack.Variable, lapack.Forward, nru, m-l2, work[nm12:], work[nm13:], u[l2:], ldu)
   320  					}
   321  					if ncc > 0 {
   322  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncc, work[nm12:], work[nm13:], c[l2*ldc:], ldc)
   323  					}
   324  					if math.Abs(e[m-2]) < thresh {
   325  						e[m-2] = 0
   326  					}
   327  				} else {
   328  					cs := 1.0
   329  					oldcs := 1.0
   330  					var sn, r, oldsn float64
   331  					for i := m - 1; i >= l2+1; i-- {
   332  						cs, sn, r = impl.Dlartg(d[i]*cs, e[i-1])
   333  						if i < m-1 {
   334  							e[i] = oldsn * r
   335  						}
   336  						oldcs, oldsn, d[i] = impl.Dlartg(oldcs*r, d[i-1]*sn)
   337  						work[i-l2-1] = cs
   338  						work[i-l2+nm1-1] = -sn
   339  						work[i-l2+nm12-1] = oldcs
   340  						work[i-l2+nm13-1] = -oldsn
   341  					}
   342  					h := d[l2] * cs
   343  					d[l2] = h * oldcs
   344  					e[l2] = h * oldsn
   345  					if ncvt > 0 {
   346  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncvt, work[nm12:], work[nm13:], vt[l2*ldvt:], ldvt)
   347  					}
   348  					if nru > 0 {
   349  						impl.Dlasr(blas.Right, lapack.Variable, lapack.Backward, nru, m-l2, work, work[n-1:], u[l2:], ldu)
   350  					}
   351  					if ncc > 0 {
   352  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncc, work, work[n-1:], c[l2*ldc:], ldc)
   353  					}
   354  					if math.Abs(e[l2]) <= thresh {
   355  						e[l2] = 0
   356  					}
   357  				}
   358  			} else {
   359  				// Use nonzero shift.
   360  				if idir == 1 {
   361  					// Chase bulge from top to bottom. Save cosines and sines for
   362  					// later singular vector updates.
   363  					f := (math.Abs(d[l2]) - shift) * (math.Copysign(1, d[l2]) + shift/d[l2])
   364  					g := e[l2]
   365  					var cosl, sinl float64
   366  					for i := l2; i < m-1; i++ {
   367  						cosr, sinr, r := impl.Dlartg(f, g)
   368  						if i > l2 {
   369  							e[i-1] = r
   370  						}
   371  						f = cosr*d[i] + sinr*e[i]
   372  						e[i] = cosr*e[i] - sinr*d[i]
   373  						g = sinr * d[i+1]
   374  						d[i+1] *= cosr
   375  						cosl, sinl, r = impl.Dlartg(f, g)
   376  						d[i] = r
   377  						f = cosl*e[i] + sinl*d[i+1]
   378  						d[i+1] = cosl*d[i+1] - sinl*e[i]
   379  						if i < m-2 {
   380  							g = sinl * e[i+1]
   381  							e[i+1] = cosl * e[i+1]
   382  						}
   383  						work[i-l2] = cosr
   384  						work[i-l2+nm1] = sinr
   385  						work[i-l2+nm12] = cosl
   386  						work[i-l2+nm13] = sinl
   387  					}
   388  					e[m-2] = f
   389  					if ncvt > 0 {
   390  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncvt, work, work[n-1:], vt[l2*ldvt:], ldvt)
   391  					}
   392  					if nru > 0 {
   393  						impl.Dlasr(blas.Right, lapack.Variable, lapack.Forward, nru, m-l2, work[nm12:], work[nm13:], u[l2:], ldu)
   394  					}
   395  					if ncc > 0 {
   396  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Forward, m-l2, ncc, work[nm12:], work[nm13:], c[l2*ldc:], ldc)
   397  					}
   398  					if math.Abs(e[m-2]) <= thresh {
   399  						e[m-2] = 0
   400  					}
   401  				} else {
   402  					// Chase bulge from top to bottom. Save cosines and sines for
   403  					// later singular vector updates.
   404  					f := (math.Abs(d[m-1]) - shift) * (math.Copysign(1, d[m-1]) + shift/d[m-1])
   405  					g := e[m-2]
   406  					for i := m - 1; i > l2; i-- {
   407  						cosr, sinr, r := impl.Dlartg(f, g)
   408  						if i < m-1 {
   409  							e[i] = r
   410  						}
   411  						f = cosr*d[i] + sinr*e[i-1]
   412  						e[i-1] = cosr*e[i-1] - sinr*d[i]
   413  						g = sinr * d[i-1]
   414  						d[i-1] *= cosr
   415  						cosl, sinl, r := impl.Dlartg(f, g)
   416  						d[i] = r
   417  						f = cosl*e[i-1] + sinl*d[i-1]
   418  						d[i-1] = cosl*d[i-1] - sinl*e[i-1]
   419  						if i > l2+1 {
   420  							g = sinl * e[i-2]
   421  							e[i-2] *= cosl
   422  						}
   423  						work[i-l2-1] = cosr
   424  						work[i-l2+nm1-1] = -sinr
   425  						work[i-l2+nm12-1] = cosl
   426  						work[i-l2+nm13-1] = -sinl
   427  					}
   428  					e[l2] = f
   429  					if math.Abs(e[l2]) <= thresh {
   430  						e[l2] = 0
   431  					}
   432  					if ncvt > 0 {
   433  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncvt, work[nm12:], work[nm13:], vt[l2*ldvt:], ldvt)
   434  					}
   435  					if nru > 0 {
   436  						impl.Dlasr(blas.Right, lapack.Variable, lapack.Backward, nru, m-l2, work, work[n-1:], u[l2:], ldu)
   437  					}
   438  					if ncc > 0 {
   439  						impl.Dlasr(blas.Left, lapack.Variable, lapack.Backward, m-l2, ncc, work, work[n-1:], c[l2*ldc:], ldc)
   440  					}
   441  				}
   442  			}
   443  		}
   444  	}
   445  
   446  	// All singular values converged, make them positive.
   447  	for i := 0; i < n; i++ {
   448  		if d[i] < 0 {
   449  			d[i] *= -1
   450  			if ncvt > 0 {
   451  				bi.Dscal(ncvt, -1, vt[i*ldvt:], 1)
   452  			}
   453  		}
   454  	}
   455  
   456  	// Sort the singular values in decreasing order.
   457  	for i := 0; i < n-1; i++ {
   458  		isub := 0
   459  		smin := d[0]
   460  		for j := 1; j < n-i; j++ {
   461  			if d[j] <= smin {
   462  				isub = j
   463  				smin = d[j]
   464  			}
   465  		}
   466  		if isub != n-i {
   467  			// Swap singular values and vectors.
   468  			d[isub] = d[n-i-1]
   469  			d[n-i-1] = smin
   470  			if ncvt > 0 {
   471  				bi.Dswap(ncvt, vt[isub*ldvt:], 1, vt[(n-i-1)*ldvt:], 1)
   472  			}
   473  			if nru > 0 {
   474  				bi.Dswap(nru, u[isub:], ldu, u[n-i-1:], ldu)
   475  			}
   476  			if ncc > 0 {
   477  				bi.Dswap(ncc, c[isub*ldc:], 1, c[(n-i-1)*ldc:], 1)
   478  			}
   479  		}
   480  	}
   481  	info = 0
   482  	for i := 0; i < n-1; i++ {
   483  		if e[i] != 0 {
   484  			info++
   485  		}
   486  	}
   487  	return info == 0
   488  }