github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/native/dgetri.go (about) 1 // Copyright ©2015 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package native 6 7 import ( 8 "github.com/gonum/blas" 9 "github.com/gonum/blas/blas64" 10 ) 11 12 // Dgetri computes the inverse of the matrix A using the LU factorization computed 13 // by Dgetrf. On entry, a contains the PLU decomposition of A as computed by 14 // Dgetrf and on exit contains the reciprocal of the original matrix. 15 // 16 // Dgetri will not perform the inversion if the matrix is singular, and returns 17 // a boolean indicating whether the inversion was successful. 18 // 19 // work is temporary storage, and lwork specifies the usable memory length. 20 // At minimum, lwork >= n and this function will panic otherwise. 21 // Dgetri is a blocked inversion, but the block size is limited 22 // by the temporary space available. If lwork == -1, instead of performing Dgetri, 23 // the optimal work length will be stored into work[0]. 24 func (impl Implementation) Dgetri(n int, a []float64, lda int, ipiv []int, work []float64, lwork int) (ok bool) { 25 checkMatrix(n, n, a, lda) 26 if len(ipiv) < n { 27 panic(badIpiv) 28 } 29 nb := impl.Ilaenv(1, "DGETRI", " ", n, -1, -1, -1) 30 if lwork == -1 { 31 work[0] = float64(n * nb) 32 return true 33 } 34 if lwork < n { 35 panic(badWork) 36 } 37 if len(work) < lwork { 38 panic(badWork) 39 } 40 if n == 0 { 41 return true 42 } 43 ok = impl.Dtrtri(blas.Upper, blas.NonUnit, n, a, lda) 44 if !ok { 45 return false 46 } 47 nbmin := 2 48 ldwork := nb 49 if nb > 1 && nb < n { 50 iws := max(ldwork*n, 1) 51 if lwork < iws { 52 nb = lwork / ldwork 53 nbmin = max(2, impl.Ilaenv(2, "DGETRI", " ", n, -1, -1, -1)) 54 } 55 } 56 bi := blas64.Implementation() 57 // TODO(btracey): Replace this with a more row-major oriented algorithm. 58 if nb < nbmin || nb >= n { 59 // Unblocked code. 60 for j := n - 1; j >= 0; j-- { 61 for i := j + 1; i < n; i++ { 62 work[i*ldwork] = a[i*lda+j] 63 a[i*lda+j] = 0 64 } 65 if j < n { 66 bi.Dgemv(blas.NoTrans, n, n-j-1, -1, a[(j+1):], lda, work[(j+1)*ldwork:], ldwork, 1, a[j:], lda) 67 } 68 } 69 } else { 70 nn := ((n - 1) / nb) * nb 71 for j := nn; j >= 0; j -= nb { 72 jb := min(nb, n-j) 73 for jj := j; jj < j+jb-1; jj++ { 74 for i := jj + 1; i < n; i++ { 75 work[i*ldwork+(jj-j)] = a[i*lda+jj] 76 a[i*lda+jj] = 0 77 } 78 } 79 if j+jb < n { 80 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, jb, n-j-jb, -1, a[(j+jb):], lda, work[(j+jb)*ldwork:], ldwork, 1, a[j:], lda) 81 bi.Dtrsm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, jb, 1, work[j*ldwork:], ldwork, a[j:], lda) 82 } 83 } 84 } 85 for j := n - 2; j >= 0; j-- { 86 jp := ipiv[j] 87 if jp != j { 88 bi.Dswap(n, a[j:], lda, a[jp:], lda) 89 } 90 } 91 return true 92 }