github.com/xmidt-org/webpa-common@v1.11.9/device/drain/drainer.go (about) 1 package drain 2 3 import ( 4 "errors" 5 "sync" 6 "sync/atomic" 7 "time" 8 9 "github.com/go-kit/kit/log" 10 "github.com/go-kit/kit/log/level" 11 "github.com/go-kit/kit/metrics/discard" 12 13 "github.com/xmidt-org/webpa-common/device" 14 "github.com/xmidt-org/webpa-common/device/devicegate" 15 "github.com/xmidt-org/webpa-common/logging" 16 "github.com/xmidt-org/webpa-common/xmetrics" 17 ) 18 19 var ( 20 ErrActive error = errors.New("A drain operation is already running") 21 ErrNotActive error = errors.New("No drain operation is running") 22 ) 23 24 const ( 25 StateNotActive uint32 = 0 26 StateActive uint32 = 1 27 28 MetricNotDraining float64 = 0.0 29 MetricDraining float64 = 1.0 30 31 Drained = "drained" 32 33 // disconnectBatchSize is the arbitrary size of batches used when no rate is associated with the drain, 34 // i.e. disconnect as fast as possible 35 disconnectBatchSize int = 1000 36 ) 37 38 type Option func(*drainer) 39 40 func WithLogger(l log.Logger) Option { 41 return func(dr *drainer) { 42 if l != nil { 43 dr.logger = l 44 } else { 45 dr.logger = logging.DefaultLogger() 46 } 47 } 48 } 49 50 func WithRegistry(r device.Registry) Option { 51 if r == nil { 52 panic("A device.Registry is required") 53 } 54 55 return func(dr *drainer) { 56 dr.registry = r 57 } 58 } 59 60 func WithConnector(c device.Connector) Option { 61 if c == nil { 62 panic("A device.Connector is required") 63 } 64 65 return func(dr *drainer) { 66 dr.connector = c 67 } 68 } 69 70 func WithManager(m device.Manager) Option { 71 if m == nil { 72 panic("A device.Manager is required") 73 } 74 75 return func(dr *drainer) { 76 dr.registry = m 77 dr.connector = m 78 } 79 } 80 81 func WithStateGauge(s xmetrics.Setter) Option { 82 return func(dr *drainer) { 83 if s != nil { 84 dr.m.state = s 85 } else { 86 dr.m.state = discard.NewGauge() 87 } 88 } 89 } 90 91 func WithDrainCounter(a xmetrics.Adder) Option { 92 return func(dr *drainer) { 93 if a != nil { 94 dr.m.counter = a 95 } else { 96 dr.m.counter = discard.NewCounter() 97 } 98 } 99 } 100 101 // DrainFilter contains the filter information for a drain job 102 type DrainFilter interface { 103 device.Filter 104 GetFilterRequest() devicegate.FilterRequest 105 } 106 107 type Job struct { 108 // Count is the total number of devices to disconnect. If this field is nonpositive and percent is unset, 109 // the count of connected devices at the start of job execution is used. If Percent is set, this field's 110 // original value is ignored and it is set to that percentage of total devices connected at the time the 111 // job starts. 112 Count int `json:"count" schema:"count"` 113 114 // Percent is the fraction of devices to drain. If this field is set, Count's original value is ignored 115 // and set to the computed percentage of connected devices at the time the job starts. 116 Percent int `json:"percent,omitempty" schema:"percent"` 117 118 // Rate is the number of devices per tick to disconnect. If this field is nonpositive, 119 // devices are disconnected as fast as possible. 120 Rate int `json:"rate,omitempty" schema:"rate"` 121 122 // Tick is the time unit for the Rate field. If Rate is set but this field is not set, 123 // a tick of 1 second is used as the default. 124 Tick time.Duration `json:"tick,omitempty" schema:"tick"` 125 126 // DrainFilter holds the filter to drain devices by. If this is set for the job, only devices that match the filter will be drained 127 DrainFilter DrainFilter `json:"filter,omitempty" schema:"filter"` 128 } 129 130 // ToMap returns a map representation of this Job appropriate for marshaling to formats like JSON. 131 // This method makes things a bit prettier, like the Tick. 132 func (j Job) ToMap() map[string]interface{} { 133 m := map[string]interface{}{ 134 "count": j.Count, 135 } 136 137 if j.Percent > 0 { 138 m["percent"] = j.Percent 139 } 140 141 if j.Rate > 0 { 142 m["rate"] = j.Rate 143 } 144 145 if j.Tick > 0 { 146 m["tick"] = j.Tick.String() 147 } 148 149 if j.DrainFilter != nil { 150 m["filter"] = j.DrainFilter.GetFilterRequest() 151 } 152 153 return m 154 } 155 156 // normalize applies some basic logic to interpret defaults and set values appropriately for a given device count 157 func (j *Job) normalize(deviceCount int) { 158 if j.Percent > 0 { 159 j.Count = int((float64(deviceCount) / 100.0) * float64(j.Percent)) 160 } else if j.Count <= 0 { 161 j.Count = deviceCount 162 } 163 164 if j.Rate > 0 { 165 if j.Tick <= 0 { 166 j.Tick = time.Second 167 } 168 } else { 169 j.Rate = 0 170 j.Tick = 0 171 } 172 } 173 174 // Interface describes the behavior of a component which can execute a Job to drain devices. 175 // Only (1) drain Job is allowed to run at any time. 176 type Interface interface { 177 // Start attempts to begin draining devices. The supplied Job describes how the drain will proceed. 178 // The returned channel can be used to wait for the drain job to complete. The returned Job will be 179 // the result of applying defaults and will represent the actual Job being executed. For example, if Job.Rate 180 // is set but Job.Tick is not, the returned Job will reflect the default of 1 second for Job.Tick. 181 Start(Job) (<-chan struct{}, Job, error) 182 183 // Status returns information about the current drain job, if any. The boolean return indicates whether 184 // the job is currently active, while the returned Job describes the actual options used in starting the drainer. 185 // This returned Job instance will not necessarily be the same as that passed to Start, as certain fields 186 // may be computed or defaulted. 187 Status() (bool, Job, Progress) 188 189 // Cancel asynchronously halts any running drain job. The returned channel can be used to wait for the job to actually exit. 190 // If no job is running, an error is returned along with a nil channel. 191 Cancel() (<-chan struct{}, error) 192 } 193 194 func defaultNewTicker(d time.Duration) (<-chan time.Time, func()) { 195 ticker := time.NewTicker(d) 196 return ticker.C, ticker.Stop 197 } 198 199 // New constructs a drainer using the supplied options 200 func New(options ...Option) Interface { 201 dr := &drainer{ 202 logger: logging.DefaultLogger(), 203 now: time.Now, 204 newTicker: defaultNewTicker, 205 m: metrics{ 206 state: discard.NewGauge(), 207 counter: discard.NewCounter(), 208 }, 209 } 210 211 for _, f := range options { 212 f(dr) 213 } 214 215 if dr.registry == nil { 216 panic("A device.Registry is required") 217 } 218 219 if dr.connector == nil { 220 panic("A device.Connector is required") 221 } 222 223 dr.m.state.Set(MetricNotDraining) 224 return dr 225 } 226 227 type metrics struct { 228 state xmetrics.Setter 229 counter xmetrics.Adder 230 } 231 232 // jobContext stores all the runtime information for a drain job 233 type jobContext struct { 234 id uint32 235 logger log.Logger 236 t *tracker 237 j Job 238 batchSize int 239 ticker <-chan time.Time 240 stop func() 241 cancel chan struct{} 242 done chan struct{} 243 } 244 245 // drainer is the internal implementation of Interface 246 type drainer struct { 247 logger log.Logger 248 connector device.Connector 249 registry device.Registry 250 now func() time.Time 251 newTicker func(time.Duration) (<-chan time.Time, func()) 252 m metrics 253 254 controlLock sync.RWMutex 255 active uint32 256 currentID uint32 257 current atomic.Value 258 } 259 260 // drainFilter is a concrete implementation of the DrainFilter interface 261 type drainFilter struct { 262 filter device.Filter 263 filterRequest devicegate.FilterRequest 264 } 265 266 func (d *drainFilter) GetFilterRequest() devicegate.FilterRequest { 267 return d.filterRequest 268 } 269 270 func (df *drainFilter) AllowConnection(d device.Interface) (bool, device.MatchResult) { 271 if df.filter == nil { 272 return false, device.MatchResult{} 273 } 274 return df.filter.AllowConnection(d) 275 } 276 277 // nextBatch grabs a batch of devices, bounded by the size of the supplied batch channel, and attempts 278 // to disconnect each of them. This method is sensitive to the jc.cancel channel. If canceled, or if 279 // no more devices are available, this method returns false. 280 func (dr *drainer) nextBatch(jc jobContext, batch chan device.ID) (more bool, visited int, skipped int) { 281 jc.logger.Log(level.Key(), level.DebugValue(), logging.MessageKey(), "nextBatch starting") 282 283 more = true 284 skipped = 0 285 dr.registry.VisitAll(func(d device.Interface) bool { 286 // if drain filter set, see if device should be drained 287 if jc.j.DrainFilter != nil { 288 if allow, _ := jc.j.DrainFilter.AllowConnection(d); allow { 289 skipped++ 290 return true 291 } 292 } 293 294 select { 295 case batch <- d.ID(): 296 return true 297 case <-jc.cancel: 298 jc.logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "job canceled") 299 more = false 300 return false 301 default: 302 return false 303 } 304 }) 305 306 visited = len(batch) 307 if !more { 308 return 309 } 310 311 if visited > 0 { 312 jc.logger.Log(level.Key(), level.DebugValue(), logging.MessageKey(), "nextBatch", "visited", visited) 313 drained := 0 314 for finished := false; more && !finished; { 315 select { 316 case id := <-batch: 317 if dr.connector.Disconnect(id, device.CloseReason{Text: Drained}) { 318 drained++ 319 } 320 case <-jc.cancel: 321 jc.logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "job canceled") 322 more = false 323 default: 324 finished = true 325 } 326 } 327 328 jc.logger.Log(level.Key(), level.DebugValue(), logging.MessageKey(), "nextBatch", "visited", visited, "drained", drained) 329 jc.t.addVisited(visited) 330 jc.t.addDrained(drained) 331 } else { 332 // if no devices were visited (or enqueued), then we must be done. 333 // either a cancellation occurred or no devices are left 334 dr.logger.Log(level.Key(), level.DebugValue(), logging.MessageKey(), "no devices visited") 335 more = false 336 } 337 338 return 339 } 340 341 func (dr *drainer) jobFinished(jc jobContext) { 342 if jc.stop != nil { 343 jc.stop() 344 } 345 346 jc.t.done(dr.now().UTC()) 347 348 // we need to contend on the control lock to avoid clobbering state from Start/Cancel code 349 dr.controlLock.Lock() 350 if jc.id == dr.currentID && atomic.CompareAndSwapUint32(&dr.active, StateActive, StateNotActive) { 351 dr.m.state.Set(MetricNotDraining) 352 } 353 354 dr.controlLock.Unlock() 355 356 // only close the done channel when all cleanup is complete 357 close(jc.done) 358 359 p := jc.t.Progress() 360 jc.logger.Log(level.Key(), level.InfoValue(), logging.MessageKey(), "drain complete", "visited", p.Visited, "drained", p.Drained) 361 } 362 363 // drain is run as a goroutine to drain devices at a particular rate 364 func (dr *drainer) drain(jc jobContext) { 365 defer dr.jobFinished(jc) 366 jc.logger.Log(level.Key(), level.InfoValue(), logging.MessageKey(), "drain starting", "count", jc.j.Count, "rate", jc.j.Rate, "tick", jc.j.Tick) 367 368 var ( 369 remaining = jc.j.Count 370 visited = 0 371 skipped = 0 372 more = true 373 batch = make(chan device.ID, jc.j.Rate) 374 ) 375 376 for more && remaining > 0 { 377 if remaining < jc.j.Rate { 378 batch = make(chan device.ID, remaining) 379 } 380 381 select { 382 case <-jc.ticker: 383 more, visited, skipped = dr.nextBatch(jc, batch) 384 remaining -= visited 385 386 // If the number skipped is the number remaining in the registry, 387 // then there are no more devices that need to be disconnected. 388 if skipped == dr.registry.Len() { 389 more = false 390 } 391 case <-jc.cancel: 392 jc.logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "job canceled") 393 more = false 394 } 395 } 396 } 397 398 // disconnect is run as a goroutine to drain devices without a rate, i.e. as fast as possible 399 func (dr *drainer) disconnect(jc jobContext) { 400 defer dr.jobFinished(jc) 401 jc.logger.Log(level.Key(), level.InfoValue(), logging.MessageKey(), "drain starting", "count", jc.j.Count) 402 403 var ( 404 remaining = jc.j.Count 405 visited = 0 406 more = true 407 batch = make(chan device.ID, jc.batchSize) 408 ) 409 410 for more && remaining > 0 { 411 if remaining < jc.batchSize { 412 batch = make(chan device.ID, remaining) 413 } 414 415 more, visited, _ = dr.nextBatch(jc, batch) 416 remaining -= visited 417 } 418 } 419 420 func (dr *drainer) Start(j Job) (<-chan struct{}, Job, error) { 421 j.normalize(dr.registry.Len()) 422 423 defer dr.controlLock.Unlock() 424 dr.controlLock.Lock() 425 426 if !atomic.CompareAndSwapUint32(&dr.active, StateNotActive, StateActive) { 427 return nil, Job{}, ErrActive 428 } 429 430 dr.currentID++ 431 jc := jobContext{ 432 id: dr.currentID, 433 logger: log.With(dr.logger, "id", dr.currentID), 434 t: &tracker{ 435 started: dr.now().UTC(), 436 counter: dr.m.counter, 437 }, 438 j: j, 439 cancel: make(chan struct{}), 440 done: make(chan struct{}), 441 } 442 443 if jc.j.Rate > 0 { 444 jc.ticker, jc.stop = dr.newTicker(j.Tick) 445 go dr.drain(jc) 446 } else { 447 jc.batchSize = disconnectBatchSize 448 go dr.disconnect(jc) 449 } 450 451 dr.m.state.Set(MetricDraining) 452 dr.current.Store(jc) 453 return jc.done, jc.j, nil 454 } 455 456 func (dr *drainer) Status() (bool, Job, Progress) { 457 defer dr.controlLock.RUnlock() 458 dr.controlLock.RLock() 459 460 if jc, ok := dr.current.Load().(jobContext); ok { 461 return atomic.LoadUint32(&dr.active) == StateActive, 462 jc.j, 463 jc.t.Progress() 464 } 465 466 // if the job has never run, this result will be returned 467 return false, Job{}, Progress{} 468 } 469 470 func (dr *drainer) Cancel() (<-chan struct{}, error) { 471 defer dr.controlLock.Unlock() 472 dr.controlLock.Lock() 473 474 if !atomic.CompareAndSwapUint32(&dr.active, StateActive, StateNotActive) { 475 return nil, ErrNotActive 476 } 477 478 dr.m.state.Set(MetricNotDraining) 479 jc := dr.current.Load().(jobContext) 480 close(jc.cancel) 481 return jc.done, nil 482 }