gonum.org/v1/gonum@v0.14.0/lapack/gonum/dlarfb.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"gonum.org/v1/gonum/blas"
     9  	"gonum.org/v1/gonum/blas/blas64"
    10  	"gonum.org/v1/gonum/lapack"
    11  )
    12  
    13  // Dlarfb applies a block reflector to a matrix.
    14  //
    15  // In the call to Dlarfb, the mxn c is multiplied by the implicitly defined matrix h as follows:
    16  //
    17  //	c = h * c   if side == Left and trans == NoTrans
    18  //	c = c * h   if side == Right and trans == NoTrans
    19  //	c = hᵀ * c  if side == Left and trans == Trans
    20  //	c = c * hᵀ  if side == Right and trans == Trans
    21  //
    22  // h is a product of elementary reflectors. direct sets the direction of multiplication
    23  //
    24  //	h = h_1 * h_2 * ... * h_k    if direct == Forward
    25  //	h = h_k * h_k-1 * ... * h_1  if direct == Backward
    26  //
    27  // The combination of direct and store defines the orientation of the elementary
    28  // reflectors. In all cases the ones on the diagonal are implicitly represented.
    29  //
    30  // If direct == lapack.Forward and store == lapack.ColumnWise
    31  //
    32  //	V = [ 1        ]
    33  //	    [v1   1    ]
    34  //	    [v1  v2   1]
    35  //	    [v1  v2  v3]
    36  //	    [v1  v2  v3]
    37  //
    38  // If direct == lapack.Forward and store == lapack.RowWise
    39  //
    40  //	V = [ 1  v1  v1  v1  v1]
    41  //	    [     1  v2  v2  v2]
    42  //	    [         1  v3  v3]
    43  //
    44  // If direct == lapack.Backward and store == lapack.ColumnWise
    45  //
    46  //	V = [v1  v2  v3]
    47  //	    [v1  v2  v3]
    48  //	    [ 1  v2  v3]
    49  //	    [     1  v3]
    50  //	    [         1]
    51  //
    52  // If direct == lapack.Backward and store == lapack.RowWise
    53  //
    54  //	V = [v1  v1   1        ]
    55  //	    [v2  v2  v2   1    ]
    56  //	    [v3  v3  v3  v3   1]
    57  //
    58  // An elementary reflector can be explicitly constructed by extracting the
    59  // corresponding elements of v, placing a 1 where the diagonal would be, and
    60  // placing zeros in the remaining elements.
    61  //
    62  // t is a k×k matrix containing the block reflector, and this function will panic
    63  // if t is not of sufficient size. See Dlarft for more information.
    64  //
    65  // work is a temporary storage matrix with stride ldwork.
    66  // work must be of size at least n×k side == Left and m×k if side == Right, and
    67  // this function will panic if this size is not met.
    68  //
    69  // Dlarfb is an internal routine. It is exported for testing purposes.
    70  func (Implementation) Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct, store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int, c []float64, ldc int, work []float64, ldwork int) {
    71  	nv := m
    72  	if side == blas.Right {
    73  		nv = n
    74  	}
    75  	switch {
    76  	case side != blas.Left && side != blas.Right:
    77  		panic(badSide)
    78  	case trans != blas.Trans && trans != blas.NoTrans:
    79  		panic(badTrans)
    80  	case direct != lapack.Forward && direct != lapack.Backward:
    81  		panic(badDirect)
    82  	case store != lapack.ColumnWise && store != lapack.RowWise:
    83  		panic(badStoreV)
    84  	case m < 0:
    85  		panic(mLT0)
    86  	case n < 0:
    87  		panic(nLT0)
    88  	case k < 0:
    89  		panic(kLT0)
    90  	case store == lapack.ColumnWise && ldv < max(1, k):
    91  		panic(badLdV)
    92  	case store == lapack.RowWise && ldv < max(1, nv):
    93  		panic(badLdV)
    94  	case ldt < max(1, k):
    95  		panic(badLdT)
    96  	case ldc < max(1, n):
    97  		panic(badLdC)
    98  	case ldwork < max(1, k):
    99  		panic(badLdWork)
   100  	}
   101  
   102  	if m == 0 || n == 0 {
   103  		return
   104  	}
   105  
   106  	nw := n
   107  	if side == blas.Right {
   108  		nw = m
   109  	}
   110  	switch {
   111  	case store == lapack.ColumnWise && len(v) < (nv-1)*ldv+k:
   112  		panic(shortV)
   113  	case store == lapack.RowWise && len(v) < (k-1)*ldv+nv:
   114  		panic(shortV)
   115  	case len(t) < (k-1)*ldt+k:
   116  		panic(shortT)
   117  	case len(c) < (m-1)*ldc+n:
   118  		panic(shortC)
   119  	case len(work) < (nw-1)*ldwork+k:
   120  		panic(shortWork)
   121  	}
   122  
   123  	bi := blas64.Implementation()
   124  
   125  	transt := blas.Trans
   126  	if trans == blas.Trans {
   127  		transt = blas.NoTrans
   128  	}
   129  	// TODO(btracey): This follows the original Lapack code where the
   130  	// elements are copied into the columns of the working array. The
   131  	// loops should go in the other direction so the data is written
   132  	// into the rows of work so the copy is not strided. A bigger change
   133  	// would be to replace work with workᵀ, but benchmarks would be
   134  	// needed to see if the change is merited.
   135  	if store == lapack.ColumnWise {
   136  		if direct == lapack.Forward {
   137  			// V1 is the first k rows of C. V2 is the remaining rows.
   138  			if side == blas.Left {
   139  				// W = Cᵀ V = C1ᵀ V1 + C2ᵀ V2 (stored in work).
   140  
   141  				// W = C1.
   142  				for j := 0; j < k; j++ {
   143  					bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
   144  				}
   145  				// W = W * V1.
   146  				bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit,
   147  					n, k, 1,
   148  					v, ldv,
   149  					work, ldwork)
   150  				if m > k {
   151  					// W = W + C2ᵀ V2.
   152  					bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
   153  						1, c[k*ldc:], ldc, v[k*ldv:], ldv,
   154  						1, work, ldwork)
   155  				}
   156  				// W = W * Tᵀ or W * T.
   157  				bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
   158  					1, t, ldt,
   159  					work, ldwork)
   160  				// C -= V * Wᵀ.
   161  				if m > k {
   162  					// C2 -= V2 * Wᵀ.
   163  					bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
   164  						-1, v[k*ldv:], ldv, work, ldwork,
   165  						1, c[k*ldc:], ldc)
   166  				}
   167  				// W *= V1ᵀ.
   168  				bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
   169  					1, v, ldv,
   170  					work, ldwork)
   171  				// C1 -= Wᵀ.
   172  				// TODO(btracey): This should use blas.Axpy.
   173  				for i := 0; i < n; i++ {
   174  					for j := 0; j < k; j++ {
   175  						c[j*ldc+i] -= work[i*ldwork+j]
   176  					}
   177  				}
   178  				return
   179  			}
   180  			// Form C = C * H or C * Hᵀ, where C = (C1 C2).
   181  
   182  			// W = C1.
   183  			for i := 0; i < k; i++ {
   184  				bi.Dcopy(m, c[i:], ldc, work[i:], ldwork)
   185  			}
   186  			// W *= V1.
   187  			bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
   188  				1, v, ldv,
   189  				work, ldwork)
   190  			if n > k {
   191  				bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
   192  					1, c[k:], ldc, v[k*ldv:], ldv,
   193  					1, work, ldwork)
   194  			}
   195  			// W *= T or Tᵀ.
   196  			bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
   197  				1, t, ldt,
   198  				work, ldwork)
   199  			if n > k {
   200  				bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
   201  					-1, work, ldwork, v[k*ldv:], ldv,
   202  					1, c[k:], ldc)
   203  			}
   204  			// C -= W * Vᵀ.
   205  			bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
   206  				1, v, ldv,
   207  				work, ldwork)
   208  			// C -= W.
   209  			// TODO(btracey): This should use blas.Axpy.
   210  			for i := 0; i < m; i++ {
   211  				for j := 0; j < k; j++ {
   212  					c[i*ldc+j] -= work[i*ldwork+j]
   213  				}
   214  			}
   215  			return
   216  		}
   217  		// V = (V1)
   218  		//   = (V2) (last k rows)
   219  		// Where V2 is unit upper triangular.
   220  		if side == blas.Left {
   221  			// Form H * C or
   222  			// W = Cᵀ V.
   223  
   224  			// W = C2ᵀ.
   225  			for j := 0; j < k; j++ {
   226  				bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
   227  			}
   228  			// W *= V2.
   229  			bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
   230  				1, v[(m-k)*ldv:], ldv,
   231  				work, ldwork)
   232  			if m > k {
   233  				// W += C1ᵀ * V1.
   234  				bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
   235  					1, c, ldc, v, ldv,
   236  					1, work, ldwork)
   237  			}
   238  			// W *= T or Tᵀ.
   239  			bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
   240  				1, t, ldt,
   241  				work, ldwork)
   242  			// C -= V * Wᵀ.
   243  			if m > k {
   244  				bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
   245  					-1, v, ldv, work, ldwork,
   246  					1, c, ldc)
   247  			}
   248  			// W *= V2ᵀ.
   249  			bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
   250  				1, v[(m-k)*ldv:], ldv,
   251  				work, ldwork)
   252  			// C2 -= Wᵀ.
   253  			// TODO(btracey): This should use blas.Axpy.
   254  			for i := 0; i < n; i++ {
   255  				for j := 0; j < k; j++ {
   256  					c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
   257  				}
   258  			}
   259  			return
   260  		}
   261  		// Form C * H or C * Hᵀ where C = (C1 C2).
   262  		// W = C * V.
   263  
   264  		// W = C2.
   265  		for j := 0; j < k; j++ {
   266  			bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
   267  		}
   268  
   269  		// W = W * V2.
   270  		bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
   271  			1, v[(n-k)*ldv:], ldv,
   272  			work, ldwork)
   273  		if n > k {
   274  			bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
   275  				1, c, ldc, v, ldv,
   276  				1, work, ldwork)
   277  		}
   278  		// W *= T or Tᵀ.
   279  		bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
   280  			1, t, ldt,
   281  			work, ldwork)
   282  		// C -= W * Vᵀ.
   283  		if n > k {
   284  			// C1 -= W * V1ᵀ.
   285  			bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
   286  				-1, work, ldwork, v, ldv,
   287  				1, c, ldc)
   288  		}
   289  		// W *= V2ᵀ.
   290  		bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
   291  			1, v[(n-k)*ldv:], ldv,
   292  			work, ldwork)
   293  		// C2 -= W.
   294  		// TODO(btracey): This should use blas.Axpy.
   295  		for i := 0; i < m; i++ {
   296  			for j := 0; j < k; j++ {
   297  				c[i*ldc+n-k+j] -= work[i*ldwork+j]
   298  			}
   299  		}
   300  		return
   301  	}
   302  	// Store = Rowwise.
   303  	if direct == lapack.Forward {
   304  		// V = (V1 V2) where v1 is unit upper triangular.
   305  		if side == blas.Left {
   306  			// Form H * C or Hᵀ * C where C = (C1; C2).
   307  			// W = Cᵀ * Vᵀ.
   308  
   309  			// W = C1ᵀ.
   310  			for j := 0; j < k; j++ {
   311  				bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
   312  			}
   313  			// W *= V1ᵀ.
   314  			bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
   315  				1, v, ldv,
   316  				work, ldwork)
   317  			if m > k {
   318  				bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
   319  					1, c[k*ldc:], ldc, v[k:], ldv,
   320  					1, work, ldwork)
   321  			}
   322  			// W *= T or Tᵀ.
   323  			bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
   324  				1, t, ldt,
   325  				work, ldwork)
   326  			// C -= Vᵀ * Wᵀ.
   327  			if m > k {
   328  				bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
   329  					-1, v[k:], ldv, work, ldwork,
   330  					1, c[k*ldc:], ldc)
   331  			}
   332  			// W *= V1.
   333  			bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
   334  				1, v, ldv,
   335  				work, ldwork)
   336  			// C1 -= Wᵀ.
   337  			// TODO(btracey): This should use blas.Axpy.
   338  			for i := 0; i < n; i++ {
   339  				for j := 0; j < k; j++ {
   340  					c[j*ldc+i] -= work[i*ldwork+j]
   341  				}
   342  			}
   343  			return
   344  		}
   345  		// Form C * H or C * Hᵀ where C = (C1 C2).
   346  		// W = C * Vᵀ.
   347  
   348  		// W = C1.
   349  		for j := 0; j < k; j++ {
   350  			bi.Dcopy(m, c[j:], ldc, work[j:], ldwork)
   351  		}
   352  		// W *= V1ᵀ.
   353  		bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
   354  			1, v, ldv,
   355  			work, ldwork)
   356  		if n > k {
   357  			bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
   358  				1, c[k:], ldc, v[k:], ldv,
   359  				1, work, ldwork)
   360  		}
   361  		// W *= T or Tᵀ.
   362  		bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
   363  			1, t, ldt,
   364  			work, ldwork)
   365  		// C -= W * V.
   366  		if n > k {
   367  			bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
   368  				-1, work, ldwork, v[k:], ldv,
   369  				1, c[k:], ldc)
   370  		}
   371  		// W *= V1.
   372  		bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
   373  			1, v, ldv,
   374  			work, ldwork)
   375  		// C1 -= W.
   376  		// TODO(btracey): This should use blas.Axpy.
   377  		for i := 0; i < m; i++ {
   378  			for j := 0; j < k; j++ {
   379  				c[i*ldc+j] -= work[i*ldwork+j]
   380  			}
   381  		}
   382  		return
   383  	}
   384  	// V = (V1 V2) where V2 is the last k columns and is lower unit triangular.
   385  	if side == blas.Left {
   386  		// Form H * C or Hᵀ C where C = (C1 ; C2).
   387  		// W = Cᵀ * Vᵀ.
   388  
   389  		// W = C2ᵀ.
   390  		for j := 0; j < k; j++ {
   391  			bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
   392  		}
   393  		// W *= V2ᵀ.
   394  		bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
   395  			1, v[m-k:], ldv,
   396  			work, ldwork)
   397  		if m > k {
   398  			bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
   399  				1, c, ldc, v, ldv,
   400  				1, work, ldwork)
   401  		}
   402  		// W *= T or Tᵀ.
   403  		bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
   404  			1, t, ldt,
   405  			work, ldwork)
   406  		// C -= Vᵀ * Wᵀ.
   407  		if m > k {
   408  			bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
   409  				-1, v, ldv, work, ldwork,
   410  				1, c, ldc)
   411  		}
   412  		// W *= V2.
   413  		bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, k,
   414  			1, v[m-k:], ldv,
   415  			work, ldwork)
   416  		// C2 -= Wᵀ.
   417  		// TODO(btracey): This should use blas.Axpy.
   418  		for i := 0; i < n; i++ {
   419  			for j := 0; j < k; j++ {
   420  				c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
   421  			}
   422  		}
   423  		return
   424  	}
   425  	// Form C * H or C * Hᵀ where C = (C1 C2).
   426  	// W = C * Vᵀ.
   427  	// W = C2.
   428  	for j := 0; j < k; j++ {
   429  		bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
   430  	}
   431  	// W *= V2ᵀ.
   432  	bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
   433  		1, v[n-k:], ldv,
   434  		work, ldwork)
   435  	if n > k {
   436  		bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
   437  			1, c, ldc, v, ldv,
   438  			1, work, ldwork)
   439  	}
   440  	// W *= T or Tᵀ.
   441  	bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
   442  		1, t, ldt,
   443  		work, ldwork)
   444  	// C -= W * V.
   445  	if n > k {
   446  		bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
   447  			-1, work, ldwork, v, ldv,
   448  			1, c, ldc)
   449  	}
   450  	// W *= V2.
   451  	bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
   452  		1, v[n-k:], ldv,
   453  		work, ldwork)
   454  	// C1 -= W.
   455  	// TODO(btracey): This should use blas.Axpy.
   456  	for i := 0; i < m; i++ {
   457  		for j := 0; j < k; j++ {
   458  			c[i*ldc+n-k+j] -= work[i*ldwork+j]
   459  		}
   460  	}
   461  }