gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dlasy2.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"gonum.org/v1/gonum/blas"
    14  	"gonum.org/v1/gonum/blas/blas64"
    15  	"gonum.org/v1/gonum/lapack"
    16  )
    17  
    18  type Dlasy2er interface {
    19  	Dlasy2(tranl, tranr bool, isgn, n1, n2 int, tl []float64, ldtl int, tr []float64, ldtr int, b []float64, ldb int, x []float64, ldx int) (scale, xnorm float64, ok bool)
    20  }
    21  
    22  func Dlasy2Test(t *testing.T, impl Dlasy2er) {
    23  	rnd := rand.New(rand.NewSource(1))
    24  	for _, tranl := range []bool{true, false} {
    25  		for _, tranr := range []bool{true, false} {
    26  			for _, isgn := range []int{1, -1} {
    27  				for _, n1 := range []int{0, 1, 2} {
    28  					for _, n2 := range []int{0, 1, 2} {
    29  						for _, extra := range []int{0, 3} {
    30  							for cas := 0; cas < 100; cas++ {
    31  								var big bool
    32  								if cas%2 == 0 {
    33  									big = true
    34  								}
    35  								testDlasy2(t, impl, tranl, tranr, isgn, n1, n2, extra, big, rnd)
    36  							}
    37  						}
    38  					}
    39  				}
    40  			}
    41  		}
    42  	}
    43  }
    44  
    45  func testDlasy2(t *testing.T, impl Dlasy2er, tranl, tranr bool, isgn, n1, n2, extra int, big bool, rnd *rand.Rand) {
    46  	const tol = 1e-14
    47  
    48  	name := fmt.Sprintf("Case n1=%v, n2=%v, isgn=%v, big=%v", n1, n2, isgn, big)
    49  
    50  	tl := randomGeneral(n1, n1, n1+extra, rnd)
    51  	tr := randomGeneral(n2, n2, n2+extra, rnd)
    52  	x := randomGeneral(n1, n2, n2+extra, rnd)
    53  	b := randomGeneral(n1, n2, n2+extra, rnd)
    54  	if big {
    55  		for i := 0; i < n1; i++ {
    56  			for j := 0; j < n2; j++ {
    57  				b.Data[i*b.Stride+j] *= bignum
    58  			}
    59  		}
    60  	}
    61  
    62  	tlCopy := cloneGeneral(tl)
    63  	trCopy := cloneGeneral(tr)
    64  	bCopy := cloneGeneral(b)
    65  
    66  	scale, xnorm, ok := impl.Dlasy2(tranl, tranr, isgn, n1, n2, tl.Data, tl.Stride, tr.Data, tr.Stride, b.Data, b.Stride, x.Data, x.Stride)
    67  
    68  	// Check any invalid modifications in read-only input.
    69  	if !equalGeneral(tl, tlCopy) {
    70  		t.Errorf("%v: unexpected modification in TL", name)
    71  	}
    72  	if !equalGeneral(tr, trCopy) {
    73  		t.Errorf("%v: unexpected modification in TR", name)
    74  	}
    75  	if !equalGeneral(b, bCopy) {
    76  		t.Errorf("%v: unexpected modification in B", name)
    77  	}
    78  
    79  	// Check any invalid modifications of x.
    80  	if !generalOutsideAllNaN(x) {
    81  		t.Errorf("%v: out-of-range write to x\n%v", name, x.Data)
    82  	}
    83  
    84  	if n1 == 0 || n2 == 0 {
    85  		return
    86  	}
    87  
    88  	if scale <= 0 || 1 < scale {
    89  		t.Errorf("%v: invalid value of scale, want in (0,1], got %v", name, scale)
    90  	}
    91  
    92  	xnormWant := dlange(lapack.MaxRowSum, x.Rows, x.Cols, x.Data, x.Stride)
    93  	if xnormWant != xnorm {
    94  		t.Errorf("%v: unexpected xnorm: want %v, got %v", name, xnormWant, xnorm)
    95  	}
    96  
    97  	if !ok {
    98  		t.Logf("%v: Dlasy2 returned ok=false", name)
    99  		return
   100  	}
   101  
   102  	// Compute diff := op(TL)*X + sgn*X*op(TR) - scale*B.
   103  	diff := zeros(n1, n2, n2)
   104  	if tranl {
   105  		blas64.Gemm(blas.Trans, blas.NoTrans, 1, tl, x, 0, diff)
   106  	} else {
   107  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tl, x, 0, diff)
   108  	}
   109  	if tranr {
   110  		blas64.Gemm(blas.NoTrans, blas.Trans, float64(isgn), x, tr, 1, diff)
   111  	} else {
   112  		blas64.Gemm(blas.NoTrans, blas.NoTrans, float64(isgn), x, tr, 1, diff)
   113  	}
   114  	for i := 0; i < n1; i++ {
   115  		for j := 0; j < n2; j++ {
   116  			diff.Data[i*diff.Stride+j] -= scale * b.Data[i*b.Stride+j]
   117  		}
   118  	}
   119  	// Check that residual |op(TL)*X + sgn*X*op(TR) - scale*B| / |X| is small.
   120  	resid := dlange(lapack.MaxColumnSum, n1, n2, diff.Data, diff.Stride) / xnorm
   121  	if resid > tol {
   122  		t.Errorf("%v: unexpected result, resid=%v, want<=%v", name, resid, tol)
   123  	}
   124  }