github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/allocrunner/taskrunner/plugin_supervisor_hook.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"os"
     7  	"path/filepath"
     8  	"sync"
     9  	"time"
    10  
    11  	hclog "github.com/hashicorp/go-hclog"
    12  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
    13  	ti "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces"
    14  	"github.com/hashicorp/nomad/client/dynamicplugins"
    15  	"github.com/hashicorp/nomad/nomad/structs"
    16  	"github.com/hashicorp/nomad/plugins/csi"
    17  	"github.com/hashicorp/nomad/plugins/drivers"
    18  )
    19  
    20  // csiPluginSupervisorHook manages supervising plugins that are running as Nomad
    21  // tasks. These plugins will be fingerprinted and it will manage connecting them
    22  // to their requisite plugin manager.
    23  //
    24  // It provides a few things to a plugin task running inside Nomad. These are:
    25  //   - A mount to the `csi_plugin.mount_dir` where the plugin will create its csi.sock
    26  //   - A mount to `local/csi` that node plugins will use to stage volume mounts.
    27  //   - When the task has started, it starts a loop of attempting to connect to the
    28  //     plugin, to perform initial fingerprinting of the plugins capabilities before
    29  //     notifying the plugin manager of the plugin.
    30  type csiPluginSupervisorHook struct {
    31  	logger           hclog.Logger
    32  	alloc            *structs.Allocation
    33  	task             *structs.Task
    34  	runner           *TaskRunner
    35  	mountPoint       string
    36  	socketMountPoint string
    37  	socketPath       string
    38  
    39  	caps *drivers.Capabilities
    40  
    41  	// eventEmitter is used to emit events to the task
    42  	eventEmitter ti.EventEmitter
    43  	lifecycle    ti.TaskLifecycle
    44  
    45  	shutdownCtx      context.Context
    46  	shutdownCancelFn context.CancelFunc
    47  	runOnce          sync.Once
    48  
    49  	// previousHealthstate is used by the supervisor goroutine to track historic
    50  	// health states for gating task events.
    51  	previousHealthState bool
    52  }
    53  
    54  type csiPluginSupervisorHookConfig struct {
    55  	clientStateDirPath string
    56  	events             ti.EventEmitter
    57  	runner             *TaskRunner
    58  	lifecycle          ti.TaskLifecycle
    59  	capabilities       *drivers.Capabilities
    60  	logger             hclog.Logger
    61  }
    62  
    63  // The plugin supervisor uses the PrestartHook mechanism to setup the requisite
    64  // mount points and configuration for the task that exposes a CSI plugin.
    65  var _ interfaces.TaskPrestartHook = &csiPluginSupervisorHook{}
    66  
    67  // The plugin supervisor uses the PoststartHook mechanism to start polling the
    68  // plugin for readiness and supported functionality before registering the
    69  // plugin with the catalog.
    70  var _ interfaces.TaskPoststartHook = &csiPluginSupervisorHook{}
    71  
    72  // The plugin supervisor uses the StopHook mechanism to deregister the plugin
    73  // with the catalog and to ensure any mounts are cleaned up.
    74  var _ interfaces.TaskStopHook = &csiPluginSupervisorHook{}
    75  
    76  // This hook creates a csi/ directory within the client's datadir used to
    77  // manage plugins and mount points volumes. The layout is as follows:
    78  
    79  // plugins/
    80  //    {alloc-id}/csi.sock
    81  //       Per-allocation directories of unix domain sockets used to communicate
    82  //       with the CSI plugin. Nomad creates the directory and the plugin creates
    83  //       the socket file. This directory is bind-mounted to the
    84  //       csi_plugin.mount_dir in the plugin task.
    85  //
    86  // {plugin-type}/{plugin-id}/
    87  //    staging/
    88  //       {volume-id}/{usage-mode}/
    89  //          Intermediate mount point used by node plugins that support
    90  //          NODE_STAGE_UNSTAGE capability.
    91  //
    92  //    per-alloc/
    93  //       {alloc-id}/{volume-id}/{usage-mode}/
    94  //          Mount point bound from the staging directory into tasks that use
    95  //          the mounted volumes
    96  
    97  func newCSIPluginSupervisorHook(config *csiPluginSupervisorHookConfig) *csiPluginSupervisorHook {
    98  	task := config.runner.Task()
    99  
   100  	pluginRoot := filepath.Join(config.clientStateDirPath, "csi",
   101  		string(task.CSIPluginConfig.Type), task.CSIPluginConfig.ID)
   102  
   103  	socketMountPoint := filepath.Join(config.clientStateDirPath, "csi",
   104  		"plugins", config.runner.Alloc().ID)
   105  
   106  	// In v1.3.0, Nomad started instructing CSI plugins to stage and publish
   107  	// within /local/csi. Plugins deployed after the introduction of
   108  	// StagePublishBaseDir default to StagePublishBaseDir = /local/csi. However,
   109  	// plugins deployed between v1.3.0 and the introduction of
   110  	// StagePublishBaseDir have StagePublishBaseDir = "". Default to /local/csi here
   111  	// to avoid breaking plugins that aren't redeployed.
   112  	if task.CSIPluginConfig.StagePublishBaseDir == "" {
   113  		task.CSIPluginConfig.StagePublishBaseDir = filepath.Join("/local", "csi")
   114  	}
   115  
   116  	if task.CSIPluginConfig.HealthTimeout == 0 {
   117  		task.CSIPluginConfig.HealthTimeout = 30 * time.Second
   118  	}
   119  
   120  	shutdownCtx, cancelFn := context.WithCancel(context.Background())
   121  
   122  	hook := &csiPluginSupervisorHook{
   123  		alloc:            config.runner.Alloc(),
   124  		runner:           config.runner,
   125  		lifecycle:        config.lifecycle,
   126  		logger:           config.logger,
   127  		task:             task,
   128  		mountPoint:       pluginRoot,
   129  		socketMountPoint: socketMountPoint,
   130  		caps:             config.capabilities,
   131  		shutdownCtx:      shutdownCtx,
   132  		shutdownCancelFn: cancelFn,
   133  		eventEmitter:     config.events,
   134  	}
   135  
   136  	return hook
   137  }
   138  
   139  func (*csiPluginSupervisorHook) Name() string {
   140  	return "csi_plugin_supervisor"
   141  }
   142  
   143  // Prestart is called before the task is started including after every
   144  // restart. This requires that the mount paths for a plugin be
   145  // idempotent, despite us not knowing the name of the plugin ahead of
   146  // time.  Because of this, we use the allocid_taskname as the unique
   147  // identifier for a plugin on the filesystem.
   148  func (h *csiPluginSupervisorHook) Prestart(ctx context.Context,
   149  	req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
   150  
   151  	// Create the mount directory that the container will access if it doesn't
   152  	// already exist. Default to only nomad user access.
   153  	if err := os.MkdirAll(h.mountPoint, 0700); err != nil && !os.IsExist(err) {
   154  		return fmt.Errorf("failed to create mount point: %v", err)
   155  	}
   156  
   157  	if err := os.MkdirAll(h.socketMountPoint, 0700); err != nil && !os.IsExist(err) {
   158  		return fmt.Errorf("failed to create socket mount point: %v", err)
   159  	}
   160  
   161  	// where the socket will be mounted
   162  	configMount := &drivers.MountConfig{
   163  		TaskPath:        h.task.CSIPluginConfig.MountDir,
   164  		HostPath:        h.socketMountPoint,
   165  		Readonly:        false,
   166  		PropagationMode: "bidirectional",
   167  	}
   168  	// where the staging and per-alloc directories will be mounted
   169  	volumeStagingMounts := &drivers.MountConfig{
   170  		TaskPath:        h.task.CSIPluginConfig.StagePublishBaseDir,
   171  		HostPath:        h.mountPoint,
   172  		Readonly:        false,
   173  		PropagationMode: "bidirectional",
   174  	}
   175  	// devices from the host
   176  	devMount := &drivers.MountConfig{
   177  		TaskPath: "/dev",
   178  		HostPath: "/dev",
   179  		Readonly: false,
   180  	}
   181  
   182  	h.setSocketHook()
   183  
   184  	if _, ok := h.task.Env["CSI_ENDPOINT"]; !ok {
   185  		switch h.caps.FSIsolation {
   186  		case drivers.FSIsolationNone:
   187  			// Plugin tasks with no filesystem isolation won't have the
   188  			// plugin dir bind-mounted to their alloc dir, but we can
   189  			// provide them the path to the socket. These Nomad-only
   190  			// plugins will need to be aware of the csi directory layout
   191  			// in the client data dir
   192  			resp.Env = map[string]string{
   193  				"CSI_ENDPOINT": "unix://" + h.socketPath}
   194  		default:
   195  			resp.Env = map[string]string{
   196  				"CSI_ENDPOINT": "unix://" + filepath.Join(
   197  					h.task.CSIPluginConfig.MountDir, structs.CSISocketName)}
   198  		}
   199  	}
   200  
   201  	mounts := ensureMountpointInserted(h.runner.hookResources.getMounts(), configMount)
   202  	mounts = ensureMountpointInserted(mounts, volumeStagingMounts)
   203  	mounts = ensureMountpointInserted(mounts, devMount)
   204  
   205  	// we normally would set resp.Mounts here but without setting the
   206  	// hookResources before returning we can get a postrun hook that's
   207  	// missing resources.
   208  	h.runner.hookResources.setMounts(mounts)
   209  
   210  	return nil
   211  }
   212  
   213  func (h *csiPluginSupervisorHook) setSocketHook() {
   214  
   215  	// TODO(tgross): https://github.com/hashicorp/nomad/issues/11786
   216  	// If we're already registered, we should be able to update the
   217  	// definition in the update hook
   218  
   219  	// For backwards compatibility, ensure that we don't overwrite the
   220  	// socketPath on client restart with existing plugin allocations.
   221  	pluginInfo, _ := h.runner.dynamicRegistry.PluginForAlloc(
   222  		string(h.task.CSIPluginConfig.Type), h.task.CSIPluginConfig.ID, h.alloc.ID)
   223  	if pluginInfo != nil && pluginInfo.ConnectionInfo.SocketPath != "" {
   224  		h.socketPath = pluginInfo.ConnectionInfo.SocketPath
   225  		return
   226  	}
   227  	h.socketPath = filepath.Join(h.socketMountPoint, structs.CSISocketName)
   228  }
   229  
   230  // Poststart is called after the task has started. Poststart is not
   231  // called if the allocation is terminal.
   232  //
   233  // The context is cancelled if the task is killed.
   234  func (h *csiPluginSupervisorHook) Poststart(_ context.Context, _ *interfaces.TaskPoststartRequest, _ *interfaces.TaskPoststartResponse) error {
   235  
   236  	// If we're already running the supervisor routine, then we don't need to try
   237  	// and restart it here as it only terminates on `Stop` hooks.
   238  	h.runOnce.Do(func() {
   239  		h.setSocketHook()
   240  		go h.ensureSupervisorLoop(h.shutdownCtx)
   241  	})
   242  
   243  	return nil
   244  }
   245  
   246  // ensureSupervisorLoop should be called in a goroutine. It will terminate when
   247  // the passed in context is terminated.
   248  //
   249  // The supervisor works by:
   250  //   - Initially waiting for the plugin to become available. This loop is expensive
   251  //     and may do things like create new gRPC Clients on every iteration.
   252  //   - After receiving an initial healthy status, it will inform the plugin catalog
   253  //     of the plugin, registering it with the plugins fingerprinted capabilities.
   254  //   - We then perform a more lightweight check, simply probing the plugin on a less
   255  //     frequent interval to ensure it is still alive, emitting task events when this
   256  //     status changes.
   257  //
   258  // Deeper fingerprinting of the plugin is implemented by the csimanager.
   259  func (h *csiPluginSupervisorHook) ensureSupervisorLoop(ctx context.Context) {
   260  	client := csi.NewClient(h.socketPath, h.logger.Named("csi_client").With(
   261  		"plugin.name", h.task.CSIPluginConfig.ID,
   262  		"plugin.type", h.task.CSIPluginConfig.Type))
   263  	defer client.Close()
   264  
   265  	t := time.NewTimer(0)
   266  
   267  	// We're in Poststart at this point, so if we can't connect within
   268  	// this deadline, assume it's broken so we can restart the task
   269  	startCtx, startCancelFn := context.WithTimeout(ctx, h.task.CSIPluginConfig.HealthTimeout)
   270  	defer startCancelFn()
   271  
   272  	var err error
   273  	var pluginHealthy bool
   274  
   275  	// Step 1: Wait for the plugin to initially become available.
   276  WAITFORREADY:
   277  	for {
   278  		select {
   279  		case <-startCtx.Done():
   280  			h.kill(ctx, fmt.Errorf("CSI plugin failed probe: %v", err))
   281  			return
   282  		case <-t.C:
   283  			pluginHealthy, err = h.supervisorLoopOnce(startCtx, client)
   284  			if err != nil || !pluginHealthy {
   285  				h.logger.Debug("CSI plugin not ready", "error", err)
   286  				// Use only a short delay here to optimize for quickly
   287  				// bringing up a plugin
   288  				t.Reset(5 * time.Second)
   289  				continue
   290  			}
   291  
   292  			// Mark the plugin as healthy in a task event
   293  			h.logger.Debug("CSI plugin is ready")
   294  			h.previousHealthState = pluginHealthy
   295  			event := structs.NewTaskEvent(structs.TaskPluginHealthy)
   296  			event.SetMessage(fmt.Sprintf("plugin: %s", h.task.CSIPluginConfig.ID))
   297  			h.eventEmitter.EmitEvent(event)
   298  
   299  			break WAITFORREADY
   300  		}
   301  	}
   302  
   303  	// Step 2: Register the plugin with the catalog.
   304  	deregisterPluginFn, err := h.registerPlugin(client, h.socketPath)
   305  	if err != nil {
   306  		h.kill(ctx, fmt.Errorf("CSI plugin failed to register: %v", err))
   307  		return
   308  	}
   309  	// De-register plugins on task shutdown
   310  	defer deregisterPluginFn()
   311  
   312  	// Step 3: Start the lightweight supervisor loop. At this point,
   313  	// probe failures don't cause the task to restart
   314  	t.Reset(0)
   315  	for {
   316  		select {
   317  		case <-ctx.Done():
   318  			return
   319  		case <-t.C:
   320  			pluginHealthy, err := h.supervisorLoopOnce(ctx, client)
   321  			if err != nil {
   322  				h.logger.Error("CSI plugin fingerprinting failed", "error", err)
   323  			}
   324  
   325  			// The plugin has transitioned to a healthy state. Emit an event.
   326  			if !h.previousHealthState && pluginHealthy {
   327  				event := structs.NewTaskEvent(structs.TaskPluginHealthy)
   328  				event.SetMessage(fmt.Sprintf("plugin: %s", h.task.CSIPluginConfig.ID))
   329  				h.eventEmitter.EmitEvent(event)
   330  			}
   331  
   332  			// The plugin has transitioned to an unhealthy state. Emit an event.
   333  			if h.previousHealthState && !pluginHealthy {
   334  				event := structs.NewTaskEvent(structs.TaskPluginUnhealthy)
   335  				if err != nil {
   336  					event.SetMessage(fmt.Sprintf("Error: %v", err))
   337  				} else {
   338  					event.SetMessage("Unknown Reason")
   339  				}
   340  				h.eventEmitter.EmitEvent(event)
   341  			}
   342  
   343  			h.previousHealthState = pluginHealthy
   344  
   345  			// This loop is informational and in some plugins this may be expensive to
   346  			// validate. We use a longer timeout (30s) to avoid causing undue work.
   347  			t.Reset(30 * time.Second)
   348  		}
   349  	}
   350  }
   351  
   352  func (h *csiPluginSupervisorHook) registerPlugin(client csi.CSIPlugin, socketPath string) (func(), error) {
   353  	// At this point we know the plugin is ready and we can fingerprint it
   354  	// to get its vendor name and version
   355  	info, err := client.PluginInfo()
   356  	if err != nil {
   357  		return nil, fmt.Errorf("failed to probe plugin: %v", err)
   358  	}
   359  
   360  	mkInfoFn := func(pluginType string) *dynamicplugins.PluginInfo {
   361  		return &dynamicplugins.PluginInfo{
   362  			Type:    pluginType,
   363  			Name:    h.task.CSIPluginConfig.ID,
   364  			Version: info.PluginVersion,
   365  			ConnectionInfo: &dynamicplugins.PluginConnectionInfo{
   366  				SocketPath: socketPath,
   367  			},
   368  			AllocID: h.alloc.ID,
   369  			Options: map[string]string{
   370  				"Provider":            info.Name, // vendor name
   371  				"MountPoint":          h.mountPoint,
   372  				"ContainerMountPoint": h.task.CSIPluginConfig.StagePublishBaseDir,
   373  			},
   374  		}
   375  	}
   376  
   377  	registrations := []*dynamicplugins.PluginInfo{}
   378  
   379  	switch h.task.CSIPluginConfig.Type {
   380  	case structs.CSIPluginTypeController:
   381  		registrations = append(registrations, mkInfoFn(dynamicplugins.PluginTypeCSIController))
   382  	case structs.CSIPluginTypeNode:
   383  		registrations = append(registrations, mkInfoFn(dynamicplugins.PluginTypeCSINode))
   384  	case structs.CSIPluginTypeMonolith:
   385  		registrations = append(registrations, mkInfoFn(dynamicplugins.PluginTypeCSIController))
   386  		registrations = append(registrations, mkInfoFn(dynamicplugins.PluginTypeCSINode))
   387  	}
   388  
   389  	deregistrationFns := []func(){}
   390  
   391  	for _, reg := range registrations {
   392  		if err := h.runner.dynamicRegistry.RegisterPlugin(reg); err != nil {
   393  			for _, fn := range deregistrationFns {
   394  				fn()
   395  			}
   396  			return nil, err
   397  		}
   398  
   399  		// need to rebind these so that each deregistration function
   400  		// closes over its own registration
   401  		rname := reg.Name
   402  		rtype := reg.Type
   403  		allocID := reg.AllocID
   404  		deregistrationFns = append(deregistrationFns, func() {
   405  			err := h.runner.dynamicRegistry.DeregisterPlugin(rtype, rname, allocID)
   406  			if err != nil {
   407  				h.logger.Error("failed to deregister csi plugin", "name", rname, "type", rtype, "error", err)
   408  			}
   409  		})
   410  	}
   411  
   412  	return func() {
   413  		for _, fn := range deregistrationFns {
   414  			fn()
   415  		}
   416  	}, nil
   417  }
   418  
   419  func (h *csiPluginSupervisorHook) supervisorLoopOnce(ctx context.Context, client csi.CSIPlugin) (bool, error) {
   420  	probeCtx, probeCancelFn := context.WithTimeout(ctx, 5*time.Second)
   421  	defer probeCancelFn()
   422  
   423  	healthy, err := client.PluginProbe(probeCtx)
   424  	if err != nil {
   425  		return false, err
   426  	}
   427  
   428  	return healthy, nil
   429  }
   430  
   431  // Stop is called after the task has exited and will not be started
   432  // again. It is the only hook guaranteed to be executed whenever
   433  // TaskRunner.Run is called (and not gracefully shutting down).
   434  // Therefore it may be called even when prestart and the other hooks
   435  // have not.
   436  //
   437  // Stop hooks must be idempotent. The context is cancelled prematurely if the
   438  // task is killed.
   439  func (h *csiPluginSupervisorHook) Stop(_ context.Context, req *interfaces.TaskStopRequest, _ *interfaces.TaskStopResponse) error {
   440  	err := os.RemoveAll(h.socketMountPoint)
   441  	if err != nil {
   442  		h.logger.Error("could not remove plugin socket directory", "dir", h.socketMountPoint, "error", err)
   443  	}
   444  	h.shutdownCancelFn()
   445  	return nil
   446  }
   447  
   448  func (h *csiPluginSupervisorHook) kill(ctx context.Context, reason error) {
   449  	h.logger.Error("killing task because plugin failed", "error", reason)
   450  	event := structs.NewTaskEvent(structs.TaskPluginUnhealthy)
   451  	event.SetMessage(fmt.Sprintf("Error: %v", reason.Error()))
   452  	h.eventEmitter.EmitEvent(event)
   453  
   454  	if err := h.lifecycle.Kill(ctx,
   455  		structs.NewTaskEvent(structs.TaskKilling).
   456  			SetFailsTask().
   457  			SetDisplayMessage(fmt.Sprintf("CSI plugin did not become healthy before configured %v health timeout", h.task.CSIPluginConfig.HealthTimeout.String())),
   458  	); err != nil {
   459  		h.logger.Error("failed to kill task", "kill_reason", reason, "error", err)
   460  	}
   461  }
   462  
   463  func ensureMountpointInserted(mounts []*drivers.MountConfig, mount *drivers.MountConfig) []*drivers.MountConfig {
   464  	for _, mnt := range mounts {
   465  		if mnt.IsEqual(mount) {
   466  			return mounts
   467  		}
   468  	}
   469  
   470  	mounts = append(mounts, mount)
   471  	return mounts
   472  }