github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgerqf.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math"
     9  	"math/rand"
    10  	"testing"
    11  
    12  	"github.com/gonum/blas"
    13  	"github.com/gonum/blas/blas64"
    14  )
    15  
    16  type Dgerqfer interface {
    17  	Dgerqf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
    18  }
    19  
    20  func DgerqfTest(t *testing.T, impl Dgerqfer) {
    21  	rnd := rand.New(rand.NewSource(1))
    22  	for c, test := range []struct {
    23  		m, n, lda int
    24  	}{
    25  		{1, 1, 0},
    26  		{2, 2, 0},
    27  		{3, 2, 0},
    28  		{2, 3, 0},
    29  		{1, 12, 0},
    30  		{2, 6, 0},
    31  		{3, 4, 0},
    32  		{4, 3, 0},
    33  		{6, 2, 0},
    34  		{12, 1, 0},
    35  		{1, 1, 20},
    36  		{2, 2, 20},
    37  		{3, 2, 20},
    38  		{2, 3, 20},
    39  		{1, 12, 20},
    40  		{2, 6, 20},
    41  		{3, 4, 20},
    42  		{4, 3, 20},
    43  		{6, 2, 20},
    44  		{12, 1, 20},
    45  	} {
    46  		n := test.n
    47  		m := test.m
    48  		lda := test.lda
    49  		if lda == 0 {
    50  			lda = test.n
    51  		}
    52  		a := make([]float64, m*lda)
    53  		for i := range a {
    54  			a[i] = rnd.Float64()
    55  		}
    56  		aCopy := make([]float64, len(a))
    57  		copy(aCopy, a)
    58  		k := min(m, n)
    59  		tau := make([]float64, k)
    60  		for i := range tau {
    61  			tau[i] = rnd.Float64()
    62  		}
    63  		work := []float64{0}
    64  		impl.Dgerqf(m, n, a, lda, tau, work, -1)
    65  		lwkopt := int(work[0])
    66  		for _, wk := range []struct {
    67  			name   string
    68  			length int
    69  		}{
    70  			{name: "short", length: m},
    71  			{name: "medium", length: lwkopt - 1},
    72  			{name: "long", length: lwkopt},
    73  		} {
    74  			if wk.length < max(1, m) {
    75  				continue
    76  			}
    77  			lwork := wk.length
    78  			work = make([]float64, lwork)
    79  			for i := range work {
    80  				work[i] = rnd.Float64()
    81  			}
    82  			copy(a, aCopy)
    83  			impl.Dgerqf(m, n, a, lda, tau, work, lwork)
    84  
    85  			// Test that the RQ factorization has completed successfully. Compute
    86  			// Q based on the vectors.
    87  			q := constructQ("RQ", m, n, a, lda, tau)
    88  
    89  			// Check that q is orthonormal
    90  			for i := 0; i < q.Rows; i++ {
    91  				nrm := blas64.Nrm2(q.Cols, blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]})
    92  				if math.IsNaN(nrm) || math.Abs(nrm-1) > 1e-14 {
    93  					t.Errorf("Case %v, q not normal", c)
    94  				}
    95  				for j := 0; j < i; j++ {
    96  					dot := blas64.Dot(q.Cols, blas64.Vector{Inc: 1, Data: q.Data[i*q.Stride:]}, blas64.Vector{Inc: 1, Data: q.Data[j*q.Stride:]})
    97  					if math.IsNaN(dot) || math.Abs(dot) > 1e-14 {
    98  						t.Errorf("Case %v, q not orthogonal", c)
    99  					}
   100  				}
   101  			}
   102  			// Check that A = R * Q
   103  			r := blas64.General{
   104  				Rows:   m,
   105  				Cols:   n,
   106  				Stride: n,
   107  				Data:   make([]float64, m*n),
   108  			}
   109  			for i := 0; i < m; i++ {
   110  				off := m - n
   111  				for j := max(0, i-off); j < n; j++ {
   112  					r.Data[i*r.Stride+j] = a[i*lda+j]
   113  				}
   114  			}
   115  
   116  			got := blas64.General{
   117  				Rows:   m,
   118  				Cols:   n,
   119  				Stride: lda,
   120  				Data:   make([]float64, m*lda),
   121  			}
   122  			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, r, q, 0, got)
   123  			want := blas64.General{
   124  				Rows:   m,
   125  				Cols:   n,
   126  				Stride: lda,
   127  				Data:   aCopy,
   128  			}
   129  			if !equalApproxGeneral(got, want, 1e-14) {
   130  				t.Errorf("Case %d, R*Q != a %s\ngot: %+v\nwant:%+v", c, wk.name, got, want)
   131  			}
   132  		}
   133  	}
   134  }