github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/native/dlarfb.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package native
     6  
     7  import (
     8  	"github.com/gonum/blas"
     9  	"github.com/gonum/blas/blas64"
    10  	"github.com/gonum/lapack"
    11  )
    12  
    13  // Dlarfb applies a block reflector to a matrix.
    14  //
    15  // In the call to Dlarfb, the mxn c is multiplied by the implicitly defined matrix h as follows:
    16  //  c = h * c if side == Left and trans == NoTrans
    17  //  c = c * h if side == Right and trans == NoTrans
    18  //  c = h^T * c if side == Left and trans == Trans
    19  //  c = c * h^T if side == Right and trans == Trans
    20  // h is a product of elementary reflectors. direct sets the direction of multiplication
    21  //  h = h_1 * h_2 * ... * h_k if direct == Forward
    22  //  h = h_k * h_k-1 * ... * h_1 if direct == Backward
    23  // The combination of direct and store defines the orientation of the elementary
    24  // reflectors. In all cases the ones on the diagonal are implicitly represented.
    25  //
    26  // If direct == lapack.Forward and store == lapack.ColumnWise
    27  //  V = [ 1        ]
    28  //      [v1   1    ]
    29  //      [v1  v2   1]
    30  //      [v1  v2  v3]
    31  //      [v1  v2  v3]
    32  // If direct == lapack.Forward and store == lapack.RowWise
    33  //  V = [ 1  v1  v1  v1  v1]
    34  //      [     1  v2  v2  v2]
    35  //      [         1  v3  v3]
    36  // If direct == lapack.Backward and store == lapack.ColumnWise
    37  //  V = [v1  v2  v3]
    38  //      [v1  v2  v3]
    39  //      [ 1  v2  v3]
    40  //      [     1  v3]
    41  //      [         1]
    42  // If direct == lapack.Backward and store == lapack.RowWise
    43  //  V = [v1  v1   1        ]
    44  //      [v2  v2  v2   1    ]
    45  //      [v3  v3  v3  v3   1]
    46  // An elementary reflector can be explicitly constructed by extracting the
    47  // corresponding elements of v, placing a 1 where the diagonal would be, and
    48  // placing zeros in the remaining elements.
    49  //
    50  // t is a k×k matrix containing the block reflector, and this function will panic
    51  // if t is not of sufficient size. See Dlarft for more information.
    52  //
    53  // work is a temporary storage matrix with stride ldwork.
    54  // work must be of size at least n×k side == Left and m×k if side == Right, and
    55  // this function will panic if this size is not met.
    56  //
    57  // Dlarfb is an internal routine. It is exported for testing purposes.
    58  func (Implementation) Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct, store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int, c []float64, ldc int, work []float64, ldwork int) {
    59  	if side != blas.Left && side != blas.Right {
    60  		panic(badSide)
    61  	}
    62  	if trans != blas.Trans && trans != blas.NoTrans {
    63  		panic(badTrans)
    64  	}
    65  	if direct != lapack.Forward && direct != lapack.Backward {
    66  		panic(badDirect)
    67  	}
    68  	if store != lapack.ColumnWise && store != lapack.RowWise {
    69  		panic(badStore)
    70  	}
    71  	checkMatrix(m, n, c, ldc)
    72  	if k < 0 {
    73  		panic(kLT0)
    74  	}
    75  	checkMatrix(k, k, t, ldt)
    76  	nv := m
    77  	nw := n
    78  	if side == blas.Right {
    79  		nv = n
    80  		nw = m
    81  	}
    82  	if store == lapack.ColumnWise {
    83  		checkMatrix(nv, k, v, ldv)
    84  	} else {
    85  		checkMatrix(k, nv, v, ldv)
    86  	}
    87  	checkMatrix(nw, k, work, ldwork)
    88  
    89  	if m == 0 || n == 0 {
    90  		return
    91  	}
    92  
    93  	bi := blas64.Implementation()
    94  
    95  	transt := blas.Trans
    96  	if trans == blas.Trans {
    97  		transt = blas.NoTrans
    98  	}
    99  	// TODO(btracey): This follows the original Lapack code where the
   100  	// elements are copied into the columns of the working array. The
   101  	// loops should go in the other direction so the data is written
   102  	// into the rows of work so the copy is not strided. A bigger change
   103  	// would be to replace work with work^T, but benchmarks would be
   104  	// needed to see if the change is merited.
   105  	if store == lapack.ColumnWise {
   106  		if direct == lapack.Forward {
   107  			// V1 is the first k rows of C. V2 is the remaining rows.
   108  			if side == blas.Left {
   109  				// W = C^T V = C1^T V1 + C2^T V2 (stored in work).
   110  
   111  				// W = C1.
   112  				for j := 0; j < k; j++ {
   113  					bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
   114  				}
   115  				// W = W * V1.
   116  				bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit,
   117  					n, k, 1,
   118  					v, ldv,
   119  					work, ldwork)
   120  				if m > k {
   121  					// W = W + C2^T V2.
   122  					bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
   123  						1, c[k*ldc:], ldc, v[k*ldv:], ldv,
   124  						1, work, ldwork)
   125  				}
   126  				// W = W * T^T or W * T.
   127  				bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
   128  					1, t, ldt,
   129  					work, ldwork)
   130  				// C -= V * W^T.
   131  				if m > k {
   132  					// C2 -= V2 * W^T.
   133  					bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
   134  						-1, v[k*ldv:], ldv, work, ldwork,
   135  						1, c[k*ldc:], ldc)
   136  				}
   137  				// W *= V1^T.
   138  				bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
   139  					1, v, ldv,
   140  					work, ldwork)
   141  				// C1 -= W^T.
   142  				// TODO(btracey): This should use blas.Axpy.
   143  				for i := 0; i < n; i++ {
   144  					for j := 0; j < k; j++ {
   145  						c[j*ldc+i] -= work[i*ldwork+j]
   146  					}
   147  				}
   148  				return
   149  			}
   150  			// Form C = C * H or C * H^T, where C = (C1 C2).
   151  
   152  			// W = C1.
   153  			for i := 0; i < k; i++ {
   154  				bi.Dcopy(m, c[i:], ldc, work[i:], ldwork)
   155  			}
   156  			// W *= V1.
   157  			bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
   158  				1, v, ldv,
   159  				work, ldwork)
   160  			if n > k {
   161  				bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
   162  					1, c[k:], ldc, v[k*ldv:], ldv,
   163  					1, work, ldwork)
   164  			}
   165  			// W *= T or T^T.
   166  			bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
   167  				1, t, ldt,
   168  				work, ldwork)
   169  			if n > k {
   170  				bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
   171  					-1, work, ldwork, v[k*ldv:], ldv,
   172  					1, c[k:], ldc)
   173  			}
   174  			// C -= W * V^T.
   175  			bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
   176  				1, v, ldv,
   177  				work, ldwork)
   178  			// C -= W.
   179  			// TODO(btracey): This should use blas.Axpy.
   180  			for i := 0; i < m; i++ {
   181  				for j := 0; j < k; j++ {
   182  					c[i*ldc+j] -= work[i*ldwork+j]
   183  				}
   184  			}
   185  			return
   186  		}
   187  		// V = (V1)
   188  		//   = (V2) (last k rows)
   189  		// Where V2 is unit upper triangular.
   190  		if side == blas.Left {
   191  			// Form H * C or
   192  			// W = C^T V.
   193  
   194  			// W = C2^T.
   195  			for j := 0; j < k; j++ {
   196  				bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
   197  			}
   198  			// W *= V2.
   199  			bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
   200  				1, v[(m-k)*ldv:], ldv,
   201  				work, ldwork)
   202  			if m > k {
   203  				// W += C1^T * V1.
   204  				bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
   205  					1, c, ldc, v, ldv,
   206  					1, work, ldwork)
   207  			}
   208  			// W *= T or T^T.
   209  			bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
   210  				1, t, ldt,
   211  				work, ldwork)
   212  			// C -= V * W^T.
   213  			if m > k {
   214  				bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
   215  					-1, v, ldv, work, ldwork,
   216  					1, c, ldc)
   217  			}
   218  			// W *= V2^T.
   219  			bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
   220  				1, v[(m-k)*ldv:], ldv,
   221  				work, ldwork)
   222  			// C2 -= W^T.
   223  			// TODO(btracey): This should use blas.Axpy.
   224  			for i := 0; i < n; i++ {
   225  				for j := 0; j < k; j++ {
   226  					c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
   227  				}
   228  			}
   229  			return
   230  		}
   231  		// Form C * H or C * H^T where C = (C1 C2).
   232  		// W = C * V.
   233  
   234  		// W = C2.
   235  		for j := 0; j < k; j++ {
   236  			bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
   237  		}
   238  
   239  		// W = W * V2.
   240  		bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
   241  			1, v[(n-k)*ldv:], ldv,
   242  			work, ldwork)
   243  		if n > k {
   244  			bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
   245  				1, c, ldc, v, ldv,
   246  				1, work, ldwork)
   247  		}
   248  		// W *= T or T^T.
   249  		bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
   250  			1, t, ldt,
   251  			work, ldwork)
   252  		// C -= W * V^T.
   253  		if n > k {
   254  			// C1 -= W * V1^T.
   255  			bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
   256  				-1, work, ldwork, v, ldv,
   257  				1, c, ldc)
   258  		}
   259  		// W *= V2^T.
   260  		bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
   261  			1, v[(n-k)*ldv:], ldv,
   262  			work, ldwork)
   263  		// C2 -= W.
   264  		// TODO(btracey): This should use blas.Axpy.
   265  		for i := 0; i < m; i++ {
   266  			for j := 0; j < k; j++ {
   267  				c[i*ldc+n-k+j] -= work[i*ldwork+j]
   268  			}
   269  		}
   270  		return
   271  	}
   272  	// Store = Rowwise.
   273  	if direct == lapack.Forward {
   274  		// V = (V1 V2) where v1 is unit upper triangular.
   275  		if side == blas.Left {
   276  			// Form H * C or H^T * C where C = (C1; C2).
   277  			// W = C^T * V^T.
   278  
   279  			// W = C1^T.
   280  			for j := 0; j < k; j++ {
   281  				bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
   282  			}
   283  			// W *= V1^T.
   284  			bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
   285  				1, v, ldv,
   286  				work, ldwork)
   287  			if m > k {
   288  				bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
   289  					1, c[k*ldc:], ldc, v[k:], ldv,
   290  					1, work, ldwork)
   291  			}
   292  			// W *= T or T^T.
   293  			bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
   294  				1, t, ldt,
   295  				work, ldwork)
   296  			// C -= V^T * W^T.
   297  			if m > k {
   298  				bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
   299  					-1, v[k:], ldv, work, ldwork,
   300  					1, c[k*ldc:], ldc)
   301  			}
   302  			// W *= V1.
   303  			bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
   304  				1, v, ldv,
   305  				work, ldwork)
   306  			// C1 -= W^T.
   307  			// TODO(btracey): This should use blas.Axpy.
   308  			for i := 0; i < n; i++ {
   309  				for j := 0; j < k; j++ {
   310  					c[j*ldc+i] -= work[i*ldwork+j]
   311  				}
   312  			}
   313  			return
   314  		}
   315  		// Form C * H or C * H^T where C = (C1 C2).
   316  		// W = C * V^T.
   317  
   318  		// W = C1.
   319  		for j := 0; j < k; j++ {
   320  			bi.Dcopy(m, c[j:], ldc, work[j:], ldwork)
   321  		}
   322  		// W *= V1^T.
   323  		bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
   324  			1, v, ldv,
   325  			work, ldwork)
   326  		if n > k {
   327  			bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
   328  				1, c[k:], ldc, v[k:], ldv,
   329  				1, work, ldwork)
   330  		}
   331  		// W *= T or T^T.
   332  		bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
   333  			1, t, ldt,
   334  			work, ldwork)
   335  		// C -= W * V.
   336  		if n > k {
   337  			bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
   338  				-1, work, ldwork, v[k:], ldv,
   339  				1, c[k:], ldc)
   340  		}
   341  		// W *= V1.
   342  		bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
   343  			1, v, ldv,
   344  			work, ldwork)
   345  		// C1 -= W.
   346  		// TODO(btracey): This should use blas.Axpy.
   347  		for i := 0; i < m; i++ {
   348  			for j := 0; j < k; j++ {
   349  				c[i*ldc+j] -= work[i*ldwork+j]
   350  			}
   351  		}
   352  		return
   353  	}
   354  	// V = (V1 V2) where V2 is the last k columns and is lower unit triangular.
   355  	if side == blas.Left {
   356  		// Form H * C or H^T C where C = (C1 ; C2).
   357  		// W = C^T * V^T.
   358  
   359  		// W = C2^T.
   360  		for j := 0; j < k; j++ {
   361  			bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
   362  		}
   363  		// W *= V2^T.
   364  		bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
   365  			1, v[m-k:], ldv,
   366  			work, ldwork)
   367  		if m > k {
   368  			bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
   369  				1, c, ldc, v, ldv,
   370  				1, work, ldwork)
   371  		}
   372  		// W *= T or T^T.
   373  		bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
   374  			1, t, ldt,
   375  			work, ldwork)
   376  		// C -= V^T * W^T.
   377  		if m > k {
   378  			bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
   379  				-1, v, ldv, work, ldwork,
   380  				1, c, ldc)
   381  		}
   382  		// W *= V2.
   383  		bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, k,
   384  			1, v[m-k:], ldv,
   385  			work, ldwork)
   386  		// C2 -= W^T.
   387  		// TODO(btracey): This should use blas.Axpy.
   388  		for i := 0; i < n; i++ {
   389  			for j := 0; j < k; j++ {
   390  				c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
   391  			}
   392  		}
   393  		return
   394  	}
   395  	// Form C * H or C * H^T where C = (C1 C2).
   396  	// W = C * V^T.
   397  	// W = C2.
   398  	for j := 0; j < k; j++ {
   399  		bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
   400  	}
   401  	// W *= V2^T.
   402  	bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
   403  		1, v[n-k:], ldv,
   404  		work, ldwork)
   405  	if n > k {
   406  		bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
   407  			1, c, ldc, v, ldv,
   408  			1, work, ldwork)
   409  	}
   410  	// W *= T or T^T.
   411  	bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
   412  		1, t, ldt,
   413  		work, ldwork)
   414  	// C -= W * V.
   415  	if n > k {
   416  		bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
   417  			-1, work, ldwork, v, ldv,
   418  			1, c, ldc)
   419  	}
   420  	// W *= V2.
   421  	bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
   422  		1, v[n-k:], ldv,
   423  		work, ldwork)
   424  	// C1 -= W.
   425  	// TODO(btracey): This should use blas.Axpy.
   426  	for i := 0; i < m; i++ {
   427  		for j := 0; j < k; j++ {
   428  			c[i*ldc+n-k+j] -= work[i*ldwork+j]
   429  		}
   430  	}
   431  }