github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgeql2.go (about)

     1  // Copyright ©2016 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math/rand"
     9  	"testing"
    10  
    11  	"github.com/gonum/blas"
    12  	"github.com/gonum/blas/blas64"
    13  	"github.com/gonum/floats"
    14  )
    15  
    16  type Dgeql2er interface {
    17  	Dgeql2(m, n int, a []float64, lda int, tau, work []float64)
    18  }
    19  
    20  func Dgeql2Test(t *testing.T, impl Dgeql2er) {
    21  	rnd := rand.New(rand.NewSource(1))
    22  	// TODO(btracey): Add tests for m < n.
    23  	for _, test := range []struct {
    24  		m, n, lda int
    25  	}{
    26  		{5, 5, 0},
    27  		{5, 3, 0},
    28  		{5, 4, 0},
    29  	} {
    30  		m := test.m
    31  		n := test.n
    32  		lda := test.lda
    33  		if lda == 0 {
    34  			lda = n
    35  		}
    36  		a := make([]float64, m*lda)
    37  		for i := range a {
    38  			a[i] = rnd.NormFloat64()
    39  		}
    40  		tau := nanSlice(min(m, n))
    41  		work := nanSlice(n)
    42  
    43  		aCopy := make([]float64, len(a))
    44  		copy(aCopy, a)
    45  		impl.Dgeql2(m, n, a, lda, tau, work)
    46  
    47  		k := min(m, n)
    48  		// Construct Q.
    49  		q := blas64.General{
    50  			Rows:   m,
    51  			Cols:   m,
    52  			Stride: m,
    53  			Data:   make([]float64, m*m),
    54  		}
    55  		for i := 0; i < m; i++ {
    56  			q.Data[i*q.Stride+i] = 1
    57  		}
    58  		for i := 0; i < k; i++ {
    59  			h := blas64.General{Rows: m, Cols: m, Stride: m, Data: make([]float64, m*m)}
    60  			for j := 0; j < m; j++ {
    61  				h.Data[j*h.Stride+j] = 1
    62  			}
    63  			v := blas64.Vector{Inc: 1, Data: make([]float64, m)}
    64  			v.Data[m-k+i] = 1
    65  			for j := 0; j < m-k+i; j++ {
    66  				v.Data[j] = a[j*lda+n-k+i]
    67  			}
    68  			blas64.Ger(-tau[i], v, v, h)
    69  			qTmp := blas64.General{Rows: q.Rows, Cols: q.Cols, Stride: q.Stride, Data: make([]float64, len(q.Data))}
    70  			copy(qTmp.Data, q.Data)
    71  			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, qTmp, 0, q)
    72  		}
    73  		if !isOrthonormal(q) {
    74  			t.Errorf("Q is not orthonormal")
    75  		}
    76  		l := blas64.General{
    77  			Rows:   m,
    78  			Cols:   n,
    79  			Stride: n,
    80  			Data:   make([]float64, m*n),
    81  		}
    82  		if m >= n {
    83  			for i := m - n; i < m; i++ {
    84  				for j := 0; j <= min(i-(m-n), n-1); j++ {
    85  					l.Data[i*l.Stride+j] = a[i*lda+j]
    86  				}
    87  			}
    88  		} else {
    89  			panic("untested")
    90  		}
    91  		ans := blas64.General{Rows: m, Cols: n, Stride: lda, Data: make([]float64, len(a))}
    92  		copy(ans.Data, a)
    93  
    94  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, l, 0, ans)
    95  		if !floats.EqualApprox(ans.Data, aCopy, 1e-10) {
    96  			t.Errorf("Reconstruction mismatch: m = %v, n = %v", m, n)
    97  		}
    98  	}
    99  }