gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dlasq2.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"sort"
    11  	"testing"
    12  
    13  	"golang.org/x/exp/rand"
    14  	"gonum.org/v1/gonum/blas"
    15  	"gonum.org/v1/gonum/floats"
    16  	"gonum.org/v1/gonum/lapack"
    17  )
    18  
    19  type Dlasq2er interface {
    20  	Dlasq2(n int, z []float64) (info int)
    21  
    22  	Dsyev(jobz lapack.EVJob, uplo blas.Uplo, n int, a []float64, lda int, w, work []float64, lwork int) (ok bool)
    23  }
    24  
    25  func Dlasq2Test(t *testing.T, impl Dlasq2er) {
    26  	const tol = 1e-14
    27  
    28  	rnd := rand.New(rand.NewSource(1))
    29  	for _, n := range []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 25, 50} {
    30  		for k := 0; k < 10; k++ {
    31  			for typ := 0; typ <= 2; typ++ {
    32  				name := fmt.Sprintf("n=%v,typ=%v", n, typ)
    33  
    34  				want := make([]float64, n)
    35  				z := make([]float64, 4*n)
    36  				switch typ {
    37  				case 0:
    38  					// L is the identity, U has zero diagonal.
    39  				case 1:
    40  					// L is the identity, U has random diagonal, and so T is upper triangular.
    41  					for i := 0; i < n; i++ {
    42  						z[2*i] = rnd.Float64()
    43  						want[i] = z[2*i]
    44  					}
    45  					sort.Float64s(want)
    46  				case 2:
    47  					// Random tridiagonal matrix
    48  					for i := range z {
    49  						z[i] = rnd.Float64()
    50  					}
    51  					// The slice 'want' is computed below.
    52  				}
    53  				zCopy := make([]float64, len(z))
    54  				copy(zCopy, z)
    55  
    56  				// Compute the eigenvalues of the symmetric positive definite
    57  				// tridiagonal matrix associated with the slice z.
    58  				info := impl.Dlasq2(n, z)
    59  				if info != 0 {
    60  					t.Fatalf("%v: Dlasq2 failed", name)
    61  				}
    62  
    63  				if n == 0 {
    64  					continue
    65  				}
    66  
    67  				got := z[:n]
    68  
    69  				if typ == 2 {
    70  					// Compute the expected result.
    71  
    72  					// Compute the non-symmetric tridiagonal matrix T = L*U where L and
    73  					// U are represented by the slice z.
    74  					ldt := n
    75  					T := make([]float64, n*ldt)
    76  					for i := 0; i < n; i++ {
    77  						if i == 0 {
    78  							T[0] = zCopy[0]
    79  						} else {
    80  							T[i*ldt+i] = zCopy[2*i-1] + zCopy[2*i]
    81  						}
    82  						if i < n-1 {
    83  							T[i*ldt+i+1] = 1
    84  							T[(i+1)*ldt+i] = zCopy[2*i+1] * zCopy[2*i]
    85  						}
    86  					}
    87  					// Compute the symmetric tridiagonal matrix by applying a similarity
    88  					// transformation on T: D^{-1}*T*D. See discussion and references in
    89  					//  http://icl.cs.utk.edu/lapack-forum/viewtopic.php?f=5&t=4839
    90  					d := make([]float64, n)
    91  					d[0] = 1
    92  					for i := 1; i < n; i++ {
    93  						d[i] = d[i-1] * T[i*ldt+i-1] / T[(i-1)*ldt+i]
    94  					}
    95  					for i, di := range d {
    96  						d[i] = math.Sqrt(di)
    97  					}
    98  					for i := 0; i < n; i++ {
    99  						// Update only the upper triangle.
   100  						for j := i; j <= min(i+1, n-1); j++ {
   101  							T[i*ldt+j] *= d[j] / d[i]
   102  						}
   103  					}
   104  
   105  					// Compute the eigenvalues of D^{-1}*T*D by using Dsyev. It's call
   106  					// tree doesn't include Dlasq2.
   107  					work := make([]float64, 3*n)
   108  					ok := impl.Dsyev(lapack.EVNone, blas.Upper, n, T, ldt, want, work, len(work))
   109  					if !ok {
   110  						t.Fatalf("%v: Dsyev failed", name)
   111  					}
   112  				}
   113  
   114  				sort.Float64s(got)
   115  				diff := floats.Distance(got, want, math.Inf(1))
   116  				if diff > tol {
   117  					t.Errorf("%v: unexpected eigenvalues; diff=%v\n%v\n%v\n\n", name, diff, got, want)
   118  				}
   119  			}
   120  		}
   121  	}
   122  }