github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/allocrunner/taskrunner/task_runner_hooks.go (about) 1 package taskrunner 2 3 import ( 4 "context" 5 "fmt" 6 "sync" 7 "time" 8 9 "github.com/LK4D4/joincontext" 10 multierror "github.com/hashicorp/go-multierror" 11 "github.com/hashicorp/nomad/client/allocrunner/interfaces" 12 "github.com/hashicorp/nomad/client/allocrunner/taskrunner/state" 13 "github.com/hashicorp/nomad/nomad/structs" 14 "github.com/hashicorp/nomad/plugins/drivers" 15 ) 16 17 // hookResources captures the resources for the task provided by hooks. 18 type hookResources struct { 19 Devices []*drivers.DeviceConfig 20 Mounts []*drivers.MountConfig 21 sync.RWMutex 22 } 23 24 func (h *hookResources) setDevices(d []*drivers.DeviceConfig) { 25 h.Lock() 26 h.Devices = d 27 h.Unlock() 28 } 29 30 func (h *hookResources) getDevices() []*drivers.DeviceConfig { 31 h.RLock() 32 defer h.RUnlock() 33 return h.Devices 34 } 35 36 func (h *hookResources) setMounts(m []*drivers.MountConfig) { 37 h.Lock() 38 h.Mounts = m 39 h.Unlock() 40 } 41 42 func (h *hookResources) getMounts() []*drivers.MountConfig { 43 h.RLock() 44 defer h.RUnlock() 45 return h.Mounts 46 } 47 48 // initHooks initializes the tasks hooks. 49 func (tr *TaskRunner) initHooks() { 50 hookLogger := tr.logger.Named("task_hook") 51 task := tr.Task() 52 53 tr.logmonHookConfig = newLogMonHookConfig(task.Name, tr.taskDir.LogDir) 54 55 // Add the hook resources 56 tr.hookResources = &hookResources{} 57 58 // Create the task directory hook. This is run first to ensure the 59 // directory path exists for other hooks. 60 alloc := tr.Alloc() 61 tr.runnerHooks = []interfaces.TaskHook{ 62 newValidateHook(tr.clientConfig, hookLogger), 63 newTaskDirHook(tr, hookLogger), 64 newIdentityHook(tr, hookLogger), 65 newLogMonHook(tr, hookLogger), 66 newDispatchHook(alloc, hookLogger), 67 newVolumeHook(tr, hookLogger), 68 newArtifactHook(tr, tr.getter, hookLogger), 69 newStatsHook(tr, tr.clientConfig.StatsCollectionInterval, hookLogger), 70 newDeviceHook(tr.devicemanager, hookLogger), 71 } 72 73 // If the task has a CSI stanza, add the hook. 74 if task.CSIPluginConfig != nil { 75 tr.runnerHooks = append(tr.runnerHooks, newCSIPluginSupervisorHook( 76 &csiPluginSupervisorHookConfig{ 77 clientStateDirPath: tr.clientConfig.StateDir, 78 events: tr, 79 runner: tr, 80 lifecycle: tr, 81 capabilities: tr.driverCapabilities, 82 logger: hookLogger, 83 })) 84 } 85 86 // If Vault is enabled, add the hook 87 if task.Vault != nil { 88 tr.runnerHooks = append(tr.runnerHooks, newVaultHook(&vaultHookConfig{ 89 vaultStanza: task.Vault, 90 client: tr.vaultClient, 91 events: tr, 92 lifecycle: tr, 93 updater: tr, 94 logger: hookLogger, 95 alloc: tr.Alloc(), 96 task: tr.taskName, 97 })) 98 } 99 100 // Get the consul namespace for the TG of the allocation. 101 consulNamespace := tr.alloc.ConsulNamespace() 102 103 // Identify the service registration provider, which can differ from the 104 // Consul namespace depending on which provider is used. 105 serviceProviderNamespace := tr.alloc.ServiceProviderNamespace() 106 107 // If there are templates is enabled, add the hook 108 if len(task.Templates) != 0 { 109 tr.runnerHooks = append(tr.runnerHooks, newTemplateHook(&templateHookConfig{ 110 logger: hookLogger, 111 lifecycle: tr, 112 events: tr, 113 templates: task.Templates, 114 clientConfig: tr.clientConfig, 115 envBuilder: tr.envBuilder, 116 consulNamespace: consulNamespace, 117 nomadNamespace: tr.alloc.Job.Namespace, 118 })) 119 } 120 121 // Always add the service hook. A task with no services on initial registration 122 // may be updated to include services, which must be handled with this hook. 123 tr.runnerHooks = append(tr.runnerHooks, newServiceHook(serviceHookConfig{ 124 alloc: tr.Alloc(), 125 task: tr.Task(), 126 providerNamespace: serviceProviderNamespace, 127 serviceRegWrapper: tr.serviceRegWrapper, 128 restarter: tr, 129 logger: hookLogger, 130 })) 131 132 // If this is a Connect sidecar proxy (or a Connect Native) service, 133 // add the sidsHook for requesting a Service Identity token (if ACLs). 134 if task.UsesConnect() { 135 // Enable the Service Identity hook only if the Nomad client is configured 136 // with a consul token, indicating that Consul ACLs are enabled 137 if tr.clientConfig.ConsulConfig.Token != "" { 138 tr.runnerHooks = append(tr.runnerHooks, newSIDSHook(sidsHookConfig{ 139 alloc: tr.Alloc(), 140 task: tr.Task(), 141 sidsClient: tr.siClient, 142 lifecycle: tr, 143 logger: hookLogger, 144 })) 145 } 146 147 if task.UsesConnectSidecar() { 148 tr.runnerHooks = append(tr.runnerHooks, 149 newEnvoyVersionHook(newEnvoyVersionHookConfig(alloc, tr.consulProxiesClient, hookLogger)), 150 newEnvoyBootstrapHook(newEnvoyBootstrapHookConfig(alloc, tr.clientConfig.ConsulConfig, consulNamespace, hookLogger)), 151 ) 152 } else if task.Kind.IsConnectNative() { 153 tr.runnerHooks = append(tr.runnerHooks, newConnectNativeHook( 154 newConnectNativeHookConfig(alloc, tr.clientConfig.ConsulConfig, hookLogger), 155 )) 156 } 157 } 158 159 // Always add the script checks hook. A task with no script check hook on 160 // initial registration may be updated to include script checks, which must 161 // be handled with this hook. 162 tr.runnerHooks = append(tr.runnerHooks, newScriptCheckHook(scriptCheckHookConfig{ 163 alloc: tr.Alloc(), 164 task: tr.Task(), 165 consul: tr.consulServiceClient, 166 logger: hookLogger, 167 })) 168 169 // If this task driver has remote capabilities, add the remote task 170 // hook. 171 if tr.driverCapabilities.RemoteTasks { 172 tr.runnerHooks = append(tr.runnerHooks, newRemoteTaskHook(tr, hookLogger)) 173 } 174 } 175 176 func (tr *TaskRunner) emitHookError(err error, hookName string) { 177 var taskEvent *structs.TaskEvent 178 if herr, ok := err.(*hookError); ok { 179 taskEvent = herr.taskEvent 180 } else { 181 message := fmt.Sprintf("%s: %v", hookName, err) 182 taskEvent = structs.NewTaskEvent(structs.TaskHookFailed).SetMessage(message) 183 } 184 185 tr.EmitEvent(taskEvent) 186 } 187 188 // prestart is used to run the runners prestart hooks. 189 func (tr *TaskRunner) prestart() error { 190 // Determine if the allocation is terminal and we should avoid running 191 // prestart hooks. 192 if tr.shouldShutdown() { 193 tr.logger.Trace("skipping prestart hooks since allocation is terminal") 194 return nil 195 } 196 197 if tr.logger.IsTrace() { 198 start := time.Now() 199 tr.logger.Trace("running prestart hooks", "start", start) 200 defer func() { 201 end := time.Now() 202 tr.logger.Trace("finished prestart hooks", "end", end, "duration", end.Sub(start)) 203 }() 204 } 205 206 // use a join context to allow any blocking pre-start hooks 207 // to be canceled by either killCtx or shutdownCtx 208 joinedCtx, joinedCancel := joincontext.Join(tr.killCtx, tr.shutdownCtx) 209 defer joinedCancel() 210 211 for _, hook := range tr.runnerHooks { 212 pre, ok := hook.(interfaces.TaskPrestartHook) 213 if !ok { 214 continue 215 } 216 217 name := pre.Name() 218 219 // Build the request 220 req := interfaces.TaskPrestartRequest{ 221 Task: tr.Task(), 222 TaskDir: tr.taskDir, 223 TaskEnv: tr.envBuilder.Build(), 224 TaskResources: tr.taskResources, 225 } 226 227 origHookState := tr.hookState(name) 228 if origHookState != nil { 229 if origHookState.PrestartDone { 230 tr.logger.Trace("skipping done prestart hook", "name", pre.Name()) 231 232 // Always set env vars from hooks 233 if name == HookNameDevices { 234 tr.envBuilder.SetDeviceHookEnv(name, origHookState.Env) 235 } else { 236 tr.envBuilder.SetHookEnv(name, origHookState.Env) 237 } 238 239 continue 240 } 241 242 // Give the hook it's old data 243 req.PreviousState = origHookState.Data 244 } 245 246 req.VaultToken = tr.getVaultToken() 247 req.NomadToken = tr.getNomadToken() 248 249 // Time the prestart hook 250 var start time.Time 251 if tr.logger.IsTrace() { 252 start = time.Now() 253 tr.logger.Trace("running prestart hook", "name", name, "start", start) 254 } 255 256 // Run the prestart hook 257 var resp interfaces.TaskPrestartResponse 258 if err := pre.Prestart(joinedCtx, &req, &resp); err != nil { 259 tr.emitHookError(err, name) 260 return structs.WrapRecoverable(fmt.Sprintf("prestart hook %q failed: %v", name, err), err) 261 } 262 263 // Store the hook state 264 { 265 hookState := &state.HookState{ 266 Data: resp.State, 267 PrestartDone: resp.Done, 268 Env: resp.Env, 269 } 270 271 // Store and persist local state if the hook state has changed 272 if !hookState.Equal(origHookState) { 273 tr.stateLock.Lock() 274 tr.localState.Hooks[name] = hookState 275 tr.stateLock.Unlock() 276 277 if err := tr.persistLocalState(); err != nil { 278 return err 279 } 280 } 281 } 282 283 // Store the environment variables returned by the hook 284 if name == HookNameDevices { 285 tr.envBuilder.SetDeviceHookEnv(name, resp.Env) 286 } else { 287 tr.envBuilder.SetHookEnv(name, resp.Env) 288 } 289 290 // Store the resources 291 if len(resp.Devices) != 0 { 292 tr.hookResources.setDevices(resp.Devices) 293 } 294 if len(resp.Mounts) != 0 { 295 tr.hookResources.setMounts(resp.Mounts) 296 } 297 298 if tr.logger.IsTrace() { 299 end := time.Now() 300 tr.logger.Trace("finished prestart hook", "name", name, "end", end, "duration", end.Sub(start)) 301 } 302 } 303 304 return nil 305 } 306 307 // poststart is used to run the runners poststart hooks. 308 func (tr *TaskRunner) poststart() error { 309 if tr.logger.IsTrace() { 310 start := time.Now() 311 tr.logger.Trace("running poststart hooks", "start", start) 312 defer func() { 313 end := time.Now() 314 tr.logger.Trace("finished poststart hooks", "end", end, "duration", end.Sub(start)) 315 }() 316 } 317 318 handle := tr.getDriverHandle() 319 net := handle.Network() 320 321 // Pass the lazy handle to the hooks so even if the driver exits and we 322 // launch a new one (external plugin), the handle will refresh. 323 lazyHandle := NewLazyHandle(tr.shutdownCtx, tr.getDriverHandle, tr.logger) 324 325 var merr multierror.Error 326 for _, hook := range tr.runnerHooks { 327 post, ok := hook.(interfaces.TaskPoststartHook) 328 if !ok { 329 continue 330 } 331 332 name := post.Name() 333 var start time.Time 334 if tr.logger.IsTrace() { 335 start = time.Now() 336 tr.logger.Trace("running poststart hook", "name", name, "start", start) 337 } 338 339 req := interfaces.TaskPoststartRequest{ 340 DriverExec: lazyHandle, 341 DriverNetwork: net, 342 DriverStats: lazyHandle, 343 TaskEnv: tr.envBuilder.Build(), 344 } 345 var resp interfaces.TaskPoststartResponse 346 if err := post.Poststart(tr.killCtx, &req, &resp); err != nil { 347 tr.emitHookError(err, name) 348 merr.Errors = append(merr.Errors, fmt.Errorf("poststart hook %q failed: %v", name, err)) 349 } 350 351 // No need to persist as PoststartResponse is currently empty 352 353 if tr.logger.IsTrace() { 354 end := time.Now() 355 tr.logger.Trace("finished poststart hooks", "name", name, "end", end, "duration", end.Sub(start)) 356 } 357 } 358 359 return merr.ErrorOrNil() 360 } 361 362 // exited is used to run the exited hooks before a task is stopped. 363 func (tr *TaskRunner) exited() error { 364 if tr.logger.IsTrace() { 365 start := time.Now() 366 tr.logger.Trace("running exited hooks", "start", start) 367 defer func() { 368 end := time.Now() 369 tr.logger.Trace("finished exited hooks", "end", end, "duration", end.Sub(start)) 370 }() 371 } 372 373 var merr multierror.Error 374 for _, hook := range tr.runnerHooks { 375 post, ok := hook.(interfaces.TaskExitedHook) 376 if !ok { 377 continue 378 } 379 380 name := post.Name() 381 var start time.Time 382 if tr.logger.IsTrace() { 383 start = time.Now() 384 tr.logger.Trace("running exited hook", "name", name, "start", start) 385 } 386 387 req := interfaces.TaskExitedRequest{} 388 var resp interfaces.TaskExitedResponse 389 if err := post.Exited(tr.killCtx, &req, &resp); err != nil { 390 tr.emitHookError(err, name) 391 merr.Errors = append(merr.Errors, fmt.Errorf("exited hook %q failed: %v", name, err)) 392 } 393 394 // No need to persist as TaskExitedResponse is currently empty 395 396 if tr.logger.IsTrace() { 397 end := time.Now() 398 tr.logger.Trace("finished exited hooks", "name", name, "end", end, "duration", end.Sub(start)) 399 } 400 } 401 402 return merr.ErrorOrNil() 403 404 } 405 406 // stop is used to run the stop hooks. 407 func (tr *TaskRunner) stop() error { 408 if tr.logger.IsTrace() { 409 start := time.Now() 410 tr.logger.Trace("running stop hooks", "start", start) 411 defer func() { 412 end := time.Now() 413 tr.logger.Trace("finished stop hooks", "end", end, "duration", end.Sub(start)) 414 }() 415 } 416 417 var merr multierror.Error 418 for _, hook := range tr.runnerHooks { 419 post, ok := hook.(interfaces.TaskStopHook) 420 if !ok { 421 continue 422 } 423 424 name := post.Name() 425 var start time.Time 426 if tr.logger.IsTrace() { 427 start = time.Now() 428 tr.logger.Trace("running stop hook", "name", name, "start", start) 429 } 430 431 req := interfaces.TaskStopRequest{} 432 433 origHookState := tr.hookState(name) 434 if origHookState != nil { 435 // Give the hook data provided by prestart 436 req.ExistingState = origHookState.Data 437 } 438 439 var resp interfaces.TaskStopResponse 440 if err := post.Stop(tr.killCtx, &req, &resp); err != nil { 441 tr.emitHookError(err, name) 442 merr.Errors = append(merr.Errors, fmt.Errorf("stop hook %q failed: %v", name, err)) 443 } 444 445 // Stop hooks cannot alter state and must be idempotent, so 446 // unlike prestart there's no state to persist here. 447 448 if tr.logger.IsTrace() { 449 end := time.Now() 450 tr.logger.Trace("finished stop hook", "name", name, "end", end, "duration", end.Sub(start)) 451 } 452 } 453 454 return merr.ErrorOrNil() 455 } 456 457 // update is used to run the runners update hooks. Should only be called from 458 // Run(). To trigger an update, update state on the TaskRunner and call 459 // triggerUpdateHooks. 460 func (tr *TaskRunner) updateHooks() { 461 if tr.logger.IsTrace() { 462 start := time.Now() 463 tr.logger.Trace("running update hooks", "start", start) 464 defer func() { 465 end := time.Now() 466 tr.logger.Trace("finished update hooks", "end", end, "duration", end.Sub(start)) 467 }() 468 } 469 470 // Prepare state needed by Update hooks 471 alloc := tr.Alloc() 472 473 // Execute Update hooks 474 for _, hook := range tr.runnerHooks { 475 upd, ok := hook.(interfaces.TaskUpdateHook) 476 if !ok { 477 continue 478 } 479 480 name := upd.Name() 481 482 // Build the request 483 req := interfaces.TaskUpdateRequest{ 484 VaultToken: tr.getVaultToken(), 485 Alloc: alloc, 486 TaskEnv: tr.envBuilder.Build(), 487 } 488 489 // Time the update hook 490 var start time.Time 491 if tr.logger.IsTrace() { 492 start = time.Now() 493 tr.logger.Trace("running update hook", "name", name, "start", start) 494 } 495 496 // Run the update hook 497 var resp interfaces.TaskUpdateResponse 498 if err := upd.Update(tr.killCtx, &req, &resp); err != nil { 499 tr.emitHookError(err, name) 500 tr.logger.Error("update hook failed", "name", name, "error", err) 501 } 502 503 // No need to persist as TaskUpdateResponse is currently empty 504 505 if tr.logger.IsTrace() { 506 end := time.Now() 507 tr.logger.Trace("finished update hooks", "name", name, "end", end, "duration", end.Sub(start)) 508 } 509 } 510 } 511 512 // preKill is used to run the runners preKill hooks 513 // preKill hooks contain logic that must be executed before 514 // a task is killed or restarted 515 func (tr *TaskRunner) preKill() { 516 if tr.logger.IsTrace() { 517 start := time.Now() 518 tr.logger.Trace("running pre kill hooks", "start", start) 519 defer func() { 520 end := time.Now() 521 tr.logger.Trace("finished pre kill hooks", "end", end, "duration", end.Sub(start)) 522 }() 523 } 524 525 for _, hook := range tr.runnerHooks { 526 killHook, ok := hook.(interfaces.TaskPreKillHook) 527 if !ok { 528 continue 529 } 530 531 name := killHook.Name() 532 533 // Time the pre kill hook 534 var start time.Time 535 if tr.logger.IsTrace() { 536 start = time.Now() 537 tr.logger.Trace("running prekill hook", "name", name, "start", start) 538 } 539 540 // Run the pre kill hook 541 req := interfaces.TaskPreKillRequest{} 542 var resp interfaces.TaskPreKillResponse 543 if err := killHook.PreKilling(context.Background(), &req, &resp); err != nil { 544 tr.emitHookError(err, name) 545 tr.logger.Error("prekill hook failed", "name", name, "error", err) 546 } 547 548 // No need to persist as TaskKillResponse is currently empty 549 550 if tr.logger.IsTrace() { 551 end := time.Now() 552 tr.logger.Trace("finished prekill hook", "name", name, "end", end, "duration", end.Sub(start)) 553 } 554 } 555 } 556 557 // shutdownHooks is called when the TaskRunner is gracefully shutdown but the 558 // task is not being stopped or garbage collected. 559 func (tr *TaskRunner) shutdownHooks() { 560 for _, hook := range tr.runnerHooks { 561 sh, ok := hook.(interfaces.ShutdownHook) 562 if !ok { 563 continue 564 } 565 566 name := sh.Name() 567 568 // Time the update hook 569 var start time.Time 570 if tr.logger.IsTrace() { 571 start = time.Now() 572 tr.logger.Trace("running shutdown hook", "name", name, "start", start) 573 } 574 575 sh.Shutdown() 576 577 if tr.logger.IsTrace() { 578 end := time.Now() 579 tr.logger.Trace("finished shutdown hook", "name", name, "end", end, "duration", end.Sub(start)) 580 } 581 } 582 }