gorgonia.org/gorgonia@v0.9.17/example_err_test.go (about) 1 package gorgonia_test 2 3 import ( 4 "fmt" 5 6 . "gorgonia.org/gorgonia" 7 ) 8 9 // Gorgonia provides an API that is fairly idiomatic - most of the functions in in the API return (T, error). 10 // This is useful for many cases, such as an interactive shell for deep learning. 11 // However, it must also be acknowledged that this makes composing functions together a bit cumbersome. 12 // 13 // To that end, Gorgonia provides two alternative methods. First, the `Lift` based functions; Second the `Must` function 14 func Example_errorHandling() { 15 // Lift 16 g := NewGraph() 17 x := NewMatrix(g, Float32, WithShape(2, 3), WithInit(RangedFrom(0)), WithName("a")) 18 y := NewMatrix(g, Float32, WithShape(3, 2), WithInit(ValuesOf(float32(2))), WithName("b")) 19 z := NewMatrix(g, Float32, WithShape(2, 1), WithInit(Zeroes()), WithName("bias")) 20 wrong := NewMatrix(g, Float64, WithShape(2, 3), WithInit(RangedFrom(0)), WithName("wrong")) 21 22 // Different LiftXXX functions exist for different API signatures 23 // A good way to do this is to have some instantiated functions at the top level of the package 24 mul := Lift2(Mul) 25 add := Lift2(Add) 26 addB := Lift2Broadcast(BroadcastAdd) 27 sq := Lift1(Square) 28 sm := Lift1Axial(SoftMax) 29 30 nn := sm(sq(addB(mul(x, y), z, nil, []byte{1}))) // OK 31 nnPlusWrong := add(nn, wrong) // Wrong types. Will Error 32 fmt.Printf("nn: %v\nAn error occurs: %v\n", nn, nnPlusWrong.Err()) 33 34 // Must() 35 h := NewGraph() 36 a := NewMatrix(h, Float32, WithShape(2, 3), WithInit(RangedFrom(0)), WithName("a")) 37 b := NewMatrix(h, Float32, WithShape(3, 2), WithInit(ValuesOf(float32(2))), WithName("b")) 38 c := NewMatrix(h, Float32, WithShape(2, 1), WithInit(RangedFrom(0)), WithName("c")) 39 wrong2 := NewMatrix(h, Float64, WithShape(2, 3), WithInit(RangedFrom(0)), WithName("wrong")) 40 41 // This is OK 42 nn2 := Must(SoftMax( 43 Must(Square( 44 Must(BroadcastAdd( 45 Must(Mul(a, b)), 46 c, 47 nil, []byte{1}, 48 )), 49 )), 50 )) 51 52 fmt.Printf("nn2: %v\n", nn2) 53 54 defer func() { 55 if r := recover(); r != nil { 56 fmt.Printf("An error occurs (caught by recover()): %v\n", r) 57 } 58 }() 59 60 nn2PlusWrong := Must(Add(nn2, wrong2)) 61 _ = nn2PlusWrong 62 63 // Output: 64 // nn: Softmax{-1, false}()(%9) :: Matrix float32 65 // An error occurs: Type inference error. Op: + false. Children: [Matrix float32, Matrix float64], OpType:Matrix a → Matrix a → Matrix a: Unable to unify while inferring type of + false: Unification Fail: float64 ~ float32 cannot be unified 66 // nn2: Softmax{-1, false}()(%9) :: Matrix float32 67 // An error occurs (caught by recover()): Type inference error. Op: + false. Children: [Matrix float32, Matrix float64], OpType:Matrix a → Matrix a → Matrix a: Unable to unify while inferring type of + false: Unification Fail: float64 ~ float32 cannot be unified 68 69 }