github.com/iqoqo/nomad@v0.11.3-0.20200911112621-d7021c74d101/client/allocrunner/taskrunner/task_runner_hooks.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"path/filepath"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/LK4D4/joincontext"
    11  	multierror "github.com/hashicorp/go-multierror"
    12  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
    13  	"github.com/hashicorp/nomad/client/allocrunner/taskrunner/state"
    14  	"github.com/hashicorp/nomad/nomad/structs"
    15  	"github.com/hashicorp/nomad/plugins/drivers"
    16  )
    17  
    18  // hookResources captures the resources for the task provided by hooks.
    19  type hookResources struct {
    20  	Devices []*drivers.DeviceConfig
    21  	Mounts  []*drivers.MountConfig
    22  	sync.RWMutex
    23  }
    24  
    25  func (h *hookResources) setDevices(d []*drivers.DeviceConfig) {
    26  	h.Lock()
    27  	h.Devices = d
    28  	h.Unlock()
    29  }
    30  
    31  func (h *hookResources) getDevices() []*drivers.DeviceConfig {
    32  	h.RLock()
    33  	defer h.RUnlock()
    34  	return h.Devices
    35  }
    36  
    37  func (h *hookResources) setMounts(m []*drivers.MountConfig) {
    38  	h.Lock()
    39  	h.Mounts = m
    40  	h.Unlock()
    41  }
    42  
    43  func (h *hookResources) getMounts() []*drivers.MountConfig {
    44  	h.RLock()
    45  	defer h.RUnlock()
    46  	return h.Mounts
    47  }
    48  
    49  // initHooks initializes the tasks hooks.
    50  func (tr *TaskRunner) initHooks() {
    51  	hookLogger := tr.logger.Named("task_hook")
    52  	task := tr.Task()
    53  
    54  	tr.logmonHookConfig = newLogMonHookConfig(task.Name, tr.taskDir.LogDir)
    55  
    56  	// Add the hook resources
    57  	tr.hookResources = &hookResources{}
    58  
    59  	// Create the task directory hook. This is run first to ensure the
    60  	// directory path exists for other hooks.
    61  	alloc := tr.Alloc()
    62  	tr.runnerHooks = []interfaces.TaskHook{
    63  		newValidateHook(tr.clientConfig, hookLogger),
    64  		newTaskDirHook(tr, hookLogger),
    65  		newLogMonHook(tr, hookLogger),
    66  		newDispatchHook(alloc, hookLogger),
    67  		newVolumeHook(tr, hookLogger),
    68  		newArtifactHook(tr, hookLogger),
    69  		newStatsHook(tr, tr.clientConfig.StatsCollectionInterval, hookLogger),
    70  		newDeviceHook(tr.devicemanager, hookLogger),
    71  	}
    72  
    73  	// If the task has a CSI stanza, add the hook.
    74  	if task.CSIPluginConfig != nil {
    75  		tr.runnerHooks = append(tr.runnerHooks, newCSIPluginSupervisorHook(filepath.Join(tr.clientConfig.StateDir, "csi"), tr, tr, hookLogger))
    76  	}
    77  
    78  	// If Vault is enabled, add the hook
    79  	if task.Vault != nil {
    80  		tr.runnerHooks = append(tr.runnerHooks, newVaultHook(&vaultHookConfig{
    81  			vaultStanza: task.Vault,
    82  			client:      tr.vaultClient,
    83  			events:      tr,
    84  			lifecycle:   tr,
    85  			updater:     tr,
    86  			logger:      hookLogger,
    87  			alloc:       tr.Alloc(),
    88  			task:        tr.taskName,
    89  		}))
    90  	}
    91  
    92  	// If there are templates is enabled, add the hook
    93  	if len(task.Templates) != 0 {
    94  		tr.runnerHooks = append(tr.runnerHooks, newTemplateHook(&templateHookConfig{
    95  			logger:       hookLogger,
    96  			lifecycle:    tr,
    97  			events:       tr,
    98  			templates:    task.Templates,
    99  			clientConfig: tr.clientConfig,
   100  			envBuilder:   tr.envBuilder,
   101  		}))
   102  	}
   103  
   104  	// If there are any services, add the service hook
   105  	if len(task.Services) != 0 {
   106  		tr.runnerHooks = append(tr.runnerHooks, newServiceHook(serviceHookConfig{
   107  			alloc:     tr.Alloc(),
   108  			task:      tr.Task(),
   109  			consul:    tr.consulClient,
   110  			restarter: tr,
   111  			logger:    hookLogger,
   112  		}))
   113  	}
   114  
   115  	// If this is a Connect sidecar proxy (or a Connect Native) service,
   116  	// add the sidsHook for requesting a Service Identity token (if ACLs).
   117  	if task.UsesConnect() {
   118  		// Enable the Service Identity hook only if the Nomad client is configured
   119  		// with a consul token, indicating that Consul ACLs are enabled
   120  		if tr.clientConfig.ConsulConfig.Token != "" {
   121  			tr.runnerHooks = append(tr.runnerHooks, newSIDSHook(sidsHookConfig{
   122  				alloc:      tr.Alloc(),
   123  				task:       tr.Task(),
   124  				sidsClient: tr.siClient,
   125  				lifecycle:  tr,
   126  				logger:     hookLogger,
   127  			}))
   128  		}
   129  
   130  		// envoy bootstrap must execute after sidsHook maybe sets SI token
   131  		tr.runnerHooks = append(tr.runnerHooks, newEnvoyBootstrapHook(
   132  			newEnvoyBootstrapHookConfig(alloc, tr.clientConfig.ConsulConfig, hookLogger),
   133  		))
   134  	}
   135  
   136  	// If there are any script checks, add the hook
   137  	scriptCheckHook := newScriptCheckHook(scriptCheckHookConfig{
   138  		alloc:  tr.Alloc(),
   139  		task:   tr.Task(),
   140  		consul: tr.consulClient,
   141  		logger: hookLogger,
   142  	})
   143  	tr.runnerHooks = append(tr.runnerHooks, scriptCheckHook)
   144  }
   145  
   146  func (tr *TaskRunner) emitHookError(err error, hookName string) {
   147  	var taskEvent *structs.TaskEvent
   148  	if herr, ok := err.(*hookError); ok {
   149  		taskEvent = herr.taskEvent
   150  	} else {
   151  		message := fmt.Sprintf("%s: %v", hookName, err)
   152  		taskEvent = structs.NewTaskEvent(structs.TaskHookFailed).SetMessage(message)
   153  	}
   154  
   155  	tr.EmitEvent(taskEvent)
   156  }
   157  
   158  // prestart is used to run the runners prestart hooks.
   159  func (tr *TaskRunner) prestart() error {
   160  	// Determine if the allocation is terminal and we should avoid running
   161  	// prestart hooks.
   162  	alloc := tr.Alloc()
   163  	if alloc.TerminalStatus() {
   164  		tr.logger.Trace("skipping prestart hooks since allocation is terminal")
   165  		return nil
   166  	}
   167  
   168  	if tr.logger.IsTrace() {
   169  		start := time.Now()
   170  		tr.logger.Trace("running prestart hooks", "start", start)
   171  		defer func() {
   172  			end := time.Now()
   173  			tr.logger.Trace("finished prestart hooks", "end", end, "duration", end.Sub(start))
   174  		}()
   175  	}
   176  
   177  	for _, hook := range tr.runnerHooks {
   178  		pre, ok := hook.(interfaces.TaskPrestartHook)
   179  		if !ok {
   180  			continue
   181  		}
   182  
   183  		name := pre.Name()
   184  
   185  		// Build the request
   186  		req := interfaces.TaskPrestartRequest{
   187  			Task:          tr.Task(),
   188  			TaskDir:       tr.taskDir,
   189  			TaskEnv:       tr.envBuilder.Build(),
   190  			TaskResources: tr.taskResources,
   191  		}
   192  
   193  		origHookState := tr.hookState(name)
   194  		if origHookState != nil {
   195  			if origHookState.PrestartDone {
   196  				tr.logger.Trace("skipping done prestart hook", "name", pre.Name())
   197  
   198  				// Always set env vars from hooks
   199  				if name == HookNameDevices {
   200  					tr.envBuilder.SetDeviceHookEnv(name, origHookState.Env)
   201  				} else {
   202  					tr.envBuilder.SetHookEnv(name, origHookState.Env)
   203  				}
   204  
   205  				continue
   206  			}
   207  
   208  			// Give the hook it's old data
   209  			req.PreviousState = origHookState.Data
   210  		}
   211  
   212  		req.VaultToken = tr.getVaultToken()
   213  
   214  		// Time the prestart hook
   215  		var start time.Time
   216  		if tr.logger.IsTrace() {
   217  			start = time.Now()
   218  			tr.logger.Trace("running prestart hook", "name", name, "start", start)
   219  		}
   220  
   221  		// Run the prestart hook
   222  		// use a joint context to allow any blocking pre-start hooks
   223  		// to be canceled by either killCtx or shutdownCtx
   224  		joinedCtx, _ := joincontext.Join(tr.killCtx, tr.shutdownCtx)
   225  		var resp interfaces.TaskPrestartResponse
   226  		if err := pre.Prestart(joinedCtx, &req, &resp); err != nil {
   227  			tr.emitHookError(err, name)
   228  			return structs.WrapRecoverable(fmt.Sprintf("prestart hook %q failed: %v", name, err), err)
   229  		}
   230  
   231  		// Store the hook state
   232  		{
   233  			hookState := &state.HookState{
   234  				Data:         resp.State,
   235  				PrestartDone: resp.Done,
   236  				Env:          resp.Env,
   237  			}
   238  
   239  			// Store and persist local state if the hook state has changed
   240  			if !hookState.Equal(origHookState) {
   241  				tr.stateLock.Lock()
   242  				tr.localState.Hooks[name] = hookState
   243  				tr.stateLock.Unlock()
   244  
   245  				if err := tr.persistLocalState(); err != nil {
   246  					return err
   247  				}
   248  			}
   249  		}
   250  
   251  		// Store the environment variables returned by the hook
   252  		if name == HookNameDevices {
   253  			tr.envBuilder.SetDeviceHookEnv(name, resp.Env)
   254  		} else {
   255  			tr.envBuilder.SetHookEnv(name, resp.Env)
   256  		}
   257  
   258  		// Store the resources
   259  		if len(resp.Devices) != 0 {
   260  			tr.hookResources.setDevices(resp.Devices)
   261  		}
   262  		if len(resp.Mounts) != 0 {
   263  			tr.hookResources.setMounts(resp.Mounts)
   264  		}
   265  
   266  		if tr.logger.IsTrace() {
   267  			end := time.Now()
   268  			tr.logger.Trace("finished prestart hook", "name", name, "end", end, "duration", end.Sub(start))
   269  		}
   270  	}
   271  
   272  	return nil
   273  }
   274  
   275  // poststart is used to run the runners poststart hooks.
   276  func (tr *TaskRunner) poststart() error {
   277  	if tr.logger.IsTrace() {
   278  		start := time.Now()
   279  		tr.logger.Trace("running poststart hooks", "start", start)
   280  		defer func() {
   281  			end := time.Now()
   282  			tr.logger.Trace("finished poststart hooks", "end", end, "duration", end.Sub(start))
   283  		}()
   284  	}
   285  
   286  	handle := tr.getDriverHandle()
   287  	net := handle.Network()
   288  
   289  	// Pass the lazy handle to the hooks so even if the driver exits and we
   290  	// launch a new one (external plugin), the handle will refresh.
   291  	lazyHandle := NewLazyHandle(tr.shutdownCtx, tr.getDriverHandle, tr.logger)
   292  
   293  	var merr multierror.Error
   294  	for _, hook := range tr.runnerHooks {
   295  		post, ok := hook.(interfaces.TaskPoststartHook)
   296  		if !ok {
   297  			continue
   298  		}
   299  
   300  		name := post.Name()
   301  		var start time.Time
   302  		if tr.logger.IsTrace() {
   303  			start = time.Now()
   304  			tr.logger.Trace("running poststart hook", "name", name, "start", start)
   305  		}
   306  
   307  		req := interfaces.TaskPoststartRequest{
   308  			DriverExec:    lazyHandle,
   309  			DriverNetwork: net,
   310  			DriverStats:   lazyHandle,
   311  			TaskEnv:       tr.envBuilder.Build(),
   312  		}
   313  		var resp interfaces.TaskPoststartResponse
   314  		if err := post.Poststart(tr.killCtx, &req, &resp); err != nil {
   315  			tr.emitHookError(err, name)
   316  			merr.Errors = append(merr.Errors, fmt.Errorf("poststart hook %q failed: %v", name, err))
   317  		}
   318  
   319  		// No need to persist as PoststartResponse is currently empty
   320  
   321  		if tr.logger.IsTrace() {
   322  			end := time.Now()
   323  			tr.logger.Trace("finished poststart hooks", "name", name, "end", end, "duration", end.Sub(start))
   324  		}
   325  	}
   326  
   327  	return merr.ErrorOrNil()
   328  }
   329  
   330  // exited is used to run the exited hooks before a task is stopped.
   331  func (tr *TaskRunner) exited() error {
   332  	if tr.logger.IsTrace() {
   333  		start := time.Now()
   334  		tr.logger.Trace("running exited hooks", "start", start)
   335  		defer func() {
   336  			end := time.Now()
   337  			tr.logger.Trace("finished exited hooks", "end", end, "duration", end.Sub(start))
   338  		}()
   339  	}
   340  
   341  	var merr multierror.Error
   342  	for _, hook := range tr.runnerHooks {
   343  		post, ok := hook.(interfaces.TaskExitedHook)
   344  		if !ok {
   345  			continue
   346  		}
   347  
   348  		name := post.Name()
   349  		var start time.Time
   350  		if tr.logger.IsTrace() {
   351  			start = time.Now()
   352  			tr.logger.Trace("running exited hook", "name", name, "start", start)
   353  		}
   354  
   355  		req := interfaces.TaskExitedRequest{}
   356  		var resp interfaces.TaskExitedResponse
   357  		if err := post.Exited(tr.killCtx, &req, &resp); err != nil {
   358  			tr.emitHookError(err, name)
   359  			merr.Errors = append(merr.Errors, fmt.Errorf("exited hook %q failed: %v", name, err))
   360  		}
   361  
   362  		// No need to persist as TaskExitedResponse is currently empty
   363  
   364  		if tr.logger.IsTrace() {
   365  			end := time.Now()
   366  			tr.logger.Trace("finished exited hooks", "name", name, "end", end, "duration", end.Sub(start))
   367  		}
   368  	}
   369  
   370  	return merr.ErrorOrNil()
   371  
   372  }
   373  
   374  // stop is used to run the stop hooks.
   375  func (tr *TaskRunner) stop() error {
   376  	if tr.logger.IsTrace() {
   377  		start := time.Now()
   378  		tr.logger.Trace("running stop hooks", "start", start)
   379  		defer func() {
   380  			end := time.Now()
   381  			tr.logger.Trace("finished stop hooks", "end", end, "duration", end.Sub(start))
   382  		}()
   383  	}
   384  
   385  	var merr multierror.Error
   386  	for _, hook := range tr.runnerHooks {
   387  		post, ok := hook.(interfaces.TaskStopHook)
   388  		if !ok {
   389  			continue
   390  		}
   391  
   392  		name := post.Name()
   393  		var start time.Time
   394  		if tr.logger.IsTrace() {
   395  			start = time.Now()
   396  			tr.logger.Trace("running stop hook", "name", name, "start", start)
   397  		}
   398  
   399  		req := interfaces.TaskStopRequest{}
   400  
   401  		origHookState := tr.hookState(name)
   402  		if origHookState != nil {
   403  			// Give the hook data provided by prestart
   404  			req.ExistingState = origHookState.Data
   405  		}
   406  
   407  		var resp interfaces.TaskStopResponse
   408  		if err := post.Stop(tr.killCtx, &req, &resp); err != nil {
   409  			tr.emitHookError(err, name)
   410  			merr.Errors = append(merr.Errors, fmt.Errorf("stop hook %q failed: %v", name, err))
   411  		}
   412  
   413  		// Stop hooks cannot alter state and must be idempotent, so
   414  		// unlike prestart there's no state to persist here.
   415  
   416  		if tr.logger.IsTrace() {
   417  			end := time.Now()
   418  			tr.logger.Trace("finished stop hook", "name", name, "end", end, "duration", end.Sub(start))
   419  		}
   420  	}
   421  
   422  	return merr.ErrorOrNil()
   423  }
   424  
   425  // update is used to run the runners update hooks. Should only be called from
   426  // Run(). To trigger an update, update state on the TaskRunner and call
   427  // triggerUpdateHooks.
   428  func (tr *TaskRunner) updateHooks() {
   429  	if tr.logger.IsTrace() {
   430  		start := time.Now()
   431  		tr.logger.Trace("running update hooks", "start", start)
   432  		defer func() {
   433  			end := time.Now()
   434  			tr.logger.Trace("finished update hooks", "end", end, "duration", end.Sub(start))
   435  		}()
   436  	}
   437  
   438  	// Prepare state needed by Update hooks
   439  	alloc := tr.Alloc()
   440  
   441  	// Execute Update hooks
   442  	for _, hook := range tr.runnerHooks {
   443  		upd, ok := hook.(interfaces.TaskUpdateHook)
   444  		if !ok {
   445  			continue
   446  		}
   447  
   448  		name := upd.Name()
   449  
   450  		// Build the request
   451  		req := interfaces.TaskUpdateRequest{
   452  			VaultToken: tr.getVaultToken(),
   453  			Alloc:      alloc,
   454  			TaskEnv:    tr.envBuilder.Build(),
   455  		}
   456  
   457  		// Time the update hook
   458  		var start time.Time
   459  		if tr.logger.IsTrace() {
   460  			start = time.Now()
   461  			tr.logger.Trace("running update hook", "name", name, "start", start)
   462  		}
   463  
   464  		// Run the update hook
   465  		var resp interfaces.TaskUpdateResponse
   466  		if err := upd.Update(tr.killCtx, &req, &resp); err != nil {
   467  			tr.emitHookError(err, name)
   468  			tr.logger.Error("update hook failed", "name", name, "error", err)
   469  		}
   470  
   471  		// No need to persist as TaskUpdateResponse is currently empty
   472  
   473  		if tr.logger.IsTrace() {
   474  			end := time.Now()
   475  			tr.logger.Trace("finished update hooks", "name", name, "end", end, "duration", end.Sub(start))
   476  		}
   477  	}
   478  }
   479  
   480  // preKill is used to run the runners preKill hooks
   481  // preKill hooks contain logic that must be executed before
   482  // a task is killed or restarted
   483  func (tr *TaskRunner) preKill() {
   484  	if tr.logger.IsTrace() {
   485  		start := time.Now()
   486  		tr.logger.Trace("running pre kill hooks", "start", start)
   487  		defer func() {
   488  			end := time.Now()
   489  			tr.logger.Trace("finished pre kill hooks", "end", end, "duration", end.Sub(start))
   490  		}()
   491  	}
   492  
   493  	for _, hook := range tr.runnerHooks {
   494  		killHook, ok := hook.(interfaces.TaskPreKillHook)
   495  		if !ok {
   496  			continue
   497  		}
   498  
   499  		name := killHook.Name()
   500  
   501  		// Time the pre kill hook
   502  		var start time.Time
   503  		if tr.logger.IsTrace() {
   504  			start = time.Now()
   505  			tr.logger.Trace("running prekill hook", "name", name, "start", start)
   506  		}
   507  
   508  		// Run the pre kill hook
   509  		req := interfaces.TaskPreKillRequest{}
   510  		var resp interfaces.TaskPreKillResponse
   511  		if err := killHook.PreKilling(context.Background(), &req, &resp); err != nil {
   512  			tr.emitHookError(err, name)
   513  			tr.logger.Error("prekill hook failed", "name", name, "error", err)
   514  		}
   515  
   516  		// No need to persist as TaskKillResponse is currently empty
   517  
   518  		if tr.logger.IsTrace() {
   519  			end := time.Now()
   520  			tr.logger.Trace("finished prekill hook", "name", name, "end", end, "duration", end.Sub(start))
   521  		}
   522  	}
   523  }
   524  
   525  // shutdownHooks is called when the TaskRunner is gracefully shutdown but the
   526  // task is not being stopped or garbage collected.
   527  func (tr *TaskRunner) shutdownHooks() {
   528  	for _, hook := range tr.runnerHooks {
   529  		sh, ok := hook.(interfaces.ShutdownHook)
   530  		if !ok {
   531  			continue
   532  		}
   533  
   534  		name := sh.Name()
   535  
   536  		// Time the update hook
   537  		var start time.Time
   538  		if tr.logger.IsTrace() {
   539  			start = time.Now()
   540  			tr.logger.Trace("running shutdown hook", "name", name, "start", start)
   541  		}
   542  
   543  		sh.Shutdown()
   544  
   545  		if tr.logger.IsTrace() {
   546  			end := time.Now()
   547  			tr.logger.Trace("finished shutdown hook", "name", name, "end", end, "duration", end.Sub(start))
   548  		}
   549  	}
   550  }