gonum.org/v1/gonum@v0.14.0/mat/mul_test.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"testing"
     9  
    10  	"golang.org/x/exp/rand"
    11  
    12  	"gonum.org/v1/gonum/blas"
    13  	"gonum.org/v1/gonum/blas/blas64"
    14  	"gonum.org/v1/gonum/floats"
    15  )
    16  
    17  // TODO: Need to add tests where one is overwritten.
    18  func TestMulTypes(t *testing.T) {
    19  	t.Parallel()
    20  	src := rand.NewSource(1)
    21  	for _, test := range []struct {
    22  		ar     int
    23  		ac     int
    24  		br     int
    25  		bc     int
    26  		Panics bool
    27  	}{
    28  		{
    29  			ar:     5,
    30  			ac:     5,
    31  			br:     5,
    32  			bc:     5,
    33  			Panics: false,
    34  		},
    35  		{
    36  			ar:     10,
    37  			ac:     5,
    38  			br:     5,
    39  			bc:     3,
    40  			Panics: false,
    41  		},
    42  		{
    43  			ar:     10,
    44  			ac:     5,
    45  			br:     5,
    46  			bc:     8,
    47  			Panics: false,
    48  		},
    49  		{
    50  			ar:     8,
    51  			ac:     10,
    52  			br:     10,
    53  			bc:     3,
    54  			Panics: false,
    55  		},
    56  		{
    57  			ar:     8,
    58  			ac:     3,
    59  			br:     3,
    60  			bc:     10,
    61  			Panics: false,
    62  		},
    63  		{
    64  			ar:     5,
    65  			ac:     8,
    66  			br:     8,
    67  			bc:     10,
    68  			Panics: false,
    69  		},
    70  		{
    71  			ar:     5,
    72  			ac:     12,
    73  			br:     12,
    74  			bc:     8,
    75  			Panics: false,
    76  		},
    77  		{
    78  			ar:     5,
    79  			ac:     7,
    80  			br:     8,
    81  			bc:     10,
    82  			Panics: true,
    83  		},
    84  	} {
    85  		ar := test.ar
    86  		ac := test.ac
    87  		br := test.br
    88  		bc := test.bc
    89  
    90  		// Generate random matrices
    91  		avec := make([]float64, ar*ac)
    92  		randomSlice(avec, src)
    93  		a := NewDense(ar, ac, avec)
    94  
    95  		bvec := make([]float64, br*bc)
    96  		randomSlice(bvec, src)
    97  
    98  		b := NewDense(br, bc, bvec)
    99  
   100  		// Check that it panics if it is supposed to
   101  		if test.Panics {
   102  			c := &Dense{}
   103  			fn := func() {
   104  				c.Mul(a, b)
   105  			}
   106  			pan, _ := panics(fn)
   107  			if !pan {
   108  				t.Errorf("Mul did not panic with dimension mismatch")
   109  			}
   110  			continue
   111  		}
   112  
   113  		cvec := make([]float64, ar*bc)
   114  
   115  		// Get correct matrix multiply answer from blas64.Gemm
   116  		blas64.Gemm(blas.NoTrans, blas.NoTrans,
   117  			1, a.mat, b.mat,
   118  			0, blas64.General{Rows: ar, Cols: bc, Stride: bc, Data: cvec},
   119  		)
   120  
   121  		avecCopy := append([]float64{}, avec...)
   122  		bvecCopy := append([]float64{}, bvec...)
   123  		cvecCopy := append([]float64{}, cvec...)
   124  
   125  		acomp := matComp{r: ar, c: ac, data: avecCopy}
   126  		bcomp := matComp{r: br, c: bc, data: bvecCopy}
   127  		ccomp := matComp{r: ar, c: bc, data: cvecCopy}
   128  
   129  		// Do normal multiply with empty dense
   130  		d := &Dense{}
   131  
   132  		testMul(t, a, b, d, acomp, bcomp, ccomp, false, "empty receiver")
   133  
   134  		// Normal multiply with existing receiver
   135  		c := NewDense(ar, bc, cvec)
   136  		randomSlice(cvec, src)
   137  		testMul(t, a, b, c, acomp, bcomp, ccomp, false, "existing receiver")
   138  
   139  		// Cast a as a basic matrix
   140  		am := (*basicMatrix)(a)
   141  		bm := (*basicMatrix)(b)
   142  		d.Reset()
   143  		testMul(t, am, b, d, acomp, bcomp, ccomp, true, "a is basic, receiver is empty")
   144  		d.Reset()
   145  		testMul(t, a, bm, d, acomp, bcomp, ccomp, true, "b is basic, receiver is empty")
   146  		d.Reset()
   147  		testMul(t, am, bm, d, acomp, bcomp, ccomp, true, "both basic, receiver is empty")
   148  		randomSlice(cvec, src)
   149  		testMul(t, am, b, d, acomp, bcomp, ccomp, true, "a is basic, receiver is full")
   150  		randomSlice(cvec, src)
   151  		testMul(t, a, bm, d, acomp, bcomp, ccomp, true, "b is basic, receiver is full")
   152  		randomSlice(cvec, src)
   153  		testMul(t, am, bm, d, acomp, bcomp, ccomp, true, "both basic, receiver is full")
   154  	}
   155  }
   156  
   157  func randomSlice(s []float64, src rand.Source) {
   158  	rnd := rand.New(src)
   159  	for i := range s {
   160  		s[i] = rnd.NormFloat64()
   161  	}
   162  }
   163  
   164  type matComp struct {
   165  	r, c int
   166  	data []float64
   167  }
   168  
   169  func testMul(t *testing.T, a, b Matrix, c *Dense, acomp, bcomp, ccomp matComp, cvecApprox bool, name string) {
   170  	c.Mul(a, b)
   171  	var aDense *Dense
   172  	switch t := a.(type) {
   173  	case *Dense:
   174  		aDense = t
   175  	case *basicMatrix:
   176  		aDense = (*Dense)(t)
   177  	}
   178  
   179  	var bDense *Dense
   180  	switch t := b.(type) {
   181  	case *Dense:
   182  		bDense = t
   183  	case *basicMatrix:
   184  		bDense = (*Dense)(t)
   185  	}
   186  
   187  	if !denseEqual(aDense, acomp) {
   188  		t.Errorf("a changed unexpectedly for %v", name)
   189  	}
   190  	if !denseEqual(bDense, bcomp) {
   191  		t.Errorf("b changed unexpectedly for %v", name)
   192  	}
   193  	if cvecApprox {
   194  		if !denseEqualApprox(c, ccomp, 1e-14) {
   195  			t.Errorf("mul answer not within tol for %v", name)
   196  		}
   197  		return
   198  	}
   199  
   200  	if !denseEqual(c, ccomp) {
   201  		t.Errorf("mul answer not equal for %v", name)
   202  	}
   203  }
   204  
   205  func denseEqual(a *Dense, acomp matComp) bool {
   206  	ar2, ac2 := a.Dims()
   207  	if ar2 != acomp.r {
   208  		return false
   209  	}
   210  	if ac2 != acomp.c {
   211  		return false
   212  	}
   213  	if !floats.Equal(a.mat.Data, acomp.data) {
   214  		return false
   215  	}
   216  	return true
   217  }
   218  
   219  func denseEqualApprox(a *Dense, acomp matComp, tol float64) bool {
   220  	ar2, ac2 := a.Dims()
   221  	if ar2 != acomp.r {
   222  		return false
   223  	}
   224  	if ac2 != acomp.c {
   225  		return false
   226  	}
   227  	if !floats.EqualApprox(a.mat.Data, acomp.data, tol) {
   228  		return false
   229  	}
   230  	return true
   231  }