gorgonia.org/gorgonia@v0.9.17/solvers.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"math"
     5  
     6  	"github.com/chewxy/math32"
     7  	"github.com/pkg/errors"
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  // Solver is anything that does gradient updates.
    12  // The name solvers is stolen from Caffe. A much shorter name than GradientUpdaters
    13  type Solver interface {
    14  	Step([]ValueGrad) error
    15  }
    16  
    17  // ValueGrad is any type that has a value and a grad. This is used for Solvers
    18  type ValueGrad interface {
    19  	Valuer
    20  	Grad() (Value, error)
    21  }
    22  
    23  // Namer is anything that has a name
    24  type Namer interface {
    25  	Name() string
    26  }
    27  
    28  func newCachedDV(n ValueGrad, weights, grad Value, zero bool) (cached *dualValue, err error) {
    29  	cached = new(dualValue)
    30  	if cached.Value, err = CloneValue(weights); err != nil {
    31  		if nm, ok := n.(Namer); ok {
    32  			return nil, errors.Wrapf(err, "Failed to clone weights of %v", nm.Name())
    33  		}
    34  		return nil, errors.Wrap(err, "Failed to clone weights")
    35  	}
    36  	if cached.d, err = CloneValue(grad); err != nil {
    37  		if nm, ok := n.(Namer); ok {
    38  			return nil, errors.Wrapf(err, "Failed to clone grad of %v", nm.Name())
    39  		}
    40  		return nil, errors.Wrap(err, "Failed to clone grad")
    41  	}
    42  	if zero {
    43  		cached.Value = ZeroValue(cached.Value)
    44  		cached.d = ZeroValue(cached.d)
    45  	}
    46  	return
    47  }
    48  
    49  func extractWeightGrad(n ValueGrad) (weights, grad Value, err error) {
    50  	weights = n.Value()
    51  	if grad, err = n.Grad(); err != nil {
    52  		if nm, ok := n.(Namer); ok {
    53  			return weights, nil, errors.Wrapf(err, "No Grad found for %v", nm.Name())
    54  		}
    55  		return weights, nil, errors.Wrap(err, "No Grad found")
    56  	}
    57  	return
    58  }
    59  
    60  // SolverOpt is a function that provides construction options for a Solver
    61  type SolverOpt func(s Solver)
    62  
    63  // WithL2Reg adds a L2 regularization parameter to the solver. By default, the solvers do not use any regularization param
    64  func WithL2Reg(l2reg float64) SolverOpt {
    65  	f := func(s Solver) {
    66  		switch st := s.(type) {
    67  		case *RMSPropSolver:
    68  			st.l2reg = l2reg
    69  			st.useL2Reg = true
    70  		case *AdamSolver:
    71  			st.l2reg = l2reg
    72  			st.useL2Reg = true
    73  		case *VanillaSolver:
    74  			st.l2reg = l2reg
    75  			st.useL2Reg = true
    76  		case *Momentum:
    77  			st.l2reg = l2reg
    78  			st.useL2Reg = true
    79  		}
    80  	}
    81  	return f
    82  }
    83  
    84  // WithL1Reg adds a L1 regularization parameter to the solver. By default, the solvers do not use any regularization param
    85  func WithL1Reg(l1reg float64) SolverOpt {
    86  	f := func(s Solver) {
    87  		switch st := s.(type) {
    88  		case *AdamSolver:
    89  			st.l1reg = l1reg
    90  			st.useL1Reg = true
    91  		case *VanillaSolver:
    92  			st.l1reg = l1reg
    93  			st.useL1Reg = true
    94  		case *Momentum:
    95  			st.l1reg = l1reg
    96  			st.useL1Reg = true
    97  		}
    98  	}
    99  	return f
   100  }
   101  
   102  // WithBatchSize sets the batch size for the solver. Currently only Adam and Vanilla (basic SGD) has batch size support
   103  func WithBatchSize(batch float64) SolverOpt {
   104  	f := func(s Solver) {
   105  		switch st := s.(type) {
   106  		case *AdamSolver:
   107  			st.batch = batch
   108  		case *VanillaSolver:
   109  			st.batch = batch
   110  		case *Momentum:
   111  			st.batch = batch
   112  		}
   113  	}
   114  	return f
   115  }
   116  
   117  // WithEps sets the smoothing factor for the solver.
   118  func WithEps(eps float64) SolverOpt {
   119  	f := func(s Solver) {
   120  		switch st := s.(type) {
   121  		case *RMSPropSolver:
   122  			st.eps = eps
   123  		case *AdamSolver:
   124  			st.eps = eps
   125  		}
   126  	}
   127  	return f
   128  }
   129  
   130  // WithClip clips the gradient if it gets too crazy. By default all solvers do not have any clips attached
   131  func WithClip(clip float64) SolverOpt {
   132  	f := func(s Solver) {
   133  		switch st := s.(type) {
   134  		case *RMSPropSolver:
   135  			st.clip = clip
   136  			st.useClip = true
   137  		case *AdamSolver:
   138  			st.clip = clip
   139  			st.useClip = true
   140  		case *VanillaSolver:
   141  			st.clip = clip
   142  			st.useClip = true
   143  		case *BarzilaiBorweinSolver:
   144  			st.clip = clip
   145  			st.useClip = true
   146  		case *Momentum:
   147  			st.clip = clip
   148  			st.useClip = true
   149  		}
   150  	}
   151  	return f
   152  }
   153  
   154  // WithLearnRate sets the learn rate or step size for the solver.
   155  func WithLearnRate(eta float64) SolverOpt {
   156  	f := func(s Solver) {
   157  		switch st := s.(type) {
   158  		case *RMSPropSolver:
   159  			st.eta = eta
   160  		case *AdamSolver:
   161  			st.eta = eta
   162  		case *VanillaSolver:
   163  			st.eta = eta
   164  		case *BarzilaiBorweinSolver:
   165  			st.eta = eta
   166  		case *Momentum:
   167  			st.eta = eta
   168  		}
   169  	}
   170  	return f
   171  }
   172  
   173  // WithBeta1 sets the beta1 param of the solver. Only works with Adam
   174  func WithBeta1(beta1 float64) SolverOpt {
   175  	f := func(s Solver) {
   176  		switch st := s.(type) {
   177  		case *AdamSolver:
   178  			st.beta1 = beta1
   179  		}
   180  	}
   181  	return f
   182  }
   183  
   184  // WithBeta2 sets the beta1 param of the solver. Only works with Adam
   185  func WithBeta2(beta2 float64) SolverOpt {
   186  	f := func(s Solver) {
   187  		switch st := s.(type) {
   188  		case *AdamSolver:
   189  			st.beta2 = beta2
   190  		}
   191  	}
   192  	return f
   193  }
   194  
   195  // WithRho sets the decay parameter of the RMSProp solver
   196  func WithRho(rho float64) SolverOpt {
   197  	f := func(s Solver) {
   198  		switch st := s.(type) {
   199  		case *RMSPropSolver:
   200  			st.decay = rho
   201  		}
   202  	}
   203  	return f
   204  }
   205  
   206  // WithMomentum sets the momentum of the solver. It is a no-op is the solver's type is not Momentum
   207  func WithMomentum(momentum float64) SolverOpt {
   208  	f := func(s Solver) {
   209  		switch st := s.(type) {
   210  		case *Momentum:
   211  			st.momentum = momentum
   212  		}
   213  	}
   214  	return f
   215  }
   216  
   217  // RMSPropSolver is a solver that implements Geoffrey Hinton's RMSProp gradient descent optimization algorithm.
   218  // http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
   219  type RMSPropSolver struct {
   220  	decay float64 // decay rate/rho
   221  	eps   float64 // smoothing factor
   222  	l2reg float64 // l2 regularization
   223  	clip  float64 // clip value
   224  	eta   float64 // learn rate
   225  
   226  	useClip, useL2Reg bool
   227  
   228  	// unsettable
   229  	cache []*dualValue
   230  }
   231  
   232  // NewRMSPropSolver creates an RMSProp solver with these default values:
   233  //		eta (learn rate)	  : 0.001
   234  //		eps (smoothing factor): 1e-8
   235  //		rho (decay factor)    : 0.999
   236  func NewRMSPropSolver(opts ...SolverOpt) *RMSPropSolver {
   237  	s := &RMSPropSolver{
   238  		decay: 0.999,
   239  		eps:   1e-8,
   240  		eta:   0.001,
   241  	}
   242  
   243  	for _, opt := range opts {
   244  		opt(s)
   245  	}
   246  	return s
   247  }
   248  
   249  // Step steps through each node in the model and applies the RMSProp gradient descent algorithm on the value.
   250  //
   251  // This function will error out if the nodes do not have an associated Grad value.
   252  func (s *RMSPropSolver) Step(model []ValueGrad) (err error) {
   253  	if s.cache == nil {
   254  		s.cache = make([]*dualValue, len(model))
   255  	}
   256  
   257  	for i, n := range model {
   258  		var weights, grad Value
   259  		if weights, grad, err = extractWeightGrad(n); err != nil {
   260  			return err
   261  		}
   262  
   263  		var cached *dualValue
   264  		if cached = s.cache[i]; cached == nil {
   265  			if cached, err = newCachedDV(n, weights, grad, true); err != nil {
   266  				return err
   267  			}
   268  			s.cache[i] = cached
   269  		}
   270  
   271  		cv := cached.Value
   272  		// cw = cw*decay + (1-decay) * grad²
   273  		switch cw := cv.(type) {
   274  		case *tensor.Dense:
   275  			var gt, gt2, w, regularized tensor.Tensor
   276  			var decay, omdecay, stepSize, eps, l2reg, clip, negClip interface{}
   277  			switch cw.Dtype() {
   278  			case tensor.Float64:
   279  				decay = s.decay
   280  				omdecay = 1.0 - s.decay
   281  				stepSize = -s.eta
   282  				eps = s.eps
   283  				l2reg = s.l2reg
   284  				clip = s.clip
   285  				negClip = -s.clip
   286  			case tensor.Float32:
   287  				decay = float32(s.decay)
   288  				omdecay = float32(1.0 - s.decay)
   289  				stepSize = float32(-s.eta)
   290  				eps = float32(s.eps)
   291  				l2reg = float32(s.l2reg)
   292  				clip = float32(s.clip)
   293  				negClip = float32(-s.clip)
   294  			}
   295  
   296  			gt = grad.(tensor.Tensor)
   297  			if gt2, err = tensor.Square(gt); err != nil {
   298  				return errors.Wrap(err, pointWiseSquareFail)
   299  			}
   300  			tensor.Mul(cw, decay, tensor.UseUnsafe())
   301  			tensor.Mul(gt2, omdecay, tensor.UseUnsafe())
   302  			tensor.Add(cw, gt2, tensor.UseUnsafe())
   303  			defer returnTensor(gt2)
   304  
   305  			if s.useClip {
   306  				if _, err = tensor.Clamp(gt, negClip, clip, tensor.UseUnsafe()); err != nil {
   307  					return errors.Wrap(err, clampFail)
   308  				}
   309  			}
   310  
   311  			// regularize
   312  			var upd tensor.Tensor
   313  			if upd, err = tensor.Add(cw, eps); err != nil {
   314  				return errors.Wrap(err, "Failed to carry Add()")
   315  			}
   316  
   317  			if _, err = tensor.InvSqrt(upd, tensor.UseUnsafe()); err != nil {
   318  				return errors.Wrap(err, invSqrtFail)
   319  			}
   320  			if _, err = tensor.Mul(gt, stepSize, tensor.UseUnsafe()); err != nil {
   321  				return errors.Wrap(err, pointWiseMulFail)
   322  			}
   323  			if _, err = tensor.Mul(upd, gt, tensor.UseUnsafe()); err != nil {
   324  				return errors.Wrap(err, pointWiseMulFail)
   325  			}
   326  
   327  			// update
   328  			w = weights.(*tensor.Dense)
   329  			if s.useL2Reg {
   330  				if regularized, err = tensor.Mul(w, l2reg); err != nil {
   331  					return errors.Wrap(err, pointWiseMulFail)
   332  				}
   333  				if _, err = tensor.Sub(upd, regularized, tensor.UseUnsafe()); err != nil {
   334  					return errors.Wrap(err, subFail)
   335  				}
   336  				defer returnTensor(regularized)
   337  			}
   338  
   339  			if _, err = tensor.Add(w, upd, tensor.UseUnsafe()); err != nil {
   340  				return errors.Wrap(err, addFail)
   341  			}
   342  			defer returnTensor(upd)
   343  
   344  			// zero all
   345  			gt.Zero()
   346  
   347  		case *F32:
   348  			decay := float32(s.decay)
   349  			omdecay := float32(1.0 - s.decay)
   350  			stepSize := float32(s.eta)
   351  			eps := float32(s.eps)
   352  			l2reg := float32(s.l2reg)
   353  
   354  			gs := grad.(*F32).any()
   355  			c := cw.any()
   356  			c = c*decay + omdecay*gs*gs
   357  
   358  			cached.Value, _ = anyToScalar(c)
   359  
   360  			w := weights.(*F32).any()
   361  			upd := -stepSize*gs/math32.Sqrt(c+eps) - l2reg*w
   362  			w += upd
   363  
   364  			// because scalar values are copies, and not pointers, we have to actually re-update the dualValu in model[i]
   365  			*(weights.(*F32)) = F32(w)
   366  			*(grad.(*F32)) = F32(0.0)
   367  		case *F64:
   368  			decay := s.decay
   369  			omdecay := 1.0 - s.decay
   370  			stepSize := s.eta
   371  			eps := s.eps
   372  			l2reg := s.l2reg
   373  
   374  			gs := grad.(*F64).any()
   375  			c := cw.any()
   376  			c = c*decay + omdecay*gs*gs
   377  
   378  			cached.Value, _ = anyToScalar(c)
   379  
   380  			w := weights.(*F64).any()
   381  			upd := -stepSize*gs/math.Sqrt(c+eps) - l2reg*w
   382  			w += upd
   383  
   384  			// because scalar values are copies, and not pointers, we have to actually re-update the dualValu in model[i]
   385  			*(weights.(*F64)) = F64(w)
   386  			*(grad.(*F64)) = F64(0.0)
   387  		default:
   388  		}
   389  		solverLogf("AFTER %1.1s", n)
   390  	}
   391  	return nil
   392  }
   393  
   394  // AdamSolver is the Adaptive Moment Estimation solver (basically RMSProp on steroids).
   395  // Paper: http://arxiv.org/abs/1412.6980
   396  //
   397  // We overload the purpose of existing data structure of a *dualValue. However, instead of just holding a value and its derivative,
   398  // the cache's *dualValues hold the Means of gradients (in .Value) and the variances of the gradients (in .d)
   399  type AdamSolver struct {
   400  	eta   float64 // learn rate
   401  	eps   float64 // smoothing
   402  	beta1 float64 // modifier for means
   403  	beta2 float64 // modifier for variances
   404  	clip  float64 // clip gradients
   405  	l1reg float64 // l1 regularization parameter
   406  	l2reg float64 // l2 regularization parameter
   407  	batch float64 // batch size
   408  
   409  	useClip, useL1Reg, useL2Reg bool
   410  
   411  	// unsettable
   412  	iter  int
   413  	cache []*dualValue
   414  }
   415  
   416  // NewAdamSolver creates an Adam solver with these default values:
   417  //		eta (learn rate)	  	: 0.001
   418  //		eps (smoothing factor)		: 1e-8
   419  //		beta1				: 0.9
   420  //		beta2 				: 0.999
   421  //		batch				: 1
   422  func NewAdamSolver(opts ...SolverOpt) *AdamSolver {
   423  	s := &AdamSolver{
   424  		eta:   0.001,
   425  		eps:   1e-8,
   426  		beta1: 0.9,
   427  		beta2: 0.999,
   428  		batch: 1,
   429  	}
   430  
   431  	for _, opt := range opts {
   432  		opt(s)
   433  	}
   434  	return s
   435  }
   436  
   437  // Step steps through each node in the model and applies the Adaptive Moment Estimation gradient descent algorithm on the value.
   438  //
   439  // This function will error out if the nodes do not have an associated Grad value.
   440  func (s *AdamSolver) Step(model []ValueGrad) (err error) {
   441  	if s.cache == nil {
   442  		s.cache = make([]*dualValue, len(model))
   443  	}
   444  
   445  	s.iter++
   446  	correction1 := (1 - math.Pow(s.beta1, float64(s.iter)))
   447  	correction2 := (1 - math.Pow(s.beta2, float64(s.iter)))
   448  
   449  	for i, n := range model {
   450  		var weights, grad Value
   451  		if weights, grad, err = extractWeightGrad(n); err != nil {
   452  			return err
   453  		}
   454  
   455  		var cached *dualValue
   456  		if cached = s.cache[i]; cached == nil {
   457  			if cached, err = newCachedDV(n, weights, grad, true); err != nil {
   458  				return err
   459  			}
   460  			s.cache[i] = cached
   461  		}
   462  
   463  		cvm := cached.Value // means of gradients
   464  		cvv := cached.d     // variances of gradients
   465  
   466  		switch m := cvm.(type) {
   467  		case *tensor.Dense:
   468  			g := grad.(*tensor.Dense)
   469  			w := weights.(*tensor.Dense)
   470  			v := cvv.(*tensor.Dense)
   471  
   472  			var l1reg, l2reg, clip, negClip, beta1, beta2, omβ1, omβ2, eps, eta, onePerBatch interface{}
   473  			var correctionV1, correctionV2 interface{}
   474  			switch m.Dtype() {
   475  			case tensor.Float64:
   476  				l1reg = s.l1reg
   477  				l2reg = s.l2reg
   478  				clip = s.clip
   479  				negClip = -s.clip
   480  				beta1 = s.beta1
   481  				beta2 = s.beta2
   482  				omβ1 = float64(1) - s.beta1
   483  				omβ2 = float64(1) - s.beta2
   484  				eps = s.eps
   485  				eta = -s.eta
   486  				onePerBatch = float64(1) / s.batch
   487  				correctionV1 = float64(1) / float64(correction1)
   488  				correctionV2 = float64(1) / float64(correction2)
   489  			case tensor.Float32:
   490  				l1reg = float32(s.l1reg)
   491  				l2reg = float32(s.l2reg)
   492  				clip = float32(s.clip)
   493  				negClip = -float32(s.clip)
   494  				beta1 = float32(s.beta1)
   495  				beta2 = float32(s.beta2)
   496  				omβ1 = float32(1) - float32(s.beta1)
   497  				omβ2 = float32(1) - float32(s.beta2)
   498  				eps = float32(s.eps)
   499  				eta = -float32(s.eta)
   500  				onePerBatch = float32(1) / float32(s.batch)
   501  				correctionV1 = float32(1) / float32(correction1)
   502  				correctionV2 = float32(1) / float32(correction2)
   503  			}
   504  
   505  			// prep the regularization of gradients
   506  			if s.useL1Reg {
   507  				var l1regs tensor.Tensor
   508  				if l1regs, err = tensor.Sign(w); err != nil {
   509  					errors.Wrap(err, signFail)
   510  				}
   511  				if l1regs, err = tensor.Mul(l1reg, l1regs, tensor.UseUnsafe()); err != nil {
   512  					return errors.Wrap(err, pointWiseMulFail)
   513  				}
   514  				if _, err = tensor.Add(g, l1regs, tensor.UseUnsafe()); err != nil {
   515  					return errors.Wrap(err, addFail)
   516  				}
   517  				defer returnTensor(l1regs)
   518  			}
   519  
   520  			if s.useL2Reg {
   521  				var l2regs tensor.Tensor
   522  				if l2regs, err = tensor.Mul(w, l2reg); err != nil {
   523  					return errors.Wrap(err, pointWiseMulFail)
   524  				}
   525  
   526  				if _, err = tensor.Add(g, l2regs, tensor.UseUnsafe()); err != nil {
   527  					return errors.Wrap(err, addFail)
   528  				}
   529  
   530  				defer returnTensor(l2regs)
   531  			}
   532  
   533  			if s.batch > 1 {
   534  				if _, err = tensor.Mul(g, onePerBatch, tensor.UseUnsafe()); err != nil {
   535  					return errors.Wrap(err, pointWiseMulFail)
   536  				}
   537  			}
   538  
   539  			if s.useClip && s.clip > 0 {
   540  				if _, err = tensor.Clamp(g, negClip, clip, tensor.UseUnsafe()); err != nil {
   541  					return errors.Wrap(err, clampFail)
   542  				}
   543  			}
   544  
   545  			// prep done. Now let's apply the formula:
   546  			// the formula is
   547  			//		(β_1 * m_t-1) + (1 - β_1)g_t ..................	1
   548  			//		(β_2 * v_t-1) + (1 - β_2)*(g_t)² .............	2
   549  
   550  			// equation(1)
   551  			t1 := g.Clone().(*tensor.Dense)
   552  			if _, err = tensor.Mul(t1, omβ1, tensor.UseUnsafe()); err != nil {
   553  				return errors.Wrap(err, pointWiseMulFail)
   554  			}
   555  
   556  			// equation(2)
   557  			if _, err = tensor.Mul(g, g, tensor.UseUnsafe()); err != nil {
   558  				return errors.Wrap(err, pointWiseMulFail)
   559  			}
   560  			if _, err = tensor.Mul(g, omβ2, tensor.UseUnsafe()); err != nil {
   561  				return errors.Wrap(err, pointWiseMulFail)
   562  			}
   563  
   564  			// equation (1)
   565  			if _, err = tensor.Mul(m, beta1, tensor.WithIncr(t1)); err != nil {
   566  				return errors.Wrap(err, pointWiseMulFail)
   567  			}
   568  
   569  			// equation (2)
   570  			if _, err = tensor.Mul(v, beta2, tensor.WithIncr(g)); err != nil {
   571  				return errors.Wrap(err, pointWiseMulFail)
   572  			}
   573  
   574  			defer returnTensor(m)
   575  			defer returnTensor(v)
   576  			cached.SetValue(t1)
   577  			cached.SetDeriv(g.Clone().(*tensor.Dense))
   578  
   579  			// now deal with the hats
   580  			mHats := t1.Clone().(*tensor.Dense)
   581  			vHats := g.Clone().(*tensor.Dense)
   582  
   583  			if _, err = tensor.Mul(mHats, correctionV1, tensor.UseUnsafe()); err != nil {
   584  				return errors.Wrap(err, pointWiseMulFail)
   585  			}
   586  
   587  			if _, err = tensor.Mul(vHats, correctionV2, tensor.UseUnsafe()); err != nil {
   588  				return errors.Wrap(err, pointWiseMulFail)
   589  			}
   590  
   591  			// update := -eta * mHat / (sqrt(vHat) + epsilon)
   592  			if _, err = tensor.Sqrt(vHats, tensor.UseUnsafe()); err != nil {
   593  				return // TODO: rewrite this to use InvSqrt
   594  			}
   595  
   596  			if _, err = tensor.Add(vHats, eps, tensor.UseUnsafe()); err != nil {
   597  				return
   598  			}
   599  
   600  			if _, err = tensor.Mul(mHats, eta, tensor.UseUnsafe()); err != nil {
   601  				return errors.Wrap(err, pointWiseMulFail)
   602  			}
   603  
   604  			if _, err = tensor.Div(mHats, vHats, tensor.WithIncr(w)); err != nil {
   605  				return
   606  			}
   607  
   608  			defer returnTensor(vHats)
   609  			defer returnTensor(mHats)
   610  
   611  			if _, err = tensor.Add(w, mHats, tensor.UseUnsafe()); err != nil {
   612  				return errors.Wrap(err, addFail)
   613  			}
   614  
   615  			g.Zero()
   616  
   617  		case *F32:
   618  			g := grad.(*F32).any()
   619  			w := weights.(*F32).any()
   620  			v := cvv.(*F32).any()
   621  			mm := m.any()
   622  
   623  			l1reg := float32(s.l1reg)
   624  			l2reg := float32(s.l2reg)
   625  			batch := float32(s.batch)
   626  			clip := float32(s.clip)
   627  			beta1 := float32(s.beta1)
   628  			beta2 := float32(s.beta2)
   629  			eps := float32(s.eps)
   630  			eta := float32(s.eta)
   631  
   632  			if s.useL1Reg {
   633  				if w < 0 {
   634  					l1reg = -l1reg
   635  				}
   636  				g += l1reg
   637  			}
   638  
   639  			if s.useL2Reg {
   640  				l2reg *= w
   641  				g += l2reg
   642  			}
   643  
   644  			if batch > 1 {
   645  				g *= (1 / batch)
   646  			}
   647  
   648  			if s.useClip {
   649  				if g > clip {
   650  					g = clip
   651  				} else if g < -clip {
   652  					g = -clip
   653  				}
   654  			}
   655  
   656  			newM := (beta1 * mm) + (1-beta1)*g
   657  			newV := (beta2 * v) + (1-beta2)*g*g
   658  
   659  			cached.Value, _ = anyToScalar(newM)
   660  			cached.d, _ = anyToScalar(newV)
   661  
   662  			mHat := (1 / float32(correction1)) * newM
   663  			vHat := (1 / float32(correction2)) * newV
   664  
   665  			upd := -eta * mHat / (float32(math.Sqrt(float64(vHat))) + eps)
   666  			w += upd
   667  
   668  			*(weights.(*F32)) = F32(w)
   669  			*(grad.(*F32)) = F32(0.0)
   670  		case *F64:
   671  			g := grad.(*F64).any()
   672  			w := weights.(*F64).any()
   673  			v := cvv.(*F64).any()
   674  			mm := m.any()
   675  
   676  			l1reg := s.l1reg
   677  			l2reg := s.l2reg
   678  			batch := s.batch
   679  			clip := s.clip
   680  			beta1 := s.beta1
   681  			beta2 := s.beta2
   682  			eps := s.eps
   683  			eta := s.eta
   684  
   685  			if s.useL1Reg {
   686  				if w < 0 {
   687  					l1reg = -l1reg
   688  				}
   689  				g += l1reg
   690  			}
   691  
   692  			if s.useL2Reg {
   693  				l2reg *= w
   694  				g += l2reg
   695  			}
   696  
   697  			if batch > 1 {
   698  				g *= (1 / batch)
   699  			}
   700  
   701  			if s.useClip {
   702  				if g > clip {
   703  					g = clip
   704  				} else if g < -clip {
   705  					g = -clip
   706  				}
   707  			}
   708  
   709  			newM := (beta1 * mm) + (1-beta1)*g
   710  			newV := (beta2 * v) + (1-beta2)*g*g
   711  
   712  			cached.Value, _ = anyToScalar(newM)
   713  			cached.d, _ = anyToScalar(newV)
   714  
   715  			mHat := (1 / correction1) * newM
   716  			vHat := (1 / correction2) * newV
   717  
   718  			upd := -eta * mHat / (math.Sqrt(vHat) + eps)
   719  			w += upd
   720  
   721  			*(weights.(*F64)) = F64(w)
   722  			*(grad.(*F64)) = F64(0.0)
   723  
   724  		default:
   725  			err = errors.Errorf(nyiTypeFail, "AdamSolver", cvm)
   726  			return
   727  		}
   728  
   729  	}
   730  	return
   731  }
   732  
   733  // VanillaSolver is your bog standard stochastic gradient descent optimizer. There are no fancy features to this
   734  type VanillaSolver struct {
   735  	eta   float64 // learn rate
   736  	clip  float64 // clip gradients
   737  	l1reg float64 // l1 regularization parameter
   738  	l2reg float64 // l2 regularization parameter
   739  	batch float64 // batch size
   740  
   741  	useClip, useL1Reg, useL2Reg bool
   742  }
   743  
   744  // NewVanillaSolver creates a new VanillaSolver with sane-ish default values
   745  func NewVanillaSolver(opts ...SolverOpt) *VanillaSolver {
   746  	s := &VanillaSolver{
   747  		batch: 1,
   748  		eta:   0.001,
   749  	}
   750  	for _, opt := range opts {
   751  		opt(s)
   752  	}
   753  	return s
   754  }
   755  
   756  // Step steps through each node in the model and applies the most basic gradient descent algorithm on the value.
   757  //
   758  // This function will error out if the nodes do not have an associated Grad value.
   759  func (s *VanillaSolver) Step(model []ValueGrad) (err error) {
   760  	for _, n := range model {
   761  		var weights, grad Value
   762  		if weights, grad, err = extractWeightGrad(n); err != nil {
   763  			return err
   764  		}
   765  		switch w := weights.(type) {
   766  		case *tensor.Dense:
   767  			g := grad.(*tensor.Dense)
   768  
   769  			var l1reg, l2reg, clip, negClip, eta interface{}
   770  			var onePerBatch interface{}
   771  			switch w.Dtype() {
   772  			case tensor.Float64:
   773  				l1reg = s.l1reg
   774  				l2reg = s.l2reg
   775  				clip = s.clip
   776  				negClip = -s.clip
   777  				eta = -s.eta
   778  				onePerBatch = float64(1) / s.batch
   779  			case tensor.Float32:
   780  				l1reg = float32(s.l1reg)
   781  				l2reg = float32(s.l2reg)
   782  				clip = float32(s.clip)
   783  				negClip = float32(-s.clip)
   784  				eta = float32(-s.eta)
   785  				onePerBatch = float32(1) / float32(s.batch)
   786  			}
   787  			// prep the regularization of gradients
   788  			var l1regs, l2regs tensor.Tensor
   789  			if s.useL1Reg {
   790  				if l1regs, err = tensor.Sign(w); err != nil {
   791  					return errors.Wrap(err, signFail)
   792  				}
   793  
   794  				if l1regs, err = tensor.Mul(l1reg, l1regs, tensor.UseUnsafe()); err != nil {
   795  					return errors.Wrap(err, pointWiseMulFail)
   796  				}
   797  
   798  				if _, err = tensor.Add(g, l1regs, tensor.UseUnsafe()); err != nil {
   799  					return errors.Wrap(err, addFail)
   800  				}
   801  
   802  				defer returnTensor(l1regs)
   803  			}
   804  
   805  			if s.useL2Reg {
   806  				if l2regs, err = tensor.Mul(w, l2reg); err != nil {
   807  					return errors.Wrap(err, pointWiseMulFail)
   808  				}
   809  
   810  				if _, err = tensor.Add(g, l2regs, tensor.UseUnsafe()); err != nil {
   811  					return errors.Wrap(err, addFail)
   812  				}
   813  
   814  				defer returnTensor(l2regs)
   815  			}
   816  
   817  			if s.batch > 1 {
   818  				if _, err = tensor.Mul(g, onePerBatch, tensor.UseUnsafe()); err != nil {
   819  					return errors.Wrap(err, pointWiseMulFail)
   820  				}
   821  			}
   822  
   823  			if s.useClip && s.clip > 0 {
   824  				if _, err = tensor.Clamp(g, negClip, clip, tensor.UseUnsafe()); err != nil {
   825  					return errors.Wrap(err, clampFail)
   826  				}
   827  			}
   828  
   829  			if _, err = tensor.Mul(g, eta, tensor.UseUnsafe()); err != nil {
   830  				return errors.Wrap(err, pointWiseMulFail)
   831  			}
   832  
   833  			if _, err = tensor.Add(w, g, tensor.UseUnsafe()); err != nil {
   834  				return errors.Wrap(err, addFail)
   835  			}
   836  
   837  			g.Zero()
   838  
   839  		case *F32:
   840  			g := grad.(*F32).any()
   841  			wv := w.any()
   842  
   843  			l1reg := float32(s.l1reg)
   844  			l2reg := float32(s.l2reg)
   845  			batch := float32(s.batch)
   846  			clip := float32(s.clip)
   847  			eta := float32(s.eta)
   848  
   849  			if s.useL1Reg {
   850  				if wv < 0 {
   851  					l1reg = -l1reg
   852  				}
   853  				g += l1reg
   854  			}
   855  
   856  			if s.useL2Reg {
   857  				l2reg *= wv
   858  				g += l2reg
   859  			}
   860  
   861  			if batch > 1 {
   862  				g *= (1 / batch)
   863  			}
   864  
   865  			if s.useClip {
   866  				if g > clip {
   867  					g = clip
   868  				} else if g < -clip {
   869  					g = -clip
   870  				}
   871  			}
   872  
   873  			upd := -eta * g
   874  			wv += upd
   875  
   876  			*(weights.(*F32)) = F32(wv)
   877  			*(grad.(*F32)) = F32(0.0)
   878  		case *F64:
   879  			g := grad.(*F64).any()
   880  			wv := w.any()
   881  
   882  			l1reg := s.l1reg
   883  			l2reg := s.l2reg
   884  			batch := s.batch
   885  			clip := s.clip
   886  			eta := s.eta
   887  
   888  			if s.useL1Reg {
   889  				if wv < 0 {
   890  					l1reg = -l1reg
   891  				}
   892  				g += l1reg
   893  			}
   894  
   895  			if s.useL2Reg {
   896  				l2reg *= wv
   897  				g += l2reg
   898  			}
   899  
   900  			if batch > 1 {
   901  				g *= (1 / batch)
   902  			}
   903  
   904  			if s.useClip {
   905  				if g > clip {
   906  					g = clip
   907  				} else if g < -clip {
   908  					g = -clip
   909  				}
   910  			}
   911  
   912  			upd := -eta * g
   913  			wv += upd
   914  
   915  			*(weights.(*F64)) = F64(wv)
   916  			*(grad.(*F64)) = F64(0.0)
   917  		default:
   918  			return errors.Errorf(nyiFail, "VanillaSolver.step", w)
   919  		}
   920  	}
   921  	return
   922  }
   923  
   924  // Momentum is the stochastic gradient descent optimizer with momentum item.
   925  type Momentum struct {
   926  	eta      float64 // learn rate
   927  	momentum float64 // momentum
   928  	clip     float64 // clip gradients
   929  	l1reg    float64 // l1 regularization parameter
   930  	l2reg    float64 // l2 regularization parameter
   931  	batch    float64 // batch size
   932  
   933  	useClip, useL1Reg, useL2Reg bool
   934  
   935  	cache []*dualValue
   936  }
   937  
   938  // NewMomentum creates a new Momentum with sane-ish default values
   939  func NewMomentum(opts ...SolverOpt) *Momentum {
   940  	s := &Momentum{
   941  		batch:    1,
   942  		eta:      0.001,
   943  		momentum: 0.9,
   944  	}
   945  	for _, opt := range opts {
   946  		opt(s)
   947  	}
   948  	return s
   949  }
   950  
   951  // Step steps through each node in the model and applies the Momentum stochastic gradient descent algorithm on the value.
   952  //
   953  // This function will error out if the nodes do not have an associated Grad value.
   954  func (s *Momentum) Step(model []ValueGrad) (err error) {
   955  	if s.cache == nil {
   956  		s.cache = make([]*dualValue, len(model))
   957  	}
   958  
   959  	for i, n := range model {
   960  		var weights, grad Value
   961  		if weights, grad, err = extractWeightGrad(n); err != nil {
   962  			return err
   963  		}
   964  
   965  		var cached *dualValue
   966  		if cached = s.cache[i]; cached == nil {
   967  			if cached, err = newCachedDV(n, weights, grad, true); err != nil {
   968  				return err
   969  			}
   970  			s.cache[i] = cached
   971  		}
   972  
   973  		cv := cached.Value
   974  		// cw = cw * momentum - eta * grad
   975  		// w = w + cw
   976  		switch cw := cv.(type) {
   977  		case *tensor.Dense:
   978  			w := weights.(*tensor.Dense)
   979  			g := grad.(*tensor.Dense)
   980  
   981  			var l1reg, l2reg, clip, negClip, eta, momentum, onePerBatch interface{}
   982  			switch cw.Dtype() {
   983  			case tensor.Float64:
   984  				l1reg = s.l1reg
   985  				l2reg = s.l2reg
   986  				clip = s.clip
   987  				negClip = -s.clip
   988  				eta = -s.eta
   989  				momentum = s.momentum
   990  				onePerBatch = float64(1) / s.batch
   991  			case tensor.Float32:
   992  				l1reg = float32(s.l1reg)
   993  				l2reg = float32(s.l2reg)
   994  				clip = float32(s.clip)
   995  				negClip = float32(-s.clip)
   996  				eta = float32(-s.eta)
   997  				momentum = float32(s.momentum)
   998  				onePerBatch = float32(1) / float32(s.batch)
   999  			}
  1000  
  1001  			// prep the regularization of gradients
  1002  			var l1regs, l2regs tensor.Tensor
  1003  			if s.useL1Reg {
  1004  				if l1regs, err = tensor.Sign(cw); err != nil {
  1005  					return errors.Wrap(err, signFail)
  1006  				}
  1007  
  1008  				if l1regs, err = tensor.Mul(l1reg, l1regs, tensor.UseUnsafe()); err != nil {
  1009  					return errors.Wrap(err, pointWiseMulFail)
  1010  				}
  1011  
  1012  				if _, err = tensor.Add(g, l1regs, tensor.UseUnsafe()); err != nil {
  1013  					return errors.Wrap(err, addFail)
  1014  				}
  1015  
  1016  				defer returnTensor(l1regs)
  1017  			}
  1018  
  1019  			if s.useL2Reg {
  1020  				if l2regs, err = tensor.Mul(cw, l2reg); err != nil {
  1021  					return errors.Wrap(err, pointWiseMulFail)
  1022  				}
  1023  
  1024  				if _, err = tensor.Add(g, l2regs, tensor.UseUnsafe()); err != nil {
  1025  					return errors.Wrap(err, addFail)
  1026  				}
  1027  
  1028  				defer returnTensor(l2regs)
  1029  			}
  1030  
  1031  			if s.batch > 1 {
  1032  				if _, err = tensor.Mul(g, onePerBatch, tensor.UseUnsafe()); err != nil {
  1033  					return errors.Wrap(err, pointWiseMulFail)
  1034  				}
  1035  			}
  1036  
  1037  			if s.useClip && s.clip > 0 {
  1038  				if _, err = tensor.Clamp(g, negClip, clip, tensor.UseUnsafe()); err != nil {
  1039  					return errors.Wrap(err, clampFail)
  1040  				}
  1041  			}
  1042  
  1043  			// momentum
  1044  			if _, err = tensor.Mul(g, eta, tensor.UseUnsafe()); err != nil {
  1045  				return errors.Wrap(err, pointWiseMulFail)
  1046  			}
  1047  
  1048  			// cw * momentum
  1049  			if _, err = tensor.Mul(cw, momentum, tensor.UseUnsafe()); err != nil {
  1050  				return errors.Wrap(err, pointWiseMulFail)
  1051  			}
  1052  
  1053  			//  cw * momentum - eta * grad
  1054  			if _, err = tensor.Add(cw, g, tensor.UseUnsafe()); err != nil {
  1055  				return errors.Wrap(err, pointWiseMulFail)
  1056  			}
  1057  
  1058  			if _, err = tensor.Add(w, cw, tensor.UseUnsafe()); err != nil {
  1059  				return errors.Wrap(err, addFail)
  1060  			}
  1061  
  1062  			g.Zero()
  1063  
  1064  		case *F32:
  1065  			l1reg := float32(s.l1reg)
  1066  			l2reg := float32(s.l2reg)
  1067  			batch := float32(s.batch)
  1068  			clip := float32(s.clip)
  1069  			eta := float32(s.eta)
  1070  			momentum := float32(s.momentum)
  1071  
  1072  			g := grad.(*F32).any()
  1073  			w := weights.(*F32).any()
  1074  			c := cw.any()
  1075  
  1076  			if s.useL1Reg {
  1077  				if w < 0 {
  1078  					l1reg = -l1reg
  1079  				}
  1080  				g += l1reg
  1081  			}
  1082  
  1083  			if s.useL2Reg {
  1084  				l2reg *= w
  1085  				g += l2reg
  1086  			}
  1087  
  1088  			if batch > 1 {
  1089  				g *= (1 / batch)
  1090  			}
  1091  
  1092  			if s.useClip {
  1093  				if g > clip {
  1094  					g = clip
  1095  				} else if g < -clip {
  1096  					g = -clip
  1097  				}
  1098  			}
  1099  
  1100  			c = c*momentum - eta*g
  1101  			w += c
  1102  
  1103  			*(weights.(*F32)) = F32(w)
  1104  			*(grad.(*F32)) = F32(0.0)
  1105  		case *F64:
  1106  			l1reg := s.l1reg
  1107  			l2reg := s.l2reg
  1108  			batch := s.batch
  1109  			clip := s.clip
  1110  			eta := s.eta
  1111  			momentum := s.momentum
  1112  
  1113  			g := grad.(*F64).any()
  1114  			w := weights.(*F64).any()
  1115  			c := cw.any()
  1116  
  1117  			if s.useL1Reg {
  1118  				if w < 0 {
  1119  					l1reg = -l1reg
  1120  				}
  1121  				g += l1reg
  1122  			}
  1123  
  1124  			if s.useL2Reg {
  1125  				l2reg *= w
  1126  				g += l2reg
  1127  			}
  1128  
  1129  			if batch > 1 {
  1130  				g *= (1 / batch)
  1131  			}
  1132  
  1133  			if s.useClip {
  1134  				if g > clip {
  1135  					g = clip
  1136  				} else if g < -clip {
  1137  					g = -clip
  1138  				}
  1139  			}
  1140  
  1141  			c = c*momentum - eta*g
  1142  			w += c
  1143  
  1144  			*(weights.(*F64)) = F64(w)
  1145  			*(grad.(*F64)) = F64(0.0)
  1146  		default:
  1147  			return errors.Errorf(nyiFail, "Momentum.step", cv)
  1148  		}
  1149  	}
  1150  	return
  1151  }
  1152  
  1153  // AdaGradSolver is the solver that does adaptive gradient descent. Read the paper: http://jmlr.org/papers/v12/duchi11a.html
  1154  type AdaGradSolver struct {
  1155  	eta   float64 // learn rate
  1156  	eps   float64 // smoothing factor
  1157  	l1Reg float64 // l1reg param
  1158  	l2reg float64 // l2reg param
  1159  	clip  float64 // clip at
  1160  
  1161  	useL2Reg, useClip bool
  1162  
  1163  	cache []*dualValue
  1164  }
  1165  
  1166  // NewAdaGradSolver creates a new AdaGradSolver with sane-ish default values
  1167  func NewAdaGradSolver(opts ...SolverOpt) *AdaGradSolver {
  1168  	s := &AdaGradSolver{
  1169  		eta: 0.001,
  1170  		eps: 1e-8,
  1171  	}
  1172  
  1173  	for _, opt := range opts {
  1174  		opt(s)
  1175  	}
  1176  	return s
  1177  }
  1178  
  1179  // Step steps through each node in the model and applies the Adaptive Gradient gradient descent algorithm on the value.
  1180  //
  1181  // This function will error out if the nodes do not have an associated Grad value.
  1182  func (s *AdaGradSolver) Step(model []ValueGrad) (err error) {
  1183  	if s.cache == nil {
  1184  		s.cache = make([]*dualValue, len(model))
  1185  	}
  1186  
  1187  	for i, n := range model {
  1188  		var weights, grad Value
  1189  		if weights, grad, err = extractWeightGrad(n); err != nil {
  1190  			return err
  1191  		}
  1192  
  1193  		var cached *dualValue
  1194  		if cached = s.cache[i]; cached == nil {
  1195  			if cached, err = newCachedDV(n, weights, grad, true); err != nil {
  1196  				return err
  1197  			}
  1198  			s.cache[i] = cached
  1199  		}
  1200  
  1201  		cv := cached.Value
  1202  
  1203  		switch cw := cv.(type) {
  1204  		case *tensor.Dense:
  1205  			var w, g, c, g2, regularized tensor.Tensor
  1206  
  1207  			var l2reg, clip, negClip, eps, eta interface{}
  1208  			switch cw.Dtype() {
  1209  			case tensor.Float64:
  1210  				l2reg = s.l2reg
  1211  				clip = s.clip
  1212  				negClip = -s.clip
  1213  				eps = s.eps
  1214  				eta = -s.eta
  1215  			case tensor.Float32:
  1216  				l2reg = float32(s.l2reg)
  1217  				clip = float32(s.clip)
  1218  				negClip = float32(-s.clip)
  1219  				eps = float32(s.eps)
  1220  				eta = float32(-s.eta)
  1221  			}
  1222  
  1223  			g = grad.(*tensor.Dense)
  1224  			if g2, err = tensor.Square(g); err != nil {
  1225  				return errors.Wrap(err, pointWiseSquareFail)
  1226  			}
  1227  
  1228  			c = cw
  1229  			tensor.Add(c, g2, tensor.UseUnsafe())
  1230  			defer returnTensor(g2)
  1231  
  1232  			if s.useClip {
  1233  				if _, err = tensor.Clamp(g, negClip, clip, tensor.UseUnsafe()); err != nil {
  1234  					return errors.Wrap(err, clampFail)
  1235  				}
  1236  			}
  1237  
  1238  			// update
  1239  			var upd tensor.Tensor
  1240  			if upd, err = tensor.Add(c, eps); err != nil {
  1241  				return errors.Wrap(err, addFail)
  1242  			}
  1243  
  1244  			if _, err = tensor.InvSqrt(upd, tensor.UseUnsafe()); err != nil {
  1245  				return errors.Wrap(err, invSqrtFail)
  1246  			}
  1247  			if _, err = tensor.Mul(g, eta, tensor.UseUnsafe()); err != nil {
  1248  				return errors.Wrap(err, pointWiseMulFail)
  1249  			}
  1250  
  1251  			if _, err = tensor.Mul(upd, g, tensor.UseUnsafe()); err != nil {
  1252  				return errors.Wrap(err, pointWiseMulFail)
  1253  			}
  1254  
  1255  			// regularize
  1256  			w = weights.(*tensor.Dense)
  1257  
  1258  			if s.useL2Reg {
  1259  				if regularized, err = tensor.Mul(w, l2reg); err != nil {
  1260  					return errors.Wrap(err, pointWiseMulFail)
  1261  				}
  1262  
  1263  				if _, err = tensor.Sub(upd, regularized, tensor.UseUnsafe()); err != nil {
  1264  					return errors.Wrap(err, subFail)
  1265  				}
  1266  
  1267  				defer returnTensor(regularized)
  1268  			}
  1269  
  1270  			if _, err = tensor.Add(w, upd, tensor.UseUnsafe()); err != nil {
  1271  				return errors.Wrap(err, addFail)
  1272  			}
  1273  			defer returnTensor(upd)
  1274  
  1275  			// zero all
  1276  			g.Zero()
  1277  
  1278  		case *F32:
  1279  			var w, g, c float32
  1280  
  1281  			l2reg := float32(s.l2reg)
  1282  			clip := float32(s.clip)
  1283  			eps := float32(s.eps)
  1284  			eta := float32(s.eta)
  1285  
  1286  			c = cw.any()
  1287  			g = grad.(*F32).any()
  1288  
  1289  			c += g * g
  1290  
  1291  			if s.useClip {
  1292  				if g > clip {
  1293  					g = clip
  1294  				} else if g < -clip {
  1295  					g = -clip
  1296  				}
  1297  			}
  1298  
  1299  			w = weights.(*F32).any()
  1300  
  1301  			upd := -eta * g / math32.Sqrt(c+eps)
  1302  
  1303  			if s.useL2Reg {
  1304  				upd -= w * l2reg
  1305  			}
  1306  
  1307  			w += upd
  1308  
  1309  			// because scalar values are copies, and not pointers, we have to actually re-update the dualValu in model[i]
  1310  			*(weights.(*F32)) = F32(w)
  1311  			*(grad.(*F32)) = F32(0.0)
  1312  		case *F64:
  1313  			var w, g, c float64
  1314  
  1315  			l2reg := s.l2reg
  1316  			clip := s.clip
  1317  			eps := s.eps
  1318  			eta := s.eta
  1319  
  1320  			c = cw.any()
  1321  			g = grad.(*F64).any()
  1322  
  1323  			c += g * g
  1324  
  1325  			if s.useClip {
  1326  				if g > clip {
  1327  					g = clip
  1328  				} else if g < -clip {
  1329  					g = -clip
  1330  				}
  1331  			}
  1332  
  1333  			w = weights.(*F64).any()
  1334  			upd := -eta * g / math.Sqrt(c+eps)
  1335  			if s.useL2Reg {
  1336  				upd -= w * l2reg
  1337  			}
  1338  
  1339  			w += upd
  1340  
  1341  			// because scalar values are copies, and not pointers, we have to actually re-update the dualValu in model[i]
  1342  			*(weights.(*F64)) = F64(w)
  1343  			*(grad.(*F64)) = F64(0.0)
  1344  
  1345  		default:
  1346  			return errors.Errorf(nyiFail, "Adagrad step", cv)
  1347  		}
  1348  
  1349  	}
  1350  
  1351  	return
  1352  }
  1353  
  1354  // BarzilaiBorweinSolver / Barzilai-Borwein performs Gradient Descent in steepest descend direction
  1355  // Solves 0 = F(x), by
  1356  //  xᵢ₊₁ = xᵢ - eta * Grad(F)(xᵢ)
  1357  // Where the learn rate eta is calculated by the Barzilai-Borwein method:
  1358  //  eta(xᵢ) = <(xᵢ - xᵢ₋₁), (Grad(F)(xᵢ) - Grad(F)(xᵢ₋₁))> /
  1359  //                  ∥(Grad(F)(xᵢ) - Grad(F)(xᵢ₋₁))∥²
  1360  // The input learn rate is used for the first iteration.
  1361  //
  1362  // TODO: Check out stochastic implementations, e.g. "Barzilai-Borwein Step Size for Stochastic Gradient Descent" https://arxiv.org/abs/1605.04131
  1363  type BarzilaiBorweinSolver struct {
  1364  	eta     float64 // initial learn rate
  1365  	clip    float64 // clip value
  1366  	useClip bool
  1367  	prevDV  []*dualValue // dual value for xᵢ₋₁ step
  1368  }
  1369  
  1370  // NewBarzilaiBorweinSolver creates a new Barzilai-Borwein solver withs some default values:
  1371  // the learn rate is set to 0.001 and the solver does not use clipping.
  1372  func NewBarzilaiBorweinSolver(opts ...SolverOpt) *BarzilaiBorweinSolver {
  1373  	s := &BarzilaiBorweinSolver{
  1374  		eta:     0.001,
  1375  		useClip: false,
  1376  	}
  1377  
  1378  	for _, opt := range opts {
  1379  		opt(s)
  1380  	}
  1381  	return s
  1382  }
  1383  
  1384  // Step steps through each node in the model and applies the Barzilai-Borwein gradient descent algorithm on the value.
  1385  //
  1386  // This function will error out if the nodes do not have an associated Grad value.
  1387  func (s *BarzilaiBorweinSolver) Step(model []ValueGrad) (err error) {
  1388  
  1389  	firstRun := false
  1390  	if s.prevDV == nil {
  1391  		firstRun = true
  1392  		s.prevDV = make([]*dualValue, len(model))
  1393  	}
  1394  
  1395  	// Update the learning rate
  1396  	if false == firstRun {
  1397  		nominator := float64(0.0)
  1398  		denominator := float64(0.0)
  1399  
  1400  		for nodeNr, node := range model {
  1401  			var weights, grad Value
  1402  			if weights, grad, err = extractWeightGrad(node); err != nil {
  1403  				return err
  1404  			}
  1405  
  1406  			switch w := weights.(type) {
  1407  			case *tensor.Dense:
  1408  				g, ok := grad.(*tensor.Dense)
  1409  				if !ok {
  1410  					return errors.Errorf("Expected a *tensor.Dense in %v. Got %T instead", node, grad)
  1411  				}
  1412  
  1413  				wOld, ok := s.prevDV[nodeNr].Value.(*tensor.Dense)
  1414  				if !ok {
  1415  					return errors.Errorf("Expected a *tensor.Dense in %v. Got %T instead", node, s.prevDV[nodeNr].Value)
  1416  				}
  1417  
  1418  				gOld, ok := s.prevDV[nodeNr].d.(*tensor.Dense)
  1419  				if !ok {
  1420  					return errors.Errorf("Expected a *tensor.Dense in %v. Got %T instead", node, s.prevDV[nodeNr].d)
  1421  				}
  1422  
  1423  				valueDiff, err := tensor.Sub(w, wOld)
  1424  				defer returnTensor(valueDiff)
  1425  				if err != nil {
  1426  					return errors.Wrap(err, subFail)
  1427  				}
  1428  
  1429  				gradDiff, err := tensor.Sub(g, gOld)
  1430  				defer returnTensor(gradDiff)
  1431  				if err != nil {
  1432  					return errors.Wrap(err, subFail)
  1433  				}
  1434  
  1435  				// <(xᵢ - xᵢ₋₁), (Grad(F)(xᵢ) - Grad(F)(xᵢ₋₁))>
  1436  
  1437  				// Scalar Product == Total tensor contraction
  1438  				dims := valueDiff.Dims()
  1439  				contractionAxes := make([]int, dims, dims)
  1440  				for axis := 0; axis < len(contractionAxes); axis++ {
  1441  					contractionAxes[axis] = axis
  1442  				}
  1443  
  1444  				valGradDiffscalarProd, err := tensor.Contract(valueDiff, gradDiff, contractionAxes, contractionAxes)
  1445  				if err != nil {
  1446  					return errors.New("operationError, Contracting value / gradient difference")
  1447  				}
  1448  				defer returnTensor(valGradDiffscalarProd)
  1449  
  1450  				nominator += valGradDiffscalarProd.Data().([]float64)[0]
  1451  
  1452  				// ∥(Grad(F)(xᵢ) - Grad(F)(xᵢ₋₁))∥²
  1453  				gradDiffscalarProd, err := tensor.Contract(gradDiff, gradDiff, contractionAxes, contractionAxes)
  1454  				if err != nil {
  1455  					return errors.New("operationError, Contracting value / gradient difference")
  1456  				}
  1457  				defer returnTensor(gradDiffscalarProd)
  1458  
  1459  				denominator += gradDiffscalarProd.Data().([]float64)[0]
  1460  
  1461  			default:
  1462  				return errors.Errorf(nyiFail, "Barizai-Borwein step", w)
  1463  			}
  1464  		}
  1465  
  1466  		s.eta = nominator / denominator
  1467  
  1468  		if s.useClip && (math.Abs(s.eta) > s.clip) {
  1469  			if math.Signbit(s.eta) {
  1470  				s.eta = -s.clip
  1471  			} else {
  1472  				s.eta = s.clip
  1473  			}
  1474  		}
  1475  	}
  1476  
  1477  	// Save this iteration's values for the next run
  1478  	for nodeNr, node := range model {
  1479  		var weights, grad Value
  1480  		if weights, grad, err = extractWeightGrad(node); err != nil {
  1481  			return err
  1482  		}
  1483  
  1484  		if false == firstRun {
  1485  			// return memory for the old dual value used in this iteration
  1486  			returnDV(s.prevDV[nodeNr])
  1487  		}
  1488  		var oldDV *dualValue
  1489  		if oldDV, err = newCachedDV(node, weights, grad, false); err != nil {
  1490  			return err
  1491  		}
  1492  		s.prevDV[nodeNr] = oldDV
  1493  	}
  1494  
  1495  	// Update the weights
  1496  	for _, node := range model {
  1497  		var weights, grad Value
  1498  		if weights, grad, err = extractWeightGrad(node); err != nil {
  1499  			return err
  1500  		}
  1501  
  1502  		switch w := weights.(type) {
  1503  		case *tensor.Dense:
  1504  			g, ok := grad.(*tensor.Dense)
  1505  			if !ok {
  1506  				return errors.Errorf("Expected a *tensor.Dense in %v. Got %T instead", node, grad)
  1507  			}
  1508  
  1509  			upd, err := tensor.Mul(g, s.eta)
  1510  			defer returnTensor(upd)
  1511  
  1512  			if err != nil {
  1513  				return errors.Wrap(err, pointWiseMulFail)
  1514  			}
  1515  
  1516  			if _, err = tensor.Sub(w, upd, tensor.UseUnsafe()); err != nil {
  1517  				return errors.Wrap(err, subFail)
  1518  			}
  1519  
  1520  			g.Zero()
  1521  
  1522  		default:
  1523  			return errors.Errorf(nyiFail, "Barizai-Borwein step", w)
  1524  		}
  1525  	}
  1526  
  1527  	return nil
  1528  }