gorgonia.org/gorgonia@v0.9.17/op_tensor_test.go (about) 1 package gorgonia 2 3 import ( 4 "crypto/sha256" 5 "runtime" 6 "testing" 7 8 "github.com/stretchr/testify/assert" 9 "gorgonia.org/tensor" 10 ) 11 12 var repeatOpTests = []struct { 13 name string 14 rep int 15 axes int 16 val Value 17 18 correct Value 19 expectedShape tensor.Shape 20 err bool 21 }{ 22 { 23 "repeat matrix on axis 0", 2, 0, 24 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4}), tensor.WithShape(2, 2)), 25 tensor.New(tensor.WithBacking([]float64{1, 2, 1, 2, 3, 4, 3, 4}), tensor.WithShape(4, 2)), 26 tensor.Shape{4, 2}, false, 27 }, 28 29 { 30 "repeat matrix on axis 1", 2, 1, 31 tensor.New(tensor.WithBacking([]float64{1, 2, 3, 4}), tensor.WithShape(2, 2)), 32 tensor.New(tensor.WithBacking([]float64{1, 1, 2, 2, 3, 3, 4, 4}), tensor.WithShape(2, 4)), 33 tensor.Shape{2, 4}, false, 34 }, 35 36 { 37 "repeat col vec on axis 0", 2, 0, 38 tensor.New(tensor.WithBacking([]float64{1, 2}), tensor.WithShape(2, 1)), 39 tensor.New(tensor.WithBacking([]float64{1, 1, 2, 2}), tensor.WithShape(4, 1)), 40 tensor.Shape{4, 1}, false, 41 }, 42 43 { 44 "repeat col vec on axis 1", 2, 1, 45 tensor.New(tensor.WithBacking([]float64{1, 2}), tensor.WithShape(2, 1)), 46 tensor.New(tensor.WithBacking([]float64{1, 1, 2, 2}), tensor.WithShape(2, 2)), 47 tensor.Shape{2, 2}, false, 48 }, 49 50 { 51 "repeat row vec on axis 0", 2, 0, 52 tensor.New(tensor.WithBacking([]float64{1, 2}), tensor.WithShape(1, 2)), 53 tensor.New(tensor.WithBacking([]float64{1, 2, 1, 2}), tensor.WithShape(2, 2)), 54 tensor.Shape{2, 2}, false, 55 }, 56 57 { 58 "repeat row vec on axis 1", 2, 1, 59 tensor.New(tensor.WithBacking([]float64{1, 2}), tensor.WithShape(1, 2)), 60 tensor.New(tensor.WithBacking([]float64{1, 1, 2, 2}), tensor.WithShape(1, 4)), 61 tensor.Shape{1, 4}, false, 62 }, 63 64 { 65 "repeat vector on axis 0", 2, 0, 66 tensor.New(tensor.WithBacking([]float64{1, 2}), tensor.WithShape(2)), 67 tensor.New(tensor.WithBacking([]float64{1, 1, 2, 2}), tensor.WithShape(4)), 68 tensor.Shape{4}, false, 69 }, 70 71 { 72 "repeat vector on axis 1", 2, 1, 73 tensor.New(tensor.WithBacking([]float64{1, 2}), tensor.WithShape(2)), 74 tensor.New(tensor.WithBacking([]float64{1, 1, 2, 2}), tensor.WithShape(2, 2)), 75 tensor.Shape{2, 2}, false, 76 }, 77 78 { 79 "repeat scalar", 2, 0, 80 NewF64(3.14), tensor.New(tensor.WithBacking([]float64{3.14, 3.14}), tensor.WithShape(2)), 81 tensor.Shape{2}, false, 82 }, 83 } 84 85 func TestRepeatOp(t *testing.T) { 86 // assert := assert.New(t) 87 88 for _, rots := range repeatOpTests { 89 // if rots.name != "repeat matrix on axis 1" { 90 // continue 91 // } 92 g := NewGraph() 93 var res Value 94 var err error 95 var repeat *repeatOp 96 97 rep := NewI(rots.rep) 98 n := NodeFromAny(g, rots.val) 99 100 repeat = newRepeatOp(rots.axes, n) 101 102 res, err = repeat.Do(rots.val, rep) 103 switch { 104 case rots.err: 105 if err == nil { 106 t.Errorf("Test %q: Expected an error", rots.name) 107 } 108 goto infershape 109 case !rots.err && err != nil: 110 t.Errorf("%+v", err) 111 goto infershape 112 } 113 114 if !ValueEq(res, rots.correct) { 115 t.Errorf("Test %q: Expected %v. Got %v", rots.name, rots.correct, res) 116 } 117 118 infershape: 119 var s tensor.Shape 120 size := sizeOp{axis: rots.axes, val: rots.rep} 121 s, err = repeat.InferShape(rots.val.Shape(), size) 122 switch { 123 case rots.err: 124 if err == nil { 125 t.Error("Expected an error") 126 } 127 continue 128 case !rots.err && err != nil: 129 t.Errorf("Test %q %+v", rots.name, err) 130 continue 131 } 132 133 if !rots.expectedShape.Eq(s) { 134 t.Errorf("Test %q InferShape: Expected %v. Got %v instead", rots.name, rots.expectedShape, s) 135 } 136 } 137 } 138 139 func repeatOpDiff(repeatOn int, shape tensor.Shape, xV, yV interface{}) (g *ExprGraph, x, y *Node, err error) { 140 g = NewGraph() 141 switch shape.Dims() { 142 case 0: 143 x = NewScalar(g, Float64, WithName("x")) 144 case 1: 145 // vanilla vector 146 x = NewVector(g, Float64, WithName("x"), WithShape(shape...)) 147 case 2: 148 x = NewMatrix(g, Float64, WithName("x"), WithShape(shape...)) 149 default: 150 //matrix and tensors 151 x = NewTensor(g, Float64, shape.Dims(), WithName("x"), WithShape(shape...)) 152 } 153 154 repOp := sizeOp{axis: repeatOn, val: 2} 155 repN := NewScalar(g, Float64, WithName("REPCONST"), WithOp(repOp), WithValue(2.0)) 156 repeat := newRepeatOp(repeatOn, x) 157 158 if y, err = ApplyOp(repeat, x, repN); err != nil { 159 return 160 } 161 xVal, _, _, _ := anyToValue(xV) 162 yVal, _, _, _ := anyToValue(yV) 163 x.bind(dvUnit(xVal)) 164 y.bind(dvUnitVar(yVal)) 165 if err = repeat.DoDiff(ExecutionContext{}, Nodes{x, repN}, y); err != nil { 166 return 167 } 168 return 169 } 170 171 func TestRepeatOpDoDiff(t *testing.T) { 172 //t.SkipNow() 173 assert := assert.New(t) 174 // var g *ExprGraph 175 // var x, y, repN *Node 176 // var repeat *repeatOp 177 var x *Node 178 var err error 179 180 var xG Value 181 var xT, yT *tensor.Dense 182 183 yT = tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{3.14, 3.14})) 184 185 // scalar repeated into a vec/colvec 186 if _, x, _, err = repeatOpDiff(0, scalarShape, 3.14, yT); err != nil { 187 t.Fatal(err) 188 } 189 xG, _ = x.Grad() 190 assert.Equal(2.0, extractF64(xG)) 191 192 // scalar repeated into a rowvec 193 // if _, x, _, err = repeatOpDiff(1, scalarShape, 3.14, yT); err != nil { 194 // t.Fatal(err) 195 // } 196 // xG, _ = x.Grad() 197 // assert.Equal(2.0, extractF64(xG)) 198 199 // vector repeated unto itself 200 xT = tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{3.14, 3.14})) 201 yT = tensor.New(tensor.WithShape(4), tensor.WithBacking([]float64{3.14, 3.14, 3.14, 3.14})) 202 if _, x, _, err = repeatOpDiff(0, tensor.Shape{2}, xT, yT); err != nil { 203 t.Fatal(err) 204 } 205 xG, _ = x.Grad() 206 assert.Equal([]float64{2, 2}, extractF64s(xG)) 207 208 // colvec repeated unto itself 209 xT = tensor.New(tensor.WithShape(2, 1), tensor.WithBacking([]float64{3.14, 3.14})) 210 yT = tensor.New(tensor.WithShape(4, 1), tensor.WithBacking([]float64{3.14, 3.14, 3.14, 3.14})) 211 if _, x, _, err = repeatOpDiff(0, tensor.Shape{2}, xT, yT); err != nil { 212 t.Fatal(err) 213 } 214 xG, _ = x.Grad() 215 assert.Equal([]float64{2, 2}, extractF64s(xG)) 216 217 // rowvec repeated unto itself 218 xT = tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float64{3.14, 3.14})) 219 yT = tensor.New(tensor.WithShape(1, 4), tensor.WithBacking([]float64{3.14, 3.14, 3.14, 3.14})) 220 if _, x, _, err = repeatOpDiff(1, tensor.Shape{1, 2}, xT, yT); err != nil { 221 t.Fatal(err) 222 } 223 xG, _ = x.Grad() 224 assert.Equal([]float64{2, 2}, extractF64s(xG)) 225 226 // matrix on axis 0 227 xT = tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{3.14, 2.718, 1.618, 1.414})) 228 yT = tensor.New(tensor.WithShape(4, 2), tensor.WithBacking([]float64{3.14, 2.718, 3.14, 2.718, 1.618, 1.414, 1.618, 1.414})) 229 if _, x, _, err = repeatOpDiff(0, tensor.Shape{1, 2}, xT, yT); err != nil { 230 t.Fatal(err) 231 } 232 xG, _ = x.Grad() 233 assert.Equal([]float64{2, 2, 2, 2}, extractF64s(xG)) 234 235 // matrix on axis 1 236 xT = tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{3.14, 2.718, 1.618, 1.414})) 237 yT = tensor.New(tensor.WithShape(4, 2), tensor.WithBacking([]float64{3.14, 2.718, 3.14, 2.718, 1.618, 1.414, 1.618, 1.414})) 238 if _, x, _, err = repeatOpDiff(1, tensor.Shape{1, 2}, xT, yT); err != nil { 239 t.Fatal(err) 240 } 241 xG, _ = x.Grad() 242 assert.Equal([]float64{2, 2, 2, 2}, extractF64s(xG)) 243 244 } 245 246 func TestTransposeOp(t *testing.T) { 247 assert := assert.New(t) 248 g := NewGraph() 249 A := NewMatrix(g, Float64, WithShape(2, 3), WithInit(RangedFrom(0))) 250 AT := Must(Transpose(A)) 251 cost1 := Must(Sum(AT)) 252 253 var m VM 254 var err error 255 256 m = NewLispMachine(g) 257 defer m.Close() 258 if err = m.RunAll(); err != nil { 259 t.Error(err) 260 } 261 262 assert.Equal(tensor.Shape{3, 2}, AT.shape) 263 264 h := NewGraph() 265 B := NewMatrix(h, Float64, WithShape(2, 3), WithInit(RangedFrom(0))) 266 BT := Must(Transpose(B)) 267 cost2 := Must(Sum(BT)) 268 Grad(cost2, B) 269 270 m = NewTapeMachine(h) 271 defer m.Close() 272 if err = m.RunAll(); err != nil { 273 t.Error(err) 274 } 275 assert.Equal(tensor.Shape{3, 2}, BT.shape) 276 277 var ag, bg Value 278 if ag, err = A.Grad(); err != nil { 279 t.Fatalf("Cannot get grad of A. Err: %v", err) 280 } 281 282 if bg, err = B.Grad(); err != nil { 283 t.Fatalf("Cannot get grad of B. Err: %v", err) 284 } 285 286 var costGrad1, costGrad2 Value 287 if costGrad1, err = cost1.Grad(); err != nil { 288 t.Fatalf("Cannot get grad of Cost1. Err %v", err) 289 } 290 291 if costGrad2, err = cost2.Grad(); err != nil { 292 t.Fatalf("Cannot get grad of Cost2. Err %v", err) 293 } 294 295 t.Logf("%v %v", cost1.Value(), cost2.Value()) 296 t.Logf("%v %v", costGrad1, costGrad2) 297 298 assert.True(ValueEq(ag, bg)) 299 } 300 301 func TestConcatOp(t *testing.T) { 302 defer runtime.GC() 303 304 assert := assert.New(t) 305 g := NewGraph() 306 x := NewVector(g, Float64, WithShape(2)) 307 xx, err := Concat(0, x, x) 308 if err != nil { 309 t.Fatalf("%+v", err) 310 } 311 312 cost := Must(Sum(xx)) 313 Grad(cost, x) 314 315 g2 := NewGraph() 316 a := NewVector(g2, Float64, WithShape(2)) 317 aa, err := Concat(0, a, a) 318 if err != nil { 319 t.Fatalf("%+v", err) 320 } 321 Must(Sum(aa)) // cost 322 323 aBack := []float64{1, 2} 324 aT := tensor.New(tensor.WithShape(2), tensor.WithBacking(aBack)) 325 326 xBack := []float64{1, 2} 327 xT := tensor.New(tensor.WithShape(2), tensor.WithBacking(xBack)) 328 329 Let(a, aT) 330 Let(x, xT) 331 m1 := NewTapeMachine(g) 332 m2 := NewLispMachine(g2) 333 defer m1.Close() 334 defer m2.Close() 335 336 if err = m1.RunAll(); err != nil { 337 t.Fatal(err) 338 } 339 340 if err = m2.RunAll(); err != nil { 341 t.Fatalf("%+v", err) 342 } 343 344 xG, _ := x.Grad() 345 aG, _ := a.Grad() 346 assert.True(ValueEq(xG, aG)) 347 assert.True(ValueEq(xx.Value(), aa.Value())) 348 } 349 350 func Test_atOp_WriteHash(t *testing.T) { 351 defer func() { 352 if r := recover(); r != nil { 353 t.Fail() 354 } 355 }() 356 h := sha256.New() 357 358 at := &atOp{} 359 at.WriteHash(h) 360 }