github.com/gopherd/gonum@v0.0.4/lapack/testlapack/dpotf2.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"testing"
     9  
    10  	"github.com/gopherd/gonum/blas"
    11  	"github.com/gopherd/gonum/floats"
    12  )
    13  
    14  type Dpotf2er interface {
    15  	Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
    16  }
    17  
    18  func Dpotf2Test(t *testing.T, impl Dpotf2er) {
    19  	for _, test := range []struct {
    20  		a   [][]float64
    21  		pos bool
    22  		U   [][]float64
    23  	}{
    24  		{
    25  			a: [][]float64{
    26  				{23, 37, 34, 32},
    27  				{108, 71, 48, 48},
    28  				{109, 109, 67, 58},
    29  				{106, 107, 106, 63},
    30  			},
    31  			pos: true,
    32  			U: [][]float64{
    33  				{4.795831523312719, 7.715033320111766, 7.089490077940543, 6.672461249826393},
    34  				{0, 3.387958215439679, -1.976308959006481, -1.026654004678691},
    35  				{0, 0, 3.582364210034111, 2.419258947036024},
    36  				{0, 0, 0, 3.401680257083044},
    37  			},
    38  		},
    39  		{
    40  			a: [][]float64{
    41  				{8, 2},
    42  				{2, 4},
    43  			},
    44  			pos: true,
    45  			U: [][]float64{
    46  				{2.82842712474619, 0.707106781186547},
    47  				{0, 1.870828693386971},
    48  			},
    49  		},
    50  	} {
    51  		testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0]), blas.Upper)
    52  		testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0])+5, blas.Upper)
    53  		aT := transpose(test.a)
    54  		L := transpose(test.U)
    55  		testDpotf2(t, impl, test.pos, aT, L, len(test.a[0]), blas.Lower)
    56  		testDpotf2(t, impl, test.pos, aT, L, len(test.a[0])+5, blas.Lower)
    57  	}
    58  }
    59  
    60  func testDpotf2(t *testing.T, impl Dpotf2er, testPos bool, a, ans [][]float64, stride int, ul blas.Uplo) {
    61  	aFlat := flattenTri(a, stride, ul)
    62  	ansFlat := flattenTri(ans, stride, ul)
    63  	pos := impl.Dpotf2(ul, len(a[0]), aFlat, stride)
    64  	if pos != testPos {
    65  		t.Errorf("Positive definite mismatch: Want %v, Got %v", testPos, pos)
    66  		return
    67  	}
    68  	if testPos && !floats.EqualApprox(ansFlat, aFlat, 1e-14) {
    69  		t.Errorf("Result mismatch: Want %v, Got  %v", ansFlat, aFlat)
    70  	}
    71  }
    72  
    73  // flattenTri  with a certain stride. stride must be >= dimension. Puts repeatable
    74  // nonce values in non-accessed places
    75  func flattenTri(a [][]float64, stride int, ul blas.Uplo) []float64 {
    76  	m := len(a)
    77  	n := len(a[0])
    78  	if stride < n {
    79  		panic("bad stride")
    80  	}
    81  	upper := ul == blas.Upper
    82  	v := make([]float64, m*stride)
    83  	count := 1000.0
    84  	for i := 0; i < m; i++ {
    85  		for j := 0; j < stride; j++ {
    86  			if j >= n || (upper && j < i) || (!upper && j > i) {
    87  				// not accessed, so give a unique crazy number
    88  				v[i*stride+j] = count
    89  				count++
    90  				continue
    91  			}
    92  			v[i*stride+j] = a[i][j]
    93  		}
    94  	}
    95  	return v
    96  }
    97  
    98  func transpose(a [][]float64) [][]float64 {
    99  	m := len(a)
   100  	n := len(a[0])
   101  	if m != n {
   102  		panic("not square")
   103  	}
   104  	aNew := make([][]float64, m)
   105  	for i := 0; i < m; i++ {
   106  		aNew[i] = make([]float64, n)
   107  	}
   108  	for i := 0; i < m; i++ {
   109  		if len(a[i]) != n {
   110  			panic("bad n size")
   111  		}
   112  		for j := 0; j < n; j++ {
   113  			aNew[j][i] = a[i][j]
   114  		}
   115  	}
   116  	return aNew
   117  }