github.com/gonum/matrix@v0.0.0-20181209220409-c518dec07be9/mat64/product.go (about) 1 // Copyright ©2015 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat64 6 7 import ( 8 "fmt" 9 10 "github.com/gonum/matrix" 11 ) 12 13 // Product calculates the product of the given factors and places the result in 14 // the receiver. The order of multiplication operations is optimized to minimize 15 // the number of floating point operations on the basis that all matrix 16 // multiplications are general. 17 func (m *Dense) Product(factors ...Matrix) { 18 // The operation order optimisation is the naive O(n^3) dynamic 19 // programming approach and does not take into consideration 20 // finer-grained optimisations that might be available. 21 // 22 // TODO(kortschak) Consider using the O(nlogn) or O(mlogn) 23 // algorithms that are available. e.g. 24 // 25 // e.g. http://www.jofcis.com/publishedpapers/2014_10_10_4299_4306.pdf 26 // 27 // In the case that this is replaced, retain this code in 28 // tests to compare against. 29 30 r, c := m.Dims() 31 switch len(factors) { 32 case 0: 33 if r != 0 || c != 0 { 34 panic(matrix.ErrShape) 35 } 36 return 37 case 1: 38 m.reuseAs(factors[0].Dims()) 39 m.Copy(factors[0]) 40 return 41 case 2: 42 // Don't do work that we know the answer to. 43 m.Mul(factors[0], factors[1]) 44 return 45 } 46 47 p := newMultiplier(m, factors) 48 p.optimize() 49 result := p.multiply() 50 m.reuseAs(result.Dims()) 51 m.Copy(result) 52 putWorkspace(result) 53 } 54 55 // debugProductWalk enables debugging output for Product. 56 const debugProductWalk = false 57 58 // multiplier performs operation order optimisation and tree traversal. 59 type multiplier struct { 60 // factors is the ordered set of 61 // factors to multiply. 62 factors []Matrix 63 // dims is the chain of factor 64 // dimensions. 65 dims []int 66 67 // table contains the dynamic 68 // programming costs and subchain 69 // division indices. 70 table table 71 } 72 73 func newMultiplier(m *Dense, factors []Matrix) *multiplier { 74 // Check size early, but don't yet 75 // allocate data for m. 76 r, c := m.Dims() 77 fr, fc := factors[0].Dims() // newMultiplier is only called with len(factors) > 2. 78 if !m.isZero() { 79 if fr != r { 80 panic(matrix.ErrShape) 81 } 82 if _, lc := factors[len(factors)-1].Dims(); lc != c { 83 panic(matrix.ErrShape) 84 } 85 } 86 87 dims := make([]int, len(factors)+1) 88 dims[0] = r 89 dims[len(dims)-1] = c 90 pc := fc 91 for i, f := range factors[1:] { 92 cr, cc := f.Dims() 93 dims[i+1] = cr 94 if pc != cr { 95 panic(matrix.ErrShape) 96 } 97 pc = cc 98 } 99 100 return &multiplier{ 101 factors: factors, 102 dims: dims, 103 table: newTable(len(factors)), 104 } 105 } 106 107 // optimize determines an optimal matrix multiply operation order. 108 func (p *multiplier) optimize() { 109 if debugProductWalk { 110 fmt.Printf("chain dims: %v\n", p.dims) 111 } 112 const maxInt = int(^uint(0) >> 1) 113 for f := 1; f < len(p.factors); f++ { 114 for i := 0; i < len(p.factors)-f; i++ { 115 j := i + f 116 p.table.set(i, j, entry{cost: maxInt}) 117 for k := i; k < j; k++ { 118 cost := p.table.at(i, k).cost + p.table.at(k+1, j).cost + p.dims[i]*p.dims[k+1]*p.dims[j+1] 119 if cost < p.table.at(i, j).cost { 120 p.table.set(i, j, entry{cost: cost, k: k}) 121 } 122 } 123 } 124 } 125 } 126 127 // multiply walks the optimal operation tree found by optimize, 128 // leaving the final result in the stack. It returns the 129 // product, which may be copied but should be returned to 130 // the workspace pool. 131 func (p *multiplier) multiply() *Dense { 132 result, _ := p.multiplySubchain(0, len(p.factors)-1) 133 if debugProductWalk { 134 r, c := result.Dims() 135 fmt.Printf("\tpop result (%d×%d) cost=%d\n", r, c, p.table.at(0, len(p.factors)-1).cost) 136 } 137 return result.(*Dense) 138 } 139 140 func (p *multiplier) multiplySubchain(i, j int) (m Matrix, intermediate bool) { 141 if i == j { 142 return p.factors[i], false 143 } 144 145 a, aTmp := p.multiplySubchain(i, p.table.at(i, j).k) 146 b, bTmp := p.multiplySubchain(p.table.at(i, j).k+1, j) 147 148 ar, ac := a.Dims() 149 br, bc := b.Dims() 150 if ac != br { 151 // Panic with a string since this 152 // is not a user-facing panic. 153 panic(matrix.ErrShape.Error()) 154 } 155 156 if debugProductWalk { 157 fmt.Printf("\tpush f[%d] (%d×%d)%s * f[%d] (%d×%d)%s\n", 158 i, ar, ac, result(aTmp), j, br, bc, result(bTmp)) 159 } 160 161 r := getWorkspace(ar, bc, false) 162 r.Mul(a, b) 163 if aTmp { 164 putWorkspace(a.(*Dense)) 165 } 166 if bTmp { 167 putWorkspace(b.(*Dense)) 168 } 169 return r, true 170 } 171 172 type entry struct { 173 k int // is the chain subdivision index. 174 cost int // cost is the cost of the operation. 175 } 176 177 // table is a row major n×n dynamic programming table. 178 type table struct { 179 n int 180 entries []entry 181 } 182 183 func newTable(n int) table { 184 return table{n: n, entries: make([]entry, n*n)} 185 } 186 187 func (t table) at(i, j int) entry { return t.entries[i*t.n+j] } 188 func (t table) set(i, j int, e entry) { t.entries[i*t.n+j] = e } 189 190 type result bool 191 192 func (r result) String() string { 193 if r { 194 return " (popped result)" 195 } 196 return "" 197 }