github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/mat/product_test.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "fmt" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 ) 13 14 type dims struct{ r, c int } 15 16 var productTests = []struct { 17 n int 18 factors []dims 19 product dims 20 panics bool 21 }{ 22 { 23 n: 1, 24 factors: []dims{{3, 4}}, 25 product: dims{3, 4}, 26 panics: false, 27 }, 28 { 29 n: 1, 30 factors: []dims{{2, 4}}, 31 product: dims{3, 4}, 32 panics: true, 33 }, 34 { 35 n: 3, 36 factors: []dims{{10, 30}, {30, 5}, {5, 60}}, 37 product: dims{10, 60}, 38 panics: false, 39 }, 40 { 41 n: 3, 42 factors: []dims{{100, 30}, {30, 5}, {5, 60}}, 43 product: dims{10, 60}, 44 panics: true, 45 }, 46 { 47 n: 7, 48 factors: []dims{{60, 5}, {5, 5}, {5, 4}, {4, 10}, {10, 22}, {22, 45}, {45, 10}}, 49 product: dims{60, 10}, 50 panics: false, 51 }, 52 { 53 n: 7, 54 factors: []dims{{60, 5}, {5, 5}, {5, 400}, {4, 10}, {10, 22}, {22, 45}, {45, 10}}, 55 product: dims{60, 10}, 56 panics: true, 57 }, 58 { 59 n: 3, 60 factors: []dims{{1, 1000}, {1000, 2}, {2, 2}}, 61 product: dims{1, 2}, 62 panics: false, 63 }, 64 65 // Random chains. 66 { 67 n: 0, 68 product: dims{0, 0}, 69 panics: false, 70 }, 71 { 72 n: 2, 73 product: dims{60, 10}, 74 panics: false, 75 }, 76 { 77 n: 3, 78 product: dims{60, 10}, 79 panics: false, 80 }, 81 { 82 n: 4, 83 product: dims{60, 10}, 84 panics: false, 85 }, 86 { 87 n: 10, 88 product: dims{60, 10}, 89 panics: false, 90 }, 91 } 92 93 func TestProduct(t *testing.T) { 94 t.Parallel() 95 rnd := rand.New(rand.NewSource(1)) 96 for _, test := range productTests { 97 dimensions := test.factors 98 if dimensions == nil && test.n > 0 { 99 dimensions = make([]dims, test.n) 100 for i := range dimensions { 101 if i != 0 { 102 dimensions[i].r = dimensions[i-1].c 103 } 104 dimensions[i].c = rnd.Intn(50) + 1 105 } 106 dimensions[0].r = test.product.r 107 dimensions[test.n-1].c = test.product.c 108 } 109 factors := make([]Matrix, test.n) 110 for i, d := range dimensions { 111 data := make([]float64, d.r*d.c) 112 for i := range data { 113 data[i] = rnd.Float64() 114 } 115 factors[i] = NewDense(d.r, d.c, data) 116 } 117 118 want := &Dense{} 119 if !test.panics { 120 var a *Dense 121 for i, b := range factors { 122 if i == 0 { 123 want.CloneFrom(b) 124 continue 125 } 126 a, want = want, &Dense{} 127 want.Mul(a, b) 128 } 129 } 130 131 got := &Dense{} 132 if test.product.r != 0 && test.product.c != 0 { 133 got = NewDense(test.product.r, test.product.c, nil) 134 } 135 panicked, message := panics(func() { 136 got.Product(factors...) 137 }) 138 if test.panics { 139 if !panicked { 140 t.Errorf("fail to panic with product chain dimensions: %+v result dimension: %+v", 141 dimensions, test.product) 142 } 143 continue 144 } else if panicked { 145 t.Errorf("unexpected panic %q with product chain dimensions: %+v result dimension: %+v", 146 message, dimensions, test.product) 147 continue 148 } 149 150 if len(factors) > 0 { 151 p := newMultiplier(NewDense(test.product.r, test.product.c, nil), factors) 152 p.optimize() 153 gotCost := p.table.at(0, len(factors)-1).cost 154 expr, wantCost, ok := bestExpressionFor(dimensions) 155 if !ok { 156 t.Fatal("unexpected number of expressions in brute force expression search") 157 } 158 if gotCost != wantCost { 159 t.Errorf("unexpected cost for chain dimensions: %+v got: %v want: %v\n%s", 160 dimensions, got, want, expr) 161 } 162 } 163 164 if !EqualApprox(got, want, 1e-14) { 165 t.Errorf("unexpected result from product chain dimensions: %+v", dimensions) 166 } 167 } 168 } 169 170 // node is a subexpression node. 171 type node struct { 172 dims 173 left, right *node 174 } 175 176 func (n *node) String() string { 177 if n.left == nil || n.right == nil { 178 rows, cols := n.shape() 179 return fmt.Sprintf("[%d×%d]", rows, cols) 180 } 181 rows, cols := n.shape() 182 return fmt.Sprintf("(%s * %s):[%d×%d]", n.left, n.right, rows, cols) 183 } 184 185 // shape returns the dimensions of the result of the subexpression. 186 func (n *node) shape() (rows, cols int) { 187 if n.left == nil || n.right == nil { 188 return n.r, n.c 189 } 190 rows, _ = n.left.shape() 191 _, cols = n.right.shape() 192 return rows, cols 193 } 194 195 // cost returns the cost to evaluate the subexpression. 196 func (n *node) cost() int { 197 if n.left == nil || n.right == nil { 198 return 0 199 } 200 lr, lc := n.left.shape() 201 _, rc := n.right.shape() 202 return lr*lc*rc + n.left.cost() + n.right.cost() 203 } 204 205 // expressionsFor returns a channel that can be used to iterate over all 206 // expressions of the given factor dimensions. 207 func expressionsFor(factors []dims) chan *node { 208 if len(factors) == 1 { 209 c := make(chan *node, 1) 210 c <- &node{dims: factors[0]} 211 close(c) 212 return c 213 } 214 c := make(chan *node) 215 go func() { 216 for i := 1; i < len(factors); i++ { 217 for left := range expressionsFor(factors[:i]) { 218 for right := range expressionsFor(factors[i:]) { 219 c <- &node{left: left, right: right} 220 } 221 } 222 } 223 close(c) 224 }() 225 return c 226 } 227 228 // catalan returns the nth 0-based Catalan number. 229 func catalan(n int) int { 230 // Work in 64-bit integers since we overflow 32-bits for some tests. 231 p := int64(1) 232 for k := n + 1; k < 2*n+1; k++ { 233 p *= int64(k) 234 } 235 for k := 2; k < n+2; k++ { 236 p /= int64(k) 237 } 238 return int(p) 239 } 240 241 // bestExpressonFor returns the lowest cost expression for the given expression 242 // factor dimensions, the cost of the expression and whether the number of 243 // expressions searched matches the Catalan number for the number of factors. 244 func bestExpressionFor(factors []dims) (exp *node, cost int, ok bool) { 245 const maxInt = int(^uint(0) >> 1) 246 min := maxInt 247 var best *node 248 var n int 249 for exp := range expressionsFor(factors) { 250 n++ 251 cost := exp.cost() 252 if cost < min { 253 min = cost 254 best = exp 255 } 256 } 257 return best, min, n == catalan(len(factors)-1) 258 }