github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dtrtri.go (about) 1 // Copyright ©2015 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "math" 9 "math/rand" 10 "testing" 11 12 "github.com/gonum/blas" 13 "github.com/gonum/blas/blas64" 14 ) 15 16 type Dtrtrier interface { 17 Dtrconer 18 Dtrtri(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) bool 19 } 20 21 func DtrtriTest(t *testing.T, impl Dtrtrier) { 22 const tol = 1e-6 23 rnd := rand.New(rand.NewSource(1)) 24 bi := blas64.Implementation() 25 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 26 for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} { 27 for _, test := range []struct { 28 n, lda int 29 }{ 30 {3, 0}, 31 {70, 0}, 32 {200, 0}, 33 {3, 5}, 34 {70, 92}, 35 {200, 205}, 36 } { 37 n := test.n 38 lda := test.lda 39 if lda == 0 { 40 lda = n 41 } 42 a := make([]float64, n*lda) 43 for i := range a { 44 a[i] = rnd.Float64() + 1 // This keeps the matrices well conditioned. 45 } 46 aCopy := make([]float64, len(a)) 47 copy(aCopy, a) 48 impl.Dtrtri(uplo, diag, n, a, lda) 49 if uplo == blas.Upper { 50 for i := 1; i < n; i++ { 51 for j := 0; j < i; j++ { 52 aCopy[i*lda+j] = 0 53 a[i*lda+j] = 0 54 } 55 } 56 } else { 57 for i := 0; i < n; i++ { 58 for j := i + 1; j < n; j++ { 59 aCopy[i*lda+j] = 0 60 a[i*lda+j] = 0 61 } 62 } 63 } 64 if diag == blas.Unit { 65 for i := 0; i < n; i++ { 66 a[i*lda+i] = 1 67 aCopy[i*lda+i] = 1 68 } 69 } 70 ans := make([]float64, len(a)) 71 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda) 72 iseye := true 73 for i := 0; i < n; i++ { 74 for j := 0; j < n; j++ { 75 if i == j { 76 if math.Abs(ans[i*lda+i]-1) > tol { 77 iseye = false 78 break 79 } 80 } else { 81 if math.Abs(ans[i*lda+j]) > tol { 82 iseye = false 83 break 84 } 85 } 86 } 87 } 88 if !iseye { 89 t.Errorf("inv(A) * A != I. Upper = %v, unit = %v, n = %v, lda = %v", 90 uplo == blas.Upper, diag == blas.Unit, n, lda) 91 } 92 } 93 } 94 } 95 }