github.com/gopherd/gonum@v0.0.4/lapack/gonum/dlarfb.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"github.com/gopherd/gonum/blas"
     9  	"github.com/gopherd/gonum/blas/blas64"
    10  	"github.com/gopherd/gonum/lapack"
    11  )
    12  
    13  // Dlarfb applies a block reflector to a matrix.
    14  //
    15  // In the call to Dlarfb, the mxn c is multiplied by the implicitly defined matrix h as follows:
    16  //  c = h * c   if side == Left and trans == NoTrans
    17  //  c = c * h   if side == Right and trans == NoTrans
    18  //  c = hᵀ * c  if side == Left and trans == Trans
    19  //  c = c * hᵀ  if side == Right and trans == Trans
    20  // h is a product of elementary reflectors. direct sets the direction of multiplication
    21  //  h = h_1 * h_2 * ... * h_k    if direct == Forward
    22  //  h = h_k * h_k-1 * ... * h_1  if direct == Backward
    23  // The combination of direct and store defines the orientation of the elementary
    24  // reflectors. In all cases the ones on the diagonal are implicitly represented.
    25  //
    26  // If direct == lapack.Forward and store == lapack.ColumnWise
    27  //  V = [ 1        ]
    28  //      [v1   1    ]
    29  //      [v1  v2   1]
    30  //      [v1  v2  v3]
    31  //      [v1  v2  v3]
    32  // If direct == lapack.Forward and store == lapack.RowWise
    33  //  V = [ 1  v1  v1  v1  v1]
    34  //      [     1  v2  v2  v2]
    35  //      [         1  v3  v3]
    36  // If direct == lapack.Backward and store == lapack.ColumnWise
    37  //  V = [v1  v2  v3]
    38  //      [v1  v2  v3]
    39  //      [ 1  v2  v3]
    40  //      [     1  v3]
    41  //      [         1]
    42  // If direct == lapack.Backward and store == lapack.RowWise
    43  //  V = [v1  v1   1        ]
    44  //      [v2  v2  v2   1    ]
    45  //      [v3  v3  v3  v3   1]
    46  // An elementary reflector can be explicitly constructed by extracting the
    47  // corresponding elements of v, placing a 1 where the diagonal would be, and
    48  // placing zeros in the remaining elements.
    49  //
    50  // t is a k×k matrix containing the block reflector, and this function will panic
    51  // if t is not of sufficient size. See Dlarft for more information.
    52  //
    53  // work is a temporary storage matrix with stride ldwork.
    54  // work must be of size at least n×k side == Left and m×k if side == Right, and
    55  // this function will panic if this size is not met.
    56  //
    57  // Dlarfb is an internal routine. It is exported for testing purposes.
    58  func (Implementation) Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct, store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int, c []float64, ldc int, work []float64, ldwork int) {
    59  	nv := m
    60  	if side == blas.Right {
    61  		nv = n
    62  	}
    63  	switch {
    64  	case side != blas.Left && side != blas.Right:
    65  		panic(badSide)
    66  	case trans != blas.Trans && trans != blas.NoTrans:
    67  		panic(badTrans)
    68  	case direct != lapack.Forward && direct != lapack.Backward:
    69  		panic(badDirect)
    70  	case store != lapack.ColumnWise && store != lapack.RowWise:
    71  		panic(badStoreV)
    72  	case m < 0:
    73  		panic(mLT0)
    74  	case n < 0:
    75  		panic(nLT0)
    76  	case k < 0:
    77  		panic(kLT0)
    78  	case store == lapack.ColumnWise && ldv < max(1, k):
    79  		panic(badLdV)
    80  	case store == lapack.RowWise && ldv < max(1, nv):
    81  		panic(badLdV)
    82  	case ldt < max(1, k):
    83  		panic(badLdT)
    84  	case ldc < max(1, n):
    85  		panic(badLdC)
    86  	case ldwork < max(1, k):
    87  		panic(badLdWork)
    88  	}
    89  
    90  	if m == 0 || n == 0 {
    91  		return
    92  	}
    93  
    94  	nw := n
    95  	if side == blas.Right {
    96  		nw = m
    97  	}
    98  	switch {
    99  	case store == lapack.ColumnWise && len(v) < (nv-1)*ldv+k:
   100  		panic(shortV)
   101  	case store == lapack.RowWise && len(v) < (k-1)*ldv+nv:
   102  		panic(shortV)
   103  	case len(t) < (k-1)*ldt+k:
   104  		panic(shortT)
   105  	case len(c) < (m-1)*ldc+n:
   106  		panic(shortC)
   107  	case len(work) < (nw-1)*ldwork+k:
   108  		panic(shortWork)
   109  	}
   110  
   111  	bi := blas64.Implementation()
   112  
   113  	transt := blas.Trans
   114  	if trans == blas.Trans {
   115  		transt = blas.NoTrans
   116  	}
   117  	// TODO(btracey): This follows the original Lapack code where the
   118  	// elements are copied into the columns of the working array. The
   119  	// loops should go in the other direction so the data is written
   120  	// into the rows of work so the copy is not strided. A bigger change
   121  	// would be to replace work with workᵀ, but benchmarks would be
   122  	// needed to see if the change is merited.
   123  	if store == lapack.ColumnWise {
   124  		if direct == lapack.Forward {
   125  			// V1 is the first k rows of C. V2 is the remaining rows.
   126  			if side == blas.Left {
   127  				// W = Cᵀ V = C1ᵀ V1 + C2ᵀ V2 (stored in work).
   128  
   129  				// W = C1.
   130  				for j := 0; j < k; j++ {
   131  					bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
   132  				}
   133  				// W = W * V1.
   134  				bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit,
   135  					n, k, 1,
   136  					v, ldv,
   137  					work, ldwork)
   138  				if m > k {
   139  					// W = W + C2ᵀ V2.
   140  					bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
   141  						1, c[k*ldc:], ldc, v[k*ldv:], ldv,
   142  						1, work, ldwork)
   143  				}
   144  				// W = W * Tᵀ or W * T.
   145  				bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
   146  					1, t, ldt,
   147  					work, ldwork)
   148  				// C -= V * Wᵀ.
   149  				if m > k {
   150  					// C2 -= V2 * Wᵀ.
   151  					bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
   152  						-1, v[k*ldv:], ldv, work, ldwork,
   153  						1, c[k*ldc:], ldc)
   154  				}
   155  				// W *= V1ᵀ.
   156  				bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
   157  					1, v, ldv,
   158  					work, ldwork)
   159  				// C1 -= Wᵀ.
   160  				// TODO(btracey): This should use blas.Axpy.
   161  				for i := 0; i < n; i++ {
   162  					for j := 0; j < k; j++ {
   163  						c[j*ldc+i] -= work[i*ldwork+j]
   164  					}
   165  				}
   166  				return
   167  			}
   168  			// Form C = C * H or C * Hᵀ, where C = (C1 C2).
   169  
   170  			// W = C1.
   171  			for i := 0; i < k; i++ {
   172  				bi.Dcopy(m, c[i:], ldc, work[i:], ldwork)
   173  			}
   174  			// W *= V1.
   175  			bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
   176  				1, v, ldv,
   177  				work, ldwork)
   178  			if n > k {
   179  				bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
   180  					1, c[k:], ldc, v[k*ldv:], ldv,
   181  					1, work, ldwork)
   182  			}
   183  			// W *= T or Tᵀ.
   184  			bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
   185  				1, t, ldt,
   186  				work, ldwork)
   187  			if n > k {
   188  				bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
   189  					-1, work, ldwork, v[k*ldv:], ldv,
   190  					1, c[k:], ldc)
   191  			}
   192  			// C -= W * Vᵀ.
   193  			bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
   194  				1, v, ldv,
   195  				work, ldwork)
   196  			// C -= W.
   197  			// TODO(btracey): This should use blas.Axpy.
   198  			for i := 0; i < m; i++ {
   199  				for j := 0; j < k; j++ {
   200  					c[i*ldc+j] -= work[i*ldwork+j]
   201  				}
   202  			}
   203  			return
   204  		}
   205  		// V = (V1)
   206  		//   = (V2) (last k rows)
   207  		// Where V2 is unit upper triangular.
   208  		if side == blas.Left {
   209  			// Form H * C or
   210  			// W = Cᵀ V.
   211  
   212  			// W = C2ᵀ.
   213  			for j := 0; j < k; j++ {
   214  				bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
   215  			}
   216  			// W *= V2.
   217  			bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
   218  				1, v[(m-k)*ldv:], ldv,
   219  				work, ldwork)
   220  			if m > k {
   221  				// W += C1ᵀ * V1.
   222  				bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
   223  					1, c, ldc, v, ldv,
   224  					1, work, ldwork)
   225  			}
   226  			// W *= T or Tᵀ.
   227  			bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
   228  				1, t, ldt,
   229  				work, ldwork)
   230  			// C -= V * Wᵀ.
   231  			if m > k {
   232  				bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
   233  					-1, v, ldv, work, ldwork,
   234  					1, c, ldc)
   235  			}
   236  			// W *= V2ᵀ.
   237  			bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
   238  				1, v[(m-k)*ldv:], ldv,
   239  				work, ldwork)
   240  			// C2 -= Wᵀ.
   241  			// TODO(btracey): This should use blas.Axpy.
   242  			for i := 0; i < n; i++ {
   243  				for j := 0; j < k; j++ {
   244  					c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
   245  				}
   246  			}
   247  			return
   248  		}
   249  		// Form C * H or C * Hᵀ where C = (C1 C2).
   250  		// W = C * V.
   251  
   252  		// W = C2.
   253  		for j := 0; j < k; j++ {
   254  			bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
   255  		}
   256  
   257  		// W = W * V2.
   258  		bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
   259  			1, v[(n-k)*ldv:], ldv,
   260  			work, ldwork)
   261  		if n > k {
   262  			bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
   263  				1, c, ldc, v, ldv,
   264  				1, work, ldwork)
   265  		}
   266  		// W *= T or Tᵀ.
   267  		bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
   268  			1, t, ldt,
   269  			work, ldwork)
   270  		// C -= W * Vᵀ.
   271  		if n > k {
   272  			// C1 -= W * V1ᵀ.
   273  			bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
   274  				-1, work, ldwork, v, ldv,
   275  				1, c, ldc)
   276  		}
   277  		// W *= V2ᵀ.
   278  		bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
   279  			1, v[(n-k)*ldv:], ldv,
   280  			work, ldwork)
   281  		// C2 -= W.
   282  		// TODO(btracey): This should use blas.Axpy.
   283  		for i := 0; i < m; i++ {
   284  			for j := 0; j < k; j++ {
   285  				c[i*ldc+n-k+j] -= work[i*ldwork+j]
   286  			}
   287  		}
   288  		return
   289  	}
   290  	// Store = Rowwise.
   291  	if direct == lapack.Forward {
   292  		// V = (V1 V2) where v1 is unit upper triangular.
   293  		if side == blas.Left {
   294  			// Form H * C or Hᵀ * C where C = (C1; C2).
   295  			// W = Cᵀ * Vᵀ.
   296  
   297  			// W = C1ᵀ.
   298  			for j := 0; j < k; j++ {
   299  				bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
   300  			}
   301  			// W *= V1ᵀ.
   302  			bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
   303  				1, v, ldv,
   304  				work, ldwork)
   305  			if m > k {
   306  				bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
   307  					1, c[k*ldc:], ldc, v[k:], ldv,
   308  					1, work, ldwork)
   309  			}
   310  			// W *= T or Tᵀ.
   311  			bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
   312  				1, t, ldt,
   313  				work, ldwork)
   314  			// C -= Vᵀ * Wᵀ.
   315  			if m > k {
   316  				bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
   317  					-1, v[k:], ldv, work, ldwork,
   318  					1, c[k*ldc:], ldc)
   319  			}
   320  			// W *= V1.
   321  			bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
   322  				1, v, ldv,
   323  				work, ldwork)
   324  			// C1 -= Wᵀ.
   325  			// TODO(btracey): This should use blas.Axpy.
   326  			for i := 0; i < n; i++ {
   327  				for j := 0; j < k; j++ {
   328  					c[j*ldc+i] -= work[i*ldwork+j]
   329  				}
   330  			}
   331  			return
   332  		}
   333  		// Form C * H or C * Hᵀ where C = (C1 C2).
   334  		// W = C * Vᵀ.
   335  
   336  		// W = C1.
   337  		for j := 0; j < k; j++ {
   338  			bi.Dcopy(m, c[j:], ldc, work[j:], ldwork)
   339  		}
   340  		// W *= V1ᵀ.
   341  		bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
   342  			1, v, ldv,
   343  			work, ldwork)
   344  		if n > k {
   345  			bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
   346  				1, c[k:], ldc, v[k:], ldv,
   347  				1, work, ldwork)
   348  		}
   349  		// W *= T or Tᵀ.
   350  		bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
   351  			1, t, ldt,
   352  			work, ldwork)
   353  		// C -= W * V.
   354  		if n > k {
   355  			bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
   356  				-1, work, ldwork, v[k:], ldv,
   357  				1, c[k:], ldc)
   358  		}
   359  		// W *= V1.
   360  		bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
   361  			1, v, ldv,
   362  			work, ldwork)
   363  		// C1 -= W.
   364  		// TODO(btracey): This should use blas.Axpy.
   365  		for i := 0; i < m; i++ {
   366  			for j := 0; j < k; j++ {
   367  				c[i*ldc+j] -= work[i*ldwork+j]
   368  			}
   369  		}
   370  		return
   371  	}
   372  	// V = (V1 V2) where V2 is the last k columns and is lower unit triangular.
   373  	if side == blas.Left {
   374  		// Form H * C or Hᵀ C where C = (C1 ; C2).
   375  		// W = Cᵀ * Vᵀ.
   376  
   377  		// W = C2ᵀ.
   378  		for j := 0; j < k; j++ {
   379  			bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
   380  		}
   381  		// W *= V2ᵀ.
   382  		bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
   383  			1, v[m-k:], ldv,
   384  			work, ldwork)
   385  		if m > k {
   386  			bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
   387  				1, c, ldc, v, ldv,
   388  				1, work, ldwork)
   389  		}
   390  		// W *= T or Tᵀ.
   391  		bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
   392  			1, t, ldt,
   393  			work, ldwork)
   394  		// C -= Vᵀ * Wᵀ.
   395  		if m > k {
   396  			bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
   397  				-1, v, ldv, work, ldwork,
   398  				1, c, ldc)
   399  		}
   400  		// W *= V2.
   401  		bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, k,
   402  			1, v[m-k:], ldv,
   403  			work, ldwork)
   404  		// C2 -= Wᵀ.
   405  		// TODO(btracey): This should use blas.Axpy.
   406  		for i := 0; i < n; i++ {
   407  			for j := 0; j < k; j++ {
   408  				c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
   409  			}
   410  		}
   411  		return
   412  	}
   413  	// Form C * H or C * Hᵀ where C = (C1 C2).
   414  	// W = C * Vᵀ.
   415  	// W = C2.
   416  	for j := 0; j < k; j++ {
   417  		bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
   418  	}
   419  	// W *= V2ᵀ.
   420  	bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
   421  		1, v[n-k:], ldv,
   422  		work, ldwork)
   423  	if n > k {
   424  		bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
   425  			1, c, ldc, v, ldv,
   426  			1, work, ldwork)
   427  	}
   428  	// W *= T or Tᵀ.
   429  	bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
   430  		1, t, ldt,
   431  		work, ldwork)
   432  	// C -= W * V.
   433  	if n > k {
   434  		bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
   435  			-1, work, ldwork, v, ldv,
   436  			1, c, ldc)
   437  	}
   438  	// W *= V2.
   439  	bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
   440  		1, v[n-k:], ldv,
   441  		work, ldwork)
   442  	// C1 -= W.
   443  	// TODO(btracey): This should use blas.Axpy.
   444  	for i := 0; i < m; i++ {
   445  		for j := 0; j < k; j++ {
   446  			c[i*ldc+n-k+j] -= work[i*ldwork+j]
   447  		}
   448  	}
   449  }