github.com/gonum/matrix@v0.0.0-20181209220409-c518dec07be9/mat64/product.go (about)

     1  // Copyright ©2015 The gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat64
     6  
     7  import (
     8  	"fmt"
     9  
    10  	"github.com/gonum/matrix"
    11  )
    12  
    13  // Product calculates the product of the given factors and places the result in
    14  // the receiver. The order of multiplication operations is optimized to minimize
    15  // the number of floating point operations on the basis that all matrix
    16  // multiplications are general.
    17  func (m *Dense) Product(factors ...Matrix) {
    18  	// The operation order optimisation is the naive O(n^3) dynamic
    19  	// programming approach and does not take into consideration
    20  	// finer-grained optimisations that might be available.
    21  	//
    22  	// TODO(kortschak) Consider using the O(nlogn) or O(mlogn)
    23  	// algorithms that are available. e.g.
    24  	//
    25  	// e.g. http://www.jofcis.com/publishedpapers/2014_10_10_4299_4306.pdf
    26  	//
    27  	// In the case that this is replaced, retain this code in
    28  	// tests to compare against.
    29  
    30  	r, c := m.Dims()
    31  	switch len(factors) {
    32  	case 0:
    33  		if r != 0 || c != 0 {
    34  			panic(matrix.ErrShape)
    35  		}
    36  		return
    37  	case 1:
    38  		m.reuseAs(factors[0].Dims())
    39  		m.Copy(factors[0])
    40  		return
    41  	case 2:
    42  		// Don't do work that we know the answer to.
    43  		m.Mul(factors[0], factors[1])
    44  		return
    45  	}
    46  
    47  	p := newMultiplier(m, factors)
    48  	p.optimize()
    49  	result := p.multiply()
    50  	m.reuseAs(result.Dims())
    51  	m.Copy(result)
    52  	putWorkspace(result)
    53  }
    54  
    55  // debugProductWalk enables debugging output for Product.
    56  const debugProductWalk = false
    57  
    58  // multiplier performs operation order optimisation and tree traversal.
    59  type multiplier struct {
    60  	// factors is the ordered set of
    61  	// factors to multiply.
    62  	factors []Matrix
    63  	// dims is the chain of factor
    64  	// dimensions.
    65  	dims []int
    66  
    67  	// table contains the dynamic
    68  	// programming costs and subchain
    69  	// division indices.
    70  	table table
    71  }
    72  
    73  func newMultiplier(m *Dense, factors []Matrix) *multiplier {
    74  	// Check size early, but don't yet
    75  	// allocate data for m.
    76  	r, c := m.Dims()
    77  	fr, fc := factors[0].Dims() // newMultiplier is only called with len(factors) > 2.
    78  	if !m.isZero() {
    79  		if fr != r {
    80  			panic(matrix.ErrShape)
    81  		}
    82  		if _, lc := factors[len(factors)-1].Dims(); lc != c {
    83  			panic(matrix.ErrShape)
    84  		}
    85  	}
    86  
    87  	dims := make([]int, len(factors)+1)
    88  	dims[0] = r
    89  	dims[len(dims)-1] = c
    90  	pc := fc
    91  	for i, f := range factors[1:] {
    92  		cr, cc := f.Dims()
    93  		dims[i+1] = cr
    94  		if pc != cr {
    95  			panic(matrix.ErrShape)
    96  		}
    97  		pc = cc
    98  	}
    99  
   100  	return &multiplier{
   101  		factors: factors,
   102  		dims:    dims,
   103  		table:   newTable(len(factors)),
   104  	}
   105  }
   106  
   107  // optimize determines an optimal matrix multiply operation order.
   108  func (p *multiplier) optimize() {
   109  	if debugProductWalk {
   110  		fmt.Printf("chain dims: %v\n", p.dims)
   111  	}
   112  	const maxInt = int(^uint(0) >> 1)
   113  	for f := 1; f < len(p.factors); f++ {
   114  		for i := 0; i < len(p.factors)-f; i++ {
   115  			j := i + f
   116  			p.table.set(i, j, entry{cost: maxInt})
   117  			for k := i; k < j; k++ {
   118  				cost := p.table.at(i, k).cost + p.table.at(k+1, j).cost + p.dims[i]*p.dims[k+1]*p.dims[j+1]
   119  				if cost < p.table.at(i, j).cost {
   120  					p.table.set(i, j, entry{cost: cost, k: k})
   121  				}
   122  			}
   123  		}
   124  	}
   125  }
   126  
   127  // multiply walks the optimal operation tree found by optimize,
   128  // leaving the final result in the stack. It returns the
   129  // product, which may be copied but should be returned to
   130  // the workspace pool.
   131  func (p *multiplier) multiply() *Dense {
   132  	result, _ := p.multiplySubchain(0, len(p.factors)-1)
   133  	if debugProductWalk {
   134  		r, c := result.Dims()
   135  		fmt.Printf("\tpop result (%d×%d) cost=%d\n", r, c, p.table.at(0, len(p.factors)-1).cost)
   136  	}
   137  	return result.(*Dense)
   138  }
   139  
   140  func (p *multiplier) multiplySubchain(i, j int) (m Matrix, intermediate bool) {
   141  	if i == j {
   142  		return p.factors[i], false
   143  	}
   144  
   145  	a, aTmp := p.multiplySubchain(i, p.table.at(i, j).k)
   146  	b, bTmp := p.multiplySubchain(p.table.at(i, j).k+1, j)
   147  
   148  	ar, ac := a.Dims()
   149  	br, bc := b.Dims()
   150  	if ac != br {
   151  		// Panic with a string since this
   152  		// is not a user-facing panic.
   153  		panic(matrix.ErrShape.Error())
   154  	}
   155  
   156  	if debugProductWalk {
   157  		fmt.Printf("\tpush f[%d] (%d×%d)%s * f[%d] (%d×%d)%s\n",
   158  			i, ar, ac, result(aTmp), j, br, bc, result(bTmp))
   159  	}
   160  
   161  	r := getWorkspace(ar, bc, false)
   162  	r.Mul(a, b)
   163  	if aTmp {
   164  		putWorkspace(a.(*Dense))
   165  	}
   166  	if bTmp {
   167  		putWorkspace(b.(*Dense))
   168  	}
   169  	return r, true
   170  }
   171  
   172  type entry struct {
   173  	k    int // is the chain subdivision index.
   174  	cost int // cost is the cost of the operation.
   175  }
   176  
   177  // table is a row major n×n dynamic programming table.
   178  type table struct {
   179  	n       int
   180  	entries []entry
   181  }
   182  
   183  func newTable(n int) table {
   184  	return table{n: n, entries: make([]entry, n*n)}
   185  }
   186  
   187  func (t table) at(i, j int) entry     { return t.entries[i*t.n+j] }
   188  func (t table) set(i, j int, e entry) { t.entries[i*t.n+j] = e }
   189  
   190  type result bool
   191  
   192  func (r result) String() string {
   193  	if r {
   194  		return " (popped result)"
   195  	}
   196  	return ""
   197  }