github.com/gopherd/gonum@v0.0.4/blas/testblas/zsyrk.go (about)

     1  // Copyright ©2019 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testblas
     6  
     7  import (
     8  	"fmt"
     9  	"math/cmplx"
    10  	"testing"
    11  
    12  	"math/rand"
    13  
    14  	"github.com/gopherd/gonum/blas"
    15  )
    16  
    17  type Zsyrker interface {
    18  	Zsyrk(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, beta complex128, c []complex128, ldc int)
    19  }
    20  
    21  func ZsyrkTest(t *testing.T, impl Zsyrker) {
    22  	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
    23  		for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
    24  			name := uploString(uplo) + "-" + transString(trans)
    25  			t.Run(name, func(t *testing.T) {
    26  				for _, n := range []int{0, 1, 2, 3, 4, 5} {
    27  					for _, k := range []int{0, 1, 2, 3, 4, 5, 7} {
    28  						zsyrkTest(t, impl, uplo, trans, n, k)
    29  					}
    30  				}
    31  			})
    32  		}
    33  	}
    34  }
    35  
    36  func zsyrkTest(t *testing.T, impl Zsyrker, uplo blas.Uplo, trans blas.Transpose, n, k int) {
    37  	const tol = 1e-13
    38  
    39  	rnd := rand.New(rand.NewSource(1))
    40  
    41  	rowA, colA := n, k
    42  	if trans == blas.Trans {
    43  		rowA, colA = k, n
    44  	}
    45  	for _, lda := range []int{max(1, colA), colA + 2} {
    46  		for _, ldc := range []int{max(1, n), n + 4} {
    47  			for _, alpha := range []complex128{0, 1, complex(0.7, -0.9)} {
    48  				for _, beta := range []complex128{0, 1, complex(1.3, -1.1)} {
    49  					for _, nanC := range []bool{false, true} {
    50  						if nanC && beta != 0 {
    51  							// Skip tests with C containing NaN values
    52  							// unless beta would zero out the NaNs.
    53  							continue
    54  						}
    55  
    56  						// Allocate the matrix A and fill it with random numbers.
    57  						a := make([]complex128, rowA*lda)
    58  						for i := range a {
    59  							a[i] = rndComplex128(rnd)
    60  						}
    61  						// Create a copy of A for checking that
    62  						// Zsyrk does not modify A.
    63  						aCopy := make([]complex128, len(a))
    64  						copy(aCopy, a)
    65  
    66  						// Allocate the matrix C and fill it with random numbers.
    67  						c := make([]complex128, n*ldc)
    68  						for i := range c {
    69  							c[i] = rndComplex128(rnd)
    70  						}
    71  						if nanC {
    72  							for i := 0; i < n; i++ {
    73  								for j := 0; j < n; j++ {
    74  									c[i+j*ldc] = cmplx.NaN()
    75  								}
    76  							}
    77  						}
    78  
    79  						// Create a copy of C for checking that
    80  						// Zsyrk does not modify its triangle
    81  						// opposite to uplo.
    82  						cCopy := make([]complex128, len(c))
    83  						copy(cCopy, c)
    84  						// Create a copy of C expanded into a
    85  						// full symmetric matrix for computing
    86  						// the expected result using zmm.
    87  						cSym := make([]complex128, len(c))
    88  						copy(cSym, c)
    89  						if uplo == blas.Upper {
    90  							for i := 0; i < n-1; i++ {
    91  								for j := i + 1; j < n; j++ {
    92  									cSym[j*ldc+i] = cSym[i*ldc+j]
    93  								}
    94  							}
    95  						} else {
    96  							for i := 1; i < n; i++ {
    97  								for j := 0; j < i; j++ {
    98  									cSym[j*ldc+i] = cSym[i*ldc+j]
    99  								}
   100  							}
   101  						}
   102  
   103  						// Compute the expected result using an internal Zgemm implementation.
   104  						var want []complex128
   105  						if trans == blas.NoTrans {
   106  							want = zmm(blas.NoTrans, blas.Trans, n, n, k, alpha, a, lda, a, lda, beta, cSym, ldc)
   107  						} else {
   108  							want = zmm(blas.Trans, blas.NoTrans, n, n, k, alpha, a, lda, a, lda, beta, cSym, ldc)
   109  						}
   110  
   111  						// Compute the result using Zsyrk.
   112  						impl.Zsyrk(uplo, trans, n, k, alpha, a, lda, beta, c, ldc)
   113  
   114  						prefix := fmt.Sprintf("n=%v,k=%v,lda=%v,ldc=%v,alpha=%v,beta=%v", n, k, lda, ldc, alpha, beta)
   115  
   116  						if !zsame(a, aCopy) {
   117  							t.Errorf("%v: unexpected modification of A", prefix)
   118  							continue
   119  						}
   120  						if uplo == blas.Upper && !zSameLowerTri(n, c, ldc, cCopy, ldc) {
   121  							t.Errorf("%v: unexpected modification in lower triangle of C", prefix)
   122  							continue
   123  						}
   124  						if uplo == blas.Lower && !zSameUpperTri(n, c, ldc, cCopy, ldc) {
   125  							t.Errorf("%v: unexpected modification in upper triangle of C", prefix)
   126  							continue
   127  						}
   128  
   129  						// Expand C into a full symmetric matrix
   130  						// for comparison with the result from zmm.
   131  						if uplo == blas.Upper {
   132  							for i := 0; i < n-1; i++ {
   133  								for j := i + 1; j < n; j++ {
   134  									c[j*ldc+i] = c[i*ldc+j]
   135  								}
   136  							}
   137  						} else {
   138  							for i := 1; i < n; i++ {
   139  								for j := 0; j < i; j++ {
   140  									c[j*ldc+i] = c[i*ldc+j]
   141  								}
   142  							}
   143  						}
   144  						if !zEqualApprox(c, want, tol) {
   145  							t.Errorf("%v: unexpected result", prefix)
   146  						}
   147  					}
   148  				}
   149  			}
   150  		}
   151  	}
   152  }