gonum.org/v1/gonum@v0.14.0/mat/product_test.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"fmt"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  )
    13  
    14  type dims struct{ r, c int }
    15  
    16  var productTests = []struct {
    17  	n       int
    18  	factors []dims
    19  	product dims
    20  	panics  bool
    21  }{
    22  	{
    23  		n:       1,
    24  		factors: []dims{{3, 4}},
    25  		product: dims{3, 4},
    26  		panics:  false,
    27  	},
    28  	{
    29  		n:       1,
    30  		factors: []dims{{2, 4}},
    31  		product: dims{3, 4},
    32  		panics:  true,
    33  	},
    34  	{
    35  		n:       3,
    36  		factors: []dims{{10, 30}, {30, 5}, {5, 60}},
    37  		product: dims{10, 60},
    38  		panics:  false,
    39  	},
    40  	{
    41  		n:       3,
    42  		factors: []dims{{100, 30}, {30, 5}, {5, 60}},
    43  		product: dims{10, 60},
    44  		panics:  true,
    45  	},
    46  	{
    47  		n:       7,
    48  		factors: []dims{{60, 5}, {5, 5}, {5, 4}, {4, 10}, {10, 22}, {22, 45}, {45, 10}},
    49  		product: dims{60, 10},
    50  		panics:  false,
    51  	},
    52  	{
    53  		n:       7,
    54  		factors: []dims{{60, 5}, {5, 5}, {5, 400}, {4, 10}, {10, 22}, {22, 45}, {45, 10}},
    55  		product: dims{60, 10},
    56  		panics:  true,
    57  	},
    58  	{
    59  		n:       3,
    60  		factors: []dims{{1, 1000}, {1000, 2}, {2, 2}},
    61  		product: dims{1, 2},
    62  		panics:  false,
    63  	},
    64  
    65  	// Random chains.
    66  	{
    67  		n:       0,
    68  		product: dims{0, 0},
    69  		panics:  false,
    70  	},
    71  	{
    72  		n:       2,
    73  		product: dims{60, 10},
    74  		panics:  false,
    75  	},
    76  	{
    77  		n:       3,
    78  		product: dims{60, 10},
    79  		panics:  false,
    80  	},
    81  	{
    82  		n:       4,
    83  		product: dims{60, 10},
    84  		panics:  false,
    85  	},
    86  	{
    87  		n:       10,
    88  		product: dims{60, 10},
    89  		panics:  false,
    90  	},
    91  }
    92  
    93  func TestProduct(t *testing.T) {
    94  	t.Parallel()
    95  	rnd := rand.New(rand.NewSource(1))
    96  	for _, test := range productTests {
    97  		dimensions := test.factors
    98  		if dimensions == nil && test.n > 0 {
    99  			dimensions = make([]dims, test.n)
   100  			for i := range dimensions {
   101  				if i != 0 {
   102  					dimensions[i].r = dimensions[i-1].c
   103  				}
   104  				dimensions[i].c = rnd.Intn(50) + 1
   105  			}
   106  			dimensions[0].r = test.product.r
   107  			dimensions[test.n-1].c = test.product.c
   108  		}
   109  		factors := make([]Matrix, test.n)
   110  		for i, d := range dimensions {
   111  			data := make([]float64, d.r*d.c)
   112  			for i := range data {
   113  				data[i] = rnd.Float64()
   114  			}
   115  			factors[i] = NewDense(d.r, d.c, data)
   116  		}
   117  
   118  		want := &Dense{}
   119  		if !test.panics {
   120  			var a *Dense
   121  			for i, b := range factors {
   122  				if i == 0 {
   123  					want.CloneFrom(b)
   124  					continue
   125  				}
   126  				a, want = want, &Dense{}
   127  				want.Mul(a, b)
   128  			}
   129  		}
   130  
   131  		got := &Dense{}
   132  		if test.product.r != 0 && test.product.c != 0 {
   133  			got = NewDense(test.product.r, test.product.c, nil)
   134  		}
   135  		panicked, message := panics(func() {
   136  			got.Product(factors...)
   137  		})
   138  		if test.panics {
   139  			if !panicked {
   140  				t.Errorf("fail to panic with product chain dimensions: %+v result dimension: %+v",
   141  					dimensions, test.product)
   142  			}
   143  			continue
   144  		} else if panicked {
   145  			t.Errorf("unexpected panic %q with product chain dimensions: %+v result dimension: %+v",
   146  				message, dimensions, test.product)
   147  			continue
   148  		}
   149  
   150  		if len(factors) > 0 {
   151  			p := newMultiplier(NewDense(test.product.r, test.product.c, nil), factors)
   152  			p.optimize()
   153  			gotCost := p.table.at(0, len(factors)-1).cost
   154  			expr, wantCost, ok := bestExpressionFor(dimensions)
   155  			if !ok {
   156  				t.Fatal("unexpected number of expressions in brute force expression search")
   157  			}
   158  			if gotCost != wantCost {
   159  				t.Errorf("unexpected cost for chain dimensions: %+v got: %v want: %v\n%s",
   160  					dimensions, got, want, expr)
   161  			}
   162  		}
   163  
   164  		if !EqualApprox(got, want, 1e-14) {
   165  			t.Errorf("unexpected result from product chain dimensions: %+v", dimensions)
   166  		}
   167  	}
   168  }
   169  
   170  // node is a subexpression node.
   171  type node struct {
   172  	dims
   173  	left, right *node
   174  }
   175  
   176  func (n *node) String() string {
   177  	if n.left == nil || n.right == nil {
   178  		rows, cols := n.shape()
   179  		return fmt.Sprintf("[%d×%d]", rows, cols)
   180  	}
   181  	rows, cols := n.shape()
   182  	return fmt.Sprintf("(%s * %s):[%d×%d]", n.left, n.right, rows, cols)
   183  }
   184  
   185  // shape returns the dimensions of the result of the subexpression.
   186  func (n *node) shape() (rows, cols int) {
   187  	if n.left == nil || n.right == nil {
   188  		return n.r, n.c
   189  	}
   190  	rows, _ = n.left.shape()
   191  	_, cols = n.right.shape()
   192  	return rows, cols
   193  }
   194  
   195  // cost returns the cost to evaluate the subexpression.
   196  func (n *node) cost() int {
   197  	if n.left == nil || n.right == nil {
   198  		return 0
   199  	}
   200  	lr, lc := n.left.shape()
   201  	_, rc := n.right.shape()
   202  	return lr*lc*rc + n.left.cost() + n.right.cost()
   203  }
   204  
   205  // expressionsFor returns a channel that can be used to iterate over all
   206  // expressions of the given factor dimensions.
   207  func expressionsFor(factors []dims) chan *node {
   208  	if len(factors) == 1 {
   209  		c := make(chan *node, 1)
   210  		c <- &node{dims: factors[0]}
   211  		close(c)
   212  		return c
   213  	}
   214  	c := make(chan *node)
   215  	go func() {
   216  		for i := 1; i < len(factors); i++ {
   217  			for left := range expressionsFor(factors[:i]) {
   218  				for right := range expressionsFor(factors[i:]) {
   219  					c <- &node{left: left, right: right}
   220  				}
   221  			}
   222  		}
   223  		close(c)
   224  	}()
   225  	return c
   226  }
   227  
   228  // catalan returns the nth 0-based Catalan number.
   229  func catalan(n int) int {
   230  	// Work in 64-bit integers since we overflow 32-bits for some tests.
   231  	p := int64(1)
   232  	for k := n + 1; k < 2*n+1; k++ {
   233  		p *= int64(k)
   234  	}
   235  	for k := 2; k < n+2; k++ {
   236  		p /= int64(k)
   237  	}
   238  	return int(p)
   239  }
   240  
   241  // bestExpressonFor returns the lowest cost expression for the given expression
   242  // factor dimensions, the cost of the expression and whether the number of
   243  // expressions searched matches the Catalan number for the number of factors.
   244  func bestExpressionFor(factors []dims) (exp *node, cost int, ok bool) {
   245  	const maxInt = int(^uint(0) >> 1)
   246  	min := maxInt
   247  	var best *node
   248  	var n int
   249  	for exp := range expressionsFor(factors) {
   250  		n++
   251  		cost := exp.cost()
   252  		if cost < min {
   253  			min = cost
   254  			best = exp
   255  		}
   256  	}
   257  	return best, min, n == catalan(len(factors)-1)
   258  }