gorgonia.org/gorgonia@v0.9.17/operations_broadcast_test.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "testing" 6 7 "github.com/stretchr/testify/assert" 8 "gorgonia.org/tensor" 9 ) 10 11 type broadcastOpTest struct { 12 name string 13 a Value 14 b Value 15 16 // broadcast axes 17 left, right []byte 18 19 // results 20 ab Value 21 err bool 22 } 23 24 var broadcastAddTests = []broadcastOpTest{ 25 {name: "vec-mat", 26 a: tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{100, 200})), 27 b: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})), 28 left: []byte{1}, 29 right: nil, 30 ab: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 102, 203, 204})), 31 err: false, 32 }, 33 34 {name: "mat-vec", 35 a: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})), 36 b: tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{100, 200})), 37 left: nil, 38 right: []byte{1}, 39 ab: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 102, 203, 204})), 40 err: false, 41 }, 42 {name: "rowvec-mat", 43 a: tensor.New(tensor.WithShape(2, 1), tensor.WithBacking([]float64{100, 200})), 44 b: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})), 45 left: []byte{1}, 46 right: nil, 47 ab: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 102, 203, 204})), 48 err: false, 49 }, 50 {name: "mat-rowvec", 51 a: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})), 52 b: tensor.New(tensor.WithShape(2, 1), tensor.WithBacking([]float64{100, 200})), 53 left: nil, 54 right: []byte{1}, 55 ab: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 102, 203, 204})), 56 err: false, 57 }, 58 {name: "colvec-mat", 59 a: tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float64{100, 200})), 60 b: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})), 61 left: []byte{0}, 62 right: nil, 63 ab: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 202, 103, 204})), 64 err: false, 65 }, 66 {name: "mat-colvec", 67 a: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})), 68 b: tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float64{100, 200})), 69 left: nil, 70 right: []byte{0}, 71 ab: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 202, 103, 204})), 72 err: false, 73 }, 74 /* // SKIPPED UNTIL WE CAN FIX BROADCAST SEMANTICS 75 {name: "3col-3tensor", 76 a: tensor.New(tensor.WithShape(1, 1, 2), tensor.WithBacking([]float64{100, 200})), 77 b: tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})), 78 left: []byte{0, 1}, 79 right: nil, 80 ab: tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{101, 202, 103, 204, 105, 206, 107, 208})), 81 err: false, 82 }, 83 {name: "3vec-3tensor", 84 a: tensor.New(tensor.WithShape(2, 1, 1), tensor.WithBacking([]float64{100, 200})), 85 b: tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})), 86 left: []byte{1, 2}, 87 right: nil, 88 ab: tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{101, 102, 103, 104, 205, 206, 207, 208})), 89 err: false, 90 }, 91 {name: "colmat-3tensor", 92 a: tensor.New(tensor.WithShape(1, 2, 2), tensor.WithBacking([]float64{100, 200, 300, 400})), 93 b: tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})), 94 left: []byte{0}, 95 right: nil, 96 ab: tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{101, 202, 303, 404, 105, 206, 307, 408})), 97 err: false, 98 }, 99 {name: "3tensor-colmat", 100 a: tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})), 101 b: tensor.New(tensor.WithShape(1, 2, 2), tensor.WithBacking([]float64{100, 200, 300, 400})), 102 left: nil, 103 right: []byte{0}, 104 ab: tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{101, 202, 303, 404, 105, 206, 307, 408})), 105 err: false, 106 }, 107 {name: "rowmat-3tensor", 108 a: tensor.New(tensor.WithShape(2, 2, 1), tensor.WithBacking([]float64{100, 200, 300, 400})), 109 b: tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})), 110 left: []byte{2}, 111 right: nil, 112 ab: tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{101, 102, 203, 204, 305, 306, 407, 408})), 113 err: false, 114 }, 115 {name: "3tensor-rowmat", 116 a: tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})), 117 b: tensor.New(tensor.WithShape(2, 2, 1), tensor.WithBacking([]float64{100, 200, 300, 400})), 118 left: nil, 119 right: []byte{2}, 120 ab: tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{101, 102, 203, 204, 305, 306, 407, 408})), 121 err: false, 122 }, 123 {name: "vec-3tensor", 124 a: tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{100, 200})), 125 b: tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})), 126 left: []byte{1, 2}, 127 right: nil, 128 ab: tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{101, 202, 103, 204, 105, 206, 107, 208})), 129 err: false, 130 }, 131 */ 132 // TODO (these would give coverage to all broadcast applications) 133 // vec-3tensor 134 // 3tensor-vec 135 // mat-3tensor 136 // 3-tensor-mat 137 // and their corresponding errors 138 139 // WILL ERR 140 // {name: "vec-mat- wrong left pattern axis", 141 // a: tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{100, 200})), 142 // b: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})), 143 // left: []byte{0}, 144 // right: nil, 145 // ab: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 102, 203, 204})), 146 // err: true, 147 // }, 148 {name: "rowvec-mat: wrong axis", 149 a: tensor.New(tensor.WithShape(2, 1), tensor.WithBacking([]float64{100, 200})), 150 b: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})), 151 left: []byte{2}, 152 right: nil, 153 ab: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 102, 203, 204})), 154 err: true, 155 }, 156 157 {name: "impossible mat-mat", 158 a: tensor.New(tensor.WithShape(2, 4), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})), 159 b: tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float64{100, 200})), 160 left: nil, 161 right: []byte{0, 1}, 162 ab: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 102, 203, 204})), 163 err: true, 164 }, 165 } 166 167 func TestBroadcastAdd(t *testing.T) { 168 assert := assert.New(t) 169 for i, bat := range broadcastAddTests { 170 //if bat.name != "impossible mat-mat" { 171 // continue 172 // } 173 g := NewGraph() 174 a := NodeFromAny(g, bat.a, WithName("a")) 175 b := NodeFromAny(g, bat.b, WithName("b")) 176 c, err := BroadcastAdd(a, b, bat.left, bat.right) 177 if checkErr(t, bat.err, err, bat.name, i) { 178 continue 179 } 180 machine := NewTapeMachine(g) 181 182 if err = machine.RunAll(); err != nil { 183 t.Errorf("Test %v(%d): %v", bat.name, i, err) 184 } 185 assert.Equal(bat.ab.Data(), c.Value().Data(), "Test %v(%v)", bat.name, i) 186 machine.Close() 187 } 188 } 189 190 var broadcastMulTests = []broadcastOpTest{ 191 {name: "vec-mat", 192 a: tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{10, 20})), 193 b: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})), 194 left: []byte{1}, 195 right: nil, 196 ab: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{10, 20, 60, 80})), 197 err: false, 198 }, 199 200 {name: "mat-vec", 201 a: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})), 202 b: tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{10, 20})), 203 left: nil, 204 right: []byte{1}, 205 ab: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{10, 20, 60, 80})), 206 err: false, 207 }, 208 {name: "rowvec-mat", 209 a: tensor.New(tensor.WithShape(2, 1), tensor.WithBacking([]float64{10, 20})), 210 b: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})), 211 left: []byte{1}, 212 right: nil, 213 ab: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{10, 20, 60, 80})), 214 err: false, 215 }, 216 {name: "mat-rowvec", 217 a: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})), 218 b: tensor.New(tensor.WithShape(2, 1), tensor.WithBacking([]float64{10, 20})), 219 left: nil, 220 right: []byte{1}, 221 ab: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{10, 20, 60, 80})), 222 err: false, 223 }, 224 {name: "colvec-mat", 225 a: tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float64{10, 20})), 226 b: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})), 227 left: []byte{0}, 228 right: nil, 229 ab: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{10, 40, 30, 80})), 230 err: false, 231 }, 232 {name: "mat-colvec", 233 a: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})), 234 b: tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float64{10, 20})), 235 left: nil, 236 right: []byte{0}, 237 ab: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{10, 40, 30, 80})), 238 err: false, 239 }, 240 241 // TODO (these would give coverage to all broadcast applications) 242 // vec-3tensor 243 // 3tensor-vec 244 // mat-3tensor 245 // 3-tensor-mat 246 // and their corresponding errors 247 248 // WILL ERR 249 // {name: "vec-mat- wrong left pattern axis", 250 // a: tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{10, 20})), 251 // b: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})), 252 // left: []byte{0}, 253 // right: nil, 254 // err: true, 255 // }, 256 {name: "rowvec-mat: wrong axis", 257 a: tensor.New(tensor.WithShape(2, 1), tensor.WithBacking([]float64{10, 20})), 258 b: tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})), 259 left: []byte{2}, 260 right: nil, 261 err: true, 262 }, 263 264 {name: "impossible mat-mat", 265 a: tensor.New(tensor.WithShape(2, 4), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})), 266 b: tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float64{10, 20})), 267 left: nil, 268 right: []byte{0, 1}, 269 err: true, 270 }, 271 } 272 273 func TestBroadcastHadamardProd(t *testing.T) { 274 assert := assert.New(t) 275 for i, bat := range broadcastMulTests { 276 g := NewGraph() 277 a := NodeFromAny(g, bat.a, WithName("a")) 278 b := NodeFromAny(g, bat.b, WithName("b")) 279 c, err := BroadcastHadamardProd(a, b, bat.left, bat.right) 280 if checkErr(t, bat.err, err, bat.name, i) { 281 continue 282 } 283 machine := NewTapeMachine(g) 284 285 if err = machine.RunAll(); err != nil { 286 t.Errorf("Test %v(%d): %v", bat.name, i, err) 287 } 288 assert.Equal(bat.ab.Data(), c.Value().Data(), "Test %v(%v)", bat.name, i) 289 machine.Close() 290 } 291 } 292 293 // Broadcasts with nils in both left and right patterns will yield the original inputs. 294 func ExampleBroadcast_nils() { 295 g := NewGraph() 296 x := NewMatrix(g, Float64, WithShape(2, 3), WithName("x")) 297 y := NewMatrix(g, Float64, WithShape(2, 3), WithName("y")) 298 a, b, err := Broadcast(x, y, NewBroadcastPattern(nil, nil)) 299 if err != nil { 300 fmt.Printf("Error: %v\n", err) 301 return 302 } 303 fmt.Printf("a == x %t; b == y %t", a == x, b == y) 304 // Output: 305 // a == x true; b == y true 306 }