github.com/Ilhicas/nomad@v1.0.4-0.20210304152020-e86851182bc3/client/allocrunner/taskrunner/task_runner_hooks.go (about) 1 package taskrunner 2 3 import ( 4 "context" 5 "fmt" 6 "path/filepath" 7 "sync" 8 "time" 9 10 "github.com/LK4D4/joincontext" 11 multierror "github.com/hashicorp/go-multierror" 12 "github.com/hashicorp/nomad/client/allocrunner/interfaces" 13 "github.com/hashicorp/nomad/client/allocrunner/taskrunner/state" 14 "github.com/hashicorp/nomad/nomad/structs" 15 "github.com/hashicorp/nomad/plugins/drivers" 16 ) 17 18 // hookResources captures the resources for the task provided by hooks. 19 type hookResources struct { 20 Devices []*drivers.DeviceConfig 21 Mounts []*drivers.MountConfig 22 sync.RWMutex 23 } 24 25 func (h *hookResources) setDevices(d []*drivers.DeviceConfig) { 26 h.Lock() 27 h.Devices = d 28 h.Unlock() 29 } 30 31 func (h *hookResources) getDevices() []*drivers.DeviceConfig { 32 h.RLock() 33 defer h.RUnlock() 34 return h.Devices 35 } 36 37 func (h *hookResources) setMounts(m []*drivers.MountConfig) { 38 h.Lock() 39 h.Mounts = m 40 h.Unlock() 41 } 42 43 func (h *hookResources) getMounts() []*drivers.MountConfig { 44 h.RLock() 45 defer h.RUnlock() 46 return h.Mounts 47 } 48 49 // initHooks initializes the tasks hooks. 50 func (tr *TaskRunner) initHooks() { 51 hookLogger := tr.logger.Named("task_hook") 52 task := tr.Task() 53 54 tr.logmonHookConfig = newLogMonHookConfig(task.Name, tr.taskDir.LogDir) 55 56 // Add the hook resources 57 tr.hookResources = &hookResources{} 58 59 // Create the task directory hook. This is run first to ensure the 60 // directory path exists for other hooks. 61 alloc := tr.Alloc() 62 tr.runnerHooks = []interfaces.TaskHook{ 63 newValidateHook(tr.clientConfig, hookLogger), 64 newTaskDirHook(tr, hookLogger), 65 newLogMonHook(tr, hookLogger), 66 newDispatchHook(alloc, hookLogger), 67 newVolumeHook(tr, hookLogger), 68 newArtifactHook(tr, hookLogger), 69 newStatsHook(tr, tr.clientConfig.StatsCollectionInterval, hookLogger), 70 newDeviceHook(tr.devicemanager, hookLogger), 71 } 72 73 // If the task has a CSI stanza, add the hook. 74 if task.CSIPluginConfig != nil { 75 tr.runnerHooks = append(tr.runnerHooks, newCSIPluginSupervisorHook(filepath.Join(tr.clientConfig.StateDir, "csi"), tr, tr, hookLogger)) 76 } 77 78 // If Vault is enabled, add the hook 79 if task.Vault != nil { 80 tr.runnerHooks = append(tr.runnerHooks, newVaultHook(&vaultHookConfig{ 81 vaultStanza: task.Vault, 82 client: tr.vaultClient, 83 events: tr, 84 lifecycle: tr, 85 updater: tr, 86 logger: hookLogger, 87 alloc: tr.Alloc(), 88 task: tr.taskName, 89 })) 90 } 91 92 // If there are templates is enabled, add the hook 93 if len(task.Templates) != 0 { 94 tr.runnerHooks = append(tr.runnerHooks, newTemplateHook(&templateHookConfig{ 95 logger: hookLogger, 96 lifecycle: tr, 97 events: tr, 98 templates: task.Templates, 99 clientConfig: tr.clientConfig, 100 envBuilder: tr.envBuilder, 101 })) 102 } 103 104 // Always add the service hook. A task with no services on initial registration 105 // may be updated to include services, which must be handled with this hook. 106 tr.runnerHooks = append(tr.runnerHooks, newServiceHook(serviceHookConfig{ 107 alloc: tr.Alloc(), 108 task: tr.Task(), 109 consul: tr.consulServiceClient, 110 restarter: tr, 111 logger: hookLogger, 112 })) 113 114 // If this is a Connect sidecar proxy (or a Connect Native) service, 115 // add the sidsHook for requesting a Service Identity token (if ACLs). 116 if task.UsesConnect() { 117 // Enable the Service Identity hook only if the Nomad client is configured 118 // with a consul token, indicating that Consul ACLs are enabled 119 if tr.clientConfig.ConsulConfig.Token != "" { 120 tr.runnerHooks = append(tr.runnerHooks, newSIDSHook(sidsHookConfig{ 121 alloc: tr.Alloc(), 122 task: tr.Task(), 123 sidsClient: tr.siClient, 124 lifecycle: tr, 125 logger: hookLogger, 126 })) 127 } 128 129 if task.UsesConnectSidecar() { 130 tr.runnerHooks = append(tr.runnerHooks, 131 newEnvoyVersionHook(newEnvoyVersionHookConfig(alloc, tr.consulProxiesClient, hookLogger)), 132 newEnvoyBootstrapHook(newEnvoyBootstrapHookConfig(alloc, tr.clientConfig.ConsulConfig, hookLogger)), 133 ) 134 } else if task.Kind.IsConnectNative() { 135 tr.runnerHooks = append(tr.runnerHooks, newConnectNativeHook( 136 newConnectNativeHookConfig(alloc, tr.clientConfig.ConsulConfig, hookLogger), 137 )) 138 } 139 } 140 141 // If there are any script checks, add the hook 142 scriptCheckHook := newScriptCheckHook(scriptCheckHookConfig{ 143 alloc: tr.Alloc(), 144 task: tr.Task(), 145 consul: tr.consulServiceClient, 146 logger: hookLogger, 147 }) 148 tr.runnerHooks = append(tr.runnerHooks, scriptCheckHook) 149 } 150 151 func (tr *TaskRunner) emitHookError(err error, hookName string) { 152 var taskEvent *structs.TaskEvent 153 if herr, ok := err.(*hookError); ok { 154 taskEvent = herr.taskEvent 155 } else { 156 message := fmt.Sprintf("%s: %v", hookName, err) 157 taskEvent = structs.NewTaskEvent(structs.TaskHookFailed).SetMessage(message) 158 } 159 160 tr.EmitEvent(taskEvent) 161 } 162 163 // prestart is used to run the runners prestart hooks. 164 func (tr *TaskRunner) prestart() error { 165 // Determine if the allocation is terminal and we should avoid running 166 // prestart hooks. 167 if tr.shouldShutdown() { 168 tr.logger.Trace("skipping prestart hooks since allocation is terminal") 169 return nil 170 } 171 172 if tr.logger.IsTrace() { 173 start := time.Now() 174 tr.logger.Trace("running prestart hooks", "start", start) 175 defer func() { 176 end := time.Now() 177 tr.logger.Trace("finished prestart hooks", "end", end, "duration", end.Sub(start)) 178 }() 179 } 180 181 for _, hook := range tr.runnerHooks { 182 pre, ok := hook.(interfaces.TaskPrestartHook) 183 if !ok { 184 continue 185 } 186 187 name := pre.Name() 188 189 // Build the request 190 req := interfaces.TaskPrestartRequest{ 191 Task: tr.Task(), 192 TaskDir: tr.taskDir, 193 TaskEnv: tr.envBuilder.Build(), 194 TaskResources: tr.taskResources, 195 } 196 197 origHookState := tr.hookState(name) 198 if origHookState != nil { 199 if origHookState.PrestartDone { 200 tr.logger.Trace("skipping done prestart hook", "name", pre.Name()) 201 202 // Always set env vars from hooks 203 if name == HookNameDevices { 204 tr.envBuilder.SetDeviceHookEnv(name, origHookState.Env) 205 } else { 206 tr.envBuilder.SetHookEnv(name, origHookState.Env) 207 } 208 209 continue 210 } 211 212 // Give the hook it's old data 213 req.PreviousState = origHookState.Data 214 } 215 216 req.VaultToken = tr.getVaultToken() 217 218 // Time the prestart hook 219 var start time.Time 220 if tr.logger.IsTrace() { 221 start = time.Now() 222 tr.logger.Trace("running prestart hook", "name", name, "start", start) 223 } 224 225 // Run the prestart hook 226 // use a joint context to allow any blocking pre-start hooks 227 // to be canceled by either killCtx or shutdownCtx 228 joinedCtx, _ := joincontext.Join(tr.killCtx, tr.shutdownCtx) 229 var resp interfaces.TaskPrestartResponse 230 if err := pre.Prestart(joinedCtx, &req, &resp); err != nil { 231 tr.emitHookError(err, name) 232 return structs.WrapRecoverable(fmt.Sprintf("prestart hook %q failed: %v", name, err), err) 233 } 234 235 // Store the hook state 236 { 237 hookState := &state.HookState{ 238 Data: resp.State, 239 PrestartDone: resp.Done, 240 Env: resp.Env, 241 } 242 243 // Store and persist local state if the hook state has changed 244 if !hookState.Equal(origHookState) { 245 tr.stateLock.Lock() 246 tr.localState.Hooks[name] = hookState 247 tr.stateLock.Unlock() 248 249 if err := tr.persistLocalState(); err != nil { 250 return err 251 } 252 } 253 } 254 255 // Store the environment variables returned by the hook 256 if name == HookNameDevices { 257 tr.envBuilder.SetDeviceHookEnv(name, resp.Env) 258 } else { 259 tr.envBuilder.SetHookEnv(name, resp.Env) 260 } 261 262 // Store the resources 263 if len(resp.Devices) != 0 { 264 tr.hookResources.setDevices(resp.Devices) 265 } 266 if len(resp.Mounts) != 0 { 267 tr.hookResources.setMounts(resp.Mounts) 268 } 269 270 if tr.logger.IsTrace() { 271 end := time.Now() 272 tr.logger.Trace("finished prestart hook", "name", name, "end", end, "duration", end.Sub(start)) 273 } 274 } 275 276 return nil 277 } 278 279 // poststart is used to run the runners poststart hooks. 280 func (tr *TaskRunner) poststart() error { 281 if tr.logger.IsTrace() { 282 start := time.Now() 283 tr.logger.Trace("running poststart hooks", "start", start) 284 defer func() { 285 end := time.Now() 286 tr.logger.Trace("finished poststart hooks", "end", end, "duration", end.Sub(start)) 287 }() 288 } 289 290 handle := tr.getDriverHandle() 291 net := handle.Network() 292 293 // Pass the lazy handle to the hooks so even if the driver exits and we 294 // launch a new one (external plugin), the handle will refresh. 295 lazyHandle := NewLazyHandle(tr.shutdownCtx, tr.getDriverHandle, tr.logger) 296 297 var merr multierror.Error 298 for _, hook := range tr.runnerHooks { 299 post, ok := hook.(interfaces.TaskPoststartHook) 300 if !ok { 301 continue 302 } 303 304 name := post.Name() 305 var start time.Time 306 if tr.logger.IsTrace() { 307 start = time.Now() 308 tr.logger.Trace("running poststart hook", "name", name, "start", start) 309 } 310 311 req := interfaces.TaskPoststartRequest{ 312 DriverExec: lazyHandle, 313 DriverNetwork: net, 314 DriverStats: lazyHandle, 315 TaskEnv: tr.envBuilder.Build(), 316 } 317 var resp interfaces.TaskPoststartResponse 318 if err := post.Poststart(tr.killCtx, &req, &resp); err != nil { 319 tr.emitHookError(err, name) 320 merr.Errors = append(merr.Errors, fmt.Errorf("poststart hook %q failed: %v", name, err)) 321 } 322 323 // No need to persist as PoststartResponse is currently empty 324 325 if tr.logger.IsTrace() { 326 end := time.Now() 327 tr.logger.Trace("finished poststart hooks", "name", name, "end", end, "duration", end.Sub(start)) 328 } 329 } 330 331 return merr.ErrorOrNil() 332 } 333 334 // exited is used to run the exited hooks before a task is stopped. 335 func (tr *TaskRunner) exited() error { 336 if tr.logger.IsTrace() { 337 start := time.Now() 338 tr.logger.Trace("running exited hooks", "start", start) 339 defer func() { 340 end := time.Now() 341 tr.logger.Trace("finished exited hooks", "end", end, "duration", end.Sub(start)) 342 }() 343 } 344 345 var merr multierror.Error 346 for _, hook := range tr.runnerHooks { 347 post, ok := hook.(interfaces.TaskExitedHook) 348 if !ok { 349 continue 350 } 351 352 name := post.Name() 353 var start time.Time 354 if tr.logger.IsTrace() { 355 start = time.Now() 356 tr.logger.Trace("running exited hook", "name", name, "start", start) 357 } 358 359 req := interfaces.TaskExitedRequest{} 360 var resp interfaces.TaskExitedResponse 361 if err := post.Exited(tr.killCtx, &req, &resp); err != nil { 362 tr.emitHookError(err, name) 363 merr.Errors = append(merr.Errors, fmt.Errorf("exited hook %q failed: %v", name, err)) 364 } 365 366 // No need to persist as TaskExitedResponse is currently empty 367 368 if tr.logger.IsTrace() { 369 end := time.Now() 370 tr.logger.Trace("finished exited hooks", "name", name, "end", end, "duration", end.Sub(start)) 371 } 372 } 373 374 return merr.ErrorOrNil() 375 376 } 377 378 // stop is used to run the stop hooks. 379 func (tr *TaskRunner) stop() error { 380 if tr.logger.IsTrace() { 381 start := time.Now() 382 tr.logger.Trace("running stop hooks", "start", start) 383 defer func() { 384 end := time.Now() 385 tr.logger.Trace("finished stop hooks", "end", end, "duration", end.Sub(start)) 386 }() 387 } 388 389 var merr multierror.Error 390 for _, hook := range tr.runnerHooks { 391 post, ok := hook.(interfaces.TaskStopHook) 392 if !ok { 393 continue 394 } 395 396 name := post.Name() 397 var start time.Time 398 if tr.logger.IsTrace() { 399 start = time.Now() 400 tr.logger.Trace("running stop hook", "name", name, "start", start) 401 } 402 403 req := interfaces.TaskStopRequest{} 404 405 origHookState := tr.hookState(name) 406 if origHookState != nil { 407 // Give the hook data provided by prestart 408 req.ExistingState = origHookState.Data 409 } 410 411 var resp interfaces.TaskStopResponse 412 if err := post.Stop(tr.killCtx, &req, &resp); err != nil { 413 tr.emitHookError(err, name) 414 merr.Errors = append(merr.Errors, fmt.Errorf("stop hook %q failed: %v", name, err)) 415 } 416 417 // Stop hooks cannot alter state and must be idempotent, so 418 // unlike prestart there's no state to persist here. 419 420 if tr.logger.IsTrace() { 421 end := time.Now() 422 tr.logger.Trace("finished stop hook", "name", name, "end", end, "duration", end.Sub(start)) 423 } 424 } 425 426 return merr.ErrorOrNil() 427 } 428 429 // update is used to run the runners update hooks. Should only be called from 430 // Run(). To trigger an update, update state on the TaskRunner and call 431 // triggerUpdateHooks. 432 func (tr *TaskRunner) updateHooks() { 433 if tr.logger.IsTrace() { 434 start := time.Now() 435 tr.logger.Trace("running update hooks", "start", start) 436 defer func() { 437 end := time.Now() 438 tr.logger.Trace("finished update hooks", "end", end, "duration", end.Sub(start)) 439 }() 440 } 441 442 // Prepare state needed by Update hooks 443 alloc := tr.Alloc() 444 445 // Execute Update hooks 446 for _, hook := range tr.runnerHooks { 447 upd, ok := hook.(interfaces.TaskUpdateHook) 448 if !ok { 449 continue 450 } 451 452 name := upd.Name() 453 454 // Build the request 455 req := interfaces.TaskUpdateRequest{ 456 VaultToken: tr.getVaultToken(), 457 Alloc: alloc, 458 TaskEnv: tr.envBuilder.Build(), 459 } 460 461 // Time the update hook 462 var start time.Time 463 if tr.logger.IsTrace() { 464 start = time.Now() 465 tr.logger.Trace("running update hook", "name", name, "start", start) 466 } 467 468 // Run the update hook 469 var resp interfaces.TaskUpdateResponse 470 if err := upd.Update(tr.killCtx, &req, &resp); err != nil { 471 tr.emitHookError(err, name) 472 tr.logger.Error("update hook failed", "name", name, "error", err) 473 } 474 475 // No need to persist as TaskUpdateResponse is currently empty 476 477 if tr.logger.IsTrace() { 478 end := time.Now() 479 tr.logger.Trace("finished update hooks", "name", name, "end", end, "duration", end.Sub(start)) 480 } 481 } 482 } 483 484 // preKill is used to run the runners preKill hooks 485 // preKill hooks contain logic that must be executed before 486 // a task is killed or restarted 487 func (tr *TaskRunner) preKill() { 488 if tr.logger.IsTrace() { 489 start := time.Now() 490 tr.logger.Trace("running pre kill hooks", "start", start) 491 defer func() { 492 end := time.Now() 493 tr.logger.Trace("finished pre kill hooks", "end", end, "duration", end.Sub(start)) 494 }() 495 } 496 497 for _, hook := range tr.runnerHooks { 498 killHook, ok := hook.(interfaces.TaskPreKillHook) 499 if !ok { 500 continue 501 } 502 503 name := killHook.Name() 504 505 // Time the pre kill hook 506 var start time.Time 507 if tr.logger.IsTrace() { 508 start = time.Now() 509 tr.logger.Trace("running prekill hook", "name", name, "start", start) 510 } 511 512 // Run the pre kill hook 513 req := interfaces.TaskPreKillRequest{} 514 var resp interfaces.TaskPreKillResponse 515 if err := killHook.PreKilling(context.Background(), &req, &resp); err != nil { 516 tr.emitHookError(err, name) 517 tr.logger.Error("prekill hook failed", "name", name, "error", err) 518 } 519 520 // No need to persist as TaskKillResponse is currently empty 521 522 if tr.logger.IsTrace() { 523 end := time.Now() 524 tr.logger.Trace("finished prekill hook", "name", name, "end", end, "duration", end.Sub(start)) 525 } 526 } 527 } 528 529 // shutdownHooks is called when the TaskRunner is gracefully shutdown but the 530 // task is not being stopped or garbage collected. 531 func (tr *TaskRunner) shutdownHooks() { 532 for _, hook := range tr.runnerHooks { 533 sh, ok := hook.(interfaces.ShutdownHook) 534 if !ok { 535 continue 536 } 537 538 name := sh.Name() 539 540 // Time the update hook 541 var start time.Time 542 if tr.logger.IsTrace() { 543 start = time.Now() 544 tr.logger.Trace("running shutdown hook", "name", name, "start", start) 545 } 546 547 sh.Shutdown() 548 549 if tr.logger.IsTrace() { 550 end := time.Now() 551 tr.logger.Trace("finished shutdown hook", "name", name, "end", end, "duration", end.Sub(start)) 552 } 553 } 554 }