gorgonia.org/gorgonia@v0.9.17/engine_test.go (about) 1 package gorgonia 2 3 import ( 4 "reflect" 5 "testing" 6 7 "gorgonia.org/tensor" 8 ) 9 10 var stdengType reflect.Type 11 12 func init() { 13 stdengType = reflect.TypeOf(StandardEngine{}) 14 } 15 16 func assertEngine(v Value, eT reflect.Type) bool { 17 te := engineOf(v) 18 if te == nil { 19 return true 20 } 21 teT := reflect.TypeOf(te) 22 return eT == teT 23 } 24 25 func assertGraphEngine(t *testing.T, g *ExprGraph, eT reflect.Type) { 26 for _, n := range g.AllNodes() { 27 if n.isInput() { 28 inputEng := reflect.TypeOf(engineOf(n.Value())) 29 if grad, err := n.Grad(); err == nil { 30 if !assertEngine(grad, inputEng) { 31 t.Errorf("Expected input %v value and gradient to share the same engine %v: Got %T", n.Name(), inputEng, engineOf(grad)) 32 return 33 } 34 } 35 continue 36 } 37 if !assertEngine(n.Value(), eT) { 38 t.Errorf("Expected node %v to be %v. Got %T instead", n, eT, engineOf(n.Value())) 39 return 40 } 41 42 if grad, err := n.Grad(); err == nil { 43 if !assertEngine(grad, eT) { 44 t.Errorf("Expected gradient of node %v to be %v. Got %T instead", n, eT, engineOf(grad)) 45 return 46 } 47 } 48 } 49 } 50 51 func engineOf(v Value) tensor.Engine { 52 if t, ok := v.(tensor.Tensor); ok { 53 return t.Engine() 54 } 55 return nil 56 } 57 58 func TestBasicEngine(t *testing.T) { 59 g, x, y, _ := simpleVecEqn() 60 61 Let(x, tensor.New(tensor.WithBacking([]float64{0, 1}))) 62 Let(y, tensor.New(tensor.WithBacking([]float64{3, 2}))) 63 m := NewTapeMachine(g, TraceExec()) 64 defer m.Close() 65 if err := m.RunAll(); err != nil { 66 t.Fatal(err) 67 } 68 if assertGraphEngine(t, g, stdengType); t.Failed() { 69 t.FailNow() 70 } 71 }