gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dlanv2.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"testing"
    11  
    12  	"golang.org/x/exp/rand"
    13  )
    14  
    15  type Dlanv2er interface {
    16  	Dlanv2(a, b, c, d float64) (aa, bb, cc, dd float64, rt1r, rt1i, rt2r, rt2i float64, cs, sn float64)
    17  }
    18  
    19  func Dlanv2Test(t *testing.T, impl Dlanv2er) {
    20  	rnd := rand.New(rand.NewSource(1))
    21  	t.Run("UpperTriangular", func(t *testing.T) {
    22  		for i := 0; i < 10; i++ {
    23  			a := rnd.NormFloat64()
    24  			b := rnd.NormFloat64()
    25  			d := rnd.NormFloat64()
    26  			dlanv2Test(t, impl, a, b, 0, d)
    27  		}
    28  	})
    29  	t.Run("LowerTriangular", func(t *testing.T) {
    30  		for i := 0; i < 10; i++ {
    31  			a := rnd.NormFloat64()
    32  			c := rnd.NormFloat64()
    33  			d := rnd.NormFloat64()
    34  			dlanv2Test(t, impl, a, 0, c, d)
    35  		}
    36  	})
    37  	t.Run("StandardSchur", func(t *testing.T) {
    38  		for i := 0; i < 10; i++ {
    39  			a := rnd.NormFloat64()
    40  			b := rnd.NormFloat64()
    41  			c := rnd.NormFloat64()
    42  			if math.Signbit(b) == math.Signbit(c) {
    43  				c = -c
    44  			}
    45  			dlanv2Test(t, impl, a, b, c, a)
    46  		}
    47  	})
    48  	t.Run("General", func(t *testing.T) {
    49  		for i := 0; i < 100; i++ {
    50  			a := rnd.NormFloat64()
    51  			b := rnd.NormFloat64()
    52  			c := rnd.NormFloat64()
    53  			d := rnd.NormFloat64()
    54  			dlanv2Test(t, impl, a, b, c, d)
    55  		}
    56  
    57  		// https://github.com/Reference-LAPACK/lapack/issues/263
    58  		dlanv2Test(t, impl, 0, 1, -1, math.Nextafter(0, 1))
    59  	})
    60  }
    61  
    62  func dlanv2Test(t *testing.T, impl Dlanv2er, a, b, c, d float64) {
    63  	aa, bb, cc, dd, rt1r, rt1i, rt2r, rt2i, cs, sn := impl.Dlanv2(a, b, c, d)
    64  
    65  	mat := fmt.Sprintf("[%v %v; %v %v]", a, b, c, d)
    66  	if cc == 0 {
    67  		// The eigenvalues are real, so check that the imaginary parts
    68  		// are zero.
    69  		if rt1i != 0 || rt2i != 0 {
    70  			t.Errorf("Unexpected complex eigenvalues for %v", mat)
    71  		}
    72  	} else {
    73  		// The eigenvalues are complex, so check that documented
    74  		// conditions hold.
    75  		if aa != dd {
    76  			t.Errorf("Diagonal elements not equal for %v: got [%v %v]", mat, aa, dd)
    77  		}
    78  		if bb*cc >= 0 {
    79  			t.Errorf("Non-diagonal elements have the same sign for %v: got [%v %v]", mat, bb, cc)
    80  		} else {
    81  			// Compute the absolute value of the imaginary part.
    82  			im := math.Sqrt(-bb * cc)
    83  			// Check that ±im is close to one of the returned
    84  			// imaginary parts.
    85  			if math.Abs(rt1i-im) > 1e-14 && math.Abs(rt1i+im) > 1e-14 {
    86  				t.Errorf("Unexpected imaginary part of eigenvalue for %v: got %v, want %v or %v", mat, rt1i, im, -im)
    87  			}
    88  			if math.Abs(rt2i-im) > 1e-14 && math.Abs(rt2i+im) > 1e-14 {
    89  				t.Errorf("Unexpected imaginary part of eigenvalue for %v: got %v, want %v or %v", mat, rt2i, im, -im)
    90  			}
    91  		}
    92  	}
    93  	// Check that the returned real parts are consistent.
    94  	if rt1r != aa && rt1r != dd {
    95  		t.Errorf("Unexpected real part of eigenvalue for %v: got %v, want %v or %v", mat, rt1r, aa, dd)
    96  	}
    97  	if rt2r != aa && rt2r != dd {
    98  		t.Errorf("Unexpected real part of eigenvalue for %v: got %v, want %v or %v", mat, rt2r, aa, dd)
    99  	}
   100  	// Check that the columns of the orthogonal matrix have unit norm.
   101  	if math.Abs(math.Hypot(cs, sn)-1) > 1e-14 {
   102  		t.Errorf("Unexpected unitary matrix for %v: got cs %v, sn %v", mat, cs, sn)
   103  	}
   104  
   105  	// Re-compute the original matrix [a b; c d] from its factorization.
   106  	gota := cs*(aa*cs-bb*sn) - sn*(cc*cs-dd*sn)
   107  	gotb := cs*(aa*sn+bb*cs) - sn*(cc*sn+dd*cs)
   108  	gotc := sn*(aa*cs-bb*sn) + cs*(cc*cs-dd*sn)
   109  	gotd := sn*(aa*sn+bb*cs) + cs*(cc*sn+dd*cs)
   110  	if math.Abs(gota-a) > 1e-14 ||
   111  		math.Abs(gotb-b) > 1e-14 ||
   112  		math.Abs(gotc-c) > 1e-14 ||
   113  		math.Abs(gotd-d) > 1e-14 {
   114  		t.Errorf("Unexpected factorization: got [%v %v; %v %v], want [%v %v; %v %v]", gota, gotb, gotc, gotd, a, b, c, d)
   115  	}
   116  }