gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dgehd2.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"testing"
    11  
    12  	"golang.org/x/exp/rand"
    13  
    14  	"gonum.org/v1/gonum/blas"
    15  	"gonum.org/v1/gonum/blas/blas64"
    16  )
    17  
    18  type Dgehd2er interface {
    19  	Dgehd2(n, ilo, ihi int, a []float64, lda int, tau, work []float64)
    20  }
    21  
    22  func Dgehd2Test(t *testing.T, impl Dgehd2er) {
    23  	rnd := rand.New(rand.NewSource(1))
    24  	for _, n := range []int{1, 2, 3, 4, 5, 7, 10, 30} {
    25  		for _, extra := range []int{0, 1, 13} {
    26  			for cas := 0; cas < 100; cas++ {
    27  				testDgehd2(t, impl, n, extra, rnd)
    28  			}
    29  		}
    30  	}
    31  }
    32  
    33  func testDgehd2(t *testing.T, impl Dgehd2er, n, extra int, rnd *rand.Rand) {
    34  	const tol = 1e-14
    35  
    36  	ilo := rnd.Intn(n)
    37  	ihi := rnd.Intn(n)
    38  	if ilo > ihi {
    39  		ilo, ihi = ihi, ilo
    40  	}
    41  
    42  	tau := nanSlice(n - 1)
    43  	work := nanSlice(n)
    44  
    45  	a := randomGeneral(n, n, n+extra, rnd)
    46  	// NaN out elements under the diagonal except
    47  	// for the [ilo:ihi,ilo:ihi] block.
    48  	for i := 1; i <= ihi; i++ {
    49  		for j := 0; j < min(ilo, i); j++ {
    50  			a.Data[i*a.Stride+j] = math.NaN()
    51  		}
    52  	}
    53  	for i := ihi + 1; i < n; i++ {
    54  		for j := 0; j < i; j++ {
    55  			a.Data[i*a.Stride+j] = math.NaN()
    56  		}
    57  	}
    58  	aCopy := a
    59  	aCopy.Data = make([]float64, len(a.Data))
    60  	copy(aCopy.Data, a.Data)
    61  
    62  	impl.Dgehd2(n, ilo, ihi, a.Data, a.Stride, tau, work)
    63  
    64  	prefix := fmt.Sprintf("Case n=%v, ilo=%v, ihi=%v, extra=%v", n, ilo, ihi, extra)
    65  
    66  	// Check any invalid modifications of a.
    67  	if !generalOutsideAllNaN(a) {
    68  		t.Errorf("%v: out-of-range write to A\n%v", prefix, a.Data)
    69  	}
    70  	for i := ilo; i <= ihi; i++ {
    71  		for j := 0; j < min(ilo, i); j++ {
    72  			if !math.IsNaN(a.Data[i*a.Stride+j]) {
    73  				t.Errorf("%v: expected NaN at A[%v,%v]", prefix, i, j)
    74  			}
    75  		}
    76  	}
    77  	for i := ihi + 1; i < n; i++ {
    78  		for j := 0; j < i; j++ {
    79  			if !math.IsNaN(a.Data[i*a.Stride+j]) {
    80  				t.Errorf("%v: expected NaN at A[%v,%v]", prefix, i, j)
    81  			}
    82  		}
    83  	}
    84  	for i := 0; i <= ilo; i++ {
    85  		for j := i; j < ilo+1; j++ {
    86  			if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] {
    87  				t.Errorf("%v: unexpected modification at A[%v,%v]", prefix, i, j)
    88  			}
    89  		}
    90  		for j := ihi + 1; j < n; j++ {
    91  			if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] {
    92  				t.Errorf("%v: unexpected modification at A[%v,%v]", prefix, i, j)
    93  			}
    94  		}
    95  	}
    96  	for i := ihi + 1; i < n; i++ {
    97  		for j := i; j < n; j++ {
    98  			if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] {
    99  				t.Errorf("%v: unexpected modification at A[%v,%v]", prefix, i, j)
   100  			}
   101  		}
   102  	}
   103  
   104  	// Check that tau has been assigned properly.
   105  	for i, v := range tau {
   106  		if i < ilo || i >= ihi {
   107  			if !math.IsNaN(v) {
   108  				t.Errorf("%v: expected NaN at tau[%v]", prefix, i)
   109  			}
   110  		} else {
   111  			if math.IsNaN(v) {
   112  				t.Errorf("%v: unexpected NaN at tau[%v]", prefix, i)
   113  			}
   114  		}
   115  	}
   116  
   117  	// Extract Q and check that it is orthogonal.
   118  	q := blas64.General{
   119  		Rows:   n,
   120  		Cols:   n,
   121  		Stride: n,
   122  		Data:   make([]float64, n*n),
   123  	}
   124  	for i := 0; i < q.Rows; i++ {
   125  		q.Data[i*q.Stride+i] = 1
   126  	}
   127  	qCopy := q
   128  	qCopy.Data = make([]float64, len(q.Data))
   129  	for j := ilo; j < ihi; j++ {
   130  		h := blas64.General{
   131  			Rows:   n,
   132  			Cols:   n,
   133  			Stride: n,
   134  			Data:   make([]float64, n*n),
   135  		}
   136  		for i := 0; i < h.Rows; i++ {
   137  			h.Data[i*h.Stride+i] = 1
   138  		}
   139  		v := blas64.Vector{
   140  			Inc:  1,
   141  			Data: make([]float64, n),
   142  		}
   143  		v.Data[j+1] = 1
   144  		for i := j + 2; i < ihi+1; i++ {
   145  			v.Data[i] = a.Data[i*a.Stride+j]
   146  		}
   147  		blas64.Ger(-tau[j], v, v, h)
   148  		copy(qCopy.Data, q.Data)
   149  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, h, 0, q)
   150  	}
   151  	if resid := residualOrthogonal(q, false); resid > tol {
   152  		t.Errorf("%v: Q is not orthogonal; resid=%v, want<=%v", prefix, resid, tol)
   153  	}
   154  
   155  	// Overwrite NaN elements of aCopy with zeros
   156  	// (we will multiply with it below).
   157  	for i := 1; i <= ihi; i++ {
   158  		for j := 0; j < min(ilo, i); j++ {
   159  			aCopy.Data[i*aCopy.Stride+j] = 0
   160  		}
   161  	}
   162  	for i := ihi + 1; i < n; i++ {
   163  		for j := 0; j < i; j++ {
   164  			aCopy.Data[i*aCopy.Stride+j] = 0
   165  		}
   166  	}
   167  
   168  	// Construct Qᵀ * AOrig * Q and check that it is
   169  	// equal to A from Dgehd2.
   170  	aq := blas64.General{
   171  		Rows:   n,
   172  		Cols:   n,
   173  		Stride: n,
   174  		Data:   make([]float64, n*n),
   175  	}
   176  	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aCopy, q, 0, aq)
   177  	qaq := blas64.General{
   178  		Rows:   n,
   179  		Cols:   n,
   180  		Stride: n,
   181  		Data:   make([]float64, n*n),
   182  	}
   183  	blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aq, 0, qaq)
   184  	for i := ilo; i <= ihi; i++ {
   185  		for j := ilo; j <= ihi; j++ {
   186  			qaqij := qaq.Data[i*qaq.Stride+j]
   187  			if j < i-1 {
   188  				if math.Abs(qaqij) > tol {
   189  					t.Errorf("%v: Qᵀ*A*Q is not upper Hessenberg, [%v,%v]=%v", prefix, i, j, qaqij)
   190  				}
   191  				continue
   192  			}
   193  			diff := qaqij - a.Data[i*a.Stride+j]
   194  			if math.Abs(diff) > tol {
   195  				t.Errorf("%v: Qᵀ*AOrig*Q and A are not equal, diff at [%v,%v]=%v", prefix, i, j, diff)
   196  			}
   197  		}
   198  	}
   199  }