github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dlahr2.go (about)

     1  // Copyright ©2016 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"compress/gzip"
     9  	"encoding/json"
    10  	"fmt"
    11  	"log"
    12  	"math"
    13  	"math/rand"
    14  	"os"
    15  	"path/filepath"
    16  	"testing"
    17  
    18  	"github.com/gonum/blas"
    19  	"github.com/gonum/blas/blas64"
    20  	"github.com/gonum/floats"
    21  )
    22  
    23  type Dlahr2er interface {
    24  	Dlahr2(n, k, nb int, a []float64, lda int, tau, t []float64, ldt int, y []float64, ldy int)
    25  }
    26  
    27  type Dlahr2test struct {
    28  	N, K, NB int
    29  	A        []float64
    30  
    31  	AWant   []float64
    32  	TWant   []float64
    33  	YWant   []float64
    34  	TauWant []float64
    35  }
    36  
    37  func Dlahr2Test(t *testing.T, impl Dlahr2er) {
    38  	rnd := rand.New(rand.NewSource(1))
    39  	for _, test := range []struct {
    40  		n, k, nb int
    41  	}{
    42  		{3, 0, 3},
    43  		{3, 1, 2},
    44  		{3, 1, 1},
    45  
    46  		{5, 0, 5},
    47  		{5, 1, 4},
    48  		{5, 1, 3},
    49  		{5, 1, 2},
    50  		{5, 1, 1},
    51  		{5, 2, 3},
    52  		{5, 2, 2},
    53  		{5, 2, 1},
    54  		{5, 3, 2},
    55  		{5, 3, 1},
    56  
    57  		{7, 3, 4},
    58  		{7, 3, 3},
    59  		{7, 3, 2},
    60  		{7, 3, 1},
    61  
    62  		{10, 0, 10},
    63  		{10, 1, 9},
    64  		{10, 1, 5},
    65  		{10, 1, 1},
    66  		{10, 5, 5},
    67  		{10, 5, 3},
    68  		{10, 5, 1},
    69  	} {
    70  		for cas := 0; cas < 100; cas++ {
    71  			for _, extraStride := range []int{0, 1, 10} {
    72  				n := test.n
    73  				k := test.k
    74  				nb := test.nb
    75  
    76  				a := randomGeneral(n, n-k+1, n-k+1+extraStride, rnd)
    77  				aCopy := a
    78  				aCopy.Data = make([]float64, len(a.Data))
    79  				copy(aCopy.Data, a.Data)
    80  				tmat := nanTriangular(blas.Upper, nb, nb+extraStride)
    81  				y := nanGeneral(n, nb, nb+extraStride)
    82  				tau := nanSlice(nb)
    83  
    84  				impl.Dlahr2(n, k, nb, a.Data, a.Stride, tau, tmat.Data, tmat.Stride, y.Data, y.Stride)
    85  
    86  				prefix := fmt.Sprintf("Case n=%v, k=%v, nb=%v, ldex=%v", n, k, nb, extraStride)
    87  
    88  				if !generalOutsideAllNaN(a) {
    89  					t.Errorf("%v: out-of-range write to A\n%v", prefix, a.Data)
    90  				}
    91  				if !triangularOutsideAllNaN(tmat) {
    92  					t.Errorf("%v: out-of-range write to T\n%v", prefix, tmat.Data)
    93  				}
    94  				if !generalOutsideAllNaN(y) {
    95  					t.Errorf("%v: out-of-range write to Y\n%v", prefix, y.Data)
    96  				}
    97  
    98  				// Check that A[:k,:] and A[:,nb:] blocks were not modified.
    99  				for i := 0; i < n; i++ {
   100  					for j := 0; j < n-k+1; j++ {
   101  						if i >= k && j < nb {
   102  							continue
   103  						}
   104  						if a.Data[i*a.Stride+j] != aCopy.Data[i*aCopy.Stride+j] {
   105  							t.Errorf("%v: unexpected write to A[%v,%v]", prefix, i, j)
   106  						}
   107  					}
   108  				}
   109  
   110  				// Check that all elements of tau were assigned.
   111  				for i, v := range tau {
   112  					if math.IsNaN(v) {
   113  						t.Errorf("%v: tau[%v] not assigned", prefix, i)
   114  					}
   115  				}
   116  
   117  				// Extract V from a.
   118  				v := blas64.General{
   119  					Rows:   n - k + 1,
   120  					Cols:   nb,
   121  					Stride: nb,
   122  					Data:   make([]float64, (n-k+1)*nb),
   123  				}
   124  				for j := 0; j < v.Cols; j++ {
   125  					v.Data[(j+1)*v.Stride+j] = 1
   126  					for i := j + 2; i < v.Rows; i++ {
   127  						v.Data[i*v.Stride+j] = a.Data[(i+k-1)*a.Stride+j]
   128  					}
   129  				}
   130  
   131  				// VT = V.
   132  				vt := v
   133  				vt.Data = make([]float64, len(v.Data))
   134  				copy(vt.Data, v.Data)
   135  				// VT = V * T.
   136  				blas64.Trmm(blas.Right, blas.NoTrans, 1, tmat, vt)
   137  				// YWant = A * V * T.
   138  				ywant := blas64.General{
   139  					Rows:   n,
   140  					Cols:   nb,
   141  					Stride: nb,
   142  					Data:   make([]float64, n*nb),
   143  				}
   144  				blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aCopy, vt, 0, ywant)
   145  
   146  				// Compare Y and YWant.
   147  				for i := 0; i < n; i++ {
   148  					for j := 0; j < nb; j++ {
   149  						diff := math.Abs(ywant.Data[i*ywant.Stride+j] - y.Data[i*y.Stride+j])
   150  						if diff > 1e-14 {
   151  							t.Errorf("%v: unexpected Y[%v,%v], diff=%v", prefix, i, j, diff)
   152  						}
   153  					}
   154  				}
   155  
   156  				// Construct Q directly from the first nb columns of a.
   157  				q := constructQ("QR", n-k, nb, a.Data[k*a.Stride:], a.Stride, tau)
   158  				if !isOrthonormal(q) {
   159  					t.Errorf("%v: Q is not orthogonal", prefix)
   160  				}
   161  				// Construct Q as the product Q = I - V*T*V^T.
   162  				qwant := blas64.General{
   163  					Rows:   n - k + 1,
   164  					Cols:   n - k + 1,
   165  					Stride: n - k + 1,
   166  					Data:   make([]float64, (n-k+1)*(n-k+1)),
   167  				}
   168  				for i := 0; i < qwant.Rows; i++ {
   169  					qwant.Data[i*qwant.Stride+i] = 1
   170  				}
   171  				blas64.Gemm(blas.NoTrans, blas.Trans, -1, vt, v, 1, qwant)
   172  				if !isOrthonormal(qwant) {
   173  					t.Errorf("%v: Q = I - V*T*V^T is not orthogonal", prefix)
   174  				}
   175  
   176  				// Compare Q and QWant. Note that since Q is
   177  				// (n-k)×(n-k) and QWant is (n-k+1)×(n-k+1), we
   178  				// ignore the first row and column of QWant.
   179  				for i := 0; i < n-k; i++ {
   180  					for j := 0; j < n-k; j++ {
   181  						diff := math.Abs(q.Data[i*q.Stride+j] - qwant.Data[(i+1)*qwant.Stride+j+1])
   182  						if diff > 1e-14 {
   183  							t.Errorf("%v: unexpected Q[%v,%v], diff=%v", prefix, i, j, diff)
   184  						}
   185  					}
   186  				}
   187  			}
   188  		}
   189  	}
   190  
   191  	// Go runs tests from the source directory, so unfortunately we need to
   192  	// include the "../testlapack" part.
   193  	file, err := os.Open(filepath.FromSlash("../testlapack/testdata/dlahr2data.json.gz"))
   194  	if err != nil {
   195  		log.Fatal(err)
   196  	}
   197  	defer file.Close()
   198  	r, err := gzip.NewReader(file)
   199  	if err != nil {
   200  		log.Fatal(err)
   201  	}
   202  	defer r.Close()
   203  
   204  	var tests []Dlahr2test
   205  	json.NewDecoder(r).Decode(&tests)
   206  	for _, test := range tests {
   207  		tau := make([]float64, len(test.TauWant))
   208  		for _, ldex := range []int{0, 1, 20} {
   209  			n := test.N
   210  			k := test.K
   211  			nb := test.NB
   212  
   213  			lda := n - k + 1 + ldex
   214  			a := make([]float64, (n-1)*lda+n-k+1)
   215  			copyMatrix(n, n-k+1, a, lda, test.A)
   216  
   217  			ldt := nb + ldex
   218  			tmat := make([]float64, (nb-1)*ldt+nb)
   219  
   220  			ldy := nb + ldex
   221  			y := make([]float64, (n-1)*ldy+nb)
   222  
   223  			impl.Dlahr2(n, k, nb, a, lda, tau, tmat, ldt, y, ldy)
   224  
   225  			prefix := fmt.Sprintf("Case n=%v, k=%v, nb=%v, ldex=%v", n, k, nb, ldex)
   226  			if !equalApprox(n, n-k+1, a, lda, test.AWant, 1e-14) {
   227  				t.Errorf("%v: unexpected matrix A\n got=%v\nwant=%v", prefix, a, test.AWant)
   228  			}
   229  			if !equalApproxTriangular(true, nb, tmat, ldt, test.TWant, 1e-14) {
   230  				t.Errorf("%v: unexpected matrix T\n got=%v\nwant=%v", prefix, tmat, test.TWant)
   231  			}
   232  			if !equalApprox(n, nb, y, ldy, test.YWant, 1e-14) {
   233  				t.Errorf("%v: unexpected matrix Y\n got=%v\nwant=%v", prefix, y, test.YWant)
   234  			}
   235  			if !floats.EqualApprox(tau, test.TauWant, 1e-14) {
   236  				t.Errorf("%v: unexpected slice tau\n got=%v\nwant=%v", prefix, tau, test.TauWant)
   237  			}
   238  		}
   239  	}
   240  }