github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dgetf2.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"math/rand"
     9  	"testing"
    10  
    11  	"github.com/gonum/blas"
    12  	"github.com/gonum/blas/blas64"
    13  	"github.com/gonum/floats"
    14  )
    15  
    16  type Dgetf2er interface {
    17  	Dgetf2(m, n int, a []float64, lda int, ipiv []int) bool
    18  }
    19  
    20  func Dgetf2Test(t *testing.T, impl Dgetf2er) {
    21  	rnd := rand.New(rand.NewSource(1))
    22  	for _, test := range []struct {
    23  		m, n, lda int
    24  	}{
    25  		{10, 10, 0},
    26  		{10, 5, 0},
    27  		{10, 5, 0},
    28  
    29  		{10, 10, 20},
    30  		{5, 10, 20},
    31  		{10, 5, 20},
    32  	} {
    33  		m := test.m
    34  		n := test.n
    35  		lda := test.lda
    36  		if lda == 0 {
    37  			lda = n
    38  		}
    39  		a := make([]float64, m*lda)
    40  		for i := range a {
    41  			a[i] = rnd.Float64()
    42  		}
    43  		aCopy := make([]float64, len(a))
    44  		copy(aCopy, a)
    45  
    46  		mn := min(m, n)
    47  		ipiv := make([]int, mn)
    48  		for i := range ipiv {
    49  			ipiv[i] = rnd.Int()
    50  		}
    51  		ok := impl.Dgetf2(m, n, a, lda, ipiv)
    52  		checkPLU(t, ok, m, n, lda, ipiv, a, aCopy, 1e-14, true)
    53  	}
    54  
    55  	// Test with singular matrices (random matrices are almost surely non-singular).
    56  	for _, test := range []struct {
    57  		m, n, lda int
    58  		a         []float64
    59  	}{
    60  		{
    61  			m:   2,
    62  			n:   2,
    63  			lda: 2,
    64  			a: []float64{
    65  				1, 0,
    66  				0, 0,
    67  			},
    68  		},
    69  		{
    70  			m:   2,
    71  			n:   2,
    72  			lda: 2,
    73  			a: []float64{
    74  				1, 5,
    75  				2, 10,
    76  			},
    77  		},
    78  		{
    79  			m:   3,
    80  			n:   3,
    81  			lda: 3,
    82  			// row 3 = row1 + 2 * row2
    83  			a: []float64{
    84  				1, 5, 7,
    85  				2, 10, -3,
    86  				5, 25, 1,
    87  			},
    88  		},
    89  		{
    90  			m:   3,
    91  			n:   4,
    92  			lda: 4,
    93  			// row 3 = row1 + 2 * row2
    94  			a: []float64{
    95  				1, 5, 7, 9,
    96  				2, 10, -3, 11,
    97  				5, 25, 1, 31,
    98  			},
    99  		},
   100  	} {
   101  		if impl.Dgetf2(test.m, test.n, test.a, test.lda, make([]int, min(test.m, test.n))) {
   102  			t.Log("Returned ok with singular matrix.")
   103  		}
   104  	}
   105  }
   106  
   107  // checkPLU checks that the PLU factorization contained in factorize matches
   108  // the original matrix contained in original.
   109  func checkPLU(t *testing.T, ok bool, m, n, lda int, ipiv []int, factorized, original []float64, tol float64, print bool) {
   110  	var hasZeroDiagonal bool
   111  	for i := 0; i < min(m, n); i++ {
   112  		if factorized[i*lda+i] == 0 {
   113  			hasZeroDiagonal = true
   114  			break
   115  		}
   116  	}
   117  	if hasZeroDiagonal && ok {
   118  		t.Error("Has a zero diagonal but returned ok")
   119  	}
   120  	if !hasZeroDiagonal && !ok {
   121  		t.Error("Non-zero diagonal but returned !ok")
   122  	}
   123  
   124  	// Check that the LU decomposition is correct.
   125  	mn := min(m, n)
   126  	l := make([]float64, m*mn)
   127  	ldl := mn
   128  	u := make([]float64, mn*n)
   129  	ldu := n
   130  	for i := 0; i < m; i++ {
   131  		for j := 0; j < n; j++ {
   132  			v := factorized[i*lda+j]
   133  			switch {
   134  			case i == j:
   135  				l[i*ldl+i] = 1
   136  				u[i*ldu+i] = v
   137  			case i > j:
   138  				l[i*ldl+j] = v
   139  			case i < j:
   140  				u[i*ldu+j] = v
   141  			}
   142  		}
   143  	}
   144  
   145  	LU := blas64.General{
   146  		Rows:   m,
   147  		Cols:   n,
   148  		Stride: n,
   149  		Data:   make([]float64, m*n),
   150  	}
   151  	U := blas64.General{
   152  		Rows:   mn,
   153  		Cols:   n,
   154  		Stride: ldu,
   155  		Data:   u,
   156  	}
   157  	L := blas64.General{
   158  		Rows:   m,
   159  		Cols:   mn,
   160  		Stride: ldl,
   161  		Data:   l,
   162  	}
   163  	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, L, U, 0, LU)
   164  
   165  	p := make([]float64, m*m)
   166  	ldp := m
   167  	for i := 0; i < m; i++ {
   168  		p[i*ldp+i] = 1
   169  	}
   170  	for i := len(ipiv) - 1; i >= 0; i-- {
   171  		v := ipiv[i]
   172  		blas64.Swap(m, blas64.Vector{Inc: 1, Data: p[i*ldp:]}, blas64.Vector{Inc: 1, Data: p[v*ldp:]})
   173  	}
   174  	P := blas64.General{
   175  		Rows:   m,
   176  		Cols:   m,
   177  		Stride: m,
   178  		Data:   p,
   179  	}
   180  	aComp := blas64.General{
   181  		Rows:   m,
   182  		Cols:   n,
   183  		Stride: lda,
   184  		Data:   make([]float64, m*lda),
   185  	}
   186  	copy(aComp.Data, factorized)
   187  	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, P, LU, 0, aComp)
   188  	if !floats.EqualApprox(aComp.Data, original, tol) {
   189  		if print {
   190  			t.Errorf("PLU multiplication does not match original matrix.\nWant: %v\nGot: %v", original, aComp.Data)
   191  			return
   192  		}
   193  		t.Error("PLU multiplication does not match original matrix.")
   194  	}
   195  }