github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dlaqp2.go (about)

     1  // Copyright ©2017 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"testing"
    11  
    12  	"github.com/gonum/blas"
    13  	"github.com/gonum/blas/blas64"
    14  )
    15  
    16  type Dlaqp2er interface {
    17  	Dlapmter
    18  	Dlaqp2(m, n, offset int, a []float64, lda int, jpvt []int, tau, vn1, vn2, work []float64)
    19  }
    20  
    21  func Dlaqp2Test(t *testing.T, impl Dlaqp2er) {
    22  	for ti, test := range []struct {
    23  		m, n, offset int
    24  	}{
    25  		{m: 4, n: 3, offset: 0},
    26  		{m: 4, n: 3, offset: 2},
    27  		{m: 4, n: 3, offset: 4},
    28  		{m: 3, n: 4, offset: 0},
    29  		{m: 3, n: 4, offset: 1},
    30  		{m: 3, n: 4, offset: 2},
    31  		{m: 8, n: 3, offset: 0},
    32  		{m: 8, n: 3, offset: 4},
    33  		{m: 8, n: 3, offset: 8},
    34  		{m: 3, n: 8, offset: 0},
    35  		{m: 3, n: 8, offset: 1},
    36  		{m: 3, n: 8, offset: 2},
    37  		{m: 10, n: 10, offset: 0},
    38  		{m: 10, n: 10, offset: 5},
    39  		{m: 10, n: 10, offset: 10},
    40  	} {
    41  		m := test.m
    42  		n := test.n
    43  		jpiv := make([]int, n)
    44  
    45  		for _, extra := range []int{0, 11} {
    46  			a := zeros(m, n, n+extra)
    47  			c := 1
    48  			for i := 0; i < m; i++ {
    49  				for j := 0; j < n; j++ {
    50  					a.Data[i*a.Stride+j] = float64(c)
    51  					c++
    52  				}
    53  			}
    54  			aCopy := cloneGeneral(a)
    55  			for j := range jpiv {
    56  				jpiv[j] = j
    57  			}
    58  
    59  			tau := make([]float64, n)
    60  			vn1 := columnNorms(m, n, a.Data, a.Stride)
    61  			vn2 := columnNorms(m, n, a.Data, a.Stride)
    62  			work := make([]float64, n)
    63  
    64  			impl.Dlaqp2(m, n, test.offset, a.Data, a.Stride, jpiv, tau, vn1, vn2, work)
    65  
    66  			prefix := fmt.Sprintf("Case %v (offset=%t,m=%v,n=%v,extra=%v)", ti, test.offset, m, n, extra)
    67  			if !generalOutsideAllNaN(a) {
    68  				t.Errorf("%v: out-of-range write to A", prefix)
    69  			}
    70  
    71  			if test.offset == m {
    72  				continue
    73  			}
    74  
    75  			mo := m - test.offset
    76  			q := constructQ("QR", mo, n, a.Data[test.offset*a.Stride:], a.Stride, tau)
    77  			// Check that q is orthonormal
    78  			for i := 0; i < mo; i++ {
    79  				nrm := blas64.Nrm2(mo, blas64.Vector{Inc: 1, Data: q.Data[i*mo:]})
    80  				if math.Abs(nrm-1) > 1e-13 {
    81  					t.Errorf("Case %v, q not normal", ti)
    82  				}
    83  				for j := 0; j < i; j++ {
    84  					dot := blas64.Dot(mo, blas64.Vector{Inc: 1, Data: q.Data[i*mo:]}, blas64.Vector{Inc: 1, Data: q.Data[j*mo:]})
    85  					if math.Abs(dot) > 1e-14 {
    86  						t.Errorf("Case %v, q not orthogonal", ti)
    87  					}
    88  				}
    89  			}
    90  
    91  			// Check that A * P = Q * R
    92  			r := blas64.General{
    93  				Rows:   mo,
    94  				Cols:   n,
    95  				Stride: n,
    96  				Data:   make([]float64, mo*n),
    97  			}
    98  			for i := 0; i < mo; i++ {
    99  				for j := i; j < n; j++ {
   100  					r.Data[i*n+j] = a.Data[(test.offset+i)*a.Stride+j]
   101  				}
   102  			}
   103  			got := nanGeneral(mo, n, n)
   104  			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, r, 0, got)
   105  
   106  			want := aCopy
   107  			impl.Dlapmt(true, want.Rows, want.Cols, want.Data, want.Stride, jpiv)
   108  			want.Rows = mo
   109  			want.Data = want.Data[test.offset*want.Stride:]
   110  			if !equalApproxGeneral(got, want, 1e-12) {
   111  				t.Errorf("Case %v,  Q*R != A*P\nQ*R=%v\nA*P=%v", ti, got, want)
   112  			}
   113  		}
   114  	}
   115  }