github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dlaqr5.go (about)

     1  // Copyright ©2016 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"compress/gzip"
     9  	"encoding/json"
    10  	"fmt"
    11  	"log"
    12  	"math"
    13  	"math/rand"
    14  	"os"
    15  	"path/filepath"
    16  	"testing"
    17  
    18  	"github.com/gonum/blas"
    19  	"github.com/gonum/blas/blas64"
    20  )
    21  
    22  type Dlaqr5er interface {
    23  	Dlaqr5(wantt, wantz bool, kacc22 int, n, ktop, kbot, nshfts int, sr, si []float64, h []float64, ldh int, iloz, ihiz int, z []float64, ldz int, v []float64, ldv int, u []float64, ldu int, nh int, wh []float64, ldwh int, nv int, wv []float64, ldwv int)
    24  }
    25  
    26  type Dlaqr5test struct {
    27  	WantT          bool
    28  	N              int
    29  	NShifts        int
    30  	KTop, KBot     int
    31  	ShiftR, ShiftI []float64
    32  	H              []float64
    33  
    34  	HWant []float64
    35  	ZWant []float64
    36  }
    37  
    38  func Dlaqr5Test(t *testing.T, impl Dlaqr5er) {
    39  	// Test without using reference data.
    40  	rnd := rand.New(rand.NewSource(1))
    41  	for _, n := range []int{1, 2, 3, 4, 5, 6, 10, 30} {
    42  		for _, extra := range []int{0, 1, 20} {
    43  			for _, kacc22 := range []int{0, 1, 2} {
    44  				for cas := 0; cas < 100; cas++ {
    45  					testDlaqr5(t, impl, n, extra, kacc22, rnd)
    46  				}
    47  			}
    48  		}
    49  	}
    50  
    51  	// Test using reference data computed by the reference netlib
    52  	// implementation.
    53  	file, err := os.Open(filepath.FromSlash("../testlapack/testdata/dlaqr5data.json.gz"))
    54  	if err != nil {
    55  		log.Fatal(err)
    56  	}
    57  	defer file.Close()
    58  	r, err := gzip.NewReader(file)
    59  	if err != nil {
    60  		log.Fatal(err)
    61  	}
    62  	defer r.Close()
    63  
    64  	var tests []Dlaqr5test
    65  	json.NewDecoder(r).Decode(&tests)
    66  	for _, test := range tests {
    67  		wantt := test.WantT
    68  		n := test.N
    69  		nshfts := test.NShifts
    70  		ktop := test.KTop
    71  		kbot := test.KBot
    72  		sr := test.ShiftR
    73  		si := test.ShiftI
    74  
    75  		for _, extra := range []int{0, 1, 10} {
    76  			v := randomGeneral(nshfts/2, 3, 3+extra, rnd)
    77  			u := randomGeneral(3*nshfts-3, 3*nshfts-3, 3*nshfts-3+extra, rnd)
    78  			nh := n
    79  			wh := randomGeneral(3*nshfts-3, n, n+extra, rnd)
    80  			nv := n
    81  			wv := randomGeneral(n, 3*nshfts-3, 3*nshfts-3+extra, rnd)
    82  
    83  			h := nanGeneral(n, n, n+extra)
    84  
    85  			for _, kacc22 := range []int{0, 1, 2} {
    86  				copyMatrix(n, n, h.Data, h.Stride, test.H)
    87  				z := eye(n, n+extra)
    88  
    89  				impl.Dlaqr5(wantt, true, kacc22,
    90  					n, ktop, kbot,
    91  					nshfts, sr, si,
    92  					h.Data, h.Stride,
    93  					0, n-1, z.Data, z.Stride,
    94  					v.Data, v.Stride,
    95  					u.Data, u.Stride,
    96  					nv, wv.Data, wv.Stride,
    97  					nh, wh.Data, wh.Stride)
    98  
    99  				prefix := fmt.Sprintf("wantt=%v, n=%v, nshfts=%v, ktop=%v, kbot=%v, extra=%v, kacc22=%v",
   100  					wantt, n, nshfts, ktop, kbot, extra, kacc22)
   101  				if !equalApprox(n, n, h.Data, h.Stride, test.HWant, 1e-13) {
   102  					t.Errorf("Case %v: unexpected matrix H\nh    =%v\nhwant=%v", prefix, h.Data, test.HWant)
   103  				}
   104  				if !equalApprox(n, n, z.Data, z.Stride, test.ZWant, 1e-13) {
   105  					t.Errorf("Case %v: unexpected matrix Z\nz    =%v\nzwant=%v", prefix, z.Data, test.ZWant)
   106  				}
   107  			}
   108  		}
   109  	}
   110  }
   111  
   112  func testDlaqr5(t *testing.T, impl Dlaqr5er, n, extra, kacc22 int, rnd *rand.Rand) {
   113  	wantt := true
   114  	wantz := true
   115  	nshfts := 2 * n
   116  	sr := make([]float64, nshfts)
   117  	si := make([]float64, nshfts)
   118  	for i := 0; i < n; i++ {
   119  		re := rnd.NormFloat64()
   120  		im := rnd.NormFloat64()
   121  		sr[2*i], sr[2*i+1] = re, re
   122  		si[2*i], si[2*i+1] = im, -im
   123  	}
   124  	ktop := rnd.Intn(n)
   125  	kbot := rnd.Intn(n)
   126  	if kbot < ktop {
   127  		ktop, kbot = kbot, ktop
   128  	}
   129  
   130  	v := randomGeneral(nshfts/2, 3, 3+extra, rnd)
   131  	u := randomGeneral(3*nshfts-3, 3*nshfts-3, 3*nshfts-3+extra, rnd)
   132  	nh := n
   133  	wh := randomGeneral(3*nshfts-3, n, n+extra, rnd)
   134  	nv := n
   135  	wv := randomGeneral(n, 3*nshfts-3, 3*nshfts-3+extra, rnd)
   136  
   137  	h := randomHessenberg(n, n+extra, rnd)
   138  	if ktop > 0 {
   139  		h.Data[ktop*h.Stride+ktop-1] = 0
   140  	}
   141  	if kbot < n-1 {
   142  		h.Data[(kbot+1)*h.Stride+kbot] = 0
   143  	}
   144  	hCopy := h
   145  	hCopy.Data = make([]float64, len(h.Data))
   146  	copy(hCopy.Data, h.Data)
   147  
   148  	z := eye(n, n+extra)
   149  
   150  	impl.Dlaqr5(wantt, wantz, kacc22,
   151  		n, ktop, kbot,
   152  		nshfts, sr, si,
   153  		h.Data, h.Stride,
   154  		0, n-1, z.Data, z.Stride,
   155  		v.Data, v.Stride,
   156  		u.Data, u.Stride,
   157  		nv, wv.Data, wv.Stride,
   158  		nh, wh.Data, wh.Stride)
   159  
   160  	prefix := fmt.Sprintf("Case n=%v, extra=%v, kacc22=%v", n, extra, kacc22)
   161  
   162  	if !generalOutsideAllNaN(h) {
   163  		t.Errorf("%v: out-of-range write to H\n%v", prefix, h.Data)
   164  	}
   165  	if !generalOutsideAllNaN(z) {
   166  		t.Errorf("%v: out-of-range write to Z\n%v", prefix, z.Data)
   167  	}
   168  	if !generalOutsideAllNaN(u) {
   169  		t.Errorf("%v: out-of-range write to U\n%v", prefix, u.Data)
   170  	}
   171  	if !generalOutsideAllNaN(v) {
   172  		t.Errorf("%v: out-of-range write to V\n%v", prefix, v.Data)
   173  	}
   174  	if !generalOutsideAllNaN(wh) {
   175  		t.Errorf("%v: out-of-range write to WH\n%v", prefix, wh.Data)
   176  	}
   177  	if !generalOutsideAllNaN(wv) {
   178  		t.Errorf("%v: out-of-range write to WV\n%v", prefix, wv.Data)
   179  	}
   180  
   181  	for i := 0; i < n; i++ {
   182  		for j := 0; j < i-1; j++ {
   183  			if h.Data[i*h.Stride+j] != 0 {
   184  				t.Errorf("%v: H is not Hessenberg, H[%v,%v]!=0", prefix, i, j)
   185  			}
   186  		}
   187  	}
   188  	if !isOrthonormal(z) {
   189  		t.Errorf("%v: Z is not orthogonal", prefix)
   190  	}
   191  	// Construct Z^T * HOrig * Z and check that it is equal to H from Dlaqr5.
   192  	hz := blas64.General{
   193  		Rows:   n,
   194  		Cols:   n,
   195  		Stride: n,
   196  		Data:   make([]float64, n*n),
   197  	}
   198  	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hCopy, z, 0, hz)
   199  	zhz := blas64.General{
   200  		Rows:   n,
   201  		Cols:   n,
   202  		Stride: n,
   203  		Data:   make([]float64, n*n),
   204  	}
   205  	blas64.Gemm(blas.Trans, blas.NoTrans, 1, z, hz, 0, zhz)
   206  	for i := 0; i < n; i++ {
   207  		for j := 0; j < n; j++ {
   208  			diff := zhz.Data[i*zhz.Stride+j] - h.Data[i*h.Stride+j]
   209  			if math.Abs(diff) > 1e-13 {
   210  				t.Errorf("%v: Z^T*HOrig*Z and H are not equal, diff at [%v,%v]=%v", prefix, i, j, diff)
   211  			}
   212  		}
   213  	}
   214  }