gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dlarfg.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"gonum.org/v1/gonum/blas"
    14  	"gonum.org/v1/gonum/blas/blas64"
    15  	"gonum.org/v1/gonum/floats"
    16  )
    17  
    18  type Dlarfger interface {
    19  	Dlarfg(n int, alpha float64, x []float64, incX int) (beta, tau float64)
    20  }
    21  
    22  func DlarfgTest(t *testing.T, impl Dlarfger) {
    23  	const tol = 1e-14
    24  	rnd := rand.New(rand.NewSource(1))
    25  	for i, test := range []struct {
    26  		alpha float64
    27  		n     int
    28  		x     []float64
    29  	}{
    30  		{
    31  			alpha: 4,
    32  			n:     3,
    33  		},
    34  		{
    35  			alpha: -2,
    36  			n:     3,
    37  		},
    38  		{
    39  			alpha: 0,
    40  			n:     3,
    41  		},
    42  		{
    43  			alpha: 1,
    44  			n:     1,
    45  		},
    46  		{
    47  			alpha: 1,
    48  			n:     4,
    49  			x:     []float64{4, 5, 6},
    50  		},
    51  		{
    52  			alpha: 1,
    53  			n:     4,
    54  			x:     []float64{0, 0, 0},
    55  		},
    56  		{
    57  			alpha: dlamchS,
    58  			n:     4,
    59  			x:     []float64{dlamchS, dlamchS, dlamchS},
    60  		},
    61  	} {
    62  		n := test.n
    63  		incX := 1
    64  		var x []float64
    65  		if test.x == nil {
    66  			x = make([]float64, n-1)
    67  			for i := range x {
    68  				x[i] = rnd.Float64()
    69  			}
    70  		} else {
    71  			if len(test.x) != n-1 {
    72  				panic("bad test")
    73  			}
    74  			x = make([]float64, n-1)
    75  			copy(x, test.x)
    76  		}
    77  		xcopy := make([]float64, n-1)
    78  		copy(xcopy, x)
    79  		alpha := test.alpha
    80  		beta, tau := impl.Dlarfg(n, alpha, x, incX)
    81  
    82  		// Verify the returns and the values in v. Construct h and perform
    83  		// the explicit multiplication.
    84  		h := make([]float64, n*n)
    85  		for i := 0; i < n; i++ {
    86  			h[i*n+i] = 1
    87  		}
    88  		hmat := blas64.General{
    89  			Rows:   n,
    90  			Cols:   n,
    91  			Stride: n,
    92  			Data:   h,
    93  		}
    94  		v := make([]float64, n)
    95  		copy(v[1:], x)
    96  		v[0] = 1
    97  		vVec := blas64.Vector{
    98  			Inc:  1,
    99  			Data: v,
   100  		}
   101  		blas64.Ger(-tau, vVec, vVec, hmat)
   102  		eye := blas64.General{
   103  			Rows:   n,
   104  			Cols:   n,
   105  			Stride: n,
   106  			Data:   make([]float64, n*n),
   107  		}
   108  		blas64.Gemm(blas.Trans, blas.NoTrans, 1, hmat, hmat, 0, eye)
   109  		dist := distFromIdentity(n, eye.Data, n)
   110  		if dist > tol {
   111  			t.Errorf("Hᵀ * H is not close to I, dist=%v", dist)
   112  		}
   113  
   114  		xVec := blas64.Vector{
   115  			Inc:  1,
   116  			Data: make([]float64, n),
   117  		}
   118  		xVec.Data[0] = test.alpha
   119  		copy(xVec.Data[1:], xcopy)
   120  
   121  		ans := make([]float64, n)
   122  		ansVec := blas64.Vector{
   123  			Inc:  1,
   124  			Data: ans,
   125  		}
   126  		blas64.Gemv(blas.NoTrans, 1, hmat, xVec, 0, ansVec)
   127  		if math.Abs(ans[0]-beta) > tol {
   128  			t.Errorf("Case %v, beta mismatch. Want %v, got %v", i, ans[0], beta)
   129  		}
   130  		if floats.Norm(ans[1:n], math.Inf(1)) > tol {
   131  			t.Errorf("Case %v, nonzero answer %v", i, ans[1:n])
   132  		}
   133  	}
   134  }