gonum.org/v1/gonum@v0.14.0/blas/testblas/zgemm.go (about)

     1  // Copyright ©2019 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testblas
     6  
     7  import (
     8  	"fmt"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  	"gonum.org/v1/gonum/blas"
    13  )
    14  
    15  type Zgemmer interface {
    16  	Zgemm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int)
    17  }
    18  
    19  func ZgemmTest(t *testing.T, impl Zgemmer) {
    20  	for _, tA := range []blas.Transpose{blas.NoTrans, blas.Trans, blas.ConjTrans} {
    21  		for _, tB := range []blas.Transpose{blas.NoTrans, blas.Trans, blas.ConjTrans} {
    22  			name := transString(tA) + "-" + transString(tB)
    23  			t.Run(name, func(t *testing.T) {
    24  				for _, m := range []int{0, 1, 2, 5, 10} {
    25  					for _, n := range []int{0, 1, 2, 5, 10} {
    26  						for _, k := range []int{0, 1, 2, 7, 11} {
    27  							zgemmTest(t, impl, tA, tB, m, n, k)
    28  						}
    29  					}
    30  				}
    31  			})
    32  		}
    33  	}
    34  }
    35  
    36  func zgemmTest(t *testing.T, impl Zgemmer, tA, tB blas.Transpose, m, n, k int) {
    37  	const tol = 1e-13
    38  
    39  	rnd := rand.New(rand.NewSource(1))
    40  
    41  	rowA, colA := m, k
    42  	if tA != blas.NoTrans {
    43  		rowA, colA = k, m
    44  	}
    45  	rowB, colB := k, n
    46  	if tB != blas.NoTrans {
    47  		rowB, colB = n, k
    48  	}
    49  
    50  	for _, lda := range []int{max(1, colA), colA + 2} {
    51  		for _, ldb := range []int{max(1, colB), colB + 3} {
    52  			for _, ldc := range []int{max(1, n), n + 4} {
    53  				for _, alpha := range []complex128{0, 1, complex(0.7, -0.9)} {
    54  					for _, beta := range []complex128{0, 1, complex(1.3, -1.1)} {
    55  						// Allocate the matrix A and fill it with random numbers.
    56  						a := make([]complex128, rowA*lda)
    57  						for i := range a {
    58  							a[i] = rndComplex128(rnd)
    59  						}
    60  						// Create a copy of A.
    61  						aCopy := make([]complex128, len(a))
    62  						copy(aCopy, a)
    63  
    64  						// Allocate the matrix B and fill it with random numbers.
    65  						b := make([]complex128, rowB*ldb)
    66  						for i := range b {
    67  							b[i] = rndComplex128(rnd)
    68  						}
    69  						// Create a copy of B.
    70  						bCopy := make([]complex128, len(b))
    71  						copy(bCopy, b)
    72  
    73  						// Allocate the matrix C and fill it with random numbers.
    74  						c := make([]complex128, m*ldc)
    75  						for i := range c {
    76  							c[i] = rndComplex128(rnd)
    77  						}
    78  
    79  						// Compute the expected result using an internal Zgemm implementation.
    80  						want := zmm(tA, tB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
    81  
    82  						// Compute a result using Zgemm.
    83  						impl.Zgemm(tA, tB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)
    84  
    85  						prefix := fmt.Sprintf("m=%v,n=%v,k=%v,lda=%v,ldb=%v,ldc=%v,alpha=%v,beta=%v", m, n, k, lda, ldb, ldc, alpha, beta)
    86  
    87  						if !zsame(a, aCopy) {
    88  							t.Errorf("%v: unexpected modification of A", prefix)
    89  							continue
    90  						}
    91  						if !zsame(b, bCopy) {
    92  							t.Errorf("%v: unexpected modification of B", prefix)
    93  							continue
    94  						}
    95  
    96  						if !zEqualApprox(c, want, tol) {
    97  							t.Errorf("%v: unexpected result,\nwant=%v\ngot =%v\n", prefix, want, c)
    98  						}
    99  					}
   100  				}
   101  			}
   102  		}
   103  	}
   104  }