gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/mat/eigen_test.go (about)

     1  // Copyright ©2013 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"math"
     9  	"sort"
    10  	"testing"
    11  
    12  	"golang.org/x/exp/rand"
    13  
    14  	"gonum.org/v1/gonum/floats"
    15  )
    16  
    17  func TestEigen(t *testing.T) {
    18  	t.Parallel()
    19  	for i, test := range []struct {
    20  		a *Dense
    21  
    22  		values []complex128
    23  		left   *CDense
    24  		right  *CDense
    25  	}{
    26  		{
    27  			a: NewDense(3, 3, []float64{
    28  				1, 0, 0,
    29  				0, 1, 0,
    30  				0, 0, 1,
    31  			}),
    32  			values: []complex128{1, 1, 1},
    33  			left: NewCDense(3, 3, []complex128{
    34  				1, 0, 0,
    35  				0, 1, 0,
    36  				0, 0, 1,
    37  			}),
    38  			right: NewCDense(3, 3, []complex128{
    39  				1, 0, 0,
    40  				0, 1, 0,
    41  				0, 0, 1,
    42  			}),
    43  		},
    44  		{
    45  			// Values compared with numpy.
    46  			a: NewDense(4, 4, []float64{
    47  				0.9025, 0.025, 0.475, 0.0475,
    48  				0.0475, 0.475, 0.475, 0.0025,
    49  				0.0475, 0.025, 0.025, 0.9025,
    50  				0.0025, 0.475, 0.025, 0.0475,
    51  			}),
    52  			values: []complex128{1, 0.7300317046114154, -0.1400158523057075 + 0.452854925738716i, -0.1400158523057075 - 0.452854925738716i},
    53  			left: NewCDense(4, 4, []complex128{
    54  				0.5, -0.3135167160788313, -0.02058121780136903 + 0.004580939300127051i, -0.02058121780136903 - 0.004580939300127051i,
    55  				0.5, 0.7842199280224781, 0.37551026954193356 - 0.2924634904103879i, 0.37551026954193356 + 0.2924634904103879i,
    56  				0.5, 0.33202200780783525, 0.16052616322784943 + 0.3881393645202527i, 0.16052616322784943 - 0.3881393645202527i,
    57  				0.5, 0.42008065840123954, -0.7723935249234155, -0.7723935249234155,
    58  			}),
    59  			right: NewCDense(4, 4, []complex128{
    60  				0.9476399565969628, -0.8637347682162745, -0.2688989440320280 - 0.1282234938321029i, -0.2688989440320280 + 0.1282234938321029i,
    61  				0.2394935907064427, 0.3457075153704627, -0.3621360383713332 - 0.2583198964498771i, -0.3621360383713332 + 0.2583198964498771i,
    62  				0.1692743801716332, 0.2706851011641580, 0.7426369401030960, 0.7426369401030960,
    63  				0.1263626404003607, 0.2473421516816520, -0.1116019576997347 + 0.3865433902819795i, -0.1116019576997347 - 0.3865433902819795i,
    64  			}),
    65  		},
    66  	} {
    67  		var e1, e2, e3, e4 Eigen
    68  		ok := e1.Factorize(test.a, EigenBoth)
    69  		if !ok {
    70  			panic("bad factorization")
    71  		}
    72  		e2.Factorize(test.a, EigenRight)
    73  		e3.Factorize(test.a, EigenLeft)
    74  		e4.Factorize(test.a, EigenNone)
    75  
    76  		v1 := e1.Values(nil)
    77  		if !cmplxEqualTol(v1, test.values, 1e-14) {
    78  			t.Errorf("eigenvalue mismatch. Case %v", i)
    79  		}
    80  		var left CDense
    81  		e1.LeftVectorsTo(&left)
    82  		if !CEqualApprox(&left, test.left, 1e-14) {
    83  			t.Errorf("left eigenvector mismatch. Case %v", i)
    84  		}
    85  		var right CDense
    86  		e1.VectorsTo(&right)
    87  		if !CEqualApprox(&right, test.right, 1e-14) {
    88  			t.Errorf("right eigenvector mismatch. Case %v", i)
    89  		}
    90  
    91  		// Check that the eigenvectors and values are the same in all combinations.
    92  		if !cmplxEqual(v1, e2.Values(nil)) {
    93  			t.Errorf("eigenvector mismatch. Case %v", i)
    94  		}
    95  		if !cmplxEqual(v1, e3.Values(nil)) {
    96  			t.Errorf("eigenvector mismatch. Case %v", i)
    97  		}
    98  		if !cmplxEqual(v1, e4.Values(nil)) {
    99  			t.Errorf("eigenvector mismatch. Case %v", i)
   100  		}
   101  		var right2 CDense
   102  		e2.VectorsTo(&right2)
   103  		if !CEqual(&right, &right2) {
   104  			t.Errorf("right eigenvector mismatch. Case %v", i)
   105  		}
   106  		var left3 CDense
   107  		e3.LeftVectorsTo(&left3)
   108  		if !CEqual(&left, &left3) {
   109  			t.Errorf("left eigenvector mismatch. Case %v", i)
   110  		}
   111  
   112  		// TODO(btracey): Also add in a test for correctness when #308 is
   113  		// resolved and we have a CMat.Mul().
   114  	}
   115  }
   116  
   117  func cmplxEqual(v1, v2 []complex128) bool {
   118  	for i, v := range v1 {
   119  		if v != v2[i] {
   120  			return false
   121  		}
   122  	}
   123  	return true
   124  }
   125  
   126  func cmplxEqualTol(v1, v2 []complex128, tol float64) bool {
   127  	for i, v := range v1 {
   128  		if !cEqualWithinAbsOrRel(v, v2[i], tol, tol) {
   129  			return false
   130  		}
   131  	}
   132  	return true
   133  }
   134  
   135  func TestEigenSym(t *testing.T) {
   136  	t.Parallel()
   137  	const tol = 1e-14
   138  	// Hand coded tests with results from lapack.
   139  	for cas, test := range []struct {
   140  		mat *SymDense
   141  
   142  		values  []float64
   143  		vectors *Dense
   144  	}{
   145  		{
   146  			mat:    NewSymDense(3, []float64{8, 2, 4, 2, 6, 10, 4, 10, 5}),
   147  			values: []float64{-4.707679201365891, 6.294580208480216, 17.413098992885672},
   148  			vectors: NewDense(3, 3, []float64{
   149  				-0.127343483135656, -0.902414161226903, -0.411621572466779,
   150  				-0.664177720955769, 0.385801900032553, -0.640331827193739,
   151  				0.736648893495999, 0.191847792659746, -0.648492738712395,
   152  			}),
   153  		},
   154  	} {
   155  		var es EigenSym
   156  		ok := es.Factorize(test.mat, true)
   157  		if !ok {
   158  			t.Errorf("case %d: bad test", cas)
   159  			continue
   160  		}
   161  		if !floats.EqualApprox(test.values, es.values, tol) {
   162  			t.Errorf("case %d: eigenvalue mismatch", cas)
   163  		}
   164  		if !EqualApprox(test.vectors, es.vectors, tol) {
   165  			t.Errorf("case %d: eigenvector mismatch", cas)
   166  		}
   167  
   168  		var es2 EigenSym
   169  		es2.Factorize(test.mat, false)
   170  		if !floats.EqualApprox(es2.values, es.values, tol) {
   171  			t.Errorf("case %d: eigenvalue mismatch when no vectors computed", cas)
   172  		}
   173  	}
   174  
   175  	// Randomized tests
   176  	rnd := rand.New(rand.NewSource(1))
   177  	for _, n := range []int{1, 2, 3, 5, 10, 70} {
   178  		for cas := 0; cas < 10; cas++ {
   179  			a := make([]float64, n*n)
   180  			for i := range a {
   181  				a[i] = rnd.NormFloat64()
   182  			}
   183  			s := NewSymDense(n, a)
   184  			var es EigenSym
   185  			ok := es.Factorize(s, true)
   186  			if !ok {
   187  				t.Errorf("n=%d,cas=%d: bad test", n, cas)
   188  				continue
   189  			}
   190  
   191  			// Check that A and EigenSym are equal as Matrix.
   192  			if !EqualApprox(s, &es, tol*float64(n)) {
   193  				t.Errorf("n=%d,cas=%d: A and EigenSym are not equal as Matrix", n, cas)
   194  			}
   195  			if !EqualApprox(s.T(), es.T(), tol*float64(n)) {
   196  				t.Errorf("n=%d,cas=%d: Aᵀ and EigenSymᵀ are not equal as Matrix", n, cas)
   197  			}
   198  
   199  			// Check that the eigenvectors are orthonormal.
   200  			if !isOrthonormal(es.vectors, 1e-8) {
   201  				t.Errorf("n=%d,cas=%d: eigenvectors not orthonormal", n, cas)
   202  			}
   203  
   204  			// Check that the eigenvalues are actually eigenvalues.
   205  			for i := 0; i < n; i++ {
   206  				v := NewVecDense(n, Col(nil, i, es.vectors))
   207  				var m VecDense
   208  				m.MulVec(s, v)
   209  
   210  				var scal VecDense
   211  				scal.ScaleVec(es.values[i], v)
   212  
   213  				if !EqualApprox(&m, &scal, 1e-8) {
   214  					t.Errorf("n=%d,cas=%d: eigenvalue %d does not match", n, cas, i)
   215  				}
   216  			}
   217  
   218  			// Check that A = Q * D * Qᵀ using the Raw methods.
   219  			var got Dense
   220  			got.Product(es.RawQ(), NewDiagDense(n, es.RawValues()), es.RawQ().T())
   221  			if !EqualApprox(s, &got, tol*float64(n)) {
   222  				var diff Dense
   223  				diff.Sub(s, &got)
   224  				diff.Apply(func(i, j int, v float64) float64 { return math.Abs(diff.At(i, j)) }, &diff)
   225  				t.Errorf("n=%d,cas=%d: A not reconstructed from Q*D*Qᵀ\n|diff|=%v", n, cas,
   226  					Formatted(&diff, Prefix("       ")))
   227  			}
   228  
   229  			// Check that the eigenvalues are in ascending order.
   230  			if !sort.Float64sAreSorted(es.values) {
   231  				t.Errorf("n=%d,cas=%d: eigenvalues not ascending", n, cas)
   232  			}
   233  		}
   234  	}
   235  }