github.com/hernad/nomad@v1.6.112/nomad/volumewatcher/volume_watcher_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package volumewatcher
     5  
     6  import (
     7  	"context"
     8  	"testing"
     9  
    10  	"github.com/hernad/nomad/ci"
    11  	"github.com/hernad/nomad/helper/testlog"
    12  	"github.com/hernad/nomad/nomad/mock"
    13  	"github.com/hernad/nomad/nomad/state"
    14  	"github.com/hernad/nomad/nomad/structs"
    15  	"github.com/stretchr/testify/require"
    16  )
    17  
    18  func TestVolumeWatch_Reap(t *testing.T) {
    19  	ci.Parallel(t)
    20  	require := require.New(t)
    21  
    22  	srv := &MockRPCServer{
    23  		state: state.TestStateStore(t),
    24  	}
    25  
    26  	plugin := mock.CSIPlugin()
    27  	node := testNode(plugin, srv.State())
    28  	alloc := mock.Alloc()
    29  	alloc.NodeID = node.ID
    30  	alloc.ClientStatus = structs.AllocClientStatusComplete
    31  	vol := testVolume(plugin, alloc, node.ID)
    32  	vol.PastClaims = vol.ReadClaims
    33  
    34  	ctx, exitFn := context.WithCancel(context.Background())
    35  	w := &volumeWatcher{
    36  		v:      vol,
    37  		rpc:    srv,
    38  		state:  srv.State(),
    39  		ctx:    ctx,
    40  		exitFn: exitFn,
    41  		logger: testlog.HCLogger(t),
    42  	}
    43  
    44  	vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy())
    45  	err := w.volumeReapImpl(vol)
    46  	require.NoError(err)
    47  
    48  	// past claim from a previous pass
    49  	vol.PastClaims = map[string]*structs.CSIVolumeClaim{
    50  		alloc.ID: {
    51  			NodeID: node.ID,
    52  			Mode:   structs.CSIVolumeClaimRead,
    53  			State:  structs.CSIVolumeClaimStateNodeDetached,
    54  		},
    55  	}
    56  	vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy())
    57  	err = w.volumeReapImpl(vol)
    58  	require.NoError(err)
    59  	require.Len(vol.PastClaims, 1)
    60  
    61  	// claim emitted by a GC event
    62  	vol.PastClaims = map[string]*structs.CSIVolumeClaim{
    63  		"": {
    64  			NodeID: node.ID,
    65  			Mode:   structs.CSIVolumeClaimGC,
    66  		},
    67  	}
    68  	vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy())
    69  	err = w.volumeReapImpl(vol)
    70  	require.NoError(err)
    71  	require.Len(vol.PastClaims, 2) // alloc claim + GC claim
    72  
    73  	// release claims of a previously GC'd allocation
    74  	vol.ReadAllocs[alloc.ID] = nil
    75  	vol.PastClaims = map[string]*structs.CSIVolumeClaim{
    76  		"": {
    77  			NodeID: node.ID,
    78  			Mode:   structs.CSIVolumeClaimRead,
    79  		},
    80  	}
    81  	vol, _ = srv.State().CSIVolumeDenormalize(nil, vol.Copy())
    82  	err = w.volumeReapImpl(vol)
    83  	require.NoError(err)
    84  	require.Len(vol.PastClaims, 2) // alloc claim + GC claim
    85  }
    86  
    87  func TestVolumeReapBadState(t *testing.T) {
    88  	ci.Parallel(t)
    89  
    90  	store := state.TestStateStore(t)
    91  	err := state.TestBadCSIState(t, store)
    92  	require.NoError(t, err)
    93  	srv := &MockRPCServer{
    94  		state: store,
    95  	}
    96  
    97  	vol, err := srv.state.CSIVolumeByID(nil,
    98  		structs.DefaultNamespace, "csi-volume-nfs0")
    99  	require.NoError(t, err)
   100  	srv.state.CSIVolumeDenormalize(nil, vol)
   101  
   102  	ctx, exitFn := context.WithCancel(context.Background())
   103  	w := &volumeWatcher{
   104  		v:      vol,
   105  		rpc:    srv,
   106  		state:  srv.State(),
   107  		ctx:    ctx,
   108  		exitFn: exitFn,
   109  		logger: testlog.HCLogger(t),
   110  	}
   111  
   112  	err = w.volumeReapImpl(vol)
   113  	require.NoError(t, err)
   114  	require.Equal(t, 2, srv.countCSIUnpublish)
   115  }