github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/testlapack/dgels.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"testing"
     9  
    10  	"golang.org/x/exp/rand"
    11  
    12  	"github.com/jingcheng-WU/gonum/blas"
    13  	"github.com/jingcheng-WU/gonum/blas/blas64"
    14  	"github.com/jingcheng-WU/gonum/floats"
    15  )
    16  
    17  type Dgelser interface {
    18  	Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool
    19  }
    20  
    21  func DgelsTest(t *testing.T, impl Dgelser) {
    22  	rnd := rand.New(rand.NewSource(1))
    23  	for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
    24  		for _, test := range []struct {
    25  			m, n, nrhs, lda, ldb int
    26  		}{
    27  			{3, 4, 5, 0, 0},
    28  			{3, 5, 4, 0, 0},
    29  			{4, 3, 5, 0, 0},
    30  			{4, 5, 3, 0, 0},
    31  			{5, 3, 4, 0, 0},
    32  			{5, 4, 3, 0, 0},
    33  			{3, 4, 5, 10, 20},
    34  			{3, 5, 4, 10, 20},
    35  			{4, 3, 5, 10, 20},
    36  			{4, 5, 3, 10, 20},
    37  			{5, 3, 4, 10, 20},
    38  			{5, 4, 3, 10, 20},
    39  			{3, 4, 5, 20, 10},
    40  			{3, 5, 4, 20, 10},
    41  			{4, 3, 5, 20, 10},
    42  			{4, 5, 3, 20, 10},
    43  			{5, 3, 4, 20, 10},
    44  			{5, 4, 3, 20, 10},
    45  			{200, 300, 400, 0, 0},
    46  			{200, 400, 300, 0, 0},
    47  			{300, 200, 400, 0, 0},
    48  			{300, 400, 200, 0, 0},
    49  			{400, 200, 300, 0, 0},
    50  			{400, 300, 200, 0, 0},
    51  			{200, 300, 400, 500, 600},
    52  			{200, 400, 300, 500, 600},
    53  			{300, 200, 400, 500, 600},
    54  			{300, 400, 200, 500, 600},
    55  			{400, 200, 300, 500, 600},
    56  			{400, 300, 200, 500, 600},
    57  			{200, 300, 400, 600, 500},
    58  			{200, 400, 300, 600, 500},
    59  			{300, 200, 400, 600, 500},
    60  			{300, 400, 200, 600, 500},
    61  			{400, 200, 300, 600, 500},
    62  			{400, 300, 200, 600, 500},
    63  		} {
    64  			m := test.m
    65  			n := test.n
    66  			nrhs := test.nrhs
    67  
    68  			lda := test.lda
    69  			if lda == 0 {
    70  				lda = n
    71  			}
    72  			a := make([]float64, m*lda)
    73  			for i := range a {
    74  				a[i] = rnd.Float64()
    75  			}
    76  			aCopy := make([]float64, len(a))
    77  			copy(aCopy, a)
    78  
    79  			// Size of b is the same trans or no trans, because the number of rows
    80  			// has to be the max of (m,n).
    81  			mb := max(m, n)
    82  			nb := nrhs
    83  			ldb := test.ldb
    84  			if ldb == 0 {
    85  				ldb = nb
    86  			}
    87  			b := make([]float64, mb*ldb)
    88  			for i := range b {
    89  				b[i] = rnd.Float64()
    90  			}
    91  			bCopy := make([]float64, len(b))
    92  			copy(bCopy, b)
    93  
    94  			// Find optimal work length.
    95  			work := make([]float64, 1)
    96  			impl.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, -1)
    97  
    98  			// Perform linear solve
    99  			work = make([]float64, int(work[0]))
   100  			lwork := len(work)
   101  			for i := range work {
   102  				work[i] = rnd.Float64()
   103  			}
   104  			impl.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, lwork)
   105  
   106  			// Check that the answer is correct by comparing to the normal equations.
   107  			aMat := blas64.General{
   108  				Rows:   m,
   109  				Cols:   n,
   110  				Stride: lda,
   111  				Data:   make([]float64, len(aCopy)),
   112  			}
   113  			copy(aMat.Data, aCopy)
   114  			szAta := n
   115  			if trans == blas.Trans {
   116  				szAta = m
   117  			}
   118  			aTA := blas64.General{
   119  				Rows:   szAta,
   120  				Cols:   szAta,
   121  				Stride: szAta,
   122  				Data:   make([]float64, szAta*szAta),
   123  			}
   124  
   125  			// Compute Aᵀ * A if notrans and A * Aᵀ otherwise.
   126  			if trans == blas.NoTrans {
   127  				blas64.Gemm(blas.Trans, blas.NoTrans, 1, aMat, aMat, 0, aTA)
   128  			} else {
   129  				blas64.Gemm(blas.NoTrans, blas.Trans, 1, aMat, aMat, 0, aTA)
   130  			}
   131  
   132  			// Multiply by X.
   133  			X := blas64.General{
   134  				Rows:   szAta,
   135  				Cols:   nrhs,
   136  				Stride: ldb,
   137  				Data:   b,
   138  			}
   139  			ans := blas64.General{
   140  				Rows:   aTA.Rows,
   141  				Cols:   X.Cols,
   142  				Stride: X.Cols,
   143  				Data:   make([]float64, aTA.Rows*X.Cols),
   144  			}
   145  			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aTA, X, 0, ans)
   146  
   147  			B := blas64.General{
   148  				Rows:   szAta,
   149  				Cols:   nrhs,
   150  				Stride: ldb,
   151  				Data:   make([]float64, len(bCopy)),
   152  			}
   153  
   154  			copy(B.Data, bCopy)
   155  			var ans2 blas64.General
   156  			if trans == blas.NoTrans {
   157  				ans2 = blas64.General{
   158  					Rows:   aMat.Cols,
   159  					Cols:   B.Cols,
   160  					Stride: B.Cols,
   161  					Data:   make([]float64, aMat.Cols*B.Cols),
   162  				}
   163  			} else {
   164  				ans2 = blas64.General{
   165  					Rows:   aMat.Rows,
   166  					Cols:   B.Cols,
   167  					Stride: B.Cols,
   168  					Data:   make([]float64, aMat.Rows*B.Cols),
   169  				}
   170  			}
   171  
   172  			// Compute Aᵀ B if Trans or A * B otherwise
   173  			if trans == blas.NoTrans {
   174  				blas64.Gemm(blas.Trans, blas.NoTrans, 1, aMat, B, 0, ans2)
   175  			} else {
   176  				blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, B, 0, ans2)
   177  			}
   178  			if !floats.EqualApprox(ans.Data, ans2.Data, 1e-12) {
   179  				t.Errorf("Normal equations not satisfied")
   180  			}
   181  		}
   182  	}
   183  }