gorgonia.org/gorgonia@v0.9.17/solvers_test.go (about) 1 package gorgonia 2 3 import ( 4 "log" 5 "math" 6 "runtime" 7 "testing" 8 9 "github.com/chewxy/math32" 10 "github.com/stretchr/testify/assert" 11 "gorgonia.org/dawson" 12 "gorgonia.org/tensor" 13 ) 14 15 func clampFloat64(v, min, max float64) float64 { 16 if v < min { 17 return min 18 } 19 if v > max { 20 return max 21 } 22 return v 23 } 24 25 func clampFloat32(v, min, max float32) float32 { 26 if v < min { 27 return min 28 } 29 if v > max { 30 return max 31 } 32 return v 33 } 34 35 func tf64Node() []ValueGrad { 36 backingV := []float64{1, 2, 3, 4} 37 backingD := []float64{0.5, -10, 10, 0.5} 38 v := tensor.New(tensor.WithBacking(backingV), tensor.WithShape(2, 2)) 39 d := tensor.New(tensor.WithBacking(backingD), tensor.WithShape(2, 2)) 40 41 dv := dvUnit0(v) 42 dv.d = d 43 44 n := new(Node) 45 n.boundTo = dv 46 47 model := []ValueGrad{n} 48 return model 49 } 50 51 func tf32Node() []ValueGrad { 52 backingV := []float32{1, 2, 3, 4} 53 backingD := []float32{0.5, -10, 10, 0.5} 54 55 v := tensor.New(tensor.WithBacking(backingV), tensor.WithShape(2, 2)) 56 d := tensor.New(tensor.WithBacking(backingD), tensor.WithShape(2, 2)) 57 58 dv := dvUnit0(v) 59 dv.d = d 60 61 n := new(Node) 62 n.boundTo = dv 63 64 model := []ValueGrad{n} 65 return model 66 } 67 68 func manualRMSProp64(t *testing.T, s *RMSPropSolver, model []ValueGrad) { 69 assert := assert.New(t) 70 correct := make([]float64, 4) 71 cached := make([]float64, 4) 72 73 grad0, _ := model[0].Grad() 74 backingV := model[0].Value().Data().([]float64) 75 backingD := grad0.Data().([]float64) 76 77 for i := 0; i < 5; i++ { 78 for j, v := range backingV { 79 grad := backingD[j] 80 cw := cached[j] 81 82 decayed := cw*s.decay + (1.0-s.decay)*grad*grad 83 cached[j] = decayed 84 85 grad = clampFloat64(grad, -s.clip, s.clip) 86 upd := -s.eta*grad/math.Sqrt(decayed+s.eps) - s.l2reg*v 87 correct[j] = v + upd 88 } 89 90 err := s.Step(model) 91 if err != nil { 92 t.Error(err) 93 } 94 95 sCache := s.cache[0].Value.(tensor.Tensor) 96 assert.Equal(correct, backingV, "Iteration: %d", i) 97 assert.Equal(cached, sCache.Data(), "Iteration: %d", i) 98 99 } 100 } 101 102 func manualRMSProp32(t *testing.T, s *RMSPropSolver, model []ValueGrad) { 103 assert := assert.New(t) 104 correct := make([]float32, 4) 105 cached := make([]float32, 4) 106 107 grad0, _ := model[0].Grad() 108 backingV := model[0].Value().Data().([]float32) 109 backingD := grad0.Data().([]float32) 110 111 decay := float32(s.decay) 112 l2reg := float32(s.l2reg) 113 eta := float32(s.eta) 114 eps := float32(s.eps) 115 clip := float32(s.clip) 116 117 // NOTE: THIS IS NAUGHTY. A proper comparison using 1e-5 should be used but that causes errors. 118 closef32 := func(a, b float32) bool { 119 return dawson.ToleranceF32(a, b, 1e-4) 120 } 121 122 for i := 0; i < 5; i++ { 123 for j, v := range backingV { 124 grad := backingD[j] 125 cw := cached[j] 126 127 decayed := cw*decay + (1.0-decay)*grad*grad 128 cached[j] = decayed 129 130 grad = clampFloat32(grad, -clip, clip) 131 upd := -eta*grad/math32.Sqrt(decayed+eps) - l2reg*v 132 correct[j] = v + upd 133 } 134 135 err := s.Step(model) 136 if err != nil { 137 t.Error(err) 138 } 139 140 sCache := s.cache[0].Value.(tensor.Tensor) 141 assert.True(dawson.AllClose(correct, backingV, closef32)) 142 assert.True(dawson.AllClose(cached, sCache.Data().([]float32), closef32)) 143 } 144 } 145 146 func TestRMSPropSolverManual(t *testing.T) { 147 148 stepSize := 0.01 149 l2Reg := 0.000001 150 clip := 5.0 151 152 var s *RMSPropSolver 153 var model []ValueGrad 154 155 s = NewRMSPropSolver(WithLearnRate(stepSize), WithL2Reg(l2Reg), WithClip(clip)) 156 model = tf64Node() 157 manualRMSProp64(t, s, model) 158 159 s = NewRMSPropSolver(WithLearnRate(stepSize), WithL2Reg(l2Reg), WithClip(clip)) 160 model = tf32Node() 161 manualRMSProp32(t, s, model) 162 163 } 164 165 func TestRMSPropSolver(t *testing.T) { 166 assert := assert.New(t) 167 168 z, cost, m, err := model2dRosenbrock(1, 100, -0.5, 0.5) 169 defer m.Close() 170 const costThreshold = 0.68 171 if nil != err { 172 t.Fatal(err) 173 } 174 175 solver := NewRMSPropSolver() 176 177 maxIterations := 1000 178 179 costFloat := 42.0 180 for 0 != maxIterations { 181 m.Reset() 182 err = m.RunAll() 183 if nil != err { 184 t.Fatal(err) 185 } 186 187 costFloat = cost.Value().Data().(float64) 188 if costThreshold > math.Abs(costFloat) { 189 break 190 } 191 192 err = solver.Step([]ValueGrad{z}) 193 if nil != err { 194 t.Fatal(err) 195 } 196 197 maxIterations-- 198 } 199 200 assert.InDelta(0, costFloat, costThreshold) 201 } 202 203 func TestAdaGradSolver(t *testing.T) { 204 assert := assert.New(t) 205 206 z, cost, m, err := model2dSquare(-0.5, 0.5) 207 defer m.Close() 208 const costThreshold = 0.39 209 if nil != err { 210 t.Fatal(err) 211 } 212 213 solver := NewAdaGradSolver() 214 215 maxIterations := 1000 216 217 costFloat := 42.0 218 for 0 != maxIterations { 219 m.Reset() 220 err = m.RunAll() 221 if nil != err { 222 t.Fatal(err) 223 } 224 225 costFloat = cost.Value().Data().(float64) 226 if costThreshold > math.Abs(costFloat) { 227 break 228 } 229 230 err = solver.Step([]ValueGrad{z}) 231 if nil != err { 232 t.Fatal(err) 233 } 234 235 maxIterations-- 236 } 237 238 assert.InDelta(0, costFloat, costThreshold) 239 } 240 241 func TestVanillaSolver(t *testing.T) { 242 assert := assert.New(t) 243 244 z, cost, m, err := model2dRosenbrock(1, 100, -0.5, 0.5) 245 defer m.Close() 246 const costThreshold = 0.185 247 if nil != err { 248 t.Fatal(err) 249 } 250 251 solver := NewVanillaSolver() 252 253 maxIterations := 1000 254 255 costFloat := 42.0 256 for 0 != maxIterations { 257 m.Reset() 258 err = m.RunAll() 259 if nil != err { 260 t.Fatal(err) 261 } 262 263 costFloat = cost.Value().Data().(float64) 264 if costThreshold > math.Abs(costFloat) { 265 break 266 } 267 268 err = solver.Step([]ValueGrad{z}) 269 if nil != err { 270 t.Fatal(err) 271 } 272 273 maxIterations-- 274 } 275 276 assert.InDelta(0, costFloat, costThreshold) 277 } 278 279 func TestMomentum(t *testing.T) { 280 assert := assert.New(t) 281 282 z, cost, m, err := model2dRosenbrock(1, 100, -0.5, 0.5) 283 defer m.Close() 284 const costThreshold = 0.39 285 if nil != err { 286 t.Fatal(err) 287 } 288 289 solver := NewMomentum() 290 291 maxIterations := 1000 292 293 costFloat := 42.0 294 for 0 != maxIterations { 295 m.Reset() 296 err = m.RunAll() 297 if nil != err { 298 t.Fatal(err) 299 } 300 301 costFloat = cost.Value().Data().(float64) 302 if costThreshold > math.Abs(costFloat) { 303 break 304 } 305 306 err = solver.Step([]ValueGrad{z}) 307 if nil != err { 308 t.Fatal(err) 309 } 310 311 maxIterations-- 312 } 313 314 assert.InDelta(0, costFloat, costThreshold) 315 } 316 317 func TestAdamSolver(t *testing.T) { 318 assert := assert.New(t) 319 320 z, cost, m, err := model2dRosenbrock(1, 100, -0.5, 0.5) 321 defer m.Close() 322 const costThreshold = 0.113 323 if nil != err { 324 t.Fatal(err) 325 } 326 327 solver := NewAdamSolver() 328 329 maxIterations := 1000 330 331 costFloat := 42.0 332 for 0 != maxIterations { 333 m.Reset() 334 err = m.RunAll() 335 if nil != err { 336 t.Fatal(err) 337 } 338 339 costFloat = cost.Value().Data().(float64) 340 if costThreshold > math.Abs(costFloat) { 341 break 342 } 343 344 err = solver.Step([]ValueGrad{z}) 345 if nil != err { 346 t.Fatal(err) 347 } 348 349 maxIterations-- 350 } 351 352 assert.InDelta(0, costFloat, costThreshold) 353 } 354 355 func TestBarzilaiBorweinSolver(t *testing.T) { 356 assert := assert.New(t) 357 358 z, cost, m, err := model2dRosenbrock(1, 100, -0.5, 0.5) 359 defer m.Close() 360 const costThreshold = 0.00002 361 if nil != err { 362 t.Fatal(err) 363 } 364 365 solver := NewBarzilaiBorweinSolver(WithLearnRate(0.0001)) 366 iterations := 0 367 costFloat := 42.0 368 369 // NOTE: due to precision issues with floating-point arithmetic, 370 // amd64 reaches the minimum expected cost at iteration #198 371 // arm64 reaches the minimum expected cost at iteration #210 372 // In some other cases arm converges faster than amd 373 // See https://github.com/golang/go/issues/18354#issuecomment-267705645 374 375 for iterations < 250 { 376 m.Reset() 377 err = m.RunAll() 378 if nil != err { 379 t.Fatal(err) 380 } 381 382 costFloat = cost.Value().Data().(float64) 383 if costThreshold > math.Abs(costFloat) { 384 break 385 } 386 387 err = solver.Step([]ValueGrad{z}) 388 if nil != err { 389 t.Fatal(err) 390 } 391 392 iterations++ 393 } 394 395 log.Printf("Found minimum cost at iteration %d. arch=%s", iterations, runtime.GOARCH) 396 397 assert.InDelta(0, costFloat, costThreshold) 398 } 399 400 // The Rosenbrock function is a non-convex function, 401 // which is used as a performance test problem for optimization algorithms. 402 // https://en.wikipedia.org/wiki/Rosenbrock_function 403 // 404 // f(x,y) = (a-x)² + b(y-x²)² 405 // It has a global minimum at (x, y) = (a, a²), where f(x,y) = 0. 406 // Usually a = 1, b = 100, then the minimum is at x = y = 1 407 // TODO: There is also an n-dimensional version...see wiki 408 func model2dRosenbrock(a, b, xInit, yInit float64) (z, cost *Node, machine *tapeMachine, err error) { 409 g := NewGraph() 410 411 z = NewTensor(g, Float64, 1, WithShape(2), WithName("z")) 412 413 aN := NewConstant(a, WithName("a")) 414 bN := NewConstant(b, WithName("b")) 415 416 xProjFloat := []float64{1, 0} 417 xProj := NewConstant(tensor.New(tensor.WithBacking(xProjFloat), tensor.WithShape(2))) 418 419 yProjFloat := []float64{0, 1} 420 yProj := NewConstant(tensor.New(tensor.WithBacking(yProjFloat), tensor.WithShape(2))) 421 422 x := Must(Mul(z, xProj)) 423 y := Must(Mul(z, yProj)) 424 425 // First term 426 427 sqrt1stTerm := Must(Sub(aN, x)) 428 429 firstTerm := Must(Square(sqrt1stTerm)) 430 431 // Second term 432 433 xSquared := Must(Square(x)) 434 435 yMinusxSquared := Must(Sub(y, xSquared)) 436 437 yMinusxSquaredSqu := Must(Square(yMinusxSquared)) 438 439 secondTerm := Must(Mul(bN, yMinusxSquaredSqu)) 440 441 // cost 442 cost = Must(Add(firstTerm, secondTerm)) 443 444 dcost, err := Grad(cost, z) 445 if nil != err { 446 return nil, nil, nil, err 447 } 448 449 prog, locMap, err := CompileFunction(g, Nodes{z}, Nodes{cost, dcost[0]}) 450 if nil != err { 451 return nil, nil, nil, err 452 } 453 454 machine = NewTapeMachine(g, WithPrecompiled(prog, locMap), BindDualValues(z)) 455 456 err = machine.Let(z, tensor.New(tensor.WithBacking([]float64{xInit, yInit}), tensor.WithShape(2))) 457 if nil != err { 458 return nil, nil, nil, err 459 } 460 461 return 462 } 463 464 func model2dSquare(xInit, yInit float64) (z, cost *Node, machine *tapeMachine, err error) { 465 g := NewGraph() 466 467 z = NewTensor(g, Float64, 1, WithShape(2), WithName("z")) 468 469 // cost 470 cost = Must(Mul(z, z)) 471 472 dcost, err := Grad(cost, z) 473 if nil != err { 474 return nil, nil, nil, err 475 } 476 477 prog, locMap, err := CompileFunction(g, Nodes{z}, Nodes{cost, dcost[0]}) 478 if nil != err { 479 return nil, nil, nil, err 480 } 481 482 machine = NewTapeMachine(g, WithPrecompiled(prog, locMap), BindDualValues(z)) 483 484 err = machine.Let(z, tensor.New(tensor.WithBacking([]float64{xInit, yInit}), tensor.WithShape(2))) 485 if nil != err { 486 return nil, nil, nil, err 487 } 488 489 return 490 }