github.com/gopherd/gonum@v0.0.4/lapack/testlapack/dtrti2.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "testing" 9 10 "math/rand" 11 12 "github.com/gopherd/gonum/blas" 13 "github.com/gopherd/gonum/blas/blas64" 14 "github.com/gopherd/gonum/floats" 15 ) 16 17 type Dtrti2er interface { 18 Dtrti2(uplo blas.Uplo, diag blas.Diag, n int, a []float64, lda int) 19 } 20 21 func Dtrti2Test(t *testing.T, impl Dtrti2er) { 22 const tol = 1e-14 23 for _, test := range []struct { 24 a []float64 25 n int 26 uplo blas.Uplo 27 diag blas.Diag 28 ans []float64 29 }{ 30 { 31 a: []float64{ 32 2, 3, 4, 33 0, 5, 6, 34 8, 0, 8}, 35 n: 3, 36 uplo: blas.Upper, 37 diag: blas.NonUnit, 38 ans: []float64{ 39 0.5, -0.3, -0.025, 40 0, 0.2, -0.15, 41 8, 0, 0.125, 42 }, 43 }, 44 { 45 a: []float64{ 46 5, 3, 4, 47 0, 7, 6, 48 10, 0, 8}, 49 n: 3, 50 uplo: blas.Upper, 51 diag: blas.Unit, 52 ans: []float64{ 53 5, -3, 14, 54 0, 7, -6, 55 10, 0, 8, 56 }, 57 }, 58 { 59 a: []float64{ 60 2, 0, 0, 61 3, 5, 0, 62 4, 6, 8}, 63 n: 3, 64 uplo: blas.Lower, 65 diag: blas.NonUnit, 66 ans: []float64{ 67 0.5, 0, 0, 68 -0.3, 0.2, 0, 69 -0.025, -0.15, 0.125, 70 }, 71 }, 72 { 73 a: []float64{ 74 1, 0, 0, 75 3, 1, 0, 76 4, 6, 1}, 77 n: 3, 78 uplo: blas.Lower, 79 diag: blas.Unit, 80 ans: []float64{ 81 1, 0, 0, 82 -3, 1, 0, 83 14, -6, 1, 84 }, 85 }, 86 } { 87 impl.Dtrti2(test.uplo, test.diag, test.n, test.a, test.n) 88 if !floats.EqualApprox(test.ans, test.a, tol) { 89 t.Errorf("Matrix inverse mismatch. Want %v, got %v.", test.ans, test.a) 90 } 91 } 92 rnd := rand.New(rand.NewSource(1)) 93 bi := blas64.Implementation() 94 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 95 for _, diag := range []blas.Diag{blas.NonUnit, blas.Unit} { 96 for _, test := range []struct { 97 n, lda int 98 }{ 99 {1, 0}, 100 {2, 0}, 101 {3, 0}, 102 {1, 5}, 103 {2, 5}, 104 {3, 5}, 105 } { 106 n := test.n 107 lda := test.lda 108 if lda == 0 { 109 lda = n 110 } 111 // Allocate n×n matrix A and fill it with random numbers. 112 a := make([]float64, n*lda) 113 for i := range a { 114 a[i] = rnd.Float64() 115 } 116 for i := 0; i < n; i++ { 117 // This keeps the matrices well conditioned. 118 a[i*lda+i] += float64(n) 119 } 120 aCopy := make([]float64, len(a)) 121 copy(aCopy, a) 122 // Compute the inverse of the uplo triangle. 123 impl.Dtrti2(uplo, diag, n, a, lda) 124 // Zero out the opposite triangle. 125 if uplo == blas.Upper { 126 for i := 1; i < n; i++ { 127 for j := 0; j < i; j++ { 128 aCopy[i*lda+j] = 0 129 a[i*lda+j] = 0 130 } 131 } 132 } else { 133 for i := 0; i < n; i++ { 134 for j := i + 1; j < n; j++ { 135 aCopy[i*lda+j] = 0 136 a[i*lda+j] = 0 137 } 138 } 139 } 140 if diag == blas.Unit { 141 // Set the diagonal of A^{-1} and A explicitly to 1. 142 for i := 0; i < n; i++ { 143 a[i*lda+i] = 1 144 aCopy[i*lda+i] = 1 145 } 146 } 147 // Compute A^{-1} * A and store the result in ans. 148 ans := make([]float64, len(a)) 149 bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, a, lda, aCopy, lda, 0, ans, lda) 150 // Check that ans is close to the identity matrix. 151 dist := distFromIdentity(n, ans, lda) 152 if dist > tol { 153 t.Errorf("|inv(A) * A - I| = %v. Upper = %v, unit = %v, ans = %v", dist, uplo == blas.Upper, diag == blas.Unit, ans) 154 } 155 } 156 } 157 } 158 }