gorgonia.org/gorgonia@v0.9.17/example_err_test.go (about)

     1  package gorgonia_test
     2  
     3  import (
     4  	"fmt"
     5  
     6  	. "gorgonia.org/gorgonia"
     7  )
     8  
     9  // Gorgonia provides an API that is fairly idiomatic - most of the functions in in the API return (T, error).
    10  // This is useful for many cases, such as an interactive shell for deep learning.
    11  // However, it must also be acknowledged that this makes composing functions together a bit cumbersome.
    12  //
    13  // To that end, Gorgonia provides two alternative methods. First, the `Lift` based functions; Second the `Must` function
    14  func Example_errorHandling() {
    15  	// Lift
    16  	g := NewGraph()
    17  	x := NewMatrix(g, Float32, WithShape(2, 3), WithInit(RangedFrom(0)), WithName("a"))
    18  	y := NewMatrix(g, Float32, WithShape(3, 2), WithInit(ValuesOf(float32(2))), WithName("b"))
    19  	z := NewMatrix(g, Float32, WithShape(2, 1), WithInit(Zeroes()), WithName("bias"))
    20  	wrong := NewMatrix(g, Float64, WithShape(2, 3), WithInit(RangedFrom(0)), WithName("wrong"))
    21  
    22  	// Different LiftXXX functions exist for different API signatures
    23  	// A good way to do this is to have some instantiated functions at the top level of the package
    24  	mul := Lift2(Mul)
    25  	add := Lift2(Add)
    26  	addB := Lift2Broadcast(BroadcastAdd)
    27  	sq := Lift1(Square)
    28  	sm := Lift1Axial(SoftMax)
    29  
    30  	nn := sm(sq(addB(mul(x, y), z, nil, []byte{1}))) // OK
    31  	nnPlusWrong := add(nn, wrong)                    // Wrong types. Will Error
    32  	fmt.Printf("nn: %v\nAn error occurs: %v\n", nn, nnPlusWrong.Err())
    33  
    34  	// Must()
    35  	h := NewGraph()
    36  	a := NewMatrix(h, Float32, WithShape(2, 3), WithInit(RangedFrom(0)), WithName("a"))
    37  	b := NewMatrix(h, Float32, WithShape(3, 2), WithInit(ValuesOf(float32(2))), WithName("b"))
    38  	c := NewMatrix(h, Float32, WithShape(2, 1), WithInit(RangedFrom(0)), WithName("c"))
    39  	wrong2 := NewMatrix(h, Float64, WithShape(2, 3), WithInit(RangedFrom(0)), WithName("wrong"))
    40  
    41  	// This is OK
    42  	nn2 := Must(SoftMax(
    43  		Must(Square(
    44  			Must(BroadcastAdd(
    45  				Must(Mul(a, b)),
    46  				c,
    47  				nil, []byte{1},
    48  			)),
    49  		)),
    50  	))
    51  
    52  	fmt.Printf("nn2: %v\n", nn2)
    53  
    54  	defer func() {
    55  		if r := recover(); r != nil {
    56  			fmt.Printf("An error occurs (caught by recover()): %v\n", r)
    57  		}
    58  	}()
    59  
    60  	nn2PlusWrong := Must(Add(nn2, wrong2))
    61  	_ = nn2PlusWrong
    62  
    63  	// Output:
    64  	// nn: Softmax{-1, false}()(%9) :: Matrix float32
    65  	// An error occurs: Type inference error. Op: + false. Children: [Matrix float32, Matrix float64], OpType:Matrix a → Matrix a → Matrix a: Unable to unify while inferring type of + false: Unification Fail: float64 ~ float32 cannot be unified
    66  	// nn2: Softmax{-1, false}()(%9) :: Matrix float32
    67  	// An error occurs (caught by recover()): Type inference error. Op: + false. Children: [Matrix float32, Matrix float64], OpType:Matrix a → Matrix a → Matrix a: Unable to unify while inferring type of + false: Unification Fail: float64 ~ float32 cannot be unified
    68  
    69  }