github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dlarfb.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math/rand"
     9  	"testing"
    10  
    11  	"github.com/gonum/blas"
    12  	"github.com/gonum/blas/blas64"
    13  	"github.com/gonum/floats"
    14  	"github.com/gonum/lapack"
    15  )
    16  
    17  type Dlarfber interface {
    18  	Dlarfter
    19  	Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct,
    20  		store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int,
    21  		c []float64, ldc int, work []float64, ldwork int)
    22  }
    23  
    24  func DlarfbTest(t *testing.T, impl Dlarfber) {
    25  	rnd := rand.New(rand.NewSource(1))
    26  	for _, store := range []lapack.StoreV{lapack.ColumnWise, lapack.RowWise} {
    27  		for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
    28  			for _, side := range []blas.Side{blas.Left, blas.Right} {
    29  				for _, trans := range []blas.Transpose{blas.Trans, blas.NoTrans} {
    30  					for cas, test := range []struct {
    31  						ma, na, cdim, lda, ldt, ldc int
    32  					}{
    33  						{6, 6, 6, 0, 0, 0},
    34  						{6, 8, 10, 0, 0, 0},
    35  						{6, 10, 8, 0, 0, 0},
    36  						{8, 6, 10, 0, 0, 0},
    37  						{8, 10, 6, 0, 0, 0},
    38  						{10, 6, 8, 0, 0, 0},
    39  						{10, 8, 6, 0, 0, 0},
    40  						{6, 6, 6, 12, 15, 30},
    41  						{6, 8, 10, 12, 15, 30},
    42  						{6, 10, 8, 12, 15, 30},
    43  						{8, 6, 10, 12, 15, 30},
    44  						{8, 10, 6, 12, 15, 30},
    45  						{10, 6, 8, 12, 15, 30},
    46  						{10, 8, 6, 12, 15, 30},
    47  						{6, 6, 6, 15, 12, 30},
    48  						{6, 8, 10, 15, 12, 30},
    49  						{6, 10, 8, 15, 12, 30},
    50  						{8, 6, 10, 15, 12, 30},
    51  						{8, 10, 6, 15, 12, 30},
    52  						{10, 6, 8, 15, 12, 30},
    53  						{10, 8, 6, 15, 12, 30},
    54  					} {
    55  						// Generate a matrix for QR
    56  						ma := test.ma
    57  						na := test.na
    58  						lda := test.lda
    59  						if lda == 0 {
    60  							lda = na
    61  						}
    62  						a := make([]float64, ma*lda)
    63  						for i := 0; i < ma; i++ {
    64  							for j := 0; j < lda; j++ {
    65  								a[i*lda+j] = rnd.Float64()
    66  							}
    67  						}
    68  						k := min(ma, na)
    69  
    70  						// H is always ma x ma
    71  						var m, n, rowsWork int
    72  						switch {
    73  						default:
    74  							panic("not implemented")
    75  						case side == blas.Left:
    76  							m = test.ma
    77  							n = test.cdim
    78  							rowsWork = n
    79  						case side == blas.Right:
    80  							m = test.cdim
    81  							n = test.ma
    82  							rowsWork = m
    83  						}
    84  
    85  						// Use dgeqr2 to find the v vectors
    86  						tau := make([]float64, na)
    87  						work := make([]float64, na)
    88  						impl.Dgeqr2(ma, k, a, lda, tau, work)
    89  
    90  						// Correct the v vectors based on the direct and store
    91  						vMatTmp := extractVMat(ma, na, a, lda, lapack.Forward, lapack.ColumnWise)
    92  						vMat := constructVMat(vMatTmp, store, direct)
    93  						v := vMat.Data
    94  						ldv := vMat.Stride
    95  
    96  						// Use dlarft to find the t vector
    97  						ldt := test.ldt
    98  						if ldt == 0 {
    99  							ldt = k
   100  						}
   101  						tm := make([]float64, k*ldt)
   102  
   103  						impl.Dlarft(direct, store, ma, k, v, ldv, tau, tm, ldt)
   104  
   105  						// Generate c matrix
   106  						ldc := test.ldc
   107  						if ldc == 0 {
   108  							ldc = n
   109  						}
   110  						c := make([]float64, m*ldc)
   111  						for i := 0; i < m; i++ {
   112  							for j := 0; j < ldc; j++ {
   113  								c[i*ldc+j] = rnd.Float64()
   114  							}
   115  						}
   116  						cCopy := make([]float64, len(c))
   117  						copy(cCopy, c)
   118  
   119  						ldwork := k
   120  						work = make([]float64, rowsWork*k)
   121  
   122  						// Call Dlarfb with this information
   123  						impl.Dlarfb(side, trans, direct, store, m, n, k, v, ldv, tm, ldt, c, ldc, work, ldwork)
   124  
   125  						h := constructH(tau, vMat, store, direct)
   126  
   127  						cMat := blas64.General{
   128  							Rows:   m,
   129  							Cols:   n,
   130  							Stride: ldc,
   131  							Data:   make([]float64, m*ldc),
   132  						}
   133  						copy(cMat.Data, cCopy)
   134  						ans := blas64.General{
   135  							Rows:   m,
   136  							Cols:   n,
   137  							Stride: ldc,
   138  							Data:   make([]float64, m*ldc),
   139  						}
   140  						copy(ans.Data, cMat.Data)
   141  						switch {
   142  						default:
   143  							panic("not implemented")
   144  						case side == blas.Left && trans == blas.NoTrans:
   145  							blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, cMat, 0, ans)
   146  						case side == blas.Left && trans == blas.Trans:
   147  							blas64.Gemm(blas.Trans, blas.NoTrans, 1, h, cMat, 0, ans)
   148  						case side == blas.Right && trans == blas.NoTrans:
   149  							blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMat, h, 0, ans)
   150  						case side == blas.Right && trans == blas.Trans:
   151  							blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMat, h, 0, ans)
   152  						}
   153  						if !floats.EqualApprox(ans.Data, c, 1e-14) {
   154  							t.Errorf("Cas %v mismatch. Want %v, got %v.", cas, ans.Data, c)
   155  						}
   156  					}
   157  				}
   158  			}
   159  		}
   160  	}
   161  }