github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/gonum/dtrti2.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"github.com/jingcheng-WU/gonum/blas"
     9  	"github.com/jingcheng-WU/gonum/blas/blas64"
    10  )
    11  
    12  // Dtrti2 computes the inverse of a triangular matrix, storing the result in place
    13  // into a. This is the BLAS level 2 version of the algorithm.
    14  //
    15  // Dtrti2 is an internal routine. It is exported for testing purposes.
    16  func (impl Implementation) Dtrti2(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) {
    17  	switch {
    18  	case uplo != blas.Upper && uplo != blas.Lower:
    19  		panic(badUplo)
    20  	case diag != blas.NonUnit && diag != blas.Unit:
    21  		panic(badDiag)
    22  	case n < 0:
    23  		panic(nLT0)
    24  	case lda < max(1, n):
    25  		panic(badLdA)
    26  	}
    27  
    28  	if n == 0 {
    29  		return
    30  	}
    31  
    32  	if len(a) < (n-1)*lda+n {
    33  		panic(shortA)
    34  	}
    35  
    36  	bi := blas64.Implementation()
    37  
    38  	nonUnit := diag == blas.NonUnit
    39  	// TODO(btracey): Replace this with a row-major ordering.
    40  	if uplo == blas.Upper {
    41  		for j := 0; j < n; j++ {
    42  			var ajj float64
    43  			if nonUnit {
    44  				ajj = 1 / a[j*lda+j]
    45  				a[j*lda+j] = ajj
    46  				ajj *= -1
    47  			} else {
    48  				ajj = -1
    49  			}
    50  			bi.Dtrmv(blas.Upper, blas.NoTrans, diag, j, a, lda, a[j:], lda)
    51  			bi.Dscal(j, ajj, a[j:], lda)
    52  		}
    53  		return
    54  	}
    55  	for j := n - 1; j >= 0; j-- {
    56  		var ajj float64
    57  		if nonUnit {
    58  			ajj = 1 / a[j*lda+j]
    59  			a[j*lda+j] = ajj
    60  			ajj *= -1
    61  		} else {
    62  			ajj = -1
    63  		}
    64  		if j < n-1 {
    65  			bi.Dtrmv(blas.Lower, blas.NoTrans, diag, n-j-1, a[(j+1)*lda+j+1:], lda, a[(j+1)*lda+j:], lda)
    66  			bi.Dscal(n-j-1, ajj, a[(j+1)*lda+j:], lda)
    67  		}
    68  	}
    69  }