github.com/iqoqo/nomad@v0.11.3-0.20200911112621-d7021c74d101/client/allocrunner/taskrunner/vault_hook.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"os"
     8  	"path/filepath"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/hashicorp/consul-template/signals"
    13  	log "github.com/hashicorp/go-hclog"
    14  
    15  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
    16  	ti "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces"
    17  	"github.com/hashicorp/nomad/client/vaultclient"
    18  	"github.com/hashicorp/nomad/nomad/structs"
    19  )
    20  
    21  const (
    22  	// vaultBackoffBaseline is the baseline time for exponential backoff when
    23  	// attempting to retrieve a Vault token
    24  	vaultBackoffBaseline = 5 * time.Second
    25  
    26  	// vaultBackoffLimit is the limit of the exponential backoff when attempting
    27  	// to retrieve a Vault token
    28  	vaultBackoffLimit = 3 * time.Minute
    29  
    30  	// vaultTokenFile is the name of the file holding the Vault token inside the
    31  	// task's secret directory
    32  	vaultTokenFile = "vault_token"
    33  )
    34  
    35  type vaultTokenUpdateHandler interface {
    36  	updatedVaultToken(token string)
    37  }
    38  
    39  func (tr *TaskRunner) updatedVaultToken(token string) {
    40  	// Update the task runner and environment
    41  	tr.setVaultToken(token)
    42  
    43  	// Trigger update hooks with the new Vault token
    44  	tr.triggerUpdateHooks()
    45  }
    46  
    47  type vaultHookConfig struct {
    48  	vaultStanza *structs.Vault
    49  	client      vaultclient.VaultClient
    50  	events      ti.EventEmitter
    51  	lifecycle   ti.TaskLifecycle
    52  	updater     vaultTokenUpdateHandler
    53  	logger      log.Logger
    54  	alloc       *structs.Allocation
    55  	task        string
    56  }
    57  
    58  type vaultHook struct {
    59  	// vaultStanza is the vault stanza for the task
    60  	vaultStanza *structs.Vault
    61  
    62  	// eventEmitter is used to emit events to the task
    63  	eventEmitter ti.EventEmitter
    64  
    65  	// lifecycle is used to signal, restart and kill a task
    66  	lifecycle ti.TaskLifecycle
    67  
    68  	// updater is used to update the Vault token
    69  	updater vaultTokenUpdateHandler
    70  
    71  	// client is the Vault client to retrieve and renew the Vault token
    72  	client vaultclient.VaultClient
    73  
    74  	// logger is used to log
    75  	logger log.Logger
    76  
    77  	// ctx and cancel are used to kill the long running token manager
    78  	ctx    context.Context
    79  	cancel context.CancelFunc
    80  
    81  	// tokenPath is the path in which to read and write the token
    82  	tokenPath string
    83  
    84  	// alloc is the allocation
    85  	alloc *structs.Allocation
    86  
    87  	// taskName is the name of the task
    88  	taskName string
    89  
    90  	// firstRun stores whether it is the first run for the hook
    91  	firstRun bool
    92  
    93  	// future is used to wait on retrieving a Vault token
    94  	future *tokenFuture
    95  }
    96  
    97  func newVaultHook(config *vaultHookConfig) *vaultHook {
    98  	ctx, cancel := context.WithCancel(context.Background())
    99  	h := &vaultHook{
   100  		vaultStanza:  config.vaultStanza,
   101  		client:       config.client,
   102  		eventEmitter: config.events,
   103  		lifecycle:    config.lifecycle,
   104  		updater:      config.updater,
   105  		alloc:        config.alloc,
   106  		taskName:     config.task,
   107  		firstRun:     true,
   108  		ctx:          ctx,
   109  		cancel:       cancel,
   110  		future:       newTokenFuture(),
   111  	}
   112  	h.logger = config.logger.Named(h.Name())
   113  	return h
   114  }
   115  
   116  func (*vaultHook) Name() string {
   117  	return "vault"
   118  }
   119  
   120  func (h *vaultHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
   121  	// If we have already run prestart before exit early. We do not use the
   122  	// PrestartDone value because we want to recover the token on restoration.
   123  	first := h.firstRun
   124  	h.firstRun = false
   125  	if !first {
   126  		return nil
   127  	}
   128  
   129  	// Try to recover a token if it was previously written in the secrets
   130  	// directory
   131  	recoveredToken := ""
   132  	h.tokenPath = filepath.Join(req.TaskDir.SecretsDir, vaultTokenFile)
   133  	data, err := ioutil.ReadFile(h.tokenPath)
   134  	if err != nil {
   135  		if !os.IsNotExist(err) {
   136  			return fmt.Errorf("failed to recover vault token: %v", err)
   137  		}
   138  
   139  		// Token file doesn't exist
   140  	} else {
   141  		// Store the recovered token
   142  		recoveredToken = string(data)
   143  	}
   144  
   145  	// Launch the token manager
   146  	go h.run(recoveredToken)
   147  
   148  	// Block until we get a token
   149  	select {
   150  	case <-h.future.Wait():
   151  	case <-ctx.Done():
   152  		return nil
   153  	}
   154  
   155  	h.updater.updatedVaultToken(h.future.Get())
   156  	return nil
   157  }
   158  
   159  func (h *vaultHook) Stop(ctx context.Context, req *interfaces.TaskStopRequest, resp *interfaces.TaskStopResponse) error {
   160  	// Shutdown any created manager
   161  	h.cancel()
   162  	return nil
   163  }
   164  
   165  func (h *vaultHook) Shutdown() {
   166  	h.cancel()
   167  }
   168  
   169  // run should be called in a go-routine and manages the derivation, renewal and
   170  // handling of errors with the Vault token. The optional parameter allows
   171  // setting the initial Vault token. This is useful when the Vault token is
   172  // recovered off disk.
   173  func (h *vaultHook) run(token string) {
   174  	// Helper for stopping token renewal
   175  	stopRenewal := func() {
   176  		if err := h.client.StopRenewToken(h.future.Get()); err != nil {
   177  			h.logger.Warn("failed to stop token renewal", "error", err)
   178  		}
   179  	}
   180  
   181  	// updatedToken lets us store state between loops. If true, a new token
   182  	// has been retrieved and we need to apply the Vault change mode
   183  	var updatedToken bool
   184  
   185  OUTER:
   186  	for {
   187  		// Check if we should exit
   188  		if h.ctx.Err() != nil {
   189  			stopRenewal()
   190  			return
   191  		}
   192  
   193  		// Clear the token
   194  		h.future.Clear()
   195  
   196  		// Check if there already is a token which can be the case for
   197  		// restoring the TaskRunner
   198  		if token == "" {
   199  			// Get a token
   200  			var exit bool
   201  			token, exit = h.deriveVaultToken()
   202  			if exit {
   203  				// Exit the manager
   204  				return
   205  			}
   206  
   207  			// Write the token to disk
   208  			if err := h.writeToken(token); err != nil {
   209  				errorString := "failed to write Vault token to disk"
   210  				h.logger.Error(errorString, "error", err)
   211  				h.lifecycle.Kill(h.ctx,
   212  					structs.NewTaskEvent(structs.TaskKilling).
   213  						SetFailsTask().
   214  						SetDisplayMessage(fmt.Sprintf("Vault %v", errorString)))
   215  				return
   216  			}
   217  		}
   218  
   219  		// Start the renewal process
   220  		renewCh, err := h.client.RenewToken(token, 30)
   221  
   222  		// An error returned means the token is not being renewed
   223  		if err != nil {
   224  			h.logger.Error("failed to start renewal of Vault token", "error", err)
   225  			token = ""
   226  			goto OUTER
   227  		}
   228  
   229  		// The Vault token is valid now, so set it
   230  		h.future.Set(token)
   231  
   232  		if updatedToken {
   233  			switch h.vaultStanza.ChangeMode {
   234  			case structs.VaultChangeModeSignal:
   235  				s, err := signals.Parse(h.vaultStanza.ChangeSignal)
   236  				if err != nil {
   237  					h.logger.Error("failed to parse signal", "error", err)
   238  					h.lifecycle.Kill(h.ctx,
   239  						structs.NewTaskEvent(structs.TaskKilling).
   240  							SetFailsTask().
   241  							SetDisplayMessage(fmt.Sprintf("Vault: failed to parse signal: %v", err)))
   242  					return
   243  				}
   244  
   245  				event := structs.NewTaskEvent(structs.TaskSignaling).SetTaskSignal(s).SetDisplayMessage("Vault: new Vault token acquired")
   246  				if err := h.lifecycle.Signal(event, h.vaultStanza.ChangeSignal); err != nil {
   247  					h.logger.Error("failed to send signal", "error", err)
   248  					h.lifecycle.Kill(h.ctx,
   249  						structs.NewTaskEvent(structs.TaskKilling).
   250  							SetFailsTask().
   251  							SetDisplayMessage(fmt.Sprintf("Vault: failed to send signal: %v", err)))
   252  					return
   253  				}
   254  			case structs.VaultChangeModeRestart:
   255  				const noFailure = false
   256  				h.lifecycle.Restart(h.ctx,
   257  					structs.NewTaskEvent(structs.TaskRestartSignal).
   258  						SetDisplayMessage("Vault: new Vault token acquired"), false)
   259  			case structs.VaultChangeModeNoop:
   260  				fallthrough
   261  			default:
   262  				h.logger.Error("invalid Vault change mode", "mode", h.vaultStanza.ChangeMode)
   263  			}
   264  
   265  			// We have handled it
   266  			updatedToken = false
   267  
   268  			// Call the handler
   269  			h.updater.updatedVaultToken(token)
   270  		}
   271  
   272  		// Start watching for renewal errors
   273  		select {
   274  		case err := <-renewCh:
   275  			// Clear the token
   276  			token = ""
   277  			h.logger.Error("failed to renew Vault token", "error", err)
   278  			stopRenewal()
   279  
   280  			// Check if we have to do anything
   281  			if h.vaultStanza.ChangeMode != structs.VaultChangeModeNoop {
   282  				updatedToken = true
   283  			}
   284  		case <-h.ctx.Done():
   285  			stopRenewal()
   286  			return
   287  		}
   288  	}
   289  }
   290  
   291  // deriveVaultToken derives the Vault token using exponential backoffs. It
   292  // returns the Vault token and whether the manager should exit.
   293  func (h *vaultHook) deriveVaultToken() (token string, exit bool) {
   294  	attempts := 0
   295  	for {
   296  		tokens, err := h.client.DeriveToken(h.alloc, []string{h.taskName})
   297  		if err == nil {
   298  			return tokens[h.taskName], false
   299  		}
   300  
   301  		// Check if this is a server side error
   302  		if structs.IsServerSide(err) {
   303  			h.logger.Error("failed to derive Vault token", "error", err, "server_side", true)
   304  			h.lifecycle.Kill(h.ctx,
   305  				structs.NewTaskEvent(structs.TaskKilling).
   306  					SetFailsTask().
   307  					SetDisplayMessage(fmt.Sprintf("Vault: server failed to derive vault token: %v", err)))
   308  			return "", true
   309  		}
   310  
   311  		// Check if we can't recover from the error
   312  		if !structs.IsRecoverable(err) {
   313  			h.logger.Error("failed to derive Vault token", "error", err, "recoverable", false)
   314  			h.lifecycle.Kill(h.ctx,
   315  				structs.NewTaskEvent(structs.TaskKilling).
   316  					SetFailsTask().
   317  					SetDisplayMessage(fmt.Sprintf("Vault: failed to derive vault token: %v", err)))
   318  			return "", true
   319  		}
   320  
   321  		// Handle the retry case
   322  		backoff := (1 << (2 * uint64(attempts))) * vaultBackoffBaseline
   323  		if backoff > vaultBackoffLimit {
   324  			backoff = vaultBackoffLimit
   325  		}
   326  		h.logger.Error("failed to derive Vault token", "error", err, "recoverable", true, "backoff", backoff)
   327  
   328  		attempts++
   329  
   330  		// Wait till retrying
   331  		select {
   332  		case <-h.ctx.Done():
   333  			return "", true
   334  		case <-time.After(backoff):
   335  		}
   336  	}
   337  }
   338  
   339  // writeToken writes the given token to disk
   340  func (h *vaultHook) writeToken(token string) error {
   341  	if err := ioutil.WriteFile(h.tokenPath, []byte(token), 0666); err != nil {
   342  		return fmt.Errorf("failed to write vault token: %v", err)
   343  	}
   344  
   345  	return nil
   346  }
   347  
   348  // tokenFuture stores the Vault token and allows consumers to block till a valid
   349  // token exists
   350  type tokenFuture struct {
   351  	waiting []chan struct{}
   352  	token   string
   353  	set     bool
   354  	m       sync.Mutex
   355  }
   356  
   357  // newTokenFuture returns a new token future without any token set
   358  func newTokenFuture() *tokenFuture {
   359  	return &tokenFuture{}
   360  }
   361  
   362  // Wait returns a channel that can be waited on. When this channel unblocks, a
   363  // valid token will be available via the Get method
   364  func (f *tokenFuture) Wait() <-chan struct{} {
   365  	f.m.Lock()
   366  	defer f.m.Unlock()
   367  
   368  	c := make(chan struct{})
   369  	if f.set {
   370  		close(c)
   371  		return c
   372  	}
   373  
   374  	f.waiting = append(f.waiting, c)
   375  	return c
   376  }
   377  
   378  // Set sets the token value and unblocks any caller of Wait
   379  func (f *tokenFuture) Set(token string) *tokenFuture {
   380  	f.m.Lock()
   381  	defer f.m.Unlock()
   382  
   383  	f.set = true
   384  	f.token = token
   385  	for _, w := range f.waiting {
   386  		close(w)
   387  	}
   388  	f.waiting = nil
   389  	return f
   390  }
   391  
   392  // Clear clears the set vault token.
   393  func (f *tokenFuture) Clear() *tokenFuture {
   394  	f.m.Lock()
   395  	defer f.m.Unlock()
   396  
   397  	f.token = ""
   398  	f.set = false
   399  	return f
   400  }
   401  
   402  // Get returns the set Vault token
   403  func (f *tokenFuture) Get() string {
   404  	f.m.Lock()
   405  	defer f.m.Unlock()
   406  	return f.token
   407  }