gonum.org/v1/gonum@v0.14.0/lapack/gonum/dgetri.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"gonum.org/v1/gonum/blas"
     9  	"gonum.org/v1/gonum/blas/blas64"
    10  )
    11  
    12  // Dgetri computes the inverse of the matrix A using the LU factorization computed
    13  // by Dgetrf. On entry, a contains the PLU decomposition of A as computed by
    14  // Dgetrf and on exit contains the reciprocal of the original matrix.
    15  //
    16  // Dgetri will not perform the inversion if the matrix is singular, and returns
    17  // a boolean indicating whether the inversion was successful.
    18  //
    19  // work is temporary storage, and lwork specifies the usable memory length.
    20  // At minimum, lwork >= n and this function will panic otherwise.
    21  // Dgetri is a blocked inversion, but the block size is limited
    22  // by the temporary space available. If lwork == -1, instead of performing Dgetri,
    23  // the optimal work length will be stored into work[0].
    24  func (impl Implementation) Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) (ok bool) {
    25  	iws := max(1, n)
    26  	switch {
    27  	case n < 0:
    28  		panic(nLT0)
    29  	case lda < max(1, n):
    30  		panic(badLdA)
    31  	case lwork < iws && lwork != -1:
    32  		panic(badLWork)
    33  	case len(work) < max(1, lwork):
    34  		panic(shortWork)
    35  	}
    36  
    37  	if n == 0 {
    38  		work[0] = 1
    39  		return true
    40  	}
    41  
    42  	nb := impl.Ilaenv(1, "DGETRI", " ", n, -1, -1, -1)
    43  	if lwork == -1 {
    44  		work[0] = float64(n * nb)
    45  		return true
    46  	}
    47  
    48  	switch {
    49  	case len(a) < (n-1)*lda+n:
    50  		panic(shortA)
    51  	case len(ipiv) != n:
    52  		panic(badLenIpiv)
    53  	}
    54  
    55  	// Form inv(U).
    56  	ok = impl.Dtrtri(blas.Upper, blas.NonUnit, n, a, lda)
    57  	if !ok {
    58  		return false
    59  	}
    60  
    61  	nbmin := 2
    62  	if 1 < nb && nb < n {
    63  		iws = max(n*nb, 1)
    64  		if lwork < iws {
    65  			nb = lwork / n
    66  			nbmin = max(2, impl.Ilaenv(2, "DGETRI", " ", n, -1, -1, -1))
    67  		}
    68  	}
    69  	ldwork := nb
    70  
    71  	bi := blas64.Implementation()
    72  	// Solve the equation inv(A)*L = inv(U) for inv(A).
    73  	// TODO(btracey): Replace this with a more row-major oriented algorithm.
    74  	if nb < nbmin || n <= nb {
    75  		// Unblocked code.
    76  		for j := n - 1; j >= 0; j-- {
    77  			for i := j + 1; i < n; i++ {
    78  				// Copy current column of L to work and replace with zeros.
    79  				work[i] = a[i*lda+j]
    80  				a[i*lda+j] = 0
    81  			}
    82  			// Compute current column of inv(A).
    83  			if j < n-1 {
    84  				bi.Dgemv(blas.NoTrans, n, n-j-1, -1, a[(j+1):], lda, work[(j+1):], 1, 1, a[j:], lda)
    85  			}
    86  		}
    87  	} else {
    88  		// Blocked code.
    89  		nn := ((n - 1) / nb) * nb
    90  		for j := nn; j >= 0; j -= nb {
    91  			jb := min(nb, n-j)
    92  			// Copy current block column of L to work and replace
    93  			// with zeros.
    94  			for jj := j; jj < j+jb; jj++ {
    95  				for i := jj + 1; i < n; i++ {
    96  					work[i*ldwork+(jj-j)] = a[i*lda+jj]
    97  					a[i*lda+jj] = 0
    98  				}
    99  			}
   100  			// Compute current block column of inv(A).
   101  			if j+jb < n {
   102  				bi.Dgemm(blas.NoTrans, blas.NoTrans, n, jb, n-j-jb, -1, a[(j+jb):], lda, work[(j+jb)*ldwork:], ldwork, 1, a[j:], lda)
   103  			}
   104  			bi.Dtrsm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, jb, 1, work[j*ldwork:], ldwork, a[j:], lda)
   105  		}
   106  	}
   107  	// Apply column interchanges.
   108  	for j := n - 2; j >= 0; j-- {
   109  		jp := ipiv[j]
   110  		if jp != j {
   111  			bi.Dswap(n, a[j:], lda, a[jp:], lda)
   112  		}
   113  	}
   114  	work[0] = float64(iws)
   115  	return true
   116  }