github.com/xmidt-org/webpa-common@v1.11.9/device/drain/drainer_test.go (about) 1 package drain 2 3 import ( 4 "fmt" 5 "strconv" 6 "testing" 7 "time" 8 9 "github.com/xmidt-org/webpa-common/device" 10 11 "github.com/stretchr/testify/assert" 12 "github.com/stretchr/testify/require" 13 "github.com/xmidt-org/webpa-common/device/devicegate" 14 "github.com/xmidt-org/webpa-common/logging" 15 "github.com/xmidt-org/webpa-common/xmetrics/xmetricstest" 16 ) 17 18 type deviceInfo struct { 19 claims map[string]interface{} 20 count int 21 } 22 23 func testJobNormalize(t *testing.T) { 24 testDrainFilter := &drainFilter{ 25 filter: &devicegate.FilterGate{ 26 FilterStore: devicegate.FilterStore(map[string]devicegate.Set{ 27 "test": &devicegate.FilterSet{Set: map[interface{}]bool{ 28 "testValue": true, 29 "testValue2": true, 30 }}, 31 }), 32 }, 33 filterRequest: devicegate.FilterRequest{ 34 Key: "test", 35 Values: []interface{}{"testValue", "testValue2"}, 36 }, 37 } 38 39 testData := []struct { 40 deviceCount int 41 actual Job 42 expected Job 43 }{ 44 {1000, Job{}, Job{Count: 1000}}, 45 {972, Job{Count: -1, Rate: -1}, Job{Count: 972}}, 46 {1873, Job{Rate: 52}, Job{Count: 1873, Rate: 52, Tick: time.Second}}, 47 {438742, Job{Tick: 15 * time.Minute}, Job{Count: 438742}}, 48 {0, Job{Percent: 0}, Job{Count: 0}}, 49 {123752, Job{Percent: 17}, Job{Count: 21037, Percent: 17}}, 50 {73, Job{Percent: 100}, Job{Count: 73, Percent: 100}}, 51 {90, Job{DrainFilter: testDrainFilter}, Job{Count: 90, DrainFilter: testDrainFilter}}, 52 } 53 54 for i, record := range testData { 55 t.Run(strconv.Itoa(i), func(t *testing.T) { 56 var ( 57 assert = assert.New(t) 58 actual = record.actual 59 ) 60 61 actual.normalize(record.deviceCount) 62 assert.Equal(record.expected, actual) 63 }) 64 } 65 } 66 67 func TestJob(t *testing.T) { 68 t.Run("Normalize", testJobNormalize) 69 } 70 71 func testWithLoggerDefault(t *testing.T) { 72 var ( 73 assert = assert.New(t) 74 d = new(drainer) 75 ) 76 77 WithLogger(nil)(d) 78 assert.NotNil(d.logger) 79 } 80 81 func testWithLoggerCustom(t *testing.T) { 82 var ( 83 assert = assert.New(t) 84 logger = logging.NewTestLogger(nil, t) 85 d = new(drainer) 86 ) 87 88 WithLogger(logger)(d) 89 assert.Equal(logger, d.logger) 90 } 91 92 func TestWithLogger(t *testing.T) { 93 t.Run("Default", testWithLoggerDefault) 94 t.Run("Custom", testWithLoggerCustom) 95 } 96 97 func testWithRegistryNil(t *testing.T) { 98 assert.Panics(t, func() { 99 WithRegistry(nil) 100 }) 101 } 102 103 func testWithRegistryCustom(t *testing.T) { 104 var ( 105 assert = assert.New(t) 106 d = new(drainer) 107 manager = new(stubManager) 108 ) 109 110 WithRegistry(manager)(d) 111 assert.Equal(manager, d.registry) 112 } 113 114 func TestWithRegistry(t *testing.T) { 115 t.Run("Nil", testWithRegistryNil) 116 t.Run("Custom", testWithRegistryCustom) 117 } 118 119 func testWithConnectorNil(t *testing.T) { 120 assert.Panics(t, func() { 121 WithConnector(nil) 122 }) 123 } 124 125 func testWithConnectorCustom(t *testing.T) { 126 var ( 127 assert = assert.New(t) 128 d = new(drainer) 129 manager = new(stubManager) 130 ) 131 132 WithConnector(manager)(d) 133 assert.Equal(manager, d.connector) 134 } 135 136 func TestWithConnector(t *testing.T) { 137 t.Run("Nil", testWithConnectorNil) 138 t.Run("Custom", testWithConnectorCustom) 139 } 140 141 func testWithManagerNil(t *testing.T) { 142 assert.Panics(t, func() { 143 WithManager(nil) 144 }) 145 } 146 147 func testWithManagerCustom(t *testing.T) { 148 var ( 149 assert = assert.New(t) 150 d = new(drainer) 151 manager = new(stubManager) 152 ) 153 154 WithManager(manager)(d) 155 assert.Equal(manager, d.registry) 156 assert.Equal(manager, d.connector) 157 } 158 159 func TestWithManager(t *testing.T) { 160 t.Run("Nil", testWithManagerNil) 161 t.Run("Custom", testWithManagerCustom) 162 } 163 164 func testWithStateGaugeDefault(t *testing.T) { 165 var ( 166 assert = assert.New(t) 167 d = new(drainer) 168 ) 169 170 WithStateGauge(nil)(d) 171 assert.NotNil(d.m.state) 172 } 173 174 func testWithStateGaugeCustom(t *testing.T) { 175 var ( 176 assert = assert.New(t) 177 d = new(drainer) 178 provider = xmetricstest.NewProvider(nil) 179 gauge = provider.NewGauge("test") 180 ) 181 182 WithStateGauge(gauge)(d) 183 assert.Equal(gauge, d.m.state) 184 } 185 186 func TestWithStateGauge(t *testing.T) { 187 t.Run("Default", testWithStateGaugeDefault) 188 t.Run("Custom", testWithStateGaugeCustom) 189 } 190 191 func testWithDrainCounterDefault(t *testing.T) { 192 var ( 193 assert = assert.New(t) 194 d = new(drainer) 195 ) 196 197 WithDrainCounter(nil)(d) 198 assert.NotNil(d.m.counter) 199 } 200 201 func testWithDrainCounterCustom(t *testing.T) { 202 var ( 203 assert = assert.New(t) 204 d = new(drainer) 205 provider = xmetricstest.NewProvider(nil) 206 counter = provider.NewCounter("test") 207 ) 208 209 WithDrainCounter(counter)(d) 210 assert.Equal(counter, d.m.counter) 211 } 212 213 func TestWithDrainCounter(t *testing.T) { 214 t.Run("Default", testWithDrainCounterDefault) 215 t.Run("Custom", testWithDrainCounterCustom) 216 } 217 218 func testNewNoRegistry(t *testing.T) { 219 var ( 220 assert = assert.New(t) 221 manager = generateManager(assert, 0) 222 ) 223 224 assert.Panics(func() { 225 New(WithConnector(manager)) 226 }) 227 } 228 229 func testNewNoConnector(t *testing.T) { 230 var ( 231 assert = assert.New(t) 232 manager = generateManager(assert, 0) 233 ) 234 235 assert.Panics(func() { 236 New(WithRegistry(manager)) 237 }) 238 } 239 240 func TestNew(t *testing.T) { 241 t.Run("NoRegistry", testNewNoRegistry) 242 t.Run("NoConnector", testNewNoConnector) 243 } 244 245 func testDrainerDrainAll(t *testing.T, deviceCount int) { 246 var ( 247 assert = assert.New(t) 248 require = require.New(t) 249 provider = xmetricstest.NewProvider(nil) 250 logger = logging.NewTestLogger(nil, t) 251 252 manager = generateManager(assert, uint64(deviceCount)) 253 254 firstTime = true 255 expectedStarted = time.Now() 256 expectedFinished = expectedStarted.Add(10 * time.Minute) 257 258 stopCalled = false 259 stop = func() { 260 stopCalled = true 261 } 262 263 ticker = make(chan time.Time, 1) 264 265 d = New( 266 WithLogger(logger), 267 WithRegistry(manager), 268 WithConnector(manager), 269 WithStateGauge(provider.NewGauge("state")), 270 WithDrainCounter(provider.NewCounter("counter")), 271 ) 272 ) 273 274 require.NotNil(d) 275 d.(*drainer).now = func() time.Time { 276 if firstTime { 277 firstTime = false 278 return expectedStarted 279 } 280 281 return expectedFinished 282 } 283 284 d.(*drainer).newTicker = func(d time.Duration) (<-chan time.Time, func()) { 285 assert.Equal(time.Second, d) 286 return ticker, stop 287 } 288 289 defer d.Cancel() // cleanup in case of horribleness 290 291 done, err := d.Cancel() 292 assert.Nil(done) 293 assert.Error(err) 294 295 active, job, progress := d.Status() 296 assert.False(active) 297 assert.Equal(Job{}, job) 298 assert.Equal(Progress{}, progress) 299 300 provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining)) 301 provider.Assert(t, "counter")(xmetricstest.Value(0.0)) 302 303 done, job, err = d.Start(Job{Rate: 100, Tick: time.Second}) 304 require.NoError(err) 305 require.NotNil(done) 306 assert.Equal(Job{Count: deviceCount, Rate: 100, Tick: time.Second}, job) 307 308 provider.Assert(t, "state")(xmetricstest.Value(MetricDraining)) 309 provider.Assert(t, "counter")(xmetricstest.Value(0.0)) 310 311 { 312 done, job, err := d.Start(Job{Rate: 123, Tick: time.Minute}) 313 assert.Nil(done) 314 assert.Error(err) 315 assert.Equal(Job{}, job) 316 } 317 318 active, job, progress = d.Status() 319 assert.True(active) 320 assert.Equal(Job{Count: deviceCount, Rate: 100, Tick: time.Second}, job) 321 assert.Equal(Progress{Visited: 0, Drained: 0, Started: expectedStarted.UTC(), Finished: nil}, progress) 322 323 go func() { 324 ticks := deviceCount / 100 325 if (deviceCount % 100) > 0 { 326 ticks++ 327 } 328 329 for i := 0; i < ticks; i++ { 330 ticker <- time.Time{} 331 } 332 }() 333 334 close(manager.pauseDisconnect) 335 close(manager.pauseVisit) 336 select { 337 case <-done: 338 // passed 339 case <-time.After(5 * time.Second): 340 assert.Fail("Drain failed to complete") 341 return 342 } 343 344 provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining)) 345 provider.Assert(t, "counter")(xmetricstest.Value(float64(deviceCount))) 346 347 done, err = d.Cancel() 348 assert.Nil(done) 349 assert.Error(err) 350 351 active, job, progress = d.Status() 352 assert.False(active) 353 assert.Equal(Job{Count: deviceCount, Rate: 100, Tick: time.Second}, job) 354 assert.Equal(deviceCount, progress.Visited) 355 assert.Equal(deviceCount, progress.Drained) 356 assert.Equal(expectedStarted.UTC(), progress.Started) 357 require.NotNil(progress.Finished) 358 assert.Equal(expectedFinished.UTC(), *progress.Finished) 359 360 assert.Empty(manager.devices) 361 assert.True(stopCalled) 362 } 363 364 func testDrainerDisconnectAll(t *testing.T, deviceCount int) { 365 var ( 366 assert = assert.New(t) 367 require = require.New(t) 368 provider = xmetricstest.NewProvider(nil) 369 logger = logging.NewTestLogger(nil, t) 370 371 manager = generateManager(assert, uint64(deviceCount)) 372 373 firstTime = true 374 expectedStarted = time.Now() 375 expectedFinished = expectedStarted.Add(10 * time.Minute) 376 377 d = New( 378 WithLogger(logger), 379 WithRegistry(manager), 380 WithConnector(manager), 381 WithStateGauge(provider.NewGauge("state")), 382 WithDrainCounter(provider.NewCounter("counter")), 383 ) 384 ) 385 386 require.NotNil(d) 387 d.(*drainer).now = func() time.Time { 388 if firstTime { 389 firstTime = false 390 return expectedStarted 391 } 392 393 return expectedFinished 394 } 395 396 defer d.Cancel() // cleanup in case of panic 397 398 done, err := d.Cancel() 399 assert.Nil(done) 400 assert.Error(err) 401 402 active, job, progress := d.Status() 403 assert.False(active) 404 assert.Equal(Job{}, job) 405 assert.Equal(Progress{}, progress) 406 407 provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining)) 408 provider.Assert(t, "counter")(xmetricstest.Value(0.0)) 409 410 done, job, err = d.Start(Job{}) 411 require.NoError(err) 412 require.NotNil(done) 413 assert.Equal(Job{Count: deviceCount}, job) 414 415 provider.Assert(t, "state")(xmetricstest.Value(MetricDraining)) 416 provider.Assert(t, "counter")(xmetricstest.Value(0.0)) 417 418 { 419 done, job, err := d.Start(Job{Rate: 123, Tick: time.Minute}) 420 assert.Nil(done) 421 assert.Error(err) 422 assert.Equal(Job{}, job) 423 } 424 425 active, job, progress = d.Status() 426 assert.True(active) 427 assert.Equal(Job{Count: deviceCount}, job) 428 assert.Equal(Progress{Visited: 0, Drained: 0, Started: expectedStarted.UTC(), Finished: nil}, progress) 429 430 close(manager.pauseDisconnect) 431 close(manager.pauseVisit) 432 select { 433 case <-done: 434 // passed 435 case <-time.After(5 * time.Second): 436 assert.Fail("Disconnect all failed to complete") 437 return 438 } 439 440 provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining)) 441 provider.Assert(t, "counter")(xmetricstest.Value(float64(deviceCount))) 442 443 done, err = d.Cancel() 444 assert.Nil(done) 445 assert.Error(err) 446 447 active, job, progress = d.Status() 448 assert.False(active) 449 assert.Equal(Job{Count: deviceCount}, job) 450 assert.Equal(deviceCount, progress.Visited) 451 assert.Equal(deviceCount, progress.Drained) 452 assert.Equal(expectedStarted.UTC(), progress.Started) 453 require.NotNil(progress.Finished) 454 assert.Equal(expectedFinished.UTC(), *progress.Finished) 455 456 assert.Empty(manager.devices) 457 } 458 459 func testDrainerVisitCancel(t *testing.T) { 460 var ( 461 assert = assert.New(t) 462 require = require.New(t) 463 provider = xmetricstest.NewProvider(nil) 464 logger = logging.NewTestLogger(nil, t) 465 466 manager = generateManager(assert, 100) 467 468 d = New( 469 WithLogger(logger), 470 WithManager(manager), 471 WithStateGauge(provider.NewGauge("state")), 472 WithDrainCounter(provider.NewCounter("counter")), 473 ) 474 ) 475 476 require.NotNil(d) 477 d.Start(Job{}) 478 done, err := d.Cancel() 479 require.NoError(err) 480 require.NotNil(done) 481 close(manager.pauseVisit) 482 483 select { 484 case <-done: 485 // passing 486 case <-time.After(5 * time.Second): 487 assert.Fail("The job did not complete after being canceled") 488 return 489 } 490 491 provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining)) 492 provider.Assert(t, "counter")(xmetricstest.Value(0.0)) 493 } 494 495 func testDrainerDisconnectCancel(t *testing.T) { 496 var ( 497 assert = assert.New(t) 498 require = require.New(t) 499 provider = xmetricstest.NewProvider(nil) 500 logger = logging.NewTestLogger(nil, t) 501 502 manager = generateManager(assert, 100) 503 504 d = New( 505 WithLogger(logger), 506 WithManager(manager), 507 WithStateGauge(provider.NewGauge("state")), 508 WithDrainCounter(provider.NewCounter("counter")), 509 ) 510 ) 511 512 require.NotNil(d) 513 defer d.Cancel() 514 d.Start(Job{}) 515 close(manager.pauseVisit) 516 517 select { 518 case <-manager.disconnect: 519 case <-time.After(5 * time.Second): 520 assert.Fail("Disconnect was not called") 521 return 522 } 523 524 done, err := d.Cancel() 525 require.NoError(err) 526 require.NotNil(done) 527 close(manager.pauseDisconnect) 528 529 select { 530 case <-done: 531 // passing 532 case <-time.After(5 * time.Second): 533 assert.Fail("The job did not complete after being canceled") 534 return 535 } 536 537 provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining)) 538 provider.Assert(t, "counter")(xmetricstest.Minimum(1.0)) 539 } 540 541 func testDrainerDrainCancel(t *testing.T) { 542 var ( 543 assert = assert.New(t) 544 require = require.New(t) 545 provider = xmetricstest.NewProvider(nil) 546 logger = logging.NewTestLogger(nil, t) 547 548 manager = generateManager(assert, 100) 549 550 stopCalled = false 551 stop = func() { 552 stopCalled = true 553 } 554 ticker = make(chan time.Time, 1) 555 556 d = New( 557 WithLogger(logger), 558 WithManager(manager), 559 WithStateGauge(provider.NewGauge("state")), 560 WithDrainCounter(provider.NewCounter("counter")), 561 ) 562 ) 563 564 require.NotNil(d) 565 defer d.Cancel() 566 567 d.(*drainer).newTicker = func(d time.Duration) (<-chan time.Time, func()) { 568 assert.Equal(time.Second, d) 569 return ticker, stop 570 } 571 572 done, job, err := d.Start(Job{Percent: 20, Rate: 5}) 573 require.NoError(err) 574 require.NotNil(done) 575 assert.Equal( 576 Job{Count: 20, Percent: 20, Rate: 5, Tick: time.Second}, 577 job, 578 ) 579 580 active, job, _ := d.Status() 581 assert.True(active) 582 assert.Equal( 583 Job{Count: 20, Percent: 20, Rate: 5, Tick: time.Second}, 584 job, 585 ) 586 587 done, err = d.Cancel() 588 require.NotNil(done) 589 require.NoError(err) 590 ticker <- time.Time{} 591 close(manager.pauseVisit) 592 close(manager.pauseDisconnect) 593 594 select { 595 case <-done: 596 // passing 597 case <-time.After(5 * time.Second): 598 assert.Fail("Drain failed to complete") 599 return 600 } 601 602 provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining)) 603 provider.Assert(t, "counter")(xmetricstest.Minimum(0.0)) 604 605 assert.True(stopCalled) 606 } 607 608 func TestDrainer(t *testing.T) { 609 deviceCounts := []int{0, 1, 2, disconnectBatchSize - 1, disconnectBatchSize, disconnectBatchSize + 1, 1709} 610 611 t.Run("DisconnectAll", func(t *testing.T) { 612 for _, deviceCount := range deviceCounts { 613 t.Run(fmt.Sprintf("deviceCount=%d", deviceCount), func(t *testing.T) { 614 testDrainerDisconnectAll(t, deviceCount) 615 }) 616 } 617 }) 618 619 t.Run("DrainAll", func(t *testing.T) { 620 for _, deviceCount := range deviceCounts { 621 t.Run(fmt.Sprintf("deviceCount=%d", deviceCount), func(t *testing.T) { 622 testDrainerDrainAll(t, deviceCount) 623 }) 624 } 625 }) 626 627 t.Run("VisitCancel", testDrainerVisitCancel) 628 t.Run("DisconnectCancel", testDrainerDisconnectCancel) 629 t.Run("DrainCancel", testDrainerDrainCancel) 630 } 631 632 func testDrainFilter(t *testing.T, deviceTypeOne deviceInfo, deviceTypeTwo deviceInfo, df DrainFilter, expectedSkipped int, count int) { 633 var ( 634 assert = assert.New(t) 635 require = require.New(t) 636 provider = xmetricstest.NewProvider(nil) 637 logger = logging.NewTestLogger(nil, t) 638 639 // generate manager with devices that have two different metadatas 640 manager = generateManagerWithDifferentDevices(assert, deviceTypeOne.claims, uint64(deviceTypeOne.count), deviceTypeTwo.claims, uint64(deviceTypeTwo.count)) 641 642 firstTime = true 643 expectedStarted = time.Now() 644 expectedFinished = expectedStarted.Add(10 * time.Minute) 645 646 stopCalled = false 647 stop = func() { 648 stopCalled = true 649 } 650 651 ticker = make(chan time.Time, 1) 652 totalCount = deviceTypeOne.count + deviceTypeTwo.count 653 realCount = totalCount 654 655 d = New( 656 WithLogger(logger), 657 WithRegistry(manager), 658 WithConnector(manager), 659 WithStateGauge(provider.NewGauge("state")), 660 WithDrainCounter(provider.NewCounter("counter")), 661 ) 662 ) 663 664 if count > 0 { 665 realCount = count 666 } 667 668 require.NotNil(d) 669 d.(*drainer).now = func() time.Time { 670 if firstTime { 671 firstTime = false 672 return expectedStarted 673 } 674 675 return expectedFinished 676 } 677 678 d.(*drainer).newTicker = func(d time.Duration) (<-chan time.Time, func()) { 679 assert.Equal(time.Second, d) 680 return ticker, stop 681 } 682 683 defer d.Cancel() // cleanup in case of horribleness 684 685 // test that cancel will error if there is not a drain job in progress 686 done, err := d.Cancel() 687 assert.Nil(done) 688 assert.Error(err) 689 690 // test status when drain hasn't started 691 active, job, progress := d.Status() 692 assert.False(active) 693 assert.Equal(Job{}, job) 694 assert.Equal(Progress{}, progress) 695 696 provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining)) 697 provider.Assert(t, "counter")(xmetricstest.Value(0.0)) 698 699 // start drain job 700 if count > 0 { 701 done, job, err = d.Start(Job{Count: count, Rate: 100, Tick: time.Second, DrainFilter: df}) 702 } else { 703 done, job, err = d.Start(Job{Rate: 100, Tick: time.Second, DrainFilter: df}) 704 } 705 706 require.NoError(err) 707 require.NotNil(done) 708 709 assert.Equal(Job{Count: realCount, Rate: 100, Tick: time.Second, DrainFilter: df}, job) 710 711 provider.Assert(t, "state")(xmetricstest.Value(MetricDraining)) 712 provider.Assert(t, "counter")(xmetricstest.Value(0.0)) 713 714 { 715 // test starting another drain job when there is one in progress 716 done, job, err := d.Start(Job{Rate: 123, Tick: time.Minute}) 717 assert.Nil(done) 718 assert.Error(err) 719 assert.Equal(Job{}, job) 720 } 721 722 // get status of drain job in progress 723 active, job, progress = d.Status() 724 assert.True(active) 725 assert.Equal(Job{Count: realCount, Rate: 100, Tick: time.Second, DrainFilter: df}, job) 726 727 assert.Equal(Progress{Visited: 0, Drained: 0, Started: expectedStarted.UTC(), Finished: nil}, progress) 728 729 go func() { 730 ticks := realCount / 100 731 if (realCount % 100) > 0 { 732 ticks++ 733 } 734 735 for i := 0; i < ticks; i++ { 736 ticker <- time.Time{} 737 } 738 }() 739 740 close(manager.pauseDisconnect) 741 close(manager.pauseVisit) 742 743 // make sure jobFinished is called and done channel is closed 744 select { 745 case <-done: 746 // passed 747 case <-time.After(5 * time.Second): 748 assert.Fail("Drain failed to complete") 749 return 750 } 751 752 provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining)) 753 754 if count > 0 && count <= totalCount-expectedSkipped { 755 provider.Assert(t, "counter")(xmetricstest.Value(float64(count))) 756 } else { 757 provider.Assert(t, "counter")(xmetricstest.Value(float64(totalCount - expectedSkipped))) 758 } 759 760 // test cancel when not draining 761 done, err = d.Cancel() 762 assert.Nil(done) 763 assert.Error(err) 764 765 active, job, progress = d.Status() 766 assert.False(active) 767 768 assert.Equal(Job{Count: realCount, Rate: 100, Tick: time.Second, DrainFilter: df}, job) 769 770 if count > 0 && count <= (totalCount-expectedSkipped) { 771 assert.Equal(count, progress.Visited) 772 assert.Equal(count, progress.Drained) 773 assert.Equal(totalCount-count, len(manager.devices)) 774 } else { 775 assert.Equal(totalCount-expectedSkipped, progress.Visited) 776 assert.Equal(totalCount-expectedSkipped, progress.Drained) 777 assert.Equal(expectedSkipped, len(manager.devices)) 778 779 } 780 781 assert.Equal(expectedStarted.UTC(), progress.Started) 782 require.NotNil(progress.Finished) 783 assert.Equal(expectedFinished.UTC(), *progress.Finished) 784 785 assert.True(stopCalled) 786 787 } 788 789 func testDisconnectFilter(t *testing.T, deviceTypeOne deviceInfo, deviceTypeTwo deviceInfo, df DrainFilter, expectedSkipped int, count int) { 790 var ( 791 assert = assert.New(t) 792 require = require.New(t) 793 provider = xmetricstest.NewProvider(nil) 794 logger = logging.NewTestLogger(nil, t) 795 796 // generate manager with devices that have two different metadatas 797 manager = generateManagerWithDifferentDevices(assert, deviceTypeOne.claims, uint64(deviceTypeOne.count), deviceTypeTwo.claims, uint64(deviceTypeTwo.count)) 798 799 firstTime = true 800 expectedStarted = time.Now() 801 expectedFinished = expectedStarted.Add(10 * time.Minute) 802 803 totalCount = deviceTypeOne.count + deviceTypeTwo.count 804 805 d = New( 806 WithLogger(logger), 807 WithRegistry(manager), 808 WithConnector(manager), 809 WithStateGauge(provider.NewGauge("state")), 810 WithDrainCounter(provider.NewCounter("counter")), 811 ) 812 ) 813 814 require.NotNil(d) 815 d.(*drainer).now = func() time.Time { 816 if firstTime { 817 firstTime = false 818 return expectedStarted 819 } 820 821 return expectedFinished 822 } 823 824 defer d.Cancel() // cleanup in case of horribleness 825 826 // test that cancel will error if there is not a drain job in progress 827 done, err := d.Cancel() 828 assert.Nil(done) 829 assert.Error(err) 830 831 // test status when drain hasn't started 832 active, job, progress := d.Status() 833 assert.False(active) 834 assert.Equal(Job{}, job) 835 assert.Equal(Progress{}, progress) 836 837 provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining)) 838 provider.Assert(t, "counter")(xmetricstest.Value(0.0)) 839 840 // start drain job 841 if count > 0 { 842 done, job, err = d.Start(Job{Count: count, DrainFilter: df}) 843 } else { 844 done, job, err = d.Start(Job{DrainFilter: df}) 845 } 846 847 require.NoError(err) 848 require.NotNil(done) 849 850 if count > 0 { 851 assert.Equal(Job{Count: count, DrainFilter: df}, job) 852 } else { 853 assert.Equal(Job{Count: totalCount, DrainFilter: df}, job) 854 } 855 856 provider.Assert(t, "state")(xmetricstest.Value(MetricDraining)) 857 provider.Assert(t, "counter")(xmetricstest.Value(0.0)) 858 859 { 860 // test starting another drain job when there is one in progress 861 done, job, err := d.Start(Job{Rate: 123, Tick: time.Minute}) 862 assert.Nil(done) 863 assert.Error(err) 864 assert.Equal(Job{}, job) 865 } 866 867 // get status of drain job in progress 868 active, job, progress = d.Status() 869 assert.True(active) 870 if count > 0 { 871 assert.Equal(Job{Count: count, DrainFilter: df}, job) 872 } else { 873 assert.Equal(Job{Count: totalCount, DrainFilter: df}, job) 874 } 875 876 assert.Equal(Progress{Visited: 0, Drained: 0, Started: expectedStarted.UTC(), Finished: nil}, progress) 877 878 close(manager.pauseDisconnect) 879 close(manager.pauseVisit) 880 881 // make sure jobFinished is called and done channel is closed 882 select { 883 case <-done: 884 // passed 885 case <-time.After(5 * time.Second): 886 assert.Fail("Drain failed to complete") 887 return 888 } 889 890 provider.Assert(t, "state")(xmetricstest.Value(MetricNotDraining)) 891 892 if count > 0 && count <= totalCount-expectedSkipped { 893 provider.Assert(t, "counter")(xmetricstest.Value(float64(count))) 894 } else { 895 provider.Assert(t, "counter")(xmetricstest.Value(float64(totalCount - expectedSkipped))) 896 } 897 898 // test cancel when not draining 899 done, err = d.Cancel() 900 assert.Nil(done) 901 assert.Error(err) 902 903 active, job, progress = d.Status() 904 assert.False(active) 905 906 if count > 0 { 907 assert.Equal(Job{Count: count, DrainFilter: df}, job) 908 } else { 909 assert.Equal(Job{Count: totalCount, DrainFilter: df}, job) 910 } 911 912 if count > 0 && count <= (totalCount-expectedSkipped) { 913 assert.Equal(count, progress.Visited) 914 assert.Equal(count, progress.Drained) 915 assert.Equal(totalCount-count, len(manager.devices)) 916 } else { 917 assert.Equal(totalCount-expectedSkipped, progress.Visited) 918 assert.Equal(totalCount-expectedSkipped, progress.Drained) 919 assert.Equal(expectedSkipped, len(manager.devices)) 920 921 } 922 923 assert.Equal(expectedStarted.UTC(), progress.Started) 924 require.NotNil(progress.Finished) 925 assert.Equal(expectedFinished.UTC(), *progress.Finished) 926 } 927 928 func TestDrainerWithFilter(t *testing.T) { 929 var ( 930 filterKey = "test" 931 filterValue = "test1" 932 df = drainFilter{ 933 filter: &devicegate.FilterGate{ 934 FilterStore: devicegate.FilterStore(map[string]devicegate.Set{ 935 filterKey: &devicegate.FilterSet{Set: map[interface{}]bool{ 936 filterValue: true, 937 }}, 938 }), 939 }, 940 filterRequest: devicegate.FilterRequest{ 941 Key: filterKey, 942 Values: []interface{}{filterValue}, 943 }, 944 } 945 946 metadata1 = map[string]interface{}{filterKey: "test"} 947 metadata2 = map[string]interface{}{filterKey: filterValue} 948 949 counts = [][]int{ 950 []int{0, 0, 100}, 951 []int{1, 0, 1}, 952 []int{2, 0, 9}, 953 []int{0, 1, 100}, 954 []int{0, 2, 1}, 955 []int{1, 1, 19}, 956 []int{0, disconnectBatchSize - 1, 100}, 957 []int{disconnectBatchSize - 1, 0, 20}, 958 []int{0, disconnectBatchSize, 20}, 959 []int{disconnectBatchSize, 0, 53}, 960 []int{0, disconnectBatchSize + 1, 120}, 961 []int{disconnectBatchSize + 1, 0, 400}, 962 []int{89, 1709, 1091}, 963 []int{1704, 43, 1000}, 964 } 965 ) 966 967 for _, deviceCount := range counts { 968 expectedSkip := deviceCount[0] 969 devices := []deviceInfo{ 970 deviceInfo{count: deviceCount[0], claims: metadata1}, 971 deviceInfo{count: deviceCount[1], claims: metadata2}, 972 } 973 974 t.Run(fmt.Sprintf("deviceCount=%d", deviceCount[0]+deviceCount[1]), func(t *testing.T) { 975 t.Run("DrainAll", func(t *testing.T) { 976 testDrainFilter(t, devices[0], devices[1], &df, expectedSkip, -1) 977 }) 978 t.Run("DrainWithCount", func(t *testing.T) { 979 testDrainFilter(t, devices[0], devices[1], &df, expectedSkip, deviceCount[2]) 980 }) 981 t.Run("DisconnectAll", func(t *testing.T) { 982 testDisconnectFilter(t, devices[0], devices[1], &df, expectedSkip, -1) 983 }) 984 t.Run("DisconnectWithCount", func(t *testing.T) { 985 testDisconnectFilter(t, devices[0], devices[1], &df, expectedSkip, deviceCount[2]) 986 }) 987 }) 988 } 989 } 990 991 func TestDrainFilterNilFilter(t *testing.T) { 992 assert := assert.New(t) 993 mockDevice := new(device.MockDevice) 994 995 df := drainFilter{} 996 allow, _ := df.AllowConnection(mockDevice) 997 assert.False(allow) 998 }