gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dgetc2.go (about)

     1  // Copyright ©2021 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"testing"
    11  
    12  	"golang.org/x/exp/rand"
    13  
    14  	"gonum.org/v1/gonum/blas"
    15  	"gonum.org/v1/gonum/blas/blas64"
    16  	"gonum.org/v1/gonum/lapack"
    17  )
    18  
    19  type Dgetc2er interface {
    20  	Dgetc2(n int, a []float64, lda int, ipiv, jpiv []int) (k int)
    21  }
    22  
    23  func Dgetc2Test(t *testing.T, impl Dgetc2er) {
    24  	rnd := rand.New(rand.NewSource(1))
    25  	for _, n := range []int{0, 1, 2, 3, 4, 5, 10, 20} {
    26  		for _, lda := range []int{n, n + 5} {
    27  			dgetc2Test(t, impl, rnd, n, lda, false)
    28  			dgetc2Test(t, impl, rnd, n, lda, true)
    29  		}
    30  	}
    31  }
    32  
    33  func dgetc2Test(t *testing.T, impl Dgetc2er, rnd *rand.Rand, n, lda int, perturb bool) {
    34  	const tol = 1e-14
    35  
    36  	name := fmt.Sprintf("n=%v,lda=%v,perturb=%v", n, lda, perturb)
    37  
    38  	// Generate a random lower-triangular matrix with unit diagonal.
    39  	l := randomGeneral(n, n, max(1, n), rnd)
    40  	for i := 0; i < n; i++ {
    41  		l.Data[i*l.Stride+i] = 1
    42  		for j := i + 1; j < n; j++ {
    43  			l.Data[i*l.Stride+j] = 0
    44  		}
    45  	}
    46  
    47  	// Generate a random upper-triangular matrix.
    48  	u := randomGeneral(n, n, max(1, n), rnd)
    49  	for i := 0; i < n; i++ {
    50  		for j := 0; j < i; j++ {
    51  			u.Data[i*u.Stride+j] = 0
    52  		}
    53  	}
    54  	if perturb && n > 0 {
    55  		// Make U singular by randomly placing a zero on the diagonal.
    56  		i := rnd.Intn(n)
    57  		u.Data[i*u.Stride+i] = 0
    58  	}
    59  
    60  	// Construct A = L*U.
    61  	a := zeros(n, n, max(1, lda))
    62  	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, l, u, 0, a)
    63  
    64  	// Allocate slices for pivots and pre-fill them with invalid indices.
    65  	ipiv := make([]int, n)
    66  	jpiv := make([]int, n)
    67  	for i := 0; i < n; i++ {
    68  		ipiv[i] = -1
    69  		jpiv[i] = -1
    70  	}
    71  
    72  	// Call Dgetc2 to compute the LU decomposition.
    73  	lu := cloneGeneral(a)
    74  	k := impl.Dgetc2(n, lu.Data, lu.Stride, ipiv, jpiv)
    75  
    76  	if n == 0 {
    77  		return
    78  	}
    79  
    80  	if perturb && k < 0 {
    81  		t.Errorf("%v: expected matrix perturbation", name)
    82  	}
    83  
    84  	// Verify all indices have been set.
    85  	for i := 0; i < n; i++ {
    86  		if ipiv[i] < 0 {
    87  			t.Errorf("%v: ipiv[%d] is not set", name, i)
    88  		}
    89  		if jpiv[i] < 0 {
    90  			t.Errorf("%v: jpiv[%d] is not set", name, i)
    91  		}
    92  	}
    93  
    94  	// Construct L and U matrices from Dgetc2 output.
    95  	l = zeros(n, n, n)
    96  	u = zeros(n, n, n)
    97  	for i := 0; i < n; i++ {
    98  		for j := 0; j < i; j++ {
    99  			l.Data[i*l.Stride+j] = lu.Data[i*lu.Stride+j]
   100  		}
   101  		l.Data[i*l.Stride+i] = 1
   102  		for j := i; j < n; j++ {
   103  			u.Data[i*u.Stride+j] = lu.Data[i*lu.Stride+j]
   104  		}
   105  	}
   106  	diff := zeros(n, n, n)
   107  	blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, l, u, 0, diff)
   108  
   109  	// Apply permutation matrices P and Q to L*U.
   110  	for i := n - 1; i >= 0; i-- {
   111  		ipv := ipiv[i]
   112  		if ipv != i {
   113  			row1 := blas64.Vector{N: n, Data: diff.Data[i*diff.Stride:], Inc: 1}
   114  			row2 := blas64.Vector{N: n, Data: diff.Data[ipv*diff.Stride:], Inc: 1}
   115  			blas64.Swap(row1, row2)
   116  		}
   117  		jpv := jpiv[i]
   118  		if jpv != i {
   119  			col1 := blas64.Vector{N: n, Data: diff.Data[i:], Inc: diff.Stride}
   120  			col2 := blas64.Vector{N: n, Data: diff.Data[jpv:], Inc: diff.Stride}
   121  			blas64.Swap(col1, col2)
   122  		}
   123  	}
   124  
   125  	// Compute the residual |P*L*U*Q - A| and check that it is small.
   126  	for i := 0; i < n; i++ {
   127  		for j := 0; j < n; j++ {
   128  			diff.Data[i*diff.Stride+j] -= a.Data[i*a.Stride+j]
   129  		}
   130  	}
   131  	resid := dlange(lapack.MaxColumnSum, n, n, diff.Data, diff.Stride)
   132  	if resid > tol || math.IsNaN(resid) {
   133  		t.Errorf("%v: unexpected result; resid=%v, want<=%v", name, resid, tol)
   134  	}
   135  }