github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dlarft.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math/rand"
     9  	"testing"
    10  
    11  	"github.com/gonum/blas"
    12  	"github.com/gonum/blas/blas64"
    13  	"github.com/gonum/floats"
    14  	"github.com/gonum/lapack"
    15  )
    16  
    17  type Dlarfter interface {
    18  	Dgeqr2er
    19  	Dlarft(direct lapack.Direct, store lapack.StoreV, n, k int, v []float64, ldv int, tau []float64, t []float64, ldt int)
    20  }
    21  
    22  func DlarftTest(t *testing.T, impl Dlarfter) {
    23  	rnd := rand.New(rand.NewSource(1))
    24  	for _, store := range []lapack.StoreV{lapack.ColumnWise, lapack.RowWise} {
    25  		for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
    26  			for _, test := range []struct {
    27  				m, n, ldv, ldt int
    28  			}{
    29  				{6, 6, 0, 0},
    30  				{8, 6, 0, 0},
    31  				{6, 8, 0, 0},
    32  				{6, 6, 10, 15},
    33  				{8, 6, 10, 15},
    34  				{6, 8, 10, 15},
    35  				{6, 6, 15, 10},
    36  				{8, 6, 15, 10},
    37  				{6, 8, 15, 10},
    38  			} {
    39  				// Generate a matrix
    40  				m := test.m
    41  				n := test.n
    42  				lda := n
    43  				if lda == 0 {
    44  					lda = n
    45  				}
    46  
    47  				a := make([]float64, m*lda)
    48  				for i := 0; i < m; i++ {
    49  					for j := 0; j < lda; j++ {
    50  						a[i*lda+j] = rnd.Float64()
    51  					}
    52  				}
    53  				// Use dgeqr2 to find the v vectors
    54  				tau := make([]float64, n)
    55  				work := make([]float64, n)
    56  				impl.Dgeqr2(m, n, a, lda, tau, work)
    57  
    58  				// Construct H using these answers
    59  				vMatTmp := extractVMat(m, n, a, lda, lapack.Forward, lapack.ColumnWise)
    60  				vMat := constructVMat(vMatTmp, store, direct)
    61  				v := vMat.Data
    62  				ldv := vMat.Stride
    63  
    64  				h := constructH(tau, vMat, store, direct)
    65  
    66  				k := min(m, n)
    67  				ldt := test.ldt
    68  				if ldt == 0 {
    69  					ldt = k
    70  				}
    71  				// Find T from the actual function
    72  				tm := make([]float64, k*ldt)
    73  				for i := range tm {
    74  					tm[i] = 100 + rnd.Float64()
    75  				}
    76  				// The v data has been put into a.
    77  				impl.Dlarft(direct, store, m, k, v, ldv, tau, tm, ldt)
    78  
    79  				tData := make([]float64, len(tm))
    80  				copy(tData, tm)
    81  				if direct == lapack.Forward {
    82  					// Zero out the lower traingular portion.
    83  					for i := 0; i < k; i++ {
    84  						for j := 0; j < i; j++ {
    85  							tData[i*ldt+j] = 0
    86  						}
    87  					}
    88  				} else {
    89  					// Zero out the upper traingular portion.
    90  					for i := 0; i < k; i++ {
    91  						for j := i + 1; j < k; j++ {
    92  							tData[i*ldt+j] = 0
    93  						}
    94  					}
    95  				}
    96  
    97  				T := blas64.General{
    98  					Rows:   k,
    99  					Cols:   k,
   100  					Stride: ldt,
   101  					Data:   tData,
   102  				}
   103  
   104  				vMatT := blas64.General{
   105  					Rows:   vMat.Cols,
   106  					Cols:   vMat.Rows,
   107  					Stride: vMat.Rows,
   108  					Data:   make([]float64, vMat.Cols*vMat.Rows),
   109  				}
   110  				for i := 0; i < vMat.Rows; i++ {
   111  					for j := 0; j < vMat.Cols; j++ {
   112  						vMatT.Data[j*vMatT.Stride+i] = vMat.Data[i*vMat.Stride+j]
   113  					}
   114  				}
   115  				var comp blas64.General
   116  				if store == lapack.ColumnWise {
   117  					// H = I - V * T * V^T
   118  					tmp := blas64.General{
   119  						Rows:   T.Rows,
   120  						Cols:   vMatT.Cols,
   121  						Stride: vMatT.Cols,
   122  						Data:   make([]float64, T.Rows*vMatT.Cols),
   123  					}
   124  					// T * V^T
   125  					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, T, vMatT, 0, tmp)
   126  					comp = blas64.General{
   127  						Rows:   vMat.Rows,
   128  						Cols:   tmp.Cols,
   129  						Stride: tmp.Cols,
   130  						Data:   make([]float64, vMat.Rows*tmp.Cols),
   131  					}
   132  					// V * (T * V^T)
   133  					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vMat, tmp, 0, comp)
   134  				} else {
   135  					// H = I - V^T * T * V
   136  					tmp := blas64.General{
   137  						Rows:   T.Rows,
   138  						Cols:   vMat.Cols,
   139  						Stride: vMat.Cols,
   140  						Data:   make([]float64, T.Rows*vMat.Cols),
   141  					}
   142  					// T * V
   143  					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, T, vMat, 0, tmp)
   144  					comp = blas64.General{
   145  						Rows:   vMatT.Rows,
   146  						Cols:   tmp.Cols,
   147  						Stride: tmp.Cols,
   148  						Data:   make([]float64, vMatT.Rows*tmp.Cols),
   149  					}
   150  					// V^T * (T * V)
   151  					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, vMatT, tmp, 0, comp)
   152  				}
   153  				// I - V^T * T * V
   154  				for i := 0; i < comp.Rows; i++ {
   155  					for j := 0; j < comp.Cols; j++ {
   156  						comp.Data[i*m+j] *= -1
   157  						if i == j {
   158  							comp.Data[i*m+j] += 1
   159  						}
   160  					}
   161  				}
   162  				if !floats.EqualApprox(comp.Data, h.Data, 1e-14) {
   163  					t.Errorf("T does not construct proper H. Store = %v, Direct = %v.\nWant %v\ngot %v.", string(store), string(direct), h.Data, comp.Data)
   164  				}
   165  			}
   166  		}
   167  	}
   168  }