gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dlahr2.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"testing"
    11  
    12  	"golang.org/x/exp/rand"
    13  
    14  	"gonum.org/v1/gonum/blas"
    15  	"gonum.org/v1/gonum/blas/blas64"
    16  )
    17  
    18  type Dlahr2er interface {
    19  	Dlahr2(n, k, nb int, a []float64, lda int, tau, t []float64, ldt int, y []float64, ldy int)
    20  }
    21  
    22  func Dlahr2Test(t *testing.T, impl Dlahr2er) {
    23  	const tol = 1e-14
    24  
    25  	rnd := rand.New(rand.NewSource(1))
    26  	for _, test := range []struct {
    27  		n, k, nb int
    28  	}{
    29  		{3, 0, 3},
    30  		{3, 1, 2},
    31  		{3, 1, 1},
    32  
    33  		{5, 0, 5},
    34  		{5, 1, 4},
    35  		{5, 1, 3},
    36  		{5, 1, 2},
    37  		{5, 1, 1},
    38  		{5, 2, 3},
    39  		{5, 2, 2},
    40  		{5, 2, 1},
    41  		{5, 3, 2},
    42  		{5, 3, 1},
    43  
    44  		{7, 3, 4},
    45  		{7, 3, 3},
    46  		{7, 3, 2},
    47  		{7, 3, 1},
    48  
    49  		{10, 0, 10},
    50  		{10, 1, 9},
    51  		{10, 1, 5},
    52  		{10, 1, 1},
    53  		{10, 5, 5},
    54  		{10, 5, 3},
    55  		{10, 5, 1},
    56  	} {
    57  		for cas := 0; cas < 100; cas++ {
    58  			for _, extraStride := range []int{0, 1, 10} {
    59  				n := test.n
    60  				k := test.k
    61  				nb := test.nb
    62  
    63  				a := randomGeneral(n, n-k+1, n-k+1+extraStride, rnd)
    64  				aCopy := a
    65  				aCopy.Data = make([]float64, len(a.Data))
    66  				copy(aCopy.Data, a.Data)
    67  				tmat := nanTriangular(blas.Upper, nb, nb+extraStride)
    68  				y := nanGeneral(n, nb, nb+extraStride)
    69  				tau := nanSlice(nb)
    70  
    71  				impl.Dlahr2(n, k, nb, a.Data, a.Stride, tau, tmat.Data, tmat.Stride, y.Data, y.Stride)
    72  
    73  				prefix := fmt.Sprintf("Case n=%v, k=%v, nb=%v, ldex=%v", n, k, nb, extraStride)
    74  
    75  				if !generalOutsideAllNaN(a) {
    76  					t.Errorf("%v: out-of-range write to A\n%v", prefix, a.Data)
    77  				}
    78  				if !triangularOutsideAllNaN(tmat) {
    79  					t.Errorf("%v: out-of-range write to T\n%v", prefix, tmat.Data)
    80  				}
    81  				if !generalOutsideAllNaN(y) {
    82  					t.Errorf("%v: out-of-range write to Y\n%v", prefix, y.Data)
    83  				}
    84  
    85  				// Check that A[:k,:] and A[:,nb:] blocks were not modified.
    86  				for i := 0; i < n; i++ {
    87  					for j := 0; j < n-k+1; j++ {
    88  						if i >= k && j < nb {
    89  							continue
    90  						}
    91  						if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] {
    92  							t.Errorf("%v: unexpected write to A[%v,%v]", prefix, i, j)
    93  						}
    94  					}
    95  				}
    96  
    97  				// Check that all elements of tau were assigned.
    98  				for i, v := range tau {
    99  					if math.IsNaN(v) {
   100  						t.Errorf("%v: tau[%v] not assigned", prefix, i)
   101  					}
   102  				}
   103  
   104  				// Extract V from a.
   105  				v := blas64.General{
   106  					Rows:   n - k + 1,
   107  					Cols:   nb,
   108  					Stride: nb,
   109  					Data:   make([]float64, (n-k+1)*nb),
   110  				}
   111  				for j := 0; j < v.Cols; j++ {
   112  					v.Data[(j+1)*v.Stride+j] = 1
   113  					for i := j + 2; i < v.Rows; i++ {
   114  						v.Data[i*v.Stride+j] = a.Data[(i+k-1)*a.Stride+j]
   115  					}
   116  				}
   117  
   118  				// VT = V.
   119  				vt := v
   120  				vt.Data = make([]float64, len(v.Data))
   121  				copy(vt.Data, v.Data)
   122  				// VT = V * T.
   123  				blas64.Trmm(blas.Right, blas.NoTrans, 1, tmat, vt)
   124  				// YWant = A * V * T.
   125  				ywant := blas64.General{
   126  					Rows:   n,
   127  					Cols:   nb,
   128  					Stride: nb,
   129  					Data:   make([]float64, n*nb),
   130  				}
   131  				blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aCopy, vt, 0, ywant)
   132  
   133  				// Compare Y and YWant.
   134  				for i := 0; i < n; i++ {
   135  					for j := 0; j < nb; j++ {
   136  						diff := math.Abs(ywant.Data[i*ywant.Stride+j] - y.Data[i*y.Stride+j])
   137  						if diff > tol {
   138  							t.Errorf("%v: unexpected Y[%v,%v], diff=%v", prefix, i, j, diff)
   139  						}
   140  					}
   141  				}
   142  
   143  				// Construct Q directly from the first nb columns of a.
   144  				q := constructQ("QR", n-k, nb, a.Data[k*a.Stride:], a.Stride, tau)
   145  				if resid := residualOrthogonal(q, false); resid > tol*float64(n) {
   146  					t.Errorf("Case %v: Q is not orthogonal; resid=%v, want<=%v", prefix, resid, tol*float64(n))
   147  				}
   148  				// Construct Q as the product Q = I - V*T*Vᵀ.
   149  				qwant := blas64.General{
   150  					Rows:   n - k + 1,
   151  					Cols:   n - k + 1,
   152  					Stride: n - k + 1,
   153  					Data:   make([]float64, (n-k+1)*(n-k+1)),
   154  				}
   155  				for i := 0; i < qwant.Rows; i++ {
   156  					qwant.Data[i*qwant.Stride+i] = 1
   157  				}
   158  				blas64.Gemm(blas.NoTrans, blas.Trans, -1, vt, v, 1, qwant)
   159  				if resid := residualOrthogonal(qwant, false); resid > tol*float64(n) {
   160  					t.Errorf("Case %v: Q = I - V*T*Vᵀ is not orthogonal; resid=%v, want<=%v", prefix, resid, tol*float64(n))
   161  				}
   162  
   163  				// Compare Q and QWant. Note that since Q is
   164  				// (n-k)×(n-k) and QWant is (n-k+1)×(n-k+1), we
   165  				// ignore the first row and column of QWant.
   166  				for i := 0; i < n-k; i++ {
   167  					for j := 0; j < n-k; j++ {
   168  						diff := math.Abs(q.Data[i*q.Stride+j] - qwant.Data[(i+1)*qwant.Stride+j+1])
   169  						if diff > tol {
   170  							t.Errorf("%v: unexpected Q[%v,%v], diff=%v", prefix, i, j, diff)
   171  						}
   172  					}
   173  				}
   174  			}
   175  		}
   176  	}
   177  }