gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/lapack/testlapack/dgghrd.go (about)

     1  // Copyright ©2023 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"gonum.org/v1/gonum/blas"
    14  	"gonum.org/v1/gonum/blas/blas64"
    15  	"gonum.org/v1/gonum/lapack"
    16  )
    17  
    18  type Dgghrder interface {
    19  	Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int, a []float64, lda int, b []float64, ldb int, q []float64, ldq int, z []float64, ldz int)
    20  }
    21  
    22  func DgghrdTest(t *testing.T, impl Dgghrder) {
    23  	rnd := rand.New(rand.NewSource(1))
    24  	comps := []lapack.OrthoComp{lapack.OrthoExplicit, lapack.OrthoNone, lapack.OrthoPostmul}
    25  	for _, compq := range comps {
    26  		for _, compz := range comps {
    27  			for _, n := range []int{0, 1, 2, 3, 4, 15} {
    28  				for _, ld := range []int{max(1, n), n + 5} {
    29  					testDgghrd(t, impl, rnd, compq, compz, n, 0, n-1, ld, ld, ld, ld)
    30  				}
    31  			}
    32  		}
    33  	}
    34  }
    35  
    36  func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, compq, compz lapack.OrthoComp, n, ilo, ihi, lda, ldb, ldq, ldz int) {
    37  	const tol = 1e-13
    38  
    39  	a := randomGeneral(n, n, lda, rnd)
    40  	b := randomGeneral(n, n, ldb, rnd)
    41  
    42  	var q, q1 blas64.General
    43  	switch compq {
    44  	case lapack.OrthoExplicit:
    45  		// Initialize q to a non-orthogonal matrix, Dgghrd should overwrite it
    46  		// with an orthogonal Q.
    47  		q = randomGeneral(n, n, ldq, rnd)
    48  	case lapack.OrthoPostmul:
    49  		// Initialize q to an orthogonal matrix Q1, so that the result Q1*Q is
    50  		// again orthogonal.
    51  		q = randomOrthogonal(n, rnd)
    52  		q1 = cloneGeneral(q)
    53  	}
    54  
    55  	var z, z1 blas64.General
    56  	switch compz {
    57  	case lapack.OrthoExplicit:
    58  		z = randomGeneral(n, n, ldz, rnd)
    59  	case lapack.OrthoPostmul:
    60  		z = randomOrthogonal(n, rnd)
    61  		z1 = cloneGeneral(z)
    62  	}
    63  
    64  	hGot := cloneGeneral(a)
    65  	tGot := cloneGeneral(b)
    66  	impl.Dgghrd(compq, compz, n, ilo, ihi, hGot.Data, hGot.Stride, tGot.Data, tGot.Stride, q.Data, max(1, q.Stride), z.Data, max(1, z.Stride))
    67  
    68  	if n == 0 {
    69  		return
    70  	}
    71  
    72  	name := fmt.Sprintf("Case compq=%v,compz=%v,n=%v,ilo=%v,ihi=%v", compq, compz, n, ilo, ihi)
    73  
    74  	if !isUpperHessenberg(hGot) {
    75  		t.Errorf("%v: H is not upper Hessenberg", name)
    76  	}
    77  	if !isUpperTriangular(tGot) {
    78  		t.Errorf("%v: T is not upper triangular", name)
    79  	}
    80  	if compq != lapack.OrthoNone {
    81  		if resid := residualOrthogonal(q, true); resid > tol {
    82  			t.Errorf("%v: Q is not orthogonal, resid=%v", name, resid)
    83  		}
    84  	}
    85  	if compz != lapack.OrthoNone {
    86  		if resid := residualOrthogonal(z, true); resid > tol {
    87  			t.Errorf("%v: Z is not orthogonal, resid=%v", name, resid)
    88  		}
    89  	}
    90  
    91  	if compq != compz {
    92  		// Verify reduction only when both Q and Z are computed.
    93  		return
    94  	}
    95  
    96  	// Zero out the lower triangle of B.
    97  	for i := 1; i < n; i++ {
    98  		for j := 0; j < i; j++ {
    99  			b.Data[i*b.Stride+j] = 0
   100  		}
   101  	}
   102  
   103  	aux := zeros(n, n, n)
   104  	switch compq {
   105  	case lapack.OrthoExplicit:
   106  		// Qᵀ*A*Z = H
   107  		hCalc := zeros(n, n, n)
   108  		blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, a, 0, aux)
   109  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, hCalc)
   110  		if !equalApproxGeneral(hGot, hCalc, tol) {
   111  			t.Errorf("%v: Qᵀ*A*Z != H", name)
   112  		}
   113  
   114  		// Qᵀ*B*Z = T
   115  		tCalc := zeros(n, n, n)
   116  		blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, b, 0, aux)
   117  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, tCalc)
   118  		if !equalApproxGeneral(tGot, tCalc, tol) {
   119  			t.Errorf("%v: Qᵀ*B*Z != T", name)
   120  		}
   121  	case lapack.OrthoPostmul:
   122  		//	Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ
   123  		lhs := zeros(n, n, n)
   124  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, a, 0, aux)
   125  		blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs)
   126  
   127  		rhs := zeros(n, n, n)
   128  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, hGot, 0, aux)
   129  		blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs)
   130  		if !equalApproxGeneral(lhs, rhs, tol) {
   131  			t.Errorf("%v: Q1 * A * Z1ᵀ != (Q1*Q) * H * (Z1*Z)ᵀ", name)
   132  		}
   133  
   134  		//	Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ
   135  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, b, 0, aux)
   136  		blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs)
   137  
   138  		blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, tGot, 0, aux)
   139  		blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs)
   140  		if !equalApproxGeneral(lhs, rhs, tol) {
   141  			t.Errorf("%v: Q1 * B * Z1ᵀ != (Q1*Q) * T * (Z1*Z)ᵀ", name)
   142  		}
   143  	}
   144  }