github.com/iqoqo/nomad@v0.11.3-0.20200911112621-d7021c74d101/client/allocrunner/taskrunner/sids_hook.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"io/ioutil"
     6  	"os"
     7  	"path/filepath"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/hashicorp/go-hclog"
    12  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
    13  	ti "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces"
    14  	"github.com/hashicorp/nomad/client/consul"
    15  	"github.com/hashicorp/nomad/nomad/structs"
    16  	"github.com/pkg/errors"
    17  )
    18  
    19  const (
    20  	// the name of this hook, used in logs
    21  	sidsHookName = "consul_si_token"
    22  
    23  	// sidsBackoffBaseline is the baseline time for exponential backoff when
    24  	// attempting to retrieve a Consul SI token
    25  	sidsBackoffBaseline = 5 * time.Second
    26  
    27  	// sidsBackoffLimit is the limit of the exponential backoff when attempting
    28  	// to retrieve a Consul SI token
    29  	sidsBackoffLimit = 3 * time.Minute
    30  
    31  	// sidsDerivationTimeout limits the amount of time we may spend trying to
    32  	// derive a SI token. If the hook does not get a token within this amount of
    33  	// time, the result is a failure.
    34  	sidsDerivationTimeout = 5 * time.Minute
    35  
    36  	// sidsTokenFile is the name of the file holding the Consul SI token inside
    37  	// the task's secret directory
    38  	sidsTokenFile = "si_token"
    39  
    40  	// sidsTokenFilePerms is the level of file permissions granted on the file
    41  	// in the secrets directory for the task
    42  	sidsTokenFilePerms = 0440
    43  )
    44  
    45  type sidsHookConfig struct {
    46  	alloc      *structs.Allocation
    47  	task       *structs.Task
    48  	sidsClient consul.ServiceIdentityAPI
    49  	lifecycle  ti.TaskLifecycle
    50  	logger     hclog.Logger
    51  }
    52  
    53  // Service Identities hook for managing SI tokens of connect enabled tasks.
    54  type sidsHook struct {
    55  	// alloc is the allocation
    56  	alloc *structs.Allocation
    57  
    58  	// taskName is the name of the task
    59  	task *structs.Task
    60  
    61  	// sidsClient is the Consul client [proxy] for requesting SI tokens
    62  	sidsClient consul.ServiceIdentityAPI
    63  
    64  	// lifecycle is used to signal, restart, and kill a task
    65  	lifecycle ti.TaskLifecycle
    66  
    67  	// derivationTimeout is the amount of time we may wait for Consul to successfully
    68  	// provide a SI token. Making this configurable for testing, otherwise
    69  	// default to sidsDerivationTimeout
    70  	derivationTimeout time.Duration
    71  
    72  	// logger is used to log
    73  	logger hclog.Logger
    74  
    75  	// lock variables that can be manipulated after hook creation
    76  	lock sync.Mutex
    77  	// firstRun keeps track of whether the hook is being called for the first
    78  	// time (for this task) during the lifespan of the Nomad Client process.
    79  	firstRun bool
    80  }
    81  
    82  func newSIDSHook(c sidsHookConfig) *sidsHook {
    83  	return &sidsHook{
    84  		alloc:             c.alloc,
    85  		task:              c.task,
    86  		sidsClient:        c.sidsClient,
    87  		lifecycle:         c.lifecycle,
    88  		derivationTimeout: sidsDerivationTimeout,
    89  		logger:            c.logger.Named(sidsHookName),
    90  		firstRun:          true,
    91  	}
    92  }
    93  
    94  func (h *sidsHook) Name() string {
    95  	return sidsHookName
    96  }
    97  
    98  func (h *sidsHook) Prestart(
    99  	ctx context.Context,
   100  	req *interfaces.TaskPrestartRequest,
   101  	resp *interfaces.TaskPrestartResponse) error {
   102  
   103  	h.lock.Lock()
   104  	defer h.lock.Unlock()
   105  
   106  	// do nothing if we have already done things
   107  	if h.earlyExit() {
   108  		resp.Done = true
   109  		return nil
   110  	}
   111  
   112  	// optimistically try to recover token from disk
   113  	token, err := h.recoverToken(req.TaskDir.SecretsDir)
   114  	if err != nil {
   115  		return err
   116  	}
   117  
   118  	// need to ask for a new SI token & persist it to disk
   119  	if token == "" {
   120  		if token, err = h.deriveSIToken(ctx); err != nil {
   121  			return err
   122  		}
   123  		if err := h.writeToken(req.TaskDir.SecretsDir, token); err != nil {
   124  			return err
   125  		}
   126  	}
   127  
   128  	h.logger.Info("derived SI token", "task", h.task.Name, "si_task", h.task.Kind.Value())
   129  
   130  	resp.Done = true
   131  	return nil
   132  }
   133  
   134  // earlyExit returns true if the Prestart hook has already been executed during
   135  // the instantiation of this task runner.
   136  //
   137  // assumes h is locked
   138  func (h *sidsHook) earlyExit() bool {
   139  	if h.firstRun {
   140  		h.firstRun = false
   141  		return false
   142  	}
   143  	return true
   144  }
   145  
   146  // writeToken writes token into the secrets directory for the task.
   147  func (h *sidsHook) writeToken(dir string, token string) error {
   148  	tokenPath := filepath.Join(dir, sidsTokenFile)
   149  	if err := ioutil.WriteFile(tokenPath, []byte(token), sidsTokenFilePerms); err != nil {
   150  		return errors.Wrap(err, "failed to write SI token")
   151  	}
   152  	return nil
   153  }
   154  
   155  // recoverToken returns the token saved to disk in the secrets directory for the
   156  // task if it exists, or the empty string if the file does not exist. an error
   157  // is returned only for some other (e.g. disk IO) error.
   158  func (h *sidsHook) recoverToken(dir string) (string, error) {
   159  	tokenPath := filepath.Join(dir, sidsTokenFile)
   160  	token, err := ioutil.ReadFile(tokenPath)
   161  	if err != nil {
   162  		if !os.IsNotExist(err) {
   163  			h.logger.Error("failed to recover SI token", "error", err)
   164  			return "", errors.Wrap(err, "failed to recover SI token")
   165  		}
   166  		h.logger.Trace("no pre-existing SI token to recover", "task", h.task.Name)
   167  		return "", nil // token file does not exist yet
   168  	}
   169  	h.logger.Trace("recovered pre-existing SI token", "task", h.task.Name)
   170  	return string(token), nil
   171  }
   172  
   173  // siDerivationResult is used to pass along the result of attempting to derive
   174  // an SI token between the goroutine doing the derivation and its caller
   175  type siDerivationResult struct {
   176  	token string
   177  	err   error
   178  }
   179  
   180  // deriveSIToken spawns and waits on a goroutine which will make attempts to
   181  // derive an SI token until a token is successfully created, or ctx is signaled
   182  // done.
   183  func (h *sidsHook) deriveSIToken(ctx context.Context) (string, error) {
   184  	ctx, cancel := context.WithTimeout(ctx, h.derivationTimeout)
   185  	defer cancel()
   186  
   187  	resultCh := make(chan siDerivationResult)
   188  
   189  	// keep trying to get the token in the background
   190  	go h.tryDerive(ctx, resultCh)
   191  
   192  	// wait until we get a token, or we get a signal to quit
   193  	for {
   194  		select {
   195  		case result := <-resultCh:
   196  			if result.err != nil {
   197  				h.logger.Error("failed to derive SI token", "error", result.err)
   198  				h.kill(ctx, errors.Wrap(result.err, "failed to derive SI token"))
   199  				return "", result.err
   200  			}
   201  			return result.token, nil
   202  		case <-ctx.Done():
   203  			return "", ctx.Err()
   204  		}
   205  	}
   206  }
   207  
   208  func (h *sidsHook) kill(ctx context.Context, reason error) {
   209  	if err := h.lifecycle.Kill(ctx,
   210  		structs.NewTaskEvent(structs.TaskKilling).
   211  			SetFailsTask().
   212  			SetDisplayMessage(reason.Error()),
   213  	); err != nil {
   214  		h.logger.Error("failed to kill task", "kill_reason", reason, "error", err)
   215  	}
   216  }
   217  
   218  // tryDerive loops forever until a token is created, or ctx is done.
   219  func (h *sidsHook) tryDerive(ctx context.Context, ch chan<- siDerivationResult) {
   220  	for attempt := 0; backoff(ctx, attempt); attempt++ {
   221  
   222  		tokens, err := h.sidsClient.DeriveSITokens(h.alloc, []string{h.task.Name})
   223  
   224  		switch {
   225  		case err == nil:
   226  			token, exists := tokens[h.task.Name]
   227  			if !exists {
   228  				err := errors.New("response does not include token for task")
   229  				h.logger.Error("derive SI token is missing token for task", "error", err, "task", h.task.Name)
   230  				ch <- siDerivationResult{token: "", err: err}
   231  				return
   232  			}
   233  			ch <- siDerivationResult{token: token, err: nil}
   234  			return
   235  		case structs.IsServerSide(err):
   236  			// the error is known to be a server problem, just die
   237  			h.logger.Error("failed to derive SI token", "error", err, "task", h.task.Name, "server_side", true)
   238  			ch <- siDerivationResult{token: "", err: err}
   239  			return
   240  		case !structs.IsRecoverable(err):
   241  			// the error is known not to be recoverable, just die
   242  			h.logger.Error("failed to derive SI token", "error", err, "task", h.task.Name, "recoverable", false)
   243  			ch <- siDerivationResult{token: "", err: err}
   244  			return
   245  
   246  		default:
   247  			// the error is marked recoverable, retry after some backoff
   248  			h.logger.Error("failed attempt to derive SI token", "error", err, "recoverable", true)
   249  		}
   250  	}
   251  }
   252  
   253  func backoff(ctx context.Context, attempt int) bool {
   254  	next := computeBackoff(attempt)
   255  	select {
   256  	case <-ctx.Done():
   257  		return false
   258  	case <-time.After(next):
   259  		return true
   260  	}
   261  }
   262  
   263  func computeBackoff(attempt int) time.Duration {
   264  	switch attempt {
   265  	case 0:
   266  		return 0
   267  	case 1:
   268  		// go fast on first retry, because a unit test should be fast
   269  		return 100 * time.Millisecond
   270  	default:
   271  		wait := time.Duration(attempt) * sidsBackoffBaseline
   272  		if wait > sidsBackoffLimit {
   273  			wait = sidsBackoffLimit
   274  		}
   275  		return wait
   276  	}
   277  }