github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/testlapack/dlarfb.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"testing"
     9  
    10  	"golang.org/x/exp/rand"
    11  
    12  	"github.com/jingcheng-WU/gonum/blas"
    13  	"github.com/jingcheng-WU/gonum/blas/blas64"
    14  	"github.com/jingcheng-WU/gonum/floats"
    15  	"github.com/jingcheng-WU/gonum/lapack"
    16  )
    17  
    18  type Dlarfber interface {
    19  	Dlarfter
    20  	Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct,
    21  		store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int,
    22  		c []float64, ldc int, work []float64, ldwork int)
    23  }
    24  
    25  func DlarfbTest(t *testing.T, impl Dlarfber) {
    26  	rnd := rand.New(rand.NewSource(1))
    27  	for _, store := range []lapack.StoreV{lapack.ColumnWise, lapack.RowWise} {
    28  		for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
    29  			for _, side := range []blas.Side{blas.Left, blas.Right} {
    30  				for _, trans := range []blas.Transpose{blas.Trans, blas.NoTrans} {
    31  					for cas, test := range []struct {
    32  						ma, na, cdim, lda, ldt, ldc int
    33  					}{
    34  						{6, 6, 6, 0, 0, 0},
    35  						{6, 8, 10, 0, 0, 0},
    36  						{6, 10, 8, 0, 0, 0},
    37  						{8, 6, 10, 0, 0, 0},
    38  						{8, 10, 6, 0, 0, 0},
    39  						{10, 6, 8, 0, 0, 0},
    40  						{10, 8, 6, 0, 0, 0},
    41  						{6, 6, 6, 12, 15, 30},
    42  						{6, 8, 10, 12, 15, 30},
    43  						{6, 10, 8, 12, 15, 30},
    44  						{8, 6, 10, 12, 15, 30},
    45  						{8, 10, 6, 12, 15, 30},
    46  						{10, 6, 8, 12, 15, 30},
    47  						{10, 8, 6, 12, 15, 30},
    48  						{6, 6, 6, 15, 12, 30},
    49  						{6, 8, 10, 15, 12, 30},
    50  						{6, 10, 8, 15, 12, 30},
    51  						{8, 6, 10, 15, 12, 30},
    52  						{8, 10, 6, 15, 12, 30},
    53  						{10, 6, 8, 15, 12, 30},
    54  						{10, 8, 6, 15, 12, 30},
    55  					} {
    56  						// Generate a matrix for QR
    57  						ma := test.ma
    58  						na := test.na
    59  						lda := test.lda
    60  						if lda == 0 {
    61  							lda = na
    62  						}
    63  						a := make([]float64, ma*lda)
    64  						for i := 0; i < ma; i++ {
    65  							for j := 0; j < lda; j++ {
    66  								a[i*lda+j] = rnd.Float64()
    67  							}
    68  						}
    69  						k := min(ma, na)
    70  
    71  						// H is always ma x ma
    72  						var m, n, rowsWork int
    73  						switch {
    74  						default:
    75  							panic("not implemented")
    76  						case side == blas.Left:
    77  							m = test.ma
    78  							n = test.cdim
    79  							rowsWork = n
    80  						case side == blas.Right:
    81  							m = test.cdim
    82  							n = test.ma
    83  							rowsWork = m
    84  						}
    85  
    86  						// Use dgeqr2 to find the v vectors
    87  						tau := make([]float64, na)
    88  						work := make([]float64, na)
    89  						impl.Dgeqr2(ma, k, a, lda, tau, work)
    90  
    91  						// Correct the v vectors based on the direct and store
    92  						vMatTmp := extractVMat(ma, na, a, lda, lapack.Forward, lapack.ColumnWise)
    93  						vMat := constructVMat(vMatTmp, store, direct)
    94  						v := vMat.Data
    95  						ldv := vMat.Stride
    96  
    97  						// Use dlarft to find the t vector
    98  						ldt := test.ldt
    99  						if ldt == 0 {
   100  							ldt = k
   101  						}
   102  						tm := make([]float64, k*ldt)
   103  
   104  						impl.Dlarft(direct, store, ma, k, v, ldv, tau, tm, ldt)
   105  
   106  						// Generate c matrix
   107  						ldc := test.ldc
   108  						if ldc == 0 {
   109  							ldc = n
   110  						}
   111  						c := make([]float64, m*ldc)
   112  						for i := 0; i < m; i++ {
   113  							for j := 0; j < ldc; j++ {
   114  								c[i*ldc+j] = rnd.Float64()
   115  							}
   116  						}
   117  						cCopy := make([]float64, len(c))
   118  						copy(cCopy, c)
   119  
   120  						ldwork := k
   121  						work = make([]float64, rowsWork*k)
   122  
   123  						// Call Dlarfb with this information
   124  						impl.Dlarfb(side, trans, direct, store, m, n, k, v, ldv, tm, ldt, c, ldc, work, ldwork)
   125  
   126  						h := constructH(tau, vMat, store, direct)
   127  
   128  						cMat := blas64.General{
   129  							Rows:   m,
   130  							Cols:   n,
   131  							Stride: ldc,
   132  							Data:   make([]float64, m*ldc),
   133  						}
   134  						copy(cMat.Data, cCopy)
   135  						ans := blas64.General{
   136  							Rows:   m,
   137  							Cols:   n,
   138  							Stride: ldc,
   139  							Data:   make([]float64, m*ldc),
   140  						}
   141  						copy(ans.Data, cMat.Data)
   142  						switch {
   143  						default:
   144  							panic("not implemented")
   145  						case side == blas.Left && trans == blas.NoTrans:
   146  							blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, cMat, 0, ans)
   147  						case side == blas.Left && trans == blas.Trans:
   148  							blas64.Gemm(blas.Trans, blas.NoTrans, 1, h, cMat, 0, ans)
   149  						case side == blas.Right && trans == blas.NoTrans:
   150  							blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMat, h, 0, ans)
   151  						case side == blas.Right && trans == blas.Trans:
   152  							blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMat, h, 0, ans)
   153  						}
   154  						if !floats.EqualApprox(ans.Data, c, 1e-14) {
   155  							t.Errorf("Cas %v mismatch. Want %v, got %v.", cas, ans.Data, c)
   156  						}
   157  					}
   158  				}
   159  			}
   160  		}
   161  	}
   162  }