github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dlasr.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math"
     9  	"math/rand"
    10  	"testing"
    11  
    12  	"github.com/gonum/blas"
    13  	"github.com/gonum/blas/blas64"
    14  	"github.com/gonum/floats"
    15  	"github.com/gonum/lapack"
    16  )
    17  
    18  type Dlasrer interface {
    19  	Dlasr(side blas.Side, pivot lapack.Pivot, direct lapack.Direct, m, n int, c, s, a []float64, lda int)
    20  }
    21  
    22  func DlasrTest(t *testing.T, impl Dlasrer) {
    23  	rnd := rand.New(rand.NewSource(1))
    24  	for _, side := range []blas.Side{blas.Left, blas.Right} {
    25  		for _, pivot := range []lapack.Pivot{lapack.Variable, lapack.Top, lapack.Bottom} {
    26  			for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
    27  				for _, test := range []struct {
    28  					m, n, lda int
    29  				}{
    30  					{5, 5, 0},
    31  					{5, 10, 0},
    32  					{10, 5, 0},
    33  
    34  					{5, 5, 20},
    35  					{5, 10, 20},
    36  					{10, 5, 20},
    37  				} {
    38  					m := test.m
    39  					n := test.n
    40  					lda := test.lda
    41  					if lda == 0 {
    42  						lda = n
    43  					}
    44  					a := make([]float64, m*lda)
    45  					for i := range a {
    46  						a[i] = rnd.Float64()
    47  					}
    48  					var s, c []float64
    49  					if side == blas.Left {
    50  						s = make([]float64, m-1)
    51  						c = make([]float64, m-1)
    52  					} else {
    53  						s = make([]float64, n-1)
    54  						c = make([]float64, n-1)
    55  					}
    56  					for k := range s {
    57  						theta := rnd.Float64() * 2 * math.Pi
    58  						s[k] = math.Sin(theta)
    59  						c[k] = math.Cos(theta)
    60  					}
    61  					aCopy := make([]float64, len(a))
    62  					copy(a, aCopy)
    63  					impl.Dlasr(side, pivot, direct, m, n, c, s, a, lda)
    64  
    65  					pSize := m
    66  					if side == blas.Right {
    67  						pSize = n
    68  					}
    69  					p := blas64.General{
    70  						Rows:   pSize,
    71  						Cols:   pSize,
    72  						Stride: pSize,
    73  						Data:   make([]float64, pSize*pSize),
    74  					}
    75  					pk := blas64.General{
    76  						Rows:   pSize,
    77  						Cols:   pSize,
    78  						Stride: pSize,
    79  						Data:   make([]float64, pSize*pSize),
    80  					}
    81  					ptmp := blas64.General{
    82  						Rows:   pSize,
    83  						Cols:   pSize,
    84  						Stride: pSize,
    85  						Data:   make([]float64, pSize*pSize),
    86  					}
    87  					for i := 0; i < pSize; i++ {
    88  						p.Data[i*p.Stride+i] = 1
    89  						ptmp.Data[i*p.Stride+i] = 1
    90  					}
    91  					// Compare to direct computation.
    92  					for k := range s {
    93  						for i := range p.Data {
    94  							pk.Data[i] = 0
    95  						}
    96  						for i := 0; i < pSize; i++ {
    97  							pk.Data[i*p.Stride+i] = 1
    98  						}
    99  						if pivot == lapack.Variable {
   100  							pk.Data[k*p.Stride+k] = c[k]
   101  							pk.Data[k*p.Stride+k+1] = s[k]
   102  							pk.Data[(k+1)*p.Stride+k] = -s[k]
   103  							pk.Data[(k+1)*p.Stride+k+1] = c[k]
   104  						} else if pivot == lapack.Top {
   105  							pk.Data[0] = c[k]
   106  							pk.Data[k+1] = s[k]
   107  							pk.Data[(k+1)*p.Stride] = -s[k]
   108  							pk.Data[(k+1)*p.Stride+k+1] = c[k]
   109  						} else {
   110  							pk.Data[(pSize-1-k)*p.Stride+pSize-k-1] = c[k]
   111  							pk.Data[(pSize-1-k)*p.Stride+pSize-1] = s[k]
   112  							pk.Data[(pSize-1)*p.Stride+pSize-1-k] = -s[k]
   113  							pk.Data[(pSize-1)*p.Stride+pSize-1] = c[k]
   114  						}
   115  						if direct == lapack.Forward {
   116  							blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, pk, ptmp, 0, p)
   117  						} else {
   118  							blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, ptmp, pk, 0, p)
   119  						}
   120  						copy(ptmp.Data, p.Data)
   121  					}
   122  
   123  					aMat := blas64.General{
   124  						Rows:   m,
   125  						Cols:   n,
   126  						Stride: lda,
   127  						Data:   make([]float64, m*lda),
   128  					}
   129  					copy(a, aCopy)
   130  					newA := blas64.General{
   131  						Rows:   m,
   132  						Cols:   n,
   133  						Stride: lda,
   134  						Data:   make([]float64, m*lda),
   135  					}
   136  					if side == blas.Left {
   137  						blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, p, aMat, 0, newA)
   138  					} else {
   139  						blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, p, 0, newA)
   140  					}
   141  					if !floats.EqualApprox(newA.Data, a, 1e-12) {
   142  						t.Errorf("A update mismatch")
   143  					}
   144  				}
   145  			}
   146  		}
   147  	}
   148  }