github.com/manicqin/nomad@v0.9.5/client/allocrunner/taskrunner/task_runner_hooks.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/LK4D4/joincontext"
    10  	multierror "github.com/hashicorp/go-multierror"
    11  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
    12  	"github.com/hashicorp/nomad/client/allocrunner/taskrunner/state"
    13  	"github.com/hashicorp/nomad/nomad/structs"
    14  	"github.com/hashicorp/nomad/plugins/drivers"
    15  )
    16  
    17  // hookResources captures the resources for the task provided by hooks.
    18  type hookResources struct {
    19  	Devices []*drivers.DeviceConfig
    20  	Mounts  []*drivers.MountConfig
    21  	sync.RWMutex
    22  }
    23  
    24  func (h *hookResources) setDevices(d []*drivers.DeviceConfig) {
    25  	h.Lock()
    26  	h.Devices = d
    27  	h.Unlock()
    28  }
    29  
    30  func (h *hookResources) getDevices() []*drivers.DeviceConfig {
    31  	h.RLock()
    32  	defer h.RUnlock()
    33  	return h.Devices
    34  }
    35  
    36  func (h *hookResources) setMounts(m []*drivers.MountConfig) {
    37  	h.Lock()
    38  	h.Mounts = m
    39  	h.Unlock()
    40  }
    41  
    42  func (h *hookResources) getMounts() []*drivers.MountConfig {
    43  	h.RLock()
    44  	defer h.RUnlock()
    45  	return h.Mounts
    46  }
    47  
    48  // initHooks intializes the tasks hooks.
    49  func (tr *TaskRunner) initHooks() {
    50  	hookLogger := tr.logger.Named("task_hook")
    51  	task := tr.Task()
    52  
    53  	tr.logmonHookConfig = newLogMonHookConfig(task.Name, tr.taskDir.LogDir)
    54  
    55  	// Add the hook resources
    56  	tr.hookResources = &hookResources{}
    57  
    58  	// Create the task directory hook. This is run first to ensure the
    59  	// directory path exists for other hooks.
    60  	alloc := tr.Alloc()
    61  	tr.runnerHooks = []interfaces.TaskHook{
    62  		newValidateHook(tr.clientConfig, hookLogger),
    63  		newTaskDirHook(tr, hookLogger),
    64  		newLogMonHook(tr, hookLogger),
    65  		newDispatchHook(alloc, hookLogger),
    66  		newVolumeHook(tr, hookLogger),
    67  		newArtifactHook(tr, hookLogger),
    68  		newStatsHook(tr, tr.clientConfig.StatsCollectionInterval, hookLogger),
    69  		newDeviceHook(tr.devicemanager, hookLogger),
    70  		newEnvoyBootstrapHook(alloc, tr.clientConfig.ConsulConfig.Addr, hookLogger),
    71  	}
    72  
    73  	// If Vault is enabled, add the hook
    74  	if task.Vault != nil {
    75  		tr.runnerHooks = append(tr.runnerHooks, newVaultHook(&vaultHookConfig{
    76  			vaultStanza: task.Vault,
    77  			client:      tr.vaultClient,
    78  			events:      tr,
    79  			lifecycle:   tr,
    80  			updater:     tr,
    81  			logger:      hookLogger,
    82  			alloc:       tr.Alloc(),
    83  			task:        tr.taskName,
    84  		}))
    85  	}
    86  
    87  	// If there are templates is enabled, add the hook
    88  	if len(task.Templates) != 0 {
    89  		tr.runnerHooks = append(tr.runnerHooks, newTemplateHook(&templateHookConfig{
    90  			logger:       hookLogger,
    91  			lifecycle:    tr,
    92  			events:       tr,
    93  			templates:    task.Templates,
    94  			clientConfig: tr.clientConfig,
    95  			envBuilder:   tr.envBuilder,
    96  		}))
    97  	}
    98  
    99  	// If there are any services, add the hook
   100  	if len(task.Services) != 0 {
   101  		tr.runnerHooks = append(tr.runnerHooks, newServiceHook(serviceHookConfig{
   102  			alloc:     tr.Alloc(),
   103  			task:      tr.Task(),
   104  			consul:    tr.consulClient,
   105  			restarter: tr,
   106  			logger:    hookLogger,
   107  		}))
   108  	}
   109  
   110  	// If there are any script checks, add the hook
   111  	scriptCheckHook := newScriptCheckHook(scriptCheckHookConfig{
   112  		alloc:  tr.Alloc(),
   113  		task:   tr.Task(),
   114  		consul: tr.consulClient,
   115  		logger: hookLogger,
   116  	})
   117  	tr.runnerHooks = append(tr.runnerHooks, scriptCheckHook)
   118  }
   119  
   120  func (tr *TaskRunner) emitHookError(err error, hookName string) {
   121  	var taskEvent *structs.TaskEvent
   122  	if herr, ok := err.(*hookError); ok {
   123  		taskEvent = herr.taskEvent
   124  	} else {
   125  		message := fmt.Sprintf("%s: %v", hookName, err)
   126  		taskEvent = structs.NewTaskEvent(structs.TaskHookFailed).SetMessage(message)
   127  	}
   128  
   129  	tr.EmitEvent(taskEvent)
   130  }
   131  
   132  // prestart is used to run the runners prestart hooks.
   133  func (tr *TaskRunner) prestart() error {
   134  	// Determine if the allocation is terminaland we should avoid running
   135  	// prestart hooks.
   136  	alloc := tr.Alloc()
   137  	if alloc.TerminalStatus() {
   138  		tr.logger.Trace("skipping prestart hooks since allocation is terminal")
   139  		return nil
   140  	}
   141  
   142  	if tr.logger.IsTrace() {
   143  		start := time.Now()
   144  		tr.logger.Trace("running prestart hooks", "start", start)
   145  		defer func() {
   146  			end := time.Now()
   147  			tr.logger.Trace("finished prestart hooks", "end", end, "duration", end.Sub(start))
   148  		}()
   149  	}
   150  
   151  	for _, hook := range tr.runnerHooks {
   152  		pre, ok := hook.(interfaces.TaskPrestartHook)
   153  		if !ok {
   154  			continue
   155  		}
   156  
   157  		name := pre.Name()
   158  
   159  		// Build the request
   160  		req := interfaces.TaskPrestartRequest{
   161  			Task:          tr.Task(),
   162  			TaskDir:       tr.taskDir,
   163  			TaskEnv:       tr.envBuilder.Build(),
   164  			TaskResources: tr.taskResources,
   165  		}
   166  
   167  		origHookState := tr.hookState(name)
   168  		if origHookState != nil {
   169  			if origHookState.PrestartDone {
   170  				tr.logger.Trace("skipping done prestart hook", "name", pre.Name())
   171  
   172  				// Always set env vars from hooks
   173  				if name == HookNameDevices {
   174  					tr.envBuilder.SetDeviceHookEnv(name, origHookState.Env)
   175  				} else {
   176  					tr.envBuilder.SetHookEnv(name, origHookState.Env)
   177  				}
   178  
   179  				continue
   180  			}
   181  
   182  			// Give the hook it's old data
   183  			req.PreviousState = origHookState.Data
   184  		}
   185  
   186  		req.VaultToken = tr.getVaultToken()
   187  
   188  		// Time the prestart hook
   189  		var start time.Time
   190  		if tr.logger.IsTrace() {
   191  			start = time.Now()
   192  			tr.logger.Trace("running prestart hook", "name", name, "start", start)
   193  		}
   194  
   195  		// Run the prestart hook
   196  		// use a joint context to allow any blocking pre-start hooks
   197  		// to be canceled by either killCtx or shutdownCtx
   198  		joinedCtx, _ := joincontext.Join(tr.killCtx, tr.shutdownCtx)
   199  		var resp interfaces.TaskPrestartResponse
   200  		if err := pre.Prestart(joinedCtx, &req, &resp); err != nil {
   201  			tr.emitHookError(err, name)
   202  			return structs.WrapRecoverable(fmt.Sprintf("prestart hook %q failed: %v", name, err), err)
   203  		}
   204  
   205  		// Store the hook state
   206  		{
   207  			hookState := &state.HookState{
   208  				Data:         resp.State,
   209  				PrestartDone: resp.Done,
   210  				Env:          resp.Env,
   211  			}
   212  
   213  			// Store and persist local state if the hook state has changed
   214  			if !hookState.Equal(origHookState) {
   215  				tr.stateLock.Lock()
   216  				tr.localState.Hooks[name] = hookState
   217  				tr.stateLock.Unlock()
   218  
   219  				if err := tr.persistLocalState(); err != nil {
   220  					return err
   221  				}
   222  			}
   223  		}
   224  
   225  		// Store the environment variables returned by the hook
   226  		if name == HookNameDevices {
   227  			tr.envBuilder.SetDeviceHookEnv(name, resp.Env)
   228  		} else {
   229  			tr.envBuilder.SetHookEnv(name, resp.Env)
   230  		}
   231  
   232  		// Store the resources
   233  		if len(resp.Devices) != 0 {
   234  			tr.hookResources.setDevices(resp.Devices)
   235  		}
   236  		if len(resp.Mounts) != 0 {
   237  			tr.hookResources.setMounts(resp.Mounts)
   238  		}
   239  
   240  		if tr.logger.IsTrace() {
   241  			end := time.Now()
   242  			tr.logger.Trace("finished prestart hook", "name", name, "end", end, "duration", end.Sub(start))
   243  		}
   244  	}
   245  
   246  	return nil
   247  }
   248  
   249  // poststart is used to run the runners poststart hooks.
   250  func (tr *TaskRunner) poststart() error {
   251  	if tr.logger.IsTrace() {
   252  		start := time.Now()
   253  		tr.logger.Trace("running poststart hooks", "start", start)
   254  		defer func() {
   255  			end := time.Now()
   256  			tr.logger.Trace("finished poststart hooks", "end", end, "duration", end.Sub(start))
   257  		}()
   258  	}
   259  
   260  	handle := tr.getDriverHandle()
   261  	net := handle.Network()
   262  
   263  	// Pass the lazy handle to the hooks so even if the driver exits and we
   264  	// launch a new one (external plugin), the handle will refresh.
   265  	lazyHandle := NewLazyHandle(tr.shutdownCtx, tr.getDriverHandle, tr.logger)
   266  
   267  	var merr multierror.Error
   268  	for _, hook := range tr.runnerHooks {
   269  		post, ok := hook.(interfaces.TaskPoststartHook)
   270  		if !ok {
   271  			continue
   272  		}
   273  
   274  		name := post.Name()
   275  		var start time.Time
   276  		if tr.logger.IsTrace() {
   277  			start = time.Now()
   278  			tr.logger.Trace("running poststart hook", "name", name, "start", start)
   279  		}
   280  
   281  		req := interfaces.TaskPoststartRequest{
   282  			DriverExec:    lazyHandle,
   283  			DriverNetwork: net,
   284  			DriverStats:   lazyHandle,
   285  			TaskEnv:       tr.envBuilder.Build(),
   286  		}
   287  		var resp interfaces.TaskPoststartResponse
   288  		if err := post.Poststart(tr.killCtx, &req, &resp); err != nil {
   289  			tr.emitHookError(err, name)
   290  			merr.Errors = append(merr.Errors, fmt.Errorf("poststart hook %q failed: %v", name, err))
   291  		}
   292  
   293  		// No need to persist as PoststartResponse is currently empty
   294  
   295  		if tr.logger.IsTrace() {
   296  			end := time.Now()
   297  			tr.logger.Trace("finished poststart hooks", "name", name, "end", end, "duration", end.Sub(start))
   298  		}
   299  	}
   300  
   301  	return merr.ErrorOrNil()
   302  }
   303  
   304  // exited is used to run the exited hooks before a task is stopped.
   305  func (tr *TaskRunner) exited() error {
   306  	if tr.logger.IsTrace() {
   307  		start := time.Now()
   308  		tr.logger.Trace("running exited hooks", "start", start)
   309  		defer func() {
   310  			end := time.Now()
   311  			tr.logger.Trace("finished exited hooks", "end", end, "duration", end.Sub(start))
   312  		}()
   313  	}
   314  
   315  	var merr multierror.Error
   316  	for _, hook := range tr.runnerHooks {
   317  		post, ok := hook.(interfaces.TaskExitedHook)
   318  		if !ok {
   319  			continue
   320  		}
   321  
   322  		name := post.Name()
   323  		var start time.Time
   324  		if tr.logger.IsTrace() {
   325  			start = time.Now()
   326  			tr.logger.Trace("running exited hook", "name", name, "start", start)
   327  		}
   328  
   329  		req := interfaces.TaskExitedRequest{}
   330  		var resp interfaces.TaskExitedResponse
   331  		if err := post.Exited(tr.killCtx, &req, &resp); err != nil {
   332  			tr.emitHookError(err, name)
   333  			merr.Errors = append(merr.Errors, fmt.Errorf("exited hook %q failed: %v", name, err))
   334  		}
   335  
   336  		// No need to persist as TaskExitedResponse is currently empty
   337  
   338  		if tr.logger.IsTrace() {
   339  			end := time.Now()
   340  			tr.logger.Trace("finished exited hooks", "name", name, "end", end, "duration", end.Sub(start))
   341  		}
   342  	}
   343  
   344  	return merr.ErrorOrNil()
   345  
   346  }
   347  
   348  // stop is used to run the stop hooks.
   349  func (tr *TaskRunner) stop() error {
   350  	if tr.logger.IsTrace() {
   351  		start := time.Now()
   352  		tr.logger.Trace("running stop hooks", "start", start)
   353  		defer func() {
   354  			end := time.Now()
   355  			tr.logger.Trace("finished stop hooks", "end", end, "duration", end.Sub(start))
   356  		}()
   357  	}
   358  
   359  	var merr multierror.Error
   360  	for _, hook := range tr.runnerHooks {
   361  		post, ok := hook.(interfaces.TaskStopHook)
   362  		if !ok {
   363  			continue
   364  		}
   365  
   366  		name := post.Name()
   367  		var start time.Time
   368  		if tr.logger.IsTrace() {
   369  			start = time.Now()
   370  			tr.logger.Trace("running stop hook", "name", name, "start", start)
   371  		}
   372  
   373  		req := interfaces.TaskStopRequest{}
   374  
   375  		origHookState := tr.hookState(name)
   376  		if origHookState != nil {
   377  			// Give the hook data provided by prestart
   378  			req.ExistingState = origHookState.Data
   379  		}
   380  
   381  		var resp interfaces.TaskStopResponse
   382  		if err := post.Stop(tr.killCtx, &req, &resp); err != nil {
   383  			tr.emitHookError(err, name)
   384  			merr.Errors = append(merr.Errors, fmt.Errorf("stop hook %q failed: %v", name, err))
   385  		}
   386  
   387  		// Stop hooks cannot alter state and must be idempotent, so
   388  		// unlike prestart there's no state to persist here.
   389  
   390  		if tr.logger.IsTrace() {
   391  			end := time.Now()
   392  			tr.logger.Trace("finished stop hook", "name", name, "end", end, "duration", end.Sub(start))
   393  		}
   394  	}
   395  
   396  	return merr.ErrorOrNil()
   397  }
   398  
   399  // update is used to run the runners update hooks. Should only be called from
   400  // Run(). To trigger an update, update state on the TaskRunner and call
   401  // triggerUpdateHooks.
   402  func (tr *TaskRunner) updateHooks() {
   403  	if tr.logger.IsTrace() {
   404  		start := time.Now()
   405  		tr.logger.Trace("running update hooks", "start", start)
   406  		defer func() {
   407  			end := time.Now()
   408  			tr.logger.Trace("finished update hooks", "end", end, "duration", end.Sub(start))
   409  		}()
   410  	}
   411  
   412  	// Prepare state needed by Update hooks
   413  	alloc := tr.Alloc()
   414  
   415  	// Execute Update hooks
   416  	for _, hook := range tr.runnerHooks {
   417  		upd, ok := hook.(interfaces.TaskUpdateHook)
   418  		if !ok {
   419  			continue
   420  		}
   421  
   422  		name := upd.Name()
   423  
   424  		// Build the request
   425  		req := interfaces.TaskUpdateRequest{
   426  			VaultToken: tr.getVaultToken(),
   427  			Alloc:      alloc,
   428  			TaskEnv:    tr.envBuilder.Build(),
   429  		}
   430  
   431  		// Time the update hook
   432  		var start time.Time
   433  		if tr.logger.IsTrace() {
   434  			start = time.Now()
   435  			tr.logger.Trace("running update hook", "name", name, "start", start)
   436  		}
   437  
   438  		// Run the update hook
   439  		var resp interfaces.TaskUpdateResponse
   440  		if err := upd.Update(tr.killCtx, &req, &resp); err != nil {
   441  			tr.emitHookError(err, name)
   442  			tr.logger.Error("update hook failed", "name", name, "error", err)
   443  		}
   444  
   445  		// No need to persist as TaskUpdateResponse is currently empty
   446  
   447  		if tr.logger.IsTrace() {
   448  			end := time.Now()
   449  			tr.logger.Trace("finished update hooks", "name", name, "end", end, "duration", end.Sub(start))
   450  		}
   451  	}
   452  }
   453  
   454  // preKill is used to run the runners preKill hooks
   455  // preKill hooks contain logic that must be executed before
   456  // a task is killed or restarted
   457  func (tr *TaskRunner) preKill() {
   458  	if tr.logger.IsTrace() {
   459  		start := time.Now()
   460  		tr.logger.Trace("running pre kill hooks", "start", start)
   461  		defer func() {
   462  			end := time.Now()
   463  			tr.logger.Trace("finished pre kill hooks", "end", end, "duration", end.Sub(start))
   464  		}()
   465  	}
   466  
   467  	for _, hook := range tr.runnerHooks {
   468  		killHook, ok := hook.(interfaces.TaskPreKillHook)
   469  		if !ok {
   470  			continue
   471  		}
   472  
   473  		name := killHook.Name()
   474  
   475  		// Time the pre kill hook
   476  		var start time.Time
   477  		if tr.logger.IsTrace() {
   478  			start = time.Now()
   479  			tr.logger.Trace("running prekill hook", "name", name, "start", start)
   480  		}
   481  
   482  		// Run the pre kill hook
   483  		req := interfaces.TaskPreKillRequest{}
   484  		var resp interfaces.TaskPreKillResponse
   485  		if err := killHook.PreKilling(context.Background(), &req, &resp); err != nil {
   486  			tr.emitHookError(err, name)
   487  			tr.logger.Error("prekill hook failed", "name", name, "error", err)
   488  		}
   489  
   490  		// No need to persist as TaskKillResponse is currently empty
   491  
   492  		if tr.logger.IsTrace() {
   493  			end := time.Now()
   494  			tr.logger.Trace("finished prekill hook", "name", name, "end", end, "duration", end.Sub(start))
   495  		}
   496  	}
   497  }
   498  
   499  // shutdownHooks is called when the TaskRunner is gracefully shutdown but the
   500  // task is not being stopped or garbage collected.
   501  func (tr *TaskRunner) shutdownHooks() {
   502  	for _, hook := range tr.runnerHooks {
   503  		sh, ok := hook.(interfaces.ShutdownHook)
   504  		if !ok {
   505  			continue
   506  		}
   507  
   508  		name := sh.Name()
   509  
   510  		// Time the update hook
   511  		var start time.Time
   512  		if tr.logger.IsTrace() {
   513  			start = time.Now()
   514  			tr.logger.Trace("running shutdown hook", "name", name, "start", start)
   515  		}
   516  
   517  		sh.Shutdown()
   518  
   519  		if tr.logger.IsTrace() {
   520  			end := time.Now()
   521  			tr.logger.Trace("finished shutdown hook", "name", name, "end", end, "duration", end.Sub(start))
   522  		}
   523  	}
   524  }