gorgonia.org/gorgonia@v0.9.17/example_monad_test.go (about) 1 package gorgonia_test 2 3 import ( 4 "fmt" 5 6 . "gorgonia.org/gorgonia" 7 ) 8 9 // This example showcases the reasons for the more confusing functions. 10 func Example_monad_raison_detre() { 11 // The main reason for the following function is to make it easier to create APIs. 12 // Gorgonia;s APIs are very explicit hence not very user friendly. 13 14 const ( 15 n = 32 16 features = 784 17 size = 100 18 ) 19 20 // The following is an example of how to set up a neural network 21 22 // First, we set up the components 23 g := NewGraph() 24 w1 := NewMatrix(g, Float32, WithShape(features, size), WithName("w"), WithInit(GlorotU(1))) 25 b1 := NewMatrix(g, Float32, WithShape(1, size), WithName("b"), WithInit(Zeroes())) 26 x1 := NewMatrix(g, Float32, WithShape(n, features), WithName("x")) 27 28 // Then we write the expression: 29 var xw, xwb, act *Node 30 var err error 31 if xw, err = Mul(x1, w1); err != nil { 32 fmt.Printf("Err while Mul: %v\n", err) 33 } 34 if xwb, err = BroadcastAdd(xw, b1, nil, []byte{0}); err != nil { 35 fmt.Printf("Err while Add: %v\n", err) 36 } 37 if act, err = Tanh(xwb); err != nil { 38 fmt.Printf("Err while Tanh: %v\n", err) 39 } 40 fmt.Printf("act is a %T\n", act) 41 42 // The following is how to set up the exact same network 43 44 // First we set up our environment 45 // 46 // These LiftXXX functions transforms Gorgonia's default API into functions that return `Result` 47 var mul = Lift2(Mul) // Lift2 turns a func(*Node, *Node) (*Node, error) 48 var tanh = Lift1(Tanh) // Lift1 turns a func(*Node) (*Node, error) 49 var add = Lift2Broadcast(BroadcastAdd) // Lift2Broadcast turns a func(*Node, *Node, []byte, []byte) (*Nide, error) 50 51 // First we set up the components 52 h := NewGraph() 53 w2 := NewMatrix(h, Float32, WithShape(features, size), WithName("w"), WithInit(GlorotU(1))) 54 b2 := NewMatrix(h, Float32, WithShape(1, size), WithName("b"), WithInit(Zeroes())) 55 x2 := NewMatrix(h, Float32, WithShape(n, features), WithName("x")) 56 57 // Then we write the expression 58 act2 := tanh(add(mul(x2, w2), b2, nil, []byte{0})) 59 fmt.Printf("act2 is a %T (note it's wrapped in the `Result` type)\n", act2) 60 fmt.Println() 61 // both g and h are the same graph but the expression is easier to write for act2 62 fmt.Printf("Both g and h are the same graph:\ng: %v\nh: %v\n", g.AllNodes(), h.AllNodes()) 63 64 // Output: 65 // act is a *gorgonia.Node 66 // act2 is a *gorgonia.Node (note it's wrapped in the `Result` type) 67 // 68 // Both g and h are the same graph: 69 // g: [w, b, x, A × B(%2, %0), Reshape(1, 100)(%1), SizeOf=32(%3), Repeat0(%4, %5), + false(%3, %6), tanh(%7)] 70 // h: [w, b, x, A × B(%2, %0), Reshape(1, 100)(%1), SizeOf=32(%3), Repeat0(%4, %5), + false(%3, %6), tanh(%7)] 71 } 72 73 // This example showcases dealing with errors. This is part 2 of the raison d'être of the more complicated functions - dealing with errors 74 func Example_monad_raison_detre_errors() { 75 // Observe that in a similar example, errors are manually controllable in the original case, 76 // but automated in the second case 77 const ( 78 n = 32 79 features = 784 80 size = 100 81 ) 82 83 // The following is an example of how to set up a neural network 84 85 // First, we set up the components 86 g := NewGraph() 87 w1 := NewMatrix(g, Float32, WithShape(features, size), WithName("w"), WithInit(GlorotU(1))) 88 b1 := NewMatrix(g, Float32, WithShape(1, size), WithName("b"), WithInit(Zeroes())) 89 x1 := NewMatrix(g, Float32, WithShape(n, features), WithName("x")) 90 91 // Then we write the expression: 92 var xw, xwb, act *Node 93 var err error 94 if xw, err = Mul(x1, w1); err != nil { 95 fmt.Printf("Err while Mul: %v\n", err) 96 } 97 // we introduce an error here - it should be []byte{0} 98 if xwb, err = BroadcastAdd(xw, b1, nil, []byte{1}); err != nil { 99 fmt.Printf("Err while Add: %v\n", err) 100 goto case2 101 } 102 if act, err = Tanh(xwb); err != nil { 103 fmt.Printf("Err while Tanh: %v\n", err) 104 } 105 _ = act // will never happen 106 107 case2: 108 109 // The following is how to set up the exact same network 110 111 // First we set up our environment 112 // 113 // Now, remember all these functions no longer return (*Node, error). Instead they return `Result` 114 var mul = Lift2(Mul) 115 var tanh = Lift1(Tanh) 116 var add = Lift2Broadcast(BroadcastAdd) 117 118 // First we set up the components 119 h := NewGraph() 120 w2 := NewMatrix(h, Float32, WithShape(features, size), WithName("w"), WithInit(GlorotU(1))) 121 b2 := NewMatrix(h, Float32, WithShape(1, size), WithName("b"), WithInit(Zeroes())) 122 x2 := NewMatrix(h, Float32, WithShape(n, features), WithName("x")) 123 124 // Then we write the expression 125 act2 := tanh(add(mul(x2, w2), b2, nil, []byte{1})) 126 127 // REMEMBER: act2 is not a *Node! It is a Result 128 fmt.Printf("act2: %v\n", act2) 129 130 // To extract error, use CheckOne 131 fmt.Printf("error: %v\n", CheckOne(act2)) 132 133 // If you extract the *Node from an error, you get nil 134 fmt.Printf("Node: %v\n", act2.Node()) 135 136 // Output: 137 // Err while Add: Failed to infer shape. Op: + false: Shape mismatch: (32, 100) and (1, 10000) 138 // act2: Failed to infer shape. Op: + false: Shape mismatch: (32, 100) and (1, 10000) 139 // error: Failed to infer shape. Op: + false: Shape mismatch: (32, 100) and (1, 10000) 140 // Node: <nil> 141 }