gorgonia.org/gorgonia@v0.9.17/example_broadcast_op_test.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "log" 6 7 // . "gorgonia.org/gorgonia" 8 "gorgonia.org/tensor" 9 ) 10 11 // By default, Gorgonia operations do not perform broadcasting. 12 // To do broadcasting, you would need to manually specify the operation 13 func ExampleBroadcastAdd() { 14 g := NewGraph() 15 a := NewVector(g, tensor.Float64, WithShape(2), WithName("a"), WithValue(tensor.New(tensor.WithBacking([]float64{100, 100})))) 16 b := NewMatrix(g, tensor.Float64, WithShape(2, 2), WithName("b"), WithValue(tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 1, 2, 2})))) 17 18 fmt.Printf("a = %v\nb =\n%v\n", a.Value(), b.Value()) 19 20 _, err := Add(a, b) 21 fmt.Printf("a + b yields an error: %v\n\n", err) 22 23 // Note here the broadcasting of a is on the first axis, not the zeroth axis. Simply put, assume that it's already a (2,1) matrix. 24 ab, err := BroadcastAdd(a, b, []byte{1}, nil) 25 if err != nil { 26 fmt.Printf("uh oh, something went wrong: %v\n", err) 27 } 28 29 ba, err := BroadcastAdd(b, a, nil, []byte{1}) 30 if err != nil { 31 fmt.Printf("uh oh, something went wrong: %v\n", err) 32 } 33 34 // Now, let's run the program 35 machine := NewTapeMachine(g) 36 defer machine.Close() 37 if err = machine.RunAll(); err != nil { 38 log.Fatal(err) 39 } 40 41 fmt.Printf("a +⃗ b =\n%v\n", ab.Value()) 42 fmt.Printf("b +⃗ a =\n%v", ba.Value()) 43 44 // Output: 45 // a = [100 100] 46 // b = 47 // ⎡1 1⎤ 48 // ⎣2 2⎦ 49 // 50 // a + b yields an error: Failed to infer shape. Op: + false: Shape mismatch: (2) and (2, 2) 51 // 52 // a +⃗ b = 53 // ⎡101 101⎤ 54 // ⎣102 102⎦ 55 // 56 // b +⃗ a = 57 // ⎡101 101⎤ 58 // ⎣102 102⎦ 59 60 } 61 62 func ExampleBroadcastGte_creatingTriangleMatrices() { 63 // Broadcasting is useful. We can create triangular dense matrices simply 64 65 g := NewGraph() 66 a := NewMatrix(g, tensor.Float64, WithShape(3, 1), WithName("a"), WithInit(RangedFrom(0))) 67 b := NewMatrix(g, tensor.Float64, WithShape(1, 4), WithName("b"), WithInit(RangedFrom(0))) 68 tl, err := BroadcastGte(a, b, true, []byte{1}, []byte{0}) 69 if err != nil { 70 log.Fatalf("uh oh. Something went wrong %v", err) 71 } 72 73 tu, err := BroadcastLt(a, b, true, []byte{1}, []byte{0}) 74 if err != nil { 75 log.Fatalf("uh oh. Something went wrong %v", err) 76 } 77 78 m := NewTapeMachine(g) 79 80 // PEDAGOGICAL: 81 // Uncomment the following code if you want to see what happens behind the scenes 82 // m.Close() 83 // logger := log.New(os.Stderr, "",0) 84 // m = NewTapeMachine(g, WithLogger(logger), WithWatchlist()) 85 86 defer m.Close() 87 if err = m.RunAll(); err != nil { 88 log.Fatal(err) 89 } 90 91 fmt.Printf("triangular, lower:\n%v\n", tl.Value()) 92 fmt.Printf("triangular, upper:\n%v\n", tu.Value()) 93 94 // Output: 95 // triangular, lower: 96 // ⎡1 0 0 0⎤ 97 // ⎢1 1 0 0⎥ 98 // ⎣1 1 1 0⎦ 99 // 100 // triangular, upper: 101 // ⎡0 1 1 1⎤ 102 // ⎢0 0 1 1⎥ 103 // ⎣0 0 0 1⎦ 104 105 }