gorgonia.org/gorgonia@v0.9.17/compile_test.go (about) 1 package gorgonia 2 3 import "testing" 4 5 func TestCompile_medium(t *testing.T) { 6 g := NewGraph() 7 x := NewMatrix(g, Float64, WithShape(20, 20), WithName("x")) 8 y := NewMatrix(g, Float64, WithShape(20, 20), WithName("y")) 9 xpy := Must(Add(x, y)) 10 xmy := Must(Sub(x, y)) 11 xpys := Must(Slice(xpy, S(0, 10))) 12 Must(Square(xpys)) 13 xmy2 := Must(Square(xmy)) 14 15 var final Value 16 Set(xmy2, xpy) 17 Read(xmy2, &final) 18 19 prog, _, err := Compile(g) 20 if err != nil { 21 t.Fatalf("error while compiling: %v", err) 22 } 23 t.Log(prog) 24 25 onDev := xpy.Device() != CPU 26 27 // leakage test 28 if onDev { 29 reg0 := register{device: Device(0), id: 0} 30 reg1 := register{device: Device(0), id: 1} 31 reg2 := register{device: Device(0), id: 2} 32 33 if !prog.instructions.has(free{reg0}) { 34 t.Error("Expected GPU(0)0 to be freed") 35 } 36 37 if !prog.instructions.has(free{reg1}) { 38 t.Error("Expected GPU(0)1 to be freed") 39 } 40 41 if !prog.instructions.has(free{reg2}) { 42 t.Error("Expected GPU(0)2 to be freed") 43 } 44 } 45 46 // position tests 47 if onDev { 48 // last two instructions should be free 49 if _, ok := prog.instructions[len(prog.instructions)-1].(free); !ok { 50 t.Error("Expected last instruction to be a Free") 51 } 52 if _, ok := prog.instructions[len(prog.instructions)-2].(free); !ok { 53 t.Error("Expected second last instruction to be a Free") 54 } 55 56 // frag = prog.m[set] 57 // if _, ok := frag[len(frag)-1].(free); !ok { 58 // t.Error("Expected a `free` instruction after LET") 59 // } 60 61 // frag = prog.m[read] 62 // if _, ok := frag[len(frag)-2].(free); !ok { 63 // t.Error("Expected a `free` instruction after READ") 64 // } 65 } 66 } 67 68 func TestCompile_CompileFn(t *testing.T) { 69 g := NewGraph() 70 x := NewScalar(g, Float32, WithName("x")) 71 y := NewScalar(g, Float32, WithName("y")) 72 xpy := Must(Add(x, y)) 73 xmy := Must(Mul(x, y)) 74 x2 := Must(Square(x)) 75 76 progAll, _, err := Compile(g) 77 if err != nil { 78 t.Fatal(err) 79 } 80 81 progAdd, _, err := CompileFunction(g, Nodes{x, y}, Nodes{xpy}) 82 if err != nil { 83 t.Fatal(err) 84 } 85 86 progMul, _, err := CompileFunction(g, Nodes{x, y}, Nodes{xmy}) 87 if err != nil { 88 t.Fatal(err) 89 } 90 91 if _, _, err = CompileFunction(g, Nodes{x, y}, Nodes{x2}); err == nil { 92 t.Error("expected an error when there is an unused node") 93 } 94 95 // properties based testing 96 if len(progAll.sorted) <= len(progAdd.sorted) || len(progAll.sorted) <= len(progMul.sorted) { 97 t.Error("progAll should have more nodes included than progAdd or progMul") 98 } 99 100 if len(progAll.instructions) <= len(progAdd.instructions) || len(progAll.instructions) <= len(progMul.instructions) { 101 t.Error("progAll should have more instructions than either progAdd or progMul") 102 } 103 104 // really this is more checking of the subgraphing 105 if !progAdd.sorted.Contains(x) { 106 t.Error("Expected progAdd to contain x") 107 } 108 if !progAdd.sorted.Contains(y) { 109 t.Error("Expected progAdd to contain y") 110 } 111 if !progAdd.sorted.Contains(xpy) { 112 t.Error("Expected progAdd to contain xpy") 113 } 114 if progAdd.sorted.Contains(xmy) || progAdd.sorted.Contains(x2) { 115 t.Error("Expected progAdd to not contain either x2 or xmy") 116 } 117 118 // same as above 119 if !progMul.sorted.Contains(x) { 120 t.Error("Expected progMul to contain x") 121 } 122 if !progMul.sorted.Contains(y) { 123 t.Error("Expected progMul to contain y") 124 } 125 if !progMul.sorted.Contains(xmy) { 126 t.Error("Expected progMul to contain xmy") 127 } 128 if progMul.sorted.Contains(xpy) || progMul.sorted.Contains(x2) { 129 t.Error("Expected progMul to not contain either x2 or xpy") 130 } 131 }