gonum.org/v1/gonum@v0.14.0/lapack/gonum/dgetri.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gonum 6 7 import ( 8 "gonum.org/v1/gonum/blas" 9 "gonum.org/v1/gonum/blas/blas64" 10 ) 11 12 // Dgetri computes the inverse of the matrix A using the LU factorization computed 13 // by Dgetrf. On entry, a contains the PLU decomposition of A as computed by 14 // Dgetrf and on exit contains the reciprocal of the original matrix. 15 // 16 // Dgetri will not perform the inversion if the matrix is singular, and returns 17 // a boolean indicating whether the inversion was successful. 18 // 19 // work is temporary storage, and lwork specifies the usable memory length. 20 // At minimum, lwork >= n and this function will panic otherwise. 21 // Dgetri is a blocked inversion, but the block size is limited 22 // by the temporary space available. If lwork == -1, instead of performing Dgetri, 23 // the optimal work length will be stored into work[0]. 24 func (impl Implementation) Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) (ok bool) { 25 iws := max(1, n) 26 switch { 27 case n < 0: 28 panic(nLT0) 29 case lda < max(1, n): 30 panic(badLdA) 31 case lwork < iws && lwork != -1: 32 panic(badLWork) 33 case len(work) < max(1, lwork): 34 panic(shortWork) 35 } 36 37 if n == 0 { 38 work[0] = 1 39 return true 40 } 41 42 nb := impl.Ilaenv(1, "DGETRI", " ", n, -1, -1, -1) 43 if lwork == -1 { 44 work[0] = float64(n * nb) 45 return true 46 } 47 48 switch { 49 case len(a) < (n-1)*lda+n: 50 panic(shortA) 51 case len(ipiv) != n: 52 panic(badLenIpiv) 53 } 54 55 // Form inv(U). 56 ok = impl.Dtrtri(blas.Upper, blas.NonUnit, n, a, lda) 57 if !ok { 58 return false 59 } 60 61 nbmin := 2 62 if 1 < nb && nb < n { 63 iws = max(n*nb, 1) 64 if lwork < iws { 65 nb = lwork / n 66 nbmin = max(2, impl.Ilaenv(2, "DGETRI", " ", n, -1, -1, -1)) 67 } 68 } 69 ldwork := nb 70 71 bi := blas64.Implementation() 72 // Solve the equation inv(A)*L = inv(U) for inv(A). 73 // TODO(btracey): Replace this with a more row-major oriented algorithm. 74 if nb < nbmin || n <= nb { 75 // Unblocked code. 76 for j := n - 1; j >= 0; j-- { 77 for i := j + 1; i < n; i++ { 78 // Copy current column of L to work and replace with zeros. 79 work[i] = a[i*lda+j] 80 a[i*lda+j] = 0 81 } 82 // Compute current column of inv(A). 83 if j < n-1 { 84 bi.Dgemv(blas.NoTrans, n, n-j-1, -1, a[(j+1):], lda, work[(j+1):], 1, 1, a[j:], lda) 85 } 86 } 87 } else { 88 // Blocked code. 89 nn := ((n - 1) / nb) * nb 90 for j := nn; j >= 0; j -= nb { 91 jb := min(nb, n-j) 92 // Copy current block column of L to work and replace 93 // with zeros. 94 for jj := j; jj < j+jb; jj++ { 95 for i := jj + 1; i < n; i++ { 96 work[i*ldwork+(jj-j)] = a[i*lda+jj] 97 a[i*lda+jj] = 0 98 } 99 } 100 // Compute current block column of inv(A). 101 if j+jb < n { 102 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, jb, n-j-jb, -1, a[(j+jb):], lda, work[(j+jb)*ldwork:], ldwork, 1, a[j:], lda) 103 } 104 bi.Dtrsm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, jb, 1, work[j*ldwork:], ldwork, a[j:], lda) 105 } 106 } 107 // Apply column interchanges. 108 for j := n - 2; j >= 0; j-- { 109 jp := ipiv[j] 110 if jp != j { 111 bi.Dswap(n, a[j:], lda, a[jp:], lda) 112 } 113 } 114 work[0] = float64(iws) 115 return true 116 }