gorgonia.org/gorgonia@v0.9.17/example_broadcast_op_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"log"
     6  
     7  	// . "gorgonia.org/gorgonia"
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  // By default, Gorgonia operations do not perform broadcasting.
    12  // To do broadcasting, you would need to manually specify the operation
    13  func ExampleBroadcastAdd() {
    14  	g := NewGraph()
    15  	a := NewVector(g, tensor.Float64, WithShape(2), WithName("a"), WithValue(tensor.New(tensor.WithBacking([]float64{100, 100}))))
    16  	b := NewMatrix(g, tensor.Float64, WithShape(2, 2), WithName("b"), WithValue(tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 1, 2, 2}))))
    17  
    18  	fmt.Printf("a = %v\nb =\n%v\n", a.Value(), b.Value())
    19  
    20  	_, err := Add(a, b)
    21  	fmt.Printf("a + b yields an error: %v\n\n", err)
    22  
    23  	// Note here the broadcasting of a is on the first axis, not the zeroth axis. Simply put, assume that it's already a (2,1) matrix.
    24  	ab, err := BroadcastAdd(a, b, []byte{1}, nil)
    25  	if err != nil {
    26  		fmt.Printf("uh oh, something went wrong: %v\n", err)
    27  	}
    28  
    29  	ba, err := BroadcastAdd(b, a, nil, []byte{1})
    30  	if err != nil {
    31  		fmt.Printf("uh oh, something went wrong: %v\n", err)
    32  	}
    33  
    34  	// Now, let's run the program
    35  	machine := NewTapeMachine(g)
    36  	defer machine.Close()
    37  	if err = machine.RunAll(); err != nil {
    38  		log.Fatal(err)
    39  	}
    40  
    41  	fmt.Printf("a +⃗ b =\n%v\n", ab.Value())
    42  	fmt.Printf("b +⃗ a =\n%v", ba.Value())
    43  
    44  	// Output:
    45  	// a = [100  100]
    46  	// b =
    47  	// ⎡1  1⎤
    48  	// ⎣2  2⎦
    49  	//
    50  	// a + b yields an error: Failed to infer shape. Op: + false: Shape mismatch: (2) and (2, 2)
    51  	//
    52  	// a +⃗ b =
    53  	// ⎡101  101⎤
    54  	// ⎣102  102⎦
    55  	//
    56  	// b +⃗ a =
    57  	// ⎡101  101⎤
    58  	// ⎣102  102⎦
    59  
    60  }
    61  
    62  func ExampleBroadcastGte_creatingTriangleMatrices() {
    63  	// Broadcasting is useful. We can create triangular dense matrices simply
    64  
    65  	g := NewGraph()
    66  	a := NewMatrix(g, tensor.Float64, WithShape(3, 1), WithName("a"), WithInit(RangedFrom(0)))
    67  	b := NewMatrix(g, tensor.Float64, WithShape(1, 4), WithName("b"), WithInit(RangedFrom(0)))
    68  	tl, err := BroadcastGte(a, b, true, []byte{1}, []byte{0})
    69  	if err != nil {
    70  		log.Fatalf("uh oh. Something went wrong %v", err)
    71  	}
    72  
    73  	tu, err := BroadcastLt(a, b, true, []byte{1}, []byte{0})
    74  	if err != nil {
    75  		log.Fatalf("uh oh. Something went wrong %v", err)
    76  	}
    77  
    78  	m := NewTapeMachine(g)
    79  
    80  	// PEDAGOGICAL:
    81  	// Uncomment the following code if you want to see what happens behind the scenes
    82  	// m.Close()
    83  	// logger := log.New(os.Stderr, "",0)
    84  	// m = NewTapeMachine(g, WithLogger(logger), WithWatchlist())
    85  
    86  	defer m.Close()
    87  	if err = m.RunAll(); err != nil {
    88  		log.Fatal(err)
    89  	}
    90  
    91  	fmt.Printf("triangular, lower:\n%v\n", tl.Value())
    92  	fmt.Printf("triangular, upper:\n%v\n", tu.Value())
    93  
    94  	// Output:
    95  	// triangular, lower:
    96  	// ⎡1  0  0  0⎤
    97  	// ⎢1  1  0  0⎥
    98  	// ⎣1  1  1  0⎦
    99  	//
   100  	// triangular, upper:
   101  	// ⎡0  1  1  1⎤
   102  	// ⎢0  0  1  1⎥
   103  	// ⎣0  0  0  1⎦
   104  
   105  }