github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/gonum/dormqr.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"github.com/jingcheng-WU/gonum/blas"
     9  	"github.com/jingcheng-WU/gonum/lapack"
    10  )
    11  
    12  // Dormqr multiplies an m×n matrix C by an orthogonal matrix Q as
    13  //  C = Q * C   if side == blas.Left  and trans == blas.NoTrans,
    14  //  C = Qᵀ * C  if side == blas.Left  and trans == blas.Trans,
    15  //  C = C * Q   if side == blas.Right and trans == blas.NoTrans,
    16  //  C = C * Qᵀ  if side == blas.Right and trans == blas.Trans,
    17  // where Q is defined as the product of k elementary reflectors
    18  //  Q = H_0 * H_1 * ... * H_{k-1}.
    19  //
    20  // If side == blas.Left, A is an m×k matrix and 0 <= k <= m.
    21  // If side == blas.Right, A is an n×k matrix and 0 <= k <= n.
    22  // The ith column of A contains the vector which defines the elementary
    23  // reflector H_i and tau[i] contains its scalar factor. tau must have length k
    24  // and Dormqr will panic otherwise. Dgeqrf returns A and tau in the required
    25  // form.
    26  //
    27  // work must have length at least max(1,lwork), and lwork must be at least n if
    28  // side == blas.Left and at least m if side == blas.Right, otherwise Dormqr will
    29  // panic.
    30  //
    31  // work is temporary storage, and lwork specifies the usable memory length. At
    32  // minimum, lwork >= m if side == blas.Left and lwork >= n if side ==
    33  // blas.Right, and this function will panic otherwise. Larger values of lwork
    34  // will generally give better performance. On return, work[0] will contain the
    35  // optimal value of lwork.
    36  //
    37  // If lwork is -1, instead of performing Dormqr, the optimal workspace size will
    38  // be stored into work[0].
    39  func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) {
    40  	left := side == blas.Left
    41  	nq := n
    42  	nw := m
    43  	if left {
    44  		nq = m
    45  		nw = n
    46  	}
    47  	switch {
    48  	case !left && side != blas.Right:
    49  		panic(badSide)
    50  	case trans != blas.NoTrans && trans != blas.Trans:
    51  		panic(badTrans)
    52  	case m < 0:
    53  		panic(mLT0)
    54  	case n < 0:
    55  		panic(nLT0)
    56  	case k < 0:
    57  		panic(kLT0)
    58  	case left && k > m:
    59  		panic(kGTM)
    60  	case !left && k > n:
    61  		panic(kGTN)
    62  	case lda < max(1, k):
    63  		panic(badLdA)
    64  	case ldc < max(1, n):
    65  		panic(badLdC)
    66  	case lwork < max(1, nw) && lwork != -1:
    67  		panic(badLWork)
    68  	case len(work) < max(1, lwork):
    69  		panic(shortWork)
    70  	}
    71  
    72  	// Quick return if possible.
    73  	if m == 0 || n == 0 || k == 0 {
    74  		work[0] = 1
    75  		return
    76  	}
    77  
    78  	const (
    79  		nbmax = 64
    80  		ldt   = nbmax
    81  		tsize = nbmax * ldt
    82  	)
    83  	opts := string(side) + string(trans)
    84  	nb := min(nbmax, impl.Ilaenv(1, "DORMQR", opts, m, n, k, -1))
    85  	lworkopt := max(1, nw)*nb + tsize
    86  	if lwork == -1 {
    87  		work[0] = float64(lworkopt)
    88  		return
    89  	}
    90  
    91  	switch {
    92  	case len(a) < (nq-1)*lda+k:
    93  		panic(shortA)
    94  	case len(tau) != k:
    95  		panic(badLenTau)
    96  	case len(c) < (m-1)*ldc+n:
    97  		panic(shortC)
    98  	}
    99  
   100  	nbmin := 2
   101  	if 1 < nb && nb < k {
   102  		if lwork < nw*nb+tsize {
   103  			nb = (lwork - tsize) / nw
   104  			nbmin = max(2, impl.Ilaenv(2, "DORMQR", opts, m, n, k, -1))
   105  		}
   106  	}
   107  
   108  	if nb < nbmin || k <= nb {
   109  		// Call unblocked code.
   110  		impl.Dorm2r(side, trans, m, n, k, a, lda, tau, c, ldc, work)
   111  		work[0] = float64(lworkopt)
   112  		return
   113  	}
   114  
   115  	var (
   116  		ldwork  = nb
   117  		notrans = trans == blas.NoTrans
   118  	)
   119  	switch {
   120  	case left && notrans:
   121  		for i := ((k - 1) / nb) * nb; i >= 0; i -= nb {
   122  			ib := min(nb, k-i)
   123  			impl.Dlarft(lapack.Forward, lapack.ColumnWise, m-i, ib,
   124  				a[i*lda+i:], lda,
   125  				tau[i:],
   126  				work[:tsize], ldt)
   127  			impl.Dlarfb(side, trans, lapack.Forward, lapack.ColumnWise, m-i, n, ib,
   128  				a[i*lda+i:], lda,
   129  				work[:tsize], ldt,
   130  				c[i*ldc:], ldc,
   131  				work[tsize:], ldwork)
   132  		}
   133  
   134  	case left && !notrans:
   135  		for i := 0; i < k; i += nb {
   136  			ib := min(nb, k-i)
   137  			impl.Dlarft(lapack.Forward, lapack.ColumnWise, m-i, ib,
   138  				a[i*lda+i:], lda,
   139  				tau[i:],
   140  				work[:tsize], ldt)
   141  			impl.Dlarfb(side, trans, lapack.Forward, lapack.ColumnWise, m-i, n, ib,
   142  				a[i*lda+i:], lda,
   143  				work[:tsize], ldt,
   144  				c[i*ldc:], ldc,
   145  				work[tsize:], ldwork)
   146  		}
   147  
   148  	case !left && notrans:
   149  		for i := 0; i < k; i += nb {
   150  			ib := min(nb, k-i)
   151  			impl.Dlarft(lapack.Forward, lapack.ColumnWise, n-i, ib,
   152  				a[i*lda+i:], lda,
   153  				tau[i:],
   154  				work[:tsize], ldt)
   155  			impl.Dlarfb(side, trans, lapack.Forward, lapack.ColumnWise, m, n-i, ib,
   156  				a[i*lda+i:], lda,
   157  				work[:tsize], ldt,
   158  				c[i:], ldc,
   159  				work[tsize:], ldwork)
   160  		}
   161  
   162  	case !left && !notrans:
   163  		for i := ((k - 1) / nb) * nb; i >= 0; i -= nb {
   164  			ib := min(nb, k-i)
   165  			impl.Dlarft(lapack.Forward, lapack.ColumnWise, n-i, ib,
   166  				a[i*lda+i:], lda,
   167  				tau[i:],
   168  				work[:tsize], ldt)
   169  			impl.Dlarfb(side, trans, lapack.Forward, lapack.ColumnWise, m, n-i, ib,
   170  				a[i*lda+i:], lda,
   171  				work[:tsize], ldt,
   172  				c[i:], ldc,
   173  				work[tsize:], ldwork)
   174  		}
   175  	}
   176  	work[0] = float64(lworkopt)
   177  }