github.com/xmidt-org/webpa-common@v1.11.9/device/drain/drainer.go (about)

     1  package drain
     2  
     3  import (
     4  	"errors"
     5  	"sync"
     6  	"sync/atomic"
     7  	"time"
     8  
     9  	"github.com/go-kit/kit/log"
    10  	"github.com/go-kit/kit/log/level"
    11  	"github.com/go-kit/kit/metrics/discard"
    12  
    13  	"github.com/xmidt-org/webpa-common/device"
    14  	"github.com/xmidt-org/webpa-common/device/devicegate"
    15  	"github.com/xmidt-org/webpa-common/logging"
    16  	"github.com/xmidt-org/webpa-common/xmetrics"
    17  )
    18  
    19  var (
    20  	ErrActive    error = errors.New("A drain operation is already running")
    21  	ErrNotActive error = errors.New("No drain operation is running")
    22  )
    23  
    24  const (
    25  	StateNotActive uint32 = 0
    26  	StateActive    uint32 = 1
    27  
    28  	MetricNotDraining float64 = 0.0
    29  	MetricDraining    float64 = 1.0
    30  
    31  	Drained = "drained"
    32  
    33  	// disconnectBatchSize is the arbitrary size of batches used when no rate is associated with the drain,
    34  	// i.e. disconnect as fast as possible
    35  	disconnectBatchSize int = 1000
    36  )
    37  
    38  type Option func(*drainer)
    39  
    40  func WithLogger(l log.Logger) Option {
    41  	return func(dr *drainer) {
    42  		if l != nil {
    43  			dr.logger = l
    44  		} else {
    45  			dr.logger = logging.DefaultLogger()
    46  		}
    47  	}
    48  }
    49  
    50  func WithRegistry(r device.Registry) Option {
    51  	if r == nil {
    52  		panic("A device.Registry is required")
    53  	}
    54  
    55  	return func(dr *drainer) {
    56  		dr.registry = r
    57  	}
    58  }
    59  
    60  func WithConnector(c device.Connector) Option {
    61  	if c == nil {
    62  		panic("A device.Connector is required")
    63  	}
    64  
    65  	return func(dr *drainer) {
    66  		dr.connector = c
    67  	}
    68  }
    69  
    70  func WithManager(m device.Manager) Option {
    71  	if m == nil {
    72  		panic("A device.Manager is required")
    73  	}
    74  
    75  	return func(dr *drainer) {
    76  		dr.registry = m
    77  		dr.connector = m
    78  	}
    79  }
    80  
    81  func WithStateGauge(s xmetrics.Setter) Option {
    82  	return func(dr *drainer) {
    83  		if s != nil {
    84  			dr.m.state = s
    85  		} else {
    86  			dr.m.state = discard.NewGauge()
    87  		}
    88  	}
    89  }
    90  
    91  func WithDrainCounter(a xmetrics.Adder) Option {
    92  	return func(dr *drainer) {
    93  		if a != nil {
    94  			dr.m.counter = a
    95  		} else {
    96  			dr.m.counter = discard.NewCounter()
    97  		}
    98  	}
    99  }
   100  
   101  // DrainFilter contains the filter information for a drain job
   102  type DrainFilter interface {
   103  	device.Filter
   104  	GetFilterRequest() devicegate.FilterRequest
   105  }
   106  
   107  type Job struct {
   108  	// Count is the total number of devices to disconnect.  If this field is nonpositive and percent is unset,
   109  	// the count of connected devices at the start of job execution is used.  If Percent is set, this field's
   110  	// original value is ignored and it is set to that percentage of total devices connected at the time the
   111  	// job starts.
   112  	Count int `json:"count" schema:"count"`
   113  
   114  	// Percent is the fraction of devices to drain.  If this field is set, Count's original value is ignored
   115  	// and set to the computed percentage of connected devices at the time the job starts.
   116  	Percent int `json:"percent,omitempty" schema:"percent"`
   117  
   118  	// Rate is the number of devices per tick to disconnect.  If this field is nonpositive,
   119  	// devices are disconnected as fast as possible.
   120  	Rate int `json:"rate,omitempty" schema:"rate"`
   121  
   122  	// Tick is the time unit for the Rate field.  If Rate is set but this field is not set,
   123  	// a tick of 1 second is used as the default.
   124  	Tick time.Duration `json:"tick,omitempty" schema:"tick"`
   125  
   126  	// DrainFilter holds the filter to drain devices by. If this is set for the job, only devices that match the filter will be drained
   127  	DrainFilter DrainFilter `json:"filter,omitempty" schema:"filter"`
   128  }
   129  
   130  // ToMap returns a map representation of this Job appropriate for marshaling to formats like JSON.
   131  // This method makes things a bit prettier, like the Tick.
   132  func (j Job) ToMap() map[string]interface{} {
   133  	m := map[string]interface{}{
   134  		"count": j.Count,
   135  	}
   136  
   137  	if j.Percent > 0 {
   138  		m["percent"] = j.Percent
   139  	}
   140  
   141  	if j.Rate > 0 {
   142  		m["rate"] = j.Rate
   143  	}
   144  
   145  	if j.Tick > 0 {
   146  		m["tick"] = j.Tick.String()
   147  	}
   148  
   149  	if j.DrainFilter != nil {
   150  		m["filter"] = j.DrainFilter.GetFilterRequest()
   151  	}
   152  
   153  	return m
   154  }
   155  
   156  // normalize applies some basic logic to interpret defaults and set values appropriately for a given device count
   157  func (j *Job) normalize(deviceCount int) {
   158  	if j.Percent > 0 {
   159  		j.Count = int((float64(deviceCount) / 100.0) * float64(j.Percent))
   160  	} else if j.Count <= 0 {
   161  		j.Count = deviceCount
   162  	}
   163  
   164  	if j.Rate > 0 {
   165  		if j.Tick <= 0 {
   166  			j.Tick = time.Second
   167  		}
   168  	} else {
   169  		j.Rate = 0
   170  		j.Tick = 0
   171  	}
   172  }
   173  
   174  // Interface describes the behavior of a component which can execute a Job to drain devices.
   175  // Only (1) drain Job is allowed to run at any time.
   176  type Interface interface {
   177  	// Start attempts to begin draining devices.  The supplied Job describes how the drain will proceed.
   178  	// The returned channel can be used to wait for the drain job to complete.  The returned Job will be
   179  	// the result of applying defaults and will represent the actual Job being executed.  For example, if Job.Rate
   180  	// is set but Job.Tick is not, the returned Job will reflect the default of 1 second for Job.Tick.
   181  	Start(Job) (<-chan struct{}, Job, error)
   182  
   183  	// Status returns information about the current drain job, if any.  The boolean return indicates whether
   184  	// the job is currently active, while the returned Job describes the actual options used in starting the drainer.
   185  	// This returned Job instance will not necessarily be the same as that passed to Start, as certain fields
   186  	// may be computed or defaulted.
   187  	Status() (bool, Job, Progress)
   188  
   189  	// Cancel asynchronously halts any running drain job.  The returned channel can be used to wait for the job to actually exit.
   190  	// If no job is running, an error is returned along with a nil channel.
   191  	Cancel() (<-chan struct{}, error)
   192  }
   193  
   194  func defaultNewTicker(d time.Duration) (<-chan time.Time, func()) {
   195  	ticker := time.NewTicker(d)
   196  	return ticker.C, ticker.Stop
   197  }
   198  
   199  // New constructs a drainer using the supplied options
   200  func New(options ...Option) Interface {
   201  	dr := &drainer{
   202  		logger:    logging.DefaultLogger(),
   203  		now:       time.Now,
   204  		newTicker: defaultNewTicker,
   205  		m: metrics{
   206  			state:   discard.NewGauge(),
   207  			counter: discard.NewCounter(),
   208  		},
   209  	}
   210  
   211  	for _, f := range options {
   212  		f(dr)
   213  	}
   214  
   215  	if dr.registry == nil {
   216  		panic("A device.Registry is required")
   217  	}
   218  
   219  	if dr.connector == nil {
   220  		panic("A device.Connector is required")
   221  	}
   222  
   223  	dr.m.state.Set(MetricNotDraining)
   224  	return dr
   225  }
   226  
   227  type metrics struct {
   228  	state   xmetrics.Setter
   229  	counter xmetrics.Adder
   230  }
   231  
   232  // jobContext stores all the runtime information for a drain job
   233  type jobContext struct {
   234  	id        uint32
   235  	logger    log.Logger
   236  	t         *tracker
   237  	j         Job
   238  	batchSize int
   239  	ticker    <-chan time.Time
   240  	stop      func()
   241  	cancel    chan struct{}
   242  	done      chan struct{}
   243  }
   244  
   245  // drainer is the internal implementation of Interface
   246  type drainer struct {
   247  	logger    log.Logger
   248  	connector device.Connector
   249  	registry  device.Registry
   250  	now       func() time.Time
   251  	newTicker func(time.Duration) (<-chan time.Time, func())
   252  	m         metrics
   253  
   254  	controlLock sync.RWMutex
   255  	active      uint32
   256  	currentID   uint32
   257  	current     atomic.Value
   258  }
   259  
   260  // drainFilter is a concrete implementation of the DrainFilter interface
   261  type drainFilter struct {
   262  	filter        device.Filter
   263  	filterRequest devicegate.FilterRequest
   264  }
   265  
   266  func (d *drainFilter) GetFilterRequest() devicegate.FilterRequest {
   267  	return d.filterRequest
   268  }
   269  
   270  func (df *drainFilter) AllowConnection(d device.Interface) (bool, device.MatchResult) {
   271  	if df.filter == nil {
   272  		return false, device.MatchResult{}
   273  	}
   274  	return df.filter.AllowConnection(d)
   275  }
   276  
   277  // nextBatch grabs a batch of devices, bounded by the size of the supplied batch channel, and attempts
   278  // to disconnect each of them.  This method is sensitive to the jc.cancel channel.  If canceled, or if
   279  // no more devices are available, this method returns false.
   280  func (dr *drainer) nextBatch(jc jobContext, batch chan device.ID) (more bool, visited int, skipped int) {
   281  	jc.logger.Log(level.Key(), level.DebugValue(), logging.MessageKey(), "nextBatch starting")
   282  
   283  	more = true
   284  	skipped = 0
   285  	dr.registry.VisitAll(func(d device.Interface) bool {
   286  		// if drain filter set, see if device should be drained
   287  		if jc.j.DrainFilter != nil {
   288  			if allow, _ := jc.j.DrainFilter.AllowConnection(d); allow {
   289  				skipped++
   290  				return true
   291  			}
   292  		}
   293  
   294  		select {
   295  		case batch <- d.ID():
   296  			return true
   297  		case <-jc.cancel:
   298  			jc.logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "job canceled")
   299  			more = false
   300  			return false
   301  		default:
   302  			return false
   303  		}
   304  	})
   305  
   306  	visited = len(batch)
   307  	if !more {
   308  		return
   309  	}
   310  
   311  	if visited > 0 {
   312  		jc.logger.Log(level.Key(), level.DebugValue(), logging.MessageKey(), "nextBatch", "visited", visited)
   313  		drained := 0
   314  		for finished := false; more && !finished; {
   315  			select {
   316  			case id := <-batch:
   317  				if dr.connector.Disconnect(id, device.CloseReason{Text: Drained}) {
   318  					drained++
   319  				}
   320  			case <-jc.cancel:
   321  				jc.logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "job canceled")
   322  				more = false
   323  			default:
   324  				finished = true
   325  			}
   326  		}
   327  
   328  		jc.logger.Log(level.Key(), level.DebugValue(), logging.MessageKey(), "nextBatch", "visited", visited, "drained", drained)
   329  		jc.t.addVisited(visited)
   330  		jc.t.addDrained(drained)
   331  	} else {
   332  		// if no devices were visited (or enqueued), then we must be done.
   333  		// either a cancellation occurred or no devices are left
   334  		dr.logger.Log(level.Key(), level.DebugValue(), logging.MessageKey(), "no devices visited")
   335  		more = false
   336  	}
   337  
   338  	return
   339  }
   340  
   341  func (dr *drainer) jobFinished(jc jobContext) {
   342  	if jc.stop != nil {
   343  		jc.stop()
   344  	}
   345  
   346  	jc.t.done(dr.now().UTC())
   347  
   348  	// we need to contend on the control lock to avoid clobbering state from Start/Cancel code
   349  	dr.controlLock.Lock()
   350  	if jc.id == dr.currentID && atomic.CompareAndSwapUint32(&dr.active, StateActive, StateNotActive) {
   351  		dr.m.state.Set(MetricNotDraining)
   352  	}
   353  
   354  	dr.controlLock.Unlock()
   355  
   356  	// only close the done channel when all cleanup is complete
   357  	close(jc.done)
   358  
   359  	p := jc.t.Progress()
   360  	jc.logger.Log(level.Key(), level.InfoValue(), logging.MessageKey(), "drain complete", "visited", p.Visited, "drained", p.Drained)
   361  }
   362  
   363  // drain is run as a goroutine to drain devices at a particular rate
   364  func (dr *drainer) drain(jc jobContext) {
   365  	defer dr.jobFinished(jc)
   366  	jc.logger.Log(level.Key(), level.InfoValue(), logging.MessageKey(), "drain starting", "count", jc.j.Count, "rate", jc.j.Rate, "tick", jc.j.Tick)
   367  
   368  	var (
   369  		remaining = jc.j.Count
   370  		visited   = 0
   371  		skipped   = 0
   372  		more      = true
   373  		batch     = make(chan device.ID, jc.j.Rate)
   374  	)
   375  
   376  	for more && remaining > 0 {
   377  		if remaining < jc.j.Rate {
   378  			batch = make(chan device.ID, remaining)
   379  		}
   380  
   381  		select {
   382  		case <-jc.ticker:
   383  			more, visited, skipped = dr.nextBatch(jc, batch)
   384  			remaining -= visited
   385  
   386  			// If the number skipped is the number remaining in the registry,
   387  			// then there are no more devices that need to be disconnected.
   388  			if skipped == dr.registry.Len() {
   389  				more = false
   390  			}
   391  		case <-jc.cancel:
   392  			jc.logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "job canceled")
   393  			more = false
   394  		}
   395  	}
   396  }
   397  
   398  // disconnect is run as a goroutine to drain devices without a rate, i.e. as fast as possible
   399  func (dr *drainer) disconnect(jc jobContext) {
   400  	defer dr.jobFinished(jc)
   401  	jc.logger.Log(level.Key(), level.InfoValue(), logging.MessageKey(), "drain starting", "count", jc.j.Count)
   402  
   403  	var (
   404  		remaining = jc.j.Count
   405  		visited   = 0
   406  		more      = true
   407  		batch     = make(chan device.ID, jc.batchSize)
   408  	)
   409  
   410  	for more && remaining > 0 {
   411  		if remaining < jc.batchSize {
   412  			batch = make(chan device.ID, remaining)
   413  		}
   414  
   415  		more, visited, _ = dr.nextBatch(jc, batch)
   416  		remaining -= visited
   417  	}
   418  }
   419  
   420  func (dr *drainer) Start(j Job) (<-chan struct{}, Job, error) {
   421  	j.normalize(dr.registry.Len())
   422  
   423  	defer dr.controlLock.Unlock()
   424  	dr.controlLock.Lock()
   425  
   426  	if !atomic.CompareAndSwapUint32(&dr.active, StateNotActive, StateActive) {
   427  		return nil, Job{}, ErrActive
   428  	}
   429  
   430  	dr.currentID++
   431  	jc := jobContext{
   432  		id:     dr.currentID,
   433  		logger: log.With(dr.logger, "id", dr.currentID),
   434  		t: &tracker{
   435  			started: dr.now().UTC(),
   436  			counter: dr.m.counter,
   437  		},
   438  		j:      j,
   439  		cancel: make(chan struct{}),
   440  		done:   make(chan struct{}),
   441  	}
   442  
   443  	if jc.j.Rate > 0 {
   444  		jc.ticker, jc.stop = dr.newTicker(j.Tick)
   445  		go dr.drain(jc)
   446  	} else {
   447  		jc.batchSize = disconnectBatchSize
   448  		go dr.disconnect(jc)
   449  	}
   450  
   451  	dr.m.state.Set(MetricDraining)
   452  	dr.current.Store(jc)
   453  	return jc.done, jc.j, nil
   454  }
   455  
   456  func (dr *drainer) Status() (bool, Job, Progress) {
   457  	defer dr.controlLock.RUnlock()
   458  	dr.controlLock.RLock()
   459  
   460  	if jc, ok := dr.current.Load().(jobContext); ok {
   461  		return atomic.LoadUint32(&dr.active) == StateActive,
   462  			jc.j,
   463  			jc.t.Progress()
   464  	}
   465  
   466  	// if the job has never run, this result will be returned
   467  	return false, Job{}, Progress{}
   468  }
   469  
   470  func (dr *drainer) Cancel() (<-chan struct{}, error) {
   471  	defer dr.controlLock.Unlock()
   472  	dr.controlLock.Lock()
   473  
   474  	if !atomic.CompareAndSwapUint32(&dr.active, StateActive, StateNotActive) {
   475  		return nil, ErrNotActive
   476  	}
   477  
   478  	dr.m.state.Set(MetricNotDraining)
   479  	jc := dr.current.Load().(jobContext)
   480  	close(jc.cancel)
   481  	return jc.done, nil
   482  }