github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/blas/testblas/dgemm.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testblas
     6  
     7  import (
     8  	"testing"
     9  
    10  	"github.com/jingcheng-WU/gonum/blas"
    11  )
    12  
    13  type Dgemmer interface {
    14  	Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int)
    15  }
    16  
    17  type DgemmCase struct {
    18  	m, n, k     int
    19  	alpha, beta float64
    20  	a           [][]float64
    21  	b           [][]float64
    22  	c           [][]float64
    23  	ans         [][]float64
    24  }
    25  
    26  var DgemmCases = []DgemmCase{
    27  
    28  	{
    29  		m:     4,
    30  		n:     3,
    31  		k:     2,
    32  		alpha: 2,
    33  		beta:  0.5,
    34  		a: [][]float64{
    35  			{1, 2},
    36  			{4, 5},
    37  			{7, 8},
    38  			{10, 11},
    39  		},
    40  		b: [][]float64{
    41  			{1, 5, 6},
    42  			{5, -8, 8},
    43  		},
    44  		c: [][]float64{
    45  			{4, 8, -9},
    46  			{12, 16, -8},
    47  			{1, 5, 15},
    48  			{-3, -4, 7},
    49  		},
    50  		ans: [][]float64{
    51  			{24, -18, 39.5},
    52  			{64, -32, 124},
    53  			{94.5, -55.5, 219.5},
    54  			{128.5, -78, 299.5},
    55  		},
    56  	},
    57  	{
    58  		m:     4,
    59  		n:     2,
    60  		k:     3,
    61  		alpha: 2,
    62  		beta:  0.5,
    63  		a: [][]float64{
    64  			{1, 2, 3},
    65  			{4, 5, 6},
    66  			{7, 8, 9},
    67  			{10, 11, 12},
    68  		},
    69  		b: [][]float64{
    70  			{1, 5},
    71  			{5, -8},
    72  			{6, 2},
    73  		},
    74  		c: [][]float64{
    75  			{4, 8},
    76  			{12, 16},
    77  			{1, 5},
    78  			{-3, -4},
    79  		},
    80  		ans: [][]float64{
    81  			{60, -6},
    82  			{136, -8},
    83  			{202.5, -19.5},
    84  			{272.5, -30},
    85  		},
    86  	},
    87  	{
    88  		m:     3,
    89  		n:     2,
    90  		k:     4,
    91  		alpha: 2,
    92  		beta:  0.5,
    93  		a: [][]float64{
    94  			{1, 2, 3, 4},
    95  			{4, 5, 6, 7},
    96  			{8, 9, 10, 11},
    97  		},
    98  		b: [][]float64{
    99  			{1, 5},
   100  			{5, -8},
   101  			{6, 2},
   102  			{8, 10},
   103  		},
   104  		c: [][]float64{
   105  			{4, 8},
   106  			{12, 16},
   107  			{9, -10},
   108  		},
   109  		ans: [][]float64{
   110  			{124, 74},
   111  			{248, 132},
   112  			{406.5, 191},
   113  		},
   114  	},
   115  	{
   116  		m:     3,
   117  		n:     4,
   118  		k:     2,
   119  		alpha: 2,
   120  		beta:  0.5,
   121  		a: [][]float64{
   122  			{1, 2},
   123  			{4, 5},
   124  			{8, 9},
   125  		},
   126  		b: [][]float64{
   127  			{1, 5, 2, 1},
   128  			{5, -8, 2, 1},
   129  		},
   130  		c: [][]float64{
   131  			{4, 8, 2, 2},
   132  			{12, 16, 8, 9},
   133  			{9, -10, 10, 10},
   134  		},
   135  		ans: [][]float64{
   136  			{24, -18, 13, 7},
   137  			{64, -32, 40, 22.5},
   138  			{110.5, -69, 73, 39},
   139  		},
   140  	},
   141  	{
   142  		m:     2,
   143  		n:     4,
   144  		k:     3,
   145  		alpha: 2,
   146  		beta:  0.5,
   147  		a: [][]float64{
   148  			{1, 2, 3},
   149  			{4, 5, 6},
   150  		},
   151  		b: [][]float64{
   152  			{1, 5, 8, 8},
   153  			{5, -8, 9, 10},
   154  			{6, 2, -3, 2},
   155  		},
   156  		c: [][]float64{
   157  			{4, 8, 7, 8},
   158  			{12, 16, -2, 6},
   159  		},
   160  		ans: [][]float64{
   161  			{60, -6, 37.5, 72},
   162  			{136, -8, 117, 191},
   163  		},
   164  	},
   165  	{
   166  		m:     2,
   167  		n:     3,
   168  		k:     4,
   169  		alpha: 2,
   170  		beta:  0.5,
   171  		a: [][]float64{
   172  			{1, 2, 3, 4},
   173  			{4, 5, 6, 7},
   174  		},
   175  		b: [][]float64{
   176  			{1, 5, 8},
   177  			{5, -8, 9},
   178  			{6, 2, -3},
   179  			{8, 10, 2},
   180  		},
   181  		c: [][]float64{
   182  			{4, 8, 1},
   183  			{12, 16, 6},
   184  		},
   185  		ans: [][]float64{
   186  			{124, 74, 50.5},
   187  			{248, 132, 149},
   188  		},
   189  	},
   190  }
   191  
   192  // assumes [][]float64 is actually a matrix
   193  func transpose(a [][]float64) [][]float64 {
   194  	b := make([][]float64, len(a[0]))
   195  	for i := range b {
   196  		b[i] = make([]float64, len(a))
   197  		for j := range b[i] {
   198  			b[i][j] = a[j][i]
   199  		}
   200  	}
   201  	return b
   202  }
   203  
   204  func TestDgemm(t *testing.T, blasser Dgemmer) {
   205  	for i, test := range DgemmCases {
   206  		// Test that it passes row major
   207  		dgemmcomp(i, "RowMajorNoTrans", t, blasser, blas.NoTrans, blas.NoTrans,
   208  			test.m, test.n, test.k, test.alpha, test.beta, test.a, test.b, test.c, test.ans)
   209  		// Try with A transposed
   210  		dgemmcomp(i, "RowMajorTransA", t, blasser, blas.Trans, blas.NoTrans,
   211  			test.m, test.n, test.k, test.alpha, test.beta, transpose(test.a), test.b, test.c, test.ans)
   212  		// Try with B transposed
   213  		dgemmcomp(i, "RowMajorTransB", t, blasser, blas.NoTrans, blas.Trans,
   214  			test.m, test.n, test.k, test.alpha, test.beta, test.a, transpose(test.b), test.c, test.ans)
   215  		// Try with both transposed
   216  		dgemmcomp(i, "RowMajorTransBoth", t, blasser, blas.Trans, blas.Trans,
   217  			test.m, test.n, test.k, test.alpha, test.beta, transpose(test.a), transpose(test.b), test.c, test.ans)
   218  	}
   219  }
   220  
   221  func dgemmcomp(i int, name string, t *testing.T, blasser Dgemmer, tA, tB blas.Transpose, m, n, k int,
   222  	alpha, beta float64, a [][]float64, b [][]float64, c [][]float64, ans [][]float64) {
   223  
   224  	aFlat := flatten(a)
   225  	aCopy := flatten(a)
   226  	bFlat := flatten(b)
   227  	bCopy := flatten(b)
   228  	cFlat := flatten(c)
   229  	ansFlat := flatten(ans)
   230  	lda := len(a[0])
   231  	ldb := len(b[0])
   232  	ldc := len(c[0])
   233  
   234  	// Compute the matrix multiplication
   235  	blasser.Dgemm(tA, tB, m, n, k, alpha, aFlat, lda, bFlat, ldb, beta, cFlat, ldc)
   236  
   237  	if !dSliceEqual(aFlat, aCopy) {
   238  		t.Errorf("Test %v case %v: a changed during call to Dgemm", i, name)
   239  	}
   240  	if !dSliceEqual(bFlat, bCopy) {
   241  		t.Errorf("Test %v case %v: b changed during call to Dgemm", i, name)
   242  	}
   243  
   244  	if !dSliceTolEqual(ansFlat, cFlat) {
   245  		t.Errorf("Test %v case %v: answer mismatch. Expected %v, Found %v", i, name, ansFlat, cFlat)
   246  	}
   247  	// TODO: Need to add a sub-slice test where don't use up full matrix
   248  }