github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/allocrunner/taskrunner/vault_hook.go (about) 1 package taskrunner 2 3 import ( 4 "context" 5 "fmt" 6 "io/ioutil" 7 "os" 8 "path/filepath" 9 "sync" 10 "time" 11 12 "github.com/hashicorp/consul-template/signals" 13 log "github.com/hashicorp/go-hclog" 14 15 "github.com/hashicorp/nomad/client/allocrunner/interfaces" 16 ti "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces" 17 "github.com/hashicorp/nomad/client/vaultclient" 18 "github.com/hashicorp/nomad/nomad/structs" 19 ) 20 21 const ( 22 // vaultBackoffBaseline is the baseline time for exponential backoff when 23 // attempting to retrieve a Vault token 24 vaultBackoffBaseline = 5 * time.Second 25 26 // vaultBackoffLimit is the limit of the exponential backoff when attempting 27 // to retrieve a Vault token 28 vaultBackoffLimit = 3 * time.Minute 29 30 // vaultTokenFile is the name of the file holding the Vault token inside the 31 // task's secret directory 32 vaultTokenFile = "vault_token" 33 ) 34 35 type vaultTokenUpdateHandler interface { 36 updatedVaultToken(token string) 37 } 38 39 func (tr *TaskRunner) updatedVaultToken(token string) { 40 // Update the task runner and environment 41 tr.setVaultToken(token) 42 43 // Trigger update hooks with the new Vault token 44 tr.triggerUpdateHooks() 45 } 46 47 type vaultHookConfig struct { 48 vaultStanza *structs.Vault 49 client vaultclient.VaultClient 50 events ti.EventEmitter 51 lifecycle ti.TaskLifecycle 52 updater vaultTokenUpdateHandler 53 logger log.Logger 54 alloc *structs.Allocation 55 task string 56 } 57 58 type vaultHook struct { 59 // vaultStanza is the vault stanza for the task 60 vaultStanza *structs.Vault 61 62 // eventEmitter is used to emit events to the task 63 eventEmitter ti.EventEmitter 64 65 // lifecycle is used to signal, restart and kill a task 66 lifecycle ti.TaskLifecycle 67 68 // updater is used to update the Vault token 69 updater vaultTokenUpdateHandler 70 71 // client is the Vault client to retrieve and renew the Vault token 72 client vaultclient.VaultClient 73 74 // logger is used to log 75 logger log.Logger 76 77 // ctx and cancel are used to kill the long running token manager 78 ctx context.Context 79 cancel context.CancelFunc 80 81 // tokenPath is the path in which to read and write the token 82 tokenPath string 83 84 // alloc is the allocation 85 alloc *structs.Allocation 86 87 // taskName is the name of the task 88 taskName string 89 90 // firstRun stores whether it is the first run for the hook 91 firstRun bool 92 93 // future is used to wait on retrieving a Vault token 94 future *tokenFuture 95 } 96 97 func newVaultHook(config *vaultHookConfig) *vaultHook { 98 ctx, cancel := context.WithCancel(context.Background()) 99 h := &vaultHook{ 100 vaultStanza: config.vaultStanza, 101 client: config.client, 102 eventEmitter: config.events, 103 lifecycle: config.lifecycle, 104 updater: config.updater, 105 alloc: config.alloc, 106 taskName: config.task, 107 firstRun: true, 108 ctx: ctx, 109 cancel: cancel, 110 future: newTokenFuture(), 111 } 112 h.logger = config.logger.Named(h.Name()) 113 return h 114 } 115 116 func (*vaultHook) Name() string { 117 return "vault" 118 } 119 120 func (h *vaultHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error { 121 // If we have already run prestart before exit early. We do not use the 122 // PrestartDone value because we want to recover the token on restoration. 123 first := h.firstRun 124 h.firstRun = false 125 if !first { 126 return nil 127 } 128 129 // Try to recover a token if it was previously written in the secrets 130 // directory 131 recoveredToken := "" 132 h.tokenPath = filepath.Join(req.TaskDir.SecretsDir, vaultTokenFile) 133 data, err := ioutil.ReadFile(h.tokenPath) 134 if err != nil { 135 if !os.IsNotExist(err) { 136 return fmt.Errorf("failed to recover vault token: %v", err) 137 } 138 139 // Token file doesn't exist 140 } else { 141 // Store the recovered token 142 recoveredToken = string(data) 143 } 144 145 // Launch the token manager 146 go h.run(recoveredToken) 147 148 // Block until we get a token 149 select { 150 case <-h.future.Wait(): 151 case <-ctx.Done(): 152 return nil 153 } 154 155 h.updater.updatedVaultToken(h.future.Get()) 156 return nil 157 } 158 159 func (h *vaultHook) Stop(ctx context.Context, req *interfaces.TaskStopRequest, resp *interfaces.TaskStopResponse) error { 160 // Shutdown any created manager 161 h.cancel() 162 return nil 163 } 164 165 func (h *vaultHook) Shutdown() { 166 h.cancel() 167 } 168 169 // run should be called in a go-routine and manages the derivation, renewal and 170 // handling of errors with the Vault token. The optional parameter allows 171 // setting the initial Vault token. This is useful when the Vault token is 172 // recovered off disk. 173 func (h *vaultHook) run(token string) { 174 // Helper for stopping token renewal 175 stopRenewal := func() { 176 if err := h.client.StopRenewToken(h.future.Get()); err != nil { 177 h.logger.Warn("failed to stop token renewal", "error", err) 178 } 179 } 180 181 // updatedToken lets us store state between loops. If true, a new token 182 // has been retrieved and we need to apply the Vault change mode 183 var updatedToken bool 184 185 OUTER: 186 for { 187 // Check if we should exit 188 if h.ctx.Err() != nil { 189 stopRenewal() 190 return 191 } 192 193 // Clear the token 194 h.future.Clear() 195 196 // Check if there already is a token which can be the case for 197 // restoring the TaskRunner 198 if token == "" { 199 // Get a token 200 var exit bool 201 token, exit = h.deriveVaultToken() 202 if exit { 203 // Exit the manager 204 return 205 } 206 207 // Write the token to disk 208 if err := h.writeToken(token); err != nil { 209 errorString := "failed to write Vault token to disk" 210 h.logger.Error(errorString, "error", err) 211 h.lifecycle.Kill(h.ctx, 212 structs.NewTaskEvent(structs.TaskKilling). 213 SetFailsTask(). 214 SetDisplayMessage(fmt.Sprintf("Vault %v", errorString))) 215 return 216 } 217 } 218 219 // Start the renewal process. 220 // 221 // This is the initial renew of the token which we derived from the 222 // server. The client does not know how long it took for the token to 223 // be generated and derived and also wants to gain control of the 224 // process quickly, but not too quickly. We therefore use a hardcoded 225 // increment value of 30; this value without a suffix is in seconds. 226 // 227 // If Vault is having availability issues or is overloaded, a large 228 // number of initial token renews can exacerbate the problem. 229 renewCh, err := h.client.RenewToken(token, 30) 230 231 // An error returned means the token is not being renewed 232 if err != nil { 233 h.logger.Error("failed to start renewal of Vault token", "error", err) 234 token = "" 235 goto OUTER 236 } 237 238 // The Vault token is valid now, so set it 239 h.future.Set(token) 240 241 if updatedToken { 242 switch h.vaultStanza.ChangeMode { 243 case structs.VaultChangeModeSignal: 244 s, err := signals.Parse(h.vaultStanza.ChangeSignal) 245 if err != nil { 246 h.logger.Error("failed to parse signal", "error", err) 247 h.lifecycle.Kill(h.ctx, 248 structs.NewTaskEvent(structs.TaskKilling). 249 SetFailsTask(). 250 SetDisplayMessage(fmt.Sprintf("Vault: failed to parse signal: %v", err))) 251 return 252 } 253 254 event := structs.NewTaskEvent(structs.TaskSignaling).SetTaskSignal(s).SetDisplayMessage("Vault: new Vault token acquired") 255 if err := h.lifecycle.Signal(event, h.vaultStanza.ChangeSignal); err != nil { 256 h.logger.Error("failed to send signal", "error", err) 257 h.lifecycle.Kill(h.ctx, 258 structs.NewTaskEvent(structs.TaskKilling). 259 SetFailsTask(). 260 SetDisplayMessage(fmt.Sprintf("Vault: failed to send signal: %v", err))) 261 return 262 } 263 case structs.VaultChangeModeRestart: 264 const noFailure = false 265 h.lifecycle.Restart(h.ctx, 266 structs.NewTaskEvent(structs.TaskRestartSignal). 267 SetDisplayMessage("Vault: new Vault token acquired"), false) 268 case structs.VaultChangeModeNoop: 269 fallthrough 270 default: 271 h.logger.Error("invalid Vault change mode", "mode", h.vaultStanza.ChangeMode) 272 } 273 274 // We have handled it 275 updatedToken = false 276 277 // Call the handler 278 h.updater.updatedVaultToken(token) 279 } 280 281 // Start watching for renewal errors 282 select { 283 case err := <-renewCh: 284 // Clear the token 285 token = "" 286 h.logger.Error("failed to renew Vault token", "error", err) 287 stopRenewal() 288 updatedToken = true 289 case <-h.ctx.Done(): 290 stopRenewal() 291 return 292 } 293 } 294 } 295 296 // deriveVaultToken derives the Vault token using exponential backoffs. It 297 // returns the Vault token and whether the manager should exit. 298 func (h *vaultHook) deriveVaultToken() (token string, exit bool) { 299 attempts := 0 300 for { 301 tokens, err := h.client.DeriveToken(h.alloc, []string{h.taskName}) 302 if err == nil { 303 return tokens[h.taskName], false 304 } 305 306 // Check if this is a server side error 307 if structs.IsServerSide(err) { 308 h.logger.Error("failed to derive Vault token", "error", err, "server_side", true) 309 h.lifecycle.Kill(h.ctx, 310 structs.NewTaskEvent(structs.TaskKilling). 311 SetFailsTask(). 312 SetDisplayMessage(fmt.Sprintf("Vault: server failed to derive vault token: %v", err))) 313 return "", true 314 } 315 316 // Check if we can't recover from the error 317 if !structs.IsRecoverable(err) { 318 h.logger.Error("failed to derive Vault token", "error", err, "recoverable", false) 319 h.lifecycle.Kill(h.ctx, 320 structs.NewTaskEvent(structs.TaskKilling). 321 SetFailsTask(). 322 SetDisplayMessage(fmt.Sprintf("Vault: failed to derive vault token: %v", err))) 323 return "", true 324 } 325 326 // Handle the retry case 327 backoff := (1 << (2 * uint64(attempts))) * vaultBackoffBaseline 328 if backoff > vaultBackoffLimit { 329 backoff = vaultBackoffLimit 330 } 331 h.logger.Error("failed to derive Vault token", "error", err, "recoverable", true, "backoff", backoff) 332 333 attempts++ 334 335 // Wait till retrying 336 select { 337 case <-h.ctx.Done(): 338 return "", true 339 case <-time.After(backoff): 340 } 341 } 342 } 343 344 // writeToken writes the given token to disk 345 func (h *vaultHook) writeToken(token string) error { 346 if err := ioutil.WriteFile(h.tokenPath, []byte(token), 0666); err != nil { 347 return fmt.Errorf("failed to write vault token: %v", err) 348 } 349 350 return nil 351 } 352 353 // tokenFuture stores the Vault token and allows consumers to block till a valid 354 // token exists 355 type tokenFuture struct { 356 waiting []chan struct{} 357 token string 358 set bool 359 m sync.Mutex 360 } 361 362 // newTokenFuture returns a new token future without any token set 363 func newTokenFuture() *tokenFuture { 364 return &tokenFuture{} 365 } 366 367 // Wait returns a channel that can be waited on. When this channel unblocks, a 368 // valid token will be available via the Get method 369 func (f *tokenFuture) Wait() <-chan struct{} { 370 f.m.Lock() 371 defer f.m.Unlock() 372 373 c := make(chan struct{}) 374 if f.set { 375 close(c) 376 return c 377 } 378 379 f.waiting = append(f.waiting, c) 380 return c 381 } 382 383 // Set sets the token value and unblocks any caller of Wait 384 func (f *tokenFuture) Set(token string) *tokenFuture { 385 f.m.Lock() 386 defer f.m.Unlock() 387 388 f.set = true 389 f.token = token 390 for _, w := range f.waiting { 391 close(w) 392 } 393 f.waiting = nil 394 return f 395 } 396 397 // Clear clears the set vault token. 398 func (f *tokenFuture) Clear() *tokenFuture { 399 f.m.Lock() 400 defer f.m.Unlock() 401 402 f.token = "" 403 f.set = false 404 return f 405 } 406 407 // Get returns the set Vault token 408 func (f *tokenFuture) Get() string { 409 f.m.Lock() 410 defer f.m.Unlock() 411 return f.token 412 }