github.com/Ilhicas/nomad@v1.0.4-0.20210304152020-e86851182bc3/nomad/volumewatcher/volumes_watcher_test.go (about)

     1  package volumewatcher
     2  
     3  import (
     4  	"testing"
     5  	"time"
     6  
     7  	memdb "github.com/hashicorp/go-memdb"
     8  	"github.com/hashicorp/nomad/helper/testlog"
     9  	"github.com/hashicorp/nomad/nomad/mock"
    10  	"github.com/hashicorp/nomad/nomad/state"
    11  	"github.com/hashicorp/nomad/nomad/structs"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  // TestVolumeWatch_EnableDisable tests the watcher registration logic that needs
    16  // to happen during leader step-up/step-down
    17  func TestVolumeWatch_EnableDisable(t *testing.T) {
    18  	t.Parallel()
    19  	require := require.New(t)
    20  
    21  	srv := &MockRPCServer{}
    22  	srv.state = state.TestStateStore(t)
    23  	index := uint64(100)
    24  
    25  	watcher := NewVolumesWatcher(testlog.HCLogger(t), srv, "")
    26  	watcher.SetEnabled(true, srv.State())
    27  
    28  	plugin := mock.CSIPlugin()
    29  	node := testNode(plugin, srv.State())
    30  	alloc := mock.Alloc()
    31  	alloc.ClientStatus = structs.AllocClientStatusComplete
    32  	vol := testVolume(plugin, alloc, node.ID)
    33  
    34  	index++
    35  	err := srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol})
    36  	require.NoError(err)
    37  
    38  	claim := &structs.CSIVolumeClaim{
    39  		Mode:  structs.CSIVolumeClaimGC,
    40  		State: structs.CSIVolumeClaimStateNodeDetached,
    41  	}
    42  	index++
    43  	err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim)
    44  	require.NoError(err)
    45  	require.Eventually(func() bool {
    46  		return 1 == len(watcher.watchers)
    47  	}, time.Second, 10*time.Millisecond)
    48  
    49  	watcher.SetEnabled(false, srv.State())
    50  	require.Equal(0, len(watcher.watchers))
    51  }
    52  
    53  // TestVolumeWatch_Checkpoint tests the checkpointing of progress across
    54  // leader leader step-up/step-down
    55  func TestVolumeWatch_Checkpoint(t *testing.T) {
    56  	t.Parallel()
    57  	require := require.New(t)
    58  
    59  	srv := &MockRPCServer{}
    60  	srv.state = state.TestStateStore(t)
    61  	index := uint64(100)
    62  
    63  	watcher := NewVolumesWatcher(testlog.HCLogger(t), srv, "")
    64  
    65  	plugin := mock.CSIPlugin()
    66  	node := testNode(plugin, srv.State())
    67  	alloc := mock.Alloc()
    68  	alloc.ClientStatus = structs.AllocClientStatusComplete
    69  	vol := testVolume(plugin, alloc, node.ID)
    70  
    71  	watcher.SetEnabled(true, srv.State())
    72  
    73  	index++
    74  	err := srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol})
    75  	require.NoError(err)
    76  
    77  	// we should get or start up a watcher when we get an update for
    78  	// the volume from the state store
    79  	require.Eventually(func() bool {
    80  		return 1 == len(watcher.watchers)
    81  	}, time.Second, 10*time.Millisecond)
    82  
    83  	// step-down (this is sync, but step-up is async)
    84  	watcher.SetEnabled(false, srv.State())
    85  	require.Equal(0, len(watcher.watchers))
    86  
    87  	// step-up again
    88  	watcher.SetEnabled(true, srv.State())
    89  	require.Eventually(func() bool {
    90  		return 1 == len(watcher.watchers) &&
    91  			!watcher.watchers[vol.ID+vol.Namespace].isRunning()
    92  	}, time.Second, 10*time.Millisecond)
    93  }
    94  
    95  // TestVolumeWatch_StartStop tests the start and stop of the watcher when
    96  // it receives notifcations and has completed its work
    97  func TestVolumeWatch_StartStop(t *testing.T) {
    98  	t.Parallel()
    99  	require := require.New(t)
   100  
   101  	srv := &MockStatefulRPCServer{}
   102  	srv.state = state.TestStateStore(t)
   103  	index := uint64(100)
   104  	watcher := NewVolumesWatcher(testlog.HCLogger(t), srv, "")
   105  
   106  	watcher.SetEnabled(true, srv.State())
   107  	require.Equal(0, len(watcher.watchers))
   108  
   109  	plugin := mock.CSIPlugin()
   110  	node := testNode(plugin, srv.State())
   111  	alloc1 := mock.Alloc()
   112  	alloc1.ClientStatus = structs.AllocClientStatusRunning
   113  	alloc2 := mock.Alloc()
   114  	alloc2.Job = alloc1.Job
   115  	alloc2.ClientStatus = structs.AllocClientStatusRunning
   116  	index++
   117  	err := srv.State().UpsertJob(structs.MsgTypeTestSetup, index, alloc1.Job)
   118  	require.NoError(err)
   119  	index++
   120  	err = srv.State().UpsertAllocs(structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc1, alloc2})
   121  	require.NoError(err)
   122  
   123  	// register a volume
   124  	vol := testVolume(plugin, alloc1, node.ID)
   125  	index++
   126  	err = srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol})
   127  	require.NoError(err)
   128  
   129  	// assert we get a watcher; there are no claims so it should immediately stop
   130  	require.Eventually(func() bool {
   131  		return 1 == len(watcher.watchers) &&
   132  			!watcher.watchers[vol.ID+vol.Namespace].isRunning()
   133  	}, time.Second*2, 10*time.Millisecond)
   134  
   135  	// claim the volume for both allocs
   136  	claim := &structs.CSIVolumeClaim{
   137  		AllocationID: alloc1.ID,
   138  		NodeID:       node.ID,
   139  		Mode:         structs.CSIVolumeClaimRead,
   140  	}
   141  	index++
   142  	err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim)
   143  	require.NoError(err)
   144  	claim.AllocationID = alloc2.ID
   145  	index++
   146  	err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim)
   147  	require.NoError(err)
   148  
   149  	// reap the volume and assert nothing has happened
   150  	claim = &structs.CSIVolumeClaim{
   151  		AllocationID: alloc1.ID,
   152  		NodeID:       node.ID,
   153  	}
   154  	index++
   155  	err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim)
   156  	require.NoError(err)
   157  
   158  	ws := memdb.NewWatchSet()
   159  	vol, _ = srv.State().CSIVolumeByID(ws, vol.Namespace, vol.ID)
   160  	require.Equal(2, len(vol.ReadAllocs))
   161  
   162  	// alloc becomes terminal
   163  	alloc1.ClientStatus = structs.AllocClientStatusComplete
   164  	index++
   165  	err = srv.State().UpsertAllocs(structs.MsgTypeTestSetup, index, []*structs.Allocation{alloc1})
   166  	require.NoError(err)
   167  	index++
   168  	claim.State = structs.CSIVolumeClaimStateReadyToFree
   169  	err = srv.State().CSIVolumeClaim(index, vol.Namespace, vol.ID, claim)
   170  	require.NoError(err)
   171  
   172  	// 1 claim has been released and watcher stops
   173  	require.Eventually(func() bool {
   174  		ws := memdb.NewWatchSet()
   175  		vol, _ := srv.State().CSIVolumeByID(ws, vol.Namespace, vol.ID)
   176  		return len(vol.ReadAllocs) == 1 && len(vol.PastClaims) == 0
   177  	}, time.Second*2, 10*time.Millisecond)
   178  
   179  	require.Eventually(func() bool {
   180  		return !watcher.watchers[vol.ID+vol.Namespace].isRunning()
   181  	}, time.Second*5, 10*time.Millisecond)
   182  }
   183  
   184  // TestVolumeWatch_RegisterDeregister tests the start and stop of
   185  // watchers around registration
   186  func TestVolumeWatch_RegisterDeregister(t *testing.T) {
   187  	t.Parallel()
   188  	require := require.New(t)
   189  
   190  	srv := &MockStatefulRPCServer{}
   191  	srv.state = state.TestStateStore(t)
   192  
   193  	index := uint64(100)
   194  
   195  	watcher := NewVolumesWatcher(testlog.HCLogger(t), srv, "")
   196  
   197  	watcher.SetEnabled(true, srv.State())
   198  	require.Equal(0, len(watcher.watchers))
   199  
   200  	plugin := mock.CSIPlugin()
   201  	alloc := mock.Alloc()
   202  	alloc.ClientStatus = structs.AllocClientStatusComplete
   203  
   204  	// register a volume without claims
   205  	vol := mock.CSIVolume(plugin)
   206  	index++
   207  	err := srv.State().CSIVolumeRegister(index, []*structs.CSIVolume{vol})
   208  	require.NoError(err)
   209  
   210  	// watcher should be started but immediately stopped
   211  	require.Eventually(func() bool {
   212  		return 1 == len(watcher.watchers)
   213  	}, time.Second, 10*time.Millisecond)
   214  
   215  	require.False(watcher.watchers[vol.ID+vol.Namespace].isRunning())
   216  }