github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgerq2.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math"
     9  	"math/rand"
    10  	"testing"
    11  
    12  	"github.com/gonum/blas"
    13  	"github.com/gonum/blas/blas64"
    14  )
    15  
    16  type Dgerq2er interface {
    17  	Dgerq2(m, n int, a []float64, lda int, tau []float64, work []float64)
    18  }
    19  
    20  func Dgerq2Test(t *testing.T, impl Dgerq2er) {
    21  	rnd := rand.New(rand.NewSource(1))
    22  	for c, test := range []struct {
    23  		m, n, lda int
    24  	}{
    25  		{1, 1, 0},
    26  		{2, 2, 0},
    27  		{3, 2, 0},
    28  		{2, 3, 0},
    29  		{1, 12, 0},
    30  		{2, 6, 0},
    31  		{3, 4, 0},
    32  		{4, 3, 0},
    33  		{6, 2, 0},
    34  		{12, 1, 0},
    35  		{1, 1, 20},
    36  		{2, 2, 20},
    37  		{3, 2, 20},
    38  		{2, 3, 20},
    39  		{1, 12, 20},
    40  		{2, 6, 20},
    41  		{3, 4, 20},
    42  		{4, 3, 20},
    43  		{6, 2, 20},
    44  		{12, 1, 20},
    45  	} {
    46  		n := test.n
    47  		m := test.m
    48  		lda := test.lda
    49  		if lda == 0 {
    50  			lda = test.n
    51  		}
    52  		a := make([]float64, m*lda)
    53  		for i := range a {
    54  			a[i] = rnd.Float64()
    55  		}
    56  		aCopy := make([]float64, len(a))
    57  		k := min(m, n)
    58  		tau := make([]float64, k)
    59  		for i := range tau {
    60  			tau[i] = rnd.Float64()
    61  		}
    62  		work := make([]float64, m)
    63  		for i := range work {
    64  			work[i] = rnd.Float64()
    65  		}
    66  		copy(aCopy, a)
    67  		impl.Dgerq2(m, n, a, lda, tau, work)
    68  
    69  		// Test that the RQ factorization has completed successfully. Compute
    70  		// Q based on the vectors.
    71  		q := constructQ("RQ", m, n, a, lda, tau)
    72  
    73  		// Check that q is orthonormal
    74  		for i := 0; i < q.Rows; i++ {
    75  			nrm := blas64.Nrm2(q.Cols, blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]})
    76  			if math.IsNaN(nrm) || math.Abs(nrm-1) > 1e-14 {
    77  				t.Errorf("Case %v, q not normal", c)
    78  			}
    79  			for j := 0; j < i; j++ {
    80  				dot := blas64.Dot(q.Cols, blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]}, blas64.Vector{Inc: 1, Data: q.Data[j*q.Stride:]})
    81  				if math.IsNaN(dot) || math.Abs(dot) > 1e-14 {
    82  					t.Errorf("Case %v, q not orthogonal", c)
    83  				}
    84  			}
    85  		}
    86  		// Check that A = R * Q
    87  		r := blas64.General{
    88  			Rows:   m,
    89  			Cols:   n,
    90  			Stride: n,
    91  			Data:   make([]float64, m*n),
    92  		}
    93  		for i := 0; i < m; i++ {
    94  			off := m - n
    95  			for j := max(0, i-off); j < n; j++ {
    96  				r.Data[i*r.Stride+j] = a[i*lda+j]
    97  			}
    98  		}
    99  
   100  		got := blas64.General{
   101  			Rows:   m,
   102  			Cols:   n,
   103  			Stride: lda,
   104  			Data:   make([]float64, m*lda),
   105  		}
   106  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, r, q, 0, got)
   107  		want := blas64.General{
   108  			Rows:   m,
   109  			Cols:   n,
   110  			Stride: lda,
   111  			Data:   aCopy,
   112  		}
   113  		if !equalApproxGeneral(got, want, 1e-14) {
   114  			t.Errorf("Case %d, R*Q != a\ngot: %+v\nwant:%+v", c, got, want)
   115  		}
   116  	}
   117  }