github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/native/dlarft.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package native
     6  
     7  import (
     8  	"github.com/gonum/blas"
     9  	"github.com/gonum/blas/blas64"
    10  	"github.com/gonum/lapack"
    11  )
    12  
    13  // Dlarft forms the triangular factor T of a block reflector H, storing the answer
    14  // in t.
    15  //  H = I - V * T * V^T  if store == lapack.ColumnWise
    16  //  H = I - V^T * T * V  if store == lapack.RowWise
    17  // H is defined by a product of the elementary reflectors where
    18  //  H = H_0 * H_1 * ... * H_{k-1}  if direct == lapack.Forward
    19  //  H = H_{k-1} * ... * H_1 * H_0  if direct == lapack.Backward
    20  //
    21  // t is a k×k triangular matrix. t is upper triangular if direct = lapack.Forward
    22  // and lower triangular otherwise. This function will panic if t is not of
    23  // sufficient size.
    24  //
    25  // store describes the storage of the elementary reflectors in v. Please see
    26  // Dlarfb for a description of layout.
    27  //
    28  // tau contains the scalar factors of the elementary reflectors H_i.
    29  //
    30  // Dlarft is an internal routine. It is exported for testing purposes.
    31  func (Implementation) Dlarft(direct lapack.Direct, store lapack.StoreV, n, k int,
    32  	v []float64, ldv int, tau []float64, t []float64, ldt int) {
    33  	if n == 0 {
    34  		return
    35  	}
    36  	if n < 0 || k < 0 {
    37  		panic(negDimension)
    38  	}
    39  	if direct != lapack.Forward && direct != lapack.Backward {
    40  		panic(badDirect)
    41  	}
    42  	if store != lapack.RowWise && store != lapack.ColumnWise {
    43  		panic(badStore)
    44  	}
    45  	if len(tau) < k {
    46  		panic(badTau)
    47  	}
    48  	checkMatrix(k, k, t, ldt)
    49  	bi := blas64.Implementation()
    50  	// TODO(btracey): There are a number of minor obvious loop optimizations here.
    51  	// TODO(btracey): It may be possible to rearrange some of the code so that
    52  	// index of 1 is more common in the Dgemv.
    53  	if direct == lapack.Forward {
    54  		prevlastv := n - 1
    55  		for i := 0; i < k; i++ {
    56  			prevlastv = max(i, prevlastv)
    57  			if tau[i] == 0 {
    58  				for j := 0; j <= i; j++ {
    59  					t[j*ldt+i] = 0
    60  				}
    61  				continue
    62  			}
    63  			var lastv int
    64  			if store == lapack.ColumnWise {
    65  				// skip trailing zeros
    66  				for lastv = n - 1; lastv >= i+1; lastv-- {
    67  					if v[lastv*ldv+i] != 0 {
    68  						break
    69  					}
    70  				}
    71  				for j := 0; j < i; j++ {
    72  					t[j*ldt+i] = -tau[i] * v[i*ldv+j]
    73  				}
    74  				j := min(lastv, prevlastv)
    75  				bi.Dgemv(blas.Trans, j-i, i,
    76  					-tau[i], v[(i+1)*ldv:], ldv, v[(i+1)*ldv+i:], ldv,
    77  					1, t[i:], ldt)
    78  			} else {
    79  				for lastv = n - 1; lastv >= i+1; lastv-- {
    80  					if v[i*ldv+lastv] != 0 {
    81  						break
    82  					}
    83  				}
    84  				for j := 0; j < i; j++ {
    85  					t[j*ldt+i] = -tau[i] * v[j*ldv+i]
    86  				}
    87  				j := min(lastv, prevlastv)
    88  				bi.Dgemv(blas.NoTrans, i, j-i,
    89  					-tau[i], v[i+1:], ldv, v[i*ldv+i+1:], 1,
    90  					1, t[i:], ldt)
    91  			}
    92  			bi.Dtrmv(blas.Upper, blas.NoTrans, blas.NonUnit, i, t, ldt, t[i:], ldt)
    93  			t[i*ldt+i] = tau[i]
    94  			if i > 1 {
    95  				prevlastv = max(prevlastv, lastv)
    96  			} else {
    97  				prevlastv = lastv
    98  			}
    99  		}
   100  		return
   101  	}
   102  	prevlastv := 0
   103  	for i := k - 1; i >= 0; i-- {
   104  		if tau[i] == 0 {
   105  			for j := i; j < k; j++ {
   106  				t[j*ldt+i] = 0
   107  			}
   108  			continue
   109  		}
   110  		var lastv int
   111  		if i < k-1 {
   112  			if store == lapack.ColumnWise {
   113  				for lastv = 0; lastv < i; lastv++ {
   114  					if v[lastv*ldv+i] != 0 {
   115  						break
   116  					}
   117  				}
   118  				for j := i + 1; j < k; j++ {
   119  					t[j*ldt+i] = -tau[i] * v[(n-k+i)*ldv+j]
   120  				}
   121  				j := max(lastv, prevlastv)
   122  				bi.Dgemv(blas.Trans, n-k+i-j, k-i-1,
   123  					-tau[i], v[j*ldv+i+1:], ldv, v[j*ldv+i:], ldv,
   124  					1, t[(i+1)*ldt+i:], ldt)
   125  			} else {
   126  				for lastv = 0; lastv < i; lastv++ {
   127  					if v[i*ldv+lastv] != 0 {
   128  						break
   129  					}
   130  				}
   131  				for j := i + 1; j < k; j++ {
   132  					t[j*ldt+i] = -tau[i] * v[j*ldv+n-k+i]
   133  				}
   134  				j := max(lastv, prevlastv)
   135  				bi.Dgemv(blas.NoTrans, k-i-1, n-k+i-j,
   136  					-tau[i], v[(i+1)*ldv+j:], ldv, v[i*ldv+j:], 1,
   137  					1, t[(i+1)*ldt+i:], ldt)
   138  			}
   139  			bi.Dtrmv(blas.Lower, blas.NoTrans, blas.NonUnit, k-i-1,
   140  				t[(i+1)*ldt+i+1:], ldt,
   141  				t[(i+1)*ldt+i:], ldt)
   142  			if i > 0 {
   143  				prevlastv = min(prevlastv, lastv)
   144  			} else {
   145  				prevlastv = lastv
   146  			}
   147  		}
   148  		t[i*ldt+i] = tau[i]
   149  	}
   150  }