github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/testlapack/dlarf.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testlapack
     6  
     7  import (
     8  	"fmt"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"github.com/jingcheng-WU/gonum/blas"
    14  	"github.com/jingcheng-WU/gonum/blas/blas64"
    15  	"github.com/jingcheng-WU/gonum/floats"
    16  	"github.com/jingcheng-WU/gonum/lapack"
    17  )
    18  
    19  type Dlarfer interface {
    20  	Dlarf(side blas.Side, m, n int, v []float64, incv int, tau float64, c []float64, ldc int, work []float64)
    21  }
    22  
    23  func DlarfTest(t *testing.T, impl Dlarfer) {
    24  	for _, side := range []blas.Side{blas.Left, blas.Right} {
    25  		name := sideToString(side)
    26  		t.Run(name, func(t *testing.T) {
    27  			runDlarfTest(t, impl, side)
    28  		})
    29  	}
    30  }
    31  
    32  func runDlarfTest(t *testing.T, impl Dlarfer, side blas.Side) {
    33  	rnd := rand.New(rand.NewSource(1))
    34  	for _, m := range []int{0, 1, 2, 3, 4, 5, 10} {
    35  		for _, n := range []int{0, 1, 2, 3, 4, 5, 10} {
    36  			for _, incv := range []int{1, 4} {
    37  				for _, ldc := range []int{max(1, n), n + 3} {
    38  					for _, nnzv := range []int{0, 1, 2} {
    39  						for _, nnzc := range []int{0, 1, 2} {
    40  							for _, tau := range []float64{0, rnd.NormFloat64()} {
    41  								dlarfTest(t, impl, rnd, side, m, n, incv, ldc, nnzv, nnzc, tau)
    42  							}
    43  						}
    44  					}
    45  				}
    46  			}
    47  		}
    48  	}
    49  }
    50  
    51  func dlarfTest(t *testing.T, impl Dlarfer, rnd *rand.Rand, side blas.Side, m, n, incv, ldc, nnzv, nnzc int, tau float64) {
    52  	const tol = 1e-14
    53  
    54  	c := make([]float64, m*ldc)
    55  	for i := range c {
    56  		c[i] = rnd.NormFloat64()
    57  	}
    58  	switch nnzc {
    59  	case 0:
    60  		// Zero out all of C.
    61  		for i := 0; i < m; i++ {
    62  			for j := 0; j < n; j++ {
    63  				c[i*ldc+j] = 0
    64  			}
    65  		}
    66  	case 1:
    67  		// Zero out right or bottom half of C.
    68  		if side == blas.Left {
    69  			for i := 0; i < m; i++ {
    70  				for j := n / 2; j < n; j++ {
    71  					c[i*ldc+j] = 0
    72  				}
    73  			}
    74  		} else {
    75  			for i := m / 2; i < m; i++ {
    76  				for j := 0; j < n; j++ {
    77  					c[i*ldc+j] = 0
    78  				}
    79  			}
    80  		}
    81  	default:
    82  		// Leave C with random content.
    83  	}
    84  	cCopy := make([]float64, len(c))
    85  	copy(cCopy, c)
    86  
    87  	var work []float64
    88  	if side == blas.Left {
    89  		work = make([]float64, n)
    90  	} else {
    91  		work = make([]float64, m)
    92  	}
    93  
    94  	vlen := n
    95  	if side == blas.Left {
    96  		vlen = m
    97  	}
    98  	vlen = max(1, vlen)
    99  	v := make([]float64, 1+(vlen-1)*incv)
   100  	for i := range v {
   101  		v[i] = rnd.NormFloat64()
   102  	}
   103  	switch nnzv {
   104  	case 0:
   105  		// Zero out all of v.
   106  		for i := 0; i < vlen; i++ {
   107  			v[i*incv] = 0
   108  		}
   109  	case 1:
   110  		// Zero out half of v.
   111  		for i := vlen / 2; i < vlen; i++ {
   112  			v[i*incv] = 0
   113  		}
   114  	default:
   115  		// Leave v with random content.
   116  	}
   117  	vCopy := make([]float64, len(v))
   118  	copy(vCopy, v)
   119  
   120  	impl.Dlarf(side, m, n, v, incv, tau, c, ldc, work)
   121  	got := c
   122  
   123  	name := fmt.Sprintf("m=%d,n=%d,incv=%d,tau=%f,ldc=%d", m, n, incv, tau, ldc)
   124  
   125  	if !floats.Equal(v, vCopy) {
   126  		t.Errorf("%v: unexpected modification of v", name)
   127  	}
   128  	if tau == 0 && !floats.Equal(got, cCopy) {
   129  		t.Errorf("%v: unexpected modification of C", name)
   130  	}
   131  
   132  	if m == 0 || n == 0 || tau == 0 {
   133  		return
   134  	}
   135  
   136  	bi := blas64.Implementation()
   137  
   138  	want := make([]float64, len(cCopy))
   139  	if side == blas.Left {
   140  		// Compute want = (I - tau * v * vᵀ) * C
   141  
   142  		// vtc = -tau * vᵀ * C = -tau * Cᵀ * v
   143  		vtc := make([]float64, n)
   144  		bi.Dgemv(blas.Trans, m, n, -tau, cCopy, ldc, v, incv, 0, vtc, 1)
   145  
   146  		// want = C + v * vtcᵀ
   147  		for i := 0; i < m; i++ {
   148  			for j := 0; j < n; j++ {
   149  				want[i*ldc+j] = cCopy[i*ldc+j] + v[i*incv]*vtc[j]
   150  			}
   151  		}
   152  	} else {
   153  		// Compute want = C * (I - tau * v * vᵀ)
   154  
   155  		// cv = -tau * C * v
   156  		cv := make([]float64, m)
   157  		bi.Dgemv(blas.NoTrans, m, n, -tau, cCopy, ldc, v, incv, 0, cv, 1)
   158  
   159  		// want = C + cv * vᵀ
   160  		for i := 0; i < m; i++ {
   161  			for j := 0; j < n; j++ {
   162  				want[i*ldc+j] = cCopy[i*ldc+j] + cv[i]*v[j*incv]
   163  			}
   164  		}
   165  	}
   166  	diff := make([]float64, m*n)
   167  	for i := 0; i < m; i++ {
   168  		for j := 0; j < n; j++ {
   169  			diff[i*n+j] = got[i*ldc+j] - want[i*ldc+j]
   170  		}
   171  	}
   172  	resid := dlange(lapack.MaxColumnSum, m, n, diff, n)
   173  	if resid > tol*float64(max(m, n)) {
   174  		t.Errorf("%v: unexpected result; resid=%v, want<=%v", name, resid, tol*float64(max(m, n)))
   175  	}
   176  }