gorgonia.org/gorgonia@v0.9.17/broadcast_test.go (about) 1 package gorgonia 2 3 import ( 4 "io/ioutil" 5 "testing" 6 7 "github.com/stretchr/testify/assert" 8 "gorgonia.org/tensor" 9 ) 10 11 func TestBroadcastPattern(t *testing.T) { 12 assert := assert.New(t) 13 var bcpat BroadcastPattern 14 15 // make sure that the basics work 16 bcpat = NewBroadcastPattern(nil, []byte{1}) 17 assert.Equal(BroadcastPattern(0x02), bcpat) 18 19 bcpat = NewBroadcastPattern(nil, []byte{0}) 20 assert.Equal(BroadcastPattern(0x01), bcpat) 21 22 bcpat = NewBroadcastPattern([]byte{1, 0}, nil) 23 assert.Equal(BroadcastPattern(0x30), bcpat) 24 25 bcpat = NewBroadcastPattern([]byte{0}, nil) 26 assert.Equal(BroadcastPattern(0x10), bcpat) 27 28 // checks 29 bcpat = NewBroadcastPattern(nil, []byte{1}) 30 assert.True(bcpat.bc(false, 1)) 31 assert.False(bcpat.bc(true, 1)) 32 33 // ons 34 bcpat = NewBroadcastPattern(nil, []byte{1}) 35 assert.Equal([]int{1}, bcpat.on()[1]) 36 assert.Nil(bcpat.on()[0]) 37 38 bcpat = NewBroadcastPattern([]byte{2, 1}, []byte{1}) 39 assert.Equal([]int{1, 2}, bcpat.on()[0]) 40 assert.Equal([]int{1}, bcpat.on()[1]) 41 } 42 43 func TestBroadcast(t *testing.T) { 44 if CUDA { 45 t.SkipNow() 46 } 47 48 assert := assert.New(t) 49 var g *ExprGraph 50 var x, y, a, b, z *Node 51 var m *lispMachine 52 var err error 53 54 xT := tensor.New(tensor.WithShape(2, 3), tensor.WithBacking(tensor.Range(tensor.Float64, 0, 6))) 55 yT := tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{100, 200})) 56 57 g = NewGraph() 58 x = NewMatrix(g, Float64, WithShape(2, 3), WithValue(xT), WithName("x")) 59 y = NewVector(g, Float64, WithShape(2), WithValue(yT), WithName("y")) 60 if a, b, err = Broadcast(x, y, NewBroadcastPattern(nil, []byte{1})); err != nil { 61 ioutil.WriteFile("Broadcast.dot", []byte(g.ToDot()), 0644) 62 t.Fatal(err) 63 } 64 z, err = Add(a, b) 65 if err != nil { 66 t.Fatalf("Error: %v. a %v + b %v", err, a.Shape(), b.Shape()) 67 } 68 if _, _, err = Broadcast(x, y, NewBroadcastPattern(nil, []byte{1})); err != nil { 69 ioutil.WriteFile("Broadcast.dot", []byte(g.ToDot()), 0644) 70 t.Fatal(err) 71 } 72 73 m = NewLispMachine(g, ExecuteFwdOnly()) 74 defer m.Close() 75 if err := m.RunAll(); err != nil { 76 t.Fatal(err) 77 } 78 assert.Equal([]float64{100, 101, 102, 203, 204, 205}, extractF64s(z.Value())) 79 80 g = NewGraph() 81 x = NewMatrix(g, Float64, WithShape(2, 3), WithValue(xT), WithName("x")) 82 y = NewVector(g, Float64, WithShape(2), WithValue(yT), WithName("y")) 83 if a, b, err = Broadcast(y, x, NewBroadcastPattern([]byte{1}, nil)); err != nil { 84 ioutil.WriteFile("Broadcast.dot", []byte(g.ToDot()), 0644) 85 t.Fatalf("%+v", err) 86 } 87 // TODO: Check the error returned by Add? 88 z, _ = Add(a, b) 89 if _, _, err = Broadcast(x, y, NewBroadcastPattern(nil, []byte{1})); err != nil { 90 ioutil.WriteFile("Broadcast.dot", []byte(g.ToDot()), 0644) 91 t.Fatal(err) 92 } 93 94 m = NewLispMachine(g, ExecuteFwdOnly()) 95 defer m.Close() 96 if err := m.RunAll(); err != nil { 97 t.Fatal(err) 98 } 99 assert.Equal([]float64{100, 101, 102, 203, 204, 205}, extractF64s(z.Value())) 100 }