github.com/gopherd/gonum@v0.0.4/mat/product.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import "fmt"
     8  
     9  // Product calculates the product of the given factors and places the result in
    10  // the receiver. The order of multiplication operations is optimized to minimize
    11  // the number of floating point operations on the basis that all matrix
    12  // multiplications are general.
    13  func (m *Dense) Product(factors ...Matrix) {
    14  	// The operation order optimisation is the naive O(n^3) dynamic
    15  	// programming approach and does not take into consideration
    16  	// finer-grained optimisations that might be available.
    17  	//
    18  	// TODO(kortschak) Consider using the O(nlogn) or O(mlogn)
    19  	// algorithms that are available. e.g.
    20  	//
    21  	// e.g. http://www.jofcis.com/publishedpapers/2014_10_10_4299_4306.pdf
    22  	//
    23  	// In the case that this is replaced, retain this code in
    24  	// tests to compare against.
    25  
    26  	r, c := m.Dims()
    27  	switch len(factors) {
    28  	case 0:
    29  		if r != 0 || c != 0 {
    30  			panic(ErrShape)
    31  		}
    32  		return
    33  	case 1:
    34  		m.reuseAsNonZeroed(factors[0].Dims())
    35  		m.Copy(factors[0])
    36  		return
    37  	case 2:
    38  		// Don't do work that we know the answer to.
    39  		m.Mul(factors[0], factors[1])
    40  		return
    41  	}
    42  
    43  	p := newMultiplier(m, factors)
    44  	p.optimize()
    45  	result := p.multiply()
    46  	m.reuseAsNonZeroed(result.Dims())
    47  	m.Copy(result)
    48  	putDenseWorkspace(result)
    49  }
    50  
    51  // debugProductWalk enables debugging output for Product.
    52  const debugProductWalk = false
    53  
    54  // multiplier performs operation order optimisation and tree traversal.
    55  type multiplier struct {
    56  	// factors is the ordered set of
    57  	// factors to multiply.
    58  	factors []Matrix
    59  	// dims is the chain of factor
    60  	// dimensions.
    61  	dims []int
    62  
    63  	// table contains the dynamic
    64  	// programming costs and subchain
    65  	// division indices.
    66  	table table
    67  }
    68  
    69  func newMultiplier(m *Dense, factors []Matrix) *multiplier {
    70  	// Check size early, but don't yet
    71  	// allocate data for m.
    72  	r, c := m.Dims()
    73  	fr, fc := factors[0].Dims() // newMultiplier is only called with len(factors) > 2.
    74  	if !m.IsEmpty() {
    75  		if fr != r {
    76  			panic(ErrShape)
    77  		}
    78  		if _, lc := factors[len(factors)-1].Dims(); lc != c {
    79  			panic(ErrShape)
    80  		}
    81  	}
    82  
    83  	dims := make([]int, len(factors)+1)
    84  	dims[0] = r
    85  	dims[len(dims)-1] = c
    86  	pc := fc
    87  	for i, f := range factors[1:] {
    88  		cr, cc := f.Dims()
    89  		dims[i+1] = cr
    90  		if pc != cr {
    91  			panic(ErrShape)
    92  		}
    93  		pc = cc
    94  	}
    95  
    96  	return &multiplier{
    97  		factors: factors,
    98  		dims:    dims,
    99  		table:   newTable(len(factors)),
   100  	}
   101  }
   102  
   103  // optimize determines an optimal matrix multiply operation order.
   104  func (p *multiplier) optimize() {
   105  	if debugProductWalk {
   106  		fmt.Printf("chain dims: %v\n", p.dims)
   107  	}
   108  	const maxInt = int(^uint(0) >> 1)
   109  	for f := 1; f < len(p.factors); f++ {
   110  		for i := 0; i < len(p.factors)-f; i++ {
   111  			j := i + f
   112  			p.table.set(i, j, entry{cost: maxInt})
   113  			for k := i; k < j; k++ {
   114  				cost := p.table.at(i, k).cost + p.table.at(k+1, j).cost + p.dims[i]*p.dims[k+1]*p.dims[j+1]
   115  				if cost < p.table.at(i, j).cost {
   116  					p.table.set(i, j, entry{cost: cost, k: k})
   117  				}
   118  			}
   119  		}
   120  	}
   121  }
   122  
   123  // multiply walks the optimal operation tree found by optimize,
   124  // leaving the final result in the stack. It returns the
   125  // product, which may be copied but should be returned to
   126  // the workspace pool.
   127  func (p *multiplier) multiply() *Dense {
   128  	result, _ := p.multiplySubchain(0, len(p.factors)-1)
   129  	if debugProductWalk {
   130  		r, c := result.Dims()
   131  		fmt.Printf("\tpop result (%d×%d) cost=%d\n", r, c, p.table.at(0, len(p.factors)-1).cost)
   132  	}
   133  	return result.(*Dense)
   134  }
   135  
   136  func (p *multiplier) multiplySubchain(i, j int) (m Matrix, intermediate bool) {
   137  	if i == j {
   138  		return p.factors[i], false
   139  	}
   140  
   141  	a, aTmp := p.multiplySubchain(i, p.table.at(i, j).k)
   142  	b, bTmp := p.multiplySubchain(p.table.at(i, j).k+1, j)
   143  
   144  	ar, ac := a.Dims()
   145  	br, bc := b.Dims()
   146  	if ac != br {
   147  		// Panic with a string since this
   148  		// is not a user-facing panic.
   149  		panic(ErrShape.Error())
   150  	}
   151  
   152  	if debugProductWalk {
   153  		fmt.Printf("\tpush f[%d] (%d×%d)%s * f[%d] (%d×%d)%s\n",
   154  			i, ar, ac, result(aTmp), j, br, bc, result(bTmp))
   155  	}
   156  
   157  	r := getDenseWorkspace(ar, bc, false)
   158  	r.Mul(a, b)
   159  	if aTmp {
   160  		putDenseWorkspace(a.(*Dense))
   161  	}
   162  	if bTmp {
   163  		putDenseWorkspace(b.(*Dense))
   164  	}
   165  	return r, true
   166  }
   167  
   168  type entry struct {
   169  	k    int // is the chain subdivision index.
   170  	cost int // cost is the cost of the operation.
   171  }
   172  
   173  // table is a row major n×n dynamic programming table.
   174  type table struct {
   175  	n       int
   176  	entries []entry
   177  }
   178  
   179  func newTable(n int) table {
   180  	return table{n: n, entries: make([]entry, n*n)}
   181  }
   182  
   183  func (t table) at(i, j int) entry     { return t.entries[i*t.n+j] }
   184  func (t table) set(i, j int, e entry) { t.entries[i*t.n+j] = e }
   185  
   186  type result bool
   187  
   188  func (r result) String() string {
   189  	if r {
   190  		return " (popped result)"
   191  	}
   192  	return ""
   193  }