github.com/hernad/nomad@v1.6.112/nomad/drainer/drain_testing.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package drainer
     5  
     6  import (
     7  	"context"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  
    12  	"golang.org/x/time/rate"
    13  
    14  	"github.com/hernad/nomad/helper/testlog"
    15  	"github.com/hernad/nomad/nomad/state"
    16  	"github.com/hernad/nomad/nomad/structs"
    17  )
    18  
    19  // This file contains helpers for testing. The raft shims make it hard to test
    20  // the whole package behavior of the drainer. See also nomad/drainer_int_test.go
    21  // for integration tests.
    22  
    23  type MockJobWatcher struct {
    24  	drainCh    chan *DrainRequest
    25  	migratedCh chan []*structs.Allocation
    26  	jobs       map[structs.NamespacedID]struct{}
    27  	sync.Mutex
    28  }
    29  
    30  // RegisterJobs marks the job as being watched
    31  func (m *MockJobWatcher) RegisterJobs(jobs []structs.NamespacedID) {
    32  	m.Lock()
    33  	defer m.Unlock()
    34  	for _, job := range jobs {
    35  		m.jobs[job] = struct{}{}
    36  	}
    37  }
    38  
    39  // Drain returns the DrainRequest channel. Tests can send on this channel to
    40  // simulate steps through the NodeDrainer watch loop. (Sending on this channel
    41  // will block anywhere else.)
    42  func (m *MockJobWatcher) Drain() <-chan *DrainRequest {
    43  	return m.drainCh
    44  }
    45  
    46  // Migrated returns the channel of migrated allocations. Tests can send on this
    47  // channel to simulate steps through the NodeDrainer watch loop. (Sending on
    48  // this channel will block anywhere else.)
    49  func (m *MockJobWatcher) Migrated() <-chan []*structs.Allocation {
    50  	return m.migratedCh
    51  }
    52  
    53  type MockDeadlineNotifier struct {
    54  	expiredCh <-chan []string
    55  	nodes     map[string]struct{}
    56  	sync.Mutex
    57  }
    58  
    59  // NextBatch returns the channel of expired nodes. Tests can send on this
    60  // channel to simulate timer events in the NodeDrainer watch loop. (Sending on
    61  // this channel will block anywhere else.)
    62  func (m *MockDeadlineNotifier) NextBatch() <-chan []string {
    63  	return m.expiredCh
    64  }
    65  
    66  // Remove removes the given node from being tracked for a deadline.
    67  func (m *MockDeadlineNotifier) Remove(nodeID string) {
    68  	m.Lock()
    69  	defer m.Unlock()
    70  	delete(m.nodes, nodeID)
    71  }
    72  
    73  // Watch marks the node as being watched; this mock throws out the timer in lieu
    74  // of manully sending on the channel to avoid racy tests.
    75  func (m *MockDeadlineNotifier) Watch(nodeID string, _ time.Time) {
    76  	m.Lock()
    77  	defer m.Unlock()
    78  	m.nodes[nodeID] = struct{}{}
    79  }
    80  
    81  type MockRaftApplierShim struct {
    82  	lock  sync.Mutex
    83  	state *state.StateStore
    84  }
    85  
    86  // AllocUpdateDesiredTransition mocks a write to raft as a state store update
    87  func (m *MockRaftApplierShim) AllocUpdateDesiredTransition(
    88  	allocs map[string]*structs.DesiredTransition, evals []*structs.Evaluation) (uint64, error) {
    89  
    90  	m.lock.Lock()
    91  	defer m.lock.Unlock()
    92  
    93  	index, _ := m.state.LatestIndex()
    94  	index++
    95  	err := m.state.UpdateAllocsDesiredTransitions(structs.MsgTypeTestSetup, index, allocs, evals)
    96  	return index, err
    97  }
    98  
    99  // NodesDrainComplete mocks a write to raft as a state store update
   100  func (m *MockRaftApplierShim) NodesDrainComplete(
   101  	nodes []string, event *structs.NodeEvent) (uint64, error) {
   102  
   103  	m.lock.Lock()
   104  	defer m.lock.Unlock()
   105  
   106  	index, _ := m.state.LatestIndex()
   107  	index++
   108  
   109  	updates := make(map[string]*structs.DrainUpdate, len(nodes))
   110  	nodeEvents := make(map[string]*structs.NodeEvent, len(nodes))
   111  	update := &structs.DrainUpdate{}
   112  	for _, node := range nodes {
   113  		updates[node] = update
   114  		if event != nil {
   115  			nodeEvents[node] = event
   116  		}
   117  	}
   118  	now := time.Now().Unix()
   119  
   120  	err := m.state.BatchUpdateNodeDrain(structs.MsgTypeTestSetup, index, now,
   121  		updates, nodeEvents)
   122  
   123  	return index, err
   124  }
   125  
   126  func testNodeDrainWatcher(t *testing.T) (*nodeDrainWatcher, *state.StateStore, *NodeDrainer) {
   127  	t.Helper()
   128  	store := state.TestStateStore(t)
   129  	limiter := rate.NewLimiter(100.0, 100)
   130  	logger := testlog.HCLogger(t)
   131  
   132  	drainer := &NodeDrainer{
   133  		enabled:          false,
   134  		logger:           logger,
   135  		nodes:            map[string]*drainingNode{},
   136  		jobWatcher:       &MockJobWatcher{jobs: map[structs.NamespacedID]struct{}{}},
   137  		deadlineNotifier: &MockDeadlineNotifier{nodes: map[string]struct{}{}},
   138  		state:            store,
   139  		queryLimiter:     limiter,
   140  		raft:             &MockRaftApplierShim{state: store},
   141  		batcher:          allocMigrateBatcher{},
   142  	}
   143  
   144  	w := NewNodeDrainWatcher(context.Background(), limiter, store, logger, drainer)
   145  	drainer.nodeWatcher = w
   146  	return w, store, drainer
   147  }