gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dsyev.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"testing"
     9  
    10  	"golang.org/x/exp/rand"
    11  
    12  	"gonum.org/v1/gonum/blas"
    13  	"gonum.org/v1/gonum/blas/blas64"
    14  	"gonum.org/v1/gonum/floats"
    15  	"gonum.org/v1/gonum/lapack"
    16  )
    17  
    18  type Dsyever interface {
    19  	Dsyev(jobz lapack.EVJob, uplo blas.Uplo, n int, a []float64, lda int, w, work []float64, lwork int) (ok bool)
    20  }
    21  
    22  func DsyevTest(t *testing.T, impl Dsyever) {
    23  	rnd := rand.New(rand.NewSource(1))
    24  	for _, uplo := range []blas.Uplo{blas.Lower, blas.Upper} {
    25  		for _, test := range []struct {
    26  			n, lda int
    27  		}{
    28  			{1, 0},
    29  			{2, 0},
    30  			{5, 0},
    31  			{10, 0},
    32  			{100, 0},
    33  
    34  			{1, 5},
    35  			{2, 5},
    36  			{5, 10},
    37  			{10, 20},
    38  			{100, 110},
    39  		} {
    40  			for cas := 0; cas < 10; cas++ {
    41  				n := test.n
    42  				lda := test.lda
    43  				if lda == 0 {
    44  					lda = n
    45  				}
    46  				a := make([]float64, n*lda)
    47  				for i := range a {
    48  					a[i] = rnd.NormFloat64()
    49  				}
    50  				aCopy := make([]float64, len(a))
    51  				copy(aCopy, a)
    52  				w := make([]float64, n)
    53  				for i := range w {
    54  					w[i] = rnd.NormFloat64()
    55  				}
    56  
    57  				work := make([]float64, 1)
    58  				impl.Dsyev(lapack.EVCompute, uplo, n, a, lda, w, work, -1)
    59  				work = make([]float64, int(work[0]))
    60  				impl.Dsyev(lapack.EVCompute, uplo, n, a, lda, w, work, len(work))
    61  
    62  				// Check that the decomposition is correct
    63  				orig := blas64.General{
    64  					Rows:   n,
    65  					Cols:   n,
    66  					Stride: n,
    67  					Data:   make([]float64, n*n),
    68  				}
    69  				if uplo == blas.Upper {
    70  					for i := 0; i < n; i++ {
    71  						for j := i; j < n; j++ {
    72  							v := aCopy[i*lda+j]
    73  							orig.Data[i*orig.Stride+j] = v
    74  							orig.Data[j*orig.Stride+i] = v
    75  						}
    76  					}
    77  				} else {
    78  					for i := 0; i < n; i++ {
    79  						for j := 0; j <= i; j++ {
    80  							v := aCopy[i*lda+j]
    81  							orig.Data[i*orig.Stride+j] = v
    82  							orig.Data[j*orig.Stride+i] = v
    83  						}
    84  					}
    85  				}
    86  
    87  				V := blas64.General{
    88  					Rows:   n,
    89  					Cols:   n,
    90  					Stride: lda,
    91  					Data:   a,
    92  				}
    93  
    94  				if !eigenDecompCorrect(w, orig, V) {
    95  					t.Errorf("Decomposition mismatch")
    96  				}
    97  
    98  				// Check that the decomposition is correct when the eigenvectors
    99  				// are not computed.
   100  				wAns := make([]float64, len(w))
   101  				copy(wAns, w)
   102  				copy(a, aCopy)
   103  				for i := range w {
   104  					w[i] = rnd.Float64()
   105  				}
   106  				for i := range work {
   107  					work[i] = rnd.Float64()
   108  				}
   109  				impl.Dsyev(lapack.EVNone, uplo, n, a, lda, w, work, len(work))
   110  				if !floats.EqualApprox(w, wAns, 1e-8) {
   111  					t.Errorf("Eigenvalue mismatch when vectors not computed")
   112  				}
   113  			}
   114  		}
   115  	}
   116  }