gorgonia.org/gorgonia@v0.9.17/example_monad_test.go (about)

     1  package gorgonia_test
     2  
     3  import (
     4  	"fmt"
     5  
     6  	. "gorgonia.org/gorgonia"
     7  )
     8  
     9  // This example showcases the reasons for the more confusing functions.
    10  func Example_monad_raison_detre() {
    11  	// The main reason for the following function is to make it easier to create APIs.
    12  	// Gorgonia;s APIs are very explicit hence not very user friendly.
    13  
    14  	const (
    15  		n        = 32
    16  		features = 784
    17  		size     = 100
    18  	)
    19  
    20  	// The following is an example of how to set up a neural network
    21  
    22  	// First, we set up the components
    23  	g := NewGraph()
    24  	w1 := NewMatrix(g, Float32, WithShape(features, size), WithName("w"), WithInit(GlorotU(1)))
    25  	b1 := NewMatrix(g, Float32, WithShape(1, size), WithName("b"), WithInit(Zeroes()))
    26  	x1 := NewMatrix(g, Float32, WithShape(n, features), WithName("x"))
    27  
    28  	// Then we write the expression:
    29  	var xw, xwb, act *Node
    30  	var err error
    31  	if xw, err = Mul(x1, w1); err != nil {
    32  		fmt.Printf("Err while Mul: %v\n", err)
    33  	}
    34  	if xwb, err = BroadcastAdd(xw, b1, nil, []byte{0}); err != nil {
    35  		fmt.Printf("Err while Add: %v\n", err)
    36  	}
    37  	if act, err = Tanh(xwb); err != nil {
    38  		fmt.Printf("Err while Tanh: %v\n", err)
    39  	}
    40  	fmt.Printf("act is a %T\n", act)
    41  
    42  	// The following is how to set up the exact same network
    43  
    44  	// First we set up our environment
    45  	//
    46  	// These LiftXXX functions transforms Gorgonia's default API into functions that return `Result`
    47  	var mul = Lift2(Mul)                   // Lift2 turns a func(*Node, *Node) (*Node, error)
    48  	var tanh = Lift1(Tanh)                 // Lift1 turns a func(*Node) (*Node, error)
    49  	var add = Lift2Broadcast(BroadcastAdd) // Lift2Broadcast turns a func(*Node, *Node, []byte, []byte) (*Nide, error)
    50  
    51  	// First we set up the components
    52  	h := NewGraph()
    53  	w2 := NewMatrix(h, Float32, WithShape(features, size), WithName("w"), WithInit(GlorotU(1)))
    54  	b2 := NewMatrix(h, Float32, WithShape(1, size), WithName("b"), WithInit(Zeroes()))
    55  	x2 := NewMatrix(h, Float32, WithShape(n, features), WithName("x"))
    56  
    57  	// Then we write the expression
    58  	act2 := tanh(add(mul(x2, w2), b2, nil, []byte{0}))
    59  	fmt.Printf("act2 is a %T (note it's wrapped in the `Result` type)\n", act2)
    60  	fmt.Println()
    61  	// both g and h are the same graph but the expression is easier to write for act2
    62  	fmt.Printf("Both g and h are the same graph:\ng: %v\nh: %v\n", g.AllNodes(), h.AllNodes())
    63  
    64  	// Output:
    65  	// act is a *gorgonia.Node
    66  	// act2 is a *gorgonia.Node (note it's wrapped in the `Result` type)
    67  	//
    68  	// Both g and h are the same graph:
    69  	// g: [w, b, x, A × B(%2, %0), Reshape(1, 100)(%1), SizeOf=32(%3), Repeat0(%4, %5), + false(%3, %6), tanh(%7)]
    70  	// h: [w, b, x, A × B(%2, %0), Reshape(1, 100)(%1), SizeOf=32(%3), Repeat0(%4, %5), + false(%3, %6), tanh(%7)]
    71  }
    72  
    73  // This example showcases dealing with errors. This is part 2 of the raison d'être of the more complicated functions - dealing with errors
    74  func Example_monad_raison_detre_errors() {
    75  	// Observe that in a similar example, errors are manually controllable in the original case,
    76  	// but automated in the second case
    77  	const (
    78  		n        = 32
    79  		features = 784
    80  		size     = 100
    81  	)
    82  
    83  	// The following is an example of how to set up a neural network
    84  
    85  	// First, we set up the components
    86  	g := NewGraph()
    87  	w1 := NewMatrix(g, Float32, WithShape(features, size), WithName("w"), WithInit(GlorotU(1)))
    88  	b1 := NewMatrix(g, Float32, WithShape(1, size), WithName("b"), WithInit(Zeroes()))
    89  	x1 := NewMatrix(g, Float32, WithShape(n, features), WithName("x"))
    90  
    91  	// Then we write the expression:
    92  	var xw, xwb, act *Node
    93  	var err error
    94  	if xw, err = Mul(x1, w1); err != nil {
    95  		fmt.Printf("Err while Mul: %v\n", err)
    96  	}
    97  	// we introduce an error here - it should be []byte{0}
    98  	if xwb, err = BroadcastAdd(xw, b1, nil, []byte{1}); err != nil {
    99  		fmt.Printf("Err while Add: %v\n", err)
   100  		goto case2
   101  	}
   102  	if act, err = Tanh(xwb); err != nil {
   103  		fmt.Printf("Err while Tanh: %v\n", err)
   104  	}
   105  	_ = act // will never happen
   106  
   107  case2:
   108  
   109  	// The following is how to set up the exact same network
   110  
   111  	// First we set up our environment
   112  	//
   113  	// Now, remember all these functions no longer return (*Node, error). Instead they return `Result`
   114  	var mul = Lift2(Mul)
   115  	var tanh = Lift1(Tanh)
   116  	var add = Lift2Broadcast(BroadcastAdd)
   117  
   118  	// First we set up the components
   119  	h := NewGraph()
   120  	w2 := NewMatrix(h, Float32, WithShape(features, size), WithName("w"), WithInit(GlorotU(1)))
   121  	b2 := NewMatrix(h, Float32, WithShape(1, size), WithName("b"), WithInit(Zeroes()))
   122  	x2 := NewMatrix(h, Float32, WithShape(n, features), WithName("x"))
   123  
   124  	// Then we write the expression
   125  	act2 := tanh(add(mul(x2, w2), b2, nil, []byte{1}))
   126  
   127  	// REMEMBER: act2 is not a *Node! It is a Result
   128  	fmt.Printf("act2: %v\n", act2)
   129  
   130  	// To extract error, use CheckOne
   131  	fmt.Printf("error: %v\n", CheckOne(act2))
   132  
   133  	// If you extract the *Node from an error, you get nil
   134  	fmt.Printf("Node: %v\n", act2.Node())
   135  
   136  	// Output:
   137  	// Err while Add: Failed to infer shape. Op: + false: Shape mismatch: (32, 100) and (1, 10000)
   138  	// act2: Failed to infer shape. Op: + false: Shape mismatch: (32, 100) and (1, 10000)
   139  	// error: Failed to infer shape. Op: + false: Shape mismatch: (32, 100) and (1, 10000)
   140  	// Node: <nil>
   141  }