github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgetrs.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math/rand"
     9  	"testing"
    10  
    11  	"github.com/gonum/blas"
    12  	"github.com/gonum/blas/blas64"
    13  	"github.com/gonum/floats"
    14  )
    15  
    16  type Dgetrser interface {
    17  	Dgetrfer
    18  	Dgetrs(trans blas.Transpose, n, nrhs int, a []float64, lda int, ipiv []int, b []float64, ldb int)
    19  }
    20  
    21  func DgetrsTest(t *testing.T, impl Dgetrser) {
    22  	rnd := rand.New(rand.NewSource(1))
    23  	// TODO(btracey): Put more thought into creating more regularized matrices
    24  	// and what correct tolerances should be. Consider also seeding the random
    25  	// number in this test to make it more robust to code changes in other
    26  	// parts of the suite.
    27  	for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
    28  		for _, test := range []struct {
    29  			n, nrhs, lda, ldb int
    30  			tol               float64
    31  		}{
    32  			{3, 3, 0, 0, 1e-12},
    33  			{3, 5, 0, 0, 1e-12},
    34  			{5, 3, 0, 0, 1e-12},
    35  
    36  			{3, 3, 8, 10, 1e-12},
    37  			{3, 5, 8, 10, 1e-12},
    38  			{5, 3, 8, 10, 1e-12},
    39  
    40  			{300, 300, 0, 0, 1e-8},
    41  			{300, 500, 0, 0, 1e-8},
    42  			{500, 300, 0, 0, 1e-6},
    43  
    44  			{300, 300, 700, 600, 1e-8},
    45  			{300, 500, 700, 600, 1e-8},
    46  			{500, 300, 700, 600, 1e-6},
    47  		} {
    48  			n := test.n
    49  			nrhs := test.nrhs
    50  			lda := test.lda
    51  			if lda == 0 {
    52  				lda = n
    53  			}
    54  			ldb := test.ldb
    55  			if ldb == 0 {
    56  				ldb = nrhs
    57  			}
    58  			a := make([]float64, n*lda)
    59  			for i := range a {
    60  				a[i] = rnd.Float64()
    61  			}
    62  			b := make([]float64, n*ldb)
    63  			for i := range b {
    64  				b[i] = rnd.Float64()
    65  			}
    66  			aCopy := make([]float64, len(a))
    67  			copy(aCopy, a)
    68  			bCopy := make([]float64, len(b))
    69  			copy(bCopy, b)
    70  
    71  			ipiv := make([]int, n)
    72  			for i := range ipiv {
    73  				ipiv[i] = rnd.Int()
    74  			}
    75  
    76  			// Compute the LU factorization.
    77  			impl.Dgetrf(n, n, a, lda, ipiv)
    78  			// Solve the system of equations given the result.
    79  			impl.Dgetrs(trans, n, nrhs, a, lda, ipiv, b, ldb)
    80  
    81  			// Check that the system of equations holds.
    82  			A := blas64.General{
    83  				Rows:   n,
    84  				Cols:   n,
    85  				Stride: lda,
    86  				Data:   aCopy,
    87  			}
    88  			B := blas64.General{
    89  				Rows:   n,
    90  				Cols:   nrhs,
    91  				Stride: ldb,
    92  				Data:   bCopy,
    93  			}
    94  			X := blas64.General{
    95  				Rows:   n,
    96  				Cols:   nrhs,
    97  				Stride: ldb,
    98  				Data:   b,
    99  			}
   100  			tmp := blas64.General{
   101  				Rows:   n,
   102  				Cols:   nrhs,
   103  				Stride: ldb,
   104  				Data:   make([]float64, n*ldb),
   105  			}
   106  			copy(tmp.Data, bCopy)
   107  			blas64.Gemm(trans, blas.NoTrans, 1, A, X, 0, B)
   108  			if !floats.EqualApprox(tmp.Data, bCopy, test.tol) {
   109  				t.Errorf("Linear solve mismatch. trans = %v, n = %v, nrhs = %v, lda = %v, ldb = %v", trans, n, nrhs, lda, ldb)
   110  			}
   111  		}
   112  	}
   113  }