gonum.org/v1/gonum@v0.14.0/lapack/gonum/dlatrd.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"gonum.org/v1/gonum/blas"
     9  	"gonum.org/v1/gonum/blas/blas64"
    10  )
    11  
    12  // Dlatrd reduces nb rows and columns of a real n×n symmetric matrix A to symmetric
    13  // tridiagonal form. It computes the orthonormal similarity transformation
    14  //
    15  //	Qᵀ * A * Q
    16  //
    17  // and returns the matrices V and W to apply to the unreduced part of A. If
    18  // uplo == blas.Upper, the upper triangle is supplied and the last nb rows are
    19  // reduced. If uplo == blas.Lower, the lower triangle is supplied and the first
    20  // nb rows are reduced.
    21  //
    22  // a contains the symmetric matrix on entry with active triangular half specified
    23  // by uplo. On exit, the nb columns have been reduced to tridiagonal form. The
    24  // diagonal contains the diagonal of the reduced matrix, the off-diagonal is
    25  // set to 1, and the remaining elements contain the data to construct Q.
    26  //
    27  // If uplo == blas.Upper, with n = 5 and nb = 2 on exit a is
    28  //
    29  //	[ a   a   a  v4  v5]
    30  //	[     a   a  v4  v5]
    31  //	[         a   1  v5]
    32  //	[             d   1]
    33  //	[                 d]
    34  //
    35  // If uplo == blas.Lower, with n = 5 and nb = 2, on exit a is
    36  //
    37  //	[ d                ]
    38  //	[ 1   d            ]
    39  //	[v1   1   a        ]
    40  //	[v1  v2   a   a    ]
    41  //	[v1  v2   a   a   a]
    42  //
    43  // e contains the superdiagonal elements of the reduced matrix. If uplo == blas.Upper,
    44  // e[n-nb:n-1] contains the last nb columns of the reduced matrix, while if
    45  // uplo == blas.Lower, e[:nb] contains the first nb columns of the reduced matrix.
    46  // e must have length at least n-1, and Dlatrd will panic otherwise.
    47  //
    48  // tau contains the scalar factors of the elementary reflectors needed to construct Q.
    49  // The reflectors are stored in tau[n-nb:n-1] if uplo == blas.Upper, and in
    50  // tau[:nb] if uplo == blas.Lower. tau must have length n-1, and Dlatrd will panic
    51  // otherwise.
    52  //
    53  // w is an n×nb matrix. On exit it contains the data to update the unreduced part
    54  // of A.
    55  //
    56  // The matrix Q is represented as a product of elementary reflectors. Each reflector
    57  // H has the form
    58  //
    59  //	I - tau * v * vᵀ
    60  //
    61  // If uplo == blas.Upper,
    62  //
    63  //	Q = H_{n-1} * H_{n-2} * ... * H_{n-nb}
    64  //
    65  // where v[:i-1] is stored in A[:i-1,i], v[i-1] = 1, and v[i:n] = 0.
    66  //
    67  // If uplo == blas.Lower,
    68  //
    69  //	Q = H_0 * H_1 * ... * H_{nb-1}
    70  //
    71  // where v[:i+1] = 0, v[i+1] = 1, and v[i+2:n] is stored in A[i+2:n,i].
    72  //
    73  // The vectors v form the n×nb matrix V which is used with W to apply a
    74  // symmetric rank-2 update to the unreduced part of A
    75  //
    76  //	A = A - V * Wᵀ - W * Vᵀ
    77  //
    78  // Dlatrd is an internal routine. It is exported for testing purposes.
    79  func (impl Implementation) Dlatrd(uplo blas.Uplo, n, nb int, a []float64, lda int, e, tau, w []float64, ldw int) {
    80  	switch {
    81  	case uplo != blas.Upper && uplo != blas.Lower:
    82  		panic(badUplo)
    83  	case n < 0:
    84  		panic(nLT0)
    85  	case nb < 0:
    86  		panic(nbLT0)
    87  	case nb > n:
    88  		panic(nbGTN)
    89  	case lda < max(1, n):
    90  		panic(badLdA)
    91  	case ldw < max(1, nb):
    92  		panic(badLdW)
    93  	}
    94  
    95  	if n == 0 {
    96  		return
    97  	}
    98  
    99  	switch {
   100  	case len(a) < (n-1)*lda+n:
   101  		panic(shortA)
   102  	case len(w) < (n-1)*ldw+nb:
   103  		panic(shortW)
   104  	case len(e) < n-1:
   105  		panic(shortE)
   106  	case len(tau) < n-1:
   107  		panic(shortTau)
   108  	}
   109  
   110  	bi := blas64.Implementation()
   111  
   112  	if uplo == blas.Upper {
   113  		for i := n - 1; i >= n-nb; i-- {
   114  			iw := i - n + nb
   115  			if i < n-1 {
   116  				// Update A(0:i, i).
   117  				bi.Dgemv(blas.NoTrans, i+1, n-i-1, -1, a[i+1:], lda,
   118  					w[i*ldw+iw+1:], 1, 1, a[i:], lda)
   119  				bi.Dgemv(blas.NoTrans, i+1, n-i-1, -1, w[iw+1:], ldw,
   120  					a[i*lda+i+1:], 1, 1, a[i:], lda)
   121  			}
   122  			if i > 0 {
   123  				// Generate elementary reflector H_i to annihilate A(0:i-2,i).
   124  				e[i-1], tau[i-1] = impl.Dlarfg(i, a[(i-1)*lda+i], a[i:], lda)
   125  				a[(i-1)*lda+i] = 1
   126  
   127  				// Compute W(0:i-1, i).
   128  				bi.Dsymv(blas.Upper, i, 1, a, lda, a[i:], lda, 0, w[iw:], ldw)
   129  				if i < n-1 {
   130  					bi.Dgemv(blas.Trans, i, n-i-1, 1, w[iw+1:], ldw,
   131  						a[i:], lda, 0, w[(i+1)*ldw+iw:], ldw)
   132  					bi.Dgemv(blas.NoTrans, i, n-i-1, -1, a[i+1:], lda,
   133  						w[(i+1)*ldw+iw:], ldw, 1, w[iw:], ldw)
   134  					bi.Dgemv(blas.Trans, i, n-i-1, 1, a[i+1:], lda,
   135  						a[i:], lda, 0, w[(i+1)*ldw+iw:], ldw)
   136  					bi.Dgemv(blas.NoTrans, i, n-i-1, -1, w[iw+1:], ldw,
   137  						w[(i+1)*ldw+iw:], ldw, 1, w[iw:], ldw)
   138  				}
   139  				bi.Dscal(i, tau[i-1], w[iw:], ldw)
   140  				alpha := -0.5 * tau[i-1] * bi.Ddot(i, w[iw:], ldw, a[i:], lda)
   141  				bi.Daxpy(i, alpha, a[i:], lda, w[iw:], ldw)
   142  			}
   143  		}
   144  	} else {
   145  		// Reduce first nb columns of lower triangle.
   146  		for i := 0; i < nb; i++ {
   147  			// Update A(i:n, i)
   148  			bi.Dgemv(blas.NoTrans, n-i, i, -1, a[i*lda:], lda,
   149  				w[i*ldw:], 1, 1, a[i*lda+i:], lda)
   150  			bi.Dgemv(blas.NoTrans, n-i, i, -1, w[i*ldw:], ldw,
   151  				a[i*lda:], 1, 1, a[i*lda+i:], lda)
   152  			if i < n-1 {
   153  				// Generate elementary reflector H_i to annihilate A(i+2:n,i).
   154  				e[i], tau[i] = impl.Dlarfg(n-i-1, a[(i+1)*lda+i], a[min(i+2, n-1)*lda+i:], lda)
   155  				a[(i+1)*lda+i] = 1
   156  
   157  				// Compute W(i+1:n,i).
   158  				bi.Dsymv(blas.Lower, n-i-1, 1, a[(i+1)*lda+i+1:], lda,
   159  					a[(i+1)*lda+i:], lda, 0, w[(i+1)*ldw+i:], ldw)
   160  				bi.Dgemv(blas.Trans, n-i-1, i, 1, w[(i+1)*ldw:], ldw,
   161  					a[(i+1)*lda+i:], lda, 0, w[i:], ldw)
   162  				bi.Dgemv(blas.NoTrans, n-i-1, i, -1, a[(i+1)*lda:], lda,
   163  					w[i:], ldw, 1, w[(i+1)*ldw+i:], ldw)
   164  				bi.Dgemv(blas.Trans, n-i-1, i, 1, a[(i+1)*lda:], lda,
   165  					a[(i+1)*lda+i:], lda, 0, w[i:], ldw)
   166  				bi.Dgemv(blas.NoTrans, n-i-1, i, -1, w[(i+1)*ldw:], ldw,
   167  					w[i:], ldw, 1, w[(i+1)*ldw+i:], ldw)
   168  				bi.Dscal(n-i-1, tau[i], w[(i+1)*ldw+i:], ldw)
   169  				alpha := -0.5 * tau[i] * bi.Ddot(n-i-1, w[(i+1)*ldw+i:], ldw,
   170  					a[(i+1)*lda+i:], lda)
   171  				bi.Daxpy(n-i-1, alpha, a[(i+1)*lda+i:], lda,
   172  					w[(i+1)*ldw+i:], ldw)
   173  			}
   174  		}
   175  	}
   176  }