github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dlarf.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math/rand"
     9  	"testing"
    10  
    11  	"github.com/gonum/blas"
    12  	"github.com/gonum/blas/blas64"
    13  	"github.com/gonum/floats"
    14  )
    15  
    16  type Dlarfer interface {
    17  	Dlarf(side blas.Side, m, n int, v []float64, incv int, tau float64, c []float64, ldc int, work []float64)
    18  }
    19  
    20  func DlarfTest(t *testing.T, impl Dlarfer) {
    21  	rnd := rand.New(rand.NewSource(1))
    22  	for i, test := range []struct {
    23  		m, n, ldc    int
    24  		incv, lastv  int
    25  		lastr, lastc int
    26  		tau          float64
    27  	}{
    28  		{
    29  			m:   3,
    30  			n:   2,
    31  			ldc: 2,
    32  
    33  			incv:  4,
    34  			lastv: 1,
    35  
    36  			lastr: 2,
    37  			lastc: 1,
    38  
    39  			tau: 2,
    40  		},
    41  		{
    42  			m:   2,
    43  			n:   3,
    44  			ldc: 3,
    45  
    46  			incv:  4,
    47  			lastv: 1,
    48  
    49  			lastr: 1,
    50  			lastc: 2,
    51  
    52  			tau: 2,
    53  		},
    54  		{
    55  			m:   2,
    56  			n:   3,
    57  			ldc: 3,
    58  
    59  			incv:  4,
    60  			lastv: 1,
    61  
    62  			lastr: 0,
    63  			lastc: 1,
    64  
    65  			tau: 2,
    66  		},
    67  		{
    68  			m:   2,
    69  			n:   3,
    70  			ldc: 3,
    71  
    72  			incv:  4,
    73  			lastv: 0,
    74  
    75  			lastr: 0,
    76  			lastc: 1,
    77  
    78  			tau: 2,
    79  		},
    80  		{
    81  			m:   10,
    82  			n:   10,
    83  			ldc: 10,
    84  
    85  			incv:  4,
    86  			lastv: 6,
    87  
    88  			lastr: 9,
    89  			lastc: 8,
    90  
    91  			tau: 2,
    92  		},
    93  	} {
    94  		// Construct a random matrix.
    95  		c := make([]float64, test.ldc*test.m)
    96  		for i := 0; i <= test.lastr; i++ {
    97  			for j := 0; j <= test.lastc; j++ {
    98  				c[i*test.ldc+j] = rnd.Float64()
    99  			}
   100  		}
   101  		cCopy := make([]float64, len(c))
   102  		copy(cCopy, c)
   103  		cCopy2 := make([]float64, len(c))
   104  		copy(cCopy2, c)
   105  
   106  		// Test with side right.
   107  		sz := max(test.m, test.n) // so v works for both right and left side.
   108  		v := make([]float64, test.incv*sz+1)
   109  		// Fill with nonzero entries up until lastv.
   110  		for i := 0; i <= test.lastv; i++ {
   111  			v[i*test.incv] = rnd.Float64()
   112  		}
   113  		// Construct h explicitly to compare.
   114  		h := make([]float64, test.n*test.n)
   115  		for i := 0; i < test.n; i++ {
   116  			h[i*test.n+i] = 1
   117  		}
   118  		hMat := blas64.General{
   119  			Rows:   test.n,
   120  			Cols:   test.n,
   121  			Stride: test.n,
   122  			Data:   h,
   123  		}
   124  		vVec := blas64.Vector{
   125  			Inc:  test.incv,
   126  			Data: v,
   127  		}
   128  		blas64.Ger(-test.tau, vVec, vVec, hMat)
   129  
   130  		// Apply multiplication (2nd copy is to avoid aliasing).
   131  		cMat := blas64.General{
   132  			Rows:   test.m,
   133  			Cols:   test.n,
   134  			Stride: test.ldc,
   135  			Data:   cCopy,
   136  		}
   137  		cMat2 := blas64.General{
   138  			Rows:   test.m,
   139  			Cols:   test.n,
   140  			Stride: test.ldc,
   141  			Data:   cCopy2,
   142  		}
   143  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMat2, hMat, 0, cMat)
   144  
   145  		// cMat now stores the true answer. Compare with the function call.
   146  		work := make([]float64, sz)
   147  		impl.Dlarf(blas.Right, test.m, test.n, v, test.incv, test.tau, c, test.ldc, work)
   148  		if !floats.EqualApprox(c, cMat.Data, 1e-14) {
   149  			t.Errorf("Dlarf mismatch right, case %v. Want %v, got %v", i, cMat.Data, c)
   150  		}
   151  
   152  		// Test on the left side.
   153  		copy(c, cCopy2)
   154  		copy(cCopy, c)
   155  		// Construct h.
   156  		h = make([]float64, test.m*test.m)
   157  		for i := 0; i < test.m; i++ {
   158  			h[i*test.m+i] = 1
   159  		}
   160  		hMat = blas64.General{
   161  			Rows:   test.m,
   162  			Cols:   test.m,
   163  			Stride: test.m,
   164  			Data:   h,
   165  		}
   166  		blas64.Ger(-test.tau, vVec, vVec, hMat)
   167  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, hMat, cMat2, 0, cMat)
   168  		impl.Dlarf(blas.Left, test.m, test.n, v, test.incv, test.tau, c, test.ldc, work)
   169  		if !floats.EqualApprox(c, cMat.Data, 1e-14) {
   170  			t.Errorf("Dlarf mismatch left, case %v. Want %v, got %v", i, cMat.Data, c)
   171  		}
   172  	}
   173  }