github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dlarfg.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math"
     9  	"math/rand"
    10  	"testing"
    11  
    12  	"github.com/gonum/blas"
    13  	"github.com/gonum/blas/blas64"
    14  )
    15  
    16  type Dlarfger interface {
    17  	Dlarfg(n int, alpha float64, x []float64, incX int) (beta, tau float64)
    18  }
    19  
    20  func DlarfgTest(t *testing.T, impl Dlarfger) {
    21  	rnd := rand.New(rand.NewSource(1))
    22  	for i, test := range []struct {
    23  		alpha float64
    24  		n     int
    25  		x     []float64
    26  	}{
    27  		{
    28  			alpha: 4,
    29  			n:     3,
    30  		},
    31  		{
    32  			alpha: -2,
    33  			n:     3,
    34  		},
    35  		{
    36  			alpha: 0,
    37  			n:     3,
    38  		},
    39  		{
    40  			alpha: 1,
    41  			n:     1,
    42  		},
    43  		{
    44  			alpha: 1,
    45  			n:     2,
    46  			x:     []float64{4, 5, 6},
    47  		},
    48  	} {
    49  		n := test.n
    50  		incX := 1
    51  		var x []float64
    52  		if test.x == nil {
    53  			x = make([]float64, n-1)
    54  			for i := range x {
    55  				x[i] = rnd.Float64()
    56  			}
    57  		} else {
    58  			x = make([]float64, n-1)
    59  			copy(x, test.x)
    60  		}
    61  		xcopy := make([]float64, n-1)
    62  		copy(xcopy, x)
    63  		alpha := test.alpha
    64  		beta, tau := impl.Dlarfg(n, alpha, x, incX)
    65  
    66  		// Verify the returns and the values in v. Construct h and perform
    67  		// the explicit multiplication.
    68  		h := make([]float64, n*n)
    69  		for i := 0; i < n; i++ {
    70  			h[i*n+i] = 1
    71  		}
    72  		hmat := blas64.General{
    73  			Rows:   n,
    74  			Cols:   n,
    75  			Stride: n,
    76  			Data:   h,
    77  		}
    78  		v := make([]float64, n)
    79  		copy(v[1:], x)
    80  		v[0] = 1
    81  		vVec := blas64.Vector{
    82  			Inc:  1,
    83  			Data: v,
    84  		}
    85  		blas64.Ger(-tau, vVec, vVec, hmat)
    86  		eye := blas64.General{
    87  			Rows:   n,
    88  			Cols:   n,
    89  			Stride: n,
    90  			Data:   make([]float64, n*n),
    91  		}
    92  		blas64.Gemm(blas.Trans, blas.NoTrans, 1, hmat, hmat, 0, eye)
    93  		iseye := true
    94  		for i := 0; i < n; i++ {
    95  			for j := 0; j < n; j++ {
    96  				if i == j {
    97  					if math.Abs(eye.Data[i*n+j]-1) > 1e-14 {
    98  						iseye = false
    99  					}
   100  				} else {
   101  					if math.Abs(eye.Data[i*n+j]) > 1e-14 {
   102  						iseye = false
   103  					}
   104  				}
   105  			}
   106  		}
   107  		if !iseye {
   108  			t.Errorf("H^T * H is not I %v", eye)
   109  		}
   110  
   111  		xVec := blas64.Vector{
   112  			Inc:  1,
   113  			Data: make([]float64, n),
   114  		}
   115  		xVec.Data[0] = test.alpha
   116  		copy(xVec.Data[1:], xcopy)
   117  
   118  		ans := make([]float64, n)
   119  		ansVec := blas64.Vector{
   120  			Inc:  1,
   121  			Data: ans,
   122  		}
   123  		blas64.Gemv(blas.NoTrans, 1, hmat, xVec, 0, ansVec)
   124  		if math.Abs(ans[0]-beta) > 1e-14 {
   125  			t.Errorf("Case %v, beta mismatch. Want %v, got %v", i, ans[0], beta)
   126  		}
   127  		for i := 1; i < n; i++ {
   128  			if math.Abs(ans[i]) > 1e-14 {
   129  				t.Errorf("Case %v, nonzero answer %v", i, ans)
   130  				break
   131  			}
   132  		}
   133  	}
   134  }