gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/lapack/testlapack/dgghrd.go (about) 1 // Copyright ©2023 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/blas" 14 "gonum.org/v1/gonum/blas/blas64" 15 "gonum.org/v1/gonum/lapack" 16 ) 17 18 type Dgghrder interface { 19 Dgghrd(compq, compz lapack.OrthoComp, n, ilo, ihi int, a []float64, lda int, b []float64, ldb int, q []float64, ldq int, z []float64, ldz int) 20 } 21 22 func DgghrdTest(t *testing.T, impl Dgghrder) { 23 rnd := rand.New(rand.NewSource(1)) 24 comps := []lapack.OrthoComp{lapack.OrthoExplicit, lapack.OrthoNone, lapack.OrthoPostmul} 25 for _, compq := range comps { 26 for _, compz := range comps { 27 for _, n := range []int{0, 1, 2, 3, 4, 15} { 28 for _, ld := range []int{max(1, n), n + 5} { 29 testDgghrd(t, impl, rnd, compq, compz, n, 0, n-1, ld, ld, ld, ld) 30 } 31 } 32 } 33 } 34 } 35 36 func testDgghrd(t *testing.T, impl Dgghrder, rnd *rand.Rand, compq, compz lapack.OrthoComp, n, ilo, ihi, lda, ldb, ldq, ldz int) { 37 const tol = 1e-13 38 39 a := randomGeneral(n, n, lda, rnd) 40 b := randomGeneral(n, n, ldb, rnd) 41 42 var q, q1 blas64.General 43 switch compq { 44 case lapack.OrthoExplicit: 45 // Initialize q to a non-orthogonal matrix, Dgghrd should overwrite it 46 // with an orthogonal Q. 47 q = randomGeneral(n, n, ldq, rnd) 48 case lapack.OrthoPostmul: 49 // Initialize q to an orthogonal matrix Q1, so that the result Q1*Q is 50 // again orthogonal. 51 q = randomOrthogonal(n, rnd) 52 q1 = cloneGeneral(q) 53 } 54 55 var z, z1 blas64.General 56 switch compz { 57 case lapack.OrthoExplicit: 58 z = randomGeneral(n, n, ldz, rnd) 59 case lapack.OrthoPostmul: 60 z = randomOrthogonal(n, rnd) 61 z1 = cloneGeneral(z) 62 } 63 64 hGot := cloneGeneral(a) 65 tGot := cloneGeneral(b) 66 impl.Dgghrd(compq, compz, n, ilo, ihi, hGot.Data, hGot.Stride, tGot.Data, tGot.Stride, q.Data, max(1, q.Stride), z.Data, max(1, z.Stride)) 67 68 if n == 0 { 69 return 70 } 71 72 name := fmt.Sprintf("Case compq=%v,compz=%v,n=%v,ilo=%v,ihi=%v", compq, compz, n, ilo, ihi) 73 74 if !isUpperHessenberg(hGot) { 75 t.Errorf("%v: H is not upper Hessenberg", name) 76 } 77 if !isUpperTriangular(tGot) { 78 t.Errorf("%v: T is not upper triangular", name) 79 } 80 if compq != lapack.OrthoNone { 81 if resid := residualOrthogonal(q, true); resid > tol { 82 t.Errorf("%v: Q is not orthogonal, resid=%v", name, resid) 83 } 84 } 85 if compz != lapack.OrthoNone { 86 if resid := residualOrthogonal(z, true); resid > tol { 87 t.Errorf("%v: Z is not orthogonal, resid=%v", name, resid) 88 } 89 } 90 91 if compq != compz { 92 // Verify reduction only when both Q and Z are computed. 93 return 94 } 95 96 // Zero out the lower triangle of B. 97 for i := 1; i < n; i++ { 98 for j := 0; j < i; j++ { 99 b.Data[i*b.Stride+j] = 0 100 } 101 } 102 103 aux := zeros(n, n, n) 104 switch compq { 105 case lapack.OrthoExplicit: 106 // Qᵀ*A*Z = H 107 hCalc := zeros(n, n, n) 108 blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, a, 0, aux) 109 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, hCalc) 110 if !equalApproxGeneral(hGot, hCalc, tol) { 111 t.Errorf("%v: Qᵀ*A*Z != H", name) 112 } 113 114 // Qᵀ*B*Z = T 115 tCalc := zeros(n, n, n) 116 blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, b, 0, aux) 117 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aux, z, 1, tCalc) 118 if !equalApproxGeneral(tGot, tCalc, tol) { 119 t.Errorf("%v: Qᵀ*B*Z != T", name) 120 } 121 case lapack.OrthoPostmul: 122 // Q1 * A * Z1ᵀ = (Q1*Q) * H * (Z1*Z)ᵀ 123 lhs := zeros(n, n, n) 124 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, a, 0, aux) 125 blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs) 126 127 rhs := zeros(n, n, n) 128 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, hGot, 0, aux) 129 blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs) 130 if !equalApproxGeneral(lhs, rhs, tol) { 131 t.Errorf("%v: Q1 * A * Z1ᵀ != (Q1*Q) * H * (Z1*Z)ᵀ", name) 132 } 133 134 // Q1 * B * Z1ᵀ = (Q1*Q) * T * (Z1*Z)ᵀ 135 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q1, b, 0, aux) 136 blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z1, 0, lhs) 137 138 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, tGot, 0, aux) 139 blas64.Gemm(blas.NoTrans, blas.Trans, 1, aux, z, 0, rhs) 140 if !equalApproxGeneral(lhs, rhs, tol) { 141 t.Errorf("%v: Q1 * B * Z1ᵀ != (Q1*Q) * T * (Z1*Z)ᵀ", name) 142 } 143 } 144 }