github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/allocrunner/taskrunner/sids_hook.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"os"
     9  	"path/filepath"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/hashicorp/go-hclog"
    14  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
    15  	ti "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces"
    16  	"github.com/hashicorp/nomad/client/consul"
    17  	"github.com/hashicorp/nomad/nomad/structs"
    18  )
    19  
    20  const (
    21  	// the name of this hook, used in logs
    22  	sidsHookName = "consul_si_token"
    23  
    24  	// sidsBackoffBaseline is the baseline time for exponential backoff when
    25  	// attempting to retrieve a Consul SI token
    26  	sidsBackoffBaseline = 5 * time.Second
    27  
    28  	// sidsBackoffLimit is the limit of the exponential backoff when attempting
    29  	// to retrieve a Consul SI token
    30  	sidsBackoffLimit = 3 * time.Minute
    31  
    32  	// sidsDerivationTimeout limits the amount of time we may spend trying to
    33  	// derive a SI token. If the hook does not get a token within this amount of
    34  	// time, the result is a failure.
    35  	sidsDerivationTimeout = 5 * time.Minute
    36  
    37  	// sidsTokenFile is the name of the file holding the Consul SI token inside
    38  	// the task's secret directory
    39  	sidsTokenFile = "si_token"
    40  
    41  	// sidsTokenFilePerms is the level of file permissions granted on the file
    42  	// in the secrets directory for the task
    43  	sidsTokenFilePerms = 0440
    44  )
    45  
    46  type sidsHookConfig struct {
    47  	alloc      *structs.Allocation
    48  	task       *structs.Task
    49  	sidsClient consul.ServiceIdentityAPI
    50  	lifecycle  ti.TaskLifecycle
    51  	logger     hclog.Logger
    52  }
    53  
    54  // Service Identities hook for managing SI tokens of connect enabled tasks.
    55  type sidsHook struct {
    56  	// alloc is the allocation
    57  	alloc *structs.Allocation
    58  
    59  	// taskName is the name of the task
    60  	task *structs.Task
    61  
    62  	// sidsClient is the Consul client [proxy] for requesting SI tokens
    63  	sidsClient consul.ServiceIdentityAPI
    64  
    65  	// lifecycle is used to signal, restart, and kill a task
    66  	lifecycle ti.TaskLifecycle
    67  
    68  	// derivationTimeout is the amount of time we may wait for Consul to successfully
    69  	// provide a SI token. Making this configurable for testing, otherwise
    70  	// default to sidsDerivationTimeout
    71  	derivationTimeout time.Duration
    72  
    73  	// logger is used to log
    74  	logger hclog.Logger
    75  
    76  	// lock variables that can be manipulated after hook creation
    77  	lock sync.Mutex
    78  	// firstRun keeps track of whether the hook is being called for the first
    79  	// time (for this task) during the lifespan of the Nomad Client process.
    80  	firstRun bool
    81  }
    82  
    83  func newSIDSHook(c sidsHookConfig) *sidsHook {
    84  	return &sidsHook{
    85  		alloc:             c.alloc,
    86  		task:              c.task,
    87  		sidsClient:        c.sidsClient,
    88  		lifecycle:         c.lifecycle,
    89  		derivationTimeout: sidsDerivationTimeout,
    90  		logger:            c.logger.Named(sidsHookName),
    91  		firstRun:          true,
    92  	}
    93  }
    94  
    95  func (h *sidsHook) Name() string {
    96  	return sidsHookName
    97  }
    98  
    99  func (h *sidsHook) Prestart(
   100  	ctx context.Context,
   101  	req *interfaces.TaskPrestartRequest,
   102  	resp *interfaces.TaskPrestartResponse) error {
   103  
   104  	h.lock.Lock()
   105  	defer h.lock.Unlock()
   106  
   107  	// do nothing if we have already done things
   108  	if h.earlyExit() {
   109  		resp.Done = true
   110  		return nil
   111  	}
   112  
   113  	// optimistically try to recover token from disk
   114  	token, err := h.recoverToken(req.TaskDir.SecretsDir)
   115  	if err != nil {
   116  		return err
   117  	}
   118  
   119  	// need to ask for a new SI token & persist it to disk
   120  	if token == "" {
   121  		if token, err = h.deriveSIToken(ctx); err != nil {
   122  			return err
   123  		}
   124  		if err := h.writeToken(req.TaskDir.SecretsDir, token); err != nil {
   125  			return err
   126  		}
   127  	}
   128  
   129  	h.logger.Info("derived SI token", "task", h.task.Name, "si_task", h.task.Kind.Value())
   130  
   131  	resp.Done = true
   132  	return nil
   133  }
   134  
   135  // earlyExit returns true if the Prestart hook has already been executed during
   136  // the instantiation of this task runner.
   137  //
   138  // assumes h is locked
   139  func (h *sidsHook) earlyExit() bool {
   140  	if h.firstRun {
   141  		h.firstRun = false
   142  		return false
   143  	}
   144  	return true
   145  }
   146  
   147  // writeToken writes token into the secrets directory for the task.
   148  func (h *sidsHook) writeToken(dir string, token string) error {
   149  	tokenPath := filepath.Join(dir, sidsTokenFile)
   150  	if err := ioutil.WriteFile(tokenPath, []byte(token), sidsTokenFilePerms); err != nil {
   151  		return fmt.Errorf("failed to write SI token: %w", err)
   152  	}
   153  	return nil
   154  }
   155  
   156  // recoverToken returns the token saved to disk in the secrets directory for the
   157  // task if it exists, or the empty string if the file does not exist. an error
   158  // is returned only for some other (e.g. disk IO) error.
   159  func (h *sidsHook) recoverToken(dir string) (string, error) {
   160  	tokenPath := filepath.Join(dir, sidsTokenFile)
   161  	token, err := ioutil.ReadFile(tokenPath)
   162  	if err != nil {
   163  		if !os.IsNotExist(err) {
   164  			h.logger.Error("failed to recover SI token", "error", err)
   165  			return "", fmt.Errorf("failed to recover SI token: %w", err)
   166  		}
   167  		h.logger.Trace("no pre-existing SI token to recover", "task", h.task.Name)
   168  		return "", nil // token file does not exist yet
   169  	}
   170  	h.logger.Trace("recovered pre-existing SI token", "task", h.task.Name)
   171  	return string(token), nil
   172  }
   173  
   174  // siDerivationResult is used to pass along the result of attempting to derive
   175  // an SI token between the goroutine doing the derivation and its caller
   176  type siDerivationResult struct {
   177  	token string
   178  	err   error
   179  }
   180  
   181  // deriveSIToken spawns and waits on a goroutine which will make attempts to
   182  // derive an SI token until a token is successfully created, or ctx is signaled
   183  // done.
   184  func (h *sidsHook) deriveSIToken(ctx context.Context) (string, error) {
   185  	ctx, cancel := context.WithTimeout(ctx, h.derivationTimeout)
   186  	defer cancel()
   187  
   188  	resultCh := make(chan siDerivationResult)
   189  
   190  	// keep trying to get the token in the background
   191  	go h.tryDerive(ctx, resultCh)
   192  
   193  	// wait until we get a token, or we get a signal to quit
   194  	for {
   195  		select {
   196  		case result := <-resultCh:
   197  			if result.err != nil {
   198  				h.logger.Error("failed to derive SI token", "error", result.err)
   199  				h.kill(ctx, fmt.Errorf("failed to derive SI token: %w", result.err))
   200  				return "", result.err
   201  			}
   202  			return result.token, nil
   203  		case <-ctx.Done():
   204  			return "", ctx.Err()
   205  		}
   206  	}
   207  }
   208  
   209  func (h *sidsHook) kill(ctx context.Context, reason error) {
   210  	if err := h.lifecycle.Kill(ctx,
   211  		structs.NewTaskEvent(structs.TaskKilling).
   212  			SetFailsTask().
   213  			SetDisplayMessage(reason.Error()),
   214  	); err != nil {
   215  		h.logger.Error("failed to kill task", "kill_reason", reason, "error", err)
   216  	}
   217  }
   218  
   219  // tryDerive loops forever until a token is created, or ctx is done.
   220  func (h *sidsHook) tryDerive(ctx context.Context, ch chan<- siDerivationResult) {
   221  	for attempt := 0; backoff(ctx, attempt); attempt++ {
   222  
   223  		tokens, err := h.sidsClient.DeriveSITokens(h.alloc, []string{h.task.Name})
   224  
   225  		switch {
   226  		case err == nil:
   227  			token, exists := tokens[h.task.Name]
   228  			if !exists {
   229  				err := errors.New("response does not include token for task")
   230  				h.logger.Error("derive SI token is missing token for task", "error", err, "task", h.task.Name)
   231  				ch <- siDerivationResult{token: "", err: err}
   232  				return
   233  			}
   234  			ch <- siDerivationResult{token: token, err: nil}
   235  			return
   236  		case structs.IsServerSide(err):
   237  			// the error is known to be a server problem, just die
   238  			h.logger.Error("failed to derive SI token", "error", err, "task", h.task.Name, "server_side", true)
   239  			ch <- siDerivationResult{token: "", err: err}
   240  			return
   241  		case !structs.IsRecoverable(err):
   242  			// the error is known not to be recoverable, just die
   243  			h.logger.Error("failed to derive SI token", "error", err, "task", h.task.Name, "recoverable", false)
   244  			ch <- siDerivationResult{token: "", err: err}
   245  			return
   246  
   247  		default:
   248  			// the error is marked recoverable, retry after some backoff
   249  			h.logger.Error("failed attempt to derive SI token", "error", err, "recoverable", true)
   250  		}
   251  	}
   252  }
   253  
   254  func backoff(ctx context.Context, attempt int) bool {
   255  	next := computeBackoff(attempt)
   256  	select {
   257  	case <-ctx.Done():
   258  		return false
   259  	case <-time.After(next):
   260  		return true
   261  	}
   262  }
   263  
   264  func computeBackoff(attempt int) time.Duration {
   265  	switch attempt {
   266  	case 0:
   267  		return 0
   268  	case 1:
   269  		// go fast on first retry, because a unit test should be fast
   270  		return 100 * time.Millisecond
   271  	default:
   272  		wait := time.Duration(attempt) * sidsBackoffBaseline
   273  		if wait > sidsBackoffLimit {
   274  			wait = sidsBackoffLimit
   275  		}
   276  		return wait
   277  	}
   278  }