gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dtrexc.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"gonum.org/v1/gonum/blas"
    14  	"gonum.org/v1/gonum/blas/blas64"
    15  	"gonum.org/v1/gonum/lapack"
    16  )
    17  
    18  type Dtrexcer interface {
    19  	Dtrexc(compq lapack.UpdateSchurComp, n int, t []float64, ldt int, q []float64, ldq int, ifst, ilst int, work []float64) (ifstOut, ilstOut int, ok bool)
    20  }
    21  
    22  func DtrexcTest(t *testing.T, impl Dtrexcer) {
    23  	rnd := rand.New(rand.NewSource(1))
    24  
    25  	for _, n := range []int{0, 1, 2, 3, 4, 5, 6, 10, 18, 31, 53} {
    26  		for _, extra := range []int{0, 3} {
    27  			for cas := 0; cas < 100; cas++ {
    28  				var ifst, ilst int
    29  				if n > 0 {
    30  					ifst = rnd.Intn(n)
    31  					ilst = rnd.Intn(n)
    32  				}
    33  				dtrexcTest(t, impl, rnd, n, ifst, ilst, extra)
    34  			}
    35  		}
    36  	}
    37  }
    38  
    39  func dtrexcTest(t *testing.T, impl Dtrexcer, rnd *rand.Rand, n, ifst, ilst, extra int) {
    40  	const tol = 1e-13
    41  
    42  	tmat, _, _ := randomSchurCanonical(n, n+extra, true, rnd)
    43  	tmatCopy := cloneGeneral(tmat)
    44  
    45  	fstSize, fstFirst := schurBlockSize(tmat, ifst)
    46  	lstSize, lstFirst := schurBlockSize(tmat, ilst)
    47  
    48  	name := fmt.Sprintf("Case n=%v,ifst=%v,nbfst=%v,ilst=%v,nblst=%v,extra=%v",
    49  		n, ifst, fstSize, ilst, lstSize, extra)
    50  
    51  	// 1. Test without accumulating Q.
    52  
    53  	compq := lapack.UpdateSchurNone
    54  
    55  	work := nanSlice(n)
    56  
    57  	ifstGot, ilstGot, ok := impl.Dtrexc(compq, n, tmat.Data, tmat.Stride, nil, 1, ifst, ilst, work)
    58  
    59  	if !generalOutsideAllNaN(tmat) {
    60  		t.Errorf("%v: out-of-range write to T", name)
    61  	}
    62  
    63  	// 2. Test with accumulating Q.
    64  
    65  	compq = lapack.UpdateSchur
    66  
    67  	tmat2 := cloneGeneral(tmatCopy)
    68  	q := eye(n, n+extra)
    69  	qCopy := cloneGeneral(q)
    70  	work = nanSlice(n)
    71  
    72  	ifstGot2, ilstGot2, ok2 := impl.Dtrexc(compq, n, tmat2.Data, tmat2.Stride, q.Data, q.Stride, ifst, ilst, work)
    73  
    74  	if !generalOutsideAllNaN(tmat2) {
    75  		t.Errorf("%v: out-of-range write to T2", name)
    76  	}
    77  	if !generalOutsideAllNaN(q) {
    78  		t.Errorf("%v: out-of-range write to Q", name)
    79  	}
    80  
    81  	// Check that outputs from cases 1. and 2. are exactly equal, then check one of them.
    82  	if ifstGot != ifstGot2 {
    83  		t.Errorf("%v: ifstGot != ifstGot2", name)
    84  	}
    85  	if ilstGot != ilstGot2 {
    86  		t.Errorf("%v: ilstGot != ilstGot2", name)
    87  	}
    88  	if ok != ok2 {
    89  		t.Errorf("%v: ok != ok2", name)
    90  	}
    91  	if !equalGeneral(tmat, tmat2) {
    92  		t.Errorf("%v: T != T2", name)
    93  	}
    94  
    95  	// Check that the index of the first block was correctly updated (if
    96  	// necessary).
    97  	ifstWant := ifst
    98  	if !fstFirst {
    99  		ifstWant = ifst - 1
   100  	}
   101  	if ifstWant != ifstGot {
   102  		t.Errorf("%v: unexpected ifst=%v, want %v", name, ifstGot, ifstWant)
   103  	}
   104  
   105  	// Check that the index of the last block is as expected when ok=true.
   106  	// When ok=false, we don't know at which block the algorithm failed, so
   107  	// we don't check.
   108  	ilstWant := ilst
   109  	if !lstFirst {
   110  		ilstWant--
   111  	}
   112  	if ok {
   113  		if ifstWant < ilstWant {
   114  			// If the blocks are swapped backwards, these
   115  			// adjustments are not necessary, the first row of the
   116  			// last block will end up at ifst.
   117  			switch {
   118  			case fstSize == 2 && lstSize == 1:
   119  				ilstWant--
   120  			case fstSize == 1 && lstSize == 2:
   121  				ilstWant++
   122  			}
   123  		}
   124  		if ilstWant != ilstGot {
   125  			t.Errorf("%v: unexpected ilst=%v, want %v", name, ilstGot, ilstWant)
   126  		}
   127  	}
   128  
   129  	if n <= 1 || ifstGot == ilstGot {
   130  		// Too small matrix or no swapping.
   131  		// Check that T was not modified.
   132  		if !equalGeneral(tmat, tmatCopy) {
   133  			t.Errorf("%v: unexpected modification of T when no swapping", name)
   134  		}
   135  		// Check that Q was not modified.
   136  		if !equalGeneral(q, qCopy) {
   137  			t.Errorf("%v: unexpected modification of Q when no swapping", name)
   138  		}
   139  		// Nothing more to check
   140  		return
   141  	}
   142  
   143  	if !isSchurCanonicalGeneral(tmat) {
   144  		t.Errorf("%v: T is not in Schur canonical form", name)
   145  	}
   146  
   147  	// Check that T was not modified except above the second subdiagonal in
   148  	// rows and columns [modMin,modMax].
   149  	modMin := min(ifstGot, ilstGot)
   150  	modMax := max(ifstGot, ilstGot) + fstSize
   151  	for i := 0; i < n; i++ {
   152  		for j := 0; j < n; j++ {
   153  			if modMin <= i && i < modMax && j+1 >= i {
   154  				continue
   155  			}
   156  			if modMin <= j && j < modMax && j+1 >= i {
   157  				continue
   158  			}
   159  			diff := tmat.Data[i*tmat.Stride+j] - tmatCopy.Data[i*tmatCopy.Stride+j]
   160  			if diff != 0 {
   161  				t.Errorf("%v: unexpected modification at T[%v,%v]", name, i, j)
   162  			}
   163  		}
   164  	}
   165  
   166  	// Check that Q is orthogonal.
   167  	resid := residualOrthogonal(q, false)
   168  	if resid > tol {
   169  		t.Errorf("%v: Q is not orthogonal; resid=%v, want<=%v", name, resid, tol)
   170  	}
   171  
   172  	// Check that Q is unchanged outside of columns [modMin,modMax].
   173  	for i := 0; i < n; i++ {
   174  		for j := 0; j < n; j++ {
   175  			if modMin <= j && j < modMax {
   176  				continue
   177  			}
   178  			if q.Data[i*q.Stride+j] != qCopy.Data[i*qCopy.Stride+j] {
   179  				t.Errorf("%v: unexpected modification of Q[%v,%v]", name, i, j)
   180  			}
   181  		}
   182  	}
   183  
   184  	// Check that Qᵀ * TOrig * Q == T
   185  	qt := zeros(n, n, n)
   186  	blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, tmatCopy, 0, qt)
   187  	qtq := cloneGeneral(tmat)
   188  	blas64.Gemm(blas.NoTrans, blas.NoTrans, -1, qt, q, 1, qtq)
   189  	resid = dlange(lapack.MaxColumnSum, n, n, qtq.Data, qtq.Stride)
   190  	if resid > tol {
   191  		t.Errorf("%v: mismatch between Qᵀ*(initial T)*Q and (final T); resid=%v, want<=%v",
   192  			name, resid, tol)
   193  	}
   194  }