gorgonia.org/gorgonia@v0.9.17/operatorPointwise_binary_test.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "io/ioutil" 6 "math/rand" 7 "runtime" 8 "testing" 9 10 "github.com/pkg/errors" 11 "github.com/stretchr/testify/assert" 12 "gorgonia.org/tensor" 13 ) 14 15 func ssBinOpTest(t *testing.T, op ʘBinaryOperatorType, dt tensor.Dtype) (err error) { 16 defer runtime.GC() 17 assert := assert.New(t) 18 var randX, randY interface{} 19 switch dt { 20 case Float64: 21 randX = rand.ExpFloat64() 22 randY = rand.ExpFloat64() 23 case Float32: 24 randX = float32(rand.ExpFloat64()) 25 randY = float32(rand.ExpFloat64()) 26 default: 27 return errors.Errorf("op %v Test not yet implemented for %v ", op, dt) 28 } 29 30 binOp := newEBOByType(op, dt, dt) 31 t.Logf("ssBinOp %v %v %v", randX, op, randY) 32 33 var g, g2 *ExprGraph 34 var x, y, z *Node 35 var a, b, c *Node 36 var i, j, k *Node 37 g = NewGraph() 38 x = NewScalar(g, dt, WithName("x")) 39 y = NewScalar(g, dt, WithName("y")) 40 if z, err = ApplyOp(binOp, x, y); err != nil { 41 return err 42 } 43 44 g2 = NewGraph() 45 a = NewScalar(g2, dt, WithName("a")) 46 b = NewScalar(g2, dt, WithName("b")) 47 if c, err = ApplyOp(binOp, a, b); err != nil { 48 return err 49 } 50 51 i = NewScalar(g, dt, WithName("i")) 52 j = NewScalar(g, dt, WithName("j")) 53 binOp.retSame = true 54 if k, err = ApplyOp(binOp, i, j); err != nil { 55 return err 56 } 57 58 // var grads Nodes 59 var m1 VM 60 if op.isArith() { 61 if _, err = Grad(c, a, b); err != nil { 62 return err 63 } 64 m1 = NewLispMachine(g) 65 } else { 66 m1 = NewLispMachine(g, ExecuteFwdOnly()) 67 } 68 69 m2 := NewTapeMachine(g2, TraceExec(), BindDualValues()) 70 defer m2.Close() 71 defer m1.Close() 72 73 Let(x, randX) 74 Let(y, randY) 75 Let(i, randX) 76 Let(j, randY) 77 if err = m1.RunAll(); err != nil { 78 return 79 } 80 81 Let(a, randX) 82 Let(b, randY) 83 if err = m2.RunAll(); err != nil { 84 return 85 } 86 87 var xG, aG, yG, bG, zG, cG Value 88 if op.isArith() { 89 if xG, err = x.Grad(); err != nil { 90 return 91 } 92 if yG, err = y.Grad(); err != nil { 93 return 94 } 95 if aG, err = a.Grad(); err != nil { 96 return 97 } 98 if bG, err = b.Grad(); err != nil { 99 return 100 } 101 102 if zG, err = z.Grad(); err != nil { 103 return 104 } 105 if cG, err = c.Grad(); err != nil { 106 return 107 } 108 109 if _, err = i.Grad(); err != nil { 110 return 111 } 112 113 if _, err = j.Grad(); err != nil { 114 return 115 } 116 if _, err = k.Grad(); err != nil { 117 return 118 } 119 120 assert.True(ValueClose(xG, aG), "Test ssDiff of %v. xG != aG. Got %v and %v", op, xG, aG) 121 assert.True(ValueClose(yG, bG), "Test ssDiff of %v. yG != bG. Got %v and %v", op, yG, bG) 122 assert.True(ValueClose(zG, cG), "Test ssDiff of %v. zG != cG. Got %v and %v", op, zG, cG) 123 } 124 125 assert.True(ValueClose(x.Value(), a.Value()), "Test ss op %v. Values are different: x: %v, a %v", op, x.Value(), a.Value()) 126 assert.True(ValueClose(y.Value(), b.Value()), "Test ss op %v. Values are different: y: %v, b %v", op, y.Value(), b.Value()) 127 assert.True(ValueClose(z.Value(), c.Value()), "Test ss op %v. Values are different: z: %v, c %v", op, z.Value(), c.Value()) 128 129 return nil 130 } 131 132 func ttBinOpTest(t *testing.T, op ʘBinaryOperatorType, dt tensor.Dtype) (err error) { 133 defer runtime.GC() 134 assert := assert.New(t) 135 var x, y, z, a, b, c, cost *Node 136 var g, g2 *ExprGraph 137 138 var randX, randY interface{} 139 switch dt { 140 case Float32: 141 randX = []float32{1, 2, 3, 4} 142 randY = []float32{2, 2, 2, 2} 143 case Float64: 144 randX = []float64{1, 2, 3, 4} 145 randY = []float64{2, 2, 2, 2} 146 } 147 148 t.Logf("ttBinOp: %v %v %v", randX, op, randY) 149 // randX := Gaussian(0, 1)(dt, 2, 2) 150 // randY := Gaussian(0, 1)(dt, 2, 2) 151 152 xV := tensor.New(tensor.WithShape(2, 2), tensor.WithBacking(randX)) 153 yV := tensor.New(tensor.WithShape(2, 2), tensor.WithBacking(randY)) 154 155 g = NewGraph() 156 g2 = NewGraph() 157 x = NewMatrix(g, dt, WithName("x"), WithShape(2, 2)) 158 y = NewMatrix(g, dt, WithName("y"), WithShape(2, 2)) 159 a = NewMatrix(g2, dt, WithName("a"), WithShape(2, 2)) 160 b = NewMatrix(g2, dt, WithName("b"), WithShape(2, 2)) 161 162 binOp := newEBOByType(op, x.t, y.t) 163 if z, err = ApplyOp(binOp, x, y); err != nil { 164 return err 165 } 166 if c, err = ApplyOp(binOp, a, b); err != nil { 167 return err 168 } 169 170 var m1 VM 171 if op.isArith() { 172 if _, err = Sum(z); err != nil { 173 return err 174 } 175 if cost, err = Sum(c); err != nil { 176 return err 177 } 178 179 if _, err = Grad(cost, a, b); err != nil { 180 return err 181 } 182 m1 = NewLispMachine(g) 183 } else { 184 m1 = NewLispMachine(g, ExecuteFwdOnly()) 185 } 186 187 // lg := log.New(os.Stderr, "", 0) 188 m2 := NewTapeMachine(g2, TraceExec()) 189 defer m2.Close() 190 defer m1.Close() 191 192 // m2 := NewTapeMachine(prog, locMap, TraceExec(), WithLogger(logger), WithWatchlist()) 193 194 Let(x, xV) 195 Let(y, yV) 196 if err = m1.RunAll(); err != nil { 197 return 198 } 199 200 Let(a, xV) 201 Let(b, yV) 202 if err = m2.RunAll(); err != nil { 203 return 204 } 205 206 var xG, aG, yG, bG, zG, cG Value 207 if op.isArith() { 208 if xG, err = x.Grad(); err != nil { 209 return 210 } 211 if yG, err = y.Grad(); err != nil { 212 return 213 } 214 if aG, err = a.Grad(); err != nil { 215 return 216 } 217 if bG, err = b.Grad(); err != nil { 218 return 219 } 220 221 if zG, err = z.Grad(); err != nil { 222 return 223 } 224 if cG, err = c.Grad(); err != nil { 225 return 226 } 227 assert.True(ValueClose(xG, aG), "Test ttDiff of %v. xG != aG. Got %+v \nand %+v", op, xG, aG) 228 assert.True(ValueClose(yG, bG), "Test ttDiff of %v. yG != bG. Got %+v \nand %+v", op, yG, bG) 229 assert.True(ValueClose(zG, cG), "Test ttDiff of %v. zG != cG. Got %+v \nand %+v", op, zG, cG) 230 } 231 232 assert.True(ValueClose(x.Value(), a.Value()), "Test tt op %v. Values are different: x: %+v\n a %+v", op, x.Value(), a.Value()) 233 assert.True(ValueClose(y.Value(), b.Value()), "Test tt op %v. Values are different: y: %+v\n b %+v", op, y.Value(), b.Value()) 234 assert.True(ValueClose(z.Value(), c.Value()), "Test tt op %v. Values are different: z: %+v\n c %+v", op, z.Value(), c.Value()) 235 236 if t.Failed() { 237 ioutil.WriteFile(fmt.Sprintf("Test_%v_tt.dot", op), []byte(g2.ToDot()), 0644) 238 } 239 240 return nil 241 } 242 243 func TestBinOps(t *testing.T) { 244 for op := addOpType; op < maxʘBinaryOpType; op++ { 245 t.Logf("OP: %v", op) 246 247 // if op != addOpType { 248 // continue 249 // } 250 251 // for op := subOpType; op < mulOpType; op++ { 252 var err error 253 err = ssBinOpTest(t, op, Float64) 254 if err != nil { 255 t.Errorf("Float64 version err: %v", err) 256 } 257 258 err = ssBinOpTest(t, op, Float32) 259 if err != nil { 260 t.Errorf("Float32 version err: %v", err) 261 } 262 263 t.Logf("Float64 T-T test for %v", op) 264 err = ttBinOpTest(t, op, Float64) 265 if err != nil { 266 t.Errorf("ttBinOp Float64 version err %v", err) 267 } 268 269 t.Logf("Float32 T-T test") 270 err = ttBinOpTest(t, op, Float32) 271 if err != nil { 272 t.Errorf("ttBinOp Float64 version err %v", err) 273 } 274 } 275 276 // single tests 277 278 // ttBinOpTest(t, subOpType, Float64) 279 }