gorgonia.org/gorgonia@v0.9.17/broadcast.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  	"gorgonia.org/tensor"
     6  )
     7  
     8  const (
     9  	bcAllowableAxes = 4
    10  )
    11  
    12  // BroadcastPattern is actually a bit array.
    13  // It's split into 2 nibbles - the left nibble represents the left operand, the right nibble represents the right operand:
    14  //		xxxx|xxxx
    15  // The least significant bit of each nibble is elem 0.
    16  // Concrete examples:
    17  //		00000010 (0x02) = broadcast axis 1 of the right operand
    18  //		00000001 (0x01) = broadcast axis 0 of the right operand
    19  //		00000101 (0x09) = broadcast axis 0 AND axis 2 of the right operand
    20  //		00010000 (0x10) = broadcast axis 0 of the left operand
    21  //		00110000 (0x30) = broadcast axis 0 and axis 1 of the lef operand
    22  // You get the drill.
    23  //
    24  // Do note that the current limitation of the BroadcastPattern allows only up to 4 dimensions per operand.
    25  type BroadcastPattern byte
    26  
    27  // NewBroadcastPattern is a helper function to create broadcast patterns
    28  func NewBroadcastPattern(leftAxes, rightAxes []byte) BroadcastPattern {
    29  	var start byte
    30  	for _, a := range leftAxes {
    31  		a += bcAllowableAxes
    32  		start |= byte(1) << a
    33  	}
    34  	for _, a := range rightAxes {
    35  		start |= byte(1) << a
    36  	}
    37  	return BroadcastPattern(start)
    38  }
    39  
    40  func (bcpat BroadcastPattern) bc(left bool, axis byte) bool {
    41  	operand := axis
    42  	if left {
    43  		operand += bcAllowableAxes
    44  	}
    45  	return (byte(bcpat)>>operand)&byte(1) == 1
    46  }
    47  
    48  func (bcpat BroadcastPattern) on() (retVal [2][]int) {
    49  	for i := 0; i < bcAllowableAxes; i++ {
    50  		if bcpat.bc(true, byte(i)) {
    51  			retVal[0] = append(retVal[0], i)
    52  		}
    53  	}
    54  
    55  	for i := 0; i < bcAllowableAxes; i++ {
    56  		if bcpat.bc(false, byte(i)) {
    57  			retVal[1] = append(retVal[1], i)
    58  		}
    59  	}
    60  
    61  	return
    62  }
    63  
    64  // Broadcast apply the pattern to the input nodes
    65  // and returns two nodes suitable for a binary operator.
    66  // Broadcast works somewhat like Numpy's broadcast, except it's now exposed as a function.
    67  func Broadcast(a, b *Node, pattern BroadcastPattern) (*Node, *Node, error) {
    68  	broadcastOn := pattern.on()
    69  
    70  	var err error
    71  	var newShape tensor.Shape
    72  	x := a
    73  	y := b
    74  	xshape := x.Shape()
    75  	yshape := y.Shape()
    76  
    77  	if len(broadcastOn[0]) > 0 {
    78  
    79  		for _, a := range broadcastOn[0] {
    80  			if a >= yshape.Dims() {
    81  				return nil, nil, errors.Errorf("Attempting to broadcast a on axis %d of b. But b has shape %v", a, yshape)
    82  			}
    83  		}
    84  		newShape = calcBroadcastShape(x, yshape.Dims(), broadcastOn[0])
    85  		if x, err = Reshape(x, newShape); err != nil {
    86  			return nil, nil, errors.Wrapf(err, "Cannot reshape x to %v for broadcasting", newShape)
    87  		}
    88  		children := Nodes{x}
    89  		for _, a := range broadcastOn[0] {
    90  			var size *Node
    91  			if size, err = SizeOf(a, y); err != nil {
    92  				return nil, nil, errors.Wrap(err, operationError)
    93  			}
    94  			children = append(children, size)
    95  		}
    96  		if x, err = repeatedApply(broadcastOn[0], children); err != nil {
    97  			return nil, nil, errors.Wrap(err, operationError)
    98  		}
    99  	}
   100  
   101  	if len(broadcastOn[1]) > 0 {
   102  		for _, a := range broadcastOn[1] {
   103  			if a >= xshape.Dims() {
   104  				return nil, nil, errors.Errorf("Attempting to broadcast b on axis %d of a. But a has shape %v", a, xshape)
   105  			}
   106  		}
   107  
   108  		newShape = calcBroadcastShape(y, xshape.Dims(), broadcastOn[1])
   109  
   110  		if y, err = Reshape(y, newShape); err != nil {
   111  			return nil, nil, errors.Wrapf(err, "Cannot reshape y to %v for broadcast", newShape)
   112  		}
   113  		children := Nodes{y}
   114  		for _, a := range broadcastOn[1] {
   115  			var size *Node
   116  			if size, err = SizeOf(a, x); err != nil {
   117  				return nil, nil, errors.Wrap(err, operationError)
   118  			}
   119  			children = append(children, size)
   120  		}
   121  
   122  		if y, err = repeatedApply(broadcastOn[1], children); err != nil {
   123  			return nil, nil, errors.Wrap(err, operationError)
   124  		}
   125  	}
   126  	return x, y, nil
   127  }