github.com/xmidt-org/webpa-common@v1.11.9/device/drain/drainer_test.go (about)

     1  package drain
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/xmidt-org/webpa-common/device"
    10  
    11  	"github.com/stretchr/testify/assert"
    12  	"github.com/stretchr/testify/require"
    13  	"github.com/xmidt-org/webpa-common/device/devicegate"
    14  	"github.com/xmidt-org/webpa-common/logging"
    15  	"github.com/xmidt-org/webpa-common/xmetrics/xmetricstest"
    16  )
    17  
    18  type deviceInfo struct {
    19  	claims map[string]interface{}
    20  	count  int
    21  }
    22  
    23  func testJobNormalize(t *testing.T) {
    24  	testDrainFilter := &drainFilter{
    25  		filter: &devicegate.FilterGate{
    26  			FilterStore: devicegate.FilterStore(map[string]devicegate.Set{
    27  				"test": &devicegate.FilterSet{Set: map[interface{}]bool{
    28  					"testValue":  true,
    29  					"testValue2": true,
    30  				}},
    31  			}),
    32  		},
    33  		filterRequest: devicegate.FilterRequest{
    34  			Key:    "test",
    35  			Values: []interface{}{"testValue", "testValue2"},
    36  		},
    37  	}
    38  
    39  	testData := []struct {
    40  		deviceCount int
    41  		actual      Job
    42  		expected    Job
    43  	}{
    44  		{1000, Job{}, Job{Count: 1000}},
    45  		{972, Job{Count: -1, Rate: -1}, Job{Count: 972}},
    46  		{1873, Job{Rate: 52}, Job{Count: 1873, Rate: 52, Tick: time.Second}},
    47  		{438742, Job{Tick: 15 * time.Minute}, Job{Count: 438742}},
    48  		{0, Job{Percent: 0}, Job{Count: 0}},
    49  		{123752, Job{Percent: 17}, Job{Count: 21037, Percent: 17}},
    50  		{73, Job{Percent: 100}, Job{Count: 73, Percent: 100}},
    51  		{90, Job{DrainFilter: testDrainFilter}, Job{Count: 90, DrainFilter: testDrainFilter}},
    52  	}
    53  
    54  	for i, record := range testData {
    55  		t.Run(strconv.Itoa(i), func(t *testing.T) {
    56  			var (
    57  				assert = assert.New(t)
    58  				actual = record.actual
    59  			)
    60  
    61  			actual.normalize(record.deviceCount)
    62  			assert.Equal(record.expected, actual)
    63  		})
    64  	}
    65  }
    66  
    67  func TestJob(t *testing.T) {
    68  	t.Run("Normalize", testJobNormalize)
    69  }
    70  
    71  func testWithLoggerDefault(t *testing.T) {
    72  	var (
    73  		assert = assert.New(t)
    74  		d      = new(drainer)
    75  	)
    76  
    77  	WithLogger(nil)(d)
    78  	assert.NotNil(d.logger)
    79  }
    80  
    81  func testWithLoggerCustom(t *testing.T) {
    82  	var (
    83  		assert = assert.New(t)
    84  		logger = logging.NewTestLogger(nil, t)
    85  		d      = new(drainer)
    86  	)
    87  
    88  	WithLogger(logger)(d)
    89  	assert.Equal(logger, d.logger)
    90  }
    91  
    92  func TestWithLogger(t *testing.T) {
    93  	t.Run("Default", testWithLoggerDefault)
    94  	t.Run("Custom", testWithLoggerCustom)
    95  }
    96  
    97  func testWithRegistryNil(t *testing.T) {
    98  	assert.Panics(t, func() {
    99  		WithRegistry(nil)
   100  	})
   101  }
   102  
   103  func testWithRegistryCustom(t *testing.T) {
   104  	var (
   105  		assert  = assert.New(t)
   106  		d       = new(drainer)
   107  		manager = new(stubManager)
   108  	)
   109  
   110  	WithRegistry(manager)(d)
   111  	assert.Equal(manager, d.registry)
   112  }
   113  
   114  func TestWithRegistry(t *testing.T) {
   115  	t.Run("Nil", testWithRegistryNil)
   116  	t.Run("Custom", testWithRegistryCustom)
   117  }
   118  
   119  func testWithConnectorNil(t *testing.T) {
   120  	assert.Panics(t, func() {
   121  		WithConnector(nil)
   122  	})
   123  }
   124  
   125  func testWithConnectorCustom(t *testing.T) {
   126  	var (
   127  		assert  = assert.New(t)
   128  		d       = new(drainer)
   129  		manager = new(stubManager)
   130  	)
   131  
   132  	WithConnector(manager)(d)
   133  	assert.Equal(manager, d.connector)
   134  }
   135  
   136  func TestWithConnector(t *testing.T) {
   137  	t.Run("Nil", testWithConnectorNil)
   138  	t.Run("Custom", testWithConnectorCustom)
   139  }
   140  
   141  func testWithManagerNil(t *testing.T) {
   142  	assert.Panics(t, func() {
   143  		WithManager(nil)
   144  	})
   145  }
   146  
   147  func testWithManagerCustom(t *testing.T) {
   148  	var (
   149  		assert  = assert.New(t)
   150  		d       = new(drainer)
   151  		manager = new(stubManager)
   152  	)
   153  
   154  	WithManager(manager)(d)
   155  	assert.Equal(manager, d.registry)
   156  	assert.Equal(manager, d.connector)
   157  }
   158  
   159  func TestWithManager(t *testing.T) {
   160  	t.Run("Nil", testWithManagerNil)
   161  	t.Run("Custom", testWithManagerCustom)
   162  }
   163  
   164  func testWithStateGaugeDefault(t *testing.T) {
   165  	var (
   166  		assert = assert.New(t)
   167  		d      = new(drainer)
   168  	)
   169  
   170  	WithStateGauge(nil)(d)
   171  	assert.NotNil(d.m.state)
   172  }
   173  
   174  func testWithStateGaugeCustom(t *testing.T) {
   175  	var (
   176  		assert   = assert.New(t)
   177  		d        = new(drainer)
   178  		provider = xmetricstest.NewProvider(nil)
   179  		gauge    = provider.NewGauge("test")
   180  	)
   181  
   182  	WithStateGauge(gauge)(d)
   183  	assert.Equal(gauge, d.m.state)
   184  }
   185  
   186  func TestWithStateGauge(t *testing.T) {
   187  	t.Run("Default", testWithStateGaugeDefault)
   188  	t.Run("Custom", testWithStateGaugeCustom)
   189  }
   190  
   191  func testWithDrainCounterDefault(t *testing.T) {
   192  	var (
   193  		assert = assert.New(t)
   194  		d      = new(drainer)
   195  	)
   196  
   197  	WithDrainCounter(nil)(d)
   198  	assert.NotNil(d.m.counter)
   199  }
   200  
   201  func testWithDrainCounterCustom(t *testing.T) {
   202  	var (
   203  		assert   = assert.New(t)
   204  		d        = new(drainer)
   205  		provider = xmetricstest.NewProvider(nil)
   206  		counter  = provider.NewCounter("test")
   207  	)
   208  
   209  	WithDrainCounter(counter)(d)
   210  	assert.Equal(counter, d.m.counter)
   211  }
   212  
   213  func TestWithDrainCounter(t *testing.T) {
   214  	t.Run("Default", testWithDrainCounterDefault)
   215  	t.Run("Custom", testWithDrainCounterCustom)
   216  }
   217  
   218  func testNewNoRegistry(t *testing.T) {
   219  	var (
   220  		assert  = assert.New(t)
   221  		manager = generateManager(assert, 0)
   222  	)
   223  
   224  	assert.Panics(func() {
   225  		New(WithConnector(manager))
   226  	})
   227  }
   228  
   229  func testNewNoConnector(t *testing.T) {
   230  	var (
   231  		assert  = assert.New(t)
   232  		manager = generateManager(assert, 0)
   233  	)
   234  
   235  	assert.Panics(func() {
   236  		New(WithRegistry(manager))
   237  	})
   238  }
   239  
   240  func TestNew(t *testing.T) {
   241  	t.Run("NoRegistry", testNewNoRegistry)
   242  	t.Run("NoConnector", testNewNoConnector)
   243  }
   244  
   245  func testDrainerDrainAll(t *testing.T, deviceCount int) {
   246  	var (
   247  		assert   = assert.New(t)
   248  		require  = require.New(t)
   249  		provider = xmetricstest.NewProvider(nil)
   250  		logger   = logging.NewTestLogger(nil, t)
   251  
   252  		manager = generateManager(assert, uint64(deviceCount))
   253  
   254  		firstTime        = true
   255  		expectedStarted  = time.Now()
   256  		expectedFinished = expectedStarted.Add(10 * time.Minute)
   257  
   258  		stopCalled = false
   259  		stop       = func() {
   260  			stopCalled = true
   261  		}
   262  
   263  		ticker = make(chan time.Time, 1)
   264  
   265  		d = New(
   266  			WithLogger(logger),
   267  			WithRegistry(manager),
   268  			WithConnector(manager),
   269  			WithStateGauge(provider.NewGauge("state")),
   270  			WithDrainCounter(provider.NewCounter("counter")),
   271  		)
   272  	)
   273  
   274  	require.NotNil(d)
   275  	d.(*drainer).now = func() time.Time {
   276  		if firstTime {
   277  			firstTime = false
   278  			return expectedStarted
   279  		}
   280  
   281  		return expectedFinished
   282  	}
   283  
   284  	d.(*drainer).newTicker = func(d time.Duration) (<-chan time.Time, func()) {
   285  		assert.Equal(time.Second, d)
   286  		return ticker, stop
   287  	}
   288  
   289  	defer d.Cancel() // cleanup in case of horribleness
   290  
   291  	done, err := d.Cancel()
   292  	assert.Nil(done)
   293  	assert.Error(err)
   294  
   295  	active, job, progress := d.Status()
   296  	assert.False(active)
   297  	assert.Equal(Job{}, job)
   298  	assert.Equal(Progress{}, progress)
   299  
   300  	provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining))
   301  	provider.Assert(t, "counter")(xmetricstest.Value(0.0))
   302  
   303  	done, job, err = d.Start(Job{Rate: 100, Tick: time.Second})
   304  	require.NoError(err)
   305  	require.NotNil(done)
   306  	assert.Equal(Job{Count: deviceCount, Rate: 100, Tick: time.Second}, job)
   307  
   308  	provider.Assert(t, "state")(xmetricstest.Value(MetricDraining))
   309  	provider.Assert(t, "counter")(xmetricstest.Value(0.0))
   310  
   311  	{
   312  		done, job, err := d.Start(Job{Rate: 123, Tick: time.Minute})
   313  		assert.Nil(done)
   314  		assert.Error(err)
   315  		assert.Equal(Job{}, job)
   316  	}
   317  
   318  	active, job, progress = d.Status()
   319  	assert.True(active)
   320  	assert.Equal(Job{Count: deviceCount, Rate: 100, Tick: time.Second}, job)
   321  	assert.Equal(Progress{Visited: 0, Drained: 0, Started: expectedStarted.UTC(), Finished: nil}, progress)
   322  
   323  	go func() {
   324  		ticks := deviceCount / 100
   325  		if (deviceCount % 100) > 0 {
   326  			ticks++
   327  		}
   328  
   329  		for i := 0; i < ticks; i++ {
   330  			ticker <- time.Time{}
   331  		}
   332  	}()
   333  
   334  	close(manager.pauseDisconnect)
   335  	close(manager.pauseVisit)
   336  	select {
   337  	case <-done:
   338  		// passed
   339  	case <-time.After(5 * time.Second):
   340  		assert.Fail("Drain failed to complete")
   341  		return
   342  	}
   343  
   344  	provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining))
   345  	provider.Assert(t, "counter")(xmetricstest.Value(float64(deviceCount)))
   346  
   347  	done, err = d.Cancel()
   348  	assert.Nil(done)
   349  	assert.Error(err)
   350  
   351  	active, job, progress = d.Status()
   352  	assert.False(active)
   353  	assert.Equal(Job{Count: deviceCount, Rate: 100, Tick: time.Second}, job)
   354  	assert.Equal(deviceCount, progress.Visited)
   355  	assert.Equal(deviceCount, progress.Drained)
   356  	assert.Equal(expectedStarted.UTC(), progress.Started)
   357  	require.NotNil(progress.Finished)
   358  	assert.Equal(expectedFinished.UTC(), *progress.Finished)
   359  
   360  	assert.Empty(manager.devices)
   361  	assert.True(stopCalled)
   362  }
   363  
   364  func testDrainerDisconnectAll(t *testing.T, deviceCount int) {
   365  	var (
   366  		assert   = assert.New(t)
   367  		require  = require.New(t)
   368  		provider = xmetricstest.NewProvider(nil)
   369  		logger   = logging.NewTestLogger(nil, t)
   370  
   371  		manager = generateManager(assert, uint64(deviceCount))
   372  
   373  		firstTime        = true
   374  		expectedStarted  = time.Now()
   375  		expectedFinished = expectedStarted.Add(10 * time.Minute)
   376  
   377  		d = New(
   378  			WithLogger(logger),
   379  			WithRegistry(manager),
   380  			WithConnector(manager),
   381  			WithStateGauge(provider.NewGauge("state")),
   382  			WithDrainCounter(provider.NewCounter("counter")),
   383  		)
   384  	)
   385  
   386  	require.NotNil(d)
   387  	d.(*drainer).now = func() time.Time {
   388  		if firstTime {
   389  			firstTime = false
   390  			return expectedStarted
   391  		}
   392  
   393  		return expectedFinished
   394  	}
   395  
   396  	defer d.Cancel() // cleanup in case of panic
   397  
   398  	done, err := d.Cancel()
   399  	assert.Nil(done)
   400  	assert.Error(err)
   401  
   402  	active, job, progress := d.Status()
   403  	assert.False(active)
   404  	assert.Equal(Job{}, job)
   405  	assert.Equal(Progress{}, progress)
   406  
   407  	provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining))
   408  	provider.Assert(t, "counter")(xmetricstest.Value(0.0))
   409  
   410  	done, job, err = d.Start(Job{})
   411  	require.NoError(err)
   412  	require.NotNil(done)
   413  	assert.Equal(Job{Count: deviceCount}, job)
   414  
   415  	provider.Assert(t, "state")(xmetricstest.Value(MetricDraining))
   416  	provider.Assert(t, "counter")(xmetricstest.Value(0.0))
   417  
   418  	{
   419  		done, job, err := d.Start(Job{Rate: 123, Tick: time.Minute})
   420  		assert.Nil(done)
   421  		assert.Error(err)
   422  		assert.Equal(Job{}, job)
   423  	}
   424  
   425  	active, job, progress = d.Status()
   426  	assert.True(active)
   427  	assert.Equal(Job{Count: deviceCount}, job)
   428  	assert.Equal(Progress{Visited: 0, Drained: 0, Started: expectedStarted.UTC(), Finished: nil}, progress)
   429  
   430  	close(manager.pauseDisconnect)
   431  	close(manager.pauseVisit)
   432  	select {
   433  	case <-done:
   434  		// passed
   435  	case <-time.After(5 * time.Second):
   436  		assert.Fail("Disconnect all failed to complete")
   437  		return
   438  	}
   439  
   440  	provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining))
   441  	provider.Assert(t, "counter")(xmetricstest.Value(float64(deviceCount)))
   442  
   443  	done, err = d.Cancel()
   444  	assert.Nil(done)
   445  	assert.Error(err)
   446  
   447  	active, job, progress = d.Status()
   448  	assert.False(active)
   449  	assert.Equal(Job{Count: deviceCount}, job)
   450  	assert.Equal(deviceCount, progress.Visited)
   451  	assert.Equal(deviceCount, progress.Drained)
   452  	assert.Equal(expectedStarted.UTC(), progress.Started)
   453  	require.NotNil(progress.Finished)
   454  	assert.Equal(expectedFinished.UTC(), *progress.Finished)
   455  
   456  	assert.Empty(manager.devices)
   457  }
   458  
   459  func testDrainerVisitCancel(t *testing.T) {
   460  	var (
   461  		assert   = assert.New(t)
   462  		require  = require.New(t)
   463  		provider = xmetricstest.NewProvider(nil)
   464  		logger   = logging.NewTestLogger(nil, t)
   465  
   466  		manager = generateManager(assert, 100)
   467  
   468  		d = New(
   469  			WithLogger(logger),
   470  			WithManager(manager),
   471  			WithStateGauge(provider.NewGauge("state")),
   472  			WithDrainCounter(provider.NewCounter("counter")),
   473  		)
   474  	)
   475  
   476  	require.NotNil(d)
   477  	d.Start(Job{})
   478  	done, err := d.Cancel()
   479  	require.NoError(err)
   480  	require.NotNil(done)
   481  	close(manager.pauseVisit)
   482  
   483  	select {
   484  	case <-done:
   485  		// passing
   486  	case <-time.After(5 * time.Second):
   487  		assert.Fail("The job did not complete after being canceled")
   488  		return
   489  	}
   490  
   491  	provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining))
   492  	provider.Assert(t, "counter")(xmetricstest.Value(0.0))
   493  }
   494  
   495  func testDrainerDisconnectCancel(t *testing.T) {
   496  	var (
   497  		assert   = assert.New(t)
   498  		require  = require.New(t)
   499  		provider = xmetricstest.NewProvider(nil)
   500  		logger   = logging.NewTestLogger(nil, t)
   501  
   502  		manager = generateManager(assert, 100)
   503  
   504  		d = New(
   505  			WithLogger(logger),
   506  			WithManager(manager),
   507  			WithStateGauge(provider.NewGauge("state")),
   508  			WithDrainCounter(provider.NewCounter("counter")),
   509  		)
   510  	)
   511  
   512  	require.NotNil(d)
   513  	defer d.Cancel()
   514  	d.Start(Job{})
   515  	close(manager.pauseVisit)
   516  
   517  	select {
   518  	case <-manager.disconnect:
   519  	case <-time.After(5 * time.Second):
   520  		assert.Fail("Disconnect was not called")
   521  		return
   522  	}
   523  
   524  	done, err := d.Cancel()
   525  	require.NoError(err)
   526  	require.NotNil(done)
   527  	close(manager.pauseDisconnect)
   528  
   529  	select {
   530  	case <-done:
   531  		// passing
   532  	case <-time.After(5 * time.Second):
   533  		assert.Fail("The job did not complete after being canceled")
   534  		return
   535  	}
   536  
   537  	provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining))
   538  	provider.Assert(t, "counter")(xmetricstest.Minimum(1.0))
   539  }
   540  
   541  func testDrainerDrainCancel(t *testing.T) {
   542  	var (
   543  		assert   = assert.New(t)
   544  		require  = require.New(t)
   545  		provider = xmetricstest.NewProvider(nil)
   546  		logger   = logging.NewTestLogger(nil, t)
   547  
   548  		manager = generateManager(assert, 100)
   549  
   550  		stopCalled = false
   551  		stop       = func() {
   552  			stopCalled = true
   553  		}
   554  		ticker = make(chan time.Time, 1)
   555  
   556  		d = New(
   557  			WithLogger(logger),
   558  			WithManager(manager),
   559  			WithStateGauge(provider.NewGauge("state")),
   560  			WithDrainCounter(provider.NewCounter("counter")),
   561  		)
   562  	)
   563  
   564  	require.NotNil(d)
   565  	defer d.Cancel()
   566  
   567  	d.(*drainer).newTicker = func(d time.Duration) (<-chan time.Time, func()) {
   568  		assert.Equal(time.Second, d)
   569  		return ticker, stop
   570  	}
   571  
   572  	done, job, err := d.Start(Job{Percent: 20, Rate: 5})
   573  	require.NoError(err)
   574  	require.NotNil(done)
   575  	assert.Equal(
   576  		Job{Count: 20, Percent: 20, Rate: 5, Tick: time.Second},
   577  		job,
   578  	)
   579  
   580  	active, job, _ := d.Status()
   581  	assert.True(active)
   582  	assert.Equal(
   583  		Job{Count: 20, Percent: 20, Rate: 5, Tick: time.Second},
   584  		job,
   585  	)
   586  
   587  	done, err = d.Cancel()
   588  	require.NotNil(done)
   589  	require.NoError(err)
   590  	ticker <- time.Time{}
   591  	close(manager.pauseVisit)
   592  	close(manager.pauseDisconnect)
   593  
   594  	select {
   595  	case <-done:
   596  		// passing
   597  	case <-time.After(5 * time.Second):
   598  		assert.Fail("Drain failed to complete")
   599  		return
   600  	}
   601  
   602  	provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining))
   603  	provider.Assert(t, "counter")(xmetricstest.Minimum(0.0))
   604  
   605  	assert.True(stopCalled)
   606  }
   607  
   608  func TestDrainer(t *testing.T) {
   609  	deviceCounts := []int{0, 1, 2, disconnectBatchSize - 1, disconnectBatchSize, disconnectBatchSize + 1, 1709}
   610  
   611  	t.Run("DisconnectAll", func(t *testing.T) {
   612  		for _, deviceCount := range deviceCounts {
   613  			t.Run(fmt.Sprintf("deviceCount=%d", deviceCount), func(t *testing.T) {
   614  				testDrainerDisconnectAll(t, deviceCount)
   615  			})
   616  		}
   617  	})
   618  
   619  	t.Run("DrainAll", func(t *testing.T) {
   620  		for _, deviceCount := range deviceCounts {
   621  			t.Run(fmt.Sprintf("deviceCount=%d", deviceCount), func(t *testing.T) {
   622  				testDrainerDrainAll(t, deviceCount)
   623  			})
   624  		}
   625  	})
   626  
   627  	t.Run("VisitCancel", testDrainerVisitCancel)
   628  	t.Run("DisconnectCancel", testDrainerDisconnectCancel)
   629  	t.Run("DrainCancel", testDrainerDrainCancel)
   630  }
   631  
   632  func testDrainFilter(t *testing.T, deviceTypeOne deviceInfo, deviceTypeTwo deviceInfo, df DrainFilter, expectedSkipped int, count int) {
   633  	var (
   634  		assert   = assert.New(t)
   635  		require  = require.New(t)
   636  		provider = xmetricstest.NewProvider(nil)
   637  		logger   = logging.NewTestLogger(nil, t)
   638  
   639  		// generate manager with devices that have two different metadatas
   640  		manager = generateManagerWithDifferentDevices(assert, deviceTypeOne.claims, uint64(deviceTypeOne.count), deviceTypeTwo.claims, uint64(deviceTypeTwo.count))
   641  
   642  		firstTime        = true
   643  		expectedStarted  = time.Now()
   644  		expectedFinished = expectedStarted.Add(10 * time.Minute)
   645  
   646  		stopCalled = false
   647  		stop       = func() {
   648  			stopCalled = true
   649  		}
   650  
   651  		ticker     = make(chan time.Time, 1)
   652  		totalCount = deviceTypeOne.count + deviceTypeTwo.count
   653  		realCount  = totalCount
   654  
   655  		d = New(
   656  			WithLogger(logger),
   657  			WithRegistry(manager),
   658  			WithConnector(manager),
   659  			WithStateGauge(provider.NewGauge("state")),
   660  			WithDrainCounter(provider.NewCounter("counter")),
   661  		)
   662  	)
   663  
   664  	if count > 0 {
   665  		realCount = count
   666  	}
   667  
   668  	require.NotNil(d)
   669  	d.(*drainer).now = func() time.Time {
   670  		if firstTime {
   671  			firstTime = false
   672  			return expectedStarted
   673  		}
   674  
   675  		return expectedFinished
   676  	}
   677  
   678  	d.(*drainer).newTicker = func(d time.Duration) (<-chan time.Time, func()) {
   679  		assert.Equal(time.Second, d)
   680  		return ticker, stop
   681  	}
   682  
   683  	defer d.Cancel() // cleanup in case of horribleness
   684  
   685  	// test that cancel will error if there is not a drain job in progress
   686  	done, err := d.Cancel()
   687  	assert.Nil(done)
   688  	assert.Error(err)
   689  
   690  	// test status when drain hasn't started
   691  	active, job, progress := d.Status()
   692  	assert.False(active)
   693  	assert.Equal(Job{}, job)
   694  	assert.Equal(Progress{}, progress)
   695  
   696  	provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining))
   697  	provider.Assert(t, "counter")(xmetricstest.Value(0.0))
   698  
   699  	// start drain job
   700  	if count > 0 {
   701  		done, job, err = d.Start(Job{Count: count, Rate: 100, Tick: time.Second, DrainFilter: df})
   702  	} else {
   703  		done, job, err = d.Start(Job{Rate: 100, Tick: time.Second, DrainFilter: df})
   704  	}
   705  
   706  	require.NoError(err)
   707  	require.NotNil(done)
   708  
   709  	assert.Equal(Job{Count: realCount, Rate: 100, Tick: time.Second, DrainFilter: df}, job)
   710  
   711  	provider.Assert(t, "state")(xmetricstest.Value(MetricDraining))
   712  	provider.Assert(t, "counter")(xmetricstest.Value(0.0))
   713  
   714  	{
   715  		// test starting another drain job when there is one in progress
   716  		done, job, err := d.Start(Job{Rate: 123, Tick: time.Minute})
   717  		assert.Nil(done)
   718  		assert.Error(err)
   719  		assert.Equal(Job{}, job)
   720  	}
   721  
   722  	// get status of drain job in progress
   723  	active, job, progress = d.Status()
   724  	assert.True(active)
   725  	assert.Equal(Job{Count: realCount, Rate: 100, Tick: time.Second, DrainFilter: df}, job)
   726  
   727  	assert.Equal(Progress{Visited: 0, Drained: 0, Started: expectedStarted.UTC(), Finished: nil}, progress)
   728  
   729  	go func() {
   730  		ticks := realCount / 100
   731  		if (realCount % 100) > 0 {
   732  			ticks++
   733  		}
   734  
   735  		for i := 0; i < ticks; i++ {
   736  			ticker <- time.Time{}
   737  		}
   738  	}()
   739  
   740  	close(manager.pauseDisconnect)
   741  	close(manager.pauseVisit)
   742  
   743  	// make sure jobFinished is called and done channel is closed
   744  	select {
   745  	case <-done:
   746  		// passed
   747  	case <-time.After(5 * time.Second):
   748  		assert.Fail("Drain failed to complete")
   749  		return
   750  	}
   751  
   752  	provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining))
   753  
   754  	if count > 0 && count <= totalCount-expectedSkipped {
   755  		provider.Assert(t, "counter")(xmetricstest.Value(float64(count)))
   756  	} else {
   757  		provider.Assert(t, "counter")(xmetricstest.Value(float64(totalCount - expectedSkipped)))
   758  	}
   759  
   760  	// test cancel when not draining
   761  	done, err = d.Cancel()
   762  	assert.Nil(done)
   763  	assert.Error(err)
   764  
   765  	active, job, progress = d.Status()
   766  	assert.False(active)
   767  
   768  	assert.Equal(Job{Count: realCount, Rate: 100, Tick: time.Second, DrainFilter: df}, job)
   769  
   770  	if count > 0 && count <= (totalCount-expectedSkipped) {
   771  		assert.Equal(count, progress.Visited)
   772  		assert.Equal(count, progress.Drained)
   773  		assert.Equal(totalCount-count, len(manager.devices))
   774  	} else {
   775  		assert.Equal(totalCount-expectedSkipped, progress.Visited)
   776  		assert.Equal(totalCount-expectedSkipped, progress.Drained)
   777  		assert.Equal(expectedSkipped, len(manager.devices))
   778  
   779  	}
   780  
   781  	assert.Equal(expectedStarted.UTC(), progress.Started)
   782  	require.NotNil(progress.Finished)
   783  	assert.Equal(expectedFinished.UTC(), *progress.Finished)
   784  
   785  	assert.True(stopCalled)
   786  
   787  }
   788  
   789  func testDisconnectFilter(t *testing.T, deviceTypeOne deviceInfo, deviceTypeTwo deviceInfo, df DrainFilter, expectedSkipped int, count int) {
   790  	var (
   791  		assert   = assert.New(t)
   792  		require  = require.New(t)
   793  		provider = xmetricstest.NewProvider(nil)
   794  		logger   = logging.NewTestLogger(nil, t)
   795  
   796  		// generate manager with devices that have two different metadatas
   797  		manager = generateManagerWithDifferentDevices(assert, deviceTypeOne.claims, uint64(deviceTypeOne.count), deviceTypeTwo.claims, uint64(deviceTypeTwo.count))
   798  
   799  		firstTime        = true
   800  		expectedStarted  = time.Now()
   801  		expectedFinished = expectedStarted.Add(10 * time.Minute)
   802  
   803  		totalCount = deviceTypeOne.count + deviceTypeTwo.count
   804  
   805  		d = New(
   806  			WithLogger(logger),
   807  			WithRegistry(manager),
   808  			WithConnector(manager),
   809  			WithStateGauge(provider.NewGauge("state")),
   810  			WithDrainCounter(provider.NewCounter("counter")),
   811  		)
   812  	)
   813  
   814  	require.NotNil(d)
   815  	d.(*drainer).now = func() time.Time {
   816  		if firstTime {
   817  			firstTime = false
   818  			return expectedStarted
   819  		}
   820  
   821  		return expectedFinished
   822  	}
   823  
   824  	defer d.Cancel() // cleanup in case of horribleness
   825  
   826  	// test that cancel will error if there is not a drain job in progress
   827  	done, err := d.Cancel()
   828  	assert.Nil(done)
   829  	assert.Error(err)
   830  
   831  	// test status when drain hasn't started
   832  	active, job, progress := d.Status()
   833  	assert.False(active)
   834  	assert.Equal(Job{}, job)
   835  	assert.Equal(Progress{}, progress)
   836  
   837  	provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining))
   838  	provider.Assert(t, "counter")(xmetricstest.Value(0.0))
   839  
   840  	// start drain job
   841  	if count > 0 {
   842  		done, job, err = d.Start(Job{Count: count, DrainFilter: df})
   843  	} else {
   844  		done, job, err = d.Start(Job{DrainFilter: df})
   845  	}
   846  
   847  	require.NoError(err)
   848  	require.NotNil(done)
   849  
   850  	if count > 0 {
   851  		assert.Equal(Job{Count: count, DrainFilter: df}, job)
   852  	} else {
   853  		assert.Equal(Job{Count: totalCount, DrainFilter: df}, job)
   854  	}
   855  
   856  	provider.Assert(t, "state")(xmetricstest.Value(MetricDraining))
   857  	provider.Assert(t, "counter")(xmetricstest.Value(0.0))
   858  
   859  	{
   860  		// test starting another drain job when there is one in progress
   861  		done, job, err := d.Start(Job{Rate: 123, Tick: time.Minute})
   862  		assert.Nil(done)
   863  		assert.Error(err)
   864  		assert.Equal(Job{}, job)
   865  	}
   866  
   867  	// get status of drain job in progress
   868  	active, job, progress = d.Status()
   869  	assert.True(active)
   870  	if count > 0 {
   871  		assert.Equal(Job{Count: count, DrainFilter: df}, job)
   872  	} else {
   873  		assert.Equal(Job{Count: totalCount, DrainFilter: df}, job)
   874  	}
   875  
   876  	assert.Equal(Progress{Visited: 0, Drained: 0, Started: expectedStarted.UTC(), Finished: nil}, progress)
   877  
   878  	close(manager.pauseDisconnect)
   879  	close(manager.pauseVisit)
   880  
   881  	// make sure jobFinished is called and done channel is closed
   882  	select {
   883  	case <-done:
   884  		// passed
   885  	case <-time.After(5 * time.Second):
   886  		assert.Fail("Drain failed to complete")
   887  		return
   888  	}
   889  
   890  	provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining))
   891  
   892  	if count > 0 && count <= totalCount-expectedSkipped {
   893  		provider.Assert(t, "counter")(xmetricstest.Value(float64(count)))
   894  	} else {
   895  		provider.Assert(t, "counter")(xmetricstest.Value(float64(totalCount - expectedSkipped)))
   896  	}
   897  
   898  	// test cancel when not draining
   899  	done, err = d.Cancel()
   900  	assert.Nil(done)
   901  	assert.Error(err)
   902  
   903  	active, job, progress = d.Status()
   904  	assert.False(active)
   905  
   906  	if count > 0 {
   907  		assert.Equal(Job{Count: count, DrainFilter: df}, job)
   908  	} else {
   909  		assert.Equal(Job{Count: totalCount, DrainFilter: df}, job)
   910  	}
   911  
   912  	if count > 0 && count <= (totalCount-expectedSkipped) {
   913  		assert.Equal(count, progress.Visited)
   914  		assert.Equal(count, progress.Drained)
   915  		assert.Equal(totalCount-count, len(manager.devices))
   916  	} else {
   917  		assert.Equal(totalCount-expectedSkipped, progress.Visited)
   918  		assert.Equal(totalCount-expectedSkipped, progress.Drained)
   919  		assert.Equal(expectedSkipped, len(manager.devices))
   920  
   921  	}
   922  
   923  	assert.Equal(expectedStarted.UTC(), progress.Started)
   924  	require.NotNil(progress.Finished)
   925  	assert.Equal(expectedFinished.UTC(), *progress.Finished)
   926  }
   927  
   928  func TestDrainerWithFilter(t *testing.T) {
   929  	var (
   930  		filterKey   = "test"
   931  		filterValue = "test1"
   932  		df          = drainFilter{
   933  			filter: &devicegate.FilterGate{
   934  				FilterStore: devicegate.FilterStore(map[string]devicegate.Set{
   935  					filterKey: &devicegate.FilterSet{Set: map[interface{}]bool{
   936  						filterValue: true,
   937  					}},
   938  				}),
   939  			},
   940  			filterRequest: devicegate.FilterRequest{
   941  				Key:    filterKey,
   942  				Values: []interface{}{filterValue},
   943  			},
   944  		}
   945  
   946  		metadata1 = map[string]interface{}{filterKey: "test"}
   947  		metadata2 = map[string]interface{}{filterKey: filterValue}
   948  
   949  		counts = [][]int{
   950  			[]int{0, 0, 100},
   951  			[]int{1, 0, 1},
   952  			[]int{2, 0, 9},
   953  			[]int{0, 1, 100},
   954  			[]int{0, 2, 1},
   955  			[]int{1, 1, 19},
   956  			[]int{0, disconnectBatchSize - 1, 100},
   957  			[]int{disconnectBatchSize - 1, 0, 20},
   958  			[]int{0, disconnectBatchSize, 20},
   959  			[]int{disconnectBatchSize, 0, 53},
   960  			[]int{0, disconnectBatchSize + 1, 120},
   961  			[]int{disconnectBatchSize + 1, 0, 400},
   962  			[]int{89, 1709, 1091},
   963  			[]int{1704, 43, 1000},
   964  		}
   965  	)
   966  
   967  	for _, deviceCount := range counts {
   968  		expectedSkip := deviceCount[0]
   969  		devices := []deviceInfo{
   970  			deviceInfo{count: deviceCount[0], claims: metadata1},
   971  			deviceInfo{count: deviceCount[1], claims: metadata2},
   972  		}
   973  
   974  		t.Run(fmt.Sprintf("deviceCount=%d", deviceCount[0]+deviceCount[1]), func(t *testing.T) {
   975  			t.Run("DrainAll", func(t *testing.T) {
   976  				testDrainFilter(t, devices[0], devices[1], &df, expectedSkip, -1)
   977  			})
   978  			t.Run("DrainWithCount", func(t *testing.T) {
   979  				testDrainFilter(t, devices[0], devices[1], &df, expectedSkip, deviceCount[2])
   980  			})
   981  			t.Run("DisconnectAll", func(t *testing.T) {
   982  				testDisconnectFilter(t, devices[0], devices[1], &df, expectedSkip, -1)
   983  			})
   984  			t.Run("DisconnectWithCount", func(t *testing.T) {
   985  				testDisconnectFilter(t, devices[0], devices[1], &df, expectedSkip, deviceCount[2])
   986  			})
   987  		})
   988  	}
   989  }
   990  
   991  func TestDrainFilterNilFilter(t *testing.T) {
   992  	assert := assert.New(t)
   993  	mockDevice := new(device.MockDevice)
   994  
   995  	df := drainFilter{}
   996  	allow, _ := df.AllowConnection(mockDevice)
   997  	assert.False(allow)
   998  }