gorgonia.org/gorgonia@v0.9.17/broadcast_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"io/ioutil"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/assert"
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  func TestBroadcastPattern(t *testing.T) {
    12  	assert := assert.New(t)
    13  	var bcpat BroadcastPattern
    14  
    15  	// make sure that the basics work
    16  	bcpat = NewBroadcastPattern(nil, []byte{1})
    17  	assert.Equal(BroadcastPattern(0x02), bcpat)
    18  
    19  	bcpat = NewBroadcastPattern(nil, []byte{0})
    20  	assert.Equal(BroadcastPattern(0x01), bcpat)
    21  
    22  	bcpat = NewBroadcastPattern([]byte{1, 0}, nil)
    23  	assert.Equal(BroadcastPattern(0x30), bcpat)
    24  
    25  	bcpat = NewBroadcastPattern([]byte{0}, nil)
    26  	assert.Equal(BroadcastPattern(0x10), bcpat)
    27  
    28  	// checks
    29  	bcpat = NewBroadcastPattern(nil, []byte{1})
    30  	assert.True(bcpat.bc(false, 1))
    31  	assert.False(bcpat.bc(true, 1))
    32  
    33  	// ons
    34  	bcpat = NewBroadcastPattern(nil, []byte{1})
    35  	assert.Equal([]int{1}, bcpat.on()[1])
    36  	assert.Nil(bcpat.on()[0])
    37  
    38  	bcpat = NewBroadcastPattern([]byte{2, 1}, []byte{1})
    39  	assert.Equal([]int{1, 2}, bcpat.on()[0])
    40  	assert.Equal([]int{1}, bcpat.on()[1])
    41  }
    42  
    43  func TestBroadcast(t *testing.T) {
    44  	if CUDA {
    45  		t.SkipNow()
    46  	}
    47  
    48  	assert := assert.New(t)
    49  	var g *ExprGraph
    50  	var x, y, a, b, z *Node
    51  	var m *lispMachine
    52  	var err error
    53  
    54  	xT := tensor.New(tensor.WithShape(2, 3), tensor.WithBacking(tensor.Range(tensor.Float64, 0, 6)))
    55  	yT := tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{100, 200}))
    56  
    57  	g = NewGraph()
    58  	x = NewMatrix(g, Float64, WithShape(2, 3), WithValue(xT), WithName("x"))
    59  	y = NewVector(g, Float64, WithShape(2), WithValue(yT), WithName("y"))
    60  	if a, b, err = Broadcast(x, y, NewBroadcastPattern(nil, []byte{1})); err != nil {
    61  		ioutil.WriteFile("Broadcast.dot", []byte(g.ToDot()), 0644)
    62  		t.Fatal(err)
    63  	}
    64  	z, err = Add(a, b)
    65  	if err != nil {
    66  		t.Fatalf("Error: %v. a %v + b %v", err, a.Shape(), b.Shape())
    67  	}
    68  	if _, _, err = Broadcast(x, y, NewBroadcastPattern(nil, []byte{1})); err != nil {
    69  		ioutil.WriteFile("Broadcast.dot", []byte(g.ToDot()), 0644)
    70  		t.Fatal(err)
    71  	}
    72  
    73  	m = NewLispMachine(g, ExecuteFwdOnly())
    74  	defer m.Close()
    75  	if err := m.RunAll(); err != nil {
    76  		t.Fatal(err)
    77  	}
    78  	assert.Equal([]float64{100, 101, 102, 203, 204, 205}, extractF64s(z.Value()))
    79  
    80  	g = NewGraph()
    81  	x = NewMatrix(g, Float64, WithShape(2, 3), WithValue(xT), WithName("x"))
    82  	y = NewVector(g, Float64, WithShape(2), WithValue(yT), WithName("y"))
    83  	if a, b, err = Broadcast(y, x, NewBroadcastPattern([]byte{1}, nil)); err != nil {
    84  		ioutil.WriteFile("Broadcast.dot", []byte(g.ToDot()), 0644)
    85  		t.Fatalf("%+v", err)
    86  	}
    87  	// TODO: Check the error returned by Add?
    88  	z, _ = Add(a, b)
    89  	if _, _, err = Broadcast(x, y, NewBroadcastPattern(nil, []byte{1})); err != nil {
    90  		ioutil.WriteFile("Broadcast.dot", []byte(g.ToDot()), 0644)
    91  		t.Fatal(err)
    92  	}
    93  
    94  	m = NewLispMachine(g, ExecuteFwdOnly())
    95  	defer m.Close()
    96  	if err := m.RunAll(); err != nil {
    97  		t.Fatal(err)
    98  	}
    99  	assert.Equal([]float64{100, 101, 102, 203, 204, 205}, extractF64s(z.Value()))
   100  }