github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/allocrunner/taskrunner/task_runner_hooks.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/LK4D4/joincontext"
    10  	multierror "github.com/hashicorp/go-multierror"
    11  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
    12  	"github.com/hashicorp/nomad/client/allocrunner/taskrunner/state"
    13  	"github.com/hashicorp/nomad/nomad/structs"
    14  	"github.com/hashicorp/nomad/plugins/drivers"
    15  )
    16  
    17  // hookResources captures the resources for the task provided by hooks.
    18  type hookResources struct {
    19  	Devices []*drivers.DeviceConfig
    20  	Mounts  []*drivers.MountConfig
    21  	sync.RWMutex
    22  }
    23  
    24  func (h *hookResources) setDevices(d []*drivers.DeviceConfig) {
    25  	h.Lock()
    26  	h.Devices = d
    27  	h.Unlock()
    28  }
    29  
    30  func (h *hookResources) getDevices() []*drivers.DeviceConfig {
    31  	h.RLock()
    32  	defer h.RUnlock()
    33  	return h.Devices
    34  }
    35  
    36  func (h *hookResources) setMounts(m []*drivers.MountConfig) {
    37  	h.Lock()
    38  	h.Mounts = m
    39  	h.Unlock()
    40  }
    41  
    42  func (h *hookResources) getMounts() []*drivers.MountConfig {
    43  	h.RLock()
    44  	defer h.RUnlock()
    45  	return h.Mounts
    46  }
    47  
    48  // initHooks initializes the tasks hooks.
    49  func (tr *TaskRunner) initHooks() {
    50  	hookLogger := tr.logger.Named("task_hook")
    51  	task := tr.Task()
    52  
    53  	tr.logmonHookConfig = newLogMonHookConfig(task.Name, tr.taskDir.LogDir)
    54  
    55  	// Add the hook resources
    56  	tr.hookResources = &hookResources{}
    57  
    58  	// Create the task directory hook. This is run first to ensure the
    59  	// directory path exists for other hooks.
    60  	alloc := tr.Alloc()
    61  	tr.runnerHooks = []interfaces.TaskHook{
    62  		newValidateHook(tr.clientConfig, hookLogger),
    63  		newTaskDirHook(tr, hookLogger),
    64  		newIdentityHook(tr, hookLogger),
    65  		newLogMonHook(tr, hookLogger),
    66  		newDispatchHook(alloc, hookLogger),
    67  		newVolumeHook(tr, hookLogger),
    68  		newArtifactHook(tr, tr.getter, hookLogger),
    69  		newStatsHook(tr, tr.clientConfig.StatsCollectionInterval, hookLogger),
    70  		newDeviceHook(tr.devicemanager, hookLogger),
    71  	}
    72  
    73  	// If the task has a CSI stanza, add the hook.
    74  	if task.CSIPluginConfig != nil {
    75  		tr.runnerHooks = append(tr.runnerHooks, newCSIPluginSupervisorHook(
    76  			&csiPluginSupervisorHookConfig{
    77  				clientStateDirPath: tr.clientConfig.StateDir,
    78  				events:             tr,
    79  				runner:             tr,
    80  				lifecycle:          tr,
    81  				capabilities:       tr.driverCapabilities,
    82  				logger:             hookLogger,
    83  			}))
    84  	}
    85  
    86  	// If Vault is enabled, add the hook
    87  	if task.Vault != nil {
    88  		tr.runnerHooks = append(tr.runnerHooks, newVaultHook(&vaultHookConfig{
    89  			vaultStanza: task.Vault,
    90  			client:      tr.vaultClient,
    91  			events:      tr,
    92  			lifecycle:   tr,
    93  			updater:     tr,
    94  			logger:      hookLogger,
    95  			alloc:       tr.Alloc(),
    96  			task:        tr.taskName,
    97  		}))
    98  	}
    99  
   100  	// Get the consul namespace for the TG of the allocation.
   101  	consulNamespace := tr.alloc.ConsulNamespace()
   102  
   103  	// Identify the service registration provider, which can differ from the
   104  	// Consul namespace depending on which provider is used.
   105  	serviceProviderNamespace := tr.alloc.ServiceProviderNamespace()
   106  
   107  	// If there are templates is enabled, add the hook
   108  	if len(task.Templates) != 0 {
   109  		tr.runnerHooks = append(tr.runnerHooks, newTemplateHook(&templateHookConfig{
   110  			logger:          hookLogger,
   111  			lifecycle:       tr,
   112  			events:          tr,
   113  			templates:       task.Templates,
   114  			clientConfig:    tr.clientConfig,
   115  			envBuilder:      tr.envBuilder,
   116  			consulNamespace: consulNamespace,
   117  			nomadNamespace:  tr.alloc.Job.Namespace,
   118  		}))
   119  	}
   120  
   121  	// Always add the service hook. A task with no services on initial registration
   122  	// may be updated to include services, which must be handled with this hook.
   123  	tr.runnerHooks = append(tr.runnerHooks, newServiceHook(serviceHookConfig{
   124  		alloc:             tr.Alloc(),
   125  		task:              tr.Task(),
   126  		providerNamespace: serviceProviderNamespace,
   127  		serviceRegWrapper: tr.serviceRegWrapper,
   128  		restarter:         tr,
   129  		logger:            hookLogger,
   130  	}))
   131  
   132  	// If this is a Connect sidecar proxy (or a Connect Native) service,
   133  	// add the sidsHook for requesting a Service Identity token (if ACLs).
   134  	if task.UsesConnect() {
   135  		// Enable the Service Identity hook only if the Nomad client is configured
   136  		// with a consul token, indicating that Consul ACLs are enabled
   137  		if tr.clientConfig.ConsulConfig.Token != "" {
   138  			tr.runnerHooks = append(tr.runnerHooks, newSIDSHook(sidsHookConfig{
   139  				alloc:      tr.Alloc(),
   140  				task:       tr.Task(),
   141  				sidsClient: tr.siClient,
   142  				lifecycle:  tr,
   143  				logger:     hookLogger,
   144  			}))
   145  		}
   146  
   147  		if task.UsesConnectSidecar() {
   148  			tr.runnerHooks = append(tr.runnerHooks,
   149  				newEnvoyVersionHook(newEnvoyVersionHookConfig(alloc, tr.consulProxiesClient, hookLogger)),
   150  				newEnvoyBootstrapHook(newEnvoyBootstrapHookConfig(alloc, tr.clientConfig.ConsulConfig, consulNamespace, hookLogger)),
   151  			)
   152  		} else if task.Kind.IsConnectNative() {
   153  			tr.runnerHooks = append(tr.runnerHooks, newConnectNativeHook(
   154  				newConnectNativeHookConfig(alloc, tr.clientConfig.ConsulConfig, hookLogger),
   155  			))
   156  		}
   157  	}
   158  
   159  	// Always add the script checks hook. A task with no script check hook on
   160  	// initial registration may be updated to include script checks, which must
   161  	// be handled with this hook.
   162  	tr.runnerHooks = append(tr.runnerHooks, newScriptCheckHook(scriptCheckHookConfig{
   163  		alloc:  tr.Alloc(),
   164  		task:   tr.Task(),
   165  		consul: tr.consulServiceClient,
   166  		logger: hookLogger,
   167  	}))
   168  
   169  	// If this task driver has remote capabilities, add the remote task
   170  	// hook.
   171  	if tr.driverCapabilities.RemoteTasks {
   172  		tr.runnerHooks = append(tr.runnerHooks, newRemoteTaskHook(tr, hookLogger))
   173  	}
   174  }
   175  
   176  func (tr *TaskRunner) emitHookError(err error, hookName string) {
   177  	var taskEvent *structs.TaskEvent
   178  	if herr, ok := err.(*hookError); ok {
   179  		taskEvent = herr.taskEvent
   180  	} else {
   181  		message := fmt.Sprintf("%s: %v", hookName, err)
   182  		taskEvent = structs.NewTaskEvent(structs.TaskHookFailed).SetMessage(message)
   183  	}
   184  
   185  	tr.EmitEvent(taskEvent)
   186  }
   187  
   188  // prestart is used to run the runners prestart hooks.
   189  func (tr *TaskRunner) prestart() error {
   190  	// Determine if the allocation is terminal and we should avoid running
   191  	// prestart hooks.
   192  	if tr.shouldShutdown() {
   193  		tr.logger.Trace("skipping prestart hooks since allocation is terminal")
   194  		return nil
   195  	}
   196  
   197  	if tr.logger.IsTrace() {
   198  		start := time.Now()
   199  		tr.logger.Trace("running prestart hooks", "start", start)
   200  		defer func() {
   201  			end := time.Now()
   202  			tr.logger.Trace("finished prestart hooks", "end", end, "duration", end.Sub(start))
   203  		}()
   204  	}
   205  
   206  	// use a join context to allow any blocking pre-start hooks
   207  	// to be canceled by either killCtx or shutdownCtx
   208  	joinedCtx, joinedCancel := joincontext.Join(tr.killCtx, tr.shutdownCtx)
   209  	defer joinedCancel()
   210  
   211  	for _, hook := range tr.runnerHooks {
   212  		pre, ok := hook.(interfaces.TaskPrestartHook)
   213  		if !ok {
   214  			continue
   215  		}
   216  
   217  		name := pre.Name()
   218  
   219  		// Build the request
   220  		req := interfaces.TaskPrestartRequest{
   221  			Task:          tr.Task(),
   222  			TaskDir:       tr.taskDir,
   223  			TaskEnv:       tr.envBuilder.Build(),
   224  			TaskResources: tr.taskResources,
   225  		}
   226  
   227  		origHookState := tr.hookState(name)
   228  		if origHookState != nil {
   229  			if origHookState.PrestartDone {
   230  				tr.logger.Trace("skipping done prestart hook", "name", pre.Name())
   231  
   232  				// Always set env vars from hooks
   233  				if name == HookNameDevices {
   234  					tr.envBuilder.SetDeviceHookEnv(name, origHookState.Env)
   235  				} else {
   236  					tr.envBuilder.SetHookEnv(name, origHookState.Env)
   237  				}
   238  
   239  				continue
   240  			}
   241  
   242  			// Give the hook it's old data
   243  			req.PreviousState = origHookState.Data
   244  		}
   245  
   246  		req.VaultToken = tr.getVaultToken()
   247  		req.NomadToken = tr.getNomadToken()
   248  
   249  		// Time the prestart hook
   250  		var start time.Time
   251  		if tr.logger.IsTrace() {
   252  			start = time.Now()
   253  			tr.logger.Trace("running prestart hook", "name", name, "start", start)
   254  		}
   255  
   256  		// Run the prestart hook
   257  		var resp interfaces.TaskPrestartResponse
   258  		if err := pre.Prestart(joinedCtx, &req, &resp); err != nil {
   259  			tr.emitHookError(err, name)
   260  			return structs.WrapRecoverable(fmt.Sprintf("prestart hook %q failed: %v", name, err), err)
   261  		}
   262  
   263  		// Store the hook state
   264  		{
   265  			hookState := &state.HookState{
   266  				Data:         resp.State,
   267  				PrestartDone: resp.Done,
   268  				Env:          resp.Env,
   269  			}
   270  
   271  			// Store and persist local state if the hook state has changed
   272  			if !hookState.Equal(origHookState) {
   273  				tr.stateLock.Lock()
   274  				tr.localState.Hooks[name] = hookState
   275  				tr.stateLock.Unlock()
   276  
   277  				if err := tr.persistLocalState(); err != nil {
   278  					return err
   279  				}
   280  			}
   281  		}
   282  
   283  		// Store the environment variables returned by the hook
   284  		if name == HookNameDevices {
   285  			tr.envBuilder.SetDeviceHookEnv(name, resp.Env)
   286  		} else {
   287  			tr.envBuilder.SetHookEnv(name, resp.Env)
   288  		}
   289  
   290  		// Store the resources
   291  		if len(resp.Devices) != 0 {
   292  			tr.hookResources.setDevices(resp.Devices)
   293  		}
   294  		if len(resp.Mounts) != 0 {
   295  			tr.hookResources.setMounts(resp.Mounts)
   296  		}
   297  
   298  		if tr.logger.IsTrace() {
   299  			end := time.Now()
   300  			tr.logger.Trace("finished prestart hook", "name", name, "end", end, "duration", end.Sub(start))
   301  		}
   302  	}
   303  
   304  	return nil
   305  }
   306  
   307  // poststart is used to run the runners poststart hooks.
   308  func (tr *TaskRunner) poststart() error {
   309  	if tr.logger.IsTrace() {
   310  		start := time.Now()
   311  		tr.logger.Trace("running poststart hooks", "start", start)
   312  		defer func() {
   313  			end := time.Now()
   314  			tr.logger.Trace("finished poststart hooks", "end", end, "duration", end.Sub(start))
   315  		}()
   316  	}
   317  
   318  	handle := tr.getDriverHandle()
   319  	net := handle.Network()
   320  
   321  	// Pass the lazy handle to the hooks so even if the driver exits and we
   322  	// launch a new one (external plugin), the handle will refresh.
   323  	lazyHandle := NewLazyHandle(tr.shutdownCtx, tr.getDriverHandle, tr.logger)
   324  
   325  	var merr multierror.Error
   326  	for _, hook := range tr.runnerHooks {
   327  		post, ok := hook.(interfaces.TaskPoststartHook)
   328  		if !ok {
   329  			continue
   330  		}
   331  
   332  		name := post.Name()
   333  		var start time.Time
   334  		if tr.logger.IsTrace() {
   335  			start = time.Now()
   336  			tr.logger.Trace("running poststart hook", "name", name, "start", start)
   337  		}
   338  
   339  		req := interfaces.TaskPoststartRequest{
   340  			DriverExec:    lazyHandle,
   341  			DriverNetwork: net,
   342  			DriverStats:   lazyHandle,
   343  			TaskEnv:       tr.envBuilder.Build(),
   344  		}
   345  		var resp interfaces.TaskPoststartResponse
   346  		if err := post.Poststart(tr.killCtx, &req, &resp); err != nil {
   347  			tr.emitHookError(err, name)
   348  			merr.Errors = append(merr.Errors, fmt.Errorf("poststart hook %q failed: %v", name, err))
   349  		}
   350  
   351  		// No need to persist as PoststartResponse is currently empty
   352  
   353  		if tr.logger.IsTrace() {
   354  			end := time.Now()
   355  			tr.logger.Trace("finished poststart hooks", "name", name, "end", end, "duration", end.Sub(start))
   356  		}
   357  	}
   358  
   359  	return merr.ErrorOrNil()
   360  }
   361  
   362  // exited is used to run the exited hooks before a task is stopped.
   363  func (tr *TaskRunner) exited() error {
   364  	if tr.logger.IsTrace() {
   365  		start := time.Now()
   366  		tr.logger.Trace("running exited hooks", "start", start)
   367  		defer func() {
   368  			end := time.Now()
   369  			tr.logger.Trace("finished exited hooks", "end", end, "duration", end.Sub(start))
   370  		}()
   371  	}
   372  
   373  	var merr multierror.Error
   374  	for _, hook := range tr.runnerHooks {
   375  		post, ok := hook.(interfaces.TaskExitedHook)
   376  		if !ok {
   377  			continue
   378  		}
   379  
   380  		name := post.Name()
   381  		var start time.Time
   382  		if tr.logger.IsTrace() {
   383  			start = time.Now()
   384  			tr.logger.Trace("running exited hook", "name", name, "start", start)
   385  		}
   386  
   387  		req := interfaces.TaskExitedRequest{}
   388  		var resp interfaces.TaskExitedResponse
   389  		if err := post.Exited(tr.killCtx, &req, &resp); err != nil {
   390  			tr.emitHookError(err, name)
   391  			merr.Errors = append(merr.Errors, fmt.Errorf("exited hook %q failed: %v", name, err))
   392  		}
   393  
   394  		// No need to persist as TaskExitedResponse is currently empty
   395  
   396  		if tr.logger.IsTrace() {
   397  			end := time.Now()
   398  			tr.logger.Trace("finished exited hooks", "name", name, "end", end, "duration", end.Sub(start))
   399  		}
   400  	}
   401  
   402  	return merr.ErrorOrNil()
   403  
   404  }
   405  
   406  // stop is used to run the stop hooks.
   407  func (tr *TaskRunner) stop() error {
   408  	if tr.logger.IsTrace() {
   409  		start := time.Now()
   410  		tr.logger.Trace("running stop hooks", "start", start)
   411  		defer func() {
   412  			end := time.Now()
   413  			tr.logger.Trace("finished stop hooks", "end", end, "duration", end.Sub(start))
   414  		}()
   415  	}
   416  
   417  	var merr multierror.Error
   418  	for _, hook := range tr.runnerHooks {
   419  		post, ok := hook.(interfaces.TaskStopHook)
   420  		if !ok {
   421  			continue
   422  		}
   423  
   424  		name := post.Name()
   425  		var start time.Time
   426  		if tr.logger.IsTrace() {
   427  			start = time.Now()
   428  			tr.logger.Trace("running stop hook", "name", name, "start", start)
   429  		}
   430  
   431  		req := interfaces.TaskStopRequest{}
   432  
   433  		origHookState := tr.hookState(name)
   434  		if origHookState != nil {
   435  			// Give the hook data provided by prestart
   436  			req.ExistingState = origHookState.Data
   437  		}
   438  
   439  		var resp interfaces.TaskStopResponse
   440  		if err := post.Stop(tr.killCtx, &req, &resp); err != nil {
   441  			tr.emitHookError(err, name)
   442  			merr.Errors = append(merr.Errors, fmt.Errorf("stop hook %q failed: %v", name, err))
   443  		}
   444  
   445  		// Stop hooks cannot alter state and must be idempotent, so
   446  		// unlike prestart there's no state to persist here.
   447  
   448  		if tr.logger.IsTrace() {
   449  			end := time.Now()
   450  			tr.logger.Trace("finished stop hook", "name", name, "end", end, "duration", end.Sub(start))
   451  		}
   452  	}
   453  
   454  	return merr.ErrorOrNil()
   455  }
   456  
   457  // update is used to run the runners update hooks. Should only be called from
   458  // Run(). To trigger an update, update state on the TaskRunner and call
   459  // triggerUpdateHooks.
   460  func (tr *TaskRunner) updateHooks() {
   461  	if tr.logger.IsTrace() {
   462  		start := time.Now()
   463  		tr.logger.Trace("running update hooks", "start", start)
   464  		defer func() {
   465  			end := time.Now()
   466  			tr.logger.Trace("finished update hooks", "end", end, "duration", end.Sub(start))
   467  		}()
   468  	}
   469  
   470  	// Prepare state needed by Update hooks
   471  	alloc := tr.Alloc()
   472  
   473  	// Execute Update hooks
   474  	for _, hook := range tr.runnerHooks {
   475  		upd, ok := hook.(interfaces.TaskUpdateHook)
   476  		if !ok {
   477  			continue
   478  		}
   479  
   480  		name := upd.Name()
   481  
   482  		// Build the request
   483  		req := interfaces.TaskUpdateRequest{
   484  			VaultToken: tr.getVaultToken(),
   485  			Alloc:      alloc,
   486  			TaskEnv:    tr.envBuilder.Build(),
   487  		}
   488  
   489  		// Time the update hook
   490  		var start time.Time
   491  		if tr.logger.IsTrace() {
   492  			start = time.Now()
   493  			tr.logger.Trace("running update hook", "name", name, "start", start)
   494  		}
   495  
   496  		// Run the update hook
   497  		var resp interfaces.TaskUpdateResponse
   498  		if err := upd.Update(tr.killCtx, &req, &resp); err != nil {
   499  			tr.emitHookError(err, name)
   500  			tr.logger.Error("update hook failed", "name", name, "error", err)
   501  		}
   502  
   503  		// No need to persist as TaskUpdateResponse is currently empty
   504  
   505  		if tr.logger.IsTrace() {
   506  			end := time.Now()
   507  			tr.logger.Trace("finished update hooks", "name", name, "end", end, "duration", end.Sub(start))
   508  		}
   509  	}
   510  }
   511  
   512  // preKill is used to run the runners preKill hooks
   513  // preKill hooks contain logic that must be executed before
   514  // a task is killed or restarted
   515  func (tr *TaskRunner) preKill() {
   516  	if tr.logger.IsTrace() {
   517  		start := time.Now()
   518  		tr.logger.Trace("running pre kill hooks", "start", start)
   519  		defer func() {
   520  			end := time.Now()
   521  			tr.logger.Trace("finished pre kill hooks", "end", end, "duration", end.Sub(start))
   522  		}()
   523  	}
   524  
   525  	for _, hook := range tr.runnerHooks {
   526  		killHook, ok := hook.(interfaces.TaskPreKillHook)
   527  		if !ok {
   528  			continue
   529  		}
   530  
   531  		name := killHook.Name()
   532  
   533  		// Time the pre kill hook
   534  		var start time.Time
   535  		if tr.logger.IsTrace() {
   536  			start = time.Now()
   537  			tr.logger.Trace("running prekill hook", "name", name, "start", start)
   538  		}
   539  
   540  		// Run the pre kill hook
   541  		req := interfaces.TaskPreKillRequest{}
   542  		var resp interfaces.TaskPreKillResponse
   543  		if err := killHook.PreKilling(context.Background(), &req, &resp); err != nil {
   544  			tr.emitHookError(err, name)
   545  			tr.logger.Error("prekill hook failed", "name", name, "error", err)
   546  		}
   547  
   548  		// No need to persist as TaskKillResponse is currently empty
   549  
   550  		if tr.logger.IsTrace() {
   551  			end := time.Now()
   552  			tr.logger.Trace("finished prekill hook", "name", name, "end", end, "duration", end.Sub(start))
   553  		}
   554  	}
   555  }
   556  
   557  // shutdownHooks is called when the TaskRunner is gracefully shutdown but the
   558  // task is not being stopped or garbage collected.
   559  func (tr *TaskRunner) shutdownHooks() {
   560  	for _, hook := range tr.runnerHooks {
   561  		sh, ok := hook.(interfaces.ShutdownHook)
   562  		if !ok {
   563  			continue
   564  		}
   565  
   566  		name := sh.Name()
   567  
   568  		// Time the update hook
   569  		var start time.Time
   570  		if tr.logger.IsTrace() {
   571  			start = time.Now()
   572  			tr.logger.Trace("running shutdown hook", "name", name, "start", start)
   573  		}
   574  
   575  		sh.Shutdown()
   576  
   577  		if tr.logger.IsTrace() {
   578  			end := time.Now()
   579  			tr.logger.Trace("finished shutdown hook", "name", name, "end", end, "duration", end.Sub(start))
   580  		}
   581  	}
   582  }