gorgonia.org/gorgonia@v0.9.17/vm_tape_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"reflect"
     5  	"testing"
     6  )
     7  
     8  func Test_tapeMachine_Reset(t *testing.T) {
     9  	g := NewGraph()
    10  
    11  	var x, y, z *Node
    12  	var err error
    13  
    14  	// define the expression
    15  	x = NewScalar(g, Float64, WithName("x"))
    16  	y = NewScalar(g, Float64, WithName("y"))
    17  	if z, err = Add(x, y); err != nil {
    18  		t.Fatal(err)
    19  	}
    20  
    21  	// create a VM to run the program on
    22  	m1 := NewTapeMachine(g)
    23  	m2 := NewTapeMachine(g)
    24  	defer m1.Close()
    25  	defer m2.Close()
    26  
    27  	// set initial values then run
    28  	Let(x, 2.0)
    29  	Let(y, 2.5)
    30  	if err = m1.RunAll(); err != nil {
    31  		t.Fatal(err)
    32  	}
    33  	if z.Value().Data().(float64) != 4.5 {
    34  		t.Fatalf("Expected %v, got %v", 4.5, z.Value())
    35  	}
    36  	m1.Reset()
    37  	if !reflect.DeepEqual(m1.locMap, m2.locMap) {
    38  		t.Fatalf("expected locmap\n\n%#v, got\n\n%#v", m1, m2)
    39  	}
    40  	if !reflect.DeepEqual(m1.p, m2.p) {
    41  		t.Fatalf("expected program\n\n%#v, got\n\n%#v", m1, m2)
    42  	}
    43  	if !reflect.DeepEqual(m1.cpumem, m2.cpumem) {
    44  		t.Fatalf("expected cpumem\n\n%#v, got\n\n%#v", m1, m2)
    45  	}
    46  	if !reflect.DeepEqual(m1.gpumem, m2.gpumem) {
    47  		t.Fatalf("expected gpumem\n\n%#v, got\n\n%#v", m1, m2)
    48  	}
    49  	if !reflect.DeepEqual(m1.pc, m2.pc) {
    50  		t.Fatalf("expected pc\n\n%#v, got\n\n%#v", m1, m2)
    51  	}
    52  }