github.com/gopherd/gonum@v0.0.4/blas/testblas/zher2k.go (about)

     1  // Copyright ©2019 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testblas
     6  
     7  import (
     8  	"fmt"
     9  	"math/cmplx"
    10  	"testing"
    11  
    12  	"math/rand"
    13  	"github.com/gopherd/gonum/blas"
    14  )
    15  
    16  type Zher2ker interface {
    17  	Zher2k(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta float64, c []complex128, ldc int)
    18  }
    19  
    20  func Zher2kTest(t *testing.T, impl Zher2ker) {
    21  	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
    22  		for _, trans := range []blas.Transpose{blas.NoTrans, blas.ConjTrans} {
    23  			name := uploString(uplo) + "-" + transString(trans)
    24  			t.Run(name, func(t *testing.T) {
    25  				for _, n := range []int{0, 1, 2, 3, 4, 5} {
    26  					for _, k := range []int{0, 1, 2, 3, 4, 5, 7} {
    27  						zher2kTest(t, impl, uplo, trans, n, k)
    28  					}
    29  				}
    30  			})
    31  		}
    32  	}
    33  }
    34  
    35  func zher2kTest(t *testing.T, impl Zher2ker, uplo blas.Uplo, trans blas.Transpose, n, k int) {
    36  	const tol = 1e-13
    37  
    38  	rnd := rand.New(rand.NewSource(1))
    39  
    40  	row, col := n, k
    41  	if trans == blas.ConjTrans {
    42  		row, col = k, n
    43  	}
    44  	for _, lda := range []int{max(1, col), col + 2} {
    45  		for _, ldb := range []int{max(1, col), col + 3} {
    46  			for _, ldc := range []int{max(1, n), n + 4} {
    47  				for _, alpha := range []complex128{0, 1, complex(0.7, -0.9)} {
    48  					for _, beta := range []float64{0, 1, 1.3} {
    49  						// Allocate the matrix A and fill it with random numbers.
    50  						a := make([]complex128, row*lda)
    51  						for i := range a {
    52  							a[i] = rndComplex128(rnd)
    53  						}
    54  						// Create a copy of A for checking that
    55  						// Zher2k does not modify A.
    56  						aCopy := make([]complex128, len(a))
    57  						copy(aCopy, a)
    58  
    59  						// Allocate the matrix B and fill it with random numbers.
    60  						b := make([]complex128, row*ldb)
    61  						for i := range b {
    62  							b[i] = rndComplex128(rnd)
    63  						}
    64  						// Create a copy of B for checking that
    65  						// Zher2k does not modify B.
    66  						bCopy := make([]complex128, len(b))
    67  						copy(bCopy, b)
    68  
    69  						// Allocate the matrix C and fill it with random numbers.
    70  						c := make([]complex128, n*ldc)
    71  						for i := range c {
    72  							c[i] = rndComplex128(rnd)
    73  						}
    74  						if (alpha == 0 || k == 0) && beta == 1 {
    75  							// In case of a quick return
    76  							// zero out the diagonal.
    77  							for i := 0; i < n; i++ {
    78  								c[i*ldc+i] = complex(real(c[i*ldc+i]), 0)
    79  							}
    80  						}
    81  						// Create a copy of C for checking that
    82  						// Zher2k does not modify its triangle
    83  						// opposite to uplo.
    84  						cCopy := make([]complex128, len(c))
    85  						copy(cCopy, c)
    86  						// Create a copy of C expanded into a
    87  						// full hermitian matrix for computing
    88  						// the expected result using zmm.
    89  						cHer := make([]complex128, len(c))
    90  						copy(cHer, c)
    91  						if uplo == blas.Upper {
    92  							for i := 0; i < n; i++ {
    93  								cHer[i*ldc+i] = complex(real(cHer[i*ldc+i]), 0)
    94  								for j := i + 1; j < n; j++ {
    95  									cHer[j*ldc+i] = cmplx.Conj(cHer[i*ldc+j])
    96  								}
    97  							}
    98  						} else {
    99  							for i := 0; i < n; i++ {
   100  								for j := 0; j < i; j++ {
   101  									cHer[j*ldc+i] = cmplx.Conj(cHer[i*ldc+j])
   102  								}
   103  								cHer[i*ldc+i] = complex(real(cHer[i*ldc+i]), 0)
   104  							}
   105  						}
   106  
   107  						// Compute the expected result using an internal Zgemm implementation.
   108  						var want []complex128
   109  						if trans == blas.NoTrans {
   110  							//  C = alpha*A*Bᴴ + conj(alpha)*B*Aᴴ + beta*C
   111  							tmp := zmm(blas.NoTrans, blas.ConjTrans, n, n, k, alpha, a, lda, b, ldb, complex(beta, 0), cHer, ldc)
   112  							want = zmm(blas.NoTrans, blas.ConjTrans, n, n, k, cmplx.Conj(alpha), b, ldb, a, lda, 1, tmp, ldc)
   113  						} else {
   114  							//  C = alpha*Aᴴ*B + conj(alpha)*Bᴴ*A + beta*C
   115  							tmp := zmm(blas.ConjTrans, blas.NoTrans, n, n, k, alpha, a, lda, b, ldb, complex(beta, 0), cHer, ldc)
   116  							want = zmm(blas.ConjTrans, blas.NoTrans, n, n, k, cmplx.Conj(alpha), b, ldb, a, lda, 1, tmp, ldc)
   117  						}
   118  
   119  						// Compute the result using Zher2k.
   120  						impl.Zher2k(uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
   121  
   122  						prefix := fmt.Sprintf("n=%v,k=%v,lda=%v,ldb=%v,ldc=%v,alpha=%v,beta=%v", n, k, lda, ldb, ldc, alpha, beta)
   123  
   124  						if !zsame(a, aCopy) {
   125  							t.Errorf("%v: unexpected modification of A", prefix)
   126  							continue
   127  						}
   128  						if !zsame(b, bCopy) {
   129  							t.Errorf("%v: unexpected modification of B", prefix)
   130  							continue
   131  						}
   132  						if uplo == blas.Upper && !zSameLowerTri(n, c, ldc, cCopy, ldc) {
   133  							t.Errorf("%v: unexpected modification in lower triangle of C", prefix)
   134  							continue
   135  						}
   136  						if uplo == blas.Lower && !zSameUpperTri(n, c, ldc, cCopy, ldc) {
   137  							t.Errorf("%v: unexpected modification in upper triangle of C", prefix)
   138  							continue
   139  						}
   140  
   141  						// Check that the diagonal of C has only real elements.
   142  						hasRealDiag := true
   143  						for i := 0; i < n; i++ {
   144  							if imag(c[i*ldc+i]) != 0 {
   145  								hasRealDiag = false
   146  								break
   147  							}
   148  						}
   149  						if !hasRealDiag {
   150  							t.Errorf("%v: diagonal of C has imaginary elements\ngot=%v", prefix, c)
   151  							continue
   152  						}
   153  
   154  						// Expand C into a full hermitian matrix
   155  						// for comparison with the result from zmm.
   156  						if uplo == blas.Upper {
   157  							for i := 0; i < n-1; i++ {
   158  								for j := i + 1; j < n; j++ {
   159  									c[j*ldc+i] = cmplx.Conj(c[i*ldc+j])
   160  								}
   161  							}
   162  						} else {
   163  							for i := 1; i < n; i++ {
   164  								for j := 0; j < i; j++ {
   165  									c[j*ldc+i] = cmplx.Conj(c[i*ldc+j])
   166  								}
   167  							}
   168  						}
   169  						if !zEqualApprox(c, want, tol) {
   170  							t.Errorf("%v: unexpected result\nwant=%v\ngot= %v", prefix, want, c)
   171  						}
   172  					}
   173  				}
   174  			}
   175  		}
   176  	}
   177  }