github.com/gopherd/gonum@v0.0.4/mat/product.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import "fmt" 8 9 // Product calculates the product of the given factors and places the result in 10 // the receiver. The order of multiplication operations is optimized to minimize 11 // the number of floating point operations on the basis that all matrix 12 // multiplications are general. 13 func (m *Dense) Product(factors ...Matrix) { 14 // The operation order optimisation is the naive O(n^3) dynamic 15 // programming approach and does not take into consideration 16 // finer-grained optimisations that might be available. 17 // 18 // TODO(kortschak) Consider using the O(nlogn) or O(mlogn) 19 // algorithms that are available. e.g. 20 // 21 // e.g. http://www.jofcis.com/publishedpapers/2014_10_10_4299_4306.pdf 22 // 23 // In the case that this is replaced, retain this code in 24 // tests to compare against. 25 26 r, c := m.Dims() 27 switch len(factors) { 28 case 0: 29 if r != 0 || c != 0 { 30 panic(ErrShape) 31 } 32 return 33 case 1: 34 m.reuseAsNonZeroed(factors[0].Dims()) 35 m.Copy(factors[0]) 36 return 37 case 2: 38 // Don't do work that we know the answer to. 39 m.Mul(factors[0], factors[1]) 40 return 41 } 42 43 p := newMultiplier(m, factors) 44 p.optimize() 45 result := p.multiply() 46 m.reuseAsNonZeroed(result.Dims()) 47 m.Copy(result) 48 putDenseWorkspace(result) 49 } 50 51 // debugProductWalk enables debugging output for Product. 52 const debugProductWalk = false 53 54 // multiplier performs operation order optimisation and tree traversal. 55 type multiplier struct { 56 // factors is the ordered set of 57 // factors to multiply. 58 factors []Matrix 59 // dims is the chain of factor 60 // dimensions. 61 dims []int 62 63 // table contains the dynamic 64 // programming costs and subchain 65 // division indices. 66 table table 67 } 68 69 func newMultiplier(m *Dense, factors []Matrix) *multiplier { 70 // Check size early, but don't yet 71 // allocate data for m. 72 r, c := m.Dims() 73 fr, fc := factors[0].Dims() // newMultiplier is only called with len(factors) > 2. 74 if !m.IsEmpty() { 75 if fr != r { 76 panic(ErrShape) 77 } 78 if _, lc := factors[len(factors)-1].Dims(); lc != c { 79 panic(ErrShape) 80 } 81 } 82 83 dims := make([]int, len(factors)+1) 84 dims[0] = r 85 dims[len(dims)-1] = c 86 pc := fc 87 for i, f := range factors[1:] { 88 cr, cc := f.Dims() 89 dims[i+1] = cr 90 if pc != cr { 91 panic(ErrShape) 92 } 93 pc = cc 94 } 95 96 return &multiplier{ 97 factors: factors, 98 dims: dims, 99 table: newTable(len(factors)), 100 } 101 } 102 103 // optimize determines an optimal matrix multiply operation order. 104 func (p *multiplier) optimize() { 105 if debugProductWalk { 106 fmt.Printf("chain dims: %v\n", p.dims) 107 } 108 const maxInt = int(^uint(0) >> 1) 109 for f := 1; f < len(p.factors); f++ { 110 for i := 0; i < len(p.factors)-f; i++ { 111 j := i + f 112 p.table.set(i, j, entry{cost: maxInt}) 113 for k := i; k < j; k++ { 114 cost := p.table.at(i, k).cost + p.table.at(k+1, j).cost + p.dims[i]*p.dims[k+1]*p.dims[j+1] 115 if cost < p.table.at(i, j).cost { 116 p.table.set(i, j, entry{cost: cost, k: k}) 117 } 118 } 119 } 120 } 121 } 122 123 // multiply walks the optimal operation tree found by optimize, 124 // leaving the final result in the stack. It returns the 125 // product, which may be copied but should be returned to 126 // the workspace pool. 127 func (p *multiplier) multiply() *Dense { 128 result, _ := p.multiplySubchain(0, len(p.factors)-1) 129 if debugProductWalk { 130 r, c := result.Dims() 131 fmt.Printf("\tpop result (%d×%d) cost=%d\n", r, c, p.table.at(0, len(p.factors)-1).cost) 132 } 133 return result.(*Dense) 134 } 135 136 func (p *multiplier) multiplySubchain(i, j int) (m Matrix, intermediate bool) { 137 if i == j { 138 return p.factors[i], false 139 } 140 141 a, aTmp := p.multiplySubchain(i, p.table.at(i, j).k) 142 b, bTmp := p.multiplySubchain(p.table.at(i, j).k+1, j) 143 144 ar, ac := a.Dims() 145 br, bc := b.Dims() 146 if ac != br { 147 // Panic with a string since this 148 // is not a user-facing panic. 149 panic(ErrShape.Error()) 150 } 151 152 if debugProductWalk { 153 fmt.Printf("\tpush f[%d] (%d×%d)%s * f[%d] (%d×%d)%s\n", 154 i, ar, ac, result(aTmp), j, br, bc, result(bTmp)) 155 } 156 157 r := getDenseWorkspace(ar, bc, false) 158 r.Mul(a, b) 159 if aTmp { 160 putDenseWorkspace(a.(*Dense)) 161 } 162 if bTmp { 163 putDenseWorkspace(b.(*Dense)) 164 } 165 return r, true 166 } 167 168 type entry struct { 169 k int // is the chain subdivision index. 170 cost int // cost is the cost of the operation. 171 } 172 173 // table is a row major n×n dynamic programming table. 174 type table struct { 175 n int 176 entries []entry 177 } 178 179 func newTable(n int) table { 180 return table{n: n, entries: make([]entry, n*n)} 181 } 182 183 func (t table) at(i, j int) entry { return t.entries[i*t.n+j] } 184 func (t table) set(i, j int, e entry) { t.entries[i*t.n+j] = e } 185 186 type result bool 187 188 func (r result) String() string { 189 if r { 190 return " (popped result)" 191 } 192 return "" 193 }