github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dtrti2.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math"
     9  	"math/rand"
    10  	"testing"
    11  
    12  	"github.com/gonum/blas"
    13  	"github.com/gonum/blas/blas64"
    14  	"github.com/gonum/floats"
    15  )
    16  
    17  type Dtrti2er interface {
    18  	Dtrti2(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int)
    19  }
    20  
    21  func Dtrti2Test(t *testing.T, impl Dtrti2er) {
    22  	const tol = 1e-14
    23  	for _, test := range []struct {
    24  		a    []float64
    25  		n    int
    26  		uplo blas.Uplo
    27  		diag blas.Diag
    28  		ans  []float64
    29  	}{
    30  		{
    31  			a: []float64{
    32  				2, 3, 4,
    33  				0, 5, 6,
    34  				8, 0, 8},
    35  			n:    3,
    36  			uplo: blas.Upper,
    37  			diag: blas.NonUnit,
    38  			ans: []float64{
    39  				0.5, -0.3, -0.025,
    40  				0, 0.2, -0.15,
    41  				8, 0, 0.125,
    42  			},
    43  		},
    44  		{
    45  			a: []float64{
    46  				5, 3, 4,
    47  				0, 7, 6,
    48  				10, 0, 8},
    49  			n:    3,
    50  			uplo: blas.Upper,
    51  			diag: blas.Unit,
    52  			ans: []float64{
    53  				5, -3, 14,
    54  				0, 7, -6,
    55  				10, 0, 8,
    56  			},
    57  		},
    58  		{
    59  			a: []float64{
    60  				2, 0, 0,
    61  				3, 5, 0,
    62  				4, 6, 8},
    63  			n:    3,
    64  			uplo: blas.Lower,
    65  			diag: blas.NonUnit,
    66  			ans: []float64{
    67  				0.5, 0, 0,
    68  				-0.3, 0.2, 0,
    69  				-0.025, -0.15, 0.125,
    70  			},
    71  		},
    72  		{
    73  			a: []float64{
    74  				1, 0, 0,
    75  				3, 1, 0,
    76  				4, 6, 1},
    77  			n:    3,
    78  			uplo: blas.Lower,
    79  			diag: blas.Unit,
    80  			ans: []float64{
    81  				1, 0, 0,
    82  				-3, 1, 0,
    83  				14, -6, 1,
    84  			},
    85  		},
    86  	} {
    87  		impl.Dtrti2(test.uplo, test.diag, test.n, test.a, test.n)
    88  		if !floats.EqualApprox(test.ans, test.a, tol) {
    89  			t.Errorf("Matrix inverse mismatch. Want %v, got %v.", test.ans, test.a)
    90  		}
    91  	}
    92  	rnd := rand.New(rand.NewSource(1))
    93  	bi := blas64.Implementation()
    94  	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
    95  		for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} {
    96  			for _, test := range []struct {
    97  				n, lda int
    98  			}{
    99  				{1, 0},
   100  				{2, 0},
   101  				{3, 0},
   102  				{1, 5},
   103  				{2, 5},
   104  				{3, 5},
   105  			} {
   106  				n := test.n
   107  				lda := test.lda
   108  				if lda == 0 {
   109  					lda = n
   110  				}
   111  				a := make([]float64, n*lda)
   112  				for i := range a {
   113  					a[i] = rnd.Float64()
   114  				}
   115  				aCopy := make([]float64, len(a))
   116  				copy(aCopy, a)
   117  				impl.Dtrti2(uplo, diag, n, a, lda)
   118  				if uplo == blas.Upper {
   119  					for i := 1; i < n; i++ {
   120  						for j := 0; j < i; j++ {
   121  							aCopy[i*lda+j] = 0
   122  							a[i*lda+j] = 0
   123  						}
   124  					}
   125  				} else {
   126  					for i := 0; i < n; i++ {
   127  						for j := i + 1; j < n; j++ {
   128  							aCopy[i*lda+j] = 0
   129  							a[i*lda+j] = 0
   130  						}
   131  					}
   132  				}
   133  				if diag == blas.Unit {
   134  					for i := 0; i < n; i++ {
   135  						a[i*lda+i] = 1
   136  						aCopy[i*lda+i] = 1
   137  					}
   138  				}
   139  				ans := make([]float64, len(a))
   140  				bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda)
   141  				iseye := true
   142  				for i := 0; i < n; i++ {
   143  					for j := 0; j < n; j++ {
   144  						if i == j {
   145  							if math.Abs(ans[i*lda+i]-1) > tol {
   146  								iseye = false
   147  								break
   148  							}
   149  						} else {
   150  							if math.Abs(ans[i*lda+j]) > tol {
   151  								iseye = false
   152  								break
   153  							}
   154  						}
   155  					}
   156  				}
   157  				if !iseye {
   158  					t.Errorf("inv(A) * A != I. Upper = %v, unit = %v, ans = %v", uplo == blas.Upper, diag == blas.Unit, ans)
   159  				}
   160  			}
   161  		}
   162  	}
   163  }