github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgebrd.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math"
     9  	"math/rand"
    10  	"testing"
    11  
    12  	"github.com/gonum/floats"
    13  )
    14  
    15  type Dgebrder interface {
    16  	Dgebrd(m, n int, a []float64, lda int, d, e, tauQ, tauP, work []float64, lwork int)
    17  	Dgebd2er
    18  }
    19  
    20  func DgebrdTest(t *testing.T, impl Dgebrder) {
    21  	rnd := rand.New(rand.NewSource(1))
    22  	for _, test := range []struct {
    23  		m, n, lda int
    24  	}{
    25  		{100, 100, 0},
    26  		{100, 150, 0},
    27  		{150, 100, 0},
    28  		{100, 100, 200},
    29  		{100, 150, 200},
    30  		{150, 100, 200},
    31  
    32  		{300, 300, 0},
    33  		{300, 400, 0},
    34  		{400, 300, 0},
    35  		{300, 300, 500},
    36  		{300, 400, 500},
    37  		{300, 400, 500},
    38  	} {
    39  		m := test.m
    40  		n := test.n
    41  		lda := test.lda
    42  		if lda == 0 {
    43  			lda = n
    44  		}
    45  		minmn := min(m, n)
    46  		a := make([]float64, m*lda)
    47  		for i := range a {
    48  			a[i] = rnd.NormFloat64()
    49  		}
    50  
    51  		d := make([]float64, minmn)
    52  		e := make([]float64, minmn-1)
    53  		tauP := make([]float64, minmn)
    54  		tauQ := make([]float64, minmn)
    55  		work := make([]float64, max(m, n))
    56  		for i := range work {
    57  			work[i] = math.NaN()
    58  		}
    59  
    60  		// Store a.
    61  		aCopy := make([]float64, len(a))
    62  		copy(aCopy, a)
    63  
    64  		// Compute the true answer with the unblocked algorithm.
    65  		impl.Dgebd2(m, n, a, lda, d, e, tauQ, tauP, work)
    66  		aAns := make([]float64, len(a))
    67  		copy(aAns, a)
    68  		dAns := make([]float64, len(d))
    69  		copy(dAns, d)
    70  		eAns := make([]float64, len(e))
    71  		copy(eAns, e)
    72  		tauQAns := make([]float64, len(tauQ))
    73  		copy(tauQAns, tauQ)
    74  		tauPAns := make([]float64, len(tauP))
    75  		copy(tauPAns, tauP)
    76  
    77  		// Test with optimal work.
    78  		lwork := -1
    79  		copy(a, aCopy)
    80  		impl.Dgebrd(m, n, a, lda, d, e, tauQ, tauP, work, lwork)
    81  		work = make([]float64, int(work[0]))
    82  		lwork = len(work)
    83  		for i := range work {
    84  			work[i] = math.NaN()
    85  		}
    86  		for i := range d {
    87  			d[i] = math.NaN()
    88  		}
    89  		for i := range e {
    90  			e[i] = math.NaN()
    91  		}
    92  		for i := range tauQ {
    93  			tauQ[i] = math.NaN()
    94  		}
    95  		for i := range tauP {
    96  			tauP[i] = math.NaN()
    97  		}
    98  		impl.Dgebrd(m, n, a, lda, d, e, tauQ, tauP, work, lwork)
    99  
   100  		// Test answers
   101  		if !floats.EqualApprox(a, aAns, 1e-10) {
   102  			t.Errorf("a mismatch")
   103  		}
   104  		if !floats.EqualApprox(d, dAns, 1e-10) {
   105  			t.Errorf("d mismatch")
   106  		}
   107  		if !floats.EqualApprox(e, eAns, 1e-10) {
   108  			t.Errorf("e mismatch")
   109  		}
   110  		if !floats.EqualApprox(tauQ, tauQAns, 1e-10) {
   111  			t.Errorf("tauQ mismatch")
   112  		}
   113  		if !floats.EqualApprox(tauP, tauPAns, 1e-10) {
   114  			t.Errorf("tauP mismatch")
   115  		}
   116  
   117  		// Test with shorter than optimal work.
   118  		lwork--
   119  		copy(a, aCopy)
   120  		for i := range d {
   121  			d[i] = 0
   122  		}
   123  		for i := range e {
   124  			e[i] = 0
   125  		}
   126  		for i := range tauP {
   127  			tauP[i] = 0
   128  		}
   129  		for i := range tauQ {
   130  			tauQ[i] = 0
   131  		}
   132  		impl.Dgebrd(m, n, a, lda, d, e, tauQ, tauP, work, lwork)
   133  
   134  		// Test answers
   135  		if !floats.EqualApprox(a, aAns, 1e-10) {
   136  			t.Errorf("a mismatch")
   137  		}
   138  		if !floats.EqualApprox(d, dAns, 1e-10) {
   139  			t.Errorf("d mismatch")
   140  		}
   141  		if !floats.EqualApprox(e, eAns, 1e-10) {
   142  			t.Errorf("e mismatch")
   143  		}
   144  		if !floats.EqualApprox(tauQ, tauQAns, 1e-10) {
   145  			t.Errorf("tauQ mismatch")
   146  		}
   147  		if !floats.EqualApprox(tauP, tauPAns, 1e-10) {
   148  			t.Errorf("tauP mismatch")
   149  		}
   150  	}
   151  }