github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/gonum/dlauum.go (about)

     1  // Copyright ©2018 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"github.com/jingcheng-WU/gonum/blas"
     9  	"github.com/jingcheng-WU/gonum/blas/blas64"
    10  )
    11  
    12  // Dlauum computes the product
    13  //  U * Uᵀ  if uplo is blas.Upper
    14  //  Lᵀ * L  if uplo is blas.Lower
    15  // where U or L is stored in the upper or lower triangular part of A.
    16  // Only the upper or lower triangle of the result is stored, overwriting
    17  // the corresponding factor in A.
    18  func (impl Implementation) Dlauum(uplo blas.Uplo, n int, a []float64, lda int) {
    19  	switch {
    20  	case uplo != blas.Upper && uplo != blas.Lower:
    21  		panic(badUplo)
    22  	case n < 0:
    23  		panic(nLT0)
    24  	case lda < max(1, n):
    25  		panic(badLdA)
    26  	}
    27  
    28  	// Quick return if possible.
    29  	if n == 0 {
    30  		return
    31  	}
    32  
    33  	if len(a) < (n-1)*lda+n {
    34  		panic(shortA)
    35  	}
    36  
    37  	// Determine the block size.
    38  	opts := "U"
    39  	if uplo == blas.Lower {
    40  		opts = "L"
    41  	}
    42  	nb := impl.Ilaenv(1, "DLAUUM", opts, n, -1, -1, -1)
    43  
    44  	if nb <= 1 || n <= nb {
    45  		// Use unblocked code.
    46  		impl.Dlauu2(uplo, n, a, lda)
    47  		return
    48  	}
    49  
    50  	// Use blocked code.
    51  	bi := blas64.Implementation()
    52  	if uplo == blas.Upper {
    53  		// Compute the product U*Uᵀ.
    54  		for i := 0; i < n; i += nb {
    55  			ib := min(nb, n-i)
    56  			bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.NonUnit,
    57  				i, ib, 1, a[i*lda+i:], lda, a[i:], lda)
    58  			impl.Dlauu2(blas.Upper, ib, a[i*lda+i:], lda)
    59  			if n-i-ib > 0 {
    60  				bi.Dgemm(blas.NoTrans, blas.Trans, i, ib, n-i-ib,
    61  					1, a[i+ib:], lda, a[i*lda+i+ib:], lda, 1, a[i:], lda)
    62  				bi.Dsyrk(blas.Upper, blas.NoTrans, ib, n-i-ib,
    63  					1, a[i*lda+i+ib:], lda, 1, a[i*lda+i:], lda)
    64  			}
    65  		}
    66  	} else {
    67  		// Compute the product Lᵀ*L.
    68  		for i := 0; i < n; i += nb {
    69  			ib := min(nb, n-i)
    70  			bi.Dtrmm(blas.Left, blas.Lower, blas.Trans, blas.NonUnit,
    71  				ib, i, 1, a[i*lda+i:], lda, a[i*lda:], lda)
    72  			impl.Dlauu2(blas.Lower, ib, a[i*lda+i:], lda)
    73  			if n-i-ib > 0 {
    74  				bi.Dgemm(blas.Trans, blas.NoTrans, ib, i, n-i-ib,
    75  					1, a[(i+ib)*lda+i:], lda, a[(i+ib)*lda:], lda, 1, a[i*lda:], lda)
    76  				bi.Dsyrk(blas.Lower, blas.Trans, ib, n-i-ib,
    77  					1, a[(i+ib)*lda+i:], lda, 1, a[i*lda+i:], lda)
    78  			}
    79  		}
    80  	}
    81  }