gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dorgbr.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"testing"
     9  
    10  	"golang.org/x/exp/rand"
    11  
    12  	"gonum.org/v1/gonum/blas/blas64"
    13  	"gonum.org/v1/gonum/floats/scalar"
    14  	"gonum.org/v1/gonum/lapack"
    15  )
    16  
    17  type Dorgbrer interface {
    18  	Dorgbr(vect lapack.GenOrtho, m, n, k int, a []float64, lda int, tau, work []float64, lwork int)
    19  	Dgebrder
    20  }
    21  
    22  func DorgbrTest(t *testing.T, impl Dorgbrer) {
    23  	rnd := rand.New(rand.NewSource(1))
    24  	for _, vect := range []lapack.GenOrtho{lapack.GenerateQ, lapack.GeneratePT} {
    25  		for _, test := range []struct {
    26  			m, n, k, lda int
    27  		}{
    28  			{5, 5, 5, 0},
    29  			{5, 5, 3, 0},
    30  			{5, 3, 5, 0},
    31  			{3, 5, 5, 0},
    32  			{3, 4, 5, 0},
    33  			{3, 5, 4, 0},
    34  			{4, 3, 5, 0},
    35  			{4, 5, 3, 0},
    36  			{5, 3, 4, 0},
    37  			{5, 4, 3, 0},
    38  
    39  			{5, 5, 5, 10},
    40  			{5, 5, 3, 10},
    41  			{5, 3, 5, 10},
    42  			{3, 5, 5, 10},
    43  			{3, 4, 5, 10},
    44  			{3, 5, 4, 10},
    45  			{4, 3, 5, 10},
    46  			{4, 5, 3, 10},
    47  			{5, 3, 4, 10},
    48  			{5, 4, 3, 10},
    49  		} {
    50  			m := test.m
    51  			n := test.n
    52  			k := test.k
    53  			lda := test.lda
    54  			// Filter out bad tests
    55  			if vect == lapack.GenerateQ {
    56  				if m < n || n < min(m, k) || m < min(m, k) {
    57  					continue
    58  				}
    59  			} else {
    60  				if n < m || m < min(n, k) || n < min(n, k) {
    61  					continue
    62  				}
    63  			}
    64  			// Sizes for Dorgbr.
    65  			var ma, na int
    66  			if vect == lapack.GenerateQ {
    67  				if m >= k {
    68  					ma = m
    69  					na = k
    70  				} else {
    71  					ma = m
    72  					na = m
    73  				}
    74  			} else {
    75  				if n >= k {
    76  					ma = k
    77  					na = n
    78  				} else {
    79  					ma = n
    80  					na = n
    81  				}
    82  			}
    83  			// a eventually needs to store either P or Q, so it must be
    84  			// sufficiently big.
    85  			var a []float64
    86  			if vect == lapack.GenerateQ {
    87  				lda = max(m, lda)
    88  				a = make([]float64, m*lda)
    89  			} else {
    90  				lda = max(n, lda)
    91  				a = make([]float64, n*lda)
    92  			}
    93  			for i := range a {
    94  				a[i] = rnd.NormFloat64()
    95  			}
    96  
    97  			nTau := min(ma, na)
    98  			tauP := make([]float64, nTau)
    99  			tauQ := make([]float64, nTau)
   100  			d := make([]float64, nTau)
   101  			e := make([]float64, nTau)
   102  			lwork := -1
   103  			work := make([]float64, 1)
   104  			impl.Dgebrd(ma, na, a, lda, d, e, tauQ, tauP, work, lwork)
   105  			work = make([]float64, int(work[0]))
   106  			lwork = len(work)
   107  			impl.Dgebrd(ma, na, a, lda, d, e, tauQ, tauP, work, lwork)
   108  
   109  			aCopy := make([]float64, len(a))
   110  			copy(aCopy, a)
   111  
   112  			var tau []float64
   113  			if vect == lapack.GenerateQ {
   114  				tau = tauQ
   115  			} else {
   116  				tau = tauP
   117  			}
   118  
   119  			impl.Dorgbr(vect, m, n, k, a, lda, tau, work, -1)
   120  			work = make([]float64, int(work[0]))
   121  			lwork = len(work)
   122  			impl.Dorgbr(vect, m, n, k, a, lda, tau, work, lwork)
   123  
   124  			var ans blas64.General
   125  			var nRows, nCols int
   126  			equal := true
   127  			if vect == lapack.GenerateQ {
   128  				nRows = m
   129  				nCols = m
   130  				if m >= k {
   131  					nCols = n
   132  				}
   133  				ans = constructQPBidiagonal(lapack.ApplyQ, ma, na, min(m, k), aCopy, lda, tau)
   134  			} else {
   135  				nRows = n
   136  				if k < n {
   137  					nRows = m
   138  				}
   139  				nCols = n
   140  				ansTmp := constructQPBidiagonal(lapack.ApplyP, ma, na, min(k, n), aCopy, lda, tau)
   141  				// Dorgbr actually computes Pᵀ
   142  				ans = transposeGeneral(ansTmp)
   143  			}
   144  			for i := 0; i < nRows; i++ {
   145  				for j := 0; j < nCols; j++ {
   146  					if !scalar.EqualWithinAbsOrRel(a[i*lda+j], ans.Data[i*ans.Stride+j], 1e-8, 1e-8) {
   147  						equal = false
   148  					}
   149  				}
   150  			}
   151  			if !equal {
   152  				t.Errorf("Extracted matrix mismatch. gen = %v, m = %v, n = %v, k = %v", string(vect), m, n, k)
   153  			}
   154  		}
   155  	}
   156  }