github.com/gonum/matrix@v0.0.0-20181209220409-c518dec07be9/mat64/product_test.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat64
     6  
     7  import (
     8  	"fmt"
     9  	"math/rand"
    10  	"testing"
    11  )
    12  
    13  type dims struct{ r, c int }
    14  
    15  var productTests = []struct {
    16  	n       int
    17  	factors []dims
    18  	product dims
    19  	panics  bool
    20  }{
    21  	{
    22  		n:       1,
    23  		factors: []dims{{3, 4}},
    24  		product: dims{3, 4},
    25  		panics:  false,
    26  	},
    27  	{
    28  		n:       1,
    29  		factors: []dims{{2, 4}},
    30  		product: dims{3, 4},
    31  		panics:  true,
    32  	},
    33  	{
    34  		n:       3,
    35  		factors: []dims{{10, 30}, {30, 5}, {5, 60}},
    36  		product: dims{10, 60},
    37  		panics:  false,
    38  	},
    39  	{
    40  		n:       3,
    41  		factors: []dims{{100, 30}, {30, 5}, {5, 60}},
    42  		product: dims{10, 60},
    43  		panics:  true,
    44  	},
    45  	{
    46  		n:       7,
    47  		factors: []dims{{60, 5}, {5, 5}, {5, 4}, {4, 10}, {10, 22}, {22, 45}, {45, 10}},
    48  		product: dims{60, 10},
    49  		panics:  false,
    50  	},
    51  	{
    52  		n:       7,
    53  		factors: []dims{{60, 5}, {5, 5}, {5, 400}, {4, 10}, {10, 22}, {22, 45}, {45, 10}},
    54  		product: dims{60, 10},
    55  		panics:  true,
    56  	},
    57  	{
    58  		n:       3,
    59  		factors: []dims{{1, 1000}, {1000, 2}, {2, 2}},
    60  		product: dims{1, 2},
    61  		panics:  false,
    62  	},
    63  
    64  	// Random chains.
    65  	{
    66  		n:       0,
    67  		product: dims{0, 0},
    68  		panics:  false,
    69  	},
    70  	{
    71  		n:       2,
    72  		product: dims{60, 10},
    73  		panics:  false,
    74  	},
    75  	{
    76  		n:       3,
    77  		product: dims{60, 10},
    78  		panics:  false,
    79  	},
    80  	{
    81  		n:       4,
    82  		product: dims{60, 10},
    83  		panics:  false,
    84  	},
    85  	{
    86  		n:       10,
    87  		product: dims{60, 10},
    88  		panics:  false,
    89  	},
    90  }
    91  
    92  func TestProduct(t *testing.T) {
    93  	for _, test := range productTests {
    94  		dimensions := test.factors
    95  		if dimensions == nil && test.n > 0 {
    96  			dimensions = make([]dims, test.n)
    97  			for i := range dimensions {
    98  				if i != 0 {
    99  					dimensions[i].r = dimensions[i-1].c
   100  				}
   101  				dimensions[i].c = rand.Intn(50) + 1
   102  			}
   103  			dimensions[0].r = test.product.r
   104  			dimensions[test.n-1].c = test.product.c
   105  		}
   106  		factors := make([]Matrix, test.n)
   107  		for i, d := range dimensions {
   108  			data := make([]float64, d.r*d.c)
   109  			for i := range data {
   110  				data[i] = rand.Float64()
   111  			}
   112  			factors[i] = NewDense(d.r, d.c, data)
   113  		}
   114  
   115  		want := &Dense{}
   116  		if !test.panics {
   117  			a := &Dense{}
   118  			for i, b := range factors {
   119  				if i == 0 {
   120  					want.Clone(b)
   121  					continue
   122  				}
   123  				a, want = want, &Dense{}
   124  				want.Mul(a, b)
   125  			}
   126  		}
   127  
   128  		got := NewDense(test.product.r, test.product.c, nil)
   129  		panicked, message := panics(func() {
   130  			got.Product(factors...)
   131  		})
   132  		if test.panics {
   133  			if !panicked {
   134  				t.Errorf("fail to panic with product chain dimensions: %+v result dimension: %+v",
   135  					dimensions, test.product)
   136  			}
   137  			continue
   138  		} else if panicked {
   139  			t.Errorf("unexpected panic %q with product chain dimensions: %+v result dimension: %+v",
   140  				message, dimensions, test.product)
   141  			continue
   142  		}
   143  
   144  		if len(factors) > 0 {
   145  			p := newMultiplier(NewDense(test.product.r, test.product.c, nil), factors)
   146  			p.optimize()
   147  			gotCost := p.table.at(0, len(factors)-1).cost
   148  			expr, wantCost, ok := bestExpressionFor(dimensions)
   149  			if !ok {
   150  				t.Fatal("unexpected number of expressions in brute force expression search")
   151  			}
   152  			if gotCost != wantCost {
   153  				t.Errorf("unexpected cost for chain dimensions: %+v got: %v want: %v\n%s",
   154  					dimensions, got, want, expr)
   155  			}
   156  		}
   157  
   158  		if !EqualApprox(got, want, 1e-14) {
   159  			t.Errorf("unexpected result from product chain dimensions: %+v", dimensions)
   160  		}
   161  	}
   162  }
   163  
   164  // node is a subexpression node.
   165  type node struct {
   166  	dims
   167  	left, right *node
   168  }
   169  
   170  func (n *node) String() string {
   171  	if n.left == nil || n.right == nil {
   172  		rows, cols := n.shape()
   173  		return fmt.Sprintf("[%d×%d]", rows, cols)
   174  	}
   175  	rows, cols := n.shape()
   176  	return fmt.Sprintf("(%s * %s):[%d×%d]", n.left, n.right, rows, cols)
   177  }
   178  
   179  // shape returns the dimensions of the result of the subexpression.
   180  func (n *node) shape() (rows, cols int) {
   181  	if n.left == nil || n.right == nil {
   182  		return n.r, n.c
   183  	}
   184  	rows, _ = n.left.shape()
   185  	_, cols = n.right.shape()
   186  	return rows, cols
   187  }
   188  
   189  // cost returns the cost to evaluate the subexpression.
   190  func (n *node) cost() int {
   191  	if n.left == nil || n.right == nil {
   192  		return 0
   193  	}
   194  	lr, lc := n.left.shape()
   195  	_, rc := n.right.shape()
   196  	return lr*lc*rc + n.left.cost() + n.right.cost()
   197  }
   198  
   199  // expressionsFor returns a channel that can be used to iterate over all
   200  // expressions of the given factor dimensions.
   201  func expressionsFor(factors []dims) chan *node {
   202  	if len(factors) == 1 {
   203  		c := make(chan *node, 1)
   204  		c <- &node{dims: factors[0]}
   205  		close(c)
   206  		return c
   207  	}
   208  	c := make(chan *node)
   209  	go func() {
   210  		for i := 1; i < len(factors); i++ {
   211  			for left := range expressionsFor(factors[:i]) {
   212  				for right := range expressionsFor(factors[i:]) {
   213  					c <- &node{left: left, right: right}
   214  				}
   215  			}
   216  		}
   217  		close(c)
   218  	}()
   219  	return c
   220  }
   221  
   222  // catalan returns the nth 0-based Catalan number.
   223  func catalan(n int) int {
   224  	p := 1
   225  	for k := n + 1; k < 2*n+1; k++ {
   226  		p *= k
   227  	}
   228  	for k := 2; k < n+2; k++ {
   229  		p /= k
   230  	}
   231  	return p
   232  }
   233  
   234  // bestExpressonFor returns the lowest cost expression for the given expression
   235  // factor dimensions, the cost of the expression and whether the number of
   236  // expressions searched matches the Catalan number for the number of factors.
   237  func bestExpressionFor(factors []dims) (exp *node, cost int, ok bool) {
   238  	const maxInt = int(^uint(0) >> 1)
   239  	min := maxInt
   240  	var best *node
   241  	var n int
   242  	for exp := range expressionsFor(factors) {
   243  		n++
   244  		cost := exp.cost()
   245  		if cost < min {
   246  			min = cost
   247  			best = exp
   248  		}
   249  	}
   250  	return best, min, n == catalan(len(factors)-1)
   251  }