github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dorgql.go (about)

     1  // Copyright ©2016 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math/rand"
    10  	"testing"
    11  
    12  	"github.com/gonum/blas"
    13  	"github.com/gonum/blas/blas64"
    14  )
    15  
    16  type Dorgqler interface {
    17  	Dorgql(m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
    18  
    19  	Dlarfger
    20  }
    21  
    22  func DorgqlTest(t *testing.T, impl Dorgqler) {
    23  	const tol = 1e-14
    24  
    25  	type Dorg2ler interface {
    26  		Dorg2l(m, n, k int, a []float64, lda int, tau, work []float64)
    27  	}
    28  	dorg2ler, hasDorg2l := impl.(Dorg2ler)
    29  
    30  	rnd := rand.New(rand.NewSource(1))
    31  	for _, m := range []int{0, 1, 2, 3, 4, 5, 7, 10, 15, 30, 50, 150} {
    32  		for _, extra := range []int{0, 11} {
    33  			for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} {
    34  				var k int
    35  				if m >= 129 {
    36  					// For large matrices make sure that k
    37  					// is large enough to trigger blocked
    38  					// path.
    39  					k = 129 + rnd.Intn(m-129+1)
    40  				} else {
    41  					k = rnd.Intn(m + 1)
    42  				}
    43  				n := k + rnd.Intn(m-k+1)
    44  				if m == 0 || n == 0 {
    45  					m = 0
    46  					n = 0
    47  					k = 0
    48  				}
    49  
    50  				// Generate k elementary reflectors in the last
    51  				// k columns of A.
    52  				a := nanGeneral(m, n, n+extra)
    53  				tau := make([]float64, k)
    54  				for l := 0; l < k; l++ {
    55  					jj := m - k + l
    56  					v := randomSlice(jj, rnd)
    57  					_, tau[l] = impl.Dlarfg(len(v)+1, rnd.NormFloat64(), v, 1)
    58  					j := n - k + l
    59  					for i := 0; i < jj; i++ {
    60  						a.Data[i*a.Stride+j] = v[i]
    61  					}
    62  				}
    63  				aCopy := cloneGeneral(a)
    64  
    65  				// Compute the full matrix Q by forming the
    66  				// Householder reflectors explicitly.
    67  				q := eye(m, m)
    68  				qCopy := eye(m, m)
    69  				for l := 0; l < k; l++ {
    70  					h := eye(m, m)
    71  					jj := m - k + l
    72  					j := n - k + l
    73  					v := blas64.Vector{1, make([]float64, m)}
    74  					for i := 0; i < jj; i++ {
    75  						v.Data[i] = a.Data[i*a.Stride+j]
    76  					}
    77  					v.Data[jj] = 1
    78  					blas64.Ger(-tau[l], v, v, h)
    79  					copy(qCopy.Data, q.Data)
    80  					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, qCopy, 0, q)
    81  				}
    82  				// View the last n columns of Q as 'want'.
    83  				want := blas64.General{
    84  					Rows:   m,
    85  					Cols:   n,
    86  					Stride: q.Stride,
    87  					Data:   q.Data[m-n:],
    88  				}
    89  
    90  				var lwork int
    91  				switch wl {
    92  				case minimumWork:
    93  					lwork = max(1, n)
    94  				case mediumWork:
    95  					work := make([]float64, 1)
    96  					impl.Dorgql(m, n, k, nil, a.Stride, nil, work, -1)
    97  					lwork = (int(work[0]) + n) / 2
    98  					lwork = max(1, lwork)
    99  				case optimumWork:
   100  					work := make([]float64, 1)
   101  					impl.Dorgql(m, n, k, nil, a.Stride, nil, work, -1)
   102  					lwork = int(work[0])
   103  				}
   104  				work := make([]float64, lwork)
   105  
   106  				// Compute the last n columns of Q by a call to
   107  				// Dorgql.
   108  				impl.Dorgql(m, n, k, a.Data, a.Stride, tau, work, len(work))
   109  
   110  				prefix := fmt.Sprintf("Case m=%v,n=%v,k=%v,wl=%v", m, n, k, wl)
   111  				if !generalOutsideAllNaN(a) {
   112  					t.Errorf("%v: out-of-range write to A", prefix)
   113  				}
   114  				if !equalApproxGeneral(want, a, tol) {
   115  					t.Errorf("%v: unexpected Q", prefix)
   116  				}
   117  
   118  				// Compute the last n columns of Q by a call to
   119  				// Dorg2l and check that we get the same result.
   120  				if !hasDorg2l {
   121  					continue
   122  				}
   123  				dorg2ler.Dorg2l(m, n, k, aCopy.Data, aCopy.Stride, tau, work)
   124  				if !equalApproxGeneral(aCopy, a, tol) {
   125  					t.Errorf("%v: mismatch between Dorgql and Dorg2l", prefix)
   126  				}
   127  			}
   128  		}
   129  	}
   130  }