gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dgetrs.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"testing"
     9  
    10  	"golang.org/x/exp/rand"
    11  
    12  	"gonum.org/v1/gonum/blas"
    13  	"gonum.org/v1/gonum/blas/blas64"
    14  	"gonum.org/v1/gonum/floats"
    15  )
    16  
    17  type Dgetrser interface {
    18  	Dgetrfer
    19  	Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int)
    20  }
    21  
    22  func DgetrsTest(t *testing.T, impl Dgetrser) {
    23  	rnd := rand.New(rand.NewSource(1))
    24  	// TODO(btracey): Put more thought into creating more regularized matrices
    25  	// and what correct tolerances should be. Consider also seeding the random
    26  	// number in this test to make it more robust to code changes in other
    27  	// parts of the suite.
    28  	for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
    29  		for _, test := range []struct {
    30  			n, nrhs, lda, ldb int
    31  			tol               float64
    32  		}{
    33  			{3, 3, 0, 0, 1e-12},
    34  			{3, 5, 0, 0, 1e-12},
    35  			{5, 3, 0, 0, 1e-12},
    36  
    37  			{3, 3, 8, 10, 1e-12},
    38  			{3, 5, 8, 10, 1e-12},
    39  			{5, 3, 8, 10, 1e-12},
    40  
    41  			{300, 300, 0, 0, 1e-8},
    42  			{300, 500, 0, 0, 1e-8},
    43  			{500, 300, 0, 0, 1e-6},
    44  
    45  			{300, 300, 700, 600, 1e-8},
    46  			{300, 500, 700, 600, 1e-8},
    47  			{500, 300, 700, 600, 1e-6},
    48  		} {
    49  			n := test.n
    50  			nrhs := test.nrhs
    51  			lda := test.lda
    52  			if lda == 0 {
    53  				lda = n
    54  			}
    55  			ldb := test.ldb
    56  			if ldb == 0 {
    57  				ldb = nrhs
    58  			}
    59  			a := make([]float64, n*lda)
    60  			for i := range a {
    61  				a[i] = rnd.Float64()
    62  			}
    63  			b := make([]float64, n*ldb)
    64  			for i := range b {
    65  				b[i] = rnd.Float64()
    66  			}
    67  			aCopy := make([]float64, len(a))
    68  			copy(aCopy, a)
    69  			bCopy := make([]float64, len(b))
    70  			copy(bCopy, b)
    71  
    72  			ipiv := make([]int, n)
    73  			for i := range ipiv {
    74  				ipiv[i] = rnd.Int()
    75  			}
    76  
    77  			// Compute the LU factorization.
    78  			impl.Dgetrf(n, n, a, lda, ipiv)
    79  			// Solve the system of equations given the result.
    80  			impl.Dgetrs(trans, n, nrhs, a, lda, ipiv, b, ldb)
    81  
    82  			// Check that the system of equations holds.
    83  			A := blas64.General{
    84  				Rows:   n,
    85  				Cols:   n,
    86  				Stride: lda,
    87  				Data:   aCopy,
    88  			}
    89  			B := blas64.General{
    90  				Rows:   n,
    91  				Cols:   nrhs,
    92  				Stride: ldb,
    93  				Data:   bCopy,
    94  			}
    95  			X := blas64.General{
    96  				Rows:   n,
    97  				Cols:   nrhs,
    98  				Stride: ldb,
    99  				Data:   b,
   100  			}
   101  			tmp := blas64.General{
   102  				Rows:   n,
   103  				Cols:   nrhs,
   104  				Stride: ldb,
   105  				Data:   make([]float64, n*ldb),
   106  			}
   107  			copy(tmp.Data, bCopy)
   108  			blas64.Gemm(trans, blas.NoTrans, 1, A, X, 0, B)
   109  			if !floats.EqualApprox(tmp.Data, bCopy, test.tol) {
   110  				t.Errorf("Linear solve mismatch. trans = %v, n = %v, nrhs = %v, lda = %v, ldb = %v", trans, n, nrhs, lda, ldb)
   111  			}
   112  		}
   113  	}
   114  }