github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/allocrunner/taskrunner/vault_hook.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"os"
     8  	"path/filepath"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/hashicorp/consul-template/signals"
    13  	log "github.com/hashicorp/go-hclog"
    14  
    15  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
    16  	ti "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces"
    17  	"github.com/hashicorp/nomad/client/vaultclient"
    18  	"github.com/hashicorp/nomad/nomad/structs"
    19  )
    20  
    21  const (
    22  	// vaultBackoffBaseline is the baseline time for exponential backoff when
    23  	// attempting to retrieve a Vault token
    24  	vaultBackoffBaseline = 5 * time.Second
    25  
    26  	// vaultBackoffLimit is the limit of the exponential backoff when attempting
    27  	// to retrieve a Vault token
    28  	vaultBackoffLimit = 3 * time.Minute
    29  
    30  	// vaultTokenFile is the name of the file holding the Vault token inside the
    31  	// task's secret directory
    32  	vaultTokenFile = "vault_token"
    33  )
    34  
    35  type vaultTokenUpdateHandler interface {
    36  	updatedVaultToken(token string)
    37  }
    38  
    39  func (tr *TaskRunner) updatedVaultToken(token string) {
    40  	// Update the task runner and environment
    41  	tr.setVaultToken(token)
    42  
    43  	// Trigger update hooks with the new Vault token
    44  	tr.triggerUpdateHooks()
    45  }
    46  
    47  type vaultHookConfig struct {
    48  	vaultStanza *structs.Vault
    49  	client      vaultclient.VaultClient
    50  	events      ti.EventEmitter
    51  	lifecycle   ti.TaskLifecycle
    52  	updater     vaultTokenUpdateHandler
    53  	logger      log.Logger
    54  	alloc       *structs.Allocation
    55  	task        string
    56  }
    57  
    58  type vaultHook struct {
    59  	// vaultStanza is the vault stanza for the task
    60  	vaultStanza *structs.Vault
    61  
    62  	// eventEmitter is used to emit events to the task
    63  	eventEmitter ti.EventEmitter
    64  
    65  	// lifecycle is used to signal, restart and kill a task
    66  	lifecycle ti.TaskLifecycle
    67  
    68  	// updater is used to update the Vault token
    69  	updater vaultTokenUpdateHandler
    70  
    71  	// client is the Vault client to retrieve and renew the Vault token
    72  	client vaultclient.VaultClient
    73  
    74  	// logger is used to log
    75  	logger log.Logger
    76  
    77  	// ctx and cancel are used to kill the long running token manager
    78  	ctx    context.Context
    79  	cancel context.CancelFunc
    80  
    81  	// tokenPath is the path in which to read and write the token
    82  	tokenPath string
    83  
    84  	// alloc is the allocation
    85  	alloc *structs.Allocation
    86  
    87  	// taskName is the name of the task
    88  	taskName string
    89  
    90  	// firstRun stores whether it is the first run for the hook
    91  	firstRun bool
    92  
    93  	// future is used to wait on retrieving a Vault token
    94  	future *tokenFuture
    95  }
    96  
    97  func newVaultHook(config *vaultHookConfig) *vaultHook {
    98  	ctx, cancel := context.WithCancel(context.Background())
    99  	h := &vaultHook{
   100  		vaultStanza:  config.vaultStanza,
   101  		client:       config.client,
   102  		eventEmitter: config.events,
   103  		lifecycle:    config.lifecycle,
   104  		updater:      config.updater,
   105  		alloc:        config.alloc,
   106  		taskName:     config.task,
   107  		firstRun:     true,
   108  		ctx:          ctx,
   109  		cancel:       cancel,
   110  		future:       newTokenFuture(),
   111  	}
   112  	h.logger = config.logger.Named(h.Name())
   113  	return h
   114  }
   115  
   116  func (*vaultHook) Name() string {
   117  	return "vault"
   118  }
   119  
   120  func (h *vaultHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
   121  	// If we have already run prestart before exit early. We do not use the
   122  	// PrestartDone value because we want to recover the token on restoration.
   123  	first := h.firstRun
   124  	h.firstRun = false
   125  	if !first {
   126  		return nil
   127  	}
   128  
   129  	// Try to recover a token if it was previously written in the secrets
   130  	// directory
   131  	recoveredToken := ""
   132  	h.tokenPath = filepath.Join(req.TaskDir.SecretsDir, vaultTokenFile)
   133  	data, err := ioutil.ReadFile(h.tokenPath)
   134  	if err != nil {
   135  		if !os.IsNotExist(err) {
   136  			return fmt.Errorf("failed to recover vault token: %v", err)
   137  		}
   138  
   139  		// Token file doesn't exist
   140  	} else {
   141  		// Store the recovered token
   142  		recoveredToken = string(data)
   143  	}
   144  
   145  	// Launch the token manager
   146  	go h.run(recoveredToken)
   147  
   148  	// Block until we get a token
   149  	select {
   150  	case <-h.future.Wait():
   151  	case <-ctx.Done():
   152  		return nil
   153  	}
   154  
   155  	h.updater.updatedVaultToken(h.future.Get())
   156  	return nil
   157  }
   158  
   159  func (h *vaultHook) Stop(ctx context.Context, req *interfaces.TaskStopRequest, resp *interfaces.TaskStopResponse) error {
   160  	// Shutdown any created manager
   161  	h.cancel()
   162  	return nil
   163  }
   164  
   165  func (h *vaultHook) Shutdown() {
   166  	h.cancel()
   167  }
   168  
   169  // run should be called in a go-routine and manages the derivation, renewal and
   170  // handling of errors with the Vault token. The optional parameter allows
   171  // setting the initial Vault token. This is useful when the Vault token is
   172  // recovered off disk.
   173  func (h *vaultHook) run(token string) {
   174  	// Helper for stopping token renewal
   175  	stopRenewal := func() {
   176  		if err := h.client.StopRenewToken(h.future.Get()); err != nil {
   177  			h.logger.Warn("failed to stop token renewal", "error", err)
   178  		}
   179  	}
   180  
   181  	// updatedToken lets us store state between loops. If true, a new token
   182  	// has been retrieved and we need to apply the Vault change mode
   183  	var updatedToken bool
   184  
   185  OUTER:
   186  	for {
   187  		// Check if we should exit
   188  		if h.ctx.Err() != nil {
   189  			stopRenewal()
   190  			return
   191  		}
   192  
   193  		// Clear the token
   194  		h.future.Clear()
   195  
   196  		// Check if there already is a token which can be the case for
   197  		// restoring the TaskRunner
   198  		if token == "" {
   199  			// Get a token
   200  			var exit bool
   201  			token, exit = h.deriveVaultToken()
   202  			if exit {
   203  				// Exit the manager
   204  				return
   205  			}
   206  
   207  			// Write the token to disk
   208  			if err := h.writeToken(token); err != nil {
   209  				errorString := "failed to write Vault token to disk"
   210  				h.logger.Error(errorString, "error", err)
   211  				h.lifecycle.Kill(h.ctx,
   212  					structs.NewTaskEvent(structs.TaskKilling).
   213  						SetFailsTask().
   214  						SetDisplayMessage(fmt.Sprintf("Vault %v", errorString)))
   215  				return
   216  			}
   217  		}
   218  
   219  		// Start the renewal process.
   220  		//
   221  		// This is the initial renew of the token which we derived from the
   222  		// server. The client does not know how long it took for the token to
   223  		// be generated and derived and also wants to gain control of the
   224  		// process quickly, but not too quickly. We therefore use a hardcoded
   225  		// increment value of 30; this value without a suffix is in seconds.
   226  		//
   227  		// If Vault is having availability issues or is overloaded, a large
   228  		// number of initial token renews can exacerbate the problem.
   229  		renewCh, err := h.client.RenewToken(token, 30)
   230  
   231  		// An error returned means the token is not being renewed
   232  		if err != nil {
   233  			h.logger.Error("failed to start renewal of Vault token", "error", err)
   234  			token = ""
   235  			goto OUTER
   236  		}
   237  
   238  		// The Vault token is valid now, so set it
   239  		h.future.Set(token)
   240  
   241  		if updatedToken {
   242  			switch h.vaultStanza.ChangeMode {
   243  			case structs.VaultChangeModeSignal:
   244  				s, err := signals.Parse(h.vaultStanza.ChangeSignal)
   245  				if err != nil {
   246  					h.logger.Error("failed to parse signal", "error", err)
   247  					h.lifecycle.Kill(h.ctx,
   248  						structs.NewTaskEvent(structs.TaskKilling).
   249  							SetFailsTask().
   250  							SetDisplayMessage(fmt.Sprintf("Vault: failed to parse signal: %v", err)))
   251  					return
   252  				}
   253  
   254  				event := structs.NewTaskEvent(structs.TaskSignaling).SetTaskSignal(s).SetDisplayMessage("Vault: new Vault token acquired")
   255  				if err := h.lifecycle.Signal(event, h.vaultStanza.ChangeSignal); err != nil {
   256  					h.logger.Error("failed to send signal", "error", err)
   257  					h.lifecycle.Kill(h.ctx,
   258  						structs.NewTaskEvent(structs.TaskKilling).
   259  							SetFailsTask().
   260  							SetDisplayMessage(fmt.Sprintf("Vault: failed to send signal: %v", err)))
   261  					return
   262  				}
   263  			case structs.VaultChangeModeRestart:
   264  				const noFailure = false
   265  				h.lifecycle.Restart(h.ctx,
   266  					structs.NewTaskEvent(structs.TaskRestartSignal).
   267  						SetDisplayMessage("Vault: new Vault token acquired"), false)
   268  			case structs.VaultChangeModeNoop:
   269  				fallthrough
   270  			default:
   271  				h.logger.Error("invalid Vault change mode", "mode", h.vaultStanza.ChangeMode)
   272  			}
   273  
   274  			// We have handled it
   275  			updatedToken = false
   276  
   277  			// Call the handler
   278  			h.updater.updatedVaultToken(token)
   279  		}
   280  
   281  		// Start watching for renewal errors
   282  		select {
   283  		case err := <-renewCh:
   284  			// Clear the token
   285  			token = ""
   286  			h.logger.Error("failed to renew Vault token", "error", err)
   287  			stopRenewal()
   288  			updatedToken = true
   289  		case <-h.ctx.Done():
   290  			stopRenewal()
   291  			return
   292  		}
   293  	}
   294  }
   295  
   296  // deriveVaultToken derives the Vault token using exponential backoffs. It
   297  // returns the Vault token and whether the manager should exit.
   298  func (h *vaultHook) deriveVaultToken() (token string, exit bool) {
   299  	attempts := 0
   300  	for {
   301  		tokens, err := h.client.DeriveToken(h.alloc, []string{h.taskName})
   302  		if err == nil {
   303  			return tokens[h.taskName], false
   304  		}
   305  
   306  		// Check if this is a server side error
   307  		if structs.IsServerSide(err) {
   308  			h.logger.Error("failed to derive Vault token", "error", err, "server_side", true)
   309  			h.lifecycle.Kill(h.ctx,
   310  				structs.NewTaskEvent(structs.TaskKilling).
   311  					SetFailsTask().
   312  					SetDisplayMessage(fmt.Sprintf("Vault: server failed to derive vault token: %v", err)))
   313  			return "", true
   314  		}
   315  
   316  		// Check if we can't recover from the error
   317  		if !structs.IsRecoverable(err) {
   318  			h.logger.Error("failed to derive Vault token", "error", err, "recoverable", false)
   319  			h.lifecycle.Kill(h.ctx,
   320  				structs.NewTaskEvent(structs.TaskKilling).
   321  					SetFailsTask().
   322  					SetDisplayMessage(fmt.Sprintf("Vault: failed to derive vault token: %v", err)))
   323  			return "", true
   324  		}
   325  
   326  		// Handle the retry case
   327  		backoff := (1 << (2 * uint64(attempts))) * vaultBackoffBaseline
   328  		if backoff > vaultBackoffLimit {
   329  			backoff = vaultBackoffLimit
   330  		}
   331  		h.logger.Error("failed to derive Vault token", "error", err, "recoverable", true, "backoff", backoff)
   332  
   333  		attempts++
   334  
   335  		// Wait till retrying
   336  		select {
   337  		case <-h.ctx.Done():
   338  			return "", true
   339  		case <-time.After(backoff):
   340  		}
   341  	}
   342  }
   343  
   344  // writeToken writes the given token to disk
   345  func (h *vaultHook) writeToken(token string) error {
   346  	if err := ioutil.WriteFile(h.tokenPath, []byte(token), 0666); err != nil {
   347  		return fmt.Errorf("failed to write vault token: %v", err)
   348  	}
   349  
   350  	return nil
   351  }
   352  
   353  // tokenFuture stores the Vault token and allows consumers to block till a valid
   354  // token exists
   355  type tokenFuture struct {
   356  	waiting []chan struct{}
   357  	token   string
   358  	set     bool
   359  	m       sync.Mutex
   360  }
   361  
   362  // newTokenFuture returns a new token future without any token set
   363  func newTokenFuture() *tokenFuture {
   364  	return &tokenFuture{}
   365  }
   366  
   367  // Wait returns a channel that can be waited on. When this channel unblocks, a
   368  // valid token will be available via the Get method
   369  func (f *tokenFuture) Wait() <-chan struct{} {
   370  	f.m.Lock()
   371  	defer f.m.Unlock()
   372  
   373  	c := make(chan struct{})
   374  	if f.set {
   375  		close(c)
   376  		return c
   377  	}
   378  
   379  	f.waiting = append(f.waiting, c)
   380  	return c
   381  }
   382  
   383  // Set sets the token value and unblocks any caller of Wait
   384  func (f *tokenFuture) Set(token string) *tokenFuture {
   385  	f.m.Lock()
   386  	defer f.m.Unlock()
   387  
   388  	f.set = true
   389  	f.token = token
   390  	for _, w := range f.waiting {
   391  		close(w)
   392  	}
   393  	f.waiting = nil
   394  	return f
   395  }
   396  
   397  // Clear clears the set vault token.
   398  func (f *tokenFuture) Clear() *tokenFuture {
   399  	f.m.Lock()
   400  	defer f.m.Unlock()
   401  
   402  	f.token = ""
   403  	f.set = false
   404  	return f
   405  }
   406  
   407  // Get returns the set Vault token
   408  func (f *tokenFuture) Get() string {
   409  	f.m.Lock()
   410  	defer f.m.Unlock()
   411  	return f.token
   412  }