gorgonia.org/gorgonia@v0.9.17/solvers.go (about) 1 package gorgonia 2 3 import ( 4 "math" 5 6 "github.com/chewxy/math32" 7 "github.com/pkg/errors" 8 "gorgonia.org/tensor" 9 ) 10 11 // Solver is anything that does gradient updates. 12 // The name solvers is stolen from Caffe. A much shorter name than GradientUpdaters 13 type Solver interface { 14 Step([]ValueGrad) error 15 } 16 17 // ValueGrad is any type that has a value and a grad. This is used for Solvers 18 type ValueGrad interface { 19 Valuer 20 Grad() (Value, error) 21 } 22 23 // Namer is anything that has a name 24 type Namer interface { 25 Name() string 26 } 27 28 func newCachedDV(n ValueGrad, weights, grad Value, zero bool) (cached *dualValue, err error) { 29 cached = new(dualValue) 30 if cached.Value, err = CloneValue(weights); err != nil { 31 if nm, ok := n.(Namer); ok { 32 return nil, errors.Wrapf(err, "Failed to clone weights of %v", nm.Name()) 33 } 34 return nil, errors.Wrap(err, "Failed to clone weights") 35 } 36 if cached.d, err = CloneValue(grad); err != nil { 37 if nm, ok := n.(Namer); ok { 38 return nil, errors.Wrapf(err, "Failed to clone grad of %v", nm.Name()) 39 } 40 return nil, errors.Wrap(err, "Failed to clone grad") 41 } 42 if zero { 43 cached.Value = ZeroValue(cached.Value) 44 cached.d = ZeroValue(cached.d) 45 } 46 return 47 } 48 49 func extractWeightGrad(n ValueGrad) (weights, grad Value, err error) { 50 weights = n.Value() 51 if grad, err = n.Grad(); err != nil { 52 if nm, ok := n.(Namer); ok { 53 return weights, nil, errors.Wrapf(err, "No Grad found for %v", nm.Name()) 54 } 55 return weights, nil, errors.Wrap(err, "No Grad found") 56 } 57 return 58 } 59 60 // SolverOpt is a function that provides construction options for a Solver 61 type SolverOpt func(s Solver) 62 63 // WithL2Reg adds a L2 regularization parameter to the solver. By default, the solvers do not use any regularization param 64 func WithL2Reg(l2reg float64) SolverOpt { 65 f := func(s Solver) { 66 switch st := s.(type) { 67 case *RMSPropSolver: 68 st.l2reg = l2reg 69 st.useL2Reg = true 70 case *AdamSolver: 71 st.l2reg = l2reg 72 st.useL2Reg = true 73 case *VanillaSolver: 74 st.l2reg = l2reg 75 st.useL2Reg = true 76 case *Momentum: 77 st.l2reg = l2reg 78 st.useL2Reg = true 79 } 80 } 81 return f 82 } 83 84 // WithL1Reg adds a L1 regularization parameter to the solver. By default, the solvers do not use any regularization param 85 func WithL1Reg(l1reg float64) SolverOpt { 86 f := func(s Solver) { 87 switch st := s.(type) { 88 case *AdamSolver: 89 st.l1reg = l1reg 90 st.useL1Reg = true 91 case *VanillaSolver: 92 st.l1reg = l1reg 93 st.useL1Reg = true 94 case *Momentum: 95 st.l1reg = l1reg 96 st.useL1Reg = true 97 } 98 } 99 return f 100 } 101 102 // WithBatchSize sets the batch size for the solver. Currently only Adam and Vanilla (basic SGD) has batch size support 103 func WithBatchSize(batch float64) SolverOpt { 104 f := func(s Solver) { 105 switch st := s.(type) { 106 case *AdamSolver: 107 st.batch = batch 108 case *VanillaSolver: 109 st.batch = batch 110 case *Momentum: 111 st.batch = batch 112 } 113 } 114 return f 115 } 116 117 // WithEps sets the smoothing factor for the solver. 118 func WithEps(eps float64) SolverOpt { 119 f := func(s Solver) { 120 switch st := s.(type) { 121 case *RMSPropSolver: 122 st.eps = eps 123 case *AdamSolver: 124 st.eps = eps 125 } 126 } 127 return f 128 } 129 130 // WithClip clips the gradient if it gets too crazy. By default all solvers do not have any clips attached 131 func WithClip(clip float64) SolverOpt { 132 f := func(s Solver) { 133 switch st := s.(type) { 134 case *RMSPropSolver: 135 st.clip = clip 136 st.useClip = true 137 case *AdamSolver: 138 st.clip = clip 139 st.useClip = true 140 case *VanillaSolver: 141 st.clip = clip 142 st.useClip = true 143 case *BarzilaiBorweinSolver: 144 st.clip = clip 145 st.useClip = true 146 case *Momentum: 147 st.clip = clip 148 st.useClip = true 149 } 150 } 151 return f 152 } 153 154 // WithLearnRate sets the learn rate or step size for the solver. 155 func WithLearnRate(eta float64) SolverOpt { 156 f := func(s Solver) { 157 switch st := s.(type) { 158 case *RMSPropSolver: 159 st.eta = eta 160 case *AdamSolver: 161 st.eta = eta 162 case *VanillaSolver: 163 st.eta = eta 164 case *BarzilaiBorweinSolver: 165 st.eta = eta 166 case *Momentum: 167 st.eta = eta 168 } 169 } 170 return f 171 } 172 173 // WithBeta1 sets the beta1 param of the solver. Only works with Adam 174 func WithBeta1(beta1 float64) SolverOpt { 175 f := func(s Solver) { 176 switch st := s.(type) { 177 case *AdamSolver: 178 st.beta1 = beta1 179 } 180 } 181 return f 182 } 183 184 // WithBeta2 sets the beta1 param of the solver. Only works with Adam 185 func WithBeta2(beta2 float64) SolverOpt { 186 f := func(s Solver) { 187 switch st := s.(type) { 188 case *AdamSolver: 189 st.beta2 = beta2 190 } 191 } 192 return f 193 } 194 195 // WithRho sets the decay parameter of the RMSProp solver 196 func WithRho(rho float64) SolverOpt { 197 f := func(s Solver) { 198 switch st := s.(type) { 199 case *RMSPropSolver: 200 st.decay = rho 201 } 202 } 203 return f 204 } 205 206 // WithMomentum sets the momentum of the solver. It is a no-op is the solver's type is not Momentum 207 func WithMomentum(momentum float64) SolverOpt { 208 f := func(s Solver) { 209 switch st := s.(type) { 210 case *Momentum: 211 st.momentum = momentum 212 } 213 } 214 return f 215 } 216 217 // RMSPropSolver is a solver that implements Geoffrey Hinton's RMSProp gradient descent optimization algorithm. 218 // http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf 219 type RMSPropSolver struct { 220 decay float64 // decay rate/rho 221 eps float64 // smoothing factor 222 l2reg float64 // l2 regularization 223 clip float64 // clip value 224 eta float64 // learn rate 225 226 useClip, useL2Reg bool 227 228 // unsettable 229 cache []*dualValue 230 } 231 232 // NewRMSPropSolver creates an RMSProp solver with these default values: 233 // eta (learn rate) : 0.001 234 // eps (smoothing factor): 1e-8 235 // rho (decay factor) : 0.999 236 func NewRMSPropSolver(opts ...SolverOpt) *RMSPropSolver { 237 s := &RMSPropSolver{ 238 decay: 0.999, 239 eps: 1e-8, 240 eta: 0.001, 241 } 242 243 for _, opt := range opts { 244 opt(s) 245 } 246 return s 247 } 248 249 // Step steps through each node in the model and applies the RMSProp gradient descent algorithm on the value. 250 // 251 // This function will error out if the nodes do not have an associated Grad value. 252 func (s *RMSPropSolver) Step(model []ValueGrad) (err error) { 253 if s.cache == nil { 254 s.cache = make([]*dualValue, len(model)) 255 } 256 257 for i, n := range model { 258 var weights, grad Value 259 if weights, grad, err = extractWeightGrad(n); err != nil { 260 return err 261 } 262 263 var cached *dualValue 264 if cached = s.cache[i]; cached == nil { 265 if cached, err = newCachedDV(n, weights, grad, true); err != nil { 266 return err 267 } 268 s.cache[i] = cached 269 } 270 271 cv := cached.Value 272 // cw = cw*decay + (1-decay) * grad² 273 switch cw := cv.(type) { 274 case *tensor.Dense: 275 var gt, gt2, w, regularized tensor.Tensor 276 var decay, omdecay, stepSize, eps, l2reg, clip, negClip interface{} 277 switch cw.Dtype() { 278 case tensor.Float64: 279 decay = s.decay 280 omdecay = 1.0 - s.decay 281 stepSize = -s.eta 282 eps = s.eps 283 l2reg = s.l2reg 284 clip = s.clip 285 negClip = -s.clip 286 case tensor.Float32: 287 decay = float32(s.decay) 288 omdecay = float32(1.0 - s.decay) 289 stepSize = float32(-s.eta) 290 eps = float32(s.eps) 291 l2reg = float32(s.l2reg) 292 clip = float32(s.clip) 293 negClip = float32(-s.clip) 294 } 295 296 gt = grad.(tensor.Tensor) 297 if gt2, err = tensor.Square(gt); err != nil { 298 return errors.Wrap(err, pointWiseSquareFail) 299 } 300 tensor.Mul(cw, decay, tensor.UseUnsafe()) 301 tensor.Mul(gt2, omdecay, tensor.UseUnsafe()) 302 tensor.Add(cw, gt2, tensor.UseUnsafe()) 303 defer returnTensor(gt2) 304 305 if s.useClip { 306 if _, err = tensor.Clamp(gt, negClip, clip, tensor.UseUnsafe()); err != nil { 307 return errors.Wrap(err, clampFail) 308 } 309 } 310 311 // regularize 312 var upd tensor.Tensor 313 if upd, err = tensor.Add(cw, eps); err != nil { 314 return errors.Wrap(err, "Failed to carry Add()") 315 } 316 317 if _, err = tensor.InvSqrt(upd, tensor.UseUnsafe()); err != nil { 318 return errors.Wrap(err, invSqrtFail) 319 } 320 if _, err = tensor.Mul(gt, stepSize, tensor.UseUnsafe()); err != nil { 321 return errors.Wrap(err, pointWiseMulFail) 322 } 323 if _, err = tensor.Mul(upd, gt, tensor.UseUnsafe()); err != nil { 324 return errors.Wrap(err, pointWiseMulFail) 325 } 326 327 // update 328 w = weights.(*tensor.Dense) 329 if s.useL2Reg { 330 if regularized, err = tensor.Mul(w, l2reg); err != nil { 331 return errors.Wrap(err, pointWiseMulFail) 332 } 333 if _, err = tensor.Sub(upd, regularized, tensor.UseUnsafe()); err != nil { 334 return errors.Wrap(err, subFail) 335 } 336 defer returnTensor(regularized) 337 } 338 339 if _, err = tensor.Add(w, upd, tensor.UseUnsafe()); err != nil { 340 return errors.Wrap(err, addFail) 341 } 342 defer returnTensor(upd) 343 344 // zero all 345 gt.Zero() 346 347 case *F32: 348 decay := float32(s.decay) 349 omdecay := float32(1.0 - s.decay) 350 stepSize := float32(s.eta) 351 eps := float32(s.eps) 352 l2reg := float32(s.l2reg) 353 354 gs := grad.(*F32).any() 355 c := cw.any() 356 c = c*decay + omdecay*gs*gs 357 358 cached.Value, _ = anyToScalar(c) 359 360 w := weights.(*F32).any() 361 upd := -stepSize*gs/math32.Sqrt(c+eps) - l2reg*w 362 w += upd 363 364 // because scalar values are copies, and not pointers, we have to actually re-update the dualValu in model[i] 365 *(weights.(*F32)) = F32(w) 366 *(grad.(*F32)) = F32(0.0) 367 case *F64: 368 decay := s.decay 369 omdecay := 1.0 - s.decay 370 stepSize := s.eta 371 eps := s.eps 372 l2reg := s.l2reg 373 374 gs := grad.(*F64).any() 375 c := cw.any() 376 c = c*decay + omdecay*gs*gs 377 378 cached.Value, _ = anyToScalar(c) 379 380 w := weights.(*F64).any() 381 upd := -stepSize*gs/math.Sqrt(c+eps) - l2reg*w 382 w += upd 383 384 // because scalar values are copies, and not pointers, we have to actually re-update the dualValu in model[i] 385 *(weights.(*F64)) = F64(w) 386 *(grad.(*F64)) = F64(0.0) 387 default: 388 } 389 solverLogf("AFTER %1.1s", n) 390 } 391 return nil 392 } 393 394 // AdamSolver is the Adaptive Moment Estimation solver (basically RMSProp on steroids). 395 // Paper: http://arxiv.org/abs/1412.6980 396 // 397 // We overload the purpose of existing data structure of a *dualValue. However, instead of just holding a value and its derivative, 398 // the cache's *dualValues hold the Means of gradients (in .Value) and the variances of the gradients (in .d) 399 type AdamSolver struct { 400 eta float64 // learn rate 401 eps float64 // smoothing 402 beta1 float64 // modifier for means 403 beta2 float64 // modifier for variances 404 clip float64 // clip gradients 405 l1reg float64 // l1 regularization parameter 406 l2reg float64 // l2 regularization parameter 407 batch float64 // batch size 408 409 useClip, useL1Reg, useL2Reg bool 410 411 // unsettable 412 iter int 413 cache []*dualValue 414 } 415 416 // NewAdamSolver creates an Adam solver with these default values: 417 // eta (learn rate) : 0.001 418 // eps (smoothing factor) : 1e-8 419 // beta1 : 0.9 420 // beta2 : 0.999 421 // batch : 1 422 func NewAdamSolver(opts ...SolverOpt) *AdamSolver { 423 s := &AdamSolver{ 424 eta: 0.001, 425 eps: 1e-8, 426 beta1: 0.9, 427 beta2: 0.999, 428 batch: 1, 429 } 430 431 for _, opt := range opts { 432 opt(s) 433 } 434 return s 435 } 436 437 // Step steps through each node in the model and applies the Adaptive Moment Estimation gradient descent algorithm on the value. 438 // 439 // This function will error out if the nodes do not have an associated Grad value. 440 func (s *AdamSolver) Step(model []ValueGrad) (err error) { 441 if s.cache == nil { 442 s.cache = make([]*dualValue, len(model)) 443 } 444 445 s.iter++ 446 correction1 := (1 - math.Pow(s.beta1, float64(s.iter))) 447 correction2 := (1 - math.Pow(s.beta2, float64(s.iter))) 448 449 for i, n := range model { 450 var weights, grad Value 451 if weights, grad, err = extractWeightGrad(n); err != nil { 452 return err 453 } 454 455 var cached *dualValue 456 if cached = s.cache[i]; cached == nil { 457 if cached, err = newCachedDV(n, weights, grad, true); err != nil { 458 return err 459 } 460 s.cache[i] = cached 461 } 462 463 cvm := cached.Value // means of gradients 464 cvv := cached.d // variances of gradients 465 466 switch m := cvm.(type) { 467 case *tensor.Dense: 468 g := grad.(*tensor.Dense) 469 w := weights.(*tensor.Dense) 470 v := cvv.(*tensor.Dense) 471 472 var l1reg, l2reg, clip, negClip, beta1, beta2, omβ1, omβ2, eps, eta, onePerBatch interface{} 473 var correctionV1, correctionV2 interface{} 474 switch m.Dtype() { 475 case tensor.Float64: 476 l1reg = s.l1reg 477 l2reg = s.l2reg 478 clip = s.clip 479 negClip = -s.clip 480 beta1 = s.beta1 481 beta2 = s.beta2 482 omβ1 = float64(1) - s.beta1 483 omβ2 = float64(1) - s.beta2 484 eps = s.eps 485 eta = -s.eta 486 onePerBatch = float64(1) / s.batch 487 correctionV1 = float64(1) / float64(correction1) 488 correctionV2 = float64(1) / float64(correction2) 489 case tensor.Float32: 490 l1reg = float32(s.l1reg) 491 l2reg = float32(s.l2reg) 492 clip = float32(s.clip) 493 negClip = -float32(s.clip) 494 beta1 = float32(s.beta1) 495 beta2 = float32(s.beta2) 496 omβ1 = float32(1) - float32(s.beta1) 497 omβ2 = float32(1) - float32(s.beta2) 498 eps = float32(s.eps) 499 eta = -float32(s.eta) 500 onePerBatch = float32(1) / float32(s.batch) 501 correctionV1 = float32(1) / float32(correction1) 502 correctionV2 = float32(1) / float32(correction2) 503 } 504 505 // prep the regularization of gradients 506 if s.useL1Reg { 507 var l1regs tensor.Tensor 508 if l1regs, err = tensor.Sign(w); err != nil { 509 errors.Wrap(err, signFail) 510 } 511 if l1regs, err = tensor.Mul(l1reg, l1regs, tensor.UseUnsafe()); err != nil { 512 return errors.Wrap(err, pointWiseMulFail) 513 } 514 if _, err = tensor.Add(g, l1regs, tensor.UseUnsafe()); err != nil { 515 return errors.Wrap(err, addFail) 516 } 517 defer returnTensor(l1regs) 518 } 519 520 if s.useL2Reg { 521 var l2regs tensor.Tensor 522 if l2regs, err = tensor.Mul(w, l2reg); err != nil { 523 return errors.Wrap(err, pointWiseMulFail) 524 } 525 526 if _, err = tensor.Add(g, l2regs, tensor.UseUnsafe()); err != nil { 527 return errors.Wrap(err, addFail) 528 } 529 530 defer returnTensor(l2regs) 531 } 532 533 if s.batch > 1 { 534 if _, err = tensor.Mul(g, onePerBatch, tensor.UseUnsafe()); err != nil { 535 return errors.Wrap(err, pointWiseMulFail) 536 } 537 } 538 539 if s.useClip && s.clip > 0 { 540 if _, err = tensor.Clamp(g, negClip, clip, tensor.UseUnsafe()); err != nil { 541 return errors.Wrap(err, clampFail) 542 } 543 } 544 545 // prep done. Now let's apply the formula: 546 // the formula is 547 // (β_1 * m_t-1) + (1 - β_1)g_t .................. 1 548 // (β_2 * v_t-1) + (1 - β_2)*(g_t)² ............. 2 549 550 // equation(1) 551 t1 := g.Clone().(*tensor.Dense) 552 if _, err = tensor.Mul(t1, omβ1, tensor.UseUnsafe()); err != nil { 553 return errors.Wrap(err, pointWiseMulFail) 554 } 555 556 // equation(2) 557 if _, err = tensor.Mul(g, g, tensor.UseUnsafe()); err != nil { 558 return errors.Wrap(err, pointWiseMulFail) 559 } 560 if _, err = tensor.Mul(g, omβ2, tensor.UseUnsafe()); err != nil { 561 return errors.Wrap(err, pointWiseMulFail) 562 } 563 564 // equation (1) 565 if _, err = tensor.Mul(m, beta1, tensor.WithIncr(t1)); err != nil { 566 return errors.Wrap(err, pointWiseMulFail) 567 } 568 569 // equation (2) 570 if _, err = tensor.Mul(v, beta2, tensor.WithIncr(g)); err != nil { 571 return errors.Wrap(err, pointWiseMulFail) 572 } 573 574 defer returnTensor(m) 575 defer returnTensor(v) 576 cached.SetValue(t1) 577 cached.SetDeriv(g.Clone().(*tensor.Dense)) 578 579 // now deal with the hats 580 mHats := t1.Clone().(*tensor.Dense) 581 vHats := g.Clone().(*tensor.Dense) 582 583 if _, err = tensor.Mul(mHats, correctionV1, tensor.UseUnsafe()); err != nil { 584 return errors.Wrap(err, pointWiseMulFail) 585 } 586 587 if _, err = tensor.Mul(vHats, correctionV2, tensor.UseUnsafe()); err != nil { 588 return errors.Wrap(err, pointWiseMulFail) 589 } 590 591 // update := -eta * mHat / (sqrt(vHat) + epsilon) 592 if _, err = tensor.Sqrt(vHats, tensor.UseUnsafe()); err != nil { 593 return // TODO: rewrite this to use InvSqrt 594 } 595 596 if _, err = tensor.Add(vHats, eps, tensor.UseUnsafe()); err != nil { 597 return 598 } 599 600 if _, err = tensor.Mul(mHats, eta, tensor.UseUnsafe()); err != nil { 601 return errors.Wrap(err, pointWiseMulFail) 602 } 603 604 if _, err = tensor.Div(mHats, vHats, tensor.WithIncr(w)); err != nil { 605 return 606 } 607 608 defer returnTensor(vHats) 609 defer returnTensor(mHats) 610 611 if _, err = tensor.Add(w, mHats, tensor.UseUnsafe()); err != nil { 612 return errors.Wrap(err, addFail) 613 } 614 615 g.Zero() 616 617 case *F32: 618 g := grad.(*F32).any() 619 w := weights.(*F32).any() 620 v := cvv.(*F32).any() 621 mm := m.any() 622 623 l1reg := float32(s.l1reg) 624 l2reg := float32(s.l2reg) 625 batch := float32(s.batch) 626 clip := float32(s.clip) 627 beta1 := float32(s.beta1) 628 beta2 := float32(s.beta2) 629 eps := float32(s.eps) 630 eta := float32(s.eta) 631 632 if s.useL1Reg { 633 if w < 0 { 634 l1reg = -l1reg 635 } 636 g += l1reg 637 } 638 639 if s.useL2Reg { 640 l2reg *= w 641 g += l2reg 642 } 643 644 if batch > 1 { 645 g *= (1 / batch) 646 } 647 648 if s.useClip { 649 if g > clip { 650 g = clip 651 } else if g < -clip { 652 g = -clip 653 } 654 } 655 656 newM := (beta1 * mm) + (1-beta1)*g 657 newV := (beta2 * v) + (1-beta2)*g*g 658 659 cached.Value, _ = anyToScalar(newM) 660 cached.d, _ = anyToScalar(newV) 661 662 mHat := (1 / float32(correction1)) * newM 663 vHat := (1 / float32(correction2)) * newV 664 665 upd := -eta * mHat / (float32(math.Sqrt(float64(vHat))) + eps) 666 w += upd 667 668 *(weights.(*F32)) = F32(w) 669 *(grad.(*F32)) = F32(0.0) 670 case *F64: 671 g := grad.(*F64).any() 672 w := weights.(*F64).any() 673 v := cvv.(*F64).any() 674 mm := m.any() 675 676 l1reg := s.l1reg 677 l2reg := s.l2reg 678 batch := s.batch 679 clip := s.clip 680 beta1 := s.beta1 681 beta2 := s.beta2 682 eps := s.eps 683 eta := s.eta 684 685 if s.useL1Reg { 686 if w < 0 { 687 l1reg = -l1reg 688 } 689 g += l1reg 690 } 691 692 if s.useL2Reg { 693 l2reg *= w 694 g += l2reg 695 } 696 697 if batch > 1 { 698 g *= (1 / batch) 699 } 700 701 if s.useClip { 702 if g > clip { 703 g = clip 704 } else if g < -clip { 705 g = -clip 706 } 707 } 708 709 newM := (beta1 * mm) + (1-beta1)*g 710 newV := (beta2 * v) + (1-beta2)*g*g 711 712 cached.Value, _ = anyToScalar(newM) 713 cached.d, _ = anyToScalar(newV) 714 715 mHat := (1 / correction1) * newM 716 vHat := (1 / correction2) * newV 717 718 upd := -eta * mHat / (math.Sqrt(vHat) + eps) 719 w += upd 720 721 *(weights.(*F64)) = F64(w) 722 *(grad.(*F64)) = F64(0.0) 723 724 default: 725 err = errors.Errorf(nyiTypeFail, "AdamSolver", cvm) 726 return 727 } 728 729 } 730 return 731 } 732 733 // VanillaSolver is your bog standard stochastic gradient descent optimizer. There are no fancy features to this 734 type VanillaSolver struct { 735 eta float64 // learn rate 736 clip float64 // clip gradients 737 l1reg float64 // l1 regularization parameter 738 l2reg float64 // l2 regularization parameter 739 batch float64 // batch size 740 741 useClip, useL1Reg, useL2Reg bool 742 } 743 744 // NewVanillaSolver creates a new VanillaSolver with sane-ish default values 745 func NewVanillaSolver(opts ...SolverOpt) *VanillaSolver { 746 s := &VanillaSolver{ 747 batch: 1, 748 eta: 0.001, 749 } 750 for _, opt := range opts { 751 opt(s) 752 } 753 return s 754 } 755 756 // Step steps through each node in the model and applies the most basic gradient descent algorithm on the value. 757 // 758 // This function will error out if the nodes do not have an associated Grad value. 759 func (s *VanillaSolver) Step(model []ValueGrad) (err error) { 760 for _, n := range model { 761 var weights, grad Value 762 if weights, grad, err = extractWeightGrad(n); err != nil { 763 return err 764 } 765 switch w := weights.(type) { 766 case *tensor.Dense: 767 g := grad.(*tensor.Dense) 768 769 var l1reg, l2reg, clip, negClip, eta interface{} 770 var onePerBatch interface{} 771 switch w.Dtype() { 772 case tensor.Float64: 773 l1reg = s.l1reg 774 l2reg = s.l2reg 775 clip = s.clip 776 negClip = -s.clip 777 eta = -s.eta 778 onePerBatch = float64(1) / s.batch 779 case tensor.Float32: 780 l1reg = float32(s.l1reg) 781 l2reg = float32(s.l2reg) 782 clip = float32(s.clip) 783 negClip = float32(-s.clip) 784 eta = float32(-s.eta) 785 onePerBatch = float32(1) / float32(s.batch) 786 } 787 // prep the regularization of gradients 788 var l1regs, l2regs tensor.Tensor 789 if s.useL1Reg { 790 if l1regs, err = tensor.Sign(w); err != nil { 791 return errors.Wrap(err, signFail) 792 } 793 794 if l1regs, err = tensor.Mul(l1reg, l1regs, tensor.UseUnsafe()); err != nil { 795 return errors.Wrap(err, pointWiseMulFail) 796 } 797 798 if _, err = tensor.Add(g, l1regs, tensor.UseUnsafe()); err != nil { 799 return errors.Wrap(err, addFail) 800 } 801 802 defer returnTensor(l1regs) 803 } 804 805 if s.useL2Reg { 806 if l2regs, err = tensor.Mul(w, l2reg); err != nil { 807 return errors.Wrap(err, pointWiseMulFail) 808 } 809 810 if _, err = tensor.Add(g, l2regs, tensor.UseUnsafe()); err != nil { 811 return errors.Wrap(err, addFail) 812 } 813 814 defer returnTensor(l2regs) 815 } 816 817 if s.batch > 1 { 818 if _, err = tensor.Mul(g, onePerBatch, tensor.UseUnsafe()); err != nil { 819 return errors.Wrap(err, pointWiseMulFail) 820 } 821 } 822 823 if s.useClip && s.clip > 0 { 824 if _, err = tensor.Clamp(g, negClip, clip, tensor.UseUnsafe()); err != nil { 825 return errors.Wrap(err, clampFail) 826 } 827 } 828 829 if _, err = tensor.Mul(g, eta, tensor.UseUnsafe()); err != nil { 830 return errors.Wrap(err, pointWiseMulFail) 831 } 832 833 if _, err = tensor.Add(w, g, tensor.UseUnsafe()); err != nil { 834 return errors.Wrap(err, addFail) 835 } 836 837 g.Zero() 838 839 case *F32: 840 g := grad.(*F32).any() 841 wv := w.any() 842 843 l1reg := float32(s.l1reg) 844 l2reg := float32(s.l2reg) 845 batch := float32(s.batch) 846 clip := float32(s.clip) 847 eta := float32(s.eta) 848 849 if s.useL1Reg { 850 if wv < 0 { 851 l1reg = -l1reg 852 } 853 g += l1reg 854 } 855 856 if s.useL2Reg { 857 l2reg *= wv 858 g += l2reg 859 } 860 861 if batch > 1 { 862 g *= (1 / batch) 863 } 864 865 if s.useClip { 866 if g > clip { 867 g = clip 868 } else if g < -clip { 869 g = -clip 870 } 871 } 872 873 upd := -eta * g 874 wv += upd 875 876 *(weights.(*F32)) = F32(wv) 877 *(grad.(*F32)) = F32(0.0) 878 case *F64: 879 g := grad.(*F64).any() 880 wv := w.any() 881 882 l1reg := s.l1reg 883 l2reg := s.l2reg 884 batch := s.batch 885 clip := s.clip 886 eta := s.eta 887 888 if s.useL1Reg { 889 if wv < 0 { 890 l1reg = -l1reg 891 } 892 g += l1reg 893 } 894 895 if s.useL2Reg { 896 l2reg *= wv 897 g += l2reg 898 } 899 900 if batch > 1 { 901 g *= (1 / batch) 902 } 903 904 if s.useClip { 905 if g > clip { 906 g = clip 907 } else if g < -clip { 908 g = -clip 909 } 910 } 911 912 upd := -eta * g 913 wv += upd 914 915 *(weights.(*F64)) = F64(wv) 916 *(grad.(*F64)) = F64(0.0) 917 default: 918 return errors.Errorf(nyiFail, "VanillaSolver.step", w) 919 } 920 } 921 return 922 } 923 924 // Momentum is the stochastic gradient descent optimizer with momentum item. 925 type Momentum struct { 926 eta float64 // learn rate 927 momentum float64 // momentum 928 clip float64 // clip gradients 929 l1reg float64 // l1 regularization parameter 930 l2reg float64 // l2 regularization parameter 931 batch float64 // batch size 932 933 useClip, useL1Reg, useL2Reg bool 934 935 cache []*dualValue 936 } 937 938 // NewMomentum creates a new Momentum with sane-ish default values 939 func NewMomentum(opts ...SolverOpt) *Momentum { 940 s := &Momentum{ 941 batch: 1, 942 eta: 0.001, 943 momentum: 0.9, 944 } 945 for _, opt := range opts { 946 opt(s) 947 } 948 return s 949 } 950 951 // Step steps through each node in the model and applies the Momentum stochastic gradient descent algorithm on the value. 952 // 953 // This function will error out if the nodes do not have an associated Grad value. 954 func (s *Momentum) Step(model []ValueGrad) (err error) { 955 if s.cache == nil { 956 s.cache = make([]*dualValue, len(model)) 957 } 958 959 for i, n := range model { 960 var weights, grad Value 961 if weights, grad, err = extractWeightGrad(n); err != nil { 962 return err 963 } 964 965 var cached *dualValue 966 if cached = s.cache[i]; cached == nil { 967 if cached, err = newCachedDV(n, weights, grad, true); err != nil { 968 return err 969 } 970 s.cache[i] = cached 971 } 972 973 cv := cached.Value 974 // cw = cw * momentum - eta * grad 975 // w = w + cw 976 switch cw := cv.(type) { 977 case *tensor.Dense: 978 w := weights.(*tensor.Dense) 979 g := grad.(*tensor.Dense) 980 981 var l1reg, l2reg, clip, negClip, eta, momentum, onePerBatch interface{} 982 switch cw.Dtype() { 983 case tensor.Float64: 984 l1reg = s.l1reg 985 l2reg = s.l2reg 986 clip = s.clip 987 negClip = -s.clip 988 eta = -s.eta 989 momentum = s.momentum 990 onePerBatch = float64(1) / s.batch 991 case tensor.Float32: 992 l1reg = float32(s.l1reg) 993 l2reg = float32(s.l2reg) 994 clip = float32(s.clip) 995 negClip = float32(-s.clip) 996 eta = float32(-s.eta) 997 momentum = float32(s.momentum) 998 onePerBatch = float32(1) / float32(s.batch) 999 } 1000 1001 // prep the regularization of gradients 1002 var l1regs, l2regs tensor.Tensor 1003 if s.useL1Reg { 1004 if l1regs, err = tensor.Sign(cw); err != nil { 1005 return errors.Wrap(err, signFail) 1006 } 1007 1008 if l1regs, err = tensor.Mul(l1reg, l1regs, tensor.UseUnsafe()); err != nil { 1009 return errors.Wrap(err, pointWiseMulFail) 1010 } 1011 1012 if _, err = tensor.Add(g, l1regs, tensor.UseUnsafe()); err != nil { 1013 return errors.Wrap(err, addFail) 1014 } 1015 1016 defer returnTensor(l1regs) 1017 } 1018 1019 if s.useL2Reg { 1020 if l2regs, err = tensor.Mul(cw, l2reg); err != nil { 1021 return errors.Wrap(err, pointWiseMulFail) 1022 } 1023 1024 if _, err = tensor.Add(g, l2regs, tensor.UseUnsafe()); err != nil { 1025 return errors.Wrap(err, addFail) 1026 } 1027 1028 defer returnTensor(l2regs) 1029 } 1030 1031 if s.batch > 1 { 1032 if _, err = tensor.Mul(g, onePerBatch, tensor.UseUnsafe()); err != nil { 1033 return errors.Wrap(err, pointWiseMulFail) 1034 } 1035 } 1036 1037 if s.useClip && s.clip > 0 { 1038 if _, err = tensor.Clamp(g, negClip, clip, tensor.UseUnsafe()); err != nil { 1039 return errors.Wrap(err, clampFail) 1040 } 1041 } 1042 1043 // momentum 1044 if _, err = tensor.Mul(g, eta, tensor.UseUnsafe()); err != nil { 1045 return errors.Wrap(err, pointWiseMulFail) 1046 } 1047 1048 // cw * momentum 1049 if _, err = tensor.Mul(cw, momentum, tensor.UseUnsafe()); err != nil { 1050 return errors.Wrap(err, pointWiseMulFail) 1051 } 1052 1053 // cw * momentum - eta * grad 1054 if _, err = tensor.Add(cw, g, tensor.UseUnsafe()); err != nil { 1055 return errors.Wrap(err, pointWiseMulFail) 1056 } 1057 1058 if _, err = tensor.Add(w, cw, tensor.UseUnsafe()); err != nil { 1059 return errors.Wrap(err, addFail) 1060 } 1061 1062 g.Zero() 1063 1064 case *F32: 1065 l1reg := float32(s.l1reg) 1066 l2reg := float32(s.l2reg) 1067 batch := float32(s.batch) 1068 clip := float32(s.clip) 1069 eta := float32(s.eta) 1070 momentum := float32(s.momentum) 1071 1072 g := grad.(*F32).any() 1073 w := weights.(*F32).any() 1074 c := cw.any() 1075 1076 if s.useL1Reg { 1077 if w < 0 { 1078 l1reg = -l1reg 1079 } 1080 g += l1reg 1081 } 1082 1083 if s.useL2Reg { 1084 l2reg *= w 1085 g += l2reg 1086 } 1087 1088 if batch > 1 { 1089 g *= (1 / batch) 1090 } 1091 1092 if s.useClip { 1093 if g > clip { 1094 g = clip 1095 } else if g < -clip { 1096 g = -clip 1097 } 1098 } 1099 1100 c = c*momentum - eta*g 1101 w += c 1102 1103 *(weights.(*F32)) = F32(w) 1104 *(grad.(*F32)) = F32(0.0) 1105 case *F64: 1106 l1reg := s.l1reg 1107 l2reg := s.l2reg 1108 batch := s.batch 1109 clip := s.clip 1110 eta := s.eta 1111 momentum := s.momentum 1112 1113 g := grad.(*F64).any() 1114 w := weights.(*F64).any() 1115 c := cw.any() 1116 1117 if s.useL1Reg { 1118 if w < 0 { 1119 l1reg = -l1reg 1120 } 1121 g += l1reg 1122 } 1123 1124 if s.useL2Reg { 1125 l2reg *= w 1126 g += l2reg 1127 } 1128 1129 if batch > 1 { 1130 g *= (1 / batch) 1131 } 1132 1133 if s.useClip { 1134 if g > clip { 1135 g = clip 1136 } else if g < -clip { 1137 g = -clip 1138 } 1139 } 1140 1141 c = c*momentum - eta*g 1142 w += c 1143 1144 *(weights.(*F64)) = F64(w) 1145 *(grad.(*F64)) = F64(0.0) 1146 default: 1147 return errors.Errorf(nyiFail, "Momentum.step", cv) 1148 } 1149 } 1150 return 1151 } 1152 1153 // AdaGradSolver is the solver that does adaptive gradient descent. Read the paper: http://jmlr.org/papers/v12/duchi11a.html 1154 type AdaGradSolver struct { 1155 eta float64 // learn rate 1156 eps float64 // smoothing factor 1157 l1Reg float64 // l1reg param 1158 l2reg float64 // l2reg param 1159 clip float64 // clip at 1160 1161 useL2Reg, useClip bool 1162 1163 cache []*dualValue 1164 } 1165 1166 // NewAdaGradSolver creates a new AdaGradSolver with sane-ish default values 1167 func NewAdaGradSolver(opts ...SolverOpt) *AdaGradSolver { 1168 s := &AdaGradSolver{ 1169 eta: 0.001, 1170 eps: 1e-8, 1171 } 1172 1173 for _, opt := range opts { 1174 opt(s) 1175 } 1176 return s 1177 } 1178 1179 // Step steps through each node in the model and applies the Adaptive Gradient gradient descent algorithm on the value. 1180 // 1181 // This function will error out if the nodes do not have an associated Grad value. 1182 func (s *AdaGradSolver) Step(model []ValueGrad) (err error) { 1183 if s.cache == nil { 1184 s.cache = make([]*dualValue, len(model)) 1185 } 1186 1187 for i, n := range model { 1188 var weights, grad Value 1189 if weights, grad, err = extractWeightGrad(n); err != nil { 1190 return err 1191 } 1192 1193 var cached *dualValue 1194 if cached = s.cache[i]; cached == nil { 1195 if cached, err = newCachedDV(n, weights, grad, true); err != nil { 1196 return err 1197 } 1198 s.cache[i] = cached 1199 } 1200 1201 cv := cached.Value 1202 1203 switch cw := cv.(type) { 1204 case *tensor.Dense: 1205 var w, g, c, g2, regularized tensor.Tensor 1206 1207 var l2reg, clip, negClip, eps, eta interface{} 1208 switch cw.Dtype() { 1209 case tensor.Float64: 1210 l2reg = s.l2reg 1211 clip = s.clip 1212 negClip = -s.clip 1213 eps = s.eps 1214 eta = -s.eta 1215 case tensor.Float32: 1216 l2reg = float32(s.l2reg) 1217 clip = float32(s.clip) 1218 negClip = float32(-s.clip) 1219 eps = float32(s.eps) 1220 eta = float32(-s.eta) 1221 } 1222 1223 g = grad.(*tensor.Dense) 1224 if g2, err = tensor.Square(g); err != nil { 1225 return errors.Wrap(err, pointWiseSquareFail) 1226 } 1227 1228 c = cw 1229 tensor.Add(c, g2, tensor.UseUnsafe()) 1230 defer returnTensor(g2) 1231 1232 if s.useClip { 1233 if _, err = tensor.Clamp(g, negClip, clip, tensor.UseUnsafe()); err != nil { 1234 return errors.Wrap(err, clampFail) 1235 } 1236 } 1237 1238 // update 1239 var upd tensor.Tensor 1240 if upd, err = tensor.Add(c, eps); err != nil { 1241 return errors.Wrap(err, addFail) 1242 } 1243 1244 if _, err = tensor.InvSqrt(upd, tensor.UseUnsafe()); err != nil { 1245 return errors.Wrap(err, invSqrtFail) 1246 } 1247 if _, err = tensor.Mul(g, eta, tensor.UseUnsafe()); err != nil { 1248 return errors.Wrap(err, pointWiseMulFail) 1249 } 1250 1251 if _, err = tensor.Mul(upd, g, tensor.UseUnsafe()); err != nil { 1252 return errors.Wrap(err, pointWiseMulFail) 1253 } 1254 1255 // regularize 1256 w = weights.(*tensor.Dense) 1257 1258 if s.useL2Reg { 1259 if regularized, err = tensor.Mul(w, l2reg); err != nil { 1260 return errors.Wrap(err, pointWiseMulFail) 1261 } 1262 1263 if _, err = tensor.Sub(upd, regularized, tensor.UseUnsafe()); err != nil { 1264 return errors.Wrap(err, subFail) 1265 } 1266 1267 defer returnTensor(regularized) 1268 } 1269 1270 if _, err = tensor.Add(w, upd, tensor.UseUnsafe()); err != nil { 1271 return errors.Wrap(err, addFail) 1272 } 1273 defer returnTensor(upd) 1274 1275 // zero all 1276 g.Zero() 1277 1278 case *F32: 1279 var w, g, c float32 1280 1281 l2reg := float32(s.l2reg) 1282 clip := float32(s.clip) 1283 eps := float32(s.eps) 1284 eta := float32(s.eta) 1285 1286 c = cw.any() 1287 g = grad.(*F32).any() 1288 1289 c += g * g 1290 1291 if s.useClip { 1292 if g > clip { 1293 g = clip 1294 } else if g < -clip { 1295 g = -clip 1296 } 1297 } 1298 1299 w = weights.(*F32).any() 1300 1301 upd := -eta * g / math32.Sqrt(c+eps) 1302 1303 if s.useL2Reg { 1304 upd -= w * l2reg 1305 } 1306 1307 w += upd 1308 1309 // because scalar values are copies, and not pointers, we have to actually re-update the dualValu in model[i] 1310 *(weights.(*F32)) = F32(w) 1311 *(grad.(*F32)) = F32(0.0) 1312 case *F64: 1313 var w, g, c float64 1314 1315 l2reg := s.l2reg 1316 clip := s.clip 1317 eps := s.eps 1318 eta := s.eta 1319 1320 c = cw.any() 1321 g = grad.(*F64).any() 1322 1323 c += g * g 1324 1325 if s.useClip { 1326 if g > clip { 1327 g = clip 1328 } else if g < -clip { 1329 g = -clip 1330 } 1331 } 1332 1333 w = weights.(*F64).any() 1334 upd := -eta * g / math.Sqrt(c+eps) 1335 if s.useL2Reg { 1336 upd -= w * l2reg 1337 } 1338 1339 w += upd 1340 1341 // because scalar values are copies, and not pointers, we have to actually re-update the dualValu in model[i] 1342 *(weights.(*F64)) = F64(w) 1343 *(grad.(*F64)) = F64(0.0) 1344 1345 default: 1346 return errors.Errorf(nyiFail, "Adagrad step", cv) 1347 } 1348 1349 } 1350 1351 return 1352 } 1353 1354 // BarzilaiBorweinSolver / Barzilai-Borwein performs Gradient Descent in steepest descend direction 1355 // Solves 0 = F(x), by 1356 // xᵢ₊₁ = xᵢ - eta * Grad(F)(xᵢ) 1357 // Where the learn rate eta is calculated by the Barzilai-Borwein method: 1358 // eta(xᵢ) = <(xᵢ - xᵢ₋₁), (Grad(F)(xᵢ) - Grad(F)(xᵢ₋₁))> / 1359 // ∥(Grad(F)(xᵢ) - Grad(F)(xᵢ₋₁))∥² 1360 // The input learn rate is used for the first iteration. 1361 // 1362 // TODO: Check out stochastic implementations, e.g. "Barzilai-Borwein Step Size for Stochastic Gradient Descent" https://arxiv.org/abs/1605.04131 1363 type BarzilaiBorweinSolver struct { 1364 eta float64 // initial learn rate 1365 clip float64 // clip value 1366 useClip bool 1367 prevDV []*dualValue // dual value for xᵢ₋₁ step 1368 } 1369 1370 // NewBarzilaiBorweinSolver creates a new Barzilai-Borwein solver withs some default values: 1371 // the learn rate is set to 0.001 and the solver does not use clipping. 1372 func NewBarzilaiBorweinSolver(opts ...SolverOpt) *BarzilaiBorweinSolver { 1373 s := &BarzilaiBorweinSolver{ 1374 eta: 0.001, 1375 useClip: false, 1376 } 1377 1378 for _, opt := range opts { 1379 opt(s) 1380 } 1381 return s 1382 } 1383 1384 // Step steps through each node in the model and applies the Barzilai-Borwein gradient descent algorithm on the value. 1385 // 1386 // This function will error out if the nodes do not have an associated Grad value. 1387 func (s *BarzilaiBorweinSolver) Step(model []ValueGrad) (err error) { 1388 1389 firstRun := false 1390 if s.prevDV == nil { 1391 firstRun = true 1392 s.prevDV = make([]*dualValue, len(model)) 1393 } 1394 1395 // Update the learning rate 1396 if false == firstRun { 1397 nominator := float64(0.0) 1398 denominator := float64(0.0) 1399 1400 for nodeNr, node := range model { 1401 var weights, grad Value 1402 if weights, grad, err = extractWeightGrad(node); err != nil { 1403 return err 1404 } 1405 1406 switch w := weights.(type) { 1407 case *tensor.Dense: 1408 g, ok := grad.(*tensor.Dense) 1409 if !ok { 1410 return errors.Errorf("Expected a *tensor.Dense in %v. Got %T instead", node, grad) 1411 } 1412 1413 wOld, ok := s.prevDV[nodeNr].Value.(*tensor.Dense) 1414 if !ok { 1415 return errors.Errorf("Expected a *tensor.Dense in %v. Got %T instead", node, s.prevDV[nodeNr].Value) 1416 } 1417 1418 gOld, ok := s.prevDV[nodeNr].d.(*tensor.Dense) 1419 if !ok { 1420 return errors.Errorf("Expected a *tensor.Dense in %v. Got %T instead", node, s.prevDV[nodeNr].d) 1421 } 1422 1423 valueDiff, err := tensor.Sub(w, wOld) 1424 defer returnTensor(valueDiff) 1425 if err != nil { 1426 return errors.Wrap(err, subFail) 1427 } 1428 1429 gradDiff, err := tensor.Sub(g, gOld) 1430 defer returnTensor(gradDiff) 1431 if err != nil { 1432 return errors.Wrap(err, subFail) 1433 } 1434 1435 // <(xᵢ - xᵢ₋₁), (Grad(F)(xᵢ) - Grad(F)(xᵢ₋₁))> 1436 1437 // Scalar Product == Total tensor contraction 1438 dims := valueDiff.Dims() 1439 contractionAxes := make([]int, dims, dims) 1440 for axis := 0; axis < len(contractionAxes); axis++ { 1441 contractionAxes[axis] = axis 1442 } 1443 1444 valGradDiffscalarProd, err := tensor.Contract(valueDiff, gradDiff, contractionAxes, contractionAxes) 1445 if err != nil { 1446 return errors.New("operationError, Contracting value / gradient difference") 1447 } 1448 defer returnTensor(valGradDiffscalarProd) 1449 1450 nominator += valGradDiffscalarProd.Data().([]float64)[0] 1451 1452 // ∥(Grad(F)(xᵢ) - Grad(F)(xᵢ₋₁))∥² 1453 gradDiffscalarProd, err := tensor.Contract(gradDiff, gradDiff, contractionAxes, contractionAxes) 1454 if err != nil { 1455 return errors.New("operationError, Contracting value / gradient difference") 1456 } 1457 defer returnTensor(gradDiffscalarProd) 1458 1459 denominator += gradDiffscalarProd.Data().([]float64)[0] 1460 1461 default: 1462 return errors.Errorf(nyiFail, "Barizai-Borwein step", w) 1463 } 1464 } 1465 1466 s.eta = nominator / denominator 1467 1468 if s.useClip && (math.Abs(s.eta) > s.clip) { 1469 if math.Signbit(s.eta) { 1470 s.eta = -s.clip 1471 } else { 1472 s.eta = s.clip 1473 } 1474 } 1475 } 1476 1477 // Save this iteration's values for the next run 1478 for nodeNr, node := range model { 1479 var weights, grad Value 1480 if weights, grad, err = extractWeightGrad(node); err != nil { 1481 return err 1482 } 1483 1484 if false == firstRun { 1485 // return memory for the old dual value used in this iteration 1486 returnDV(s.prevDV[nodeNr]) 1487 } 1488 var oldDV *dualValue 1489 if oldDV, err = newCachedDV(node, weights, grad, false); err != nil { 1490 return err 1491 } 1492 s.prevDV[nodeNr] = oldDV 1493 } 1494 1495 // Update the weights 1496 for _, node := range model { 1497 var weights, grad Value 1498 if weights, grad, err = extractWeightGrad(node); err != nil { 1499 return err 1500 } 1501 1502 switch w := weights.(type) { 1503 case *tensor.Dense: 1504 g, ok := grad.(*tensor.Dense) 1505 if !ok { 1506 return errors.Errorf("Expected a *tensor.Dense in %v. Got %T instead", node, grad) 1507 } 1508 1509 upd, err := tensor.Mul(g, s.eta) 1510 defer returnTensor(upd) 1511 1512 if err != nil { 1513 return errors.Wrap(err, pointWiseMulFail) 1514 } 1515 1516 if _, err = tensor.Sub(w, upd, tensor.UseUnsafe()); err != nil { 1517 return errors.Wrap(err, subFail) 1518 } 1519 1520 g.Zero() 1521 1522 default: 1523 return errors.Errorf(nyiFail, "Barizai-Borwein step", w) 1524 } 1525 } 1526 1527 return nil 1528 }