gonum.org/v1/gonum@v0.14.0/blas/testblas/ztrsm.go (about)

     1  // Copyright ©2019 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testblas
     6  
     7  import (
     8  	"fmt"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  
    13  	"gonum.org/v1/gonum/blas"
    14  )
    15  
    16  type Ztrsmer interface {
    17  	Ztrsm(side blas.Side, uplo blas.Uplo, transA blas.Transpose, diag blas.Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int)
    18  }
    19  
    20  func ZtrsmTest(t *testing.T, impl Ztrsmer) {
    21  	for _, side := range []blas.Side{blas.Left, blas.Right} {
    22  		for _, uplo := range []blas.Uplo{blas.Lower, blas.Upper} {
    23  			for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans, blas.ConjTrans} {
    24  				for _, diag := range []blas.Diag{blas.Unit, blas.NonUnit} {
    25  					name := sideString(side) + "-" + uploString(uplo) + "-" + transString(trans) + "-" + diagString(diag)
    26  					t.Run(name, func(t *testing.T) {
    27  						for _, m := range []int{0, 1, 2, 3, 4, 5} {
    28  							for _, n := range []int{0, 1, 2, 3, 4, 5} {
    29  								ztrsmTest(t, impl, side, uplo, trans, diag, m, n)
    30  							}
    31  						}
    32  					})
    33  				}
    34  			}
    35  		}
    36  	}
    37  }
    38  
    39  func ztrsmTest(t *testing.T, impl Ztrsmer, side blas.Side, uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, m, n int) {
    40  	const tol = 1e-13
    41  
    42  	rnd := rand.New(rand.NewSource(1))
    43  
    44  	nA := m
    45  	if side == blas.Right {
    46  		nA = n
    47  	}
    48  	for _, lda := range []int{max(1, nA), nA + 2} {
    49  		for _, ldb := range []int{max(1, n), n + 3} {
    50  			for _, alpha := range []complex128{0, 1, complex(0.7, -0.9)} {
    51  				// Allocate the matrix A and fill it with random numbers.
    52  				a := make([]complex128, nA*lda)
    53  				for i := range a {
    54  					a[i] = rndComplex128(rnd)
    55  				}
    56  				// Set some elements of A to 0 and 1 to cover special cases in Ztrsm.
    57  				if nA > 2 {
    58  					if uplo == blas.Upper {
    59  						a[nA-2] = 1
    60  						a[nA-1] = 0
    61  					} else {
    62  						a[(nA-2)*lda] = 1
    63  						a[(nA-1)*lda] = 0
    64  					}
    65  				}
    66  				// Create a copy of A for checking that Ztrsm
    67  				// does not modify its triangle opposite to uplo.
    68  				aCopy := make([]complex128, len(a))
    69  				copy(aCopy, a)
    70  				// Create a dense representation of A for
    71  				// computing the right-hand side matrix using zmm.
    72  				aTri := make([]complex128, len(a))
    73  				copy(aTri, a)
    74  				if uplo == blas.Upper {
    75  					for i := 0; i < nA; i++ {
    76  						// Zero out the lower triangle.
    77  						for j := 0; j < i; j++ {
    78  							aTri[i*lda+j] = 0
    79  						}
    80  						if diag == blas.Unit {
    81  							aTri[i*lda+i] = 1
    82  						}
    83  					}
    84  				} else {
    85  					for i := 0; i < nA; i++ {
    86  						if diag == blas.Unit {
    87  							aTri[i*lda+i] = 1
    88  						}
    89  						// Zero out the upper triangle.
    90  						for j := i + 1; j < nA; j++ {
    91  							aTri[i*lda+j] = 0
    92  						}
    93  					}
    94  				}
    95  
    96  				// Allocate the right-hand side matrix B and fill it with random numbers.
    97  				b := make([]complex128, m*ldb)
    98  				for i := range b {
    99  					b[i] = rndComplex128(rnd)
   100  				}
   101  				// Set some elements of B to 0 to cover special cases in Ztrsm.
   102  				if m > 1 && n > 1 {
   103  					b[0] = 0
   104  					b[(m-1)*ldb+n-1] = 0
   105  				}
   106  				bCopy := make([]complex128, len(b))
   107  				copy(bCopy, b)
   108  
   109  				// Compute the solution matrix X using Ztrsm.
   110  				// X is overwritten on B.
   111  				impl.Ztrsm(side, uplo, trans, diag, m, n, alpha, a, lda, b, ldb)
   112  				x := b
   113  
   114  				prefix := fmt.Sprintf("m=%v,n=%v,lda=%v,ldb=%v,alpha=%v", m, n, lda, ldb, alpha)
   115  
   116  				if !zsame(a, aCopy) {
   117  					t.Errorf("%v: unexpected modification of A", prefix)
   118  					continue
   119  				}
   120  
   121  				// Compute the left-hand side matrix of op(A)*X=alpha*B or X*op(A)=alpha*B
   122  				// using an internal Zgemm implementation.
   123  				var lhs []complex128
   124  				if side == blas.Left {
   125  					lhs = zmm(trans, blas.NoTrans, m, n, m, 1, aTri, lda, x, ldb, 0, b, ldb)
   126  				} else {
   127  					lhs = zmm(blas.NoTrans, trans, m, n, n, 1, x, ldb, aTri, lda, 0, b, ldb)
   128  				}
   129  
   130  				// Compute the right-hand side matrix alpha*B.
   131  				rhs := bCopy
   132  				for i := 0; i < m; i++ {
   133  					for j := 0; j < n; j++ {
   134  						rhs[i*ldb+j] *= alpha
   135  					}
   136  				}
   137  
   138  				if !zEqualApprox(lhs, rhs, tol) {
   139  					t.Errorf("%v: unexpected result", prefix)
   140  				}
   141  			}
   142  		}
   143  	}
   144  }