github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgehd2.go (about)

     1  // Copyright ©2016 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"math/rand"
    11  	"testing"
    12  
    13  	"github.com/gonum/blas"
    14  	"github.com/gonum/blas/blas64"
    15  )
    16  
    17  type Dgehd2er interface {
    18  	Dgehd2(n, ilo, ihi int, a []float64, lda int, tau, work []float64)
    19  }
    20  
    21  func Dgehd2Test(t *testing.T, impl Dgehd2er) {
    22  	rnd := rand.New(rand.NewSource(1))
    23  	for _, n := range []int{1, 2, 3, 4, 5, 7, 10, 30} {
    24  		for _, extra := range []int{0, 1, 13} {
    25  			for cas := 0; cas < 100; cas++ {
    26  				testDgehd2(t, impl, n, extra, rnd)
    27  			}
    28  		}
    29  	}
    30  }
    31  
    32  func testDgehd2(t *testing.T, impl Dgehd2er, n, extra int, rnd *rand.Rand) {
    33  	ilo := rnd.Intn(n)
    34  	ihi := rnd.Intn(n)
    35  	if ilo > ihi {
    36  		ilo, ihi = ihi, ilo
    37  	}
    38  
    39  	tau := nanSlice(n - 1)
    40  	work := nanSlice(n)
    41  
    42  	a := randomGeneral(n, n, n+extra, rnd)
    43  	// NaN out elements under the diagonal except
    44  	// for the [ilo:ihi,ilo:ihi] block.
    45  	for i := 1; i <= ihi; i++ {
    46  		for j := 0; j < min(ilo, i); j++ {
    47  			a.Data[i*a.Stride+j] = math.NaN()
    48  		}
    49  	}
    50  	for i := ihi + 1; i < n; i++ {
    51  		for j := 0; j < i; j++ {
    52  			a.Data[i*a.Stride+j] = math.NaN()
    53  		}
    54  	}
    55  	aCopy := a
    56  	aCopy.Data = make([]float64, len(a.Data))
    57  	copy(aCopy.Data, a.Data)
    58  
    59  	impl.Dgehd2(n, ilo, ihi, a.Data, a.Stride, tau, work)
    60  
    61  	prefix := fmt.Sprintf("Case n=%v, ilo=%v, ihi=%v, extra=%v", n, ilo, ihi, extra)
    62  
    63  	// Check any invalid modifications of a.
    64  	if !generalOutsideAllNaN(a) {
    65  		t.Errorf("%v: out-of-range write to A\n%v", prefix, a.Data)
    66  	}
    67  	for i := ilo; i <= ihi; i++ {
    68  		for j := 0; j < min(ilo, i); j++ {
    69  			if !math.IsNaN(a.Data[i*a.Stride+j]) {
    70  				t.Errorf("%v: expected NaN at A[%v,%v]", prefix, i, j)
    71  			}
    72  		}
    73  	}
    74  	for i := ihi + 1; i < n; i++ {
    75  		for j := 0; j < i; j++ {
    76  			if !math.IsNaN(a.Data[i*a.Stride+j]) {
    77  				t.Errorf("%v: expected NaN at A[%v,%v]", prefix, i, j)
    78  			}
    79  		}
    80  	}
    81  	for i := 0; i <= ilo; i++ {
    82  		for j := i; j < ilo+1; j++ {
    83  			if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] {
    84  				t.Errorf("%v: unexpected modification at A[%v,%v]", prefix, i, j)
    85  			}
    86  		}
    87  		for j := ihi + 1; j < n; j++ {
    88  			if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] {
    89  				t.Errorf("%v: unexpected modification at A[%v,%v]", prefix, i, j)
    90  			}
    91  		}
    92  	}
    93  	for i := ihi + 1; i < n; i++ {
    94  		for j := i; j < n; j++ {
    95  			if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] {
    96  				t.Errorf("%v: unexpected modification at A[%v,%v]", prefix, i, j)
    97  			}
    98  		}
    99  	}
   100  
   101  	// Check that tau has been assigned properly.
   102  	for i, v := range tau {
   103  		if i < ilo || i >= ihi {
   104  			if !math.IsNaN(v) {
   105  				t.Errorf("%v: expected NaN at tau[%v]", prefix, i)
   106  			}
   107  		} else {
   108  			if math.IsNaN(v) {
   109  				t.Errorf("%v: unexpected NaN at tau[%v]", prefix, i)
   110  			}
   111  		}
   112  	}
   113  
   114  	// Extract Q and check that it is orthogonal.
   115  	q := blas64.General{
   116  		Rows:   n,
   117  		Cols:   n,
   118  		Stride: n,
   119  		Data:   make([]float64, n*n),
   120  	}
   121  	for i := 0; i < q.Rows; i++ {
   122  		q.Data[i*q.Stride+i] = 1
   123  	}
   124  	qCopy := q
   125  	qCopy.Data = make([]float64, len(q.Data))
   126  	for j := ilo; j < ihi; j++ {
   127  		h := blas64.General{
   128  			Rows:   n,
   129  			Cols:   n,
   130  			Stride: n,
   131  			Data:   make([]float64, n*n),
   132  		}
   133  		for i := 0; i < h.Rows; i++ {
   134  			h.Data[i*h.Stride+i] = 1
   135  		}
   136  		v := blas64.Vector{
   137  			Inc:  1,
   138  			Data: make([]float64, n),
   139  		}
   140  		v.Data[j+1] = 1
   141  		for i := j + 2; i < ihi+1; i++ {
   142  			v.Data[i] = a.Data[i*a.Stride+j]
   143  		}
   144  		blas64.Ger(-tau[j], v, v, h)
   145  		copy(qCopy.Data, q.Data)
   146  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, qCopy, h, 0, q)
   147  	}
   148  	if !isOrthonormal(q) {
   149  		t.Errorf("%v: Q is not orthogonal\nQ=%v", prefix, q)
   150  	}
   151  
   152  	// Overwrite NaN elements of aCopy with zeros
   153  	// (we will multiply with it below).
   154  	for i := 1; i <= ihi; i++ {
   155  		for j := 0; j < min(ilo, i); j++ {
   156  			aCopy.Data[i*aCopy.Stride+j] = 0
   157  		}
   158  	}
   159  	for i := ihi + 1; i < n; i++ {
   160  		for j := 0; j < i; j++ {
   161  			aCopy.Data[i*aCopy.Stride+j] = 0
   162  		}
   163  	}
   164  
   165  	// Construct Q^T * AOrig * Q and check that it is
   166  	// equal to A from Dgehd2.
   167  	aq := blas64.General{
   168  		Rows:   n,
   169  		Cols:   n,
   170  		Stride: n,
   171  		Data:   make([]float64, n*n),
   172  	}
   173  	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aCopy, q, 0, aq)
   174  	qaq := blas64.General{
   175  		Rows:   n,
   176  		Cols:   n,
   177  		Stride: n,
   178  		Data:   make([]float64, n*n),
   179  	}
   180  	blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aq, 0, qaq)
   181  	for i := ilo; i <= ihi; i++ {
   182  		for j := ilo; j <= ihi; j++ {
   183  			qaqij := qaq.Data[i*qaq.Stride+j]
   184  			if j < i-1 {
   185  				if math.Abs(qaqij) > 1e-14 {
   186  					t.Errorf("%v: Q^T*A*Q is not upper Hessenberg, [%v,%v]=%v", prefix, i, j, qaqij)
   187  				}
   188  				continue
   189  			}
   190  			diff := qaqij - a.Data[i*a.Stride+j]
   191  			if math.Abs(diff) > 1e-14 {
   192  				t.Errorf("%v: Q^T*AOrig*Q and A are not equal, diff at [%v,%v]=%v", prefix, i, j, diff)
   193  			}
   194  		}
   195  	}
   196  }