gorgonia.org/gorgonia@v0.9.17/broadcast.go (about) 1 package gorgonia 2 3 import ( 4 "github.com/pkg/errors" 5 "gorgonia.org/tensor" 6 ) 7 8 const ( 9 bcAllowableAxes = 4 10 ) 11 12 // BroadcastPattern is actually a bit array. 13 // It's split into 2 nibbles - the left nibble represents the left operand, the right nibble represents the right operand: 14 // xxxx|xxxx 15 // The least significant bit of each nibble is elem 0. 16 // Concrete examples: 17 // 00000010 (0x02) = broadcast axis 1 of the right operand 18 // 00000001 (0x01) = broadcast axis 0 of the right operand 19 // 00000101 (0x09) = broadcast axis 0 AND axis 2 of the right operand 20 // 00010000 (0x10) = broadcast axis 0 of the left operand 21 // 00110000 (0x30) = broadcast axis 0 and axis 1 of the lef operand 22 // You get the drill. 23 // 24 // Do note that the current limitation of the BroadcastPattern allows only up to 4 dimensions per operand. 25 type BroadcastPattern byte 26 27 // NewBroadcastPattern is a helper function to create broadcast patterns 28 func NewBroadcastPattern(leftAxes, rightAxes []byte) BroadcastPattern { 29 var start byte 30 for _, a := range leftAxes { 31 a += bcAllowableAxes 32 start |= byte(1) << a 33 } 34 for _, a := range rightAxes { 35 start |= byte(1) << a 36 } 37 return BroadcastPattern(start) 38 } 39 40 func (bcpat BroadcastPattern) bc(left bool, axis byte) bool { 41 operand := axis 42 if left { 43 operand += bcAllowableAxes 44 } 45 return (byte(bcpat)>>operand)&byte(1) == 1 46 } 47 48 func (bcpat BroadcastPattern) on() (retVal [2][]int) { 49 for i := 0; i < bcAllowableAxes; i++ { 50 if bcpat.bc(true, byte(i)) { 51 retVal[0] = append(retVal[0], i) 52 } 53 } 54 55 for i := 0; i < bcAllowableAxes; i++ { 56 if bcpat.bc(false, byte(i)) { 57 retVal[1] = append(retVal[1], i) 58 } 59 } 60 61 return 62 } 63 64 // Broadcast apply the pattern to the input nodes 65 // and returns two nodes suitable for a binary operator. 66 // Broadcast works somewhat like Numpy's broadcast, except it's now exposed as a function. 67 func Broadcast(a, b *Node, pattern BroadcastPattern) (*Node, *Node, error) { 68 broadcastOn := pattern.on() 69 70 var err error 71 var newShape tensor.Shape 72 x := a 73 y := b 74 xshape := x.Shape() 75 yshape := y.Shape() 76 77 if len(broadcastOn[0]) > 0 { 78 79 for _, a := range broadcastOn[0] { 80 if a >= yshape.Dims() { 81 return nil, nil, errors.Errorf("Attempting to broadcast a on axis %d of b. But b has shape %v", a, yshape) 82 } 83 } 84 newShape = calcBroadcastShape(x, yshape.Dims(), broadcastOn[0]) 85 if x, err = Reshape(x, newShape); err != nil { 86 return nil, nil, errors.Wrapf(err, "Cannot reshape x to %v for broadcasting", newShape) 87 } 88 children := Nodes{x} 89 for _, a := range broadcastOn[0] { 90 var size *Node 91 if size, err = SizeOf(a, y); err != nil { 92 return nil, nil, errors.Wrap(err, operationError) 93 } 94 children = append(children, size) 95 } 96 if x, err = repeatedApply(broadcastOn[0], children); err != nil { 97 return nil, nil, errors.Wrap(err, operationError) 98 } 99 } 100 101 if len(broadcastOn[1]) > 0 { 102 for _, a := range broadcastOn[1] { 103 if a >= xshape.Dims() { 104 return nil, nil, errors.Errorf("Attempting to broadcast b on axis %d of a. But a has shape %v", a, xshape) 105 } 106 } 107 108 newShape = calcBroadcastShape(y, xshape.Dims(), broadcastOn[1]) 109 110 if y, err = Reshape(y, newShape); err != nil { 111 return nil, nil, errors.Wrapf(err, "Cannot reshape y to %v for broadcast", newShape) 112 } 113 children := Nodes{y} 114 for _, a := range broadcastOn[1] { 115 var size *Node 116 if size, err = SizeOf(a, x); err != nil { 117 return nil, nil, errors.Wrap(err, operationError) 118 } 119 children = append(children, size) 120 } 121 122 if y, err = repeatedApply(broadcastOn[1], children); err != nil { 123 return nil, nil, errors.Wrap(err, operationError) 124 } 125 } 126 return x, y, nil 127 }