github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/blas/gonum/pardgemm_test.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"testing"
     9  
    10  	"golang.org/x/exp/rand"
    11  
    12  	"github.com/jingcheng-WU/gonum/blas"
    13  	"github.com/jingcheng-WU/gonum/floats"
    14  )
    15  
    16  func TestDgemmParallel(t *testing.T) {
    17  	rnd := rand.New(rand.NewSource(1))
    18  	for i, test := range []struct {
    19  		m     int
    20  		n     int
    21  		k     int
    22  		alpha float64
    23  		tA    blas.Transpose
    24  		tB    blas.Transpose
    25  	}{
    26  		{
    27  			m:     3,
    28  			n:     4,
    29  			k:     2,
    30  			alpha: 2.5,
    31  			tA:    blas.NoTrans,
    32  			tB:    blas.NoTrans,
    33  		},
    34  		{
    35  			m:     blockSize*2 + 5,
    36  			n:     3,
    37  			k:     2,
    38  			alpha: 2.5,
    39  			tA:    blas.NoTrans,
    40  			tB:    blas.NoTrans,
    41  		},
    42  		{
    43  			m:     3,
    44  			n:     blockSize * 2,
    45  			k:     2,
    46  			alpha: 2.5,
    47  			tA:    blas.NoTrans,
    48  			tB:    blas.NoTrans,
    49  		},
    50  		{
    51  			m:     2,
    52  			n:     3,
    53  			k:     blockSize*3 - 2,
    54  			alpha: 2.5,
    55  			tA:    blas.NoTrans,
    56  			tB:    blas.NoTrans,
    57  		},
    58  		{
    59  			m:     blockSize * minParBlock,
    60  			n:     3,
    61  			k:     2,
    62  			alpha: 2.5,
    63  			tA:    blas.NoTrans,
    64  			tB:    blas.NoTrans,
    65  		},
    66  		{
    67  			m:     3,
    68  			n:     blockSize * minParBlock,
    69  			k:     2,
    70  			alpha: 2.5,
    71  			tA:    blas.NoTrans,
    72  			tB:    blas.NoTrans,
    73  		},
    74  		{
    75  			m:     2,
    76  			n:     3,
    77  			k:     blockSize * minParBlock,
    78  			alpha: 2.5,
    79  			tA:    blas.NoTrans,
    80  			tB:    blas.NoTrans,
    81  		},
    82  		{
    83  			m:     blockSize*minParBlock + 1,
    84  			n:     blockSize * minParBlock,
    85  			k:     3,
    86  			alpha: 2.5,
    87  			tA:    blas.NoTrans,
    88  			tB:    blas.NoTrans,
    89  		},
    90  		{
    91  			m:     3,
    92  			n:     blockSize*minParBlock + 2,
    93  			k:     blockSize * 3,
    94  			alpha: 2.5,
    95  			tA:    blas.NoTrans,
    96  			tB:    blas.NoTrans,
    97  		},
    98  		{
    99  			m:     blockSize * minParBlock,
   100  			n:     3,
   101  			k:     blockSize * minParBlock,
   102  			alpha: 2.5,
   103  			tA:    blas.NoTrans,
   104  			tB:    blas.NoTrans,
   105  		},
   106  		{
   107  			m:     blockSize * minParBlock,
   108  			n:     blockSize * minParBlock,
   109  			k:     blockSize * 3,
   110  			alpha: 2.5,
   111  			tA:    blas.NoTrans,
   112  			tB:    blas.NoTrans,
   113  		},
   114  		{
   115  			m:     blockSize + blockSize/2,
   116  			n:     blockSize + blockSize/2,
   117  			k:     blockSize + blockSize/2,
   118  			alpha: 2.5,
   119  			tA:    blas.NoTrans,
   120  			tB:    blas.NoTrans,
   121  		},
   122  	} {
   123  		testMatchParallelSerial(t, rnd, i, blas.NoTrans, blas.NoTrans, test.m, test.n, test.k, test.alpha)
   124  		testMatchParallelSerial(t, rnd, i, blas.Trans, blas.NoTrans, test.m, test.n, test.k, test.alpha)
   125  		testMatchParallelSerial(t, rnd, i, blas.NoTrans, blas.Trans, test.m, test.n, test.k, test.alpha)
   126  		testMatchParallelSerial(t, rnd, i, blas.Trans, blas.Trans, test.m, test.n, test.k, test.alpha)
   127  	}
   128  }
   129  
   130  func testMatchParallelSerial(t *testing.T, rnd *rand.Rand, i int, tA, tB blas.Transpose, m, n, k int, alpha float64) {
   131  	var (
   132  		rowA, colA int
   133  		rowB, colB int
   134  	)
   135  	if tA == blas.NoTrans {
   136  		rowA = m
   137  		colA = k
   138  	} else {
   139  		rowA = k
   140  		colA = m
   141  	}
   142  	if tB == blas.NoTrans {
   143  		rowB = k
   144  		colB = n
   145  	} else {
   146  		rowB = n
   147  		colB = k
   148  	}
   149  
   150  	lda := colA
   151  	a := randmat(rowA, colA, lda, rnd)
   152  	aCopy := make([]float64, len(a))
   153  	copy(aCopy, a)
   154  
   155  	ldb := colB
   156  	b := randmat(rowB, colB, ldb, rnd)
   157  	bCopy := make([]float64, len(b))
   158  	copy(bCopy, b)
   159  
   160  	ldc := n
   161  	c := randmat(m, n, ldc, rnd)
   162  	want := make([]float64, len(c))
   163  	copy(want, c)
   164  
   165  	dgemmSerial(tA == blas.Trans, tB == blas.Trans, m, n, k, a, lda, b, ldb, want, ldc, alpha)
   166  	dgemmParallel(tA == blas.Trans, tB == blas.Trans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
   167  
   168  	if !floats.Equal(a, aCopy) {
   169  		t.Errorf("Case %v: a changed during call to dgemmParallel", i)
   170  	}
   171  	if !floats.Equal(b, bCopy) {
   172  		t.Errorf("Case %v: b changed during call to dgemmParallel", i)
   173  	}
   174  	if !floats.EqualApprox(c, want, 1e-12) {
   175  		t.Errorf("Case %v: answer not equal parallel and serial", i)
   176  	}
   177  }
   178  
   179  func randmat(r, c, stride int, rnd *rand.Rand) []float64 {
   180  	data := make([]float64, r*stride+c)
   181  	for i := range data {
   182  		data[i] = rnd.NormFloat64()
   183  	}
   184  	return data
   185  }