gorgonia.org/gorgonia@v0.9.17/vm_genera_test.go (about) 1 package gorgonia 2 3 import ( 4 "bytes" 5 "log" 6 "runtime" 7 "testing" 8 9 "github.com/stretchr/testify/assert" 10 "gorgonia.org/tensor" 11 ) 12 13 func TestLispMachineBasics(t *testing.T) { 14 assert := assert.New(t) 15 var m *lispMachine 16 // var err error 17 var buf bytes.Buffer 18 19 // test various flags first 20 g := NewGraph() 21 m = NewLispMachine(g) 22 defer m.Close() 23 assert.Equal(byte(0x3), m.runFlags) 24 assert.True(m.runFwd()) 25 assert.True(m.runBwd()) 26 27 logger := log.New(&buf, "", 0) 28 m = NewLispMachine(g, WithLogger(logger)) 29 defer m.Close() 30 assert.Equal(logger, m.logger) 31 assert.Equal(byte(0x0), m.logFlags) // if you pass in a logger without telling which direction to log... nothing gets logged 32 33 m = NewLispMachine(g, WithLogger(nil)) 34 defer m.Close() 35 assert.NotNil(m.logger) 36 37 m = NewLispMachine(g, WithValueFmt("%v")) 38 defer m.Close() 39 assert.Equal("%v", m.valueFmt) 40 41 m = NewLispMachine(g, WithNaNWatch()) 42 defer m.Close() 43 assert.Equal(byte(0x7), m.runFlags) 44 assert.True(m.watchNaN()) 45 46 m = NewLispMachine(g, WithInfWatch()) 47 defer m.Close() 48 assert.Equal(byte(0xb), m.runFlags) 49 assert.True(m.watchInf()) 50 51 m = NewLispMachine(g, ExecuteFwdOnly()) 52 defer m.Close() 53 assert.Equal(byte(0x1), m.runFlags) 54 assert.True(m.runFwd()) 55 assert.False(m.runBwd()) 56 57 m = NewLispMachine(g, ExecuteBwdOnly()) 58 defer m.Close() 59 assert.Equal(byte(0x2), m.runFlags) 60 assert.True(m.runBwd()) 61 assert.False(m.runFwd()) 62 63 m = NewLispMachine(g, LogFwd()) 64 defer m.Close() 65 assert.Equal(byte(0x1), m.logFlags) 66 assert.Equal(byte(0x3), m.runFlags) 67 assert.True(m.logFwd()) 68 assert.False(m.logBwd()) 69 70 m = NewLispMachine(g, LogBwd()) 71 defer m.Close() 72 assert.Equal(byte(0x2), m.logFlags) 73 assert.Equal(byte(0x3), m.runFlags) 74 assert.True(m.logBwd()) 75 assert.False(m.logFwd()) 76 77 // if you pass in a watchlist, but don't have any logger, well, it's not gonna log anything 78 m = NewLispMachine(g, WithWatchlist()) 79 defer m.Close() 80 assert.Equal(byte(0x80), m.logFlags) 81 assert.Equal(byte(0x3), m.runFlags) 82 assert.True(m.watchAll()) 83 84 } 85 86 func TestLispMachineMechanics(t *testing.T) { 87 assert := assert.New(t) 88 var err error 89 g, x, y, z := simpleVecEqn() 90 91 sz := Must(Sum(z)) 92 93 xBack := []float64{1, 5} 94 yBack := []float64{2, 4} 95 Let(x, tensor.New(tensor.WithShape(x.shape...), tensor.WithBacking(xBack))) 96 Let(y, tensor.New(tensor.WithShape(y.shape...), tensor.WithBacking(yBack))) 97 98 machine := NewLispMachine(g) 99 defer machine.Close() 100 if err = machine.RunAll(); err != nil { 101 t.Error(err) 102 } 103 104 gBack := []float64{1, 1} 105 grad := tensor.New(tensor.WithShape(x.shape...), tensor.WithBacking(gBack)) 106 xG, _ := x.Grad() 107 yG, _ := y.Grad() 108 109 assert.True(ValueEq(grad, xG)) 110 assert.True(ValueEq(grad, yG)) 111 112 // tack more items onto the graph, and execute it again 113 szp2 := Must(Add(sz, twof64)) 114 szp3 := Must(Add(sz, threef64)) 115 116 var szp2Val Value 117 readSzp2 := Read(szp2, &szp2Val) 118 119 sg := g.SubgraphRoots(readSzp2, szp2) 120 machine = NewLispMachine(sg) 121 defer machine.Close() 122 if err = machine.RunAll(); err != nil { 123 t.Error(err) 124 } 125 126 assert.NotNil(szp2Val) 127 assert.Equal(szp2.Value(), szp2Val) 128 assert.Nil(szp3.boundTo) // node that was not executed on should not have any values bound to it 129 130 // play it again, sam! 131 // this is to test that if given the same root that had previously been executed on, it will not reallocate a new *dv 132 sg = g.SubgraphRoots(szp3) 133 machine = NewLispMachine(sg) 134 defer machine.Close() 135 136 if err = machine.RunAll(); err != nil { 137 t.Error(err) 138 } 139 140 // save szp3's value 141 szp3dv := szp3.boundTo.(*dualValue) 142 szp3dvv := szp3dv.Value 143 144 if err = machine.RunAll(); err != nil { 145 t.Error(err) 146 } 147 148 if dv := szp3.boundTo.(*dualValue); dv != szp3dv { 149 t.Error("A new *dualValue had been allocated for szp3dv. That's not supposed to happen") 150 } else if dv.Value != szp3dvv { 151 t.Error("A new value for szp3dv.Value has been allocated. That ain't supposed to happen") 152 } 153 154 // idiotsville 155 156 // non scalar costs 157 cost := Must(Add(sz, x)) 158 sg = g.Subgraph(cost) 159 machine = NewLispMachine(sg) 160 defer machine.Close() 161 if err = machine.RunAll(); err == nil { 162 t.Error("Expected a AutoDiff error") 163 } 164 } 165 166 func TestLispMachineRepeatedRuns(t *testing.T) { 167 assert := assert.New(t) 168 var err error 169 g := NewGraph() 170 x := NewVector(g, Float64, WithShape(2), WithName("x"), WithInit(RangedFrom(0))) 171 y := NewMatrix(g, Float64, WithShape(2, 3), WithName("y"), WithInit(RangedFrom(0))) 172 z := Must(Mul(x, y)) 173 cost := Must(Slice(z, S(1))) // this simulates the more complex cost functions 174 175 reps := 10 176 177 for i := 0; i < reps; i++ { 178 m := NewLispMachine(g) 179 if err := m.RunAll(); err != nil { 180 t.Errorf("Repetition %d error: %+v", i, err) 181 continue 182 } 183 184 var gradX, gradY, gradZ, gradC Value 185 if gradX, err = x.Grad(); err != nil { 186 t.Errorf("No gradient for x in repetition %d. Error: %v", i, err) 187 continue 188 } 189 if gradY, err = y.Grad(); err != nil { 190 t.Errorf("No gradient for y in repetition %d. Error: %v", i, err) 191 continue 192 } 193 if gradZ, err = z.Grad(); err != nil { 194 t.Errorf("No gradient for z in repetition %d. Error: %v", i, err) 195 continue 196 } 197 if gradC, err = cost.Grad(); err != nil { 198 t.Errorf("No gradient for cost in repetition %d. Error: %v", i, err) 199 continue 200 } 201 202 assert.Equal([]float64{1, 4}, gradX.Data(), "run %d", i) 203 assert.Equal([]float64{0, 0, 0, 0, 1, 0}, gradY.Data(), "run %d", i) 204 assert.Equal([]float64{0, 1, 0}, gradZ.Data(), "run %d", i) 205 assert.Equal(1.0, gradC.Data(), "run %d", i) 206 207 // assert that the data has been unchanged 208 assert.Equal([]float64{0, 1}, x.Value().Data()) 209 assert.Equal([]float64{0, 1, 2, 3, 4, 5}, y.Value().Data()) 210 assert.Equal([]float64{3, 4, 5}, z.Value().Data()) 211 assert.Equal(float64(4), cost.Value().Data()) 212 213 // This simulates the cloberring of of the gradients of the nodes. The next iteration should STILL reveal the same results 214 model := Nodes{x, y, z, cost} 215 for _, n := range model { 216 dv := n.boundTo.(*dualValue) 217 if err = dv.SetDeriv(ZeroValue(dv.d)); err != nil { 218 t.Errorf("Unable to set the gradient to 0 for %v. Error : %v", n, err) 219 continue 220 } 221 } 222 m.Close() 223 runtime.GC() 224 } 225 226 }