gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dtrtri.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "testing" 9 10 "golang.org/x/exp/rand" 11 12 "gonum.org/v1/gonum/blas" 13 "gonum.org/v1/gonum/blas/blas64" 14 ) 15 16 type Dtrtrier interface { 17 Dtrtri(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) bool 18 } 19 20 func DtrtriTest(t *testing.T, impl Dtrtrier) { 21 const tol = 1e-10 22 rnd := rand.New(rand.NewSource(1)) 23 bi := blas64.Implementation() 24 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 25 for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} { 26 for _, test := range []struct { 27 n, lda int 28 }{ 29 {3, 0}, 30 {70, 0}, 31 {200, 0}, 32 {3, 5}, 33 {70, 92}, 34 {200, 205}, 35 } { 36 n := test.n 37 lda := test.lda 38 if lda == 0 { 39 lda = n 40 } 41 // Allocate n×n matrix A and fill it with random numbers. 42 a := make([]float64, n*lda) 43 for i := range a { 44 a[i] = rnd.Float64() 45 } 46 for i := 0; i < n; i++ { 47 // This keeps the matrices well conditioned. 48 a[i*lda+i] += float64(n) 49 } 50 aCopy := make([]float64, len(a)) 51 copy(aCopy, a) 52 // Compute the inverse of the uplo triangle. 53 impl.Dtrtri(uplo, diag, n, a, lda) 54 // Zero out the opposite triangle. 55 if uplo == blas.Upper { 56 for i := 1; i < n; i++ { 57 for j := 0; j < i; j++ { 58 aCopy[i*lda+j] = 0 59 a[i*lda+j] = 0 60 } 61 } 62 } else { 63 for i := 0; i < n; i++ { 64 for j := i + 1; j < n; j++ { 65 aCopy[i*lda+j] = 0 66 a[i*lda+j] = 0 67 } 68 } 69 } 70 if diag == blas.Unit { 71 // Set the diagonal explicitly to 1. 72 for i := 0; i < n; i++ { 73 a[i*lda+i] = 1 74 aCopy[i*lda+i] = 1 75 } 76 } 77 // Compute A^{-1} * A and store the result in ans. 78 ans := make([]float64, len(a)) 79 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda) 80 // Check that ans is the identity matrix. 81 dist := distFromIdentity(n, ans, lda) 82 if dist > tol { 83 t.Errorf("|inv(A) * A - I| = %v is too large. Upper = %v, unit = %v, n = %v, lda = %v", 84 dist, uplo == blas.Upper, diag == blas.Unit, n, lda) 85 } 86 } 87 } 88 } 89 }