github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/testlapack/dormbr.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"testing"
     9  
    10  	"golang.org/x/exp/rand"
    11  
    12  	"github.com/jingcheng-WU/gonum/blas"
    13  	"github.com/jingcheng-WU/gonum/blas/blas64"
    14  	"github.com/jingcheng-WU/gonum/floats"
    15  	"github.com/jingcheng-WU/gonum/lapack"
    16  )
    17  
    18  type Dormbrer interface {
    19  	Dormbr(vect lapack.ApplyOrtho, side blas.Side, trans blas.Transpose, m, n, k int, a []float64, lda int, tau, c []float64, ldc int, work []float64, lwork int)
    20  	Dgebrder
    21  }
    22  
    23  func DormbrTest(t *testing.T, impl Dormbrer) {
    24  	rnd := rand.New(rand.NewSource(1))
    25  	bi := blas64.Implementation()
    26  	for _, vect := range []lapack.ApplyOrtho{lapack.ApplyQ, lapack.ApplyP} {
    27  		for _, side := range []blas.Side{blas.Left, blas.Right} {
    28  			for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
    29  				for _, wl := range []worklen{minimumWork, mediumWork, optimumWork} {
    30  					for _, test := range []struct {
    31  						m, n, k, lda, ldc int
    32  					}{
    33  						{3, 4, 5, 0, 0},
    34  						{3, 5, 4, 0, 0},
    35  						{4, 3, 5, 0, 0},
    36  						{4, 5, 3, 0, 0},
    37  						{5, 3, 4, 0, 0},
    38  						{5, 4, 3, 0, 0},
    39  
    40  						{3, 4, 5, 10, 12},
    41  						{3, 5, 4, 10, 12},
    42  						{4, 3, 5, 10, 12},
    43  						{4, 5, 3, 10, 12},
    44  						{5, 3, 4, 10, 12},
    45  						{5, 4, 3, 10, 12},
    46  
    47  						{150, 140, 130, 0, 0},
    48  					} {
    49  						m := test.m
    50  						n := test.n
    51  						k := test.k
    52  						ldc := test.ldc
    53  						if ldc == 0 {
    54  							ldc = n
    55  						}
    56  						nq := n
    57  						nw := m
    58  						if side == blas.Left {
    59  							nq = m
    60  							nw = n
    61  						}
    62  
    63  						// Compute a decomposition.
    64  						var ma, na int
    65  						var a []float64
    66  						if vect == lapack.ApplyQ {
    67  							ma = nq
    68  							na = k
    69  						} else {
    70  							ma = k
    71  							na = nq
    72  						}
    73  						lda := test.lda
    74  						if lda == 0 {
    75  							lda = na
    76  						}
    77  						a = make([]float64, ma*lda)
    78  						for i := range a {
    79  							a[i] = rnd.NormFloat64()
    80  						}
    81  						nTau := min(nq, k)
    82  						tauP := make([]float64, nTau)
    83  						tauQ := make([]float64, nTau)
    84  						d := make([]float64, nTau)
    85  						e := make([]float64, nTau)
    86  
    87  						work := make([]float64, 1)
    88  						impl.Dgebrd(ma, na, a, lda, d, e, tauQ, tauP, work, -1)
    89  						work = make([]float64, int(work[0]))
    90  						impl.Dgebrd(ma, na, a, lda, d, e, tauQ, tauP, work, len(work))
    91  
    92  						// Apply and compare update.
    93  						c := make([]float64, m*ldc)
    94  						for i := range c {
    95  							c[i] = rnd.NormFloat64()
    96  						}
    97  						cCopy := make([]float64, len(c))
    98  						copy(cCopy, c)
    99  
   100  						var lwork int
   101  						switch wl {
   102  						case minimumWork:
   103  							lwork = nw
   104  						case optimumWork:
   105  							impl.Dormbr(vect, side, trans, m, n, k, a, lda, tauQ, c, ldc, work, -1)
   106  							lwork = int(work[0])
   107  						case mediumWork:
   108  							work := make([]float64, 1)
   109  							impl.Dormbr(vect, side, trans, m, n, k, a, lda, tauQ, c, ldc, work, -1)
   110  							lwork = (int(work[0]) + nw) / 2
   111  						}
   112  						lwork = max(1, lwork)
   113  						work = make([]float64, lwork)
   114  
   115  						if vect == lapack.ApplyQ {
   116  							impl.Dormbr(vect, side, trans, m, n, k, a, lda, tauQ, c, ldc, work, lwork)
   117  						} else {
   118  							impl.Dormbr(vect, side, trans, m, n, k, a, lda, tauP, c, ldc, work, lwork)
   119  						}
   120  
   121  						// Check that the multiplication was correct.
   122  						cOrig := blas64.General{
   123  							Rows:   m,
   124  							Cols:   n,
   125  							Stride: ldc,
   126  							Data:   make([]float64, len(cCopy)),
   127  						}
   128  						copy(cOrig.Data, cCopy)
   129  						cAns := blas64.General{
   130  							Rows:   m,
   131  							Cols:   n,
   132  							Stride: ldc,
   133  							Data:   make([]float64, len(cCopy)),
   134  						}
   135  						copy(cAns.Data, cCopy)
   136  						nb := min(ma, na)
   137  						var mulMat blas64.General
   138  						if vect == lapack.ApplyQ {
   139  							mulMat = constructQPBidiagonal(lapack.ApplyQ, ma, na, nb, a, lda, tauQ)
   140  						} else {
   141  							mulMat = constructQPBidiagonal(lapack.ApplyP, ma, na, nb, a, lda, tauP)
   142  						}
   143  
   144  						mulTrans := trans
   145  
   146  						if side == blas.Left {
   147  							bi.Dgemm(mulTrans, blas.NoTrans, m, n, m, 1, mulMat.Data, mulMat.Stride, cOrig.Data, cOrig.Stride, 0, cAns.Data, cAns.Stride)
   148  						} else {
   149  							bi.Dgemm(blas.NoTrans, mulTrans, m, n, n, 1, cOrig.Data, cOrig.Stride, mulMat.Data, mulMat.Stride, 0, cAns.Data, cAns.Stride)
   150  						}
   151  
   152  						if !floats.EqualApprox(cAns.Data, c, 1e-13) {
   153  							isApplyQ := vect == lapack.ApplyQ
   154  							isLeft := side == blas.Left
   155  							isTrans := trans == blas.Trans
   156  
   157  							t.Errorf("C mismatch. isApplyQ: %v, isLeft: %v, isTrans: %v, m = %v, n = %v, k = %v, lda = %v, ldc = %v",
   158  								isApplyQ, isLeft, isTrans, m, n, k, lda, ldc)
   159  						}
   160  					}
   161  				}
   162  			}
   163  		}
   164  	}
   165  }