github.com/iqoqo/nomad@v0.11.3-0.20200911112621-d7021c74d101/client/allocrunner/taskrunner/sids_hook.go (about) 1 package taskrunner 2 3 import ( 4 "context" 5 "io/ioutil" 6 "os" 7 "path/filepath" 8 "sync" 9 "time" 10 11 "github.com/hashicorp/go-hclog" 12 "github.com/hashicorp/nomad/client/allocrunner/interfaces" 13 ti "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces" 14 "github.com/hashicorp/nomad/client/consul" 15 "github.com/hashicorp/nomad/nomad/structs" 16 "github.com/pkg/errors" 17 ) 18 19 const ( 20 // the name of this hook, used in logs 21 sidsHookName = "consul_si_token" 22 23 // sidsBackoffBaseline is the baseline time for exponential backoff when 24 // attempting to retrieve a Consul SI token 25 sidsBackoffBaseline = 5 * time.Second 26 27 // sidsBackoffLimit is the limit of the exponential backoff when attempting 28 // to retrieve a Consul SI token 29 sidsBackoffLimit = 3 * time.Minute 30 31 // sidsDerivationTimeout limits the amount of time we may spend trying to 32 // derive a SI token. If the hook does not get a token within this amount of 33 // time, the result is a failure. 34 sidsDerivationTimeout = 5 * time.Minute 35 36 // sidsTokenFile is the name of the file holding the Consul SI token inside 37 // the task's secret directory 38 sidsTokenFile = "si_token" 39 40 // sidsTokenFilePerms is the level of file permissions granted on the file 41 // in the secrets directory for the task 42 sidsTokenFilePerms = 0440 43 ) 44 45 type sidsHookConfig struct { 46 alloc *structs.Allocation 47 task *structs.Task 48 sidsClient consul.ServiceIdentityAPI 49 lifecycle ti.TaskLifecycle 50 logger hclog.Logger 51 } 52 53 // Service Identities hook for managing SI tokens of connect enabled tasks. 54 type sidsHook struct { 55 // alloc is the allocation 56 alloc *structs.Allocation 57 58 // taskName is the name of the task 59 task *structs.Task 60 61 // sidsClient is the Consul client [proxy] for requesting SI tokens 62 sidsClient consul.ServiceIdentityAPI 63 64 // lifecycle is used to signal, restart, and kill a task 65 lifecycle ti.TaskLifecycle 66 67 // derivationTimeout is the amount of time we may wait for Consul to successfully 68 // provide a SI token. Making this configurable for testing, otherwise 69 // default to sidsDerivationTimeout 70 derivationTimeout time.Duration 71 72 // logger is used to log 73 logger hclog.Logger 74 75 // lock variables that can be manipulated after hook creation 76 lock sync.Mutex 77 // firstRun keeps track of whether the hook is being called for the first 78 // time (for this task) during the lifespan of the Nomad Client process. 79 firstRun bool 80 } 81 82 func newSIDSHook(c sidsHookConfig) *sidsHook { 83 return &sidsHook{ 84 alloc: c.alloc, 85 task: c.task, 86 sidsClient: c.sidsClient, 87 lifecycle: c.lifecycle, 88 derivationTimeout: sidsDerivationTimeout, 89 logger: c.logger.Named(sidsHookName), 90 firstRun: true, 91 } 92 } 93 94 func (h *sidsHook) Name() string { 95 return sidsHookName 96 } 97 98 func (h *sidsHook) Prestart( 99 ctx context.Context, 100 req *interfaces.TaskPrestartRequest, 101 resp *interfaces.TaskPrestartResponse) error { 102 103 h.lock.Lock() 104 defer h.lock.Unlock() 105 106 // do nothing if we have already done things 107 if h.earlyExit() { 108 resp.Done = true 109 return nil 110 } 111 112 // optimistically try to recover token from disk 113 token, err := h.recoverToken(req.TaskDir.SecretsDir) 114 if err != nil { 115 return err 116 } 117 118 // need to ask for a new SI token & persist it to disk 119 if token == "" { 120 if token, err = h.deriveSIToken(ctx); err != nil { 121 return err 122 } 123 if err := h.writeToken(req.TaskDir.SecretsDir, token); err != nil { 124 return err 125 } 126 } 127 128 h.logger.Info("derived SI token", "task", h.task.Name, "si_task", h.task.Kind.Value()) 129 130 resp.Done = true 131 return nil 132 } 133 134 // earlyExit returns true if the Prestart hook has already been executed during 135 // the instantiation of this task runner. 136 // 137 // assumes h is locked 138 func (h *sidsHook) earlyExit() bool { 139 if h.firstRun { 140 h.firstRun = false 141 return false 142 } 143 return true 144 } 145 146 // writeToken writes token into the secrets directory for the task. 147 func (h *sidsHook) writeToken(dir string, token string) error { 148 tokenPath := filepath.Join(dir, sidsTokenFile) 149 if err := ioutil.WriteFile(tokenPath, []byte(token), sidsTokenFilePerms); err != nil { 150 return errors.Wrap(err, "failed to write SI token") 151 } 152 return nil 153 } 154 155 // recoverToken returns the token saved to disk in the secrets directory for the 156 // task if it exists, or the empty string if the file does not exist. an error 157 // is returned only for some other (e.g. disk IO) error. 158 func (h *sidsHook) recoverToken(dir string) (string, error) { 159 tokenPath := filepath.Join(dir, sidsTokenFile) 160 token, err := ioutil.ReadFile(tokenPath) 161 if err != nil { 162 if !os.IsNotExist(err) { 163 h.logger.Error("failed to recover SI token", "error", err) 164 return "", errors.Wrap(err, "failed to recover SI token") 165 } 166 h.logger.Trace("no pre-existing SI token to recover", "task", h.task.Name) 167 return "", nil // token file does not exist yet 168 } 169 h.logger.Trace("recovered pre-existing SI token", "task", h.task.Name) 170 return string(token), nil 171 } 172 173 // siDerivationResult is used to pass along the result of attempting to derive 174 // an SI token between the goroutine doing the derivation and its caller 175 type siDerivationResult struct { 176 token string 177 err error 178 } 179 180 // deriveSIToken spawns and waits on a goroutine which will make attempts to 181 // derive an SI token until a token is successfully created, or ctx is signaled 182 // done. 183 func (h *sidsHook) deriveSIToken(ctx context.Context) (string, error) { 184 ctx, cancel := context.WithTimeout(ctx, h.derivationTimeout) 185 defer cancel() 186 187 resultCh := make(chan siDerivationResult) 188 189 // keep trying to get the token in the background 190 go h.tryDerive(ctx, resultCh) 191 192 // wait until we get a token, or we get a signal to quit 193 for { 194 select { 195 case result := <-resultCh: 196 if result.err != nil { 197 h.logger.Error("failed to derive SI token", "error", result.err) 198 h.kill(ctx, errors.Wrap(result.err, "failed to derive SI token")) 199 return "", result.err 200 } 201 return result.token, nil 202 case <-ctx.Done(): 203 return "", ctx.Err() 204 } 205 } 206 } 207 208 func (h *sidsHook) kill(ctx context.Context, reason error) { 209 if err := h.lifecycle.Kill(ctx, 210 structs.NewTaskEvent(structs.TaskKilling). 211 SetFailsTask(). 212 SetDisplayMessage(reason.Error()), 213 ); err != nil { 214 h.logger.Error("failed to kill task", "kill_reason", reason, "error", err) 215 } 216 } 217 218 // tryDerive loops forever until a token is created, or ctx is done. 219 func (h *sidsHook) tryDerive(ctx context.Context, ch chan<- siDerivationResult) { 220 for attempt := 0; backoff(ctx, attempt); attempt++ { 221 222 tokens, err := h.sidsClient.DeriveSITokens(h.alloc, []string{h.task.Name}) 223 224 switch { 225 case err == nil: 226 token, exists := tokens[h.task.Name] 227 if !exists { 228 err := errors.New("response does not include token for task") 229 h.logger.Error("derive SI token is missing token for task", "error", err, "task", h.task.Name) 230 ch <- siDerivationResult{token: "", err: err} 231 return 232 } 233 ch <- siDerivationResult{token: token, err: nil} 234 return 235 case structs.IsServerSide(err): 236 // the error is known to be a server problem, just die 237 h.logger.Error("failed to derive SI token", "error", err, "task", h.task.Name, "server_side", true) 238 ch <- siDerivationResult{token: "", err: err} 239 return 240 case !structs.IsRecoverable(err): 241 // the error is known not to be recoverable, just die 242 h.logger.Error("failed to derive SI token", "error", err, "task", h.task.Name, "recoverable", false) 243 ch <- siDerivationResult{token: "", err: err} 244 return 245 246 default: 247 // the error is marked recoverable, retry after some backoff 248 h.logger.Error("failed attempt to derive SI token", "error", err, "recoverable", true) 249 } 250 } 251 } 252 253 func backoff(ctx context.Context, attempt int) bool { 254 next := computeBackoff(attempt) 255 select { 256 case <-ctx.Done(): 257 return false 258 case <-time.After(next): 259 return true 260 } 261 } 262 263 func computeBackoff(attempt int) time.Duration { 264 switch attempt { 265 case 0: 266 return 0 267 case 1: 268 // go fast on first retry, because a unit test should be fast 269 return 100 * time.Millisecond 270 default: 271 wait := time.Duration(attempt) * sidsBackoffBaseline 272 if wait > sidsBackoffLimit { 273 wait = sidsBackoffLimit 274 } 275 return wait 276 } 277 }