gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dgebrd.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"gonum.org/v1/gonum/floats"
    14  )
    15  
    16  type Dgebrder interface {
    17  	Dgebrd(m, n int, a []float64, lda int, d, e, tauQ, tauP, work []float64, lwork int)
    18  	Dgebd2er
    19  }
    20  
    21  func DgebrdTest(t *testing.T, impl Dgebrder) {
    22  	rnd := rand.New(rand.NewSource(1))
    23  	for _, test := range []struct {
    24  		m, n, lda int
    25  	}{
    26  		{100, 100, 0},
    27  		{100, 150, 0},
    28  		{150, 100, 0},
    29  		{100, 100, 200},
    30  		{100, 150, 200},
    31  		{150, 100, 200},
    32  
    33  		{300, 300, 0},
    34  		{300, 400, 0},
    35  		{400, 300, 0},
    36  		{300, 300, 500},
    37  		{300, 400, 500},
    38  		{300, 400, 500},
    39  	} {
    40  		m := test.m
    41  		n := test.n
    42  		lda := test.lda
    43  		if lda == 0 {
    44  			lda = n
    45  		}
    46  		minmn := min(m, n)
    47  		a := make([]float64, m*lda)
    48  		for i := range a {
    49  			a[i] = rnd.NormFloat64()
    50  		}
    51  
    52  		d := make([]float64, minmn)
    53  		e := make([]float64, minmn-1)
    54  		tauP := make([]float64, minmn)
    55  		tauQ := make([]float64, minmn)
    56  		work := make([]float64, max(m, n))
    57  		for i := range work {
    58  			work[i] = math.NaN()
    59  		}
    60  
    61  		// Store a.
    62  		aCopy := make([]float64, len(a))
    63  		copy(aCopy, a)
    64  
    65  		// Compute the true answer with the unblocked algorithm.
    66  		impl.Dgebd2(m, n, a, lda, d, e, tauQ, tauP, work)
    67  		aAns := make([]float64, len(a))
    68  		copy(aAns, a)
    69  		dAns := make([]float64, len(d))
    70  		copy(dAns, d)
    71  		eAns := make([]float64, len(e))
    72  		copy(eAns, e)
    73  		tauQAns := make([]float64, len(tauQ))
    74  		copy(tauQAns, tauQ)
    75  		tauPAns := make([]float64, len(tauP))
    76  		copy(tauPAns, tauP)
    77  
    78  		// Test with optimal work.
    79  		lwork := -1
    80  		copy(a, aCopy)
    81  		impl.Dgebrd(m, n, a, lda, d, e, tauQ, tauP, work, lwork)
    82  		work = make([]float64, int(work[0]))
    83  		lwork = len(work)
    84  		for i := range work {
    85  			work[i] = math.NaN()
    86  		}
    87  		for i := range d {
    88  			d[i] = math.NaN()
    89  		}
    90  		for i := range e {
    91  			e[i] = math.NaN()
    92  		}
    93  		for i := range tauQ {
    94  			tauQ[i] = math.NaN()
    95  		}
    96  		for i := range tauP {
    97  			tauP[i] = math.NaN()
    98  		}
    99  		impl.Dgebrd(m, n, a, lda, d, e, tauQ, tauP, work, lwork)
   100  
   101  		// Test answers
   102  		if !floats.EqualApprox(a, aAns, 1e-10) {
   103  			t.Errorf("a mismatch")
   104  		}
   105  		if !floats.EqualApprox(d, dAns, 1e-10) {
   106  			t.Errorf("d mismatch")
   107  		}
   108  		if !floats.EqualApprox(e, eAns, 1e-10) {
   109  			t.Errorf("e mismatch")
   110  		}
   111  		if !floats.EqualApprox(tauQ, tauQAns, 1e-10) {
   112  			t.Errorf("tauQ mismatch")
   113  		}
   114  		if !floats.EqualApprox(tauP, tauPAns, 1e-10) {
   115  			t.Errorf("tauP mismatch")
   116  		}
   117  
   118  		// Test with shorter than optimal work.
   119  		lwork--
   120  		copy(a, aCopy)
   121  		for i := range d {
   122  			d[i] = 0
   123  		}
   124  		for i := range e {
   125  			e[i] = 0
   126  		}
   127  		for i := range tauP {
   128  			tauP[i] = 0
   129  		}
   130  		for i := range tauQ {
   131  			tauQ[i] = 0
   132  		}
   133  		impl.Dgebrd(m, n, a, lda, d, e, tauQ, tauP, work, lwork)
   134  
   135  		// Test answers
   136  		if !floats.EqualApprox(a, aAns, 1e-10) {
   137  			t.Errorf("a mismatch")
   138  		}
   139  		if !floats.EqualApprox(d, dAns, 1e-10) {
   140  			t.Errorf("d mismatch")
   141  		}
   142  		if !floats.EqualApprox(e, eAns, 1e-10) {
   143  			t.Errorf("e mismatch")
   144  		}
   145  		if !floats.EqualApprox(tauQ, tauQAns, 1e-10) {
   146  			t.Errorf("tauQ mismatch")
   147  		}
   148  		if !floats.EqualApprox(tauP, tauPAns, 1e-10) {
   149  			t.Errorf("tauP mismatch")
   150  		}
   151  	}
   152  }