github.com/hashicorp/vault/sdk@v0.11.0/helper/testcluster/util.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package testcluster 5 6 import ( 7 "context" 8 "encoding/base64" 9 "encoding/hex" 10 "fmt" 11 "sync/atomic" 12 "time" 13 14 "github.com/hashicorp/go-multierror" 15 "github.com/hashicorp/go-uuid" 16 "github.com/hashicorp/vault/api" 17 "github.com/hashicorp/vault/sdk/helper/xor" 18 ) 19 20 // Note that OSS standbys will not accept seal requests. And ent perf standbys 21 // may fail it as well if they haven't yet been able to get "elected" as perf standbys. 22 func SealNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error { 23 if nodeIdx >= len(cluster.Nodes()) { 24 return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx) 25 } 26 node := cluster.Nodes()[nodeIdx] 27 client := node.APIClient() 28 29 err := client.Sys().SealWithContext(ctx) 30 if err != nil { 31 return err 32 } 33 34 return NodeSealed(ctx, cluster, nodeIdx) 35 } 36 37 func SealAllNodes(ctx context.Context, cluster VaultCluster) error { 38 for i := range cluster.Nodes() { 39 if err := SealNode(ctx, cluster, i); err != nil { 40 return err 41 } 42 } 43 return nil 44 } 45 46 func UnsealNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error { 47 if nodeIdx >= len(cluster.Nodes()) { 48 return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx) 49 } 50 node := cluster.Nodes()[nodeIdx] 51 client := node.APIClient() 52 53 for _, key := range cluster.GetBarrierOrRecoveryKeys() { 54 _, err := client.Sys().UnsealWithContext(ctx, hex.EncodeToString(key)) 55 if err != nil { 56 return err 57 } 58 } 59 60 return NodeHealthy(ctx, cluster, nodeIdx) 61 } 62 63 func UnsealAllNodes(ctx context.Context, cluster VaultCluster) error { 64 for i := range cluster.Nodes() { 65 if err := UnsealNode(ctx, cluster, i); err != nil { 66 return err 67 } 68 } 69 return nil 70 } 71 72 func NodeSealed(ctx context.Context, cluster VaultCluster, nodeIdx int) error { 73 if nodeIdx >= len(cluster.Nodes()) { 74 return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx) 75 } 76 node := cluster.Nodes()[nodeIdx] 77 client := node.APIClient() 78 79 var health *api.HealthResponse 80 var err error 81 for ctx.Err() == nil { 82 health, err = client.Sys().HealthWithContext(ctx) 83 switch { 84 case err != nil: 85 case !health.Sealed: 86 err = fmt.Errorf("unsealed: %#v", health) 87 default: 88 return nil 89 } 90 time.Sleep(500 * time.Millisecond) 91 } 92 return fmt.Errorf("node %d is not sealed: %v", nodeIdx, err) 93 } 94 95 func WaitForNCoresSealed(ctx context.Context, cluster VaultCluster, n int) error { 96 ctx, cancel := context.WithCancel(ctx) 97 defer cancel() 98 99 errs := make(chan error) 100 for i := range cluster.Nodes() { 101 go func(i int) { 102 var err error 103 for ctx.Err() == nil { 104 err = NodeSealed(ctx, cluster, i) 105 if err == nil { 106 errs <- nil 107 return 108 } 109 time.Sleep(100 * time.Millisecond) 110 } 111 if err == nil { 112 err = ctx.Err() 113 } 114 errs <- err 115 }(i) 116 } 117 118 var merr *multierror.Error 119 var sealed int 120 for range cluster.Nodes() { 121 err := <-errs 122 if err != nil { 123 merr = multierror.Append(merr, err) 124 } else { 125 sealed++ 126 if sealed == n { 127 return nil 128 } 129 } 130 } 131 132 return fmt.Errorf("%d cores were not sealed, errs: %v", n, merr.ErrorOrNil()) 133 } 134 135 func NodeHealthy(ctx context.Context, cluster VaultCluster, nodeIdx int) error { 136 if nodeIdx >= len(cluster.Nodes()) { 137 return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx) 138 } 139 node := cluster.Nodes()[nodeIdx] 140 client := node.APIClient() 141 142 var health *api.HealthResponse 143 var err error 144 for ctx.Err() == nil { 145 health, err = client.Sys().HealthWithContext(ctx) 146 switch { 147 case err != nil: 148 case health == nil: 149 err = fmt.Errorf("nil response to health check") 150 case health.Sealed: 151 err = fmt.Errorf("sealed: %#v", health) 152 default: 153 return nil 154 } 155 time.Sleep(500 * time.Millisecond) 156 } 157 return fmt.Errorf("node %d is unhealthy: %v", nodeIdx, err) 158 } 159 160 func LeaderNode(ctx context.Context, cluster VaultCluster) (int, error) { 161 // Be robust to multiple nodes thinking they are active. This is possible in 162 // certain network partition situations where the old leader has not 163 // discovered it's lost leadership yet. In tests this is only likely to come 164 // up when we are specifically provoking it, but it's possible it could happen 165 // at any point if leadership flaps of connectivity suffers transient errors 166 // etc. so be robust against it. The best solution would be to have some sort 167 // of epoch like the raft term that is guaranteed to be monotonically 168 // increasing through elections, however we don't have that abstraction for 169 // all HABackends in general. The best we have is the ActiveTime. In a 170 // distributed systems text book this would be bad to rely on due to clock 171 // sync issues etc. but for our tests it's likely fine because even if we are 172 // running separate Vault containers, they are all using the same hardware 173 // clock in the system. 174 leaderActiveTimes := make(map[int]time.Time) 175 for i, node := range cluster.Nodes() { 176 client := node.APIClient() 177 ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) 178 resp, err := client.Sys().LeaderWithContext(ctx) 179 cancel() 180 if err != nil || resp == nil || !resp.IsSelf { 181 continue 182 } 183 leaderActiveTimes[i] = resp.ActiveTime 184 } 185 if len(leaderActiveTimes) == 0 { 186 return -1, fmt.Errorf("no leader found") 187 } 188 // At least one node thinks it is active. If multiple, pick the one with the 189 // most recent ActiveTime. Note if there is only one then this just returns 190 // it. 191 var newestLeaderIdx int 192 var newestActiveTime time.Time 193 for i, at := range leaderActiveTimes { 194 if at.After(newestActiveTime) { 195 newestActiveTime = at 196 newestLeaderIdx = i 197 } 198 } 199 return newestLeaderIdx, nil 200 } 201 202 func WaitForActiveNode(ctx context.Context, cluster VaultCluster) (int, error) { 203 for ctx.Err() == nil { 204 if idx, _ := LeaderNode(ctx, cluster); idx != -1 { 205 return idx, nil 206 } 207 time.Sleep(500 * time.Millisecond) 208 } 209 return -1, ctx.Err() 210 } 211 212 func WaitForStandbyNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error { 213 if nodeIdx >= len(cluster.Nodes()) { 214 return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx) 215 } 216 node := cluster.Nodes()[nodeIdx] 217 client := node.APIClient() 218 219 var err error 220 for ctx.Err() == nil { 221 var resp *api.LeaderResponse 222 223 resp, err = client.Sys().LeaderWithContext(ctx) 224 switch { 225 case err != nil: 226 case resp.IsSelf: 227 return fmt.Errorf("waiting for standby but node is leader") 228 case resp.LeaderAddress == "": 229 err = fmt.Errorf("node doesn't know leader address") 230 default: 231 return nil 232 } 233 234 time.Sleep(100 * time.Millisecond) 235 } 236 if err == nil { 237 err = ctx.Err() 238 } 239 return err 240 } 241 242 func WaitForActiveNodeAndStandbys(ctx context.Context, cluster VaultCluster) (int, error) { 243 ctx, cancel := context.WithCancel(ctx) 244 defer cancel() 245 246 leaderIdx, err := WaitForActiveNode(ctx, cluster) 247 if err != nil { 248 return 0, err 249 } 250 251 if len(cluster.Nodes()) == 1 { 252 return 0, nil 253 } 254 255 errs := make(chan error) 256 for i := range cluster.Nodes() { 257 if i == leaderIdx { 258 continue 259 } 260 go func(i int) { 261 errs <- WaitForStandbyNode(ctx, cluster, i) 262 }(i) 263 } 264 265 var merr *multierror.Error 266 expectedStandbys := len(cluster.Nodes()) - 1 267 for i := 0; i < expectedStandbys; i++ { 268 merr = multierror.Append(merr, <-errs) 269 } 270 271 return leaderIdx, merr.ErrorOrNil() 272 } 273 274 func WaitForActiveNodeAndPerfStandbys(ctx context.Context, cluster VaultCluster) error { 275 logger := cluster.NamedLogger("WaitForActiveNodeAndPerfStandbys") 276 // This WaitForActiveNode was added because after a Raft cluster is sealed 277 // and then unsealed, when it comes up it may have a different leader than 278 // Core0, making this helper fail. 279 // A sleep before calling WaitForActiveNodeAndPerfStandbys seems to sort 280 // things out, but so apparently does this. We should be able to eliminate 281 // this call to WaitForActiveNode by reworking the logic in this method. 282 leaderIdx, err := WaitForActiveNode(ctx, cluster) 283 if err != nil { 284 return err 285 } 286 287 if len(cluster.Nodes()) == 1 { 288 return nil 289 } 290 291 expectedStandbys := len(cluster.Nodes()) - 1 292 293 mountPoint, err := uuid.GenerateUUID() 294 if err != nil { 295 return err 296 } 297 leaderClient := cluster.Nodes()[leaderIdx].APIClient() 298 299 for ctx.Err() == nil { 300 err = leaderClient.Sys().MountWithContext(ctx, mountPoint, &api.MountInput{ 301 Type: "kv", 302 Local: true, 303 }) 304 if err == nil { 305 break 306 } 307 time.Sleep(1 * time.Second) 308 } 309 if err != nil { 310 return fmt.Errorf("unable to mount KV engine: %v", err) 311 } 312 path := mountPoint + "/waitforactivenodeandperfstandbys" 313 var standbys, actives int64 314 errchan := make(chan error, len(cluster.Nodes())) 315 for i := range cluster.Nodes() { 316 go func(coreNo int) { 317 node := cluster.Nodes()[coreNo] 318 client := node.APIClient() 319 val := 1 320 var err error 321 defer func() { 322 errchan <- err 323 }() 324 325 var lastWAL uint64 326 for ctx.Err() == nil { 327 _, err = leaderClient.Logical().WriteWithContext(ctx, path, map[string]interface{}{ 328 "bar": val, 329 }) 330 val++ 331 time.Sleep(250 * time.Millisecond) 332 if err != nil { 333 continue 334 } 335 var leader *api.LeaderResponse 336 leader, err = client.Sys().LeaderWithContext(ctx) 337 if err != nil { 338 logger.Trace("waiting for core", "core", coreNo, "err", err) 339 continue 340 } 341 switch { 342 case leader.IsSelf: 343 logger.Trace("waiting for core", "core", coreNo, "isLeader", true) 344 atomic.AddInt64(&actives, 1) 345 return 346 case leader.PerfStandby && leader.PerfStandbyLastRemoteWAL > 0: 347 switch { 348 case lastWAL == 0: 349 lastWAL = leader.PerfStandbyLastRemoteWAL 350 logger.Trace("waiting for core", "core", coreNo, "lastRemoteWAL", leader.PerfStandbyLastRemoteWAL, "lastWAL", lastWAL) 351 case lastWAL < leader.PerfStandbyLastRemoteWAL: 352 logger.Trace("waiting for core", "core", coreNo, "lastRemoteWAL", leader.PerfStandbyLastRemoteWAL, "lastWAL", lastWAL) 353 atomic.AddInt64(&standbys, 1) 354 return 355 } 356 default: 357 logger.Trace("waiting for core", "core", coreNo, 358 "ha_enabled", leader.HAEnabled, 359 "is_self", leader.IsSelf, 360 "perf_standby", leader.PerfStandby, 361 "perf_standby_remote_wal", leader.PerfStandbyLastRemoteWAL) 362 } 363 } 364 }(i) 365 } 366 367 errs := make([]error, 0, len(cluster.Nodes())) 368 for range cluster.Nodes() { 369 errs = append(errs, <-errchan) 370 } 371 if actives != 1 || int(standbys) != expectedStandbys { 372 return fmt.Errorf("expected 1 active core and %d standbys, got %d active and %d standbys, errs: %v", 373 expectedStandbys, actives, standbys, errs) 374 } 375 376 for ctx.Err() == nil { 377 err = leaderClient.Sys().UnmountWithContext(ctx, mountPoint) 378 if err == nil { 379 break 380 } 381 time.Sleep(time.Second) 382 } 383 if err != nil { 384 return fmt.Errorf("unable to unmount KV engine on primary") 385 } 386 return nil 387 } 388 389 type GenerateRootKind int 390 391 const ( 392 GenerateRootRegular GenerateRootKind = iota 393 GenerateRootDR 394 GenerateRecovery 395 ) 396 397 func GenerateRoot(cluster VaultCluster, kind GenerateRootKind) (string, error) { 398 // If recovery keys supported, use those to perform root token generation instead 399 keys := cluster.GetBarrierOrRecoveryKeys() 400 401 client := cluster.Nodes()[0].APIClient() 402 403 var err error 404 var status *api.GenerateRootStatusResponse 405 switch kind { 406 case GenerateRootRegular: 407 status, err = client.Sys().GenerateRootInit("", "") 408 case GenerateRootDR: 409 status, err = client.Sys().GenerateDROperationTokenInit("", "") 410 case GenerateRecovery: 411 status, err = client.Sys().GenerateRecoveryOperationTokenInit("", "") 412 } 413 if err != nil { 414 return "", err 415 } 416 417 if status.Required > len(keys) { 418 return "", fmt.Errorf("need more keys than have, need %d have %d", status.Required, len(keys)) 419 } 420 421 otp := status.OTP 422 423 for i, key := range keys { 424 if i >= status.Required { 425 break 426 } 427 428 strKey := base64.StdEncoding.EncodeToString(key) 429 switch kind { 430 case GenerateRootRegular: 431 status, err = client.Sys().GenerateRootUpdate(strKey, status.Nonce) 432 case GenerateRootDR: 433 status, err = client.Sys().GenerateDROperationTokenUpdate(strKey, status.Nonce) 434 case GenerateRecovery: 435 status, err = client.Sys().GenerateRecoveryOperationTokenUpdate(strKey, status.Nonce) 436 } 437 if err != nil { 438 return "", err 439 } 440 } 441 if !status.Complete { 442 return "", fmt.Errorf("generate root operation did not end successfully") 443 } 444 445 tokenBytes, err := base64.RawStdEncoding.DecodeString(status.EncodedToken) 446 if err != nil { 447 return "", err 448 } 449 tokenBytes, err = xor.XORBytes(tokenBytes, []byte(otp)) 450 if err != nil { 451 return "", err 452 } 453 return string(tokenBytes), nil 454 }