github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/testlapack/dgehrd.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"testing"
    11  
    12  	"golang.org/x/exp/rand"
    13  
    14  	"github.com/jingcheng-WU/gonum/blas"
    15  	"github.com/jingcheng-WU/gonum/blas/blas64"
    16  )
    17  
    18  type Dgehrder interface {
    19  	Dgehrd(n, ilo, ihi int, a []float64, lda int, tau, work []float64, lwork int)
    20  
    21  	Dorgqr(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
    22  }
    23  
    24  func DgehrdTest(t *testing.T, impl Dgehrder) {
    25  	rnd := rand.New(rand.NewSource(1))
    26  
    27  	// Randomized tests for small matrix sizes that will most likely
    28  	// use the unblocked algorithm.
    29  	for _, n := range []int{1, 2, 3, 4, 5, 10, 34} {
    30  		for _, extra := range []int{0, 13} {
    31  			for _, optwork := range []bool{true, false} {
    32  				for cas := 0; cas < 10; cas++ {
    33  					ilo := rnd.Intn(n)
    34  					ihi := rnd.Intn(n)
    35  					if ilo > ihi {
    36  						ilo, ihi = ihi, ilo
    37  					}
    38  					testDgehrd(t, impl, n, ilo, ihi, extra, optwork, rnd)
    39  				}
    40  			}
    41  		}
    42  	}
    43  
    44  	// These are selected tests for larger matrix sizes to test the blocked
    45  	// algorithm. Use sizes around several powers of two because that is
    46  	// where the blocked path will most likely start to be taken. For
    47  	// example, at present the blocked algorithm is used for sizes larger
    48  	// than 129.
    49  	for _, test := range []struct {
    50  		n, ilo, ihi int
    51  	}{
    52  		{0, 0, -1},
    53  
    54  		{68, 0, 63},
    55  		{68, 0, 64},
    56  		{68, 0, 65},
    57  		{68, 0, 66},
    58  		{68, 0, 67},
    59  
    60  		{132, 2, 129},
    61  		{132, 1, 129}, // Size = 129, unblocked.
    62  		{132, 0, 129}, // Size = 130, blocked.
    63  		{132, 1, 130},
    64  		{132, 0, 130},
    65  		{132, 1, 131},
    66  		{132, 0, 131},
    67  
    68  		{260, 2, 257},
    69  		{260, 1, 257},
    70  		{260, 0, 257},
    71  		{260, 0, 258},
    72  		{260, 0, 259},
    73  	} {
    74  		for _, extra := range []int{0, 13} {
    75  			for _, optwork := range []bool{true, false} {
    76  				testDgehrd(t, impl, test.n, test.ilo, test.ihi, extra, optwork, rnd)
    77  			}
    78  		}
    79  	}
    80  }
    81  
    82  func testDgehrd(t *testing.T, impl Dgehrder, n, ilo, ihi, extra int, optwork bool, rnd *rand.Rand) {
    83  	const tol = 1e-13
    84  
    85  	a := randomGeneral(n, n, n+extra, rnd)
    86  	aCopy := a
    87  	aCopy.Data = make([]float64, len(a.Data))
    88  	copy(aCopy.Data, a.Data)
    89  
    90  	var tau []float64
    91  	if n > 1 {
    92  		tau = nanSlice(n - 1)
    93  	}
    94  
    95  	var work []float64
    96  	if optwork {
    97  		work = nanSlice(1)
    98  		impl.Dgehrd(n, ilo, ihi, a.Data, a.Stride, tau, work, -1)
    99  		work = nanSlice(int(work[0]))
   100  	} else {
   101  		work = nanSlice(max(1, n))
   102  	}
   103  
   104  	impl.Dgehrd(n, ilo, ihi, a.Data, a.Stride, tau, work, len(work))
   105  
   106  	if n == 0 {
   107  		// Just make sure there is no panic.
   108  		return
   109  	}
   110  
   111  	prefix := fmt.Sprintf("Case n=%v, ilo=%v, ihi=%v, extra=%v", n, ilo, ihi, extra)
   112  
   113  	// Check any invalid modifications of a.
   114  	if !generalOutsideAllNaN(a) {
   115  		t.Errorf("%v: out-of-range write to A\n%v", prefix, a.Data)
   116  	}
   117  	for i := ilo; i <= ihi; i++ {
   118  		for j := 0; j < min(ilo, i); j++ {
   119  			if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] {
   120  				t.Errorf("%v: unexpected modification of A[%v,%v]", prefix, i, j)
   121  			}
   122  		}
   123  	}
   124  	for i := ihi + 1; i < n; i++ {
   125  		for j := 0; j < i; j++ {
   126  			if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] {
   127  				t.Errorf("%v: unexpected modification of A[%v,%v]", prefix, i, j)
   128  			}
   129  		}
   130  	}
   131  	for i := 0; i <= ilo; i++ {
   132  		for j := i; j < ilo+1; j++ {
   133  			if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] {
   134  				t.Errorf("%v: unexpected modification at A[%v,%v]", prefix, i, j)
   135  			}
   136  		}
   137  		for j := ihi + 1; j < n; j++ {
   138  			if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] {
   139  				t.Errorf("%v: unexpected modification at A[%v,%v]", prefix, i, j)
   140  			}
   141  		}
   142  	}
   143  	for i := ihi + 1; i < n; i++ {
   144  		for j := i; j < n; j++ {
   145  			if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] {
   146  				t.Errorf("%v: unexpected modification at A[%v,%v]", prefix, i, j)
   147  			}
   148  		}
   149  	}
   150  
   151  	// Check that tau has been assigned properly.
   152  	for i, v := range tau {
   153  		if math.IsNaN(v) {
   154  			t.Errorf("%v: unexpected NaN at tau[%v]", prefix, i)
   155  		}
   156  	}
   157  
   158  	// Extract Q and check that it is orthogonal.
   159  	q := eye(n, n)
   160  	if ilo != ihi {
   161  		for i := ilo + 2; i <= ihi; i++ {
   162  			for j := ilo + 1; j < ihi; j++ {
   163  				q.Data[i*q.Stride+j] = a.Data[i*a.Stride+j-1]
   164  			}
   165  		}
   166  		nh := ihi - ilo
   167  		impl.Dorgqr(nh, nh, nh, q.Data[(ilo+1)*q.Stride+ilo+1:], q.Stride, tau[ilo:ihi], work, len(work))
   168  	}
   169  	if resid := residualOrthogonal(q, false); resid > tol {
   170  		t.Errorf("%v: Q is not orthogonal; resid=%v, want<=%v", prefix, resid, tol)
   171  	}
   172  
   173  	// Construct Qᵀ * AOrig * Q and check that it is upper Hessenberg.
   174  	aq := blas64.General{
   175  		Rows:   n,
   176  		Cols:   n,
   177  		Stride: n,
   178  		Data:   make([]float64, n*n),
   179  	}
   180  	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aCopy, q, 0, aq)
   181  	qaq := blas64.General{
   182  		Rows:   n,
   183  		Cols:   n,
   184  		Stride: n,
   185  		Data:   make([]float64, n*n),
   186  	}
   187  	blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, aq, 0, qaq)
   188  	for i := 0; i <= ilo; i++ {
   189  		for j := ilo + 1; j <= ihi; j++ {
   190  			qaqij := qaq.Data[i*qaq.Stride+j]
   191  			diff := qaqij - a.Data[i*a.Stride+j]
   192  			if math.Abs(diff) > tol {
   193  				t.Errorf("%v: Qᵀ*AOrig*Q and A are not equal, diff at [%v,%v]=%v", prefix, i, j, diff)
   194  			}
   195  		}
   196  	}
   197  	for i := ilo + 1; i <= ihi; i++ {
   198  		for j := ilo; j < n; j++ {
   199  			qaqij := qaq.Data[i*qaq.Stride+j]
   200  			if j < i-1 {
   201  				if math.Abs(qaqij) > tol {
   202  					t.Errorf("%v: Qᵀ*AOrig*Q is not upper Hessenberg, [%v,%v]=%v", prefix, i, j, qaqij)
   203  				}
   204  				continue
   205  			}
   206  			diff := qaqij - a.Data[i*a.Stride+j]
   207  			if math.Abs(diff) > tol {
   208  				t.Errorf("%v: Qᵀ*AOrig*Q and A are not equal, diff at [%v,%v]=%v", prefix, i, j, diff)
   209  			}
   210  		}
   211  	}
   212  }