gonum.org/v1/gonum@v0.14.0/blas/testblas/ztrmm.go (about)

     1  // Copyright ©2019 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testblas
     6  
     7  import (
     8  	"fmt"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"gonum.org/v1/gonum/blas"
    14  )
    15  
    16  type Ztrmmer interface {
    17  	Ztrmm(side blas.Side, uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int)
    18  }
    19  
    20  func ZtrmmTest(t *testing.T, impl Ztrmmer) {
    21  	for _, side := range []blas.Side{blas.Left, blas.Right} {
    22  		for _, uplo := range []blas.Uplo{blas.Lower, blas.Upper} {
    23  			for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans, blas.ConjTrans} {
    24  				for _, diag := range []blas.Diag{blas.Unit, blas.NonUnit} {
    25  					name := sideString(side) + "-" + uploString(uplo) + "-" + transString(trans) + "-" + diagString(diag)
    26  					t.Run(name, func(t *testing.T) {
    27  						for _, m := range []int{0, 1, 2, 3, 4, 5} {
    28  							for _, n := range []int{0, 1, 2, 3, 4, 5} {
    29  								ztrmmTest(t, impl, side, uplo, trans, diag, m, n)
    30  							}
    31  						}
    32  					})
    33  				}
    34  			}
    35  		}
    36  	}
    37  }
    38  
    39  func ztrmmTest(t *testing.T, impl Ztrmmer, side blas.Side, uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, m, n int) {
    40  	const tol = 1e-13
    41  
    42  	rnd := rand.New(rand.NewSource(1))
    43  
    44  	nA := m
    45  	if side == blas.Right {
    46  		nA = n
    47  	}
    48  	for _, lda := range []int{max(1, nA), nA + 2} {
    49  		for _, ldb := range []int{max(1, n), n + 3} {
    50  			for _, alpha := range []complex128{0, 1, complex(0.7, -0.9)} {
    51  				// Allocate the matrix A and fill it with random numbers.
    52  				a := make([]complex128, nA*lda)
    53  				for i := range a {
    54  					a[i] = rndComplex128(rnd)
    55  				}
    56  				// Put a zero into A to cover special cases in Ztrmm.
    57  				if nA > 1 {
    58  					if uplo == blas.Upper {
    59  						a[nA-1] = 0
    60  					} else {
    61  						a[(nA-1)*lda] = 0
    62  					}
    63  				}
    64  				// Create a copy of A for checking that Ztrmm
    65  				// does not modify its triangle opposite to
    66  				// uplo.
    67  				aCopy := make([]complex128, len(a))
    68  				copy(aCopy, a)
    69  				// Create a dense representation of A for
    70  				// computing the expected result using zmm.
    71  				aTri := make([]complex128, len(a))
    72  				copy(aTri, a)
    73  				if uplo == blas.Upper {
    74  					for i := 0; i < nA; i++ {
    75  						// Zero out the lower triangle.
    76  						for j := 0; j < i; j++ {
    77  							aTri[i*lda+j] = 0
    78  						}
    79  						if diag == blas.Unit {
    80  							aTri[i*lda+i] = 1
    81  						}
    82  					}
    83  				} else {
    84  					for i := 0; i < nA; i++ {
    85  						if diag == blas.Unit {
    86  							aTri[i*lda+i] = 1
    87  						}
    88  						// Zero out the upper triangle.
    89  						for j := i + 1; j < nA; j++ {
    90  							aTri[i*lda+j] = 0
    91  						}
    92  					}
    93  				}
    94  
    95  				// Allocate the matrix B and fill it with random numbers.
    96  				b := make([]complex128, m*ldb)
    97  				for i := range b {
    98  					b[i] = rndComplex128(rnd)
    99  				}
   100  				// Put a zero into B to cover special cases in Ztrmm.
   101  				if m > 0 && n > 0 {
   102  					b[0] = 0
   103  				}
   104  
   105  				// Compute the expected result using an internal Zgemm implementation.
   106  				var want []complex128
   107  				if side == blas.Left {
   108  					want = zmm(trans, blas.NoTrans, m, n, m, alpha, aTri, lda, b, ldb, 0, b, ldb)
   109  				} else {
   110  					want = zmm(blas.NoTrans, trans, m, n, n, alpha, b, ldb, aTri, lda, 0, b, ldb)
   111  				}
   112  
   113  				// Compute the result using Ztrmm.
   114  				impl.Ztrmm(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb)
   115  
   116  				prefix := fmt.Sprintf("m=%v,n=%v,lda=%v,ldb=%v,alpha=%v", m, n, lda, ldb, alpha)
   117  				if !zsame(a, aCopy) {
   118  					t.Errorf("%v: unexpected modification of A", prefix)
   119  					continue
   120  				}
   121  
   122  				if !zEqualApprox(b, want, tol) {
   123  					t.Errorf("%v: unexpected result", prefix)
   124  				}
   125  			}
   126  		}
   127  	}
   128  }