gorgonia.org/gorgonia@v0.9.17/compile_test.go (about)

     1  package gorgonia
     2  
     3  import "testing"
     4  
     5  func TestCompile_medium(t *testing.T) {
     6  	g := NewGraph()
     7  	x := NewMatrix(g, Float64, WithShape(20, 20), WithName("x"))
     8  	y := NewMatrix(g, Float64, WithShape(20, 20), WithName("y"))
     9  	xpy := Must(Add(x, y))
    10  	xmy := Must(Sub(x, y))
    11  	xpys := Must(Slice(xpy, S(0, 10)))
    12  	Must(Square(xpys))
    13  	xmy2 := Must(Square(xmy))
    14  
    15  	var final Value
    16  	Set(xmy2, xpy)
    17  	Read(xmy2, &final)
    18  
    19  	prog, _, err := Compile(g)
    20  	if err != nil {
    21  		t.Fatalf("error while compiling: %v", err)
    22  	}
    23  	t.Log(prog)
    24  
    25  	onDev := xpy.Device() != CPU
    26  
    27  	// leakage test
    28  	if onDev {
    29  		reg0 := register{device: Device(0), id: 0}
    30  		reg1 := register{device: Device(0), id: 1}
    31  		reg2 := register{device: Device(0), id: 2}
    32  
    33  		if !prog.instructions.has(free{reg0}) {
    34  			t.Error("Expected GPU(0)0 to be freed")
    35  		}
    36  
    37  		if !prog.instructions.has(free{reg1}) {
    38  			t.Error("Expected GPU(0)1 to be freed")
    39  		}
    40  
    41  		if !prog.instructions.has(free{reg2}) {
    42  			t.Error("Expected GPU(0)2 to be freed")
    43  		}
    44  	}
    45  
    46  	// position tests
    47  	if onDev {
    48  		// last two instructions should be free
    49  		if _, ok := prog.instructions[len(prog.instructions)-1].(free); !ok {
    50  			t.Error("Expected last instruction to be a Free")
    51  		}
    52  		if _, ok := prog.instructions[len(prog.instructions)-2].(free); !ok {
    53  			t.Error("Expected second last instruction to be a Free")
    54  		}
    55  
    56  		// frag = prog.m[set]
    57  		// if _, ok := frag[len(frag)-1].(free); !ok {
    58  		// 	t.Error("Expected a `free` instruction after LET")
    59  		// }
    60  
    61  		// frag = prog.m[read]
    62  		// if _, ok := frag[len(frag)-2].(free); !ok {
    63  		// 	t.Error("Expected a `free` instruction after READ")
    64  		// }
    65  	}
    66  }
    67  
    68  func TestCompile_CompileFn(t *testing.T) {
    69  	g := NewGraph()
    70  	x := NewScalar(g, Float32, WithName("x"))
    71  	y := NewScalar(g, Float32, WithName("y"))
    72  	xpy := Must(Add(x, y))
    73  	xmy := Must(Mul(x, y))
    74  	x2 := Must(Square(x))
    75  
    76  	progAll, _, err := Compile(g)
    77  	if err != nil {
    78  		t.Fatal(err)
    79  	}
    80  
    81  	progAdd, _, err := CompileFunction(g, Nodes{x, y}, Nodes{xpy})
    82  	if err != nil {
    83  		t.Fatal(err)
    84  	}
    85  
    86  	progMul, _, err := CompileFunction(g, Nodes{x, y}, Nodes{xmy})
    87  	if err != nil {
    88  		t.Fatal(err)
    89  	}
    90  
    91  	if _, _, err = CompileFunction(g, Nodes{x, y}, Nodes{x2}); err == nil {
    92  		t.Error("expected an error when there is an unused node")
    93  	}
    94  
    95  	// properties based testing
    96  	if len(progAll.sorted) <= len(progAdd.sorted) || len(progAll.sorted) <= len(progMul.sorted) {
    97  		t.Error("progAll should have more nodes included than progAdd or progMul")
    98  	}
    99  
   100  	if len(progAll.instructions) <= len(progAdd.instructions) || len(progAll.instructions) <= len(progMul.instructions) {
   101  		t.Error("progAll should have more instructions than either progAdd or progMul")
   102  	}
   103  
   104  	// really this is more checking of the subgraphing
   105  	if !progAdd.sorted.Contains(x) {
   106  		t.Error("Expected progAdd to contain x")
   107  	}
   108  	if !progAdd.sorted.Contains(y) {
   109  		t.Error("Expected progAdd to contain y")
   110  	}
   111  	if !progAdd.sorted.Contains(xpy) {
   112  		t.Error("Expected progAdd to contain xpy")
   113  	}
   114  	if progAdd.sorted.Contains(xmy) || progAdd.sorted.Contains(x2) {
   115  		t.Error("Expected progAdd to not contain either x2 or xmy")
   116  	}
   117  
   118  	// same as above
   119  	if !progMul.sorted.Contains(x) {
   120  		t.Error("Expected progMul to contain x")
   121  	}
   122  	if !progMul.sorted.Contains(y) {
   123  		t.Error("Expected progMul to contain y")
   124  	}
   125  	if !progMul.sorted.Contains(xmy) {
   126  		t.Error("Expected progMul to contain xmy")
   127  	}
   128  	if progMul.sorted.Contains(xpy) || progMul.sorted.Contains(x2) {
   129  		t.Error("Expected progMul to not contain either x2 or xpy")
   130  	}
   131  }