github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/testlapack/dgeql2.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"testing"
     9  
    10  	"golang.org/x/exp/rand"
    11  
    12  	"github.com/jingcheng-WU/gonum/blas"
    13  	"github.com/jingcheng-WU/gonum/blas/blas64"
    14  	"github.com/jingcheng-WU/gonum/floats"
    15  )
    16  
    17  type Dgeql2er interface {
    18  	Dgeql2(m, n int, a []float64, lda int, tau, work []float64)
    19  }
    20  
    21  func Dgeql2Test(t *testing.T, impl Dgeql2er) {
    22  	const tol = 1e-14
    23  
    24  	rnd := rand.New(rand.NewSource(1))
    25  	// TODO(btracey): Add tests for m < n.
    26  	for _, test := range []struct {
    27  		m, n, lda int
    28  	}{
    29  		{5, 5, 0},
    30  		{5, 3, 0},
    31  		{5, 4, 0},
    32  	} {
    33  		m := test.m
    34  		n := test.n
    35  		lda := test.lda
    36  		if lda == 0 {
    37  			lda = n
    38  		}
    39  		a := make([]float64, m*lda)
    40  		for i := range a {
    41  			a[i] = rnd.NormFloat64()
    42  		}
    43  		tau := nanSlice(min(m, n))
    44  		work := nanSlice(n)
    45  
    46  		aCopy := make([]float64, len(a))
    47  		copy(aCopy, a)
    48  		impl.Dgeql2(m, n, a, lda, tau, work)
    49  
    50  		k := min(m, n)
    51  		// Construct Q.
    52  		q := blas64.General{
    53  			Rows:   m,
    54  			Cols:   m,
    55  			Stride: m,
    56  			Data:   make([]float64, m*m),
    57  		}
    58  		for i := 0; i < m; i++ {
    59  			q.Data[i*q.Stride+i] = 1
    60  		}
    61  		for i := 0; i < k; i++ {
    62  			h := blas64.General{Rows: m, Cols: m, Stride: m, Data: make([]float64, m*m)}
    63  			for j := 0; j < m; j++ {
    64  				h.Data[j*h.Stride+j] = 1
    65  			}
    66  			v := blas64.Vector{Inc: 1, Data: make([]float64, m)}
    67  			v.Data[m-k+i] = 1
    68  			for j := 0; j < m-k+i; j++ {
    69  				v.Data[j] = a[j*lda+n-k+i]
    70  			}
    71  			blas64.Ger(-tau[i], v, v, h)
    72  			qTmp := blas64.General{Rows: q.Rows, Cols: q.Cols, Stride: q.Stride, Data: make([]float64, len(q.Data))}
    73  			copy(qTmp.Data, q.Data)
    74  			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, qTmp, 0, q)
    75  		}
    76  		if resid := residualOrthogonal(q, false); resid > tol {
    77  			t.Errorf("Q is not orthogonal; resid=%v, want<=%v", resid, tol)
    78  		}
    79  		l := blas64.General{
    80  			Rows:   m,
    81  			Cols:   n,
    82  			Stride: n,
    83  			Data:   make([]float64, m*n),
    84  		}
    85  		if m >= n {
    86  			for i := m - n; i < m; i++ {
    87  				for j := 0; j <= min(i-(m-n), n-1); j++ {
    88  					l.Data[i*l.Stride+j] = a[i*lda+j]
    89  				}
    90  			}
    91  		} else {
    92  			panic("untested")
    93  		}
    94  		ans := blas64.General{Rows: m, Cols: n, Stride: lda, Data: make([]float64, len(a))}
    95  		copy(ans.Data, a)
    96  
    97  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, l, 0, ans)
    98  		if !floats.EqualApprox(ans.Data, aCopy, tol) {
    99  			t.Errorf("Reconstruction mismatch: m = %v, n = %v", m, n)
   100  		}
   101  	}
   102  }