gonum.org/v1/gonum@v0.14.0/lapack/gonum/dlarft.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"gonum.org/v1/gonum/blas"
     9  	"gonum.org/v1/gonum/blas/blas64"
    10  	"gonum.org/v1/gonum/lapack"
    11  )
    12  
    13  // Dlarft forms the triangular factor T of a block reflector H, storing the answer
    14  // in t.
    15  //
    16  //	H = I - V * T * Vᵀ  if store == lapack.ColumnWise
    17  //	H = I - Vᵀ * T * V  if store == lapack.RowWise
    18  //
    19  // H is defined by a product of the elementary reflectors where
    20  //
    21  //	H = H_0 * H_1 * ... * H_{k-1}  if direct == lapack.Forward
    22  //	H = H_{k-1} * ... * H_1 * H_0  if direct == lapack.Backward
    23  //
    24  // t is a k×k triangular matrix. t is upper triangular if direct = lapack.Forward
    25  // and lower triangular otherwise. This function will panic if t is not of
    26  // sufficient size.
    27  //
    28  // store describes the storage of the elementary reflectors in v. See
    29  // Dlarfb for a description of layout.
    30  //
    31  // tau contains the scalar factors of the elementary reflectors H_i.
    32  //
    33  // Dlarft is an internal routine. It is exported for testing purposes.
    34  func (Implementation) Dlarft(direct lapack.Direct, store lapack.StoreV, n, k int, v []float64, ldv int, tau []float64, t []float64, ldt int) {
    35  	mv, nv := n, k
    36  	if store == lapack.RowWise {
    37  		mv, nv = k, n
    38  	}
    39  	switch {
    40  	case direct != lapack.Forward && direct != lapack.Backward:
    41  		panic(badDirect)
    42  	case store != lapack.RowWise && store != lapack.ColumnWise:
    43  		panic(badStoreV)
    44  	case n < 0:
    45  		panic(nLT0)
    46  	case k < 1:
    47  		panic(kLT1)
    48  	case ldv < max(1, nv):
    49  		panic(badLdV)
    50  	case len(tau) < k:
    51  		panic(shortTau)
    52  	case ldt < max(1, k):
    53  		panic(shortT)
    54  	}
    55  
    56  	if n == 0 {
    57  		return
    58  	}
    59  
    60  	switch {
    61  	case len(v) < (mv-1)*ldv+nv:
    62  		panic(shortV)
    63  	case len(t) < (k-1)*ldt+k:
    64  		panic(shortT)
    65  	}
    66  
    67  	bi := blas64.Implementation()
    68  
    69  	// TODO(btracey): There are a number of minor obvious loop optimizations here.
    70  	// TODO(btracey): It may be possible to rearrange some of the code so that
    71  	// index of 1 is more common in the Dgemv.
    72  	if direct == lapack.Forward {
    73  		prevlastv := n - 1
    74  		for i := 0; i < k; i++ {
    75  			prevlastv = max(i, prevlastv)
    76  			if tau[i] == 0 {
    77  				for j := 0; j <= i; j++ {
    78  					t[j*ldt+i] = 0
    79  				}
    80  				continue
    81  			}
    82  			var lastv int
    83  			if store == lapack.ColumnWise {
    84  				// skip trailing zeros
    85  				for lastv = n - 1; lastv >= i+1; lastv-- {
    86  					if v[lastv*ldv+i] != 0 {
    87  						break
    88  					}
    89  				}
    90  				for j := 0; j < i; j++ {
    91  					t[j*ldt+i] = -tau[i] * v[i*ldv+j]
    92  				}
    93  				j := min(lastv, prevlastv)
    94  				bi.Dgemv(blas.Trans, j-i, i,
    95  					-tau[i], v[(i+1)*ldv:], ldv, v[(i+1)*ldv+i:], ldv,
    96  					1, t[i:], ldt)
    97  			} else {
    98  				for lastv = n - 1; lastv >= i+1; lastv-- {
    99  					if v[i*ldv+lastv] != 0 {
   100  						break
   101  					}
   102  				}
   103  				for j := 0; j < i; j++ {
   104  					t[j*ldt+i] = -tau[i] * v[j*ldv+i]
   105  				}
   106  				j := min(lastv, prevlastv)
   107  				bi.Dgemv(blas.NoTrans, i, j-i,
   108  					-tau[i], v[i+1:], ldv, v[i*ldv+i+1:], 1,
   109  					1, t[i:], ldt)
   110  			}
   111  			bi.Dtrmv(blas.Upper, blas.NoTrans, blas.NonUnit, i, t, ldt, t[i:], ldt)
   112  			t[i*ldt+i] = tau[i]
   113  			if i > 1 {
   114  				prevlastv = max(prevlastv, lastv)
   115  			} else {
   116  				prevlastv = lastv
   117  			}
   118  		}
   119  		return
   120  	}
   121  	prevlastv := 0
   122  	for i := k - 1; i >= 0; i-- {
   123  		if tau[i] == 0 {
   124  			for j := i; j < k; j++ {
   125  				t[j*ldt+i] = 0
   126  			}
   127  			continue
   128  		}
   129  		var lastv int
   130  		if i < k-1 {
   131  			if store == lapack.ColumnWise {
   132  				for lastv = 0; lastv < i; lastv++ {
   133  					if v[lastv*ldv+i] != 0 {
   134  						break
   135  					}
   136  				}
   137  				for j := i + 1; j < k; j++ {
   138  					t[j*ldt+i] = -tau[i] * v[(n-k+i)*ldv+j]
   139  				}
   140  				j := max(lastv, prevlastv)
   141  				bi.Dgemv(blas.Trans, n-k+i-j, k-i-1,
   142  					-tau[i], v[j*ldv+i+1:], ldv, v[j*ldv+i:], ldv,
   143  					1, t[(i+1)*ldt+i:], ldt)
   144  			} else {
   145  				for lastv = 0; lastv < i; lastv++ {
   146  					if v[i*ldv+lastv] != 0 {
   147  						break
   148  					}
   149  				}
   150  				for j := i + 1; j < k; j++ {
   151  					t[j*ldt+i] = -tau[i] * v[j*ldv+n-k+i]
   152  				}
   153  				j := max(lastv, prevlastv)
   154  				bi.Dgemv(blas.NoTrans, k-i-1, n-k+i-j,
   155  					-tau[i], v[(i+1)*ldv+j:], ldv, v[i*ldv+j:], 1,
   156  					1, t[(i+1)*ldt+i:], ldt)
   157  			}
   158  			bi.Dtrmv(blas.Lower, blas.NoTrans, blas.NonUnit, k-i-1,
   159  				t[(i+1)*ldt+i+1:], ldt,
   160  				t[(i+1)*ldt+i:], ldt)
   161  			if i > 0 {
   162  				prevlastv = min(prevlastv, lastv)
   163  			} else {
   164  				prevlastv = lastv
   165  			}
   166  		}
   167  		t[i*ldt+i] = tau[i]
   168  	}
   169  }