gorgonia.org/gorgonia@v0.9.17/engine_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"reflect"
     5  	"testing"
     6  
     7  	"gorgonia.org/tensor"
     8  )
     9  
    10  var stdengType reflect.Type
    11  
    12  func init() {
    13  	stdengType = reflect.TypeOf(StandardEngine{})
    14  }
    15  
    16  func assertEngine(v Value, eT reflect.Type) bool {
    17  	te := engineOf(v)
    18  	if te == nil {
    19  		return true
    20  	}
    21  	teT := reflect.TypeOf(te)
    22  	return eT == teT
    23  }
    24  
    25  func assertGraphEngine(t *testing.T, g *ExprGraph, eT reflect.Type) {
    26  	for _, n := range g.AllNodes() {
    27  		if n.isInput() {
    28  			inputEng := reflect.TypeOf(engineOf(n.Value()))
    29  			if grad, err := n.Grad(); err == nil {
    30  				if !assertEngine(grad, inputEng) {
    31  					t.Errorf("Expected input %v value and gradient to share the same engine %v: Got %T", n.Name(), inputEng, engineOf(grad))
    32  					return
    33  				}
    34  			}
    35  			continue
    36  		}
    37  		if !assertEngine(n.Value(), eT) {
    38  			t.Errorf("Expected node %v to be %v. Got %T instead", n, eT, engineOf(n.Value()))
    39  			return
    40  		}
    41  
    42  		if grad, err := n.Grad(); err == nil {
    43  			if !assertEngine(grad, eT) {
    44  				t.Errorf("Expected gradient of node %v to be %v. Got %T instead", n, eT, engineOf(grad))
    45  				return
    46  			}
    47  		}
    48  	}
    49  }
    50  
    51  func engineOf(v Value) tensor.Engine {
    52  	if t, ok := v.(tensor.Tensor); ok {
    53  		return t.Engine()
    54  	}
    55  	return nil
    56  }
    57  
    58  func TestBasicEngine(t *testing.T) {
    59  	g, x, y, _ := simpleVecEqn()
    60  
    61  	Let(x, tensor.New(tensor.WithBacking([]float64{0, 1})))
    62  	Let(y, tensor.New(tensor.WithBacking([]float64{3, 2})))
    63  	m := NewTapeMachine(g, TraceExec())
    64  	defer m.Close()
    65  	if err := m.RunAll(); err != nil {
    66  		t.Fatal(err)
    67  	}
    68  	if assertGraphEngine(t, g, stdengType); t.Failed() {
    69  		t.FailNow()
    70  	}
    71  }