github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dsyev.go (about)

     1  // Copyright ©2016 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math/rand"
     9  	"testing"
    10  
    11  	"github.com/gonum/blas"
    12  	"github.com/gonum/blas/blas64"
    13  	"github.com/gonum/floats"
    14  	"github.com/gonum/lapack"
    15  )
    16  
    17  type Dsyever interface {
    18  	Dsyev(jobz lapack.EVJob, uplo blas.Uplo, n int, a []float64, lda int, w, work []float64, lwork int) (ok bool)
    19  }
    20  
    21  func DsyevTest(t *testing.T, impl Dsyever) {
    22  	rnd := rand.New(rand.NewSource(1))
    23  	for _, uplo := range []blas.Uplo{blas.Lower, blas.Upper} {
    24  		for _, test := range []struct {
    25  			n, lda int
    26  		}{
    27  			{1, 0},
    28  			{2, 0},
    29  			{5, 0},
    30  			{10, 0},
    31  			{100, 0},
    32  
    33  			{1, 5},
    34  			{2, 5},
    35  			{5, 10},
    36  			{10, 20},
    37  			{100, 110},
    38  		} {
    39  			for cas := 0; cas < 10; cas++ {
    40  				n := test.n
    41  				lda := test.lda
    42  				if lda == 0 {
    43  					lda = n
    44  				}
    45  				a := make([]float64, n*lda)
    46  				for i := range a {
    47  					a[i] = rnd.NormFloat64()
    48  				}
    49  				aCopy := make([]float64, len(a))
    50  				copy(aCopy, a)
    51  				w := make([]float64, n)
    52  				for i := range w {
    53  					w[i] = rnd.NormFloat64()
    54  				}
    55  
    56  				work := make([]float64, 1)
    57  				impl.Dsyev(lapack.ComputeEV, uplo, n, a, lda, w, work, -1)
    58  				work = make([]float64, int(work[0]))
    59  				impl.Dsyev(lapack.ComputeEV, uplo, n, a, lda, w, work, len(work))
    60  
    61  				// Check that the decomposition is correct
    62  				orig := blas64.General{
    63  					Rows:   n,
    64  					Cols:   n,
    65  					Stride: n,
    66  					Data:   make([]float64, n*n),
    67  				}
    68  				if uplo == blas.Upper {
    69  					for i := 0; i < n; i++ {
    70  						for j := i; j < n; j++ {
    71  							v := aCopy[i*lda+j]
    72  							orig.Data[i*orig.Stride+j] = v
    73  							orig.Data[j*orig.Stride+i] = v
    74  						}
    75  					}
    76  				} else {
    77  					for i := 0; i < n; i++ {
    78  						for j := 0; j <= i; j++ {
    79  							v := aCopy[i*lda+j]
    80  							orig.Data[i*orig.Stride+j] = v
    81  							orig.Data[j*orig.Stride+i] = v
    82  						}
    83  					}
    84  				}
    85  
    86  				V := blas64.General{
    87  					Rows:   n,
    88  					Cols:   n,
    89  					Stride: lda,
    90  					Data:   a,
    91  				}
    92  
    93  				if !eigenDecompCorrect(w, orig, V) {
    94  					t.Errorf("Decomposition mismatch")
    95  				}
    96  
    97  				// Check that the decomposition is correct when the eigenvectors
    98  				// are not computed.
    99  				wAns := make([]float64, len(w))
   100  				copy(wAns, w)
   101  				copy(a, aCopy)
   102  				for i := range w {
   103  					w[i] = rnd.Float64()
   104  				}
   105  				for i := range work {
   106  					work[i] = rnd.Float64()
   107  				}
   108  				impl.Dsyev(lapack.None, uplo, n, a, lda, w, work, len(work))
   109  				if !floats.EqualApprox(w, wAns, 1e-8) {
   110  					t.Errorf("Eigenvalue mismatch when vectors not computed")
   111  				}
   112  			}
   113  		}
   114  	}
   115  }