gonum.org/v1/gonum@v0.14.0/blas/testblas/zsyr2k.go (about)

     1  // Copyright ©2019 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testblas
     6  
     7  import (
     8  	"fmt"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"gonum.org/v1/gonum/blas"
    14  )
    15  
    16  type Zsyr2ker interface {
    17  	Zsyr2k(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int)
    18  }
    19  
    20  func Zsyr2kTest(t *testing.T, impl Zsyr2ker) {
    21  	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
    22  		for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
    23  			name := uploString(uplo) + "-" + transString(trans)
    24  			t.Run(name, func(t *testing.T) {
    25  				for _, n := range []int{0, 1, 2, 3, 4, 5} {
    26  					for _, k := range []int{0, 1, 2, 3, 4, 5, 7} {
    27  						zsyr2kTest(t, impl, uplo, trans, n, k)
    28  					}
    29  				}
    30  			})
    31  		}
    32  	}
    33  }
    34  
    35  func zsyr2kTest(t *testing.T, impl Zsyr2ker, uplo blas.Uplo, trans blas.Transpose, n, k int) {
    36  	const tol = 1e-13
    37  
    38  	rnd := rand.New(rand.NewSource(1))
    39  
    40  	row, col := n, k
    41  	if trans == blas.Trans {
    42  		row, col = k, n
    43  	}
    44  	for _, lda := range []int{max(1, col), col + 2} {
    45  		for _, ldb := range []int{max(1, col), col + 3} {
    46  			for _, ldc := range []int{max(1, n), n + 4} {
    47  				for _, alpha := range []complex128{0, 1, complex(0.7, -0.9)} {
    48  					for _, beta := range []complex128{0, 1, complex(1.3, -1.1)} {
    49  						// Allocate the matrix A and fill it with random numbers.
    50  						a := make([]complex128, row*lda)
    51  						for i := range a {
    52  							a[i] = rndComplex128(rnd)
    53  						}
    54  						// Create a copy of A for checking that
    55  						// Zsyr2k does not modify A.
    56  						aCopy := make([]complex128, len(a))
    57  						copy(aCopy, a)
    58  
    59  						// Allocate the matrix B and fill it with random numbers.
    60  						b := make([]complex128, row*ldb)
    61  						for i := range b {
    62  							b[i] = rndComplex128(rnd)
    63  						}
    64  						// Create a copy of B for checking that
    65  						// Zsyr2k does not modify B.
    66  						bCopy := make([]complex128, len(b))
    67  						copy(bCopy, b)
    68  
    69  						// Allocate the matrix C and fill it with random numbers.
    70  						c := make([]complex128, n*ldc)
    71  						for i := range c {
    72  							c[i] = rndComplex128(rnd)
    73  						}
    74  						// Create a copy of C for checking that
    75  						// Zsyr2k does not modify its triangle
    76  						// opposite to uplo.
    77  						cCopy := make([]complex128, len(c))
    78  						copy(cCopy, c)
    79  						// Create a copy of C expanded into a
    80  						// full symmetric matrix for computing
    81  						// the expected result using zmm.
    82  						cSym := make([]complex128, len(c))
    83  						copy(cSym, c)
    84  						if uplo == blas.Upper {
    85  							for i := 0; i < n-1; i++ {
    86  								for j := i + 1; j < n; j++ {
    87  									cSym[j*ldc+i] = cSym[i*ldc+j]
    88  								}
    89  							}
    90  						} else {
    91  							for i := 1; i < n; i++ {
    92  								for j := 0; j < i; j++ {
    93  									cSym[j*ldc+i] = cSym[i*ldc+j]
    94  								}
    95  							}
    96  						}
    97  
    98  						// Compute the expected result using an internal Zgemm implementation.
    99  						var want []complex128
   100  						if trans == blas.NoTrans {
   101  							//  C = alpha*A*Bᵀ + alpha*B*Aᵀ + beta*C
   102  							tmp := zmm(blas.NoTrans, blas.Trans, n, n, k, alpha, a, lda, b, ldb, beta, cSym, ldc)
   103  							want = zmm(blas.NoTrans, blas.Trans, n, n, k, alpha, b, ldb, a, lda, 1, tmp, ldc)
   104  						} else {
   105  							//  C = alpha*Aᵀ*B + alpha*Bᵀ*A + beta*C
   106  							tmp := zmm(blas.Trans, blas.NoTrans, n, n, k, alpha, a, lda, b, ldb, beta, cSym, ldc)
   107  							want = zmm(blas.Trans, blas.NoTrans, n, n, k, alpha, b, ldb, a, lda, 1, tmp, ldc)
   108  						}
   109  
   110  						// Compute the result using Zsyr2k.
   111  						impl.Zsyr2k(uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
   112  
   113  						prefix := fmt.Sprintf("n=%v,k=%v,lda=%v,ldb=%v,ldc=%v,alpha=%v,beta=%v", n, k, lda, ldb, ldc, alpha, beta)
   114  
   115  						if !zsame(a, aCopy) {
   116  							t.Errorf("%v: unexpected modification of A", prefix)
   117  							continue
   118  						}
   119  						if !zsame(b, bCopy) {
   120  							t.Errorf("%v: unexpected modification of B", prefix)
   121  							continue
   122  						}
   123  						if uplo == blas.Upper && !zSameLowerTri(n, c, ldc, cCopy, ldc) {
   124  							t.Errorf("%v: unexpected modification in lower triangle of C", prefix)
   125  							continue
   126  						}
   127  						if uplo == blas.Lower && !zSameUpperTri(n, c, ldc, cCopy, ldc) {
   128  							t.Errorf("%v: unexpected modification in upper triangle of C", prefix)
   129  							continue
   130  						}
   131  
   132  						// Expand C into a full symmetric matrix
   133  						// for comparison with the result from zmm.
   134  						if uplo == blas.Upper {
   135  							for i := 0; i < n-1; i++ {
   136  								for j := i + 1; j < n; j++ {
   137  									c[j*ldc+i] = c[i*ldc+j]
   138  								}
   139  							}
   140  						} else {
   141  							for i := 1; i < n; i++ {
   142  								for j := 0; j < i; j++ {
   143  									c[j*ldc+i] = c[i*ldc+j]
   144  								}
   145  							}
   146  						}
   147  						if !zEqualApprox(c, want, tol) {
   148  							t.Errorf("%v: unexpected result", prefix)
   149  						}
   150  					}
   151  				}
   152  			}
   153  		}
   154  	}
   155  }