github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dpotf2.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"testing"
     9  
    10  	"github.com/gonum/blas"
    11  	"github.com/gonum/floats"
    12  )
    13  
    14  type Dpotf2er interface {
    15  	Dpotf2(ul blas.Uplo, n int, a []float64, lda int) (ok bool)
    16  }
    17  
    18  func Dpotf2Test(t *testing.T, impl Dpotf2er) {
    19  	for _, test := range []struct {
    20  		a   [][]float64
    21  		ul  blas.Uplo
    22  		pos bool
    23  		U   [][]float64
    24  	}{
    25  		{
    26  			a: [][]float64{
    27  				{23, 37, 34, 32},
    28  				{108, 71, 48, 48},
    29  				{109, 109, 67, 58},
    30  				{106, 107, 106, 63},
    31  			},
    32  			pos: true,
    33  			U: [][]float64{
    34  				{4.795831523312719, 7.715033320111766, 7.089490077940543, 6.672461249826393},
    35  				{0, 3.387958215439679, -1.976308959006481, -1.026654004678691},
    36  				{0, 0, 3.582364210034111, 2.419258947036024},
    37  				{0, 0, 0, 3.401680257083044},
    38  			},
    39  		},
    40  		{
    41  			a: [][]float64{
    42  				{8, 2},
    43  				{2, 4},
    44  			},
    45  			pos: true,
    46  			U: [][]float64{
    47  				{2.82842712474619, 0.707106781186547},
    48  				{0, 1.870828693386971},
    49  			},
    50  		},
    51  	} {
    52  		testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0]), blas.Upper)
    53  		testDpotf2(t, impl, test.pos, test.a, test.U, len(test.a[0])+5, blas.Upper)
    54  		aT := transpose(test.a)
    55  		L := transpose(test.U)
    56  		testDpotf2(t, impl, test.pos, aT, L, len(test.a[0]), blas.Lower)
    57  		testDpotf2(t, impl, test.pos, aT, L, len(test.a[0])+5, blas.Lower)
    58  	}
    59  }
    60  
    61  func testDpotf2(t *testing.T, impl Dpotf2er, testPos bool, a, ans [][]float64, stride int, ul blas.Uplo) {
    62  	aFlat := flattenTri(a, stride, ul)
    63  	ansFlat := flattenTri(ans, stride, ul)
    64  	pos := impl.Dpotf2(ul, len(a[0]), aFlat, stride)
    65  	if pos != testPos {
    66  		t.Errorf("Positive definite mismatch: Want %v, Got %v", testPos, pos)
    67  		return
    68  	}
    69  	if testPos && !floats.EqualApprox(ansFlat, aFlat, 1e-14) {
    70  		t.Errorf("Result mismatch: Want %v, Got  %v", ansFlat, aFlat)
    71  	}
    72  }
    73  
    74  // flattenTri  with a certain stride. stride must be >= dimension. Puts repeatable
    75  // nonce values in non-accessed places
    76  func flattenTri(a [][]float64, stride int, ul blas.Uplo) []float64 {
    77  	m := len(a)
    78  	n := len(a[0])
    79  	if stride < n {
    80  		panic("bad stride")
    81  	}
    82  	upper := ul == blas.Upper
    83  	v := make([]float64, m*stride)
    84  	count := 1000.0
    85  	for i := 0; i < m; i++ {
    86  		for j := 0; j < stride; j++ {
    87  			if j >= n || (upper && j < i) || (!upper && j > i) {
    88  				// not accessed, so give a unique crazy number
    89  				v[i*stride+j] = count
    90  				count++
    91  				continue
    92  			}
    93  			v[i*stride+j] = a[i][j]
    94  		}
    95  	}
    96  	return v
    97  }
    98  
    99  func transpose(a [][]float64) [][]float64 {
   100  	m := len(a)
   101  	n := len(a[0])
   102  	if m != n {
   103  		panic("not square")
   104  	}
   105  	aNew := make([][]float64, m)
   106  	for i := 0; i < m; i++ {
   107  		aNew[i] = make([]float64, n)
   108  	}
   109  	for i := 0; i < m; i++ {
   110  		if len(a[i]) != n {
   111  			panic("bad n size")
   112  		}
   113  		for j := 0; j < n; j++ {
   114  			aNew[j][i] = a[i][j]
   115  		}
   116  	}
   117  	return aNew
   118  }