github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgels.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math/rand"
     9  	"testing"
    10  
    11  	"github.com/gonum/blas"
    12  	"github.com/gonum/blas/blas64"
    13  	"github.com/gonum/floats"
    14  )
    15  
    16  type Dgelser interface {
    17  	Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool
    18  }
    19  
    20  func DgelsTest(t *testing.T, impl Dgelser) {
    21  	rnd := rand.New(rand.NewSource(1))
    22  	for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
    23  		for _, test := range []struct {
    24  			m, n, nrhs, lda, ldb int
    25  		}{
    26  			{3, 4, 5, 0, 0},
    27  			{3, 5, 4, 0, 0},
    28  			{4, 3, 5, 0, 0},
    29  			{4, 5, 3, 0, 0},
    30  			{5, 3, 4, 0, 0},
    31  			{5, 4, 3, 0, 0},
    32  			{3, 4, 5, 10, 20},
    33  			{3, 5, 4, 10, 20},
    34  			{4, 3, 5, 10, 20},
    35  			{4, 5, 3, 10, 20},
    36  			{5, 3, 4, 10, 20},
    37  			{5, 4, 3, 10, 20},
    38  			{3, 4, 5, 20, 10},
    39  			{3, 5, 4, 20, 10},
    40  			{4, 3, 5, 20, 10},
    41  			{4, 5, 3, 20, 10},
    42  			{5, 3, 4, 20, 10},
    43  			{5, 4, 3, 20, 10},
    44  			{200, 300, 400, 0, 0},
    45  			{200, 400, 300, 0, 0},
    46  			{300, 200, 400, 0, 0},
    47  			{300, 400, 200, 0, 0},
    48  			{400, 200, 300, 0, 0},
    49  			{400, 300, 200, 0, 0},
    50  			{200, 300, 400, 500, 600},
    51  			{200, 400, 300, 500, 600},
    52  			{300, 200, 400, 500, 600},
    53  			{300, 400, 200, 500, 600},
    54  			{400, 200, 300, 500, 600},
    55  			{400, 300, 200, 500, 600},
    56  			{200, 300, 400, 600, 500},
    57  			{200, 400, 300, 600, 500},
    58  			{300, 200, 400, 600, 500},
    59  			{300, 400, 200, 600, 500},
    60  			{400, 200, 300, 600, 500},
    61  			{400, 300, 200, 600, 500},
    62  		} {
    63  			m := test.m
    64  			n := test.n
    65  			nrhs := test.nrhs
    66  
    67  			lda := test.lda
    68  			if lda == 0 {
    69  				lda = n
    70  			}
    71  			a := make([]float64, m*lda)
    72  			for i := range a {
    73  				a[i] = rnd.Float64()
    74  			}
    75  			aCopy := make([]float64, len(a))
    76  			copy(aCopy, a)
    77  
    78  			// Size of b is the same trans or no trans, because the number of rows
    79  			// has to be the max of (m,n).
    80  			mb := max(m, n)
    81  			nb := nrhs
    82  			ldb := test.ldb
    83  			if ldb == 0 {
    84  				ldb = nb
    85  			}
    86  			b := make([]float64, mb*ldb)
    87  			for i := range b {
    88  				b[i] = rnd.Float64()
    89  			}
    90  			bCopy := make([]float64, len(b))
    91  			copy(bCopy, b)
    92  
    93  			// Find optimal work length.
    94  			work := make([]float64, 1)
    95  			impl.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, -1)
    96  
    97  			// Perform linear solve
    98  			work = make([]float64, int(work[0]))
    99  			lwork := len(work)
   100  			for i := range work {
   101  				work[i] = rnd.Float64()
   102  			}
   103  			impl.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, lwork)
   104  
   105  			// Check that the answer is correct by comparing to the normal equations.
   106  			aMat := blas64.General{
   107  				Rows:   m,
   108  				Cols:   n,
   109  				Stride: lda,
   110  				Data:   make([]float64, len(aCopy)),
   111  			}
   112  			copy(aMat.Data, aCopy)
   113  			szAta := n
   114  			if trans == blas.Trans {
   115  				szAta = m
   116  			}
   117  			aTA := blas64.General{
   118  				Rows:   szAta,
   119  				Cols:   szAta,
   120  				Stride: szAta,
   121  				Data:   make([]float64, szAta*szAta),
   122  			}
   123  
   124  			// Compute A^T * A if notrans and A * A^T otherwise.
   125  			if trans == blas.NoTrans {
   126  				blas64.Gemm(blas.Trans, blas.NoTrans, 1, aMat, aMat, 0, aTA)
   127  			} else {
   128  				blas64.Gemm(blas.NoTrans, blas.Trans, 1, aMat, aMat, 0, aTA)
   129  			}
   130  
   131  			// Multiply by X.
   132  			X := blas64.General{
   133  				Rows:   szAta,
   134  				Cols:   nrhs,
   135  				Stride: ldb,
   136  				Data:   b,
   137  			}
   138  			ans := blas64.General{
   139  				Rows:   aTA.Rows,
   140  				Cols:   X.Cols,
   141  				Stride: X.Cols,
   142  				Data:   make([]float64, aTA.Rows*X.Cols),
   143  			}
   144  			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aTA, X, 0, ans)
   145  
   146  			B := blas64.General{
   147  				Rows:   szAta,
   148  				Cols:   nrhs,
   149  				Stride: ldb,
   150  				Data:   make([]float64, len(bCopy)),
   151  			}
   152  
   153  			copy(B.Data, bCopy)
   154  			var ans2 blas64.General
   155  			if trans == blas.NoTrans {
   156  				ans2 = blas64.General{
   157  					Rows:   aMat.Cols,
   158  					Cols:   B.Cols,
   159  					Stride: B.Cols,
   160  					Data:   make([]float64, aMat.Cols*B.Cols),
   161  				}
   162  			} else {
   163  				ans2 = blas64.General{
   164  					Rows:   aMat.Rows,
   165  					Cols:   B.Cols,
   166  					Stride: B.Cols,
   167  					Data:   make([]float64, aMat.Rows*B.Cols),
   168  				}
   169  			}
   170  
   171  			// Compute A^T B if Trans or A * B otherwise
   172  			if trans == blas.NoTrans {
   173  				blas64.Gemm(blas.Trans, blas.NoTrans, 1, aMat, B, 0, ans2)
   174  			} else {
   175  				blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, B, 0, ans2)
   176  			}
   177  			if !floats.EqualApprox(ans.Data, ans2.Data, 1e-12) {
   178  				t.Errorf("Normal equations not satisfied")
   179  			}
   180  		}
   181  	}
   182  }