gonum.org/v1/gonum@v0.14.0/lapack/gonum/dlarfb.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package gonum 6 7 import ( 8 "gonum.org/v1/gonum/blas" 9 "gonum.org/v1/gonum/blas/blas64" 10 "gonum.org/v1/gonum/lapack" 11 ) 12 13 // Dlarfb applies a block reflector to a matrix. 14 // 15 // In the call to Dlarfb, the mxn c is multiplied by the implicitly defined matrix h as follows: 16 // 17 // c = h * c if side == Left and trans == NoTrans 18 // c = c * h if side == Right and trans == NoTrans 19 // c = hᵀ * c if side == Left and trans == Trans 20 // c = c * hᵀ if side == Right and trans == Trans 21 // 22 // h is a product of elementary reflectors. direct sets the direction of multiplication 23 // 24 // h = h_1 * h_2 * ... * h_k if direct == Forward 25 // h = h_k * h_k-1 * ... * h_1 if direct == Backward 26 // 27 // The combination of direct and store defines the orientation of the elementary 28 // reflectors. In all cases the ones on the diagonal are implicitly represented. 29 // 30 // If direct == lapack.Forward and store == lapack.ColumnWise 31 // 32 // V = [ 1 ] 33 // [v1 1 ] 34 // [v1 v2 1] 35 // [v1 v2 v3] 36 // [v1 v2 v3] 37 // 38 // If direct == lapack.Forward and store == lapack.RowWise 39 // 40 // V = [ 1 v1 v1 v1 v1] 41 // [ 1 v2 v2 v2] 42 // [ 1 v3 v3] 43 // 44 // If direct == lapack.Backward and store == lapack.ColumnWise 45 // 46 // V = [v1 v2 v3] 47 // [v1 v2 v3] 48 // [ 1 v2 v3] 49 // [ 1 v3] 50 // [ 1] 51 // 52 // If direct == lapack.Backward and store == lapack.RowWise 53 // 54 // V = [v1 v1 1 ] 55 // [v2 v2 v2 1 ] 56 // [v3 v3 v3 v3 1] 57 // 58 // An elementary reflector can be explicitly constructed by extracting the 59 // corresponding elements of v, placing a 1 where the diagonal would be, and 60 // placing zeros in the remaining elements. 61 // 62 // t is a k×k matrix containing the block reflector, and this function will panic 63 // if t is not of sufficient size. See Dlarft for more information. 64 // 65 // work is a temporary storage matrix with stride ldwork. 66 // work must be of size at least n×k side == Left and m×k if side == Right, and 67 // this function will panic if this size is not met. 68 // 69 // Dlarfb is an internal routine. It is exported for testing purposes. 70 func (Implementation) Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct, store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int, c []float64, ldc int, work []float64, ldwork int) { 71 nv := m 72 if side == blas.Right { 73 nv = n 74 } 75 switch { 76 case side != blas.Left && side != blas.Right: 77 panic(badSide) 78 case trans != blas.Trans && trans != blas.NoTrans: 79 panic(badTrans) 80 case direct != lapack.Forward && direct != lapack.Backward: 81 panic(badDirect) 82 case store != lapack.ColumnWise && store != lapack.RowWise: 83 panic(badStoreV) 84 case m < 0: 85 panic(mLT0) 86 case n < 0: 87 panic(nLT0) 88 case k < 0: 89 panic(kLT0) 90 case store == lapack.ColumnWise && ldv < max(1, k): 91 panic(badLdV) 92 case store == lapack.RowWise && ldv < max(1, nv): 93 panic(badLdV) 94 case ldt < max(1, k): 95 panic(badLdT) 96 case ldc < max(1, n): 97 panic(badLdC) 98 case ldwork < max(1, k): 99 panic(badLdWork) 100 } 101 102 if m == 0 || n == 0 { 103 return 104 } 105 106 nw := n 107 if side == blas.Right { 108 nw = m 109 } 110 switch { 111 case store == lapack.ColumnWise && len(v) < (nv-1)*ldv+k: 112 panic(shortV) 113 case store == lapack.RowWise && len(v) < (k-1)*ldv+nv: 114 panic(shortV) 115 case len(t) < (k-1)*ldt+k: 116 panic(shortT) 117 case len(c) < (m-1)*ldc+n: 118 panic(shortC) 119 case len(work) < (nw-1)*ldwork+k: 120 panic(shortWork) 121 } 122 123 bi := blas64.Implementation() 124 125 transt := blas.Trans 126 if trans == blas.Trans { 127 transt = blas.NoTrans 128 } 129 // TODO(btracey): This follows the original Lapack code where the 130 // elements are copied into the columns of the working array. The 131 // loops should go in the other direction so the data is written 132 // into the rows of work so the copy is not strided. A bigger change 133 // would be to replace work with workᵀ, but benchmarks would be 134 // needed to see if the change is merited. 135 if store == lapack.ColumnWise { 136 if direct == lapack.Forward { 137 // V1 is the first k rows of C. V2 is the remaining rows. 138 if side == blas.Left { 139 // W = Cᵀ V = C1ᵀ V1 + C2ᵀ V2 (stored in work). 140 141 // W = C1. 142 for j := 0; j < k; j++ { 143 bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork) 144 } 145 // W = W * V1. 146 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, 147 n, k, 1, 148 v, ldv, 149 work, ldwork) 150 if m > k { 151 // W = W + C2ᵀ V2. 152 bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k, 153 1, c[k*ldc:], ldc, v[k*ldv:], ldv, 154 1, work, ldwork) 155 } 156 // W = W * Tᵀ or W * T. 157 bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k, 158 1, t, ldt, 159 work, ldwork) 160 // C -= V * Wᵀ. 161 if m > k { 162 // C2 -= V2 * Wᵀ. 163 bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k, 164 -1, v[k*ldv:], ldv, work, ldwork, 165 1, c[k*ldc:], ldc) 166 } 167 // W *= V1ᵀ. 168 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k, 169 1, v, ldv, 170 work, ldwork) 171 // C1 -= Wᵀ. 172 // TODO(btracey): This should use blas.Axpy. 173 for i := 0; i < n; i++ { 174 for j := 0; j < k; j++ { 175 c[j*ldc+i] -= work[i*ldwork+j] 176 } 177 } 178 return 179 } 180 // Form C = C * H or C * Hᵀ, where C = (C1 C2). 181 182 // W = C1. 183 for i := 0; i < k; i++ { 184 bi.Dcopy(m, c[i:], ldc, work[i:], ldwork) 185 } 186 // W *= V1. 187 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k, 188 1, v, ldv, 189 work, ldwork) 190 if n > k { 191 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k, 192 1, c[k:], ldc, v[k*ldv:], ldv, 193 1, work, ldwork) 194 } 195 // W *= T or Tᵀ. 196 bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k, 197 1, t, ldt, 198 work, ldwork) 199 if n > k { 200 bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k, 201 -1, work, ldwork, v[k*ldv:], ldv, 202 1, c[k:], ldc) 203 } 204 // C -= W * Vᵀ. 205 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k, 206 1, v, ldv, 207 work, ldwork) 208 // C -= W. 209 // TODO(btracey): This should use blas.Axpy. 210 for i := 0; i < m; i++ { 211 for j := 0; j < k; j++ { 212 c[i*ldc+j] -= work[i*ldwork+j] 213 } 214 } 215 return 216 } 217 // V = (V1) 218 // = (V2) (last k rows) 219 // Where V2 is unit upper triangular. 220 if side == blas.Left { 221 // Form H * C or 222 // W = Cᵀ V. 223 224 // W = C2ᵀ. 225 for j := 0; j < k; j++ { 226 bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork) 227 } 228 // W *= V2. 229 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k, 230 1, v[(m-k)*ldv:], ldv, 231 work, ldwork) 232 if m > k { 233 // W += C1ᵀ * V1. 234 bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k, 235 1, c, ldc, v, ldv, 236 1, work, ldwork) 237 } 238 // W *= T or Tᵀ. 239 bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k, 240 1, t, ldt, 241 work, ldwork) 242 // C -= V * Wᵀ. 243 if m > k { 244 bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k, 245 -1, v, ldv, work, ldwork, 246 1, c, ldc) 247 } 248 // W *= V2ᵀ. 249 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k, 250 1, v[(m-k)*ldv:], ldv, 251 work, ldwork) 252 // C2 -= Wᵀ. 253 // TODO(btracey): This should use blas.Axpy. 254 for i := 0; i < n; i++ { 255 for j := 0; j < k; j++ { 256 c[(m-k+j)*ldc+i] -= work[i*ldwork+j] 257 } 258 } 259 return 260 } 261 // Form C * H or C * Hᵀ where C = (C1 C2). 262 // W = C * V. 263 264 // W = C2. 265 for j := 0; j < k; j++ { 266 bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork) 267 } 268 269 // W = W * V2. 270 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k, 271 1, v[(n-k)*ldv:], ldv, 272 work, ldwork) 273 if n > k { 274 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k, 275 1, c, ldc, v, ldv, 276 1, work, ldwork) 277 } 278 // W *= T or Tᵀ. 279 bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k, 280 1, t, ldt, 281 work, ldwork) 282 // C -= W * Vᵀ. 283 if n > k { 284 // C1 -= W * V1ᵀ. 285 bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k, 286 -1, work, ldwork, v, ldv, 287 1, c, ldc) 288 } 289 // W *= V2ᵀ. 290 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k, 291 1, v[(n-k)*ldv:], ldv, 292 work, ldwork) 293 // C2 -= W. 294 // TODO(btracey): This should use blas.Axpy. 295 for i := 0; i < m; i++ { 296 for j := 0; j < k; j++ { 297 c[i*ldc+n-k+j] -= work[i*ldwork+j] 298 } 299 } 300 return 301 } 302 // Store = Rowwise. 303 if direct == lapack.Forward { 304 // V = (V1 V2) where v1 is unit upper triangular. 305 if side == blas.Left { 306 // Form H * C or Hᵀ * C where C = (C1; C2). 307 // W = Cᵀ * Vᵀ. 308 309 // W = C1ᵀ. 310 for j := 0; j < k; j++ { 311 bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork) 312 } 313 // W *= V1ᵀ. 314 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k, 315 1, v, ldv, 316 work, ldwork) 317 if m > k { 318 bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k, 319 1, c[k*ldc:], ldc, v[k:], ldv, 320 1, work, ldwork) 321 } 322 // W *= T or Tᵀ. 323 bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k, 324 1, t, ldt, 325 work, ldwork) 326 // C -= Vᵀ * Wᵀ. 327 if m > k { 328 bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k, 329 -1, v[k:], ldv, work, ldwork, 330 1, c[k*ldc:], ldc) 331 } 332 // W *= V1. 333 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k, 334 1, v, ldv, 335 work, ldwork) 336 // C1 -= Wᵀ. 337 // TODO(btracey): This should use blas.Axpy. 338 for i := 0; i < n; i++ { 339 for j := 0; j < k; j++ { 340 c[j*ldc+i] -= work[i*ldwork+j] 341 } 342 } 343 return 344 } 345 // Form C * H or C * Hᵀ where C = (C1 C2). 346 // W = C * Vᵀ. 347 348 // W = C1. 349 for j := 0; j < k; j++ { 350 bi.Dcopy(m, c[j:], ldc, work[j:], ldwork) 351 } 352 // W *= V1ᵀ. 353 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k, 354 1, v, ldv, 355 work, ldwork) 356 if n > k { 357 bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k, 358 1, c[k:], ldc, v[k:], ldv, 359 1, work, ldwork) 360 } 361 // W *= T or Tᵀ. 362 bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k, 363 1, t, ldt, 364 work, ldwork) 365 // C -= W * V. 366 if n > k { 367 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k, 368 -1, work, ldwork, v[k:], ldv, 369 1, c[k:], ldc) 370 } 371 // W *= V1. 372 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k, 373 1, v, ldv, 374 work, ldwork) 375 // C1 -= W. 376 // TODO(btracey): This should use blas.Axpy. 377 for i := 0; i < m; i++ { 378 for j := 0; j < k; j++ { 379 c[i*ldc+j] -= work[i*ldwork+j] 380 } 381 } 382 return 383 } 384 // V = (V1 V2) where V2 is the last k columns and is lower unit triangular. 385 if side == blas.Left { 386 // Form H * C or Hᵀ C where C = (C1 ; C2). 387 // W = Cᵀ * Vᵀ. 388 389 // W = C2ᵀ. 390 for j := 0; j < k; j++ { 391 bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork) 392 } 393 // W *= V2ᵀ. 394 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k, 395 1, v[m-k:], ldv, 396 work, ldwork) 397 if m > k { 398 bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k, 399 1, c, ldc, v, ldv, 400 1, work, ldwork) 401 } 402 // W *= T or Tᵀ. 403 bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k, 404 1, t, ldt, 405 work, ldwork) 406 // C -= Vᵀ * Wᵀ. 407 if m > k { 408 bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k, 409 -1, v, ldv, work, ldwork, 410 1, c, ldc) 411 } 412 // W *= V2. 413 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, k, 414 1, v[m-k:], ldv, 415 work, ldwork) 416 // C2 -= Wᵀ. 417 // TODO(btracey): This should use blas.Axpy. 418 for i := 0; i < n; i++ { 419 for j := 0; j < k; j++ { 420 c[(m-k+j)*ldc+i] -= work[i*ldwork+j] 421 } 422 } 423 return 424 } 425 // Form C * H or C * Hᵀ where C = (C1 C2). 426 // W = C * Vᵀ. 427 // W = C2. 428 for j := 0; j < k; j++ { 429 bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork) 430 } 431 // W *= V2ᵀ. 432 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k, 433 1, v[n-k:], ldv, 434 work, ldwork) 435 if n > k { 436 bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k, 437 1, c, ldc, v, ldv, 438 1, work, ldwork) 439 } 440 // W *= T or Tᵀ. 441 bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k, 442 1, t, ldt, 443 work, ldwork) 444 // C -= W * V. 445 if n > k { 446 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k, 447 -1, work, ldwork, v, ldv, 448 1, c, ldc) 449 } 450 // W *= V2. 451 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k, 452 1, v[n-k:], ldv, 453 work, ldwork) 454 // C1 -= W. 455 // TODO(btracey): This should use blas.Axpy. 456 for i := 0; i < m; i++ { 457 for j := 0; j < k; j++ { 458 c[i*ldc+n-k+j] -= work[i*ldwork+j] 459 } 460 } 461 }