github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dlaqps.go (about)

     1  // Copyright ©2017 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"testing"
    11  
    12  	"github.com/gonum/blas"
    13  	"github.com/gonum/blas/blas64"
    14  )
    15  
    16  type Dlaqpser interface {
    17  	Dlapmter
    18  	Dlaqps(m, n, offset, nb int, a []float64, lda int, jpvt []int, tau, vn1, vn2, auxv, f []float64, ldf int) (kb int)
    19  }
    20  
    21  func DlaqpsTest(t *testing.T, impl Dlaqpser) {
    22  	for ti, test := range []struct {
    23  		m, n, nb, offset int
    24  	}{
    25  		{m: 4, n: 3, nb: 2, offset: 0},
    26  		{m: 4, n: 3, nb: 1, offset: 2},
    27  		{m: 3, n: 4, nb: 2, offset: 0},
    28  		{m: 3, n: 4, nb: 1, offset: 2},
    29  		{m: 8, n: 3, nb: 2, offset: 0},
    30  		{m: 8, n: 3, nb: 1, offset: 4},
    31  		{m: 3, n: 8, nb: 2, offset: 0},
    32  		{m: 3, n: 8, nb: 1, offset: 1},
    33  		{m: 10, n: 10, nb: 3, offset: 0},
    34  		{m: 10, n: 10, nb: 2, offset: 5},
    35  	} {
    36  		m := test.m
    37  		n := test.n
    38  		jpiv := make([]int, n)
    39  
    40  		for _, extra := range []int{0, 11} {
    41  			a := zeros(m, n, n+extra)
    42  			c := 1
    43  			for i := 0; i < m; i++ {
    44  				for j := 0; j < n; j++ {
    45  					a.Data[i*a.Stride+j] = float64(c)
    46  					c++
    47  				}
    48  			}
    49  			aCopy := cloneGeneral(a)
    50  			for j := range jpiv {
    51  				jpiv[j] = j
    52  			}
    53  
    54  			tau := make([]float64, n)
    55  			vn1 := columnNorms(m, n, a.Data, a.Stride)
    56  			vn2 := columnNorms(m, n, a.Data, a.Stride)
    57  			auxv := make([]float64, test.nb)
    58  			f := zeros(test.n, test.nb, n)
    59  
    60  			kb := impl.Dlaqps(m, n, test.offset, test.nb, a.Data, a.Stride, jpiv, tau, vn1, vn2, auxv, f.Data, f.Stride)
    61  
    62  			prefix := fmt.Sprintf("Case %v (offset=%t,m=%v,n=%v,extra=%v)", ti, test.offset, m, n, extra)
    63  			if !generalOutsideAllNaN(a) {
    64  				t.Errorf("%v: out-of-range write to A", prefix)
    65  			}
    66  
    67  			if test.offset == m {
    68  				continue
    69  			}
    70  
    71  			mo := m - test.offset
    72  			q := constructQ("QR", mo, kb, a.Data[test.offset*a.Stride:], a.Stride, tau)
    73  			// Check that q is orthonormal
    74  			for i := 0; i < mo; i++ {
    75  				nrm := blas64.Nrm2(mo, blas64.Vector{Inc: 1, Data: q.Data[i*mo:]})
    76  				if math.Abs(nrm-1) > 1e-13 {
    77  					t.Errorf("Case %v, q not normal", ti)
    78  				}
    79  				for j := 0; j < i; j++ {
    80  					dot := blas64.Dot(mo, blas64.Vector{Inc: 1, Data: q.Data[i*mo:]}, blas64.Vector{Inc: 1, Data: q.Data[j*mo:]})
    81  					if math.Abs(dot) > 1e-14 {
    82  						t.Errorf("Case %v, q not orthogonal", ti)
    83  					}
    84  				}
    85  			}
    86  
    87  			// Check that A * P = Q * R
    88  			r := blas64.General{
    89  				Rows:   mo,
    90  				Cols:   kb,
    91  				Stride: kb,
    92  				Data:   make([]float64, mo*kb),
    93  			}
    94  			for i := 0; i < mo; i++ {
    95  				for j := i; j < kb; j++ {
    96  					r.Data[i*kb+j] = a.Data[(test.offset+i)*a.Stride+j]
    97  				}
    98  			}
    99  			got := nanGeneral(mo, kb, kb)
   100  			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, r, 0, got)
   101  
   102  			want := aCopy
   103  			impl.Dlapmt(true, want.Rows, want.Cols, want.Data, want.Stride, jpiv)
   104  			want.Rows = mo
   105  			want.Cols = kb
   106  			want.Data = want.Data[test.offset*want.Stride:]
   107  			if !equalApproxGeneral(got, want, 1e-12) {
   108  				t.Errorf("Case %v,  Q*R != A*P\nQ*R=%v\nA*P=%v", ti, got, want)
   109  			}
   110  		}
   111  	}
   112  }