github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/allocrunner/taskrunner/script_check_hook.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/hashicorp/consul/api"
    10  	log "github.com/hashicorp/go-hclog"
    11  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
    12  	tinterfaces "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces"
    13  	"github.com/hashicorp/nomad/client/serviceregistration"
    14  	"github.com/hashicorp/nomad/client/taskenv"
    15  	agentconsul "github.com/hashicorp/nomad/command/agent/consul"
    16  	"github.com/hashicorp/nomad/nomad/structs"
    17  )
    18  
    19  var _ interfaces.TaskPoststartHook = &scriptCheckHook{}
    20  var _ interfaces.TaskUpdateHook = &scriptCheckHook{}
    21  var _ interfaces.TaskStopHook = &scriptCheckHook{}
    22  
    23  // default max amount of time to wait for all scripts on shutdown.
    24  const defaultShutdownWait = time.Minute
    25  
    26  type scriptCheckHookConfig struct {
    27  	alloc        *structs.Allocation
    28  	task         *structs.Task
    29  	consul       serviceregistration.Handler
    30  	logger       log.Logger
    31  	shutdownWait time.Duration
    32  }
    33  
    34  // scriptCheckHook implements a task runner hook for running script
    35  // checks in the context of a task
    36  type scriptCheckHook struct {
    37  	consul          serviceregistration.Handler
    38  	consulNamespace string
    39  	alloc           *structs.Allocation
    40  	task            *structs.Task
    41  	logger          log.Logger
    42  	shutdownWait    time.Duration // max time to wait for scripts to shutdown
    43  	shutdownCh      chan struct{} // closed when all scripts should shutdown
    44  
    45  	// The following fields can be changed by Update()
    46  	driverExec tinterfaces.ScriptExecutor
    47  	taskEnv    *taskenv.TaskEnv
    48  
    49  	// These maintain state and are populated by Poststart() or Update()
    50  	scripts        map[string]*scriptCheck
    51  	runningScripts map[string]*taskletHandle
    52  
    53  	// Since Update() may be called concurrently with any other hook all
    54  	// hook methods must be fully serialized
    55  	mu sync.Mutex
    56  }
    57  
    58  // newScriptCheckHook returns a hook without any scriptChecks.
    59  // They will get created only once their task environment is ready
    60  // in Poststart() or Update()
    61  func newScriptCheckHook(c scriptCheckHookConfig) *scriptCheckHook {
    62  	h := &scriptCheckHook{
    63  		consul:          c.consul,
    64  		consulNamespace: c.alloc.Job.LookupTaskGroup(c.alloc.TaskGroup).Consul.GetNamespace(),
    65  		alloc:           c.alloc,
    66  		task:            c.task,
    67  		scripts:         make(map[string]*scriptCheck),
    68  		runningScripts:  make(map[string]*taskletHandle),
    69  		shutdownWait:    defaultShutdownWait,
    70  		shutdownCh:      make(chan struct{}),
    71  	}
    72  
    73  	if c.shutdownWait != 0 {
    74  		h.shutdownWait = c.shutdownWait // override for testing
    75  	}
    76  	h.logger = c.logger.Named(h.Name())
    77  	return h
    78  }
    79  
    80  func (h *scriptCheckHook) Name() string {
    81  	return "script_checks"
    82  }
    83  
    84  // Prestart implements interfaces.TaskPrestartHook. It stores the
    85  // initial structs.Task
    86  func (h *scriptCheckHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, _ *interfaces.TaskPrestartResponse) error {
    87  	h.mu.Lock()
    88  	defer h.mu.Unlock()
    89  	h.task = req.Task
    90  	return nil
    91  }
    92  
    93  // PostStart implements interfaces.TaskPoststartHook. It creates new
    94  // script checks with the current task context (driver and env), and
    95  // starts up the scripts.
    96  func (h *scriptCheckHook) Poststart(ctx context.Context, req *interfaces.TaskPoststartRequest, _ *interfaces.TaskPoststartResponse) error {
    97  	h.mu.Lock()
    98  	defer h.mu.Unlock()
    99  
   100  	if req.DriverExec == nil {
   101  		h.logger.Debug("driver doesn't support script checks")
   102  		return nil
   103  	}
   104  	h.driverExec = req.DriverExec
   105  	h.taskEnv = req.TaskEnv
   106  
   107  	return h.upsertChecks()
   108  }
   109  
   110  // Updated implements interfaces.TaskUpdateHook. It creates new
   111  // script checks with the current task context (driver and env and possibly
   112  // new structs.Task), and starts up the scripts.
   113  func (h *scriptCheckHook) Update(ctx context.Context, req *interfaces.TaskUpdateRequest, _ *interfaces.TaskUpdateResponse) error {
   114  	h.mu.Lock()
   115  	defer h.mu.Unlock()
   116  
   117  	task := req.Alloc.LookupTask(h.task.Name)
   118  	if task == nil {
   119  		return fmt.Errorf("task %q not found in updated alloc", h.task.Name)
   120  	}
   121  	h.alloc = req.Alloc
   122  	h.task = task
   123  	h.taskEnv = req.TaskEnv
   124  
   125  	return h.upsertChecks()
   126  }
   127  
   128  func (h *scriptCheckHook) upsertChecks() error {
   129  	// Create new script checks struct with new task context
   130  	oldScriptChecks := h.scripts
   131  	h.scripts = h.newScriptChecks()
   132  
   133  	// Run new or replacement scripts
   134  	for id, script := range h.scripts {
   135  		// If it's already running, cancel and replace
   136  		if oldScript, running := h.runningScripts[id]; running {
   137  			oldScript.cancel()
   138  		}
   139  		// Start and store the handle
   140  		h.runningScripts[id] = script.run()
   141  	}
   142  
   143  	// Cancel scripts we no longer want
   144  	for id := range oldScriptChecks {
   145  		if _, ok := h.scripts[id]; !ok {
   146  			if oldScript, running := h.runningScripts[id]; running {
   147  				oldScript.cancel()
   148  			}
   149  		}
   150  	}
   151  	return nil
   152  }
   153  
   154  // Stop implements interfaces.TaskStopHook and blocks waiting for running
   155  // scripts to finish (or for the shutdownWait timeout to expire).
   156  func (h *scriptCheckHook) Stop(ctx context.Context, req *interfaces.TaskStopRequest, resp *interfaces.TaskStopResponse) error {
   157  	h.mu.Lock()
   158  	defer h.mu.Unlock()
   159  	close(h.shutdownCh)
   160  	deadline := time.After(h.shutdownWait)
   161  	err := fmt.Errorf("timed out waiting for script checks to exit")
   162  	for _, script := range h.runningScripts {
   163  		select {
   164  		case <-script.wait():
   165  		case <-ctx.Done():
   166  			// the caller is passing the background context, so
   167  			// we should never really see this outside of testing
   168  		case <-deadline:
   169  			// at this point the Consul client has been cleaned
   170  			// up so we don't want to hang onto this.
   171  			return err
   172  		}
   173  	}
   174  	return nil
   175  }
   176  
   177  func (h *scriptCheckHook) newScriptChecks() map[string]*scriptCheck {
   178  	scriptChecks := make(map[string]*scriptCheck)
   179  	interpolatedTaskServices := taskenv.InterpolateServices(h.taskEnv, h.task.Services)
   180  	for _, service := range interpolatedTaskServices {
   181  		for _, check := range service.Checks {
   182  			if check.Type != structs.ServiceCheckScript {
   183  				continue
   184  			}
   185  			serviceID := serviceregistration.MakeAllocServiceID(
   186  				h.alloc.ID, h.task.Name, service)
   187  			sc := newScriptCheck(&scriptCheckConfig{
   188  				consulNamespace: h.consulNamespace,
   189  				allocID:         h.alloc.ID,
   190  				taskName:        h.task.Name,
   191  				check:           check,
   192  				serviceID:       serviceID,
   193  				ttlUpdater:      h.consul,
   194  				driverExec:      h.driverExec,
   195  				taskEnv:         h.taskEnv,
   196  				logger:          h.logger,
   197  				shutdownCh:      h.shutdownCh,
   198  			})
   199  			if sc != nil {
   200  				scriptChecks[sc.id] = sc
   201  			}
   202  		}
   203  	}
   204  
   205  	// Walk back through the task group to see if there are script checks
   206  	// associated with the task. If so, we'll create scriptCheck tasklets
   207  	// for them. The group-level service and any check restart behaviors it
   208  	// needs are entirely encapsulated within the group service hook which
   209  	// watches Consul for status changes.
   210  	//
   211  	// The script check is associated with a group task if the service.task or
   212  	// service.check.task matches the task name. The service.check.task takes
   213  	// precedence.
   214  	tg := h.alloc.Job.LookupTaskGroup(h.alloc.TaskGroup)
   215  	interpolatedGroupServices := taskenv.InterpolateServices(h.taskEnv, tg.Services)
   216  	for _, service := range interpolatedGroupServices {
   217  		for _, check := range service.Checks {
   218  			if check.Type != structs.ServiceCheckScript {
   219  				continue
   220  			}
   221  			if !h.associated(h.task.Name, service.TaskName, check.TaskName) {
   222  				continue
   223  			}
   224  			groupTaskName := "group-" + tg.Name
   225  			serviceID := serviceregistration.MakeAllocServiceID(
   226  				h.alloc.ID, groupTaskName, service)
   227  			sc := newScriptCheck(&scriptCheckConfig{
   228  				consulNamespace: h.consulNamespace,
   229  				allocID:         h.alloc.ID,
   230  				taskName:        groupTaskName,
   231  				check:           check,
   232  				serviceID:       serviceID,
   233  				ttlUpdater:      h.consul,
   234  				driverExec:      h.driverExec,
   235  				taskEnv:         h.taskEnv,
   236  				logger:          h.logger,
   237  				shutdownCh:      h.shutdownCh,
   238  				isGroup:         true,
   239  			})
   240  			if sc != nil {
   241  				scriptChecks[sc.id] = sc
   242  			}
   243  		}
   244  	}
   245  	return scriptChecks
   246  }
   247  
   248  // associated returns true if the script check is associated with the task. This
   249  // would be the case if the check.task is the same as task, or if the service.task
   250  // is the same as the task _and_ check.task is not configured (i.e. the check
   251  // inherits the task of the service).
   252  func (*scriptCheckHook) associated(task, serviceTask, checkTask string) bool {
   253  	if checkTask == task {
   254  		return true
   255  	}
   256  	if serviceTask == task && checkTask == "" {
   257  		return true
   258  	}
   259  	return false
   260  }
   261  
   262  // TTLUpdater is the subset of consul agent functionality needed by script
   263  // checks to heartbeat
   264  type TTLUpdater interface {
   265  	UpdateTTL(id, namespace, output, status string) error
   266  }
   267  
   268  // scriptCheck runs script checks via a interfaces.ScriptExecutor and updates the
   269  // appropriate check's TTL when the script succeeds.
   270  type scriptCheck struct {
   271  	id              string
   272  	consulNamespace string
   273  	ttlUpdater      TTLUpdater
   274  	check           *structs.ServiceCheck
   275  	lastCheckOk     bool // true if the last check was ok; otherwise false
   276  	tasklet
   277  }
   278  
   279  // scriptCheckConfig is a parameter struct for newScriptCheck
   280  type scriptCheckConfig struct {
   281  	allocID         string
   282  	taskName        string
   283  	serviceID       string
   284  	consulNamespace string
   285  	check           *structs.ServiceCheck
   286  	ttlUpdater      TTLUpdater
   287  	driverExec      tinterfaces.ScriptExecutor
   288  	taskEnv         *taskenv.TaskEnv
   289  	logger          log.Logger
   290  	shutdownCh      chan struct{}
   291  	isGroup         bool
   292  }
   293  
   294  // newScriptCheck constructs a scriptCheck. we're only going to
   295  // configure the immutable fields of scriptCheck here, with the
   296  // rest being configured during the Poststart hook so that we have
   297  // the rest of the task execution environment
   298  func newScriptCheck(config *scriptCheckConfig) *scriptCheck {
   299  
   300  	// Guard against not having a valid taskEnv. This can be the case if the
   301  	// PreKilling or Exited hook is run before Poststart.
   302  	if config.taskEnv == nil || config.driverExec == nil {
   303  		return nil
   304  	}
   305  
   306  	orig := config.check
   307  	sc := &scriptCheck{
   308  		ttlUpdater:  config.ttlUpdater,
   309  		check:       config.check.Copy(),
   310  		lastCheckOk: true, // start logging on first failure
   311  	}
   312  
   313  	// we can't use the promoted fields of tasklet in the struct literal
   314  	sc.Command = config.taskEnv.ReplaceEnv(config.check.Command)
   315  	sc.Args = config.taskEnv.ParseAndReplace(config.check.Args)
   316  	sc.Interval = config.check.Interval
   317  	sc.Timeout = config.check.Timeout
   318  	sc.exec = config.driverExec
   319  	sc.callback = newScriptCheckCallback(sc)
   320  	sc.logger = config.logger
   321  	sc.shutdownCh = config.shutdownCh
   322  	sc.check.Command = sc.Command
   323  	sc.check.Args = sc.Args
   324  
   325  	if config.isGroup {
   326  		// group services don't have access to a task environment
   327  		// at creation, so their checks get registered before the
   328  		// check can be interpolated here. if we don't use the
   329  		// original checkID, they can't be updated.
   330  		sc.id = agentconsul.MakeCheckID(config.serviceID, orig)
   331  	} else {
   332  		sc.id = agentconsul.MakeCheckID(config.serviceID, sc.check)
   333  	}
   334  	sc.consulNamespace = config.consulNamespace
   335  	return sc
   336  }
   337  
   338  // Copy does a *shallow* copy of script checks.
   339  func (sc *scriptCheck) Copy() *scriptCheck {
   340  	newSc := sc
   341  	return newSc
   342  }
   343  
   344  // closes over the script check and returns the taskletCallback for
   345  // when the script check executes.
   346  func newScriptCheckCallback(s *scriptCheck) taskletCallback {
   347  
   348  	return func(ctx context.Context, params execResult) {
   349  		output := params.output
   350  		code := params.code
   351  		err := params.err
   352  
   353  		state := api.HealthCritical
   354  		switch code {
   355  		case 0:
   356  			state = api.HealthPassing
   357  		case 1:
   358  			state = api.HealthWarning
   359  		}
   360  
   361  		var outputMsg string
   362  		if err != nil {
   363  			state = api.HealthCritical
   364  			outputMsg = err.Error()
   365  		} else {
   366  			outputMsg = string(output)
   367  		}
   368  
   369  		// heartbeat the check to Consul
   370  		err = s.updateTTL(ctx, outputMsg, state)
   371  		select {
   372  		case <-ctx.Done():
   373  			// check has been removed; don't report errors
   374  			return
   375  		default:
   376  		}
   377  
   378  		if err != nil {
   379  			if s.lastCheckOk {
   380  				s.lastCheckOk = false
   381  				s.logger.Warn("updating check failed", "error", err)
   382  			} else {
   383  				s.logger.Debug("updating check still failing", "error", err)
   384  			}
   385  
   386  		} else if !s.lastCheckOk {
   387  			// Succeeded for the first time or after failing; log
   388  			s.lastCheckOk = true
   389  			s.logger.Info("updating check succeeded")
   390  		}
   391  	}
   392  }
   393  
   394  const (
   395  	updateTTLBackoffBaseline = 1 * time.Second
   396  	updateTTLBackoffLimit    = 3 * time.Second
   397  )
   398  
   399  // updateTTL updates the state to Consul, performing an exponential backoff
   400  // in the case where the check isn't registered in Consul to avoid a race between
   401  // service registration and the first check.
   402  func (sc *scriptCheck) updateTTL(ctx context.Context, msg, state string) error {
   403  	for attempts := 0; ; attempts++ {
   404  		err := sc.ttlUpdater.UpdateTTL(sc.id, sc.consulNamespace, msg, state)
   405  		if err == nil {
   406  			return nil
   407  		}
   408  
   409  		// Handle the retry case
   410  		backoff := (1 << (2 * uint64(attempts))) * updateTTLBackoffBaseline
   411  		if backoff > updateTTLBackoffLimit {
   412  			return err
   413  		}
   414  
   415  		// Wait till retrying
   416  		select {
   417  		case <-ctx.Done():
   418  			return err
   419  		case <-time.After(backoff):
   420  		}
   421  	}
   422  }