gorgonia.org/gorgonia@v0.9.17/analysis_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  )
    10  
    11  func TestBuildIntervals(t *testing.T) {
    12  	assert := assert.New(t)
    13  	var err error
    14  	g, x, y, z := simpleVecEqn()
    15  	var readVal Value
    16  	r := Read(z, &readVal)
    17  
    18  	z2 := Must(Square(z))
    19  	z2y := Must(HadamardProd(z2, y))
    20  	c := NewConstant(1.0, WithName("FOOO")) // const
    21  	g.addToAll(c)                           // this is a hack because there is no good way to get a constant into a graph since In() won't work on constatns
    22  
    23  	// because sorting is unstable, we need to test many times
    24  	var sorted Nodes
    25  	var intervals map[*Node]*interval
    26  
    27  	for i := 0; i < 100; i++ {
    28  		if sorted, err = Sort(g); err != nil {
    29  			t.Fatal(err)
    30  		}
    31  		reverseNodes(sorted)
    32  
    33  		df := analyze(g, sorted)
    34  		df.buildIntervals(sorted)
    35  		df.debugIntervals(sorted) // prints intervals on debug mode
    36  		intervals = df.intervals
    37  
    38  		// inputs are live until the last instruction
    39  		assert.Equal(len(intervals), intervals[x].end, "%v", len(sorted))
    40  		if intervals[x].start != 1 && intervals[x].start != 0 {
    41  			t.Errorf("x starts at 1 or 0 (depending on how the sort allocates it)")
    42  		}
    43  
    44  		assert.Equal(len(g.AllNodes()), intervals[y].end)
    45  		if intervals[y].start != 1 && intervals[y].start != 0 {
    46  			t.Errorf("y starts at 1 or 0 (depending on how the sort allocates it)")
    47  		}
    48  
    49  		// constants should be live until the last instruction
    50  		assert.Equal(len(intervals), intervals[c].end, "%v", len(sorted))
    51  
    52  		assert.Equal(2, intervals[z].start)
    53  		if intervals[z2].start > intervals[z].end {
    54  			t.Error("z2 should start before z ends")
    55  		}
    56  
    57  		assert.Equal(intervals[r].start, intervals[r].end)
    58  		if intervals[r].start < intervals[z].start {
    59  			t.Error("z should have an earlier start than r")
    60  		}
    61  		if intervals[r].start > intervals[z].end {
    62  			t.Error("z should end before r starts (or at the same as r start")
    63  		}
    64  
    65  		if intervals[z2].end <= intervals[z2].start {
    66  			t.Error("Given that z2y uses z2, the intervals should not end at the same as its start")
    67  		}
    68  		if intervals[z2].start < intervals[z].start {
    69  			t.Error("z should have an earlier start than z2")
    70  		}
    71  		if intervals[z2].start > intervals[z].end {
    72  			t.Error("z should end before r starts (or at the same as z2 start")
    73  		}
    74  
    75  		assert.Equal(intervals[z2y].start, intervals[z2y].end)
    76  		if intervals[z2y].start < intervals[z2].start {
    77  			t.Error("z2 should have an earlier start than z2y")
    78  		}
    79  		if intervals[z2y].start > intervals[z2].end {
    80  			t.Error("z2 should end before r starts (or at the same as z2y start")
    81  		}
    82  
    83  		if t.Failed() {
    84  			break
    85  		}
    86  
    87  	}
    88  
    89  	// visual reminder
    90  	var buf bytes.Buffer
    91  	buf.WriteString("VISUAL REMINDER OF INTERVALS\n")
    92  	sorted.reverse()
    93  	for i, n := range sorted {
    94  		in := intervals[n]
    95  		fmt.Fprintf(&buf, "%d\t%v\tfrom %v to %v \n", i, n, in.start, in.end)
    96  
    97  	}
    98  	t.Log(buf.String())
    99  }