gonum.org/v1/gonum@v0.14.0/lapack/testlapack/dlarf.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/blas" 14 "gonum.org/v1/gonum/blas/blas64" 15 "gonum.org/v1/gonum/floats" 16 "gonum.org/v1/gonum/lapack" 17 ) 18 19 type Dlarfer interface { 20 Dlarf(side blas.Side, m, n int, v []float64, incv int, tau float64, c []float64, ldc int, work []float64) 21 } 22 23 func DlarfTest(t *testing.T, impl Dlarfer) { 24 for _, side := range []blas.Side{blas.Left, blas.Right} { 25 name := sideToString(side) 26 t.Run(name, func(t *testing.T) { 27 runDlarfTest(t, impl, side) 28 }) 29 } 30 } 31 32 func runDlarfTest(t *testing.T, impl Dlarfer, side blas.Side) { 33 rnd := rand.New(rand.NewSource(1)) 34 for _, m := range []int{0, 1, 2, 3, 4, 5, 10} { 35 for _, n := range []int{0, 1, 2, 3, 4, 5, 10} { 36 for _, incv := range []int{1, 4} { 37 for _, ldc := range []int{max(1, n), n + 3} { 38 for _, nnzv := range []int{0, 1, 2} { 39 for _, nnzc := range []int{0, 1, 2} { 40 for _, tau := range []float64{0, rnd.NormFloat64()} { 41 dlarfTest(t, impl, rnd, side, m, n, incv, ldc, nnzv, nnzc, tau) 42 } 43 } 44 } 45 } 46 } 47 } 48 } 49 } 50 51 func dlarfTest(t *testing.T, impl Dlarfer, rnd *rand.Rand, side blas.Side, m, n, incv, ldc, nnzv, nnzc int, tau float64) { 52 const tol = 1e-14 53 54 c := make([]float64, m*ldc) 55 for i := range c { 56 c[i] = rnd.NormFloat64() 57 } 58 switch nnzc { 59 case 0: 60 // Zero out all of C. 61 for i := 0; i < m; i++ { 62 for j := 0; j < n; j++ { 63 c[i*ldc+j] = 0 64 } 65 } 66 case 1: 67 // Zero out right or bottom half of C. 68 if side == blas.Left { 69 for i := 0; i < m; i++ { 70 for j := n / 2; j < n; j++ { 71 c[i*ldc+j] = 0 72 } 73 } 74 } else { 75 for i := m / 2; i < m; i++ { 76 for j := 0; j < n; j++ { 77 c[i*ldc+j] = 0 78 } 79 } 80 } 81 default: 82 // Leave C with random content. 83 } 84 cCopy := make([]float64, len(c)) 85 copy(cCopy, c) 86 87 var work []float64 88 if side == blas.Left { 89 work = make([]float64, n) 90 } else { 91 work = make([]float64, m) 92 } 93 94 vlen := n 95 if side == blas.Left { 96 vlen = m 97 } 98 vlen = max(1, vlen) 99 v := make([]float64, 1+(vlen-1)*incv) 100 for i := range v { 101 v[i] = rnd.NormFloat64() 102 } 103 switch nnzv { 104 case 0: 105 // Zero out all of v. 106 for i := 0; i < vlen; i++ { 107 v[i*incv] = 0 108 } 109 case 1: 110 // Zero out half of v. 111 for i := vlen / 2; i < vlen; i++ { 112 v[i*incv] = 0 113 } 114 default: 115 // Leave v with random content. 116 } 117 vCopy := make([]float64, len(v)) 118 copy(vCopy, v) 119 120 impl.Dlarf(side, m, n, v, incv, tau, c, ldc, work) 121 got := c 122 123 name := fmt.Sprintf("m=%d,n=%d,incv=%d,tau=%f,ldc=%d", m, n, incv, tau, ldc) 124 125 if !floats.Equal(v, vCopy) { 126 t.Errorf("%v: unexpected modification of v", name) 127 } 128 if tau == 0 && !floats.Equal(got, cCopy) { 129 t.Errorf("%v: unexpected modification of C", name) 130 } 131 132 if m == 0 || n == 0 || tau == 0 { 133 return 134 } 135 136 bi := blas64.Implementation() 137 138 want := make([]float64, len(cCopy)) 139 if side == blas.Left { 140 // Compute want = (I - tau * v * vᵀ) * C 141 142 // vtc = -tau * vᵀ * C = -tau * Cᵀ * v 143 vtc := make([]float64, n) 144 bi.Dgemv(blas.Trans, m, n, -tau, cCopy, ldc, v, incv, 0, vtc, 1) 145 146 // want = C + v * vtcᵀ 147 for i := 0; i < m; i++ { 148 for j := 0; j < n; j++ { 149 want[i*ldc+j] = cCopy[i*ldc+j] + v[i*incv]*vtc[j] 150 } 151 } 152 } else { 153 // Compute want = C * (I - tau * v * vᵀ) 154 155 // cv = -tau * C * v 156 cv := make([]float64, m) 157 bi.Dgemv(blas.NoTrans, m, n, -tau, cCopy, ldc, v, incv, 0, cv, 1) 158 159 // want = C + cv * vᵀ 160 for i := 0; i < m; i++ { 161 for j := 0; j < n; j++ { 162 want[i*ldc+j] = cCopy[i*ldc+j] + cv[i]*v[j*incv] 163 } 164 } 165 } 166 diff := make([]float64, m*n) 167 for i := 0; i < m; i++ { 168 for j := 0; j < n; j++ { 169 diff[i*n+j] = got[i*ldc+j] - want[i*ldc+j] 170 } 171 } 172 resid := dlange(lapack.MaxColumnSum, m, n, diff, n) 173 if resid > tol*float64(max(m, n)) { 174 t.Errorf("%v: unexpected result; resid=%v, want<=%v", name, resid, tol*float64(max(m, n))) 175 } 176 }