github.com/Ilhicas/nomad@v1.0.4-0.20210304152020-e86851182bc3/client/allocrunner/taskrunner/task_runner_hooks.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"path/filepath"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/LK4D4/joincontext"
    11  	multierror "github.com/hashicorp/go-multierror"
    12  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
    13  	"github.com/hashicorp/nomad/client/allocrunner/taskrunner/state"
    14  	"github.com/hashicorp/nomad/nomad/structs"
    15  	"github.com/hashicorp/nomad/plugins/drivers"
    16  )
    17  
    18  // hookResources captures the resources for the task provided by hooks.
    19  type hookResources struct {
    20  	Devices []*drivers.DeviceConfig
    21  	Mounts  []*drivers.MountConfig
    22  	sync.RWMutex
    23  }
    24  
    25  func (h *hookResources) setDevices(d []*drivers.DeviceConfig) {
    26  	h.Lock()
    27  	h.Devices = d
    28  	h.Unlock()
    29  }
    30  
    31  func (h *hookResources) getDevices() []*drivers.DeviceConfig {
    32  	h.RLock()
    33  	defer h.RUnlock()
    34  	return h.Devices
    35  }
    36  
    37  func (h *hookResources) setMounts(m []*drivers.MountConfig) {
    38  	h.Lock()
    39  	h.Mounts = m
    40  	h.Unlock()
    41  }
    42  
    43  func (h *hookResources) getMounts() []*drivers.MountConfig {
    44  	h.RLock()
    45  	defer h.RUnlock()
    46  	return h.Mounts
    47  }
    48  
    49  // initHooks initializes the tasks hooks.
    50  func (tr *TaskRunner) initHooks() {
    51  	hookLogger := tr.logger.Named("task_hook")
    52  	task := tr.Task()
    53  
    54  	tr.logmonHookConfig = newLogMonHookConfig(task.Name, tr.taskDir.LogDir)
    55  
    56  	// Add the hook resources
    57  	tr.hookResources = &hookResources{}
    58  
    59  	// Create the task directory hook. This is run first to ensure the
    60  	// directory path exists for other hooks.
    61  	alloc := tr.Alloc()
    62  	tr.runnerHooks = []interfaces.TaskHook{
    63  		newValidateHook(tr.clientConfig, hookLogger),
    64  		newTaskDirHook(tr, hookLogger),
    65  		newLogMonHook(tr, hookLogger),
    66  		newDispatchHook(alloc, hookLogger),
    67  		newVolumeHook(tr, hookLogger),
    68  		newArtifactHook(tr, hookLogger),
    69  		newStatsHook(tr, tr.clientConfig.StatsCollectionInterval, hookLogger),
    70  		newDeviceHook(tr.devicemanager, hookLogger),
    71  	}
    72  
    73  	// If the task has a CSI stanza, add the hook.
    74  	if task.CSIPluginConfig != nil {
    75  		tr.runnerHooks = append(tr.runnerHooks, newCSIPluginSupervisorHook(filepath.Join(tr.clientConfig.StateDir, "csi"), tr, tr, hookLogger))
    76  	}
    77  
    78  	// If Vault is enabled, add the hook
    79  	if task.Vault != nil {
    80  		tr.runnerHooks = append(tr.runnerHooks, newVaultHook(&vaultHookConfig{
    81  			vaultStanza: task.Vault,
    82  			client:      tr.vaultClient,
    83  			events:      tr,
    84  			lifecycle:   tr,
    85  			updater:     tr,
    86  			logger:      hookLogger,
    87  			alloc:       tr.Alloc(),
    88  			task:        tr.taskName,
    89  		}))
    90  	}
    91  
    92  	// If there are templates is enabled, add the hook
    93  	if len(task.Templates) != 0 {
    94  		tr.runnerHooks = append(tr.runnerHooks, newTemplateHook(&templateHookConfig{
    95  			logger:       hookLogger,
    96  			lifecycle:    tr,
    97  			events:       tr,
    98  			templates:    task.Templates,
    99  			clientConfig: tr.clientConfig,
   100  			envBuilder:   tr.envBuilder,
   101  		}))
   102  	}
   103  
   104  	// Always add the service hook. A task with no services on initial registration
   105  	// may be updated to include services, which must be handled with this hook.
   106  	tr.runnerHooks = append(tr.runnerHooks, newServiceHook(serviceHookConfig{
   107  		alloc:     tr.Alloc(),
   108  		task:      tr.Task(),
   109  		consul:    tr.consulServiceClient,
   110  		restarter: tr,
   111  		logger:    hookLogger,
   112  	}))
   113  
   114  	// If this is a Connect sidecar proxy (or a Connect Native) service,
   115  	// add the sidsHook for requesting a Service Identity token (if ACLs).
   116  	if task.UsesConnect() {
   117  		// Enable the Service Identity hook only if the Nomad client is configured
   118  		// with a consul token, indicating that Consul ACLs are enabled
   119  		if tr.clientConfig.ConsulConfig.Token != "" {
   120  			tr.runnerHooks = append(tr.runnerHooks, newSIDSHook(sidsHookConfig{
   121  				alloc:      tr.Alloc(),
   122  				task:       tr.Task(),
   123  				sidsClient: tr.siClient,
   124  				lifecycle:  tr,
   125  				logger:     hookLogger,
   126  			}))
   127  		}
   128  
   129  		if task.UsesConnectSidecar() {
   130  			tr.runnerHooks = append(tr.runnerHooks,
   131  				newEnvoyVersionHook(newEnvoyVersionHookConfig(alloc, tr.consulProxiesClient, hookLogger)),
   132  				newEnvoyBootstrapHook(newEnvoyBootstrapHookConfig(alloc, tr.clientConfig.ConsulConfig, hookLogger)),
   133  			)
   134  		} else if task.Kind.IsConnectNative() {
   135  			tr.runnerHooks = append(tr.runnerHooks, newConnectNativeHook(
   136  				newConnectNativeHookConfig(alloc, tr.clientConfig.ConsulConfig, hookLogger),
   137  			))
   138  		}
   139  	}
   140  
   141  	// If there are any script checks, add the hook
   142  	scriptCheckHook := newScriptCheckHook(scriptCheckHookConfig{
   143  		alloc:  tr.Alloc(),
   144  		task:   tr.Task(),
   145  		consul: tr.consulServiceClient,
   146  		logger: hookLogger,
   147  	})
   148  	tr.runnerHooks = append(tr.runnerHooks, scriptCheckHook)
   149  }
   150  
   151  func (tr *TaskRunner) emitHookError(err error, hookName string) {
   152  	var taskEvent *structs.TaskEvent
   153  	if herr, ok := err.(*hookError); ok {
   154  		taskEvent = herr.taskEvent
   155  	} else {
   156  		message := fmt.Sprintf("%s: %v", hookName, err)
   157  		taskEvent = structs.NewTaskEvent(structs.TaskHookFailed).SetMessage(message)
   158  	}
   159  
   160  	tr.EmitEvent(taskEvent)
   161  }
   162  
   163  // prestart is used to run the runners prestart hooks.
   164  func (tr *TaskRunner) prestart() error {
   165  	// Determine if the allocation is terminal and we should avoid running
   166  	// prestart hooks.
   167  	if tr.shouldShutdown() {
   168  		tr.logger.Trace("skipping prestart hooks since allocation is terminal")
   169  		return nil
   170  	}
   171  
   172  	if tr.logger.IsTrace() {
   173  		start := time.Now()
   174  		tr.logger.Trace("running prestart hooks", "start", start)
   175  		defer func() {
   176  			end := time.Now()
   177  			tr.logger.Trace("finished prestart hooks", "end", end, "duration", end.Sub(start))
   178  		}()
   179  	}
   180  
   181  	for _, hook := range tr.runnerHooks {
   182  		pre, ok := hook.(interfaces.TaskPrestartHook)
   183  		if !ok {
   184  			continue
   185  		}
   186  
   187  		name := pre.Name()
   188  
   189  		// Build the request
   190  		req := interfaces.TaskPrestartRequest{
   191  			Task:          tr.Task(),
   192  			TaskDir:       tr.taskDir,
   193  			TaskEnv:       tr.envBuilder.Build(),
   194  			TaskResources: tr.taskResources,
   195  		}
   196  
   197  		origHookState := tr.hookState(name)
   198  		if origHookState != nil {
   199  			if origHookState.PrestartDone {
   200  				tr.logger.Trace("skipping done prestart hook", "name", pre.Name())
   201  
   202  				// Always set env vars from hooks
   203  				if name == HookNameDevices {
   204  					tr.envBuilder.SetDeviceHookEnv(name, origHookState.Env)
   205  				} else {
   206  					tr.envBuilder.SetHookEnv(name, origHookState.Env)
   207  				}
   208  
   209  				continue
   210  			}
   211  
   212  			// Give the hook it's old data
   213  			req.PreviousState = origHookState.Data
   214  		}
   215  
   216  		req.VaultToken = tr.getVaultToken()
   217  
   218  		// Time the prestart hook
   219  		var start time.Time
   220  		if tr.logger.IsTrace() {
   221  			start = time.Now()
   222  			tr.logger.Trace("running prestart hook", "name", name, "start", start)
   223  		}
   224  
   225  		// Run the prestart hook
   226  		// use a joint context to allow any blocking pre-start hooks
   227  		// to be canceled by either killCtx or shutdownCtx
   228  		joinedCtx, _ := joincontext.Join(tr.killCtx, tr.shutdownCtx)
   229  		var resp interfaces.TaskPrestartResponse
   230  		if err := pre.Prestart(joinedCtx, &req, &resp); err != nil {
   231  			tr.emitHookError(err, name)
   232  			return structs.WrapRecoverable(fmt.Sprintf("prestart hook %q failed: %v", name, err), err)
   233  		}
   234  
   235  		// Store the hook state
   236  		{
   237  			hookState := &state.HookState{
   238  				Data:         resp.State,
   239  				PrestartDone: resp.Done,
   240  				Env:          resp.Env,
   241  			}
   242  
   243  			// Store and persist local state if the hook state has changed
   244  			if !hookState.Equal(origHookState) {
   245  				tr.stateLock.Lock()
   246  				tr.localState.Hooks[name] = hookState
   247  				tr.stateLock.Unlock()
   248  
   249  				if err := tr.persistLocalState(); err != nil {
   250  					return err
   251  				}
   252  			}
   253  		}
   254  
   255  		// Store the environment variables returned by the hook
   256  		if name == HookNameDevices {
   257  			tr.envBuilder.SetDeviceHookEnv(name, resp.Env)
   258  		} else {
   259  			tr.envBuilder.SetHookEnv(name, resp.Env)
   260  		}
   261  
   262  		// Store the resources
   263  		if len(resp.Devices) != 0 {
   264  			tr.hookResources.setDevices(resp.Devices)
   265  		}
   266  		if len(resp.Mounts) != 0 {
   267  			tr.hookResources.setMounts(resp.Mounts)
   268  		}
   269  
   270  		if tr.logger.IsTrace() {
   271  			end := time.Now()
   272  			tr.logger.Trace("finished prestart hook", "name", name, "end", end, "duration", end.Sub(start))
   273  		}
   274  	}
   275  
   276  	return nil
   277  }
   278  
   279  // poststart is used to run the runners poststart hooks.
   280  func (tr *TaskRunner) poststart() error {
   281  	if tr.logger.IsTrace() {
   282  		start := time.Now()
   283  		tr.logger.Trace("running poststart hooks", "start", start)
   284  		defer func() {
   285  			end := time.Now()
   286  			tr.logger.Trace("finished poststart hooks", "end", end, "duration", end.Sub(start))
   287  		}()
   288  	}
   289  
   290  	handle := tr.getDriverHandle()
   291  	net := handle.Network()
   292  
   293  	// Pass the lazy handle to the hooks so even if the driver exits and we
   294  	// launch a new one (external plugin), the handle will refresh.
   295  	lazyHandle := NewLazyHandle(tr.shutdownCtx, tr.getDriverHandle, tr.logger)
   296  
   297  	var merr multierror.Error
   298  	for _, hook := range tr.runnerHooks {
   299  		post, ok := hook.(interfaces.TaskPoststartHook)
   300  		if !ok {
   301  			continue
   302  		}
   303  
   304  		name := post.Name()
   305  		var start time.Time
   306  		if tr.logger.IsTrace() {
   307  			start = time.Now()
   308  			tr.logger.Trace("running poststart hook", "name", name, "start", start)
   309  		}
   310  
   311  		req := interfaces.TaskPoststartRequest{
   312  			DriverExec:    lazyHandle,
   313  			DriverNetwork: net,
   314  			DriverStats:   lazyHandle,
   315  			TaskEnv:       tr.envBuilder.Build(),
   316  		}
   317  		var resp interfaces.TaskPoststartResponse
   318  		if err := post.Poststart(tr.killCtx, &req, &resp); err != nil {
   319  			tr.emitHookError(err, name)
   320  			merr.Errors = append(merr.Errors, fmt.Errorf("poststart hook %q failed: %v", name, err))
   321  		}
   322  
   323  		// No need to persist as PoststartResponse is currently empty
   324  
   325  		if tr.logger.IsTrace() {
   326  			end := time.Now()
   327  			tr.logger.Trace("finished poststart hooks", "name", name, "end", end, "duration", end.Sub(start))
   328  		}
   329  	}
   330  
   331  	return merr.ErrorOrNil()
   332  }
   333  
   334  // exited is used to run the exited hooks before a task is stopped.
   335  func (tr *TaskRunner) exited() error {
   336  	if tr.logger.IsTrace() {
   337  		start := time.Now()
   338  		tr.logger.Trace("running exited hooks", "start", start)
   339  		defer func() {
   340  			end := time.Now()
   341  			tr.logger.Trace("finished exited hooks", "end", end, "duration", end.Sub(start))
   342  		}()
   343  	}
   344  
   345  	var merr multierror.Error
   346  	for _, hook := range tr.runnerHooks {
   347  		post, ok := hook.(interfaces.TaskExitedHook)
   348  		if !ok {
   349  			continue
   350  		}
   351  
   352  		name := post.Name()
   353  		var start time.Time
   354  		if tr.logger.IsTrace() {
   355  			start = time.Now()
   356  			tr.logger.Trace("running exited hook", "name", name, "start", start)
   357  		}
   358  
   359  		req := interfaces.TaskExitedRequest{}
   360  		var resp interfaces.TaskExitedResponse
   361  		if err := post.Exited(tr.killCtx, &req, &resp); err != nil {
   362  			tr.emitHookError(err, name)
   363  			merr.Errors = append(merr.Errors, fmt.Errorf("exited hook %q failed: %v", name, err))
   364  		}
   365  
   366  		// No need to persist as TaskExitedResponse is currently empty
   367  
   368  		if tr.logger.IsTrace() {
   369  			end := time.Now()
   370  			tr.logger.Trace("finished exited hooks", "name", name, "end", end, "duration", end.Sub(start))
   371  		}
   372  	}
   373  
   374  	return merr.ErrorOrNil()
   375  
   376  }
   377  
   378  // stop is used to run the stop hooks.
   379  func (tr *TaskRunner) stop() error {
   380  	if tr.logger.IsTrace() {
   381  		start := time.Now()
   382  		tr.logger.Trace("running stop hooks", "start", start)
   383  		defer func() {
   384  			end := time.Now()
   385  			tr.logger.Trace("finished stop hooks", "end", end, "duration", end.Sub(start))
   386  		}()
   387  	}
   388  
   389  	var merr multierror.Error
   390  	for _, hook := range tr.runnerHooks {
   391  		post, ok := hook.(interfaces.TaskStopHook)
   392  		if !ok {
   393  			continue
   394  		}
   395  
   396  		name := post.Name()
   397  		var start time.Time
   398  		if tr.logger.IsTrace() {
   399  			start = time.Now()
   400  			tr.logger.Trace("running stop hook", "name", name, "start", start)
   401  		}
   402  
   403  		req := interfaces.TaskStopRequest{}
   404  
   405  		origHookState := tr.hookState(name)
   406  		if origHookState != nil {
   407  			// Give the hook data provided by prestart
   408  			req.ExistingState = origHookState.Data
   409  		}
   410  
   411  		var resp interfaces.TaskStopResponse
   412  		if err := post.Stop(tr.killCtx, &req, &resp); err != nil {
   413  			tr.emitHookError(err, name)
   414  			merr.Errors = append(merr.Errors, fmt.Errorf("stop hook %q failed: %v", name, err))
   415  		}
   416  
   417  		// Stop hooks cannot alter state and must be idempotent, so
   418  		// unlike prestart there's no state to persist here.
   419  
   420  		if tr.logger.IsTrace() {
   421  			end := time.Now()
   422  			tr.logger.Trace("finished stop hook", "name", name, "end", end, "duration", end.Sub(start))
   423  		}
   424  	}
   425  
   426  	return merr.ErrorOrNil()
   427  }
   428  
   429  // update is used to run the runners update hooks. Should only be called from
   430  // Run(). To trigger an update, update state on the TaskRunner and call
   431  // triggerUpdateHooks.
   432  func (tr *TaskRunner) updateHooks() {
   433  	if tr.logger.IsTrace() {
   434  		start := time.Now()
   435  		tr.logger.Trace("running update hooks", "start", start)
   436  		defer func() {
   437  			end := time.Now()
   438  			tr.logger.Trace("finished update hooks", "end", end, "duration", end.Sub(start))
   439  		}()
   440  	}
   441  
   442  	// Prepare state needed by Update hooks
   443  	alloc := tr.Alloc()
   444  
   445  	// Execute Update hooks
   446  	for _, hook := range tr.runnerHooks {
   447  		upd, ok := hook.(interfaces.TaskUpdateHook)
   448  		if !ok {
   449  			continue
   450  		}
   451  
   452  		name := upd.Name()
   453  
   454  		// Build the request
   455  		req := interfaces.TaskUpdateRequest{
   456  			VaultToken: tr.getVaultToken(),
   457  			Alloc:      alloc,
   458  			TaskEnv:    tr.envBuilder.Build(),
   459  		}
   460  
   461  		// Time the update hook
   462  		var start time.Time
   463  		if tr.logger.IsTrace() {
   464  			start = time.Now()
   465  			tr.logger.Trace("running update hook", "name", name, "start", start)
   466  		}
   467  
   468  		// Run the update hook
   469  		var resp interfaces.TaskUpdateResponse
   470  		if err := upd.Update(tr.killCtx, &req, &resp); err != nil {
   471  			tr.emitHookError(err, name)
   472  			tr.logger.Error("update hook failed", "name", name, "error", err)
   473  		}
   474  
   475  		// No need to persist as TaskUpdateResponse is currently empty
   476  
   477  		if tr.logger.IsTrace() {
   478  			end := time.Now()
   479  			tr.logger.Trace("finished update hooks", "name", name, "end", end, "duration", end.Sub(start))
   480  		}
   481  	}
   482  }
   483  
   484  // preKill is used to run the runners preKill hooks
   485  // preKill hooks contain logic that must be executed before
   486  // a task is killed or restarted
   487  func (tr *TaskRunner) preKill() {
   488  	if tr.logger.IsTrace() {
   489  		start := time.Now()
   490  		tr.logger.Trace("running pre kill hooks", "start", start)
   491  		defer func() {
   492  			end := time.Now()
   493  			tr.logger.Trace("finished pre kill hooks", "end", end, "duration", end.Sub(start))
   494  		}()
   495  	}
   496  
   497  	for _, hook := range tr.runnerHooks {
   498  		killHook, ok := hook.(interfaces.TaskPreKillHook)
   499  		if !ok {
   500  			continue
   501  		}
   502  
   503  		name := killHook.Name()
   504  
   505  		// Time the pre kill hook
   506  		var start time.Time
   507  		if tr.logger.IsTrace() {
   508  			start = time.Now()
   509  			tr.logger.Trace("running prekill hook", "name", name, "start", start)
   510  		}
   511  
   512  		// Run the pre kill hook
   513  		req := interfaces.TaskPreKillRequest{}
   514  		var resp interfaces.TaskPreKillResponse
   515  		if err := killHook.PreKilling(context.Background(), &req, &resp); err != nil {
   516  			tr.emitHookError(err, name)
   517  			tr.logger.Error("prekill hook failed", "name", name, "error", err)
   518  		}
   519  
   520  		// No need to persist as TaskKillResponse is currently empty
   521  
   522  		if tr.logger.IsTrace() {
   523  			end := time.Now()
   524  			tr.logger.Trace("finished prekill hook", "name", name, "end", end, "duration", end.Sub(start))
   525  		}
   526  	}
   527  }
   528  
   529  // shutdownHooks is called when the TaskRunner is gracefully shutdown but the
   530  // task is not being stopped or garbage collected.
   531  func (tr *TaskRunner) shutdownHooks() {
   532  	for _, hook := range tr.runnerHooks {
   533  		sh, ok := hook.(interfaces.ShutdownHook)
   534  		if !ok {
   535  			continue
   536  		}
   537  
   538  		name := sh.Name()
   539  
   540  		// Time the update hook
   541  		var start time.Time
   542  		if tr.logger.IsTrace() {
   543  			start = time.Now()
   544  			tr.logger.Trace("running shutdown hook", "name", name, "start", start)
   545  		}
   546  
   547  		sh.Shutdown()
   548  
   549  		if tr.logger.IsTrace() {
   550  			end := time.Now()
   551  			tr.logger.Trace("finished shutdown hook", "name", name, "end", end, "duration", end.Sub(start))
   552  		}
   553  	}
   554  }