github.com/smithx10/nomad@v0.9.1-rc1/client/allocrunner/taskrunner/task_runner_hooks.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"time"
     8  
     9  	multierror "github.com/hashicorp/go-multierror"
    10  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
    11  	"github.com/hashicorp/nomad/client/allocrunner/taskrunner/state"
    12  	"github.com/hashicorp/nomad/nomad/structs"
    13  	"github.com/hashicorp/nomad/plugins/drivers"
    14  )
    15  
    16  // hookResources captures the resources for the task provided by hooks.
    17  type hookResources struct {
    18  	Devices []*drivers.DeviceConfig
    19  	Mounts  []*drivers.MountConfig
    20  	sync.RWMutex
    21  }
    22  
    23  func (h *hookResources) setDevices(d []*drivers.DeviceConfig) {
    24  	h.Lock()
    25  	h.Devices = d
    26  	h.Unlock()
    27  }
    28  
    29  func (h *hookResources) getDevices() []*drivers.DeviceConfig {
    30  	h.RLock()
    31  	defer h.RUnlock()
    32  	return h.Devices
    33  }
    34  
    35  func (h *hookResources) setMounts(m []*drivers.MountConfig) {
    36  	h.Lock()
    37  	h.Mounts = m
    38  	h.Unlock()
    39  }
    40  
    41  func (h *hookResources) getMounts() []*drivers.MountConfig {
    42  	h.RLock()
    43  	defer h.RUnlock()
    44  	return h.Mounts
    45  }
    46  
    47  // initHooks intializes the tasks hooks.
    48  func (tr *TaskRunner) initHooks() {
    49  	hookLogger := tr.logger.Named("task_hook")
    50  	task := tr.Task()
    51  
    52  	tr.logmonHookConfig = newLogMonHookConfig(task.Name, tr.taskDir.LogDir)
    53  
    54  	// Add the hook resources
    55  	tr.hookResources = &hookResources{}
    56  
    57  	// Create the task directory hook. This is run first to ensure the
    58  	// directory path exists for other hooks.
    59  	tr.runnerHooks = []interfaces.TaskHook{
    60  		newValidateHook(tr.clientConfig, hookLogger),
    61  		newTaskDirHook(tr, hookLogger),
    62  		newLogMonHook(tr.logmonHookConfig, hookLogger),
    63  		newDispatchHook(tr.Alloc(), hookLogger),
    64  		newArtifactHook(tr, hookLogger),
    65  		newStatsHook(tr, tr.clientConfig.StatsCollectionInterval, hookLogger),
    66  		newDeviceHook(tr.devicemanager, hookLogger),
    67  	}
    68  
    69  	// If Vault is enabled, add the hook
    70  	if task.Vault != nil {
    71  		tr.runnerHooks = append(tr.runnerHooks, newVaultHook(&vaultHookConfig{
    72  			vaultStanza: task.Vault,
    73  			client:      tr.vaultClient,
    74  			events:      tr,
    75  			lifecycle:   tr,
    76  			updater:     tr,
    77  			logger:      hookLogger,
    78  			alloc:       tr.Alloc(),
    79  			task:        tr.taskName,
    80  		}))
    81  	}
    82  
    83  	// If there are templates is enabled, add the hook
    84  	if len(task.Templates) != 0 {
    85  		tr.runnerHooks = append(tr.runnerHooks, newTemplateHook(&templateHookConfig{
    86  			logger:       hookLogger,
    87  			lifecycle:    tr,
    88  			events:       tr,
    89  			templates:    task.Templates,
    90  			clientConfig: tr.clientConfig,
    91  			envBuilder:   tr.envBuilder,
    92  		}))
    93  	}
    94  
    95  	// If there are any services, add the hook
    96  	if len(task.Services) != 0 {
    97  		tr.runnerHooks = append(tr.runnerHooks, newServiceHook(serviceHookConfig{
    98  			alloc:     tr.Alloc(),
    99  			task:      tr.Task(),
   100  			consul:    tr.consulClient,
   101  			restarter: tr,
   102  			logger:    hookLogger,
   103  		}))
   104  	}
   105  }
   106  
   107  func (tr *TaskRunner) emitHookError(err error, hookName string) {
   108  	var taskEvent *structs.TaskEvent
   109  	if herr, ok := err.(*hookError); ok {
   110  		taskEvent = herr.taskEvent
   111  	} else {
   112  		message := fmt.Sprintf("%s: %v", hookName, err)
   113  		taskEvent = structs.NewTaskEvent(structs.TaskHookFailed).SetMessage(message)
   114  	}
   115  
   116  	tr.EmitEvent(taskEvent)
   117  }
   118  
   119  // prestart is used to run the runners prestart hooks.
   120  func (tr *TaskRunner) prestart() error {
   121  	// Determine if the allocation is terminaland we should avoid running
   122  	// prestart hooks.
   123  	alloc := tr.Alloc()
   124  	if alloc.TerminalStatus() {
   125  		tr.logger.Trace("skipping prestart hooks since allocation is terminal")
   126  		return nil
   127  	}
   128  
   129  	if tr.logger.IsTrace() {
   130  		start := time.Now()
   131  		tr.logger.Trace("running prestart hooks", "start", start)
   132  		defer func() {
   133  			end := time.Now()
   134  			tr.logger.Trace("finished prestart hooks", "end", end, "duration", end.Sub(start))
   135  		}()
   136  	}
   137  
   138  	for _, hook := range tr.runnerHooks {
   139  		pre, ok := hook.(interfaces.TaskPrestartHook)
   140  		if !ok {
   141  			continue
   142  		}
   143  
   144  		name := pre.Name()
   145  
   146  		// Build the request
   147  		req := interfaces.TaskPrestartRequest{
   148  			Task:          tr.Task(),
   149  			TaskDir:       tr.taskDir,
   150  			TaskEnv:       tr.envBuilder.Build(),
   151  			TaskResources: tr.taskResources,
   152  		}
   153  
   154  		origHookState := tr.hookState(name)
   155  		if origHookState != nil {
   156  			if origHookState.PrestartDone {
   157  				tr.logger.Trace("skipping done prestart hook", "name", pre.Name())
   158  
   159  				// Always set env vars from hooks
   160  				if name == HookNameDevices {
   161  					tr.envBuilder.SetDeviceHookEnv(name, origHookState.Env)
   162  				} else {
   163  					tr.envBuilder.SetHookEnv(name, origHookState.Env)
   164  				}
   165  
   166  				continue
   167  			}
   168  
   169  			// Give the hook it's old data
   170  			req.PreviousState = origHookState.Data
   171  		}
   172  
   173  		req.VaultToken = tr.getVaultToken()
   174  
   175  		// Time the prestart hook
   176  		var start time.Time
   177  		if tr.logger.IsTrace() {
   178  			start = time.Now()
   179  			tr.logger.Trace("running prestart hook", "name", name, "start", start)
   180  		}
   181  
   182  		// Run the prestart hook
   183  		var resp interfaces.TaskPrestartResponse
   184  		if err := pre.Prestart(tr.killCtx, &req, &resp); err != nil {
   185  			tr.emitHookError(err, name)
   186  			return structs.WrapRecoverable(fmt.Sprintf("prestart hook %q failed: %v", name, err), err)
   187  		}
   188  
   189  		// Store the hook state
   190  		{
   191  			hookState := &state.HookState{
   192  				Data:         resp.State,
   193  				PrestartDone: resp.Done,
   194  				Env:          resp.Env,
   195  			}
   196  
   197  			// Store and persist local state if the hook state has changed
   198  			if !hookState.Equal(origHookState) {
   199  				tr.stateLock.Lock()
   200  				tr.localState.Hooks[name] = hookState
   201  				tr.stateLock.Unlock()
   202  
   203  				if err := tr.persistLocalState(); err != nil {
   204  					return err
   205  				}
   206  			}
   207  		}
   208  
   209  		// Store the environment variables returned by the hook
   210  		if name == HookNameDevices {
   211  			tr.envBuilder.SetDeviceHookEnv(name, resp.Env)
   212  		} else {
   213  			tr.envBuilder.SetHookEnv(name, resp.Env)
   214  		}
   215  
   216  		// Store the resources
   217  		if len(resp.Devices) != 0 {
   218  			tr.hookResources.setDevices(resp.Devices)
   219  		}
   220  		if len(resp.Mounts) != 0 {
   221  			tr.hookResources.setMounts(resp.Mounts)
   222  		}
   223  
   224  		if tr.logger.IsTrace() {
   225  			end := time.Now()
   226  			tr.logger.Trace("finished prestart hook", "name", name, "end", end, "duration", end.Sub(start))
   227  		}
   228  	}
   229  
   230  	return nil
   231  }
   232  
   233  // poststart is used to run the runners poststart hooks.
   234  func (tr *TaskRunner) poststart() error {
   235  	if tr.logger.IsTrace() {
   236  		start := time.Now()
   237  		tr.logger.Trace("running poststart hooks", "start", start)
   238  		defer func() {
   239  			end := time.Now()
   240  			tr.logger.Trace("finished poststart hooks", "end", end, "duration", end.Sub(start))
   241  		}()
   242  	}
   243  
   244  	handle := tr.getDriverHandle()
   245  	net := handle.Network()
   246  
   247  	// Pass the lazy handle to the hooks so even if the driver exits and we
   248  	// launch a new one (external plugin), the handle will refresh.
   249  	lazyHandle := NewLazyHandle(tr.shutdownCtx, tr.getDriverHandle, tr.logger)
   250  
   251  	var merr multierror.Error
   252  	for _, hook := range tr.runnerHooks {
   253  		post, ok := hook.(interfaces.TaskPoststartHook)
   254  		if !ok {
   255  			continue
   256  		}
   257  
   258  		name := post.Name()
   259  		var start time.Time
   260  		if tr.logger.IsTrace() {
   261  			start = time.Now()
   262  			tr.logger.Trace("running poststart hook", "name", name, "start", start)
   263  		}
   264  
   265  		req := interfaces.TaskPoststartRequest{
   266  			DriverExec:    lazyHandle,
   267  			DriverNetwork: net,
   268  			DriverStats:   lazyHandle,
   269  			TaskEnv:       tr.envBuilder.Build(),
   270  		}
   271  		var resp interfaces.TaskPoststartResponse
   272  		if err := post.Poststart(tr.killCtx, &req, &resp); err != nil {
   273  			tr.emitHookError(err, name)
   274  			merr.Errors = append(merr.Errors, fmt.Errorf("poststart hook %q failed: %v", name, err))
   275  		}
   276  
   277  		// No need to persist as PoststartResponse is currently empty
   278  
   279  		if tr.logger.IsTrace() {
   280  			end := time.Now()
   281  			tr.logger.Trace("finished poststart hooks", "name", name, "end", end, "duration", end.Sub(start))
   282  		}
   283  	}
   284  
   285  	return merr.ErrorOrNil()
   286  }
   287  
   288  // exited is used to run the exited hooks before a task is stopped.
   289  func (tr *TaskRunner) exited() error {
   290  	if tr.logger.IsTrace() {
   291  		start := time.Now()
   292  		tr.logger.Trace("running exited hooks", "start", start)
   293  		defer func() {
   294  			end := time.Now()
   295  			tr.logger.Trace("finished exited hooks", "end", end, "duration", end.Sub(start))
   296  		}()
   297  	}
   298  
   299  	var merr multierror.Error
   300  	for _, hook := range tr.runnerHooks {
   301  		post, ok := hook.(interfaces.TaskExitedHook)
   302  		if !ok {
   303  			continue
   304  		}
   305  
   306  		name := post.Name()
   307  		var start time.Time
   308  		if tr.logger.IsTrace() {
   309  			start = time.Now()
   310  			tr.logger.Trace("running exited hook", "name", name, "start", start)
   311  		}
   312  
   313  		req := interfaces.TaskExitedRequest{}
   314  		var resp interfaces.TaskExitedResponse
   315  		if err := post.Exited(tr.killCtx, &req, &resp); err != nil {
   316  			tr.emitHookError(err, name)
   317  			merr.Errors = append(merr.Errors, fmt.Errorf("exited hook %q failed: %v", name, err))
   318  		}
   319  
   320  		// No need to persist as TaskExitedResponse is currently empty
   321  
   322  		if tr.logger.IsTrace() {
   323  			end := time.Now()
   324  			tr.logger.Trace("finished exited hooks", "name", name, "end", end, "duration", end.Sub(start))
   325  		}
   326  	}
   327  
   328  	return merr.ErrorOrNil()
   329  
   330  }
   331  
   332  // stop is used to run the stop hooks.
   333  func (tr *TaskRunner) stop() error {
   334  	if tr.logger.IsTrace() {
   335  		start := time.Now()
   336  		tr.logger.Trace("running stop hooks", "start", start)
   337  		defer func() {
   338  			end := time.Now()
   339  			tr.logger.Trace("finished stop hooks", "end", end, "duration", end.Sub(start))
   340  		}()
   341  	}
   342  
   343  	var merr multierror.Error
   344  	for _, hook := range tr.runnerHooks {
   345  		post, ok := hook.(interfaces.TaskStopHook)
   346  		if !ok {
   347  			continue
   348  		}
   349  
   350  		name := post.Name()
   351  		var start time.Time
   352  		if tr.logger.IsTrace() {
   353  			start = time.Now()
   354  			tr.logger.Trace("running stop hook", "name", name, "start", start)
   355  		}
   356  
   357  		req := interfaces.TaskStopRequest{}
   358  
   359  		origHookState := tr.hookState(name)
   360  		if origHookState != nil {
   361  			// Give the hook data provided by prestart
   362  			req.ExistingState = origHookState.Data
   363  		}
   364  
   365  		var resp interfaces.TaskStopResponse
   366  		if err := post.Stop(tr.killCtx, &req, &resp); err != nil {
   367  			tr.emitHookError(err, name)
   368  			merr.Errors = append(merr.Errors, fmt.Errorf("stop hook %q failed: %v", name, err))
   369  		}
   370  
   371  		// Stop hooks cannot alter state and must be idempotent, so
   372  		// unlike prestart there's no state to persist here.
   373  
   374  		if tr.logger.IsTrace() {
   375  			end := time.Now()
   376  			tr.logger.Trace("finished stop hook", "name", name, "end", end, "duration", end.Sub(start))
   377  		}
   378  	}
   379  
   380  	return merr.ErrorOrNil()
   381  }
   382  
   383  // update is used to run the runners update hooks. Should only be called from
   384  // Run(). To trigger an update, update state on the TaskRunner and call
   385  // triggerUpdateHooks.
   386  func (tr *TaskRunner) updateHooks() {
   387  	if tr.logger.IsTrace() {
   388  		start := time.Now()
   389  		tr.logger.Trace("running update hooks", "start", start)
   390  		defer func() {
   391  			end := time.Now()
   392  			tr.logger.Trace("finished update hooks", "end", end, "duration", end.Sub(start))
   393  		}()
   394  	}
   395  
   396  	// Prepare state needed by Update hooks
   397  	alloc := tr.Alloc()
   398  
   399  	// Execute Update hooks
   400  	for _, hook := range tr.runnerHooks {
   401  		upd, ok := hook.(interfaces.TaskUpdateHook)
   402  		if !ok {
   403  			continue
   404  		}
   405  
   406  		name := upd.Name()
   407  
   408  		// Build the request
   409  		req := interfaces.TaskUpdateRequest{
   410  			VaultToken: tr.getVaultToken(),
   411  			Alloc:      alloc,
   412  			TaskEnv:    tr.envBuilder.Build(),
   413  		}
   414  
   415  		// Time the update hook
   416  		var start time.Time
   417  		if tr.logger.IsTrace() {
   418  			start = time.Now()
   419  			tr.logger.Trace("running update hook", "name", name, "start", start)
   420  		}
   421  
   422  		// Run the update hook
   423  		var resp interfaces.TaskUpdateResponse
   424  		if err := upd.Update(tr.killCtx, &req, &resp); err != nil {
   425  			tr.emitHookError(err, name)
   426  			tr.logger.Error("update hook failed", "name", name, "error", err)
   427  		}
   428  
   429  		// No need to persist as TaskUpdateResponse is currently empty
   430  
   431  		if tr.logger.IsTrace() {
   432  			end := time.Now()
   433  			tr.logger.Trace("finished update hooks", "name", name, "end", end, "duration", end.Sub(start))
   434  		}
   435  	}
   436  }
   437  
   438  // preKill is used to run the runners preKill hooks
   439  // preKill hooks contain logic that must be executed before
   440  // a task is killed or restarted
   441  func (tr *TaskRunner) preKill() {
   442  	if tr.logger.IsTrace() {
   443  		start := time.Now()
   444  		tr.logger.Trace("running pre kill hooks", "start", start)
   445  		defer func() {
   446  			end := time.Now()
   447  			tr.logger.Trace("finished pre kill hooks", "end", end, "duration", end.Sub(start))
   448  		}()
   449  	}
   450  
   451  	for _, hook := range tr.runnerHooks {
   452  		killHook, ok := hook.(interfaces.TaskPreKillHook)
   453  		if !ok {
   454  			continue
   455  		}
   456  
   457  		name := killHook.Name()
   458  
   459  		// Time the pre kill hook
   460  		var start time.Time
   461  		if tr.logger.IsTrace() {
   462  			start = time.Now()
   463  			tr.logger.Trace("running prekill hook", "name", name, "start", start)
   464  		}
   465  
   466  		// Run the pre kill hook
   467  		req := interfaces.TaskPreKillRequest{}
   468  		var resp interfaces.TaskPreKillResponse
   469  		if err := killHook.PreKilling(context.Background(), &req, &resp); err != nil {
   470  			tr.emitHookError(err, name)
   471  			tr.logger.Error("prekill hook failed", "name", name, "error", err)
   472  		}
   473  
   474  		// No need to persist as TaskKillResponse is currently empty
   475  
   476  		if tr.logger.IsTrace() {
   477  			end := time.Now()
   478  			tr.logger.Trace("finished prekill hook", "name", name, "end", end, "duration", end.Sub(start))
   479  		}
   480  	}
   481  }
   482  
   483  // shutdownHooks is called when the TaskRunner is gracefully shutdown but the
   484  // task is not being stopped or garbage collected.
   485  func (tr *TaskRunner) shutdownHooks() {
   486  	for _, hook := range tr.runnerHooks {
   487  		sh, ok := hook.(interfaces.ShutdownHook)
   488  		if !ok {
   489  			continue
   490  		}
   491  
   492  		name := sh.Name()
   493  
   494  		// Time the update hook
   495  		var start time.Time
   496  		if tr.logger.IsTrace() {
   497  			start = time.Now()
   498  			tr.logger.Trace("running shutdown hook", "name", name, "start", start)
   499  		}
   500  
   501  		sh.Shutdown()
   502  
   503  		if tr.logger.IsTrace() {
   504  			end := time.Now()
   505  			tr.logger.Trace("finished shutdown hook", "name", name, "end", end, "duration", end.Sub(start))
   506  		}
   507  	}
   508  }