github.com/smithx10/nomad@v0.9.1-rc1/client/allocrunner/taskrunner/task_runner_hooks.go (about) 1 package taskrunner 2 3 import ( 4 "context" 5 "fmt" 6 "sync" 7 "time" 8 9 multierror "github.com/hashicorp/go-multierror" 10 "github.com/hashicorp/nomad/client/allocrunner/interfaces" 11 "github.com/hashicorp/nomad/client/allocrunner/taskrunner/state" 12 "github.com/hashicorp/nomad/nomad/structs" 13 "github.com/hashicorp/nomad/plugins/drivers" 14 ) 15 16 // hookResources captures the resources for the task provided by hooks. 17 type hookResources struct { 18 Devices []*drivers.DeviceConfig 19 Mounts []*drivers.MountConfig 20 sync.RWMutex 21 } 22 23 func (h *hookResources) setDevices(d []*drivers.DeviceConfig) { 24 h.Lock() 25 h.Devices = d 26 h.Unlock() 27 } 28 29 func (h *hookResources) getDevices() []*drivers.DeviceConfig { 30 h.RLock() 31 defer h.RUnlock() 32 return h.Devices 33 } 34 35 func (h *hookResources) setMounts(m []*drivers.MountConfig) { 36 h.Lock() 37 h.Mounts = m 38 h.Unlock() 39 } 40 41 func (h *hookResources) getMounts() []*drivers.MountConfig { 42 h.RLock() 43 defer h.RUnlock() 44 return h.Mounts 45 } 46 47 // initHooks intializes the tasks hooks. 48 func (tr *TaskRunner) initHooks() { 49 hookLogger := tr.logger.Named("task_hook") 50 task := tr.Task() 51 52 tr.logmonHookConfig = newLogMonHookConfig(task.Name, tr.taskDir.LogDir) 53 54 // Add the hook resources 55 tr.hookResources = &hookResources{} 56 57 // Create the task directory hook. This is run first to ensure the 58 // directory path exists for other hooks. 59 tr.runnerHooks = []interfaces.TaskHook{ 60 newValidateHook(tr.clientConfig, hookLogger), 61 newTaskDirHook(tr, hookLogger), 62 newLogMonHook(tr.logmonHookConfig, hookLogger), 63 newDispatchHook(tr.Alloc(), hookLogger), 64 newArtifactHook(tr, hookLogger), 65 newStatsHook(tr, tr.clientConfig.StatsCollectionInterval, hookLogger), 66 newDeviceHook(tr.devicemanager, hookLogger), 67 } 68 69 // If Vault is enabled, add the hook 70 if task.Vault != nil { 71 tr.runnerHooks = append(tr.runnerHooks, newVaultHook(&vaultHookConfig{ 72 vaultStanza: task.Vault, 73 client: tr.vaultClient, 74 events: tr, 75 lifecycle: tr, 76 updater: tr, 77 logger: hookLogger, 78 alloc: tr.Alloc(), 79 task: tr.taskName, 80 })) 81 } 82 83 // If there are templates is enabled, add the hook 84 if len(task.Templates) != 0 { 85 tr.runnerHooks = append(tr.runnerHooks, newTemplateHook(&templateHookConfig{ 86 logger: hookLogger, 87 lifecycle: tr, 88 events: tr, 89 templates: task.Templates, 90 clientConfig: tr.clientConfig, 91 envBuilder: tr.envBuilder, 92 })) 93 } 94 95 // If there are any services, add the hook 96 if len(task.Services) != 0 { 97 tr.runnerHooks = append(tr.runnerHooks, newServiceHook(serviceHookConfig{ 98 alloc: tr.Alloc(), 99 task: tr.Task(), 100 consul: tr.consulClient, 101 restarter: tr, 102 logger: hookLogger, 103 })) 104 } 105 } 106 107 func (tr *TaskRunner) emitHookError(err error, hookName string) { 108 var taskEvent *structs.TaskEvent 109 if herr, ok := err.(*hookError); ok { 110 taskEvent = herr.taskEvent 111 } else { 112 message := fmt.Sprintf("%s: %v", hookName, err) 113 taskEvent = structs.NewTaskEvent(structs.TaskHookFailed).SetMessage(message) 114 } 115 116 tr.EmitEvent(taskEvent) 117 } 118 119 // prestart is used to run the runners prestart hooks. 120 func (tr *TaskRunner) prestart() error { 121 // Determine if the allocation is terminaland we should avoid running 122 // prestart hooks. 123 alloc := tr.Alloc() 124 if alloc.TerminalStatus() { 125 tr.logger.Trace("skipping prestart hooks since allocation is terminal") 126 return nil 127 } 128 129 if tr.logger.IsTrace() { 130 start := time.Now() 131 tr.logger.Trace("running prestart hooks", "start", start) 132 defer func() { 133 end := time.Now() 134 tr.logger.Trace("finished prestart hooks", "end", end, "duration", end.Sub(start)) 135 }() 136 } 137 138 for _, hook := range tr.runnerHooks { 139 pre, ok := hook.(interfaces.TaskPrestartHook) 140 if !ok { 141 continue 142 } 143 144 name := pre.Name() 145 146 // Build the request 147 req := interfaces.TaskPrestartRequest{ 148 Task: tr.Task(), 149 TaskDir: tr.taskDir, 150 TaskEnv: tr.envBuilder.Build(), 151 TaskResources: tr.taskResources, 152 } 153 154 origHookState := tr.hookState(name) 155 if origHookState != nil { 156 if origHookState.PrestartDone { 157 tr.logger.Trace("skipping done prestart hook", "name", pre.Name()) 158 159 // Always set env vars from hooks 160 if name == HookNameDevices { 161 tr.envBuilder.SetDeviceHookEnv(name, origHookState.Env) 162 } else { 163 tr.envBuilder.SetHookEnv(name, origHookState.Env) 164 } 165 166 continue 167 } 168 169 // Give the hook it's old data 170 req.PreviousState = origHookState.Data 171 } 172 173 req.VaultToken = tr.getVaultToken() 174 175 // Time the prestart hook 176 var start time.Time 177 if tr.logger.IsTrace() { 178 start = time.Now() 179 tr.logger.Trace("running prestart hook", "name", name, "start", start) 180 } 181 182 // Run the prestart hook 183 var resp interfaces.TaskPrestartResponse 184 if err := pre.Prestart(tr.killCtx, &req, &resp); err != nil { 185 tr.emitHookError(err, name) 186 return structs.WrapRecoverable(fmt.Sprintf("prestart hook %q failed: %v", name, err), err) 187 } 188 189 // Store the hook state 190 { 191 hookState := &state.HookState{ 192 Data: resp.State, 193 PrestartDone: resp.Done, 194 Env: resp.Env, 195 } 196 197 // Store and persist local state if the hook state has changed 198 if !hookState.Equal(origHookState) { 199 tr.stateLock.Lock() 200 tr.localState.Hooks[name] = hookState 201 tr.stateLock.Unlock() 202 203 if err := tr.persistLocalState(); err != nil { 204 return err 205 } 206 } 207 } 208 209 // Store the environment variables returned by the hook 210 if name == HookNameDevices { 211 tr.envBuilder.SetDeviceHookEnv(name, resp.Env) 212 } else { 213 tr.envBuilder.SetHookEnv(name, resp.Env) 214 } 215 216 // Store the resources 217 if len(resp.Devices) != 0 { 218 tr.hookResources.setDevices(resp.Devices) 219 } 220 if len(resp.Mounts) != 0 { 221 tr.hookResources.setMounts(resp.Mounts) 222 } 223 224 if tr.logger.IsTrace() { 225 end := time.Now() 226 tr.logger.Trace("finished prestart hook", "name", name, "end", end, "duration", end.Sub(start)) 227 } 228 } 229 230 return nil 231 } 232 233 // poststart is used to run the runners poststart hooks. 234 func (tr *TaskRunner) poststart() error { 235 if tr.logger.IsTrace() { 236 start := time.Now() 237 tr.logger.Trace("running poststart hooks", "start", start) 238 defer func() { 239 end := time.Now() 240 tr.logger.Trace("finished poststart hooks", "end", end, "duration", end.Sub(start)) 241 }() 242 } 243 244 handle := tr.getDriverHandle() 245 net := handle.Network() 246 247 // Pass the lazy handle to the hooks so even if the driver exits and we 248 // launch a new one (external plugin), the handle will refresh. 249 lazyHandle := NewLazyHandle(tr.shutdownCtx, tr.getDriverHandle, tr.logger) 250 251 var merr multierror.Error 252 for _, hook := range tr.runnerHooks { 253 post, ok := hook.(interfaces.TaskPoststartHook) 254 if !ok { 255 continue 256 } 257 258 name := post.Name() 259 var start time.Time 260 if tr.logger.IsTrace() { 261 start = time.Now() 262 tr.logger.Trace("running poststart hook", "name", name, "start", start) 263 } 264 265 req := interfaces.TaskPoststartRequest{ 266 DriverExec: lazyHandle, 267 DriverNetwork: net, 268 DriverStats: lazyHandle, 269 TaskEnv: tr.envBuilder.Build(), 270 } 271 var resp interfaces.TaskPoststartResponse 272 if err := post.Poststart(tr.killCtx, &req, &resp); err != nil { 273 tr.emitHookError(err, name) 274 merr.Errors = append(merr.Errors, fmt.Errorf("poststart hook %q failed: %v", name, err)) 275 } 276 277 // No need to persist as PoststartResponse is currently empty 278 279 if tr.logger.IsTrace() { 280 end := time.Now() 281 tr.logger.Trace("finished poststart hooks", "name", name, "end", end, "duration", end.Sub(start)) 282 } 283 } 284 285 return merr.ErrorOrNil() 286 } 287 288 // exited is used to run the exited hooks before a task is stopped. 289 func (tr *TaskRunner) exited() error { 290 if tr.logger.IsTrace() { 291 start := time.Now() 292 tr.logger.Trace("running exited hooks", "start", start) 293 defer func() { 294 end := time.Now() 295 tr.logger.Trace("finished exited hooks", "end", end, "duration", end.Sub(start)) 296 }() 297 } 298 299 var merr multierror.Error 300 for _, hook := range tr.runnerHooks { 301 post, ok := hook.(interfaces.TaskExitedHook) 302 if !ok { 303 continue 304 } 305 306 name := post.Name() 307 var start time.Time 308 if tr.logger.IsTrace() { 309 start = time.Now() 310 tr.logger.Trace("running exited hook", "name", name, "start", start) 311 } 312 313 req := interfaces.TaskExitedRequest{} 314 var resp interfaces.TaskExitedResponse 315 if err := post.Exited(tr.killCtx, &req, &resp); err != nil { 316 tr.emitHookError(err, name) 317 merr.Errors = append(merr.Errors, fmt.Errorf("exited hook %q failed: %v", name, err)) 318 } 319 320 // No need to persist as TaskExitedResponse is currently empty 321 322 if tr.logger.IsTrace() { 323 end := time.Now() 324 tr.logger.Trace("finished exited hooks", "name", name, "end", end, "duration", end.Sub(start)) 325 } 326 } 327 328 return merr.ErrorOrNil() 329 330 } 331 332 // stop is used to run the stop hooks. 333 func (tr *TaskRunner) stop() error { 334 if tr.logger.IsTrace() { 335 start := time.Now() 336 tr.logger.Trace("running stop hooks", "start", start) 337 defer func() { 338 end := time.Now() 339 tr.logger.Trace("finished stop hooks", "end", end, "duration", end.Sub(start)) 340 }() 341 } 342 343 var merr multierror.Error 344 for _, hook := range tr.runnerHooks { 345 post, ok := hook.(interfaces.TaskStopHook) 346 if !ok { 347 continue 348 } 349 350 name := post.Name() 351 var start time.Time 352 if tr.logger.IsTrace() { 353 start = time.Now() 354 tr.logger.Trace("running stop hook", "name", name, "start", start) 355 } 356 357 req := interfaces.TaskStopRequest{} 358 359 origHookState := tr.hookState(name) 360 if origHookState != nil { 361 // Give the hook data provided by prestart 362 req.ExistingState = origHookState.Data 363 } 364 365 var resp interfaces.TaskStopResponse 366 if err := post.Stop(tr.killCtx, &req, &resp); err != nil { 367 tr.emitHookError(err, name) 368 merr.Errors = append(merr.Errors, fmt.Errorf("stop hook %q failed: %v", name, err)) 369 } 370 371 // Stop hooks cannot alter state and must be idempotent, so 372 // unlike prestart there's no state to persist here. 373 374 if tr.logger.IsTrace() { 375 end := time.Now() 376 tr.logger.Trace("finished stop hook", "name", name, "end", end, "duration", end.Sub(start)) 377 } 378 } 379 380 return merr.ErrorOrNil() 381 } 382 383 // update is used to run the runners update hooks. Should only be called from 384 // Run(). To trigger an update, update state on the TaskRunner and call 385 // triggerUpdateHooks. 386 func (tr *TaskRunner) updateHooks() { 387 if tr.logger.IsTrace() { 388 start := time.Now() 389 tr.logger.Trace("running update hooks", "start", start) 390 defer func() { 391 end := time.Now() 392 tr.logger.Trace("finished update hooks", "end", end, "duration", end.Sub(start)) 393 }() 394 } 395 396 // Prepare state needed by Update hooks 397 alloc := tr.Alloc() 398 399 // Execute Update hooks 400 for _, hook := range tr.runnerHooks { 401 upd, ok := hook.(interfaces.TaskUpdateHook) 402 if !ok { 403 continue 404 } 405 406 name := upd.Name() 407 408 // Build the request 409 req := interfaces.TaskUpdateRequest{ 410 VaultToken: tr.getVaultToken(), 411 Alloc: alloc, 412 TaskEnv: tr.envBuilder.Build(), 413 } 414 415 // Time the update hook 416 var start time.Time 417 if tr.logger.IsTrace() { 418 start = time.Now() 419 tr.logger.Trace("running update hook", "name", name, "start", start) 420 } 421 422 // Run the update hook 423 var resp interfaces.TaskUpdateResponse 424 if err := upd.Update(tr.killCtx, &req, &resp); err != nil { 425 tr.emitHookError(err, name) 426 tr.logger.Error("update hook failed", "name", name, "error", err) 427 } 428 429 // No need to persist as TaskUpdateResponse is currently empty 430 431 if tr.logger.IsTrace() { 432 end := time.Now() 433 tr.logger.Trace("finished update hooks", "name", name, "end", end, "duration", end.Sub(start)) 434 } 435 } 436 } 437 438 // preKill is used to run the runners preKill hooks 439 // preKill hooks contain logic that must be executed before 440 // a task is killed or restarted 441 func (tr *TaskRunner) preKill() { 442 if tr.logger.IsTrace() { 443 start := time.Now() 444 tr.logger.Trace("running pre kill hooks", "start", start) 445 defer func() { 446 end := time.Now() 447 tr.logger.Trace("finished pre kill hooks", "end", end, "duration", end.Sub(start)) 448 }() 449 } 450 451 for _, hook := range tr.runnerHooks { 452 killHook, ok := hook.(interfaces.TaskPreKillHook) 453 if !ok { 454 continue 455 } 456 457 name := killHook.Name() 458 459 // Time the pre kill hook 460 var start time.Time 461 if tr.logger.IsTrace() { 462 start = time.Now() 463 tr.logger.Trace("running prekill hook", "name", name, "start", start) 464 } 465 466 // Run the pre kill hook 467 req := interfaces.TaskPreKillRequest{} 468 var resp interfaces.TaskPreKillResponse 469 if err := killHook.PreKilling(context.Background(), &req, &resp); err != nil { 470 tr.emitHookError(err, name) 471 tr.logger.Error("prekill hook failed", "name", name, "error", err) 472 } 473 474 // No need to persist as TaskKillResponse is currently empty 475 476 if tr.logger.IsTrace() { 477 end := time.Now() 478 tr.logger.Trace("finished prekill hook", "name", name, "end", end, "duration", end.Sub(start)) 479 } 480 } 481 } 482 483 // shutdownHooks is called when the TaskRunner is gracefully shutdown but the 484 // task is not being stopped or garbage collected. 485 func (tr *TaskRunner) shutdownHooks() { 486 for _, hook := range tr.runnerHooks { 487 sh, ok := hook.(interfaces.ShutdownHook) 488 if !ok { 489 continue 490 } 491 492 name := sh.Name() 493 494 // Time the update hook 495 var start time.Time 496 if tr.logger.IsTrace() { 497 start = time.Now() 498 tr.logger.Trace("running shutdown hook", "name", name, "start", start) 499 } 500 501 sh.Shutdown() 502 503 if tr.logger.IsTrace() { 504 end := time.Now() 505 tr.logger.Trace("finished shutdown hook", "name", name, "end", end, "duration", end.Sub(start)) 506 } 507 } 508 }