gorgonia.org/gorgonia@v0.9.17/analysis_test.go (about) 1 package gorgonia 2 3 import ( 4 "bytes" 5 "fmt" 6 "testing" 7 8 "github.com/stretchr/testify/assert" 9 ) 10 11 func TestBuildIntervals(t *testing.T) { 12 assert := assert.New(t) 13 var err error 14 g, x, y, z := simpleVecEqn() 15 var readVal Value 16 r := Read(z, &readVal) 17 18 z2 := Must(Square(z)) 19 z2y := Must(HadamardProd(z2, y)) 20 c := NewConstant(1.0, WithName("FOOO")) // const 21 g.addToAll(c) // this is a hack because there is no good way to get a constant into a graph since In() won't work on constatns 22 23 // because sorting is unstable, we need to test many times 24 var sorted Nodes 25 var intervals map[*Node]*interval 26 27 for i := 0; i < 100; i++ { 28 if sorted, err = Sort(g); err != nil { 29 t.Fatal(err) 30 } 31 reverseNodes(sorted) 32 33 df := analyze(g, sorted) 34 df.buildIntervals(sorted) 35 df.debugIntervals(sorted) // prints intervals on debug mode 36 intervals = df.intervals 37 38 // inputs are live until the last instruction 39 assert.Equal(len(intervals), intervals[x].end, "%v", len(sorted)) 40 if intervals[x].start != 1 && intervals[x].start != 0 { 41 t.Errorf("x starts at 1 or 0 (depending on how the sort allocates it)") 42 } 43 44 assert.Equal(len(g.AllNodes()), intervals[y].end) 45 if intervals[y].start != 1 && intervals[y].start != 0 { 46 t.Errorf("y starts at 1 or 0 (depending on how the sort allocates it)") 47 } 48 49 // constants should be live until the last instruction 50 assert.Equal(len(intervals), intervals[c].end, "%v", len(sorted)) 51 52 assert.Equal(2, intervals[z].start) 53 if intervals[z2].start > intervals[z].end { 54 t.Error("z2 should start before z ends") 55 } 56 57 assert.Equal(intervals[r].start, intervals[r].end) 58 if intervals[r].start < intervals[z].start { 59 t.Error("z should have an earlier start than r") 60 } 61 if intervals[r].start > intervals[z].end { 62 t.Error("z should end before r starts (or at the same as r start") 63 } 64 65 if intervals[z2].end <= intervals[z2].start { 66 t.Error("Given that z2y uses z2, the intervals should not end at the same as its start") 67 } 68 if intervals[z2].start < intervals[z].start { 69 t.Error("z should have an earlier start than z2") 70 } 71 if intervals[z2].start > intervals[z].end { 72 t.Error("z should end before r starts (or at the same as z2 start") 73 } 74 75 assert.Equal(intervals[z2y].start, intervals[z2y].end) 76 if intervals[z2y].start < intervals[z2].start { 77 t.Error("z2 should have an earlier start than z2y") 78 } 79 if intervals[z2y].start > intervals[z2].end { 80 t.Error("z2 should end before r starts (or at the same as z2y start") 81 } 82 83 if t.Failed() { 84 break 85 } 86 87 } 88 89 // visual reminder 90 var buf bytes.Buffer 91 buf.WriteString("VISUAL REMINDER OF INTERVALS\n") 92 sorted.reverse() 93 for i, n := range sorted { 94 in := intervals[n] 95 fmt.Fprintf(&buf, "%d\t%v\tfrom %v to %v \n", i, n, in.start, in.end) 96 97 } 98 t.Log(buf.String()) 99 }