github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/allocrunner/taskrunner/sids_hook.go (about) 1 package taskrunner 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "io/ioutil" 8 "os" 9 "path/filepath" 10 "sync" 11 "time" 12 13 "github.com/hashicorp/go-hclog" 14 "github.com/hashicorp/nomad/client/allocrunner/interfaces" 15 ti "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces" 16 "github.com/hashicorp/nomad/client/consul" 17 "github.com/hashicorp/nomad/nomad/structs" 18 ) 19 20 const ( 21 // the name of this hook, used in logs 22 sidsHookName = "consul_si_token" 23 24 // sidsBackoffBaseline is the baseline time for exponential backoff when 25 // attempting to retrieve a Consul SI token 26 sidsBackoffBaseline = 5 * time.Second 27 28 // sidsBackoffLimit is the limit of the exponential backoff when attempting 29 // to retrieve a Consul SI token 30 sidsBackoffLimit = 3 * time.Minute 31 32 // sidsDerivationTimeout limits the amount of time we may spend trying to 33 // derive a SI token. If the hook does not get a token within this amount of 34 // time, the result is a failure. 35 sidsDerivationTimeout = 5 * time.Minute 36 37 // sidsTokenFile is the name of the file holding the Consul SI token inside 38 // the task's secret directory 39 sidsTokenFile = "si_token" 40 41 // sidsTokenFilePerms is the level of file permissions granted on the file 42 // in the secrets directory for the task 43 sidsTokenFilePerms = 0440 44 ) 45 46 type sidsHookConfig struct { 47 alloc *structs.Allocation 48 task *structs.Task 49 sidsClient consul.ServiceIdentityAPI 50 lifecycle ti.TaskLifecycle 51 logger hclog.Logger 52 } 53 54 // Service Identities hook for managing SI tokens of connect enabled tasks. 55 type sidsHook struct { 56 // alloc is the allocation 57 alloc *structs.Allocation 58 59 // taskName is the name of the task 60 task *structs.Task 61 62 // sidsClient is the Consul client [proxy] for requesting SI tokens 63 sidsClient consul.ServiceIdentityAPI 64 65 // lifecycle is used to signal, restart, and kill a task 66 lifecycle ti.TaskLifecycle 67 68 // derivationTimeout is the amount of time we may wait for Consul to successfully 69 // provide a SI token. Making this configurable for testing, otherwise 70 // default to sidsDerivationTimeout 71 derivationTimeout time.Duration 72 73 // logger is used to log 74 logger hclog.Logger 75 76 // lock variables that can be manipulated after hook creation 77 lock sync.Mutex 78 // firstRun keeps track of whether the hook is being called for the first 79 // time (for this task) during the lifespan of the Nomad Client process. 80 firstRun bool 81 } 82 83 func newSIDSHook(c sidsHookConfig) *sidsHook { 84 return &sidsHook{ 85 alloc: c.alloc, 86 task: c.task, 87 sidsClient: c.sidsClient, 88 lifecycle: c.lifecycle, 89 derivationTimeout: sidsDerivationTimeout, 90 logger: c.logger.Named(sidsHookName), 91 firstRun: true, 92 } 93 } 94 95 func (h *sidsHook) Name() string { 96 return sidsHookName 97 } 98 99 func (h *sidsHook) Prestart( 100 ctx context.Context, 101 req *interfaces.TaskPrestartRequest, 102 resp *interfaces.TaskPrestartResponse) error { 103 104 h.lock.Lock() 105 defer h.lock.Unlock() 106 107 // do nothing if we have already done things 108 if h.earlyExit() { 109 resp.Done = true 110 return nil 111 } 112 113 // optimistically try to recover token from disk 114 token, err := h.recoverToken(req.TaskDir.SecretsDir) 115 if err != nil { 116 return err 117 } 118 119 // need to ask for a new SI token & persist it to disk 120 if token == "" { 121 if token, err = h.deriveSIToken(ctx); err != nil { 122 return err 123 } 124 if err := h.writeToken(req.TaskDir.SecretsDir, token); err != nil { 125 return err 126 } 127 } 128 129 h.logger.Info("derived SI token", "task", h.task.Name, "si_task", h.task.Kind.Value()) 130 131 resp.Done = true 132 return nil 133 } 134 135 // earlyExit returns true if the Prestart hook has already been executed during 136 // the instantiation of this task runner. 137 // 138 // assumes h is locked 139 func (h *sidsHook) earlyExit() bool { 140 if h.firstRun { 141 h.firstRun = false 142 return false 143 } 144 return true 145 } 146 147 // writeToken writes token into the secrets directory for the task. 148 func (h *sidsHook) writeToken(dir string, token string) error { 149 tokenPath := filepath.Join(dir, sidsTokenFile) 150 if err := ioutil.WriteFile(tokenPath, []byte(token), sidsTokenFilePerms); err != nil { 151 return fmt.Errorf("failed to write SI token: %w", err) 152 } 153 return nil 154 } 155 156 // recoverToken returns the token saved to disk in the secrets directory for the 157 // task if it exists, or the empty string if the file does not exist. an error 158 // is returned only for some other (e.g. disk IO) error. 159 func (h *sidsHook) recoverToken(dir string) (string, error) { 160 tokenPath := filepath.Join(dir, sidsTokenFile) 161 token, err := ioutil.ReadFile(tokenPath) 162 if err != nil { 163 if !os.IsNotExist(err) { 164 h.logger.Error("failed to recover SI token", "error", err) 165 return "", fmt.Errorf("failed to recover SI token: %w", err) 166 } 167 h.logger.Trace("no pre-existing SI token to recover", "task", h.task.Name) 168 return "", nil // token file does not exist yet 169 } 170 h.logger.Trace("recovered pre-existing SI token", "task", h.task.Name) 171 return string(token), nil 172 } 173 174 // siDerivationResult is used to pass along the result of attempting to derive 175 // an SI token between the goroutine doing the derivation and its caller 176 type siDerivationResult struct { 177 token string 178 err error 179 } 180 181 // deriveSIToken spawns and waits on a goroutine which will make attempts to 182 // derive an SI token until a token is successfully created, or ctx is signaled 183 // done. 184 func (h *sidsHook) deriveSIToken(ctx context.Context) (string, error) { 185 ctx, cancel := context.WithTimeout(ctx, h.derivationTimeout) 186 defer cancel() 187 188 resultCh := make(chan siDerivationResult) 189 190 // keep trying to get the token in the background 191 go h.tryDerive(ctx, resultCh) 192 193 // wait until we get a token, or we get a signal to quit 194 for { 195 select { 196 case result := <-resultCh: 197 if result.err != nil { 198 h.logger.Error("failed to derive SI token", "error", result.err) 199 h.kill(ctx, fmt.Errorf("failed to derive SI token: %w", result.err)) 200 return "", result.err 201 } 202 return result.token, nil 203 case <-ctx.Done(): 204 return "", ctx.Err() 205 } 206 } 207 } 208 209 func (h *sidsHook) kill(ctx context.Context, reason error) { 210 if err := h.lifecycle.Kill(ctx, 211 structs.NewTaskEvent(structs.TaskKilling). 212 SetFailsTask(). 213 SetDisplayMessage(reason.Error()), 214 ); err != nil { 215 h.logger.Error("failed to kill task", "kill_reason", reason, "error", err) 216 } 217 } 218 219 // tryDerive loops forever until a token is created, or ctx is done. 220 func (h *sidsHook) tryDerive(ctx context.Context, ch chan<- siDerivationResult) { 221 for attempt := 0; backoff(ctx, attempt); attempt++ { 222 223 tokens, err := h.sidsClient.DeriveSITokens(h.alloc, []string{h.task.Name}) 224 225 switch { 226 case err == nil: 227 token, exists := tokens[h.task.Name] 228 if !exists { 229 err := errors.New("response does not include token for task") 230 h.logger.Error("derive SI token is missing token for task", "error", err, "task", h.task.Name) 231 ch <- siDerivationResult{token: "", err: err} 232 return 233 } 234 ch <- siDerivationResult{token: token, err: nil} 235 return 236 case structs.IsServerSide(err): 237 // the error is known to be a server problem, just die 238 h.logger.Error("failed to derive SI token", "error", err, "task", h.task.Name, "server_side", true) 239 ch <- siDerivationResult{token: "", err: err} 240 return 241 case !structs.IsRecoverable(err): 242 // the error is known not to be recoverable, just die 243 h.logger.Error("failed to derive SI token", "error", err, "task", h.task.Name, "recoverable", false) 244 ch <- siDerivationResult{token: "", err: err} 245 return 246 247 default: 248 // the error is marked recoverable, retry after some backoff 249 h.logger.Error("failed attempt to derive SI token", "error", err, "recoverable", true) 250 } 251 } 252 } 253 254 func backoff(ctx context.Context, attempt int) bool { 255 next := computeBackoff(attempt) 256 select { 257 case <-ctx.Done(): 258 return false 259 case <-time.After(next): 260 return true 261 } 262 } 263 264 func computeBackoff(attempt int) time.Duration { 265 switch attempt { 266 case 0: 267 return 0 268 case 1: 269 // go fast on first retry, because a unit test should be fast 270 return 100 * time.Millisecond 271 default: 272 wait := time.Duration(attempt) * sidsBackoffBaseline 273 if wait > sidsBackoffLimit { 274 wait = sidsBackoffLimit 275 } 276 return wait 277 } 278 }