github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dormbr.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math/rand"
     9  	"testing"
    10  
    11  	"github.com/gonum/blas"
    12  	"github.com/gonum/blas/blas64"
    13  	"github.com/gonum/floats"
    14  	"github.com/gonum/lapack"
    15  )
    16  
    17  type Dormbrer interface {
    18  	Dormbr(vect lapack.DecompUpdate, side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int)
    19  	Dgebrder
    20  }
    21  
    22  func DormbrTest(t *testing.T, impl Dormbrer) {
    23  	rnd := rand.New(rand.NewSource(1))
    24  	bi := blas64.Implementation()
    25  	for _, vect := range []lapack.DecompUpdate{lapack.ApplyQ, lapack.ApplyP} {
    26  		for _, side := range []blas.Side{blas.Left, blas.Right} {
    27  			for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
    28  				for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} {
    29  					for _, test := range []struct {
    30  						m, n, k, lda, ldc int
    31  					}{
    32  						{3, 4, 5, 0, 0},
    33  						{3, 5, 4, 0, 0},
    34  						{4, 3, 5, 0, 0},
    35  						{4, 5, 3, 0, 0},
    36  						{5, 3, 4, 0, 0},
    37  						{5, 4, 3, 0, 0},
    38  
    39  						{3, 4, 5, 10, 12},
    40  						{3, 5, 4, 10, 12},
    41  						{4, 3, 5, 10, 12},
    42  						{4, 5, 3, 10, 12},
    43  						{5, 3, 4, 10, 12},
    44  						{5, 4, 3, 10, 12},
    45  
    46  						{150, 140, 130, 0, 0},
    47  					} {
    48  						m := test.m
    49  						n := test.n
    50  						k := test.k
    51  						ldc := test.ldc
    52  						if ldc == 0 {
    53  							ldc = n
    54  						}
    55  						nq := n
    56  						nw := m
    57  						if side == blas.Left {
    58  							nq = m
    59  							nw = n
    60  						}
    61  
    62  						// Compute a decomposition.
    63  						var ma, na int
    64  						var a []float64
    65  						if vect == lapack.ApplyQ {
    66  							ma = nq
    67  							na = k
    68  						} else {
    69  							ma = k
    70  							na = nq
    71  						}
    72  						lda := test.lda
    73  						if lda == 0 {
    74  							lda = na
    75  						}
    76  						a = make([]float64, ma*lda)
    77  						for i := range a {
    78  							a[i] = rnd.NormFloat64()
    79  						}
    80  						nTau := min(nq, k)
    81  						tauP := make([]float64, nTau)
    82  						tauQ := make([]float64, nTau)
    83  						d := make([]float64, nTau)
    84  						e := make([]float64, nTau)
    85  
    86  						work := make([]float64, 1)
    87  						impl.Dgebrd(ma, na, a, lda, d, e, tauQ, tauP, work, -1)
    88  						work = make([]float64, int(work[0]))
    89  						impl.Dgebrd(ma, na, a, lda, d, e, tauQ, tauP, work, len(work))
    90  
    91  						// Apply and compare update.
    92  						c := make([]float64, m*ldc)
    93  						for i := range c {
    94  							c[i] = rnd.NormFloat64()
    95  						}
    96  						cCopy := make([]float64, len(c))
    97  						copy(cCopy, c)
    98  
    99  						var lwork int
   100  						switch wl {
   101  						case minimumWork:
   102  							lwork = nw
   103  						case optimumWork:
   104  							impl.Dormbr(vect, side, trans, m, n, k, a, lda, tauQ, c, ldc, work, -1)
   105  							lwork = int(work[0])
   106  						case mediumWork:
   107  							work := make([]float64, 1)
   108  							impl.Dormbr(vect, side, trans, m, n, k, a, lda, tauQ, c, ldc, work, -1)
   109  							lwork = (int(work[0]) + nw) / 2
   110  						}
   111  						lwork = max(1, lwork)
   112  						work = make([]float64, lwork)
   113  
   114  						if vect == lapack.ApplyQ {
   115  							impl.Dormbr(vect, side, trans, m, n, k, a, lda, tauQ, c, ldc, work, lwork)
   116  						} else {
   117  							impl.Dormbr(vect, side, trans, m, n, k, a, lda, tauP, c, ldc, work, lwork)
   118  						}
   119  
   120  						// Check that the multiplication was correct.
   121  						cOrig := blas64.General{
   122  							Rows:   m,
   123  							Cols:   n,
   124  							Stride: ldc,
   125  							Data:   make([]float64, len(cCopy)),
   126  						}
   127  						copy(cOrig.Data, cCopy)
   128  						cAns := blas64.General{
   129  							Rows:   m,
   130  							Cols:   n,
   131  							Stride: ldc,
   132  							Data:   make([]float64, len(cCopy)),
   133  						}
   134  						copy(cAns.Data, cCopy)
   135  						nb := min(ma, na)
   136  						var mulMat blas64.General
   137  						if vect == lapack.ApplyQ {
   138  							mulMat = constructQPBidiagonal(lapack.ApplyQ, ma, na, nb, a, lda, tauQ)
   139  						} else {
   140  							mulMat = constructQPBidiagonal(lapack.ApplyP, ma, na, nb, a, lda, tauP)
   141  						}
   142  
   143  						mulTrans := trans
   144  
   145  						if side == blas.Left {
   146  							bi.Dgemm(mulTrans, blas.NoTrans, m, n, m, 1, mulMat.Data, mulMat.Stride, cOrig.Data, cOrig.Stride, 0, cAns.Data, cAns.Stride)
   147  						} else {
   148  							bi.Dgemm(blas.NoTrans, mulTrans, m, n, n, 1, cOrig.Data, cOrig.Stride, mulMat.Data, mulMat.Stride, 0, cAns.Data, cAns.Stride)
   149  						}
   150  
   151  						if !floats.EqualApprox(cAns.Data, c, 1e-13) {
   152  							isApplyQ := vect == lapack.ApplyQ
   153  							isLeft := side == blas.Left
   154  							isTrans := trans == blas.Trans
   155  
   156  							t.Errorf("C mismatch. isApplyQ: %v, isLeft: %v, isTrans: %v, m = %v, n = %v, k = %v, lda = %v, ldc = %v",
   157  								isApplyQ, isLeft, isTrans, m, n, k, lda, ldc)
   158  						}
   159  					}
   160  				}
   161  			}
   162  		}
   163  	}
   164  }