gorgonia.org/gorgonia@v0.9.17/operations_broadcast_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/assert"
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  type broadcastOpTest struct {
    12  	name string
    13  	a    Value
    14  	b    Value
    15  
    16  	// broadcast axes
    17  	left, right []byte
    18  
    19  	// results
    20  	ab  Value
    21  	err bool
    22  }
    23  
    24  var broadcastAddTests = []broadcastOpTest{
    25  	{name: "vec-mat",
    26  		a:     tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{100, 200})),
    27  		b:     tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})),
    28  		left:  []byte{1},
    29  		right: nil,
    30  		ab:    tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 102, 203, 204})),
    31  		err:   false,
    32  	},
    33  
    34  	{name: "mat-vec",
    35  		a:     tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})),
    36  		b:     tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{100, 200})),
    37  		left:  nil,
    38  		right: []byte{1},
    39  		ab:    tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 102, 203, 204})),
    40  		err:   false,
    41  	},
    42  	{name: "rowvec-mat",
    43  		a:     tensor.New(tensor.WithShape(2, 1), tensor.WithBacking([]float64{100, 200})),
    44  		b:     tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})),
    45  		left:  []byte{1},
    46  		right: nil,
    47  		ab:    tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 102, 203, 204})),
    48  		err:   false,
    49  	},
    50  	{name: "mat-rowvec",
    51  		a:     tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})),
    52  		b:     tensor.New(tensor.WithShape(2, 1), tensor.WithBacking([]float64{100, 200})),
    53  		left:  nil,
    54  		right: []byte{1},
    55  		ab:    tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 102, 203, 204})),
    56  		err:   false,
    57  	},
    58  	{name: "colvec-mat",
    59  		a:     tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float64{100, 200})),
    60  		b:     tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})),
    61  		left:  []byte{0},
    62  		right: nil,
    63  		ab:    tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 202, 103, 204})),
    64  		err:   false,
    65  	},
    66  	{name: "mat-colvec",
    67  		a:     tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})),
    68  		b:     tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float64{100, 200})),
    69  		left:  nil,
    70  		right: []byte{0},
    71  		ab:    tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 202, 103, 204})),
    72  		err:   false,
    73  	},
    74  	/* // SKIPPED UNTIL WE CAN FIX BROADCAST SEMANTICS
    75  	{name: "3col-3tensor",
    76  		a:     tensor.New(tensor.WithShape(1, 1, 2), tensor.WithBacking([]float64{100, 200})),
    77  		b:     tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})),
    78  		left:  []byte{0, 1},
    79  		right: nil,
    80  		ab:    tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{101, 202, 103, 204, 105, 206, 107, 208})),
    81  		err:   false,
    82  	},
    83  	{name: "3vec-3tensor",
    84  		a:     tensor.New(tensor.WithShape(2, 1, 1), tensor.WithBacking([]float64{100, 200})),
    85  		b:     tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})),
    86  		left:  []byte{1, 2},
    87  		right: nil,
    88  		ab:    tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{101, 102, 103, 104, 205, 206, 207, 208})),
    89  		err:   false,
    90  	},
    91  	{name: "colmat-3tensor",
    92  		a:     tensor.New(tensor.WithShape(1, 2, 2), tensor.WithBacking([]float64{100, 200, 300, 400})),
    93  		b:     tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})),
    94  		left:  []byte{0},
    95  		right: nil,
    96  		ab:    tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{101, 202, 303, 404, 105, 206, 307, 408})),
    97  		err:   false,
    98  	},
    99  	{name: "3tensor-colmat",
   100  		a:     tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})),
   101  		b:     tensor.New(tensor.WithShape(1, 2, 2), tensor.WithBacking([]float64{100, 200, 300, 400})),
   102  		left:  nil,
   103  		right: []byte{0},
   104  		ab:    tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{101, 202, 303, 404, 105, 206, 307, 408})),
   105  		err:   false,
   106  	},
   107  	{name: "rowmat-3tensor",
   108  		a:     tensor.New(tensor.WithShape(2, 2, 1), tensor.WithBacking([]float64{100, 200, 300, 400})),
   109  		b:     tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})),
   110  		left:  []byte{2},
   111  		right: nil,
   112  		ab:    tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{101, 102, 203, 204, 305, 306, 407, 408})),
   113  		err:   false,
   114  	},
   115  	{name: "3tensor-rowmat",
   116  		a:     tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})),
   117  		b:     tensor.New(tensor.WithShape(2, 2, 1), tensor.WithBacking([]float64{100, 200, 300, 400})),
   118  		left:  nil,
   119  		right: []byte{2},
   120  		ab:    tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{101, 102, 203, 204, 305, 306, 407, 408})),
   121  		err:   false,
   122  	},
   123  	{name: "vec-3tensor",
   124  		a:     tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{100, 200})),
   125  		b:     tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})),
   126  		left:  []byte{1, 2},
   127  		right: nil,
   128  		ab:    tensor.New(tensor.WithShape(2, 2, 2), tensor.WithBacking([]float64{101, 202, 103, 204, 105, 206, 107, 208})),
   129  		err:   false,
   130  	},
   131  	*/
   132  	// TODO (these would give coverage to all broadcast applications)
   133  	// 	vec-3tensor
   134  	// 	3tensor-vec
   135  	// 	mat-3tensor
   136  	// 	3-tensor-mat
   137  	// and their corresponding errors
   138  
   139  	// WILL ERR
   140  	// {name: "vec-mat- wrong left pattern axis",
   141  	// 	a:     tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{100, 200})),
   142  	// 	b:     tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})),
   143  	// 	left:  []byte{0},
   144  	// 	right: nil,
   145  	// 	ab:    tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 102, 203, 204})),
   146  	// 	err:   true,
   147  	// },
   148  	{name: "rowvec-mat: wrong axis",
   149  		a:     tensor.New(tensor.WithShape(2, 1), tensor.WithBacking([]float64{100, 200})),
   150  		b:     tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})),
   151  		left:  []byte{2},
   152  		right: nil,
   153  		ab:    tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 102, 203, 204})),
   154  		err:   true,
   155  	},
   156  
   157  	{name: "impossible mat-mat",
   158  		a:     tensor.New(tensor.WithShape(2, 4), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})),
   159  		b:     tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float64{100, 200})),
   160  		left:  nil,
   161  		right: []byte{0, 1},
   162  		ab:    tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{101, 102, 203, 204})),
   163  		err:   true,
   164  	},
   165  }
   166  
   167  func TestBroadcastAdd(t *testing.T) {
   168  	assert := assert.New(t)
   169  	for i, bat := range broadcastAddTests {
   170  		//if bat.name != "impossible mat-mat" {
   171  		//		continue
   172  		//	}
   173  		g := NewGraph()
   174  		a := NodeFromAny(g, bat.a, WithName("a"))
   175  		b := NodeFromAny(g, bat.b, WithName("b"))
   176  		c, err := BroadcastAdd(a, b, bat.left, bat.right)
   177  		if checkErr(t, bat.err, err, bat.name, i) {
   178  			continue
   179  		}
   180  		machine := NewTapeMachine(g)
   181  
   182  		if err = machine.RunAll(); err != nil {
   183  			t.Errorf("Test %v(%d): %v", bat.name, i, err)
   184  		}
   185  		assert.Equal(bat.ab.Data(), c.Value().Data(), "Test %v(%v)", bat.name, i)
   186  		machine.Close()
   187  	}
   188  }
   189  
   190  var broadcastMulTests = []broadcastOpTest{
   191  	{name: "vec-mat",
   192  		a:     tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{10, 20})),
   193  		b:     tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})),
   194  		left:  []byte{1},
   195  		right: nil,
   196  		ab:    tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{10, 20, 60, 80})),
   197  		err:   false,
   198  	},
   199  
   200  	{name: "mat-vec",
   201  		a:     tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})),
   202  		b:     tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{10, 20})),
   203  		left:  nil,
   204  		right: []byte{1},
   205  		ab:    tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{10, 20, 60, 80})),
   206  		err:   false,
   207  	},
   208  	{name: "rowvec-mat",
   209  		a:     tensor.New(tensor.WithShape(2, 1), tensor.WithBacking([]float64{10, 20})),
   210  		b:     tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})),
   211  		left:  []byte{1},
   212  		right: nil,
   213  		ab:    tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{10, 20, 60, 80})),
   214  		err:   false,
   215  	},
   216  	{name: "mat-rowvec",
   217  		a:     tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})),
   218  		b:     tensor.New(tensor.WithShape(2, 1), tensor.WithBacking([]float64{10, 20})),
   219  		left:  nil,
   220  		right: []byte{1},
   221  		ab:    tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{10, 20, 60, 80})),
   222  		err:   false,
   223  	},
   224  	{name: "colvec-mat",
   225  		a:     tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float64{10, 20})),
   226  		b:     tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})),
   227  		left:  []byte{0},
   228  		right: nil,
   229  		ab:    tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{10, 40, 30, 80})),
   230  		err:   false,
   231  	},
   232  	{name: "mat-colvec",
   233  		a:     tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})),
   234  		b:     tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float64{10, 20})),
   235  		left:  nil,
   236  		right: []byte{0},
   237  		ab:    tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{10, 40, 30, 80})),
   238  		err:   false,
   239  	},
   240  
   241  	// TODO (these would give coverage to all broadcast applications)
   242  	// 	vec-3tensor
   243  	// 	3tensor-vec
   244  	// 	mat-3tensor
   245  	// 	3-tensor-mat
   246  	// and their corresponding errors
   247  
   248  	// WILL ERR
   249  	// {name: "vec-mat- wrong left pattern axis",
   250  	// 	a:     tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{10, 20})),
   251  	// 	b:     tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})),
   252  	// 	left:  []byte{0},
   253  	// 	right: nil,
   254  	// 	err:   true,
   255  	// },
   256  	{name: "rowvec-mat: wrong axis",
   257  		a:     tensor.New(tensor.WithShape(2, 1), tensor.WithBacking([]float64{10, 20})),
   258  		b:     tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})),
   259  		left:  []byte{2},
   260  		right: nil,
   261  		err:   true,
   262  	},
   263  
   264  	{name: "impossible mat-mat",
   265  		a:     tensor.New(tensor.WithShape(2, 4), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8})),
   266  		b:     tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float64{10, 20})),
   267  		left:  nil,
   268  		right: []byte{0, 1},
   269  		err:   true,
   270  	},
   271  }
   272  
   273  func TestBroadcastHadamardProd(t *testing.T) {
   274  	assert := assert.New(t)
   275  	for i, bat := range broadcastMulTests {
   276  		g := NewGraph()
   277  		a := NodeFromAny(g, bat.a, WithName("a"))
   278  		b := NodeFromAny(g, bat.b, WithName("b"))
   279  		c, err := BroadcastHadamardProd(a, b, bat.left, bat.right)
   280  		if checkErr(t, bat.err, err, bat.name, i) {
   281  			continue
   282  		}
   283  		machine := NewTapeMachine(g)
   284  
   285  		if err = machine.RunAll(); err != nil {
   286  			t.Errorf("Test %v(%d): %v", bat.name, i, err)
   287  		}
   288  		assert.Equal(bat.ab.Data(), c.Value().Data(), "Test %v(%v)", bat.name, i)
   289  		machine.Close()
   290  	}
   291  }
   292  
   293  // Broadcasts with nils in both left and right patterns will yield the original inputs.
   294  func ExampleBroadcast_nils() {
   295  	g := NewGraph()
   296  	x := NewMatrix(g, Float64, WithShape(2, 3), WithName("x"))
   297  	y := NewMatrix(g, Float64, WithShape(2, 3), WithName("y"))
   298  	a, b, err := Broadcast(x, y, NewBroadcastPattern(nil, nil))
   299  	if err != nil {
   300  		fmt.Printf("Error: %v\n", err)
   301  		return
   302  	}
   303  	fmt.Printf("a == x %t; b == y %t", a == x, b == y)
   304  	//  Output:
   305  	// a == x true; b == y true
   306  }