gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dtrti2.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"testing"
     9  
    10  	"golang.org/x/exp/rand"
    11  
    12  	"gonum.org/v1/gonum/blas"
    13  	"gonum.org/v1/gonum/blas/blas64"
    14  	"gonum.org/v1/gonum/floats"
    15  )
    16  
    17  type Dtrti2er interface {
    18  	Dtrti2(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int)
    19  }
    20  
    21  func Dtrti2Test(t *testing.T, impl Dtrti2er) {
    22  	const tol = 1e-14
    23  	for _, test := range []struct {
    24  		a    []float64
    25  		n    int
    26  		uplo blas.Uplo
    27  		diag blas.Diag
    28  		ans  []float64
    29  	}{
    30  		{
    31  			a: []float64{
    32  				2, 3, 4,
    33  				0, 5, 6,
    34  				8, 0, 8},
    35  			n:    3,
    36  			uplo: blas.Upper,
    37  			diag: blas.NonUnit,
    38  			ans: []float64{
    39  				0.5, -0.3, -0.025,
    40  				0, 0.2, -0.15,
    41  				8, 0, 0.125,
    42  			},
    43  		},
    44  		{
    45  			a: []float64{
    46  				5, 3, 4,
    47  				0, 7, 6,
    48  				10, 0, 8},
    49  			n:    3,
    50  			uplo: blas.Upper,
    51  			diag: blas.Unit,
    52  			ans: []float64{
    53  				5, -3, 14,
    54  				0, 7, -6,
    55  				10, 0, 8,
    56  			},
    57  		},
    58  		{
    59  			a: []float64{
    60  				2, 0, 0,
    61  				3, 5, 0,
    62  				4, 6, 8},
    63  			n:    3,
    64  			uplo: blas.Lower,
    65  			diag: blas.NonUnit,
    66  			ans: []float64{
    67  				0.5, 0, 0,
    68  				-0.3, 0.2, 0,
    69  				-0.025, -0.15, 0.125,
    70  			},
    71  		},
    72  		{
    73  			a: []float64{
    74  				1, 0, 0,
    75  				3, 1, 0,
    76  				4, 6, 1},
    77  			n:    3,
    78  			uplo: blas.Lower,
    79  			diag: blas.Unit,
    80  			ans: []float64{
    81  				1, 0, 0,
    82  				-3, 1, 0,
    83  				14, -6, 1,
    84  			},
    85  		},
    86  	} {
    87  		impl.Dtrti2(test.uplo, test.diag, test.n, test.a, test.n)
    88  		if !floats.EqualApprox(test.ans, test.a, tol) {
    89  			t.Errorf("Matrix inverse mismatch. Want %v, got %v.", test.ans, test.a)
    90  		}
    91  	}
    92  	rnd := rand.New(rand.NewSource(1))
    93  	bi := blas64.Implementation()
    94  	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
    95  		for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} {
    96  			for _, test := range []struct {
    97  				n, lda int
    98  			}{
    99  				{1, 0},
   100  				{2, 0},
   101  				{3, 0},
   102  				{1, 5},
   103  				{2, 5},
   104  				{3, 5},
   105  			} {
   106  				n := test.n
   107  				lda := test.lda
   108  				if lda == 0 {
   109  					lda = n
   110  				}
   111  				// Allocate n×n matrix A and fill it with random numbers.
   112  				a := make([]float64, n*lda)
   113  				for i := range a {
   114  					a[i] = rnd.Float64()
   115  				}
   116  				for i := 0; i < n; i++ {
   117  					// This keeps the matrices well conditioned.
   118  					a[i*lda+i] += float64(n)
   119  				}
   120  				aCopy := make([]float64, len(a))
   121  				copy(aCopy, a)
   122  				// Compute the inverse of the uplo triangle.
   123  				impl.Dtrti2(uplo, diag, n, a, lda)
   124  				// Zero out the opposite triangle.
   125  				if uplo == blas.Upper {
   126  					for i := 1; i < n; i++ {
   127  						for j := 0; j < i; j++ {
   128  							aCopy[i*lda+j] = 0
   129  							a[i*lda+j] = 0
   130  						}
   131  					}
   132  				} else {
   133  					for i := 0; i < n; i++ {
   134  						for j := i + 1; j < n; j++ {
   135  							aCopy[i*lda+j] = 0
   136  							a[i*lda+j] = 0
   137  						}
   138  					}
   139  				}
   140  				if diag == blas.Unit {
   141  					// Set the diagonal of A^{-1} and A explicitly to 1.
   142  					for i := 0; i < n; i++ {
   143  						a[i*lda+i] = 1
   144  						aCopy[i*lda+i] = 1
   145  					}
   146  				}
   147  				// Compute A^{-1} * A and store the result in ans.
   148  				ans := make([]float64, len(a))
   149  				bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda)
   150  				// Check that ans is close to the identity matrix.
   151  				dist := distFromIdentity(n, ans, lda)
   152  				if dist > tol {
   153  					t.Errorf("|inv(A) * A - I| = %v. Upper = %v, unit = %v, ans = %v", dist, uplo == blas.Upper, diag == blas.Unit, ans)
   154  				}
   155  			}
   156  		}
   157  	}
   158  }