github.com/manicqin/nomad@v0.9.5/client/allocrunner/taskrunner/task_runner_hooks.go (about) 1 package taskrunner 2 3 import ( 4 "context" 5 "fmt" 6 "sync" 7 "time" 8 9 "github.com/LK4D4/joincontext" 10 multierror "github.com/hashicorp/go-multierror" 11 "github.com/hashicorp/nomad/client/allocrunner/interfaces" 12 "github.com/hashicorp/nomad/client/allocrunner/taskrunner/state" 13 "github.com/hashicorp/nomad/nomad/structs" 14 "github.com/hashicorp/nomad/plugins/drivers" 15 ) 16 17 // hookResources captures the resources for the task provided by hooks. 18 type hookResources struct { 19 Devices []*drivers.DeviceConfig 20 Mounts []*drivers.MountConfig 21 sync.RWMutex 22 } 23 24 func (h *hookResources) setDevices(d []*drivers.DeviceConfig) { 25 h.Lock() 26 h.Devices = d 27 h.Unlock() 28 } 29 30 func (h *hookResources) getDevices() []*drivers.DeviceConfig { 31 h.RLock() 32 defer h.RUnlock() 33 return h.Devices 34 } 35 36 func (h *hookResources) setMounts(m []*drivers.MountConfig) { 37 h.Lock() 38 h.Mounts = m 39 h.Unlock() 40 } 41 42 func (h *hookResources) getMounts() []*drivers.MountConfig { 43 h.RLock() 44 defer h.RUnlock() 45 return h.Mounts 46 } 47 48 // initHooks intializes the tasks hooks. 49 func (tr *TaskRunner) initHooks() { 50 hookLogger := tr.logger.Named("task_hook") 51 task := tr.Task() 52 53 tr.logmonHookConfig = newLogMonHookConfig(task.Name, tr.taskDir.LogDir) 54 55 // Add the hook resources 56 tr.hookResources = &hookResources{} 57 58 // Create the task directory hook. This is run first to ensure the 59 // directory path exists for other hooks. 60 alloc := tr.Alloc() 61 tr.runnerHooks = []interfaces.TaskHook{ 62 newValidateHook(tr.clientConfig, hookLogger), 63 newTaskDirHook(tr, hookLogger), 64 newLogMonHook(tr, hookLogger), 65 newDispatchHook(alloc, hookLogger), 66 newVolumeHook(tr, hookLogger), 67 newArtifactHook(tr, hookLogger), 68 newStatsHook(tr, tr.clientConfig.StatsCollectionInterval, hookLogger), 69 newDeviceHook(tr.devicemanager, hookLogger), 70 newEnvoyBootstrapHook(alloc, tr.clientConfig.ConsulConfig.Addr, hookLogger), 71 } 72 73 // If Vault is enabled, add the hook 74 if task.Vault != nil { 75 tr.runnerHooks = append(tr.runnerHooks, newVaultHook(&vaultHookConfig{ 76 vaultStanza: task.Vault, 77 client: tr.vaultClient, 78 events: tr, 79 lifecycle: tr, 80 updater: tr, 81 logger: hookLogger, 82 alloc: tr.Alloc(), 83 task: tr.taskName, 84 })) 85 } 86 87 // If there are templates is enabled, add the hook 88 if len(task.Templates) != 0 { 89 tr.runnerHooks = append(tr.runnerHooks, newTemplateHook(&templateHookConfig{ 90 logger: hookLogger, 91 lifecycle: tr, 92 events: tr, 93 templates: task.Templates, 94 clientConfig: tr.clientConfig, 95 envBuilder: tr.envBuilder, 96 })) 97 } 98 99 // If there are any services, add the hook 100 if len(task.Services) != 0 { 101 tr.runnerHooks = append(tr.runnerHooks, newServiceHook(serviceHookConfig{ 102 alloc: tr.Alloc(), 103 task: tr.Task(), 104 consul: tr.consulClient, 105 restarter: tr, 106 logger: hookLogger, 107 })) 108 } 109 110 // If there are any script checks, add the hook 111 scriptCheckHook := newScriptCheckHook(scriptCheckHookConfig{ 112 alloc: tr.Alloc(), 113 task: tr.Task(), 114 consul: tr.consulClient, 115 logger: hookLogger, 116 }) 117 tr.runnerHooks = append(tr.runnerHooks, scriptCheckHook) 118 } 119 120 func (tr *TaskRunner) emitHookError(err error, hookName string) { 121 var taskEvent *structs.TaskEvent 122 if herr, ok := err.(*hookError); ok { 123 taskEvent = herr.taskEvent 124 } else { 125 message := fmt.Sprintf("%s: %v", hookName, err) 126 taskEvent = structs.NewTaskEvent(structs.TaskHookFailed).SetMessage(message) 127 } 128 129 tr.EmitEvent(taskEvent) 130 } 131 132 // prestart is used to run the runners prestart hooks. 133 func (tr *TaskRunner) prestart() error { 134 // Determine if the allocation is terminaland we should avoid running 135 // prestart hooks. 136 alloc := tr.Alloc() 137 if alloc.TerminalStatus() { 138 tr.logger.Trace("skipping prestart hooks since allocation is terminal") 139 return nil 140 } 141 142 if tr.logger.IsTrace() { 143 start := time.Now() 144 tr.logger.Trace("running prestart hooks", "start", start) 145 defer func() { 146 end := time.Now() 147 tr.logger.Trace("finished prestart hooks", "end", end, "duration", end.Sub(start)) 148 }() 149 } 150 151 for _, hook := range tr.runnerHooks { 152 pre, ok := hook.(interfaces.TaskPrestartHook) 153 if !ok { 154 continue 155 } 156 157 name := pre.Name() 158 159 // Build the request 160 req := interfaces.TaskPrestartRequest{ 161 Task: tr.Task(), 162 TaskDir: tr.taskDir, 163 TaskEnv: tr.envBuilder.Build(), 164 TaskResources: tr.taskResources, 165 } 166 167 origHookState := tr.hookState(name) 168 if origHookState != nil { 169 if origHookState.PrestartDone { 170 tr.logger.Trace("skipping done prestart hook", "name", pre.Name()) 171 172 // Always set env vars from hooks 173 if name == HookNameDevices { 174 tr.envBuilder.SetDeviceHookEnv(name, origHookState.Env) 175 } else { 176 tr.envBuilder.SetHookEnv(name, origHookState.Env) 177 } 178 179 continue 180 } 181 182 // Give the hook it's old data 183 req.PreviousState = origHookState.Data 184 } 185 186 req.VaultToken = tr.getVaultToken() 187 188 // Time the prestart hook 189 var start time.Time 190 if tr.logger.IsTrace() { 191 start = time.Now() 192 tr.logger.Trace("running prestart hook", "name", name, "start", start) 193 } 194 195 // Run the prestart hook 196 // use a joint context to allow any blocking pre-start hooks 197 // to be canceled by either killCtx or shutdownCtx 198 joinedCtx, _ := joincontext.Join(tr.killCtx, tr.shutdownCtx) 199 var resp interfaces.TaskPrestartResponse 200 if err := pre.Prestart(joinedCtx, &req, &resp); err != nil { 201 tr.emitHookError(err, name) 202 return structs.WrapRecoverable(fmt.Sprintf("prestart hook %q failed: %v", name, err), err) 203 } 204 205 // Store the hook state 206 { 207 hookState := &state.HookState{ 208 Data: resp.State, 209 PrestartDone: resp.Done, 210 Env: resp.Env, 211 } 212 213 // Store and persist local state if the hook state has changed 214 if !hookState.Equal(origHookState) { 215 tr.stateLock.Lock() 216 tr.localState.Hooks[name] = hookState 217 tr.stateLock.Unlock() 218 219 if err := tr.persistLocalState(); err != nil { 220 return err 221 } 222 } 223 } 224 225 // Store the environment variables returned by the hook 226 if name == HookNameDevices { 227 tr.envBuilder.SetDeviceHookEnv(name, resp.Env) 228 } else { 229 tr.envBuilder.SetHookEnv(name, resp.Env) 230 } 231 232 // Store the resources 233 if len(resp.Devices) != 0 { 234 tr.hookResources.setDevices(resp.Devices) 235 } 236 if len(resp.Mounts) != 0 { 237 tr.hookResources.setMounts(resp.Mounts) 238 } 239 240 if tr.logger.IsTrace() { 241 end := time.Now() 242 tr.logger.Trace("finished prestart hook", "name", name, "end", end, "duration", end.Sub(start)) 243 } 244 } 245 246 return nil 247 } 248 249 // poststart is used to run the runners poststart hooks. 250 func (tr *TaskRunner) poststart() error { 251 if tr.logger.IsTrace() { 252 start := time.Now() 253 tr.logger.Trace("running poststart hooks", "start", start) 254 defer func() { 255 end := time.Now() 256 tr.logger.Trace("finished poststart hooks", "end", end, "duration", end.Sub(start)) 257 }() 258 } 259 260 handle := tr.getDriverHandle() 261 net := handle.Network() 262 263 // Pass the lazy handle to the hooks so even if the driver exits and we 264 // launch a new one (external plugin), the handle will refresh. 265 lazyHandle := NewLazyHandle(tr.shutdownCtx, tr.getDriverHandle, tr.logger) 266 267 var merr multierror.Error 268 for _, hook := range tr.runnerHooks { 269 post, ok := hook.(interfaces.TaskPoststartHook) 270 if !ok { 271 continue 272 } 273 274 name := post.Name() 275 var start time.Time 276 if tr.logger.IsTrace() { 277 start = time.Now() 278 tr.logger.Trace("running poststart hook", "name", name, "start", start) 279 } 280 281 req := interfaces.TaskPoststartRequest{ 282 DriverExec: lazyHandle, 283 DriverNetwork: net, 284 DriverStats: lazyHandle, 285 TaskEnv: tr.envBuilder.Build(), 286 } 287 var resp interfaces.TaskPoststartResponse 288 if err := post.Poststart(tr.killCtx, &req, &resp); err != nil { 289 tr.emitHookError(err, name) 290 merr.Errors = append(merr.Errors, fmt.Errorf("poststart hook %q failed: %v", name, err)) 291 } 292 293 // No need to persist as PoststartResponse is currently empty 294 295 if tr.logger.IsTrace() { 296 end := time.Now() 297 tr.logger.Trace("finished poststart hooks", "name", name, "end", end, "duration", end.Sub(start)) 298 } 299 } 300 301 return merr.ErrorOrNil() 302 } 303 304 // exited is used to run the exited hooks before a task is stopped. 305 func (tr *TaskRunner) exited() error { 306 if tr.logger.IsTrace() { 307 start := time.Now() 308 tr.logger.Trace("running exited hooks", "start", start) 309 defer func() { 310 end := time.Now() 311 tr.logger.Trace("finished exited hooks", "end", end, "duration", end.Sub(start)) 312 }() 313 } 314 315 var merr multierror.Error 316 for _, hook := range tr.runnerHooks { 317 post, ok := hook.(interfaces.TaskExitedHook) 318 if !ok { 319 continue 320 } 321 322 name := post.Name() 323 var start time.Time 324 if tr.logger.IsTrace() { 325 start = time.Now() 326 tr.logger.Trace("running exited hook", "name", name, "start", start) 327 } 328 329 req := interfaces.TaskExitedRequest{} 330 var resp interfaces.TaskExitedResponse 331 if err := post.Exited(tr.killCtx, &req, &resp); err != nil { 332 tr.emitHookError(err, name) 333 merr.Errors = append(merr.Errors, fmt.Errorf("exited hook %q failed: %v", name, err)) 334 } 335 336 // No need to persist as TaskExitedResponse is currently empty 337 338 if tr.logger.IsTrace() { 339 end := time.Now() 340 tr.logger.Trace("finished exited hooks", "name", name, "end", end, "duration", end.Sub(start)) 341 } 342 } 343 344 return merr.ErrorOrNil() 345 346 } 347 348 // stop is used to run the stop hooks. 349 func (tr *TaskRunner) stop() error { 350 if tr.logger.IsTrace() { 351 start := time.Now() 352 tr.logger.Trace("running stop hooks", "start", start) 353 defer func() { 354 end := time.Now() 355 tr.logger.Trace("finished stop hooks", "end", end, "duration", end.Sub(start)) 356 }() 357 } 358 359 var merr multierror.Error 360 for _, hook := range tr.runnerHooks { 361 post, ok := hook.(interfaces.TaskStopHook) 362 if !ok { 363 continue 364 } 365 366 name := post.Name() 367 var start time.Time 368 if tr.logger.IsTrace() { 369 start = time.Now() 370 tr.logger.Trace("running stop hook", "name", name, "start", start) 371 } 372 373 req := interfaces.TaskStopRequest{} 374 375 origHookState := tr.hookState(name) 376 if origHookState != nil { 377 // Give the hook data provided by prestart 378 req.ExistingState = origHookState.Data 379 } 380 381 var resp interfaces.TaskStopResponse 382 if err := post.Stop(tr.killCtx, &req, &resp); err != nil { 383 tr.emitHookError(err, name) 384 merr.Errors = append(merr.Errors, fmt.Errorf("stop hook %q failed: %v", name, err)) 385 } 386 387 // Stop hooks cannot alter state and must be idempotent, so 388 // unlike prestart there's no state to persist here. 389 390 if tr.logger.IsTrace() { 391 end := time.Now() 392 tr.logger.Trace("finished stop hook", "name", name, "end", end, "duration", end.Sub(start)) 393 } 394 } 395 396 return merr.ErrorOrNil() 397 } 398 399 // update is used to run the runners update hooks. Should only be called from 400 // Run(). To trigger an update, update state on the TaskRunner and call 401 // triggerUpdateHooks. 402 func (tr *TaskRunner) updateHooks() { 403 if tr.logger.IsTrace() { 404 start := time.Now() 405 tr.logger.Trace("running update hooks", "start", start) 406 defer func() { 407 end := time.Now() 408 tr.logger.Trace("finished update hooks", "end", end, "duration", end.Sub(start)) 409 }() 410 } 411 412 // Prepare state needed by Update hooks 413 alloc := tr.Alloc() 414 415 // Execute Update hooks 416 for _, hook := range tr.runnerHooks { 417 upd, ok := hook.(interfaces.TaskUpdateHook) 418 if !ok { 419 continue 420 } 421 422 name := upd.Name() 423 424 // Build the request 425 req := interfaces.TaskUpdateRequest{ 426 VaultToken: tr.getVaultToken(), 427 Alloc: alloc, 428 TaskEnv: tr.envBuilder.Build(), 429 } 430 431 // Time the update hook 432 var start time.Time 433 if tr.logger.IsTrace() { 434 start = time.Now() 435 tr.logger.Trace("running update hook", "name", name, "start", start) 436 } 437 438 // Run the update hook 439 var resp interfaces.TaskUpdateResponse 440 if err := upd.Update(tr.killCtx, &req, &resp); err != nil { 441 tr.emitHookError(err, name) 442 tr.logger.Error("update hook failed", "name", name, "error", err) 443 } 444 445 // No need to persist as TaskUpdateResponse is currently empty 446 447 if tr.logger.IsTrace() { 448 end := time.Now() 449 tr.logger.Trace("finished update hooks", "name", name, "end", end, "duration", end.Sub(start)) 450 } 451 } 452 } 453 454 // preKill is used to run the runners preKill hooks 455 // preKill hooks contain logic that must be executed before 456 // a task is killed or restarted 457 func (tr *TaskRunner) preKill() { 458 if tr.logger.IsTrace() { 459 start := time.Now() 460 tr.logger.Trace("running pre kill hooks", "start", start) 461 defer func() { 462 end := time.Now() 463 tr.logger.Trace("finished pre kill hooks", "end", end, "duration", end.Sub(start)) 464 }() 465 } 466 467 for _, hook := range tr.runnerHooks { 468 killHook, ok := hook.(interfaces.TaskPreKillHook) 469 if !ok { 470 continue 471 } 472 473 name := killHook.Name() 474 475 // Time the pre kill hook 476 var start time.Time 477 if tr.logger.IsTrace() { 478 start = time.Now() 479 tr.logger.Trace("running prekill hook", "name", name, "start", start) 480 } 481 482 // Run the pre kill hook 483 req := interfaces.TaskPreKillRequest{} 484 var resp interfaces.TaskPreKillResponse 485 if err := killHook.PreKilling(context.Background(), &req, &resp); err != nil { 486 tr.emitHookError(err, name) 487 tr.logger.Error("prekill hook failed", "name", name, "error", err) 488 } 489 490 // No need to persist as TaskKillResponse is currently empty 491 492 if tr.logger.IsTrace() { 493 end := time.Now() 494 tr.logger.Trace("finished prekill hook", "name", name, "end", end, "duration", end.Sub(start)) 495 } 496 } 497 } 498 499 // shutdownHooks is called when the TaskRunner is gracefully shutdown but the 500 // task is not being stopped or garbage collected. 501 func (tr *TaskRunner) shutdownHooks() { 502 for _, hook := range tr.runnerHooks { 503 sh, ok := hook.(interfaces.ShutdownHook) 504 if !ok { 505 continue 506 } 507 508 name := sh.Name() 509 510 // Time the update hook 511 var start time.Time 512 if tr.logger.IsTrace() { 513 start = time.Now() 514 tr.logger.Trace("running shutdown hook", "name", name, "start", start) 515 } 516 517 sh.Shutdown() 518 519 if tr.logger.IsTrace() { 520 end := time.Now() 521 tr.logger.Trace("finished shutdown hook", "name", name, "end", end, "duration", end.Sub(start)) 522 } 523 } 524 }