github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/testlapack/dlaexc.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"github.com/jingcheng-WU/gonum/blas"
    14  	"github.com/jingcheng-WU/gonum/blas/blas64"
    15  	"github.com/jingcheng-WU/gonum/lapack"
    16  )
    17  
    18  type Dlaexcer interface {
    19  	Dlaexc(wantq bool, n int, t []float64, ldt int, q []float64, ldq int, j1, n1, n2 int, work []float64) bool
    20  }
    21  
    22  func DlaexcTest(t *testing.T, impl Dlaexcer) {
    23  	rnd := rand.New(rand.NewSource(1))
    24  
    25  	for _, n := range []int{1, 2, 3, 4, 5, 6, 10, 18, 31, 53} {
    26  		for _, extra := range []int{0, 3} {
    27  			for cas := 0; cas < 100; cas++ {
    28  				testDlaexc(t, impl, rnd, n, extra)
    29  			}
    30  		}
    31  	}
    32  }
    33  
    34  func testDlaexc(t *testing.T, impl Dlaexcer, rnd *rand.Rand, n, extra int) {
    35  	const tol = 1e-14
    36  
    37  	// Generate random T in Schur canonical form.
    38  	tmat, _, _ := randomSchurCanonical(n, n+extra, true, rnd)
    39  	tmatCopy := cloneGeneral(tmat)
    40  
    41  	// Randomly pick the index of the first block.
    42  	j1 := rnd.Intn(n)
    43  	if j1 > 0 && tmat.Data[j1*tmat.Stride+j1-1] != 0 {
    44  		// Adjust j1 if it points to the second row of a 2x2 block.
    45  		j1--
    46  	}
    47  	// Read sizes of the two blocks based on properties of T.
    48  	var n1, n2 int
    49  	switch j1 {
    50  	case n - 1:
    51  		n1, n2 = 1, 0
    52  	case n - 2:
    53  		if tmat.Data[(j1+1)*tmat.Stride+j1] == 0 {
    54  			n1, n2 = 1, 1
    55  		} else {
    56  			n1, n2 = 2, 0
    57  		}
    58  	case n - 3:
    59  		if tmat.Data[(j1+1)*tmat.Stride+j1] == 0 {
    60  			n1, n2 = 1, 2
    61  		} else {
    62  			n1, n2 = 2, 1
    63  		}
    64  	default:
    65  		if tmat.Data[(j1+1)*tmat.Stride+j1] == 0 {
    66  			n1 = 1
    67  			if tmat.Data[(j1+2)*tmat.Stride+j1+1] == 0 {
    68  				n2 = 1
    69  			} else {
    70  				n2 = 2
    71  			}
    72  		} else {
    73  			n1 = 2
    74  			if tmat.Data[(j1+3)*tmat.Stride+j1+2] == 0 {
    75  				n2 = 1
    76  			} else {
    77  				n2 = 2
    78  			}
    79  		}
    80  	}
    81  
    82  	name := fmt.Sprintf("Case n=%v,j1=%v,n1=%v,n2=%v,extra=%v", n, j1, n1, n2, extra)
    83  
    84  	// 1. Test without accumulating Q.
    85  
    86  	wantq := false
    87  
    88  	work := nanSlice(n)
    89  
    90  	ok := impl.Dlaexc(wantq, n, tmat.Data, tmat.Stride, nil, 1, j1, n1, n2, work)
    91  
    92  	// 2. Test with accumulating Q.
    93  
    94  	wantq = true
    95  
    96  	tmat2 := cloneGeneral(tmatCopy)
    97  	q := eye(n, n+extra)
    98  	qCopy := cloneGeneral(q)
    99  	work = nanSlice(n)
   100  
   101  	ok2 := impl.Dlaexc(wantq, n, tmat2.Data, tmat2.Stride, q.Data, q.Stride, j1, n1, n2, work)
   102  
   103  	if !generalOutsideAllNaN(tmat) {
   104  		t.Errorf("%v: out-of-range write to T", name)
   105  	}
   106  	if !generalOutsideAllNaN(tmat2) {
   107  		t.Errorf("%v: out-of-range write to T2", name)
   108  	}
   109  	if !generalOutsideAllNaN(q) {
   110  		t.Errorf("%v: out-of-range write to Q", name)
   111  	}
   112  
   113  	// Check that outputs from cases 1. and 2. are exactly equal, then check one of them.
   114  	if ok != ok2 {
   115  		t.Errorf("%v: ok != ok2", name)
   116  	}
   117  	if !equalGeneral(tmat, tmat2) {
   118  		t.Errorf("%v: T != T2", name)
   119  	}
   120  
   121  	if !ok {
   122  		if n1 == 1 && n2 == 1 {
   123  			t.Errorf("%v: unexpected failure", name)
   124  		} else {
   125  			t.Logf("%v: Dlaexc returned false", name)
   126  		}
   127  	}
   128  
   129  	if !ok || n1 == 0 || n2 == 0 || j1+n1 >= n {
   130  		// Check that T is not modified.
   131  		if !equalGeneral(tmat, tmatCopy) {
   132  			t.Errorf("%v: unexpected modification of T", name)
   133  		}
   134  		// Check that Q is not modified.
   135  		if !equalGeneral(q, qCopy) {
   136  			t.Errorf("%v: unexpected modification of Q", name)
   137  		}
   138  		return
   139  	}
   140  
   141  	// Check that T is not modified outside of rows and columns [j1:j1+n1+n2].
   142  	for i := 0; i < n; i++ {
   143  		if j1 <= i && i < j1+n1+n2 {
   144  			continue
   145  		}
   146  		for j := 0; j < n; j++ {
   147  			if j1 <= j && j < j1+n1+n2 {
   148  				continue
   149  			}
   150  			diff := tmat.Data[i*tmat.Stride+j] - tmatCopy.Data[i*tmatCopy.Stride+j]
   151  			if diff != 0 {
   152  				t.Errorf("%v: unexpected modification of T[%v,%v]", name, i, j)
   153  			}
   154  		}
   155  	}
   156  
   157  	if !isSchurCanonicalGeneral(tmat) {
   158  		t.Errorf("%v: T is not in Schur canonical form", name)
   159  	}
   160  
   161  	// Check that Q is orthogonal.
   162  	resid := residualOrthogonal(q, false)
   163  	if resid > tol {
   164  		t.Errorf("%v: Q is not orthogonal; resid=%v, want<=%v", name, resid, tol)
   165  	}
   166  
   167  	// Check that Q is unchanged outside of columns [j1:j1+n1+n2].
   168  	for i := 0; i < n; i++ {
   169  		for j := 0; j < n; j++ {
   170  			if j1 <= j && j < j1+n1+n2 {
   171  				continue
   172  			}
   173  			diff := q.Data[i*q.Stride+j] - qCopy.Data[i*qCopy.Stride+j]
   174  			if diff != 0 {
   175  				t.Errorf("%v: unexpected modification of Q[%v,%v]", name, i, j)
   176  			}
   177  		}
   178  	}
   179  
   180  	// Check that Qᵀ * TOrig * Q == T
   181  	qt := zeros(n, n, n)
   182  	blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, tmatCopy, 0, qt)
   183  	qtq := cloneGeneral(tmat)
   184  	blas64.Gemm(blas.NoTrans, blas.NoTrans, -1, qt, q, 1, qtq)
   185  	resid = dlange(lapack.MaxColumnSum, n, n, qtq.Data, qtq.Stride)
   186  	if resid > float64(n)*tol {
   187  		t.Errorf("%v: mismatch between Qᵀ*(initial T)*Q and (final T); resid=%v, want<=%v",
   188  			name, resid, float64(n)*tol)
   189  	}
   190  }