gorgonia.org/gorgonia@v0.9.17/vm_genera_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"bytes"
     5  	"log"
     6  	"runtime"
     7  	"testing"
     8  
     9  	"github.com/stretchr/testify/assert"
    10  	"gorgonia.org/tensor"
    11  )
    12  
    13  func TestLispMachineBasics(t *testing.T) {
    14  	assert := assert.New(t)
    15  	var m *lispMachine
    16  	// var err error
    17  	var buf bytes.Buffer
    18  
    19  	// test various flags first
    20  	g := NewGraph()
    21  	m = NewLispMachine(g)
    22  	defer m.Close()
    23  	assert.Equal(byte(0x3), m.runFlags)
    24  	assert.True(m.runFwd())
    25  	assert.True(m.runBwd())
    26  
    27  	logger := log.New(&buf, "", 0)
    28  	m = NewLispMachine(g, WithLogger(logger))
    29  	defer m.Close()
    30  	assert.Equal(logger, m.logger)
    31  	assert.Equal(byte(0x0), m.logFlags) // if you pass in a logger without telling which direction to log... nothing gets logged
    32  
    33  	m = NewLispMachine(g, WithLogger(nil))
    34  	defer m.Close()
    35  	assert.NotNil(m.logger)
    36  
    37  	m = NewLispMachine(g, WithValueFmt("%v"))
    38  	defer m.Close()
    39  	assert.Equal("%v", m.valueFmt)
    40  
    41  	m = NewLispMachine(g, WithNaNWatch())
    42  	defer m.Close()
    43  	assert.Equal(byte(0x7), m.runFlags)
    44  	assert.True(m.watchNaN())
    45  
    46  	m = NewLispMachine(g, WithInfWatch())
    47  	defer m.Close()
    48  	assert.Equal(byte(0xb), m.runFlags)
    49  	assert.True(m.watchInf())
    50  
    51  	m = NewLispMachine(g, ExecuteFwdOnly())
    52  	defer m.Close()
    53  	assert.Equal(byte(0x1), m.runFlags)
    54  	assert.True(m.runFwd())
    55  	assert.False(m.runBwd())
    56  
    57  	m = NewLispMachine(g, ExecuteBwdOnly())
    58  	defer m.Close()
    59  	assert.Equal(byte(0x2), m.runFlags)
    60  	assert.True(m.runBwd())
    61  	assert.False(m.runFwd())
    62  
    63  	m = NewLispMachine(g, LogFwd())
    64  	defer m.Close()
    65  	assert.Equal(byte(0x1), m.logFlags)
    66  	assert.Equal(byte(0x3), m.runFlags)
    67  	assert.True(m.logFwd())
    68  	assert.False(m.logBwd())
    69  
    70  	m = NewLispMachine(g, LogBwd())
    71  	defer m.Close()
    72  	assert.Equal(byte(0x2), m.logFlags)
    73  	assert.Equal(byte(0x3), m.runFlags)
    74  	assert.True(m.logBwd())
    75  	assert.False(m.logFwd())
    76  
    77  	// if you pass in a watchlist, but don't have any logger, well, it's not gonna log anything
    78  	m = NewLispMachine(g, WithWatchlist())
    79  	defer m.Close()
    80  	assert.Equal(byte(0x80), m.logFlags)
    81  	assert.Equal(byte(0x3), m.runFlags)
    82  	assert.True(m.watchAll())
    83  
    84  }
    85  
    86  func TestLispMachineMechanics(t *testing.T) {
    87  	assert := assert.New(t)
    88  	var err error
    89  	g, x, y, z := simpleVecEqn()
    90  
    91  	sz := Must(Sum(z))
    92  
    93  	xBack := []float64{1, 5}
    94  	yBack := []float64{2, 4}
    95  	Let(x, tensor.New(tensor.WithShape(x.shape...), tensor.WithBacking(xBack)))
    96  	Let(y, tensor.New(tensor.WithShape(y.shape...), tensor.WithBacking(yBack)))
    97  
    98  	machine := NewLispMachine(g)
    99  	defer machine.Close()
   100  	if err = machine.RunAll(); err != nil {
   101  		t.Error(err)
   102  	}
   103  
   104  	gBack := []float64{1, 1}
   105  	grad := tensor.New(tensor.WithShape(x.shape...), tensor.WithBacking(gBack))
   106  	xG, _ := x.Grad()
   107  	yG, _ := y.Grad()
   108  
   109  	assert.True(ValueEq(grad, xG))
   110  	assert.True(ValueEq(grad, yG))
   111  
   112  	// tack more items onto the graph, and execute it again
   113  	szp2 := Must(Add(sz, twof64))
   114  	szp3 := Must(Add(sz, threef64))
   115  
   116  	var szp2Val Value
   117  	readSzp2 := Read(szp2, &szp2Val)
   118  
   119  	sg := g.SubgraphRoots(readSzp2, szp2)
   120  	machine = NewLispMachine(sg)
   121  	defer machine.Close()
   122  	if err = machine.RunAll(); err != nil {
   123  		t.Error(err)
   124  	}
   125  
   126  	assert.NotNil(szp2Val)
   127  	assert.Equal(szp2.Value(), szp2Val)
   128  	assert.Nil(szp3.boundTo) // node that was not executed on should not have any values bound to it
   129  
   130  	// play it again, sam!
   131  	// this is to test that if given the same root that had previously been executed on, it will not reallocate a new *dv
   132  	sg = g.SubgraphRoots(szp3)
   133  	machine = NewLispMachine(sg)
   134  	defer machine.Close()
   135  
   136  	if err = machine.RunAll(); err != nil {
   137  		t.Error(err)
   138  	}
   139  
   140  	// save szp3's value
   141  	szp3dv := szp3.boundTo.(*dualValue)
   142  	szp3dvv := szp3dv.Value
   143  
   144  	if err = machine.RunAll(); err != nil {
   145  		t.Error(err)
   146  	}
   147  
   148  	if dv := szp3.boundTo.(*dualValue); dv != szp3dv {
   149  		t.Error("A new *dualValue had been allocated for szp3dv. That's not supposed to happen")
   150  	} else if dv.Value != szp3dvv {
   151  		t.Error("A new value for szp3dv.Value has been allocated. That ain't supposed to happen")
   152  	}
   153  
   154  	// idiotsville
   155  
   156  	// non scalar costs
   157  	cost := Must(Add(sz, x))
   158  	sg = g.Subgraph(cost)
   159  	machine = NewLispMachine(sg)
   160  	defer machine.Close()
   161  	if err = machine.RunAll(); err == nil {
   162  		t.Error("Expected a AutoDiff error")
   163  	}
   164  }
   165  
   166  func TestLispMachineRepeatedRuns(t *testing.T) {
   167  	assert := assert.New(t)
   168  	var err error
   169  	g := NewGraph()
   170  	x := NewVector(g, Float64, WithShape(2), WithName("x"), WithInit(RangedFrom(0)))
   171  	y := NewMatrix(g, Float64, WithShape(2, 3), WithName("y"), WithInit(RangedFrom(0)))
   172  	z := Must(Mul(x, y))
   173  	cost := Must(Slice(z, S(1))) // this simulates the more complex cost functions
   174  
   175  	reps := 10
   176  
   177  	for i := 0; i < reps; i++ {
   178  		m := NewLispMachine(g)
   179  		if err := m.RunAll(); err != nil {
   180  			t.Errorf("Repetition %d error: %+v", i, err)
   181  			continue
   182  		}
   183  
   184  		var gradX, gradY, gradZ, gradC Value
   185  		if gradX, err = x.Grad(); err != nil {
   186  			t.Errorf("No gradient for x in repetition %d. Error: %v", i, err)
   187  			continue
   188  		}
   189  		if gradY, err = y.Grad(); err != nil {
   190  			t.Errorf("No gradient for y in repetition %d. Error: %v", i, err)
   191  			continue
   192  		}
   193  		if gradZ, err = z.Grad(); err != nil {
   194  			t.Errorf("No gradient for z in repetition %d. Error: %v", i, err)
   195  			continue
   196  		}
   197  		if gradC, err = cost.Grad(); err != nil {
   198  			t.Errorf("No gradient for cost in repetition %d. Error: %v", i, err)
   199  			continue
   200  		}
   201  
   202  		assert.Equal([]float64{1, 4}, gradX.Data(), "run %d", i)
   203  		assert.Equal([]float64{0, 0, 0, 0, 1, 0}, gradY.Data(), "run %d", i)
   204  		assert.Equal([]float64{0, 1, 0}, gradZ.Data(), "run %d", i)
   205  		assert.Equal(1.0, gradC.Data(), "run %d", i)
   206  
   207  		// assert that the data has been unchanged
   208  		assert.Equal([]float64{0, 1}, x.Value().Data())
   209  		assert.Equal([]float64{0, 1, 2, 3, 4, 5}, y.Value().Data())
   210  		assert.Equal([]float64{3, 4, 5}, z.Value().Data())
   211  		assert.Equal(float64(4), cost.Value().Data())
   212  
   213  		// This simulates the cloberring of of the gradients of the nodes. The next iteration should STILL reveal the same results
   214  		model := Nodes{x, y, z, cost}
   215  		for _, n := range model {
   216  			dv := n.boundTo.(*dualValue)
   217  			if err = dv.SetDeriv(ZeroValue(dv.d)); err != nil {
   218  				t.Errorf("Unable to set the gradient to 0 for %v. Error : %v", n, err)
   219  				continue
   220  			}
   221  		}
   222  		m.Close()
   223  		runtime.GC()
   224  	}
   225  
   226  }