gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dpotf2.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "testing" 9 10 "gonum.org/v1/gonum/blas" 11 "gonum.org/v1/gonum/floats" 12 ) 13 14 type Dpotf2er interface { 15 Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool) 16 } 17 18 func Dpotf2Test(t *testing.T, impl Dpotf2er) { 19 for _, test := range []struct { 20 a [][]float64 21 pos bool 22 U [][]float64 23 }{ 24 { 25 a: [][]float64{ 26 {23, 37, 34, 32}, 27 {108, 71, 48, 48}, 28 {109, 109, 67, 58}, 29 {106, 107, 106, 63}, 30 }, 31 pos: true, 32 U: [][]float64{ 33 {4.795831523312719, 7.715033320111766, 7.089490077940543, 6.672461249826393}, 34 {0, 3.387958215439679, -1.976308959006481, -1.026654004678691}, 35 {0, 0, 3.582364210034111, 2.419258947036024}, 36 {0, 0, 0, 3.401680257083044}, 37 }, 38 }, 39 { 40 a: [][]float64{ 41 {8, 2}, 42 {2, 4}, 43 }, 44 pos: true, 45 U: [][]float64{ 46 {2.82842712474619, 0.707106781186547}, 47 {0, 1.870828693386971}, 48 }, 49 }, 50 } { 51 testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0]), blas.Upper) 52 testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0])+5, blas.Upper) 53 aT := transpose(test.a) 54 L := transpose(test.U) 55 testDpotf2(t, impl, test.pos, aT, L, len(test.a[0]), blas.Lower) 56 testDpotf2(t, impl, test.pos, aT, L, len(test.a[0])+5, blas.Lower) 57 } 58 } 59 60 func testDpotf2(t *testing.T, impl Dpotf2er, testPos bool, a, ans [][]float64, stride int, ul blas.Uplo) { 61 aFlat := flattenTri(a, stride, ul) 62 ansFlat := flattenTri(ans, stride, ul) 63 pos := impl.Dpotf2(ul, len(a[0]), aFlat, stride) 64 if pos != testPos { 65 t.Errorf("Positive definite mismatch: Want %v, Got %v", testPos, pos) 66 return 67 } 68 if testPos && !floats.EqualApprox(ansFlat, aFlat, 1e-14) { 69 t.Errorf("Result mismatch: Want %v, Got %v", ansFlat, aFlat) 70 } 71 } 72 73 // flattenTri with a certain stride. stride must be >= dimension. Puts repeatable 74 // nonce values in non-accessed places 75 func flattenTri(a [][]float64, stride int, ul blas.Uplo) []float64 { 76 m := len(a) 77 n := len(a[0]) 78 if stride < n { 79 panic("bad stride") 80 } 81 upper := ul == blas.Upper 82 v := make([]float64, m*stride) 83 count := 1000.0 84 for i := 0; i < m; i++ { 85 for j := 0; j < stride; j++ { 86 if j >= n || (upper && j < i) || (!upper && j > i) { 87 // not accessed, so give a unique crazy number 88 v[i*stride+j] = count 89 count++ 90 continue 91 } 92 v[i*stride+j] = a[i][j] 93 } 94 } 95 return v 96 } 97 98 func transpose(a [][]float64) [][]float64 { 99 m := len(a) 100 n := len(a[0]) 101 if m != n { 102 panic("not square") 103 } 104 aNew := make([][]float64, m) 105 for i := 0; i < m; i++ { 106 aNew[i] = make([]float64, n) 107 } 108 for i := 0; i < m; i++ { 109 if len(a[i]) != n { 110 panic("bad n size") 111 } 112 for j := 0; j < n; j++ { 113 aNew[j][i] = a[i][j] 114 } 115 } 116 return aNew 117 }