github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/native/dlarfb.go (about) 1 // Copyright ©2015 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package native 6 7 import ( 8 "github.com/gonum/blas" 9 "github.com/gonum/blas/blas64" 10 "github.com/gonum/lapack" 11 ) 12 13 // Dlarfb applies a block reflector to a matrix. 14 // 15 // In the call to Dlarfb, the mxn c is multiplied by the implicitly defined matrix h as follows: 16 // c = h * c if side == Left and trans == NoTrans 17 // c = c * h if side == Right and trans == NoTrans 18 // c = h^T * c if side == Left and trans == Trans 19 // c = c * h^T if side == Right and trans == Trans 20 // h is a product of elementary reflectors. direct sets the direction of multiplication 21 // h = h_1 * h_2 * ... * h_k if direct == Forward 22 // h = h_k * h_k-1 * ... * h_1 if direct == Backward 23 // The combination of direct and store defines the orientation of the elementary 24 // reflectors. In all cases the ones on the diagonal are implicitly represented. 25 // 26 // If direct == lapack.Forward and store == lapack.ColumnWise 27 // V = [ 1 ] 28 // [v1 1 ] 29 // [v1 v2 1] 30 // [v1 v2 v3] 31 // [v1 v2 v3] 32 // If direct == lapack.Forward and store == lapack.RowWise 33 // V = [ 1 v1 v1 v1 v1] 34 // [ 1 v2 v2 v2] 35 // [ 1 v3 v3] 36 // If direct == lapack.Backward and store == lapack.ColumnWise 37 // V = [v1 v2 v3] 38 // [v1 v2 v3] 39 // [ 1 v2 v3] 40 // [ 1 v3] 41 // [ 1] 42 // If direct == lapack.Backward and store == lapack.RowWise 43 // V = [v1 v1 1 ] 44 // [v2 v2 v2 1 ] 45 // [v3 v3 v3 v3 1] 46 // An elementary reflector can be explicitly constructed by extracting the 47 // corresponding elements of v, placing a 1 where the diagonal would be, and 48 // placing zeros in the remaining elements. 49 // 50 // t is a k×k matrix containing the block reflector, and this function will panic 51 // if t is not of sufficient size. See Dlarft for more information. 52 // 53 // work is a temporary storage matrix with stride ldwork. 54 // work must be of size at least n×k side == Left and m×k if side == Right, and 55 // this function will panic if this size is not met. 56 // 57 // Dlarfb is an internal routine. It is exported for testing purposes. 58 func (Implementation) Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct, store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int, c []float64, ldc int, work []float64, ldwork int) { 59 if side != blas.Left && side != blas.Right { 60 panic(badSide) 61 } 62 if trans != blas.Trans && trans != blas.NoTrans { 63 panic(badTrans) 64 } 65 if direct != lapack.Forward && direct != lapack.Backward { 66 panic(badDirect) 67 } 68 if store != lapack.ColumnWise && store != lapack.RowWise { 69 panic(badStore) 70 } 71 checkMatrix(m, n, c, ldc) 72 if k < 0 { 73 panic(kLT0) 74 } 75 checkMatrix(k, k, t, ldt) 76 nv := m 77 nw := n 78 if side == blas.Right { 79 nv = n 80 nw = m 81 } 82 if store == lapack.ColumnWise { 83 checkMatrix(nv, k, v, ldv) 84 } else { 85 checkMatrix(k, nv, v, ldv) 86 } 87 checkMatrix(nw, k, work, ldwork) 88 89 if m == 0 || n == 0 { 90 return 91 } 92 93 bi := blas64.Implementation() 94 95 transt := blas.Trans 96 if trans == blas.Trans { 97 transt = blas.NoTrans 98 } 99 // TODO(btracey): This follows the original Lapack code where the 100 // elements are copied into the columns of the working array. The 101 // loops should go in the other direction so the data is written 102 // into the rows of work so the copy is not strided. A bigger change 103 // would be to replace work with work^T, but benchmarks would be 104 // needed to see if the change is merited. 105 if store == lapack.ColumnWise { 106 if direct == lapack.Forward { 107 // V1 is the first k rows of C. V2 is the remaining rows. 108 if side == blas.Left { 109 // W = C^T V = C1^T V1 + C2^T V2 (stored in work). 110 111 // W = C1. 112 for j := 0; j < k; j++ { 113 bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork) 114 } 115 // W = W * V1. 116 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, 117 n, k, 1, 118 v, ldv, 119 work, ldwork) 120 if m > k { 121 // W = W + C2^T V2. 122 bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k, 123 1, c[k*ldc:], ldc, v[k*ldv:], ldv, 124 1, work, ldwork) 125 } 126 // W = W * T^T or W * T. 127 bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k, 128 1, t, ldt, 129 work, ldwork) 130 // C -= V * W^T. 131 if m > k { 132 // C2 -= V2 * W^T. 133 bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k, 134 -1, v[k*ldv:], ldv, work, ldwork, 135 1, c[k*ldc:], ldc) 136 } 137 // W *= V1^T. 138 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k, 139 1, v, ldv, 140 work, ldwork) 141 // C1 -= W^T. 142 // TODO(btracey): This should use blas.Axpy. 143 for i := 0; i < n; i++ { 144 for j := 0; j < k; j++ { 145 c[j*ldc+i] -= work[i*ldwork+j] 146 } 147 } 148 return 149 } 150 // Form C = C * H or C * H^T, where C = (C1 C2). 151 152 // W = C1. 153 for i := 0; i < k; i++ { 154 bi.Dcopy(m, c[i:], ldc, work[i:], ldwork) 155 } 156 // W *= V1. 157 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k, 158 1, v, ldv, 159 work, ldwork) 160 if n > k { 161 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k, 162 1, c[k:], ldc, v[k*ldv:], ldv, 163 1, work, ldwork) 164 } 165 // W *= T or T^T. 166 bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k, 167 1, t, ldt, 168 work, ldwork) 169 if n > k { 170 bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k, 171 -1, work, ldwork, v[k*ldv:], ldv, 172 1, c[k:], ldc) 173 } 174 // C -= W * V^T. 175 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k, 176 1, v, ldv, 177 work, ldwork) 178 // C -= W. 179 // TODO(btracey): This should use blas.Axpy. 180 for i := 0; i < m; i++ { 181 for j := 0; j < k; j++ { 182 c[i*ldc+j] -= work[i*ldwork+j] 183 } 184 } 185 return 186 } 187 // V = (V1) 188 // = (V2) (last k rows) 189 // Where V2 is unit upper triangular. 190 if side == blas.Left { 191 // Form H * C or 192 // W = C^T V. 193 194 // W = C2^T. 195 for j := 0; j < k; j++ { 196 bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork) 197 } 198 // W *= V2. 199 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k, 200 1, v[(m-k)*ldv:], ldv, 201 work, ldwork) 202 if m > k { 203 // W += C1^T * V1. 204 bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k, 205 1, c, ldc, v, ldv, 206 1, work, ldwork) 207 } 208 // W *= T or T^T. 209 bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k, 210 1, t, ldt, 211 work, ldwork) 212 // C -= V * W^T. 213 if m > k { 214 bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k, 215 -1, v, ldv, work, ldwork, 216 1, c, ldc) 217 } 218 // W *= V2^T. 219 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k, 220 1, v[(m-k)*ldv:], ldv, 221 work, ldwork) 222 // C2 -= W^T. 223 // TODO(btracey): This should use blas.Axpy. 224 for i := 0; i < n; i++ { 225 for j := 0; j < k; j++ { 226 c[(m-k+j)*ldc+i] -= work[i*ldwork+j] 227 } 228 } 229 return 230 } 231 // Form C * H or C * H^T where C = (C1 C2). 232 // W = C * V. 233 234 // W = C2. 235 for j := 0; j < k; j++ { 236 bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork) 237 } 238 239 // W = W * V2. 240 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k, 241 1, v[(n-k)*ldv:], ldv, 242 work, ldwork) 243 if n > k { 244 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k, 245 1, c, ldc, v, ldv, 246 1, work, ldwork) 247 } 248 // W *= T or T^T. 249 bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k, 250 1, t, ldt, 251 work, ldwork) 252 // C -= W * V^T. 253 if n > k { 254 // C1 -= W * V1^T. 255 bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k, 256 -1, work, ldwork, v, ldv, 257 1, c, ldc) 258 } 259 // W *= V2^T. 260 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k, 261 1, v[(n-k)*ldv:], ldv, 262 work, ldwork) 263 // C2 -= W. 264 // TODO(btracey): This should use blas.Axpy. 265 for i := 0; i < m; i++ { 266 for j := 0; j < k; j++ { 267 c[i*ldc+n-k+j] -= work[i*ldwork+j] 268 } 269 } 270 return 271 } 272 // Store = Rowwise. 273 if direct == lapack.Forward { 274 // V = (V1 V2) where v1 is unit upper triangular. 275 if side == blas.Left { 276 // Form H * C or H^T * C where C = (C1; C2). 277 // W = C^T * V^T. 278 279 // W = C1^T. 280 for j := 0; j < k; j++ { 281 bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork) 282 } 283 // W *= V1^T. 284 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k, 285 1, v, ldv, 286 work, ldwork) 287 if m > k { 288 bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k, 289 1, c[k*ldc:], ldc, v[k:], ldv, 290 1, work, ldwork) 291 } 292 // W *= T or T^T. 293 bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k, 294 1, t, ldt, 295 work, ldwork) 296 // C -= V^T * W^T. 297 if m > k { 298 bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k, 299 -1, v[k:], ldv, work, ldwork, 300 1, c[k*ldc:], ldc) 301 } 302 // W *= V1. 303 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k, 304 1, v, ldv, 305 work, ldwork) 306 // C1 -= W^T. 307 // TODO(btracey): This should use blas.Axpy. 308 for i := 0; i < n; i++ { 309 for j := 0; j < k; j++ { 310 c[j*ldc+i] -= work[i*ldwork+j] 311 } 312 } 313 return 314 } 315 // Form C * H or C * H^T where C = (C1 C2). 316 // W = C * V^T. 317 318 // W = C1. 319 for j := 0; j < k; j++ { 320 bi.Dcopy(m, c[j:], ldc, work[j:], ldwork) 321 } 322 // W *= V1^T. 323 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k, 324 1, v, ldv, 325 work, ldwork) 326 if n > k { 327 bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k, 328 1, c[k:], ldc, v[k:], ldv, 329 1, work, ldwork) 330 } 331 // W *= T or T^T. 332 bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k, 333 1, t, ldt, 334 work, ldwork) 335 // C -= W * V. 336 if n > k { 337 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k, 338 -1, work, ldwork, v[k:], ldv, 339 1, c[k:], ldc) 340 } 341 // W *= V1. 342 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k, 343 1, v, ldv, 344 work, ldwork) 345 // C1 -= W. 346 // TODO(btracey): This should use blas.Axpy. 347 for i := 0; i < m; i++ { 348 for j := 0; j < k; j++ { 349 c[i*ldc+j] -= work[i*ldwork+j] 350 } 351 } 352 return 353 } 354 // V = (V1 V2) where V2 is the last k columns and is lower unit triangular. 355 if side == blas.Left { 356 // Form H * C or H^T C where C = (C1 ; C2). 357 // W = C^T * V^T. 358 359 // W = C2^T. 360 for j := 0; j < k; j++ { 361 bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork) 362 } 363 // W *= V2^T. 364 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k, 365 1, v[m-k:], ldv, 366 work, ldwork) 367 if m > k { 368 bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k, 369 1, c, ldc, v, ldv, 370 1, work, ldwork) 371 } 372 // W *= T or T^T. 373 bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k, 374 1, t, ldt, 375 work, ldwork) 376 // C -= V^T * W^T. 377 if m > k { 378 bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k, 379 -1, v, ldv, work, ldwork, 380 1, c, ldc) 381 } 382 // W *= V2. 383 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, k, 384 1, v[m-k:], ldv, 385 work, ldwork) 386 // C2 -= W^T. 387 // TODO(btracey): This should use blas.Axpy. 388 for i := 0; i < n; i++ { 389 for j := 0; j < k; j++ { 390 c[(m-k+j)*ldc+i] -= work[i*ldwork+j] 391 } 392 } 393 return 394 } 395 // Form C * H or C * H^T where C = (C1 C2). 396 // W = C * V^T. 397 // W = C2. 398 for j := 0; j < k; j++ { 399 bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork) 400 } 401 // W *= V2^T. 402 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k, 403 1, v[n-k:], ldv, 404 work, ldwork) 405 if n > k { 406 bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k, 407 1, c, ldc, v, ldv, 408 1, work, ldwork) 409 } 410 // W *= T or T^T. 411 bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k, 412 1, t, ldt, 413 work, ldwork) 414 // C -= W * V. 415 if n > k { 416 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k, 417 -1, work, ldwork, v, ldv, 418 1, c, ldc) 419 } 420 // W *= V2. 421 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k, 422 1, v[n-k:], ldv, 423 work, ldwork) 424 // C1 -= W. 425 // TODO(btracey): This should use blas.Axpy. 426 for i := 0; i < m; i++ { 427 for j := 0; j < k; j++ { 428 c[i*ldc+n-k+j] -= work[i*ldwork+j] 429 } 430 } 431 }