github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dtrtri.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math"
     9  	"math/rand"
    10  	"testing"
    11  
    12  	"github.com/gonum/blas"
    13  	"github.com/gonum/blas/blas64"
    14  )
    15  
    16  type Dtrtrier interface {
    17  	Dtrconer
    18  	Dtrtri(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) bool
    19  }
    20  
    21  func DtrtriTest(t *testing.T, impl Dtrtrier) {
    22  	const tol = 1e-6
    23  	rnd := rand.New(rand.NewSource(1))
    24  	bi := blas64.Implementation()
    25  	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
    26  		for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} {
    27  			for _, test := range []struct {
    28  				n, lda int
    29  			}{
    30  				{3, 0},
    31  				{70, 0},
    32  				{200, 0},
    33  				{3, 5},
    34  				{70, 92},
    35  				{200, 205},
    36  			} {
    37  				n := test.n
    38  				lda := test.lda
    39  				if lda == 0 {
    40  					lda = n
    41  				}
    42  				a := make([]float64, n*lda)
    43  				for i := range a {
    44  					a[i] = rnd.Float64() + 1 // This keeps the matrices well conditioned.
    45  				}
    46  				aCopy := make([]float64, len(a))
    47  				copy(aCopy, a)
    48  				impl.Dtrtri(uplo, diag, n, a, lda)
    49  				if uplo == blas.Upper {
    50  					for i := 1; i < n; i++ {
    51  						for j := 0; j < i; j++ {
    52  							aCopy[i*lda+j] = 0
    53  							a[i*lda+j] = 0
    54  						}
    55  					}
    56  				} else {
    57  					for i := 0; i < n; i++ {
    58  						for j := i + 1; j < n; j++ {
    59  							aCopy[i*lda+j] = 0
    60  							a[i*lda+j] = 0
    61  						}
    62  					}
    63  				}
    64  				if diag == blas.Unit {
    65  					for i := 0; i < n; i++ {
    66  						a[i*lda+i] = 1
    67  						aCopy[i*lda+i] = 1
    68  					}
    69  				}
    70  				ans := make([]float64, len(a))
    71  				bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda)
    72  				iseye := true
    73  				for i := 0; i < n; i++ {
    74  					for j := 0; j < n; j++ {
    75  						if i == j {
    76  							if math.Abs(ans[i*lda+i]-1) > tol {
    77  								iseye = false
    78  								break
    79  							}
    80  						} else {
    81  							if math.Abs(ans[i*lda+j]) > tol {
    82  								iseye = false
    83  								break
    84  							}
    85  						}
    86  					}
    87  				}
    88  				if !iseye {
    89  					t.Errorf("inv(A) * A != I. Upper = %v, unit = %v, n = %v, lda = %v",
    90  						uplo == blas.Upper, diag == blas.Unit, n, lda)
    91  				}
    92  			}
    93  		}
    94  	}
    95  }