github.com/gopherd/gonum@v0.0.4/lapack/gonum/dlarfb.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gonum 6 7 import ( 8 "github.com/gopherd/gonum/blas" 9 "github.com/gopherd/gonum/blas/blas64" 10 "github.com/gopherd/gonum/lapack" 11 ) 12 13 // Dlarfb applies a block reflector to a matrix. 14 // 15 // In the call to Dlarfb, the mxn c is multiplied by the implicitly defined matrix h as follows: 16 // c = h * c if side == Left and trans == NoTrans 17 // c = c * h if side == Right and trans == NoTrans 18 // c = hᵀ * c if side == Left and trans == Trans 19 // c = c * hᵀ if side == Right and trans == Trans 20 // h is a product of elementary reflectors. direct sets the direction of multiplication 21 // h = h_1 * h_2 * ... * h_k if direct == Forward 22 // h = h_k * h_k-1 * ... * h_1 if direct == Backward 23 // The combination of direct and store defines the orientation of the elementary 24 // reflectors. In all cases the ones on the diagonal are implicitly represented. 25 // 26 // If direct == lapack.Forward and store == lapack.ColumnWise 27 // V = [ 1 ] 28 // [v1 1 ] 29 // [v1 v2 1] 30 // [v1 v2 v3] 31 // [v1 v2 v3] 32 // If direct == lapack.Forward and store == lapack.RowWise 33 // V = [ 1 v1 v1 v1 v1] 34 // [ 1 v2 v2 v2] 35 // [ 1 v3 v3] 36 // If direct == lapack.Backward and store == lapack.ColumnWise 37 // V = [v1 v2 v3] 38 // [v1 v2 v3] 39 // [ 1 v2 v3] 40 // [ 1 v3] 41 // [ 1] 42 // If direct == lapack.Backward and store == lapack.RowWise 43 // V = [v1 v1 1 ] 44 // [v2 v2 v2 1 ] 45 // [v3 v3 v3 v3 1] 46 // An elementary reflector can be explicitly constructed by extracting the 47 // corresponding elements of v, placing a 1 where the diagonal would be, and 48 // placing zeros in the remaining elements. 49 // 50 // t is a k×k matrix containing the block reflector, and this function will panic 51 // if t is not of sufficient size. See Dlarft for more information. 52 // 53 // work is a temporary storage matrix with stride ldwork. 54 // work must be of size at least n×k side == Left and m×k if side == Right, and 55 // this function will panic if this size is not met. 56 // 57 // Dlarfb is an internal routine. It is exported for testing purposes. 58 func (Implementation) Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct, store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int, c []float64, ldc int, work []float64, ldwork int) { 59 nv := m 60 if side == blas.Right { 61 nv = n 62 } 63 switch { 64 case side != blas.Left && side != blas.Right: 65 panic(badSide) 66 case trans != blas.Trans && trans != blas.NoTrans: 67 panic(badTrans) 68 case direct != lapack.Forward && direct != lapack.Backward: 69 panic(badDirect) 70 case store != lapack.ColumnWise && store != lapack.RowWise: 71 panic(badStoreV) 72 case m < 0: 73 panic(mLT0) 74 case n < 0: 75 panic(nLT0) 76 case k < 0: 77 panic(kLT0) 78 case store == lapack.ColumnWise && ldv < max(1, k): 79 panic(badLdV) 80 case store == lapack.RowWise && ldv < max(1, nv): 81 panic(badLdV) 82 case ldt < max(1, k): 83 panic(badLdT) 84 case ldc < max(1, n): 85 panic(badLdC) 86 case ldwork < max(1, k): 87 panic(badLdWork) 88 } 89 90 if m == 0 || n == 0 { 91 return 92 } 93 94 nw := n 95 if side == blas.Right { 96 nw = m 97 } 98 switch { 99 case store == lapack.ColumnWise && len(v) < (nv-1)*ldv+k: 100 panic(shortV) 101 case store == lapack.RowWise && len(v) < (k-1)*ldv+nv: 102 panic(shortV) 103 case len(t) < (k-1)*ldt+k: 104 panic(shortT) 105 case len(c) < (m-1)*ldc+n: 106 panic(shortC) 107 case len(work) < (nw-1)*ldwork+k: 108 panic(shortWork) 109 } 110 111 bi := blas64.Implementation() 112 113 transt := blas.Trans 114 if trans == blas.Trans { 115 transt = blas.NoTrans 116 } 117 // TODO(btracey): This follows the original Lapack code where the 118 // elements are copied into the columns of the working array. The 119 // loops should go in the other direction so the data is written 120 // into the rows of work so the copy is not strided. A bigger change 121 // would be to replace work with workᵀ, but benchmarks would be 122 // needed to see if the change is merited. 123 if store == lapack.ColumnWise { 124 if direct == lapack.Forward { 125 // V1 is the first k rows of C. V2 is the remaining rows. 126 if side == blas.Left { 127 // W = Cᵀ V = C1ᵀ V1 + C2ᵀ V2 (stored in work). 128 129 // W = C1. 130 for j := 0; j < k; j++ { 131 bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork) 132 } 133 // W = W * V1. 134 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, 135 n, k, 1, 136 v, ldv, 137 work, ldwork) 138 if m > k { 139 // W = W + C2ᵀ V2. 140 bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k, 141 1, c[k*ldc:], ldc, v[k*ldv:], ldv, 142 1, work, ldwork) 143 } 144 // W = W * Tᵀ or W * T. 145 bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k, 146 1, t, ldt, 147 work, ldwork) 148 // C -= V * Wᵀ. 149 if m > k { 150 // C2 -= V2 * Wᵀ. 151 bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k, 152 -1, v[k*ldv:], ldv, work, ldwork, 153 1, c[k*ldc:], ldc) 154 } 155 // W *= V1ᵀ. 156 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k, 157 1, v, ldv, 158 work, ldwork) 159 // C1 -= Wᵀ. 160 // TODO(btracey): This should use blas.Axpy. 161 for i := 0; i < n; i++ { 162 for j := 0; j < k; j++ { 163 c[j*ldc+i] -= work[i*ldwork+j] 164 } 165 } 166 return 167 } 168 // Form C = C * H or C * Hᵀ, where C = (C1 C2). 169 170 // W = C1. 171 for i := 0; i < k; i++ { 172 bi.Dcopy(m, c[i:], ldc, work[i:], ldwork) 173 } 174 // W *= V1. 175 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k, 176 1, v, ldv, 177 work, ldwork) 178 if n > k { 179 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k, 180 1, c[k:], ldc, v[k*ldv:], ldv, 181 1, work, ldwork) 182 } 183 // W *= T or Tᵀ. 184 bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k, 185 1, t, ldt, 186 work, ldwork) 187 if n > k { 188 bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k, 189 -1, work, ldwork, v[k*ldv:], ldv, 190 1, c[k:], ldc) 191 } 192 // C -= W * Vᵀ. 193 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k, 194 1, v, ldv, 195 work, ldwork) 196 // C -= W. 197 // TODO(btracey): This should use blas.Axpy. 198 for i := 0; i < m; i++ { 199 for j := 0; j < k; j++ { 200 c[i*ldc+j] -= work[i*ldwork+j] 201 } 202 } 203 return 204 } 205 // V = (V1) 206 // = (V2) (last k rows) 207 // Where V2 is unit upper triangular. 208 if side == blas.Left { 209 // Form H * C or 210 // W = Cᵀ V. 211 212 // W = C2ᵀ. 213 for j := 0; j < k; j++ { 214 bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork) 215 } 216 // W *= V2. 217 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k, 218 1, v[(m-k)*ldv:], ldv, 219 work, ldwork) 220 if m > k { 221 // W += C1ᵀ * V1. 222 bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k, 223 1, c, ldc, v, ldv, 224 1, work, ldwork) 225 } 226 // W *= T or Tᵀ. 227 bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k, 228 1, t, ldt, 229 work, ldwork) 230 // C -= V * Wᵀ. 231 if m > k { 232 bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k, 233 -1, v, ldv, work, ldwork, 234 1, c, ldc) 235 } 236 // W *= V2ᵀ. 237 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k, 238 1, v[(m-k)*ldv:], ldv, 239 work, ldwork) 240 // C2 -= Wᵀ. 241 // TODO(btracey): This should use blas.Axpy. 242 for i := 0; i < n; i++ { 243 for j := 0; j < k; j++ { 244 c[(m-k+j)*ldc+i] -= work[i*ldwork+j] 245 } 246 } 247 return 248 } 249 // Form C * H or C * Hᵀ where C = (C1 C2). 250 // W = C * V. 251 252 // W = C2. 253 for j := 0; j < k; j++ { 254 bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork) 255 } 256 257 // W = W * V2. 258 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k, 259 1, v[(n-k)*ldv:], ldv, 260 work, ldwork) 261 if n > k { 262 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k, 263 1, c, ldc, v, ldv, 264 1, work, ldwork) 265 } 266 // W *= T or Tᵀ. 267 bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k, 268 1, t, ldt, 269 work, ldwork) 270 // C -= W * Vᵀ. 271 if n > k { 272 // C1 -= W * V1ᵀ. 273 bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k, 274 -1, work, ldwork, v, ldv, 275 1, c, ldc) 276 } 277 // W *= V2ᵀ. 278 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k, 279 1, v[(n-k)*ldv:], ldv, 280 work, ldwork) 281 // C2 -= W. 282 // TODO(btracey): This should use blas.Axpy. 283 for i := 0; i < m; i++ { 284 for j := 0; j < k; j++ { 285 c[i*ldc+n-k+j] -= work[i*ldwork+j] 286 } 287 } 288 return 289 } 290 // Store = Rowwise. 291 if direct == lapack.Forward { 292 // V = (V1 V2) where v1 is unit upper triangular. 293 if side == blas.Left { 294 // Form H * C or Hᵀ * C where C = (C1; C2). 295 // W = Cᵀ * Vᵀ. 296 297 // W = C1ᵀ. 298 for j := 0; j < k; j++ { 299 bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork) 300 } 301 // W *= V1ᵀ. 302 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k, 303 1, v, ldv, 304 work, ldwork) 305 if m > k { 306 bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k, 307 1, c[k*ldc:], ldc, v[k:], ldv, 308 1, work, ldwork) 309 } 310 // W *= T or Tᵀ. 311 bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k, 312 1, t, ldt, 313 work, ldwork) 314 // C -= Vᵀ * Wᵀ. 315 if m > k { 316 bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k, 317 -1, v[k:], ldv, work, ldwork, 318 1, c[k*ldc:], ldc) 319 } 320 // W *= V1. 321 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k, 322 1, v, ldv, 323 work, ldwork) 324 // C1 -= Wᵀ. 325 // TODO(btracey): This should use blas.Axpy. 326 for i := 0; i < n; i++ { 327 for j := 0; j < k; j++ { 328 c[j*ldc+i] -= work[i*ldwork+j] 329 } 330 } 331 return 332 } 333 // Form C * H or C * Hᵀ where C = (C1 C2). 334 // W = C * Vᵀ. 335 336 // W = C1. 337 for j := 0; j < k; j++ { 338 bi.Dcopy(m, c[j:], ldc, work[j:], ldwork) 339 } 340 // W *= V1ᵀ. 341 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k, 342 1, v, ldv, 343 work, ldwork) 344 if n > k { 345 bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k, 346 1, c[k:], ldc, v[k:], ldv, 347 1, work, ldwork) 348 } 349 // W *= T or Tᵀ. 350 bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k, 351 1, t, ldt, 352 work, ldwork) 353 // C -= W * V. 354 if n > k { 355 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k, 356 -1, work, ldwork, v[k:], ldv, 357 1, c[k:], ldc) 358 } 359 // W *= V1. 360 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k, 361 1, v, ldv, 362 work, ldwork) 363 // C1 -= W. 364 // TODO(btracey): This should use blas.Axpy. 365 for i := 0; i < m; i++ { 366 for j := 0; j < k; j++ { 367 c[i*ldc+j] -= work[i*ldwork+j] 368 } 369 } 370 return 371 } 372 // V = (V1 V2) where V2 is the last k columns and is lower unit triangular. 373 if side == blas.Left { 374 // Form H * C or Hᵀ C where C = (C1 ; C2). 375 // W = Cᵀ * Vᵀ. 376 377 // W = C2ᵀ. 378 for j := 0; j < k; j++ { 379 bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork) 380 } 381 // W *= V2ᵀ. 382 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k, 383 1, v[m-k:], ldv, 384 work, ldwork) 385 if m > k { 386 bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k, 387 1, c, ldc, v, ldv, 388 1, work, ldwork) 389 } 390 // W *= T or Tᵀ. 391 bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k, 392 1, t, ldt, 393 work, ldwork) 394 // C -= Vᵀ * Wᵀ. 395 if m > k { 396 bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k, 397 -1, v, ldv, work, ldwork, 398 1, c, ldc) 399 } 400 // W *= V2. 401 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, k, 402 1, v[m-k:], ldv, 403 work, ldwork) 404 // C2 -= Wᵀ. 405 // TODO(btracey): This should use blas.Axpy. 406 for i := 0; i < n; i++ { 407 for j := 0; j < k; j++ { 408 c[(m-k+j)*ldc+i] -= work[i*ldwork+j] 409 } 410 } 411 return 412 } 413 // Form C * H or C * Hᵀ where C = (C1 C2). 414 // W = C * Vᵀ. 415 // W = C2. 416 for j := 0; j < k; j++ { 417 bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork) 418 } 419 // W *= V2ᵀ. 420 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k, 421 1, v[n-k:], ldv, 422 work, ldwork) 423 if n > k { 424 bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k, 425 1, c, ldc, v, ldv, 426 1, work, ldwork) 427 } 428 // W *= T or Tᵀ. 429 bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k, 430 1, t, ldt, 431 work, ldwork) 432 // C -= W * V. 433 if n > k { 434 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k, 435 -1, work, ldwork, v, ldv, 436 1, c, ldc) 437 } 438 // W *= V2. 439 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k, 440 1, v[n-k:], ldv, 441 work, ldwork) 442 // C1 -= W. 443 // TODO(btracey): This should use blas.Axpy. 444 for i := 0; i < m; i++ { 445 for j := 0; j < k; j++ { 446 c[i*ldc+n-k+j] -= work[i*ldwork+j] 447 } 448 } 449 }