github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/allocrunner/csi_hook.go (about)

     1  package allocrunner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strings"
     7  	"sync"
     8  	"time"
     9  
    10  	hclog "github.com/hashicorp/go-hclog"
    11  	multierror "github.com/hashicorp/go-multierror"
    12  	"github.com/hashicorp/nomad/client/pluginmanager/csimanager"
    13  	"github.com/hashicorp/nomad/helper"
    14  	"github.com/hashicorp/nomad/nomad/structs"
    15  	"github.com/hashicorp/nomad/plugins/drivers"
    16  )
    17  
    18  // csiHook will wait for remote csi volumes to be attached to the host before
    19  // continuing.
    20  //
    21  // It is a noop for allocs that do not depend on CSI Volumes.
    22  type csiHook struct {
    23  	alloc      *structs.Allocation
    24  	logger     hclog.Logger
    25  	csimanager csimanager.Manager
    26  
    27  	// interfaces implemented by the allocRunner
    28  	rpcClient            RPCer
    29  	taskCapabilityGetter taskCapabilityGetter
    30  	updater              hookResourceSetter
    31  
    32  	nodeSecret         string
    33  	volumeRequests     map[string]*volumeAndRequest
    34  	minBackoffInterval time.Duration
    35  	maxBackoffInterval time.Duration
    36  	maxBackoffDuration time.Duration
    37  
    38  	shutdownCtx      context.Context
    39  	shutdownCancelFn context.CancelFunc
    40  }
    41  
    42  // implemented by allocrunner
    43  type taskCapabilityGetter interface {
    44  	GetTaskDriverCapabilities(string) (*drivers.Capabilities, error)
    45  }
    46  
    47  func newCSIHook(alloc *structs.Allocation, logger hclog.Logger, csi csimanager.Manager, rpcClient RPCer, taskCapabilityGetter taskCapabilityGetter, updater hookResourceSetter, nodeSecret string) *csiHook {
    48  
    49  	shutdownCtx, shutdownCancelFn := context.WithCancel(context.Background())
    50  
    51  	return &csiHook{
    52  		alloc:                alloc,
    53  		logger:               logger.Named("csi_hook"),
    54  		csimanager:           csi,
    55  		rpcClient:            rpcClient,
    56  		taskCapabilityGetter: taskCapabilityGetter,
    57  		updater:              updater,
    58  		nodeSecret:           nodeSecret,
    59  		volumeRequests:       map[string]*volumeAndRequest{},
    60  		minBackoffInterval:   time.Second,
    61  		maxBackoffInterval:   time.Minute,
    62  		maxBackoffDuration:   time.Hour * 24,
    63  		shutdownCtx:          shutdownCtx,
    64  		shutdownCancelFn:     shutdownCancelFn,
    65  	}
    66  }
    67  
    68  func (c *csiHook) Name() string {
    69  	return "csi_hook"
    70  }
    71  
    72  func (c *csiHook) Prerun() error {
    73  	if !c.shouldRun() {
    74  		return nil
    75  	}
    76  
    77  	volumes, err := c.claimVolumesFromAlloc()
    78  	if err != nil {
    79  		return fmt.Errorf("claim volumes: %v", err)
    80  	}
    81  	c.volumeRequests = volumes
    82  
    83  	mounts := make(map[string]*csimanager.MountInfo, len(volumes))
    84  	for alias, pair := range volumes {
    85  
    86  		// We use this context only to attach hclog to the gRPC
    87  		// context. The lifetime is the lifetime of the gRPC stream,
    88  		// not specific RPC timeouts, but we manage the stream
    89  		// lifetime via Close in the pluginmanager.
    90  		mounter, err := c.csimanager.MounterForPlugin(c.shutdownCtx, pair.volume.PluginID)
    91  		if err != nil {
    92  			return err
    93  		}
    94  
    95  		usageOpts := &csimanager.UsageOptions{
    96  			ReadOnly:       pair.request.ReadOnly,
    97  			AttachmentMode: pair.request.AttachmentMode,
    98  			AccessMode:     pair.request.AccessMode,
    99  			MountOptions:   pair.request.MountOptions,
   100  		}
   101  
   102  		mountInfo, err := mounter.MountVolume(
   103  			c.shutdownCtx, pair.volume, c.alloc, usageOpts, pair.publishContext)
   104  		if err != nil {
   105  			return err
   106  		}
   107  
   108  		mounts[alias] = mountInfo
   109  	}
   110  
   111  	res := c.updater.GetAllocHookResources()
   112  	res.CSIMounts = mounts
   113  	c.updater.SetAllocHookResources(res)
   114  
   115  	return nil
   116  }
   117  
   118  // Postrun sends an RPC to the server to unpublish the volume. This may
   119  // forward client RPCs to the node plugins or to the controller plugins,
   120  // depending on whether other allocations on this node have claims on this
   121  // volume.
   122  func (c *csiHook) Postrun() error {
   123  	if !c.shouldRun() {
   124  		return nil
   125  	}
   126  
   127  	var wg sync.WaitGroup
   128  	errs := make(chan error, len(c.volumeRequests))
   129  
   130  	for _, pair := range c.volumeRequests {
   131  		wg.Add(1)
   132  		// CSI RPCs can potentially take a long time. Split the work
   133  		// into goroutines so that operators could potentially reuse
   134  		// one of a set of volumes
   135  		go func(pair *volumeAndRequest) {
   136  			defer wg.Done()
   137  			err := c.unmountImpl(pair)
   138  			if err != nil {
   139  				// we can recover an unmount failure if the operator
   140  				// brings the plugin back up, so retry every few minutes
   141  				// but eventually give up. Don't block shutdown so that
   142  				// we don't block shutting down the client in -dev mode
   143  				go func(pair *volumeAndRequest) {
   144  					err := c.unmountWithRetry(pair)
   145  					if err != nil {
   146  						c.logger.Error("volume could not be unmounted")
   147  					}
   148  					err = c.unpublish(pair)
   149  					if err != nil {
   150  						c.logger.Error("volume could not be unpublished")
   151  					}
   152  				}(pair)
   153  			}
   154  
   155  			// we can't recover from this RPC error client-side; the
   156  			// volume claim GC job will have to clean up for us once
   157  			// the allocation is marked terminal
   158  			errs <- c.unpublish(pair)
   159  		}(pair)
   160  	}
   161  
   162  	wg.Wait()
   163  	close(errs) // so we don't block waiting if there were no errors
   164  
   165  	var mErr *multierror.Error
   166  	for err := range errs {
   167  		mErr = multierror.Append(mErr, err)
   168  	}
   169  
   170  	return mErr.ErrorOrNil()
   171  }
   172  
   173  type volumeAndRequest struct {
   174  	volume  *structs.CSIVolume
   175  	request *structs.VolumeRequest
   176  
   177  	// When volumeAndRequest was returned from a volume claim, this field will be
   178  	// populated for plugins that require it.
   179  	publishContext map[string]string
   180  }
   181  
   182  // claimVolumesFromAlloc is used by the pre-run hook to fetch all of the volume
   183  // metadata and claim it for use by this alloc/node at the same time.
   184  func (c *csiHook) claimVolumesFromAlloc() (map[string]*volumeAndRequest, error) {
   185  	result := make(map[string]*volumeAndRequest)
   186  	tg := c.alloc.Job.LookupTaskGroup(c.alloc.TaskGroup)
   187  	supportsVolumes := false
   188  
   189  	for _, task := range tg.Tasks {
   190  		caps, err := c.taskCapabilityGetter.GetTaskDriverCapabilities(task.Name)
   191  		if err != nil {
   192  			return nil, fmt.Errorf("could not validate task driver capabilities: %v", err)
   193  		}
   194  
   195  		if caps.MountConfigs == drivers.MountConfigSupportNone {
   196  			continue
   197  		}
   198  
   199  		supportsVolumes = true
   200  		break
   201  	}
   202  
   203  	if !supportsVolumes {
   204  		return nil, fmt.Errorf("no task supports CSI")
   205  	}
   206  
   207  	// Initially, populate the result map with all of the requests
   208  	for alias, volumeRequest := range tg.Volumes {
   209  		if volumeRequest.Type == structs.VolumeTypeCSI {
   210  			result[alias] = &volumeAndRequest{request: volumeRequest}
   211  		}
   212  	}
   213  
   214  	// Iterate over the result map and upsert the volume field as each volume gets
   215  	// claimed by the server.
   216  	for alias, pair := range result {
   217  		claimType := structs.CSIVolumeClaimWrite
   218  		if pair.request.ReadOnly {
   219  			claimType = structs.CSIVolumeClaimRead
   220  		}
   221  
   222  		source := pair.request.Source
   223  		if pair.request.PerAlloc {
   224  			source = source + structs.AllocSuffix(c.alloc.Name)
   225  		}
   226  
   227  		req := &structs.CSIVolumeClaimRequest{
   228  			VolumeID:       source,
   229  			AllocationID:   c.alloc.ID,
   230  			NodeID:         c.alloc.NodeID,
   231  			Claim:          claimType,
   232  			AccessMode:     pair.request.AccessMode,
   233  			AttachmentMode: pair.request.AttachmentMode,
   234  			WriteRequest: structs.WriteRequest{
   235  				Region:    c.alloc.Job.Region,
   236  				Namespace: c.alloc.Job.Namespace,
   237  				AuthToken: c.nodeSecret,
   238  			},
   239  		}
   240  
   241  		resp, err := c.claimWithRetry(req)
   242  		if err != nil {
   243  			return nil, fmt.Errorf("could not claim volume %s: %w", req.VolumeID, err)
   244  		}
   245  		if resp.Volume == nil {
   246  			return nil, fmt.Errorf("Unexpected nil volume returned for ID: %v", pair.request.Source)
   247  		}
   248  
   249  		result[alias].request = pair.request
   250  		result[alias].volume = resp.Volume
   251  		result[alias].publishContext = resp.PublishContext
   252  	}
   253  
   254  	return result, nil
   255  }
   256  
   257  // claimWithRetry tries to claim the volume on the server, retrying
   258  // with exponential backoff capped to a maximum interval
   259  func (c *csiHook) claimWithRetry(req *structs.CSIVolumeClaimRequest) (*structs.CSIVolumeClaimResponse, error) {
   260  
   261  	ctx, cancel := context.WithTimeout(c.shutdownCtx, c.maxBackoffDuration)
   262  	defer cancel()
   263  
   264  	var resp structs.CSIVolumeClaimResponse
   265  	var err error
   266  	backoff := c.minBackoffInterval
   267  	t, stop := helper.NewSafeTimer(0)
   268  	defer stop()
   269  	for {
   270  		select {
   271  		case <-ctx.Done():
   272  			return nil, err
   273  		case <-t.C:
   274  		}
   275  
   276  		err = c.rpcClient.RPC("CSIVolume.Claim", req, &resp)
   277  		if err == nil {
   278  			break
   279  		}
   280  
   281  		if !isRetryableClaimRPCError(err) {
   282  			break
   283  		}
   284  
   285  		if backoff < c.maxBackoffInterval {
   286  			backoff = backoff * 2
   287  			if backoff > c.maxBackoffInterval {
   288  				backoff = c.maxBackoffInterval
   289  			}
   290  		}
   291  		c.logger.Debug(
   292  			"volume could not be claimed because it is in use", "retry_in", backoff)
   293  		t.Reset(backoff)
   294  	}
   295  	return &resp, err
   296  }
   297  
   298  // isRetryableClaimRPCError looks for errors where we need to retry
   299  // with backoff because we expect them to be eventually resolved.
   300  func isRetryableClaimRPCError(err error) bool {
   301  
   302  	// note: because these errors are returned via RPC which breaks error
   303  	// wrapping, we can't check with errors.Is and need to read the string
   304  	errMsg := err.Error()
   305  	if strings.Contains(errMsg, structs.ErrCSIVolumeMaxClaims.Error()) {
   306  		return true
   307  	}
   308  	if strings.Contains(errMsg, structs.ErrCSIClientRPCRetryable.Error()) {
   309  		return true
   310  	}
   311  	if strings.Contains(errMsg, "no servers") {
   312  		return true
   313  	}
   314  	if strings.Contains(errMsg, structs.ErrNoLeader.Error()) {
   315  		return true
   316  	}
   317  	return false
   318  }
   319  
   320  func (c *csiHook) shouldRun() bool {
   321  	tg := c.alloc.Job.LookupTaskGroup(c.alloc.TaskGroup)
   322  	for _, vol := range tg.Volumes {
   323  		if vol.Type == structs.VolumeTypeCSI {
   324  			return true
   325  		}
   326  	}
   327  
   328  	return false
   329  }
   330  
   331  func (c *csiHook) unpublish(pair *volumeAndRequest) error {
   332  
   333  	mode := structs.CSIVolumeClaimRead
   334  	if !pair.request.ReadOnly {
   335  		mode = structs.CSIVolumeClaimWrite
   336  	}
   337  
   338  	source := pair.request.Source
   339  	if pair.request.PerAlloc {
   340  		// NOTE: PerAlloc can't be set if we have canaries
   341  		source = source + structs.AllocSuffix(c.alloc.Name)
   342  	}
   343  
   344  	req := &structs.CSIVolumeUnpublishRequest{
   345  		VolumeID: source,
   346  		Claim: &structs.CSIVolumeClaim{
   347  			AllocationID: c.alloc.ID,
   348  			NodeID:       c.alloc.NodeID,
   349  			Mode:         mode,
   350  			State:        structs.CSIVolumeClaimStateUnpublishing,
   351  		},
   352  		WriteRequest: structs.WriteRequest{
   353  			Region:    c.alloc.Job.Region,
   354  			Namespace: c.alloc.Job.Namespace,
   355  			AuthToken: c.nodeSecret,
   356  		},
   357  	}
   358  
   359  	return c.rpcClient.RPC("CSIVolume.Unpublish",
   360  		req, &structs.CSIVolumeUnpublishResponse{})
   361  
   362  }
   363  
   364  // unmountWithRetry tries to unmount/unstage the volume, retrying with
   365  // exponential backoff capped to a maximum interval
   366  func (c *csiHook) unmountWithRetry(pair *volumeAndRequest) error {
   367  
   368  	ctx, cancel := context.WithTimeout(c.shutdownCtx, c.maxBackoffDuration)
   369  	defer cancel()
   370  	var err error
   371  	backoff := c.minBackoffInterval
   372  	t, stop := helper.NewSafeTimer(0)
   373  	defer stop()
   374  	for {
   375  		select {
   376  		case <-ctx.Done():
   377  			return err
   378  		case <-t.C:
   379  		}
   380  
   381  		err = c.unmountImpl(pair)
   382  		if err == nil {
   383  			break
   384  		}
   385  
   386  		if backoff < c.maxBackoffInterval {
   387  			backoff = backoff * 2
   388  			if backoff > c.maxBackoffInterval {
   389  				backoff = c.maxBackoffInterval
   390  			}
   391  		}
   392  		c.logger.Debug("volume could not be unmounted", "retry_in", backoff)
   393  		t.Reset(backoff)
   394  	}
   395  	return nil
   396  }
   397  
   398  // unmountImpl implements the call to the CSI plugin manager to
   399  // unmount the volume. Each retry will write an "Unmount volume"
   400  // NodeEvent
   401  func (c *csiHook) unmountImpl(pair *volumeAndRequest) error {
   402  
   403  	mounter, err := c.csimanager.MounterForPlugin(c.shutdownCtx, pair.volume.PluginID)
   404  	if err != nil {
   405  		return err
   406  	}
   407  
   408  	usageOpts := &csimanager.UsageOptions{
   409  		ReadOnly:       pair.request.ReadOnly,
   410  		AttachmentMode: pair.request.AttachmentMode,
   411  		AccessMode:     pair.request.AccessMode,
   412  		MountOptions:   pair.request.MountOptions,
   413  	}
   414  
   415  	return mounter.UnmountVolume(c.shutdownCtx,
   416  		pair.volume.ID, pair.volume.RemoteID(), c.alloc.ID, usageOpts)
   417  }
   418  
   419  // Shutdown will get called when the client is gracefully
   420  // stopping. Cancel our shutdown context so that we don't block client
   421  // shutdown while in the CSI RPC retry loop.
   422  func (c *csiHook) Shutdown() {
   423  	c.logger.Trace("shutting down hook")
   424  	c.shutdownCancelFn()
   425  }
   426  
   427  // Destroy will get called when an allocation gets GC'd on the client
   428  // or when a -dev mode client is stopped. Cancel our shutdown context
   429  // so that we don't block client shutdown while in the CSI RPC retry
   430  // loop.
   431  func (c *csiHook) Destroy() {
   432  	c.logger.Trace("destroying hook")
   433  	c.shutdownCancelFn()
   434  }