github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgeqrf.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math/rand"
     9  	"testing"
    10  
    11  	"github.com/gonum/floats"
    12  )
    13  
    14  type Dgeqrfer interface {
    15  	Dgeqr2er
    16  	Dgeqrf(m, n int, a []float64, lda int, tau, work []float64, lwork int)
    17  }
    18  
    19  func DgeqrfTest(t *testing.T, impl Dgeqrfer) {
    20  	rnd := rand.New(rand.NewSource(1))
    21  	for c, test := range []struct {
    22  		m, n, lda int
    23  	}{
    24  		{10, 5, 0},
    25  		{5, 10, 0},
    26  		{10, 10, 0},
    27  		{300, 5, 0},
    28  		{3, 500, 0},
    29  		{200, 200, 0},
    30  		{300, 200, 0},
    31  		{204, 300, 0},
    32  		{1, 3000, 0},
    33  		{3000, 1, 0},
    34  		{10, 5, 20},
    35  		{5, 10, 20},
    36  		{10, 10, 20},
    37  		{300, 5, 400},
    38  		{3, 500, 600},
    39  		{200, 200, 300},
    40  		{300, 200, 300},
    41  		{204, 300, 400},
    42  		{1, 3000, 4000},
    43  		{3000, 1, 4000},
    44  	} {
    45  		m := test.m
    46  		n := test.n
    47  		lda := test.lda
    48  		if lda == 0 {
    49  			lda = test.n
    50  		}
    51  		a := make([]float64, m*lda)
    52  		for i := 0; i < m; i++ {
    53  			for j := 0; j < n; j++ {
    54  				a[i*lda+j] = rnd.Float64()
    55  			}
    56  		}
    57  		tau := make([]float64, n)
    58  		for i := 0; i < n; i++ {
    59  			tau[i] = rnd.Float64()
    60  		}
    61  		aCopy := make([]float64, len(a))
    62  		copy(aCopy, a)
    63  		ans := make([]float64, len(a))
    64  		copy(ans, a)
    65  		work := make([]float64, n)
    66  		// Compute unblocked QR.
    67  		impl.Dgeqr2(m, n, ans, lda, tau, work)
    68  		// Compute blocked QR with small work.
    69  		impl.Dgeqrf(m, n, a, lda, tau, work, len(work))
    70  		if !floats.EqualApprox(ans, a, 1e-12) {
    71  			t.Errorf("Case %v, mismatch small work.", c)
    72  		}
    73  		// Try the full length of work.
    74  		impl.Dgeqrf(m, n, a, lda, tau, work, -1)
    75  		lwork := int(work[0])
    76  		work = make([]float64, lwork)
    77  		copy(a, aCopy)
    78  		impl.Dgeqrf(m, n, a, lda, tau, work, lwork)
    79  		if !floats.EqualApprox(ans, a, 1e-12) {
    80  			t.Errorf("Case %v, mismatch large work.", c)
    81  		}
    82  
    83  		// Try a slightly smaller version of work to test blocking.
    84  		if len(work) <= n {
    85  			continue
    86  		}
    87  		work = work[1:]
    88  		lwork--
    89  		copy(a, aCopy)
    90  		impl.Dgeqrf(m, n, a, lda, tau, work, lwork)
    91  		if !floats.EqualApprox(ans, a, 1e-12) {
    92  			t.Errorf("Case %v, mismatch large work.", c)
    93  		}
    94  	}
    95  }