github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dtrti2.go (about) 1 // Copyright ©2015 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "math" 9 "math/rand" 10 "testing" 11 12 "github.com/gonum/blas" 13 "github.com/gonum/blas/blas64" 14 "github.com/gonum/floats" 15 ) 16 17 type Dtrti2er interface { 18 Dtrti2(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) 19 } 20 21 func Dtrti2Test(t *testing.T, impl Dtrti2er) { 22 const tol = 1e-14 23 for _, test := range []struct { 24 a []float64 25 n int 26 uplo blas.Uplo 27 diag blas.Diag 28 ans []float64 29 }{ 30 { 31 a: []float64{ 32 2, 3, 4, 33 0, 5, 6, 34 8, 0, 8}, 35 n: 3, 36 uplo: blas.Upper, 37 diag: blas.NonUnit, 38 ans: []float64{ 39 0.5, -0.3, -0.025, 40 0, 0.2, -0.15, 41 8, 0, 0.125, 42 }, 43 }, 44 { 45 a: []float64{ 46 5, 3, 4, 47 0, 7, 6, 48 10, 0, 8}, 49 n: 3, 50 uplo: blas.Upper, 51 diag: blas.Unit, 52 ans: []float64{ 53 5, -3, 14, 54 0, 7, -6, 55 10, 0, 8, 56 }, 57 }, 58 { 59 a: []float64{ 60 2, 0, 0, 61 3, 5, 0, 62 4, 6, 8}, 63 n: 3, 64 uplo: blas.Lower, 65 diag: blas.NonUnit, 66 ans: []float64{ 67 0.5, 0, 0, 68 -0.3, 0.2, 0, 69 -0.025, -0.15, 0.125, 70 }, 71 }, 72 { 73 a: []float64{ 74 1, 0, 0, 75 3, 1, 0, 76 4, 6, 1}, 77 n: 3, 78 uplo: blas.Lower, 79 diag: blas.Unit, 80 ans: []float64{ 81 1, 0, 0, 82 -3, 1, 0, 83 14, -6, 1, 84 }, 85 }, 86 } { 87 impl.Dtrti2(test.uplo, test.diag, test.n, test.a, test.n) 88 if !floats.EqualApprox(test.ans, test.a, tol) { 89 t.Errorf("Matrix inverse mismatch. Want %v, got %v.", test.ans, test.a) 90 } 91 } 92 rnd := rand.New(rand.NewSource(1)) 93 bi := blas64.Implementation() 94 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 95 for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} { 96 for _, test := range []struct { 97 n, lda int 98 }{ 99 {1, 0}, 100 {2, 0}, 101 {3, 0}, 102 {1, 5}, 103 {2, 5}, 104 {3, 5}, 105 } { 106 n := test.n 107 lda := test.lda 108 if lda == 0 { 109 lda = n 110 } 111 a := make([]float64, n*lda) 112 for i := range a { 113 a[i] = rnd.Float64() 114 } 115 aCopy := make([]float64, len(a)) 116 copy(aCopy, a) 117 impl.Dtrti2(uplo, diag, n, a, lda) 118 if uplo == blas.Upper { 119 for i := 1; i < n; i++ { 120 for j := 0; j < i; j++ { 121 aCopy[i*lda+j] = 0 122 a[i*lda+j] = 0 123 } 124 } 125 } else { 126 for i := 0; i < n; i++ { 127 for j := i + 1; j < n; j++ { 128 aCopy[i*lda+j] = 0 129 a[i*lda+j] = 0 130 } 131 } 132 } 133 if diag == blas.Unit { 134 for i := 0; i < n; i++ { 135 a[i*lda+i] = 1 136 aCopy[i*lda+i] = 1 137 } 138 } 139 ans := make([]float64, len(a)) 140 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda) 141 iseye := true 142 for i := 0; i < n; i++ { 143 for j := 0; j < n; j++ { 144 if i == j { 145 if math.Abs(ans[i*lda+i]-1) > tol { 146 iseye = false 147 break 148 } 149 } else { 150 if math.Abs(ans[i*lda+j]) > tol { 151 iseye = false 152 break 153 } 154 } 155 } 156 } 157 if !iseye { 158 t.Errorf("inv(A) * A != I. Upper = %v, unit = %v, ans = %v", uplo == blas.Upper, diag == blas.Unit, ans) 159 } 160 } 161 } 162 } 163 }