gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dlarft.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"testing"
     9  
    10  	"golang.org/x/exp/rand"
    11  
    12  	"gonum.org/v1/gonum/blas"
    13  	"gonum.org/v1/gonum/blas/blas64"
    14  	"gonum.org/v1/gonum/floats"
    15  	"gonum.org/v1/gonum/lapack"
    16  )
    17  
    18  type Dlarfter interface {
    19  	Dgeqr2er
    20  	Dlarft(direct lapack.Direct, store lapack.StoreV, n, k int, v []float64, ldv int, tau []float64, t []float64, ldt int)
    21  }
    22  
    23  func DlarftTest(t *testing.T, impl Dlarfter) {
    24  	rnd := rand.New(rand.NewSource(1))
    25  	for _, store := range []lapack.StoreV{lapack.ColumnWise, lapack.RowWise} {
    26  		for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
    27  			for _, test := range []struct {
    28  				m, n, ldv, ldt int
    29  			}{
    30  				{6, 6, 0, 0},
    31  				{8, 6, 0, 0},
    32  				{6, 8, 0, 0},
    33  				{6, 6, 10, 15},
    34  				{8, 6, 10, 15},
    35  				{6, 8, 10, 15},
    36  				{6, 6, 15, 10},
    37  				{8, 6, 15, 10},
    38  				{6, 8, 15, 10},
    39  			} {
    40  				// Generate a matrix
    41  				m := test.m
    42  				n := test.n
    43  				lda := n
    44  				if lda == 0 {
    45  					lda = n
    46  				}
    47  
    48  				a := make([]float64, m*lda)
    49  				for i := 0; i < m; i++ {
    50  					for j := 0; j < lda; j++ {
    51  						a[i*lda+j] = rnd.Float64()
    52  					}
    53  				}
    54  				// Use dgeqr2 to find the v vectors
    55  				tau := make([]float64, n)
    56  				work := make([]float64, n)
    57  				impl.Dgeqr2(m, n, a, lda, tau, work)
    58  
    59  				// Construct H using these answers
    60  				vMatTmp := extractVMat(m, n, a, lda, lapack.Forward, lapack.ColumnWise)
    61  				vMat := constructVMat(vMatTmp, store, direct)
    62  				v := vMat.Data
    63  				ldv := vMat.Stride
    64  
    65  				h := constructH(tau, vMat, store, direct)
    66  
    67  				k := min(m, n)
    68  				ldt := test.ldt
    69  				if ldt == 0 {
    70  					ldt = k
    71  				}
    72  				// Find T from the actual function
    73  				tm := make([]float64, k*ldt)
    74  				for i := range tm {
    75  					tm[i] = 100 + rnd.Float64()
    76  				}
    77  				// The v data has been put into a.
    78  				impl.Dlarft(direct, store, m, k, v, ldv, tau, tm, ldt)
    79  
    80  				tData := make([]float64, len(tm))
    81  				copy(tData, tm)
    82  				if direct == lapack.Forward {
    83  					// Zero out the lower triangular portion.
    84  					for i := 0; i < k; i++ {
    85  						for j := 0; j < i; j++ {
    86  							tData[i*ldt+j] = 0
    87  						}
    88  					}
    89  				} else {
    90  					// Zero out the upper triangular portion.
    91  					for i := 0; i < k; i++ {
    92  						for j := i + 1; j < k; j++ {
    93  							tData[i*ldt+j] = 0
    94  						}
    95  					}
    96  				}
    97  
    98  				T := blas64.General{
    99  					Rows:   k,
   100  					Cols:   k,
   101  					Stride: ldt,
   102  					Data:   tData,
   103  				}
   104  
   105  				vMatT := blas64.General{
   106  					Rows:   vMat.Cols,
   107  					Cols:   vMat.Rows,
   108  					Stride: vMat.Rows,
   109  					Data:   make([]float64, vMat.Cols*vMat.Rows),
   110  				}
   111  				for i := 0; i < vMat.Rows; i++ {
   112  					for j := 0; j < vMat.Cols; j++ {
   113  						vMatT.Data[j*vMatT.Stride+i] = vMat.Data[i*vMat.Stride+j]
   114  					}
   115  				}
   116  				var comp blas64.General
   117  				if store == lapack.ColumnWise {
   118  					// H = I - V * T * Vᵀ
   119  					tmp := blas64.General{
   120  						Rows:   T.Rows,
   121  						Cols:   vMatT.Cols,
   122  						Stride: vMatT.Cols,
   123  						Data:   make([]float64, T.Rows*vMatT.Cols),
   124  					}
   125  					// T * Vᵀ
   126  					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, T, vMatT, 0, tmp)
   127  					comp = blas64.General{
   128  						Rows:   vMat.Rows,
   129  						Cols:   tmp.Cols,
   130  						Stride: tmp.Cols,
   131  						Data:   make([]float64, vMat.Rows*tmp.Cols),
   132  					}
   133  					// V * (T * Vᵀ)
   134  					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vMat, tmp, 0, comp)
   135  				} else {
   136  					// H = I - Vᵀ * T * V
   137  					tmp := blas64.General{
   138  						Rows:   T.Rows,
   139  						Cols:   vMat.Cols,
   140  						Stride: vMat.Cols,
   141  						Data:   make([]float64, T.Rows*vMat.Cols),
   142  					}
   143  					// T * V
   144  					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, T, vMat, 0, tmp)
   145  					comp = blas64.General{
   146  						Rows:   vMatT.Rows,
   147  						Cols:   tmp.Cols,
   148  						Stride: tmp.Cols,
   149  						Data:   make([]float64, vMatT.Rows*tmp.Cols),
   150  					}
   151  					// Vᵀ * (T * V)
   152  					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vMatT, tmp, 0, comp)
   153  				}
   154  				// I - Vᵀ * T * V
   155  				for i := 0; i < comp.Rows; i++ {
   156  					for j := 0; j < comp.Cols; j++ {
   157  						comp.Data[i*m+j] *= -1
   158  						if i == j {
   159  							comp.Data[i*m+j] += 1
   160  						}
   161  					}
   162  				}
   163  				if !floats.EqualApprox(comp.Data, h.Data, 1e-14) {
   164  					t.Errorf("T does not construct proper H. Store = %v, Direct = %v.\nWant %v\ngot %v.", string(store), string(direct), h.Data, comp.Data)
   165  				}
   166  			}
   167  		}
   168  	}
   169  }