github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dpotf2.go (about) 1 // Copyright ©2015 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "testing" 9 10 "github.com/gonum/blas" 11 "github.com/gonum/floats" 12 ) 13 14 type Dpotf2er interface { 15 Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool) 16 } 17 18 func Dpotf2Test(t *testing.T, impl Dpotf2er) { 19 for _, test := range []struct { 20 a [][]float64 21 ul blas.Uplo 22 pos bool 23 U [][]float64 24 }{ 25 { 26 a: [][]float64{ 27 {23, 37, 34, 32}, 28 {108, 71, 48, 48}, 29 {109, 109, 67, 58}, 30 {106, 107, 106, 63}, 31 }, 32 pos: true, 33 U: [][]float64{ 34 {4.795831523312719, 7.715033320111766, 7.089490077940543, 6.672461249826393}, 35 {0, 3.387958215439679, -1.976308959006481, -1.026654004678691}, 36 {0, 0, 3.582364210034111, 2.419258947036024}, 37 {0, 0, 0, 3.401680257083044}, 38 }, 39 }, 40 { 41 a: [][]float64{ 42 {8, 2}, 43 {2, 4}, 44 }, 45 pos: true, 46 U: [][]float64{ 47 {2.82842712474619, 0.707106781186547}, 48 {0, 1.870828693386971}, 49 }, 50 }, 51 } { 52 testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0]), blas.Upper) 53 testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0])+5, blas.Upper) 54 aT := transpose(test.a) 55 L := transpose(test.U) 56 testDpotf2(t, impl, test.pos, aT, L, len(test.a[0]), blas.Lower) 57 testDpotf2(t, impl, test.pos, aT, L, len(test.a[0])+5, blas.Lower) 58 } 59 } 60 61 func testDpotf2(t *testing.T, impl Dpotf2er, testPos bool, a, ans [][]float64, stride int, ul blas.Uplo) { 62 aFlat := flattenTri(a, stride, ul) 63 ansFlat := flattenTri(ans, stride, ul) 64 pos := impl.Dpotf2(ul, len(a[0]), aFlat, stride) 65 if pos != testPos { 66 t.Errorf("Positive definite mismatch: Want %v, Got %v", testPos, pos) 67 return 68 } 69 if testPos && !floats.EqualApprox(ansFlat, aFlat, 1e-14) { 70 t.Errorf("Result mismatch: Want %v, Got %v", ansFlat, aFlat) 71 } 72 } 73 74 // flattenTri with a certain stride. stride must be >= dimension. Puts repeatable 75 // nonce values in non-accessed places 76 func flattenTri(a [][]float64, stride int, ul blas.Uplo) []float64 { 77 m := len(a) 78 n := len(a[0]) 79 if stride < n { 80 panic("bad stride") 81 } 82 upper := ul == blas.Upper 83 v := make([]float64, m*stride) 84 count := 1000.0 85 for i := 0; i < m; i++ { 86 for j := 0; j < stride; j++ { 87 if j >= n || (upper && j < i) || (!upper && j > i) { 88 // not accessed, so give a unique crazy number 89 v[i*stride+j] = count 90 count++ 91 continue 92 } 93 v[i*stride+j] = a[i][j] 94 } 95 } 96 return v 97 } 98 99 func transpose(a [][]float64) [][]float64 { 100 m := len(a) 101 n := len(a[0]) 102 if m != n { 103 panic("not square") 104 } 105 aNew := make([][]float64, m) 106 for i := 0; i < m; i++ { 107 aNew[i] = make([]float64, n) 108 } 109 for i := 0; i < m; i++ { 110 if len(a[i]) != n { 111 panic("bad n size") 112 } 113 for j := 0; j < n; j++ { 114 aNew[j][i] = a[i][j] 115 } 116 } 117 return aNew 118 }