gonum.org/v1/gonum@v0.14.0/lapack/gonum/dormqr.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"gonum.org/v1/gonum/blas"
     9  	"gonum.org/v1/gonum/lapack"
    10  )
    11  
    12  // Dormqr multiplies an m×n matrix C by an orthogonal matrix Q as
    13  //
    14  //	C = Q * C   if side == blas.Left  and trans == blas.NoTrans,
    15  //	C = Qᵀ * C  if side == blas.Left  and trans == blas.Trans,
    16  //	C = C * Q   if side == blas.Right and trans == blas.NoTrans,
    17  //	C = C * Qᵀ  if side == blas.Right and trans == blas.Trans,
    18  //
    19  // where Q is defined as the product of k elementary reflectors
    20  //
    21  //	Q = H_0 * H_1 * ... * H_{k-1}.
    22  //
    23  // If side == blas.Left, A is an m×k matrix and 0 <= k <= m.
    24  // If side == blas.Right, A is an n×k matrix and 0 <= k <= n.
    25  // The ith column of A contains the vector which defines the elementary
    26  // reflector H_i and tau[i] contains its scalar factor. tau must have length k
    27  // and Dormqr will panic otherwise. Dgeqrf returns A and tau in the required
    28  // form.
    29  //
    30  // work must have length at least max(1,lwork), and lwork must be at least n if
    31  // side == blas.Left and at least m if side == blas.Right, otherwise Dormqr will
    32  // panic.
    33  //
    34  // work is temporary storage, and lwork specifies the usable memory length. At
    35  // minimum, lwork >= m if side == blas.Left and lwork >= n if side ==
    36  // blas.Right, and this function will panic otherwise. Larger values of lwork
    37  // will generally give better performance. On return, work[0] will contain the
    38  // optimal value of lwork.
    39  //
    40  // If lwork is -1, instead of performing Dormqr, the optimal workspace size will
    41  // be stored into work[0].
    42  func (impl Implementation) Dormqr(side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int) {
    43  	left := side == blas.Left
    44  	nq := n
    45  	nw := m
    46  	if left {
    47  		nq = m
    48  		nw = n
    49  	}
    50  	switch {
    51  	case !left && side != blas.Right:
    52  		panic(badSide)
    53  	case trans != blas.NoTrans && trans != blas.Trans:
    54  		panic(badTrans)
    55  	case m < 0:
    56  		panic(mLT0)
    57  	case n < 0:
    58  		panic(nLT0)
    59  	case k < 0:
    60  		panic(kLT0)
    61  	case left && k > m:
    62  		panic(kGTM)
    63  	case !left && k > n:
    64  		panic(kGTN)
    65  	case lda < max(1, k):
    66  		panic(badLdA)
    67  	case ldc < max(1, n):
    68  		panic(badLdC)
    69  	case lwork < max(1, nw) && lwork != -1:
    70  		panic(badLWork)
    71  	case len(work) < max(1, lwork):
    72  		panic(shortWork)
    73  	}
    74  
    75  	// Quick return if possible.
    76  	if m == 0 || n == 0 || k == 0 {
    77  		work[0] = 1
    78  		return
    79  	}
    80  
    81  	const (
    82  		nbmax = 64
    83  		ldt   = nbmax
    84  		tsize = nbmax * ldt
    85  	)
    86  	opts := string(side) + string(trans)
    87  	nb := min(nbmax, impl.Ilaenv(1, "DORMQR", opts, m, n, k, -1))
    88  	lworkopt := max(1, nw)*nb + tsize
    89  	if lwork == -1 {
    90  		work[0] = float64(lworkopt)
    91  		return
    92  	}
    93  
    94  	switch {
    95  	case len(a) < (nq-1)*lda+k:
    96  		panic(shortA)
    97  	case len(tau) != k:
    98  		panic(badLenTau)
    99  	case len(c) < (m-1)*ldc+n:
   100  		panic(shortC)
   101  	}
   102  
   103  	nbmin := 2
   104  	if 1 < nb && nb < k {
   105  		if lwork < nw*nb+tsize {
   106  			nb = (lwork - tsize) / nw
   107  			nbmin = max(2, impl.Ilaenv(2, "DORMQR", opts, m, n, k, -1))
   108  		}
   109  	}
   110  
   111  	if nb < nbmin || k <= nb {
   112  		// Call unblocked code.
   113  		impl.Dorm2r(side, trans, m, n, k, a, lda, tau, c, ldc, work)
   114  		work[0] = float64(lworkopt)
   115  		return
   116  	}
   117  
   118  	var (
   119  		ldwork  = nb
   120  		notrans = trans == blas.NoTrans
   121  	)
   122  	switch {
   123  	case left && notrans:
   124  		for i := ((k - 1) / nb) * nb; i >= 0; i -= nb {
   125  			ib := min(nb, k-i)
   126  			impl.Dlarft(lapack.Forward, lapack.ColumnWise, m-i, ib,
   127  				a[i*lda+i:], lda,
   128  				tau[i:],
   129  				work[:tsize], ldt)
   130  			impl.Dlarfb(side, trans, lapack.Forward, lapack.ColumnWise, m-i, n, ib,
   131  				a[i*lda+i:], lda,
   132  				work[:tsize], ldt,
   133  				c[i*ldc:], ldc,
   134  				work[tsize:], ldwork)
   135  		}
   136  
   137  	case left && !notrans:
   138  		for i := 0; i < k; i += nb {
   139  			ib := min(nb, k-i)
   140  			impl.Dlarft(lapack.Forward, lapack.ColumnWise, m-i, ib,
   141  				a[i*lda+i:], lda,
   142  				tau[i:],
   143  				work[:tsize], ldt)
   144  			impl.Dlarfb(side, trans, lapack.Forward, lapack.ColumnWise, m-i, n, ib,
   145  				a[i*lda+i:], lda,
   146  				work[:tsize], ldt,
   147  				c[i*ldc:], ldc,
   148  				work[tsize:], ldwork)
   149  		}
   150  
   151  	case !left && notrans:
   152  		for i := 0; i < k; i += nb {
   153  			ib := min(nb, k-i)
   154  			impl.Dlarft(lapack.Forward, lapack.ColumnWise, n-i, ib,
   155  				a[i*lda+i:], lda,
   156  				tau[i:],
   157  				work[:tsize], ldt)
   158  			impl.Dlarfb(side, trans, lapack.Forward, lapack.ColumnWise, m, n-i, ib,
   159  				a[i*lda+i:], lda,
   160  				work[:tsize], ldt,
   161  				c[i:], ldc,
   162  				work[tsize:], ldwork)
   163  		}
   164  
   165  	case !left && !notrans:
   166  		for i := ((k - 1) / nb) * nb; i >= 0; i -= nb {
   167  			ib := min(nb, k-i)
   168  			impl.Dlarft(lapack.Forward, lapack.ColumnWise, n-i, ib,
   169  				a[i*lda+i:], lda,
   170  				tau[i:],
   171  				work[:tsize], ldt)
   172  			impl.Dlarfb(side, trans, lapack.Forward, lapack.ColumnWise, m, n-i, ib,
   173  				a[i*lda+i:], lda,
   174  				work[:tsize], ldt,
   175  				c[i:], ldc,
   176  				work[tsize:], ldwork)
   177  		}
   178  	}
   179  	work[0] = float64(lworkopt)
   180  }