github.com/gonum/matrix@v0.0.0-20181209220409-c518dec07be9/mat64/product_test.go (about) 1 // Copyright ©2015 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat64 6 7 import ( 8 "fmt" 9 "math/rand" 10 "testing" 11 ) 12 13 type dims struct{ r, c int } 14 15 var productTests = []struct { 16 n int 17 factors []dims 18 product dims 19 panics bool 20 }{ 21 { 22 n: 1, 23 factors: []dims{{3, 4}}, 24 product: dims{3, 4}, 25 panics: false, 26 }, 27 { 28 n: 1, 29 factors: []dims{{2, 4}}, 30 product: dims{3, 4}, 31 panics: true, 32 }, 33 { 34 n: 3, 35 factors: []dims{{10, 30}, {30, 5}, {5, 60}}, 36 product: dims{10, 60}, 37 panics: false, 38 }, 39 { 40 n: 3, 41 factors: []dims{{100, 30}, {30, 5}, {5, 60}}, 42 product: dims{10, 60}, 43 panics: true, 44 }, 45 { 46 n: 7, 47 factors: []dims{{60, 5}, {5, 5}, {5, 4}, {4, 10}, {10, 22}, {22, 45}, {45, 10}}, 48 product: dims{60, 10}, 49 panics: false, 50 }, 51 { 52 n: 7, 53 factors: []dims{{60, 5}, {5, 5}, {5, 400}, {4, 10}, {10, 22}, {22, 45}, {45, 10}}, 54 product: dims{60, 10}, 55 panics: true, 56 }, 57 { 58 n: 3, 59 factors: []dims{{1, 1000}, {1000, 2}, {2, 2}}, 60 product: dims{1, 2}, 61 panics: false, 62 }, 63 64 // Random chains. 65 { 66 n: 0, 67 product: dims{0, 0}, 68 panics: false, 69 }, 70 { 71 n: 2, 72 product: dims{60, 10}, 73 panics: false, 74 }, 75 { 76 n: 3, 77 product: dims{60, 10}, 78 panics: false, 79 }, 80 { 81 n: 4, 82 product: dims{60, 10}, 83 panics: false, 84 }, 85 { 86 n: 10, 87 product: dims{60, 10}, 88 panics: false, 89 }, 90 } 91 92 func TestProduct(t *testing.T) { 93 for _, test := range productTests { 94 dimensions := test.factors 95 if dimensions == nil && test.n > 0 { 96 dimensions = make([]dims, test.n) 97 for i := range dimensions { 98 if i != 0 { 99 dimensions[i].r = dimensions[i-1].c 100 } 101 dimensions[i].c = rand.Intn(50) + 1 102 } 103 dimensions[0].r = test.product.r 104 dimensions[test.n-1].c = test.product.c 105 } 106 factors := make([]Matrix, test.n) 107 for i, d := range dimensions { 108 data := make([]float64, d.r*d.c) 109 for i := range data { 110 data[i] = rand.Float64() 111 } 112 factors[i] = NewDense(d.r, d.c, data) 113 } 114 115 want := &Dense{} 116 if !test.panics { 117 a := &Dense{} 118 for i, b := range factors { 119 if i == 0 { 120 want.Clone(b) 121 continue 122 } 123 a, want = want, &Dense{} 124 want.Mul(a, b) 125 } 126 } 127 128 got := NewDense(test.product.r, test.product.c, nil) 129 panicked, message := panics(func() { 130 got.Product(factors...) 131 }) 132 if test.panics { 133 if !panicked { 134 t.Errorf("fail to panic with product chain dimensions: %+v result dimension: %+v", 135 dimensions, test.product) 136 } 137 continue 138 } else if panicked { 139 t.Errorf("unexpected panic %q with product chain dimensions: %+v result dimension: %+v", 140 message, dimensions, test.product) 141 continue 142 } 143 144 if len(factors) > 0 { 145 p := newMultiplier(NewDense(test.product.r, test.product.c, nil), factors) 146 p.optimize() 147 gotCost := p.table.at(0, len(factors)-1).cost 148 expr, wantCost, ok := bestExpressionFor(dimensions) 149 if !ok { 150 t.Fatal("unexpected number of expressions in brute force expression search") 151 } 152 if gotCost != wantCost { 153 t.Errorf("unexpected cost for chain dimensions: %+v got: %v want: %v\n%s", 154 dimensions, got, want, expr) 155 } 156 } 157 158 if !EqualApprox(got, want, 1e-14) { 159 t.Errorf("unexpected result from product chain dimensions: %+v", dimensions) 160 } 161 } 162 } 163 164 // node is a subexpression node. 165 type node struct { 166 dims 167 left, right *node 168 } 169 170 func (n *node) String() string { 171 if n.left == nil || n.right == nil { 172 rows, cols := n.shape() 173 return fmt.Sprintf("[%d×%d]", rows, cols) 174 } 175 rows, cols := n.shape() 176 return fmt.Sprintf("(%s * %s):[%d×%d]", n.left, n.right, rows, cols) 177 } 178 179 // shape returns the dimensions of the result of the subexpression. 180 func (n *node) shape() (rows, cols int) { 181 if n.left == nil || n.right == nil { 182 return n.r, n.c 183 } 184 rows, _ = n.left.shape() 185 _, cols = n.right.shape() 186 return rows, cols 187 } 188 189 // cost returns the cost to evaluate the subexpression. 190 func (n *node) cost() int { 191 if n.left == nil || n.right == nil { 192 return 0 193 } 194 lr, lc := n.left.shape() 195 _, rc := n.right.shape() 196 return lr*lc*rc + n.left.cost() + n.right.cost() 197 } 198 199 // expressionsFor returns a channel that can be used to iterate over all 200 // expressions of the given factor dimensions. 201 func expressionsFor(factors []dims) chan *node { 202 if len(factors) == 1 { 203 c := make(chan *node, 1) 204 c <- &node{dims: factors[0]} 205 close(c) 206 return c 207 } 208 c := make(chan *node) 209 go func() { 210 for i := 1; i < len(factors); i++ { 211 for left := range expressionsFor(factors[:i]) { 212 for right := range expressionsFor(factors[i:]) { 213 c <- &node{left: left, right: right} 214 } 215 } 216 } 217 close(c) 218 }() 219 return c 220 } 221 222 // catalan returns the nth 0-based Catalan number. 223 func catalan(n int) int { 224 p := 1 225 for k := n + 1; k < 2*n+1; k++ { 226 p *= k 227 } 228 for k := 2; k < n+2; k++ { 229 p /= k 230 } 231 return p 232 } 233 234 // bestExpressonFor returns the lowest cost expression for the given expression 235 // factor dimensions, the cost of the expression and whether the number of 236 // expressions searched matches the Catalan number for the number of factors. 237 func bestExpressionFor(factors []dims) (exp *node, cost int, ok bool) { 238 const maxInt = int(^uint(0) >> 1) 239 min := maxInt 240 var best *node 241 var n int 242 for exp := range expressionsFor(factors) { 243 n++ 244 cost := exp.cost() 245 if cost < min { 246 min = cost 247 best = exp 248 } 249 } 250 return best, min, n == catalan(len(factors)-1) 251 }