github.com/hashicorp/vault/sdk@v0.11.0/helper/testcluster/util.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package testcluster
     5  
     6  import (
     7  	"context"
     8  	"encoding/base64"
     9  	"encoding/hex"
    10  	"fmt"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/hashicorp/go-multierror"
    15  	"github.com/hashicorp/go-uuid"
    16  	"github.com/hashicorp/vault/api"
    17  	"github.com/hashicorp/vault/sdk/helper/xor"
    18  )
    19  
    20  // Note that OSS standbys will not accept seal requests.  And ent perf standbys
    21  // may fail it as well if they haven't yet been able to get "elected" as perf standbys.
    22  func SealNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
    23  	if nodeIdx >= len(cluster.Nodes()) {
    24  		return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
    25  	}
    26  	node := cluster.Nodes()[nodeIdx]
    27  	client := node.APIClient()
    28  
    29  	err := client.Sys().SealWithContext(ctx)
    30  	if err != nil {
    31  		return err
    32  	}
    33  
    34  	return NodeSealed(ctx, cluster, nodeIdx)
    35  }
    36  
    37  func SealAllNodes(ctx context.Context, cluster VaultCluster) error {
    38  	for i := range cluster.Nodes() {
    39  		if err := SealNode(ctx, cluster, i); err != nil {
    40  			return err
    41  		}
    42  	}
    43  	return nil
    44  }
    45  
    46  func UnsealNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
    47  	if nodeIdx >= len(cluster.Nodes()) {
    48  		return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
    49  	}
    50  	node := cluster.Nodes()[nodeIdx]
    51  	client := node.APIClient()
    52  
    53  	for _, key := range cluster.GetBarrierOrRecoveryKeys() {
    54  		_, err := client.Sys().UnsealWithContext(ctx, hex.EncodeToString(key))
    55  		if err != nil {
    56  			return err
    57  		}
    58  	}
    59  
    60  	return NodeHealthy(ctx, cluster, nodeIdx)
    61  }
    62  
    63  func UnsealAllNodes(ctx context.Context, cluster VaultCluster) error {
    64  	for i := range cluster.Nodes() {
    65  		if err := UnsealNode(ctx, cluster, i); err != nil {
    66  			return err
    67  		}
    68  	}
    69  	return nil
    70  }
    71  
    72  func NodeSealed(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
    73  	if nodeIdx >= len(cluster.Nodes()) {
    74  		return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
    75  	}
    76  	node := cluster.Nodes()[nodeIdx]
    77  	client := node.APIClient()
    78  
    79  	var health *api.HealthResponse
    80  	var err error
    81  	for ctx.Err() == nil {
    82  		health, err = client.Sys().HealthWithContext(ctx)
    83  		switch {
    84  		case err != nil:
    85  		case !health.Sealed:
    86  			err = fmt.Errorf("unsealed: %#v", health)
    87  		default:
    88  			return nil
    89  		}
    90  		time.Sleep(500 * time.Millisecond)
    91  	}
    92  	return fmt.Errorf("node %d is not sealed: %v", nodeIdx, err)
    93  }
    94  
    95  func WaitForNCoresSealed(ctx context.Context, cluster VaultCluster, n int) error {
    96  	ctx, cancel := context.WithCancel(ctx)
    97  	defer cancel()
    98  
    99  	errs := make(chan error)
   100  	for i := range cluster.Nodes() {
   101  		go func(i int) {
   102  			var err error
   103  			for ctx.Err() == nil {
   104  				err = NodeSealed(ctx, cluster, i)
   105  				if err == nil {
   106  					errs <- nil
   107  					return
   108  				}
   109  				time.Sleep(100 * time.Millisecond)
   110  			}
   111  			if err == nil {
   112  				err = ctx.Err()
   113  			}
   114  			errs <- err
   115  		}(i)
   116  	}
   117  
   118  	var merr *multierror.Error
   119  	var sealed int
   120  	for range cluster.Nodes() {
   121  		err := <-errs
   122  		if err != nil {
   123  			merr = multierror.Append(merr, err)
   124  		} else {
   125  			sealed++
   126  			if sealed == n {
   127  				return nil
   128  			}
   129  		}
   130  	}
   131  
   132  	return fmt.Errorf("%d cores were not sealed, errs: %v", n, merr.ErrorOrNil())
   133  }
   134  
   135  func NodeHealthy(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
   136  	if nodeIdx >= len(cluster.Nodes()) {
   137  		return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
   138  	}
   139  	node := cluster.Nodes()[nodeIdx]
   140  	client := node.APIClient()
   141  
   142  	var health *api.HealthResponse
   143  	var err error
   144  	for ctx.Err() == nil {
   145  		health, err = client.Sys().HealthWithContext(ctx)
   146  		switch {
   147  		case err != nil:
   148  		case health == nil:
   149  			err = fmt.Errorf("nil response to health check")
   150  		case health.Sealed:
   151  			err = fmt.Errorf("sealed: %#v", health)
   152  		default:
   153  			return nil
   154  		}
   155  		time.Sleep(500 * time.Millisecond)
   156  	}
   157  	return fmt.Errorf("node %d is unhealthy: %v", nodeIdx, err)
   158  }
   159  
   160  func LeaderNode(ctx context.Context, cluster VaultCluster) (int, error) {
   161  	// Be robust to multiple nodes thinking they are active. This is possible in
   162  	// certain network partition situations where the old leader has not
   163  	// discovered it's lost leadership yet. In tests this is only likely to come
   164  	// up when we are specifically provoking it, but it's possible it could happen
   165  	// at any point if leadership flaps of connectivity suffers transient errors
   166  	// etc. so be robust against it. The best solution would be to have some sort
   167  	// of epoch like the raft term that is guaranteed to be monotonically
   168  	// increasing through elections, however we don't have that abstraction for
   169  	// all HABackends in general. The best we have is the ActiveTime. In a
   170  	// distributed systems text book this would be bad to rely on due to clock
   171  	// sync issues etc. but for our tests it's likely fine because even if we are
   172  	// running separate Vault containers, they are all using the same hardware
   173  	// clock in the system.
   174  	leaderActiveTimes := make(map[int]time.Time)
   175  	for i, node := range cluster.Nodes() {
   176  		client := node.APIClient()
   177  		ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
   178  		resp, err := client.Sys().LeaderWithContext(ctx)
   179  		cancel()
   180  		if err != nil || resp == nil || !resp.IsSelf {
   181  			continue
   182  		}
   183  		leaderActiveTimes[i] = resp.ActiveTime
   184  	}
   185  	if len(leaderActiveTimes) == 0 {
   186  		return -1, fmt.Errorf("no leader found")
   187  	}
   188  	// At least one node thinks it is active. If multiple, pick the one with the
   189  	// most recent ActiveTime. Note if there is only one then this just returns
   190  	// it.
   191  	var newestLeaderIdx int
   192  	var newestActiveTime time.Time
   193  	for i, at := range leaderActiveTimes {
   194  		if at.After(newestActiveTime) {
   195  			newestActiveTime = at
   196  			newestLeaderIdx = i
   197  		}
   198  	}
   199  	return newestLeaderIdx, nil
   200  }
   201  
   202  func WaitForActiveNode(ctx context.Context, cluster VaultCluster) (int, error) {
   203  	for ctx.Err() == nil {
   204  		if idx, _ := LeaderNode(ctx, cluster); idx != -1 {
   205  			return idx, nil
   206  		}
   207  		time.Sleep(500 * time.Millisecond)
   208  	}
   209  	return -1, ctx.Err()
   210  }
   211  
   212  func WaitForStandbyNode(ctx context.Context, cluster VaultCluster, nodeIdx int) error {
   213  	if nodeIdx >= len(cluster.Nodes()) {
   214  		return fmt.Errorf("invalid nodeIdx %d for cluster", nodeIdx)
   215  	}
   216  	node := cluster.Nodes()[nodeIdx]
   217  	client := node.APIClient()
   218  
   219  	var err error
   220  	for ctx.Err() == nil {
   221  		var resp *api.LeaderResponse
   222  
   223  		resp, err = client.Sys().LeaderWithContext(ctx)
   224  		switch {
   225  		case err != nil:
   226  		case resp.IsSelf:
   227  			return fmt.Errorf("waiting for standby but node is leader")
   228  		case resp.LeaderAddress == "":
   229  			err = fmt.Errorf("node doesn't know leader address")
   230  		default:
   231  			return nil
   232  		}
   233  
   234  		time.Sleep(100 * time.Millisecond)
   235  	}
   236  	if err == nil {
   237  		err = ctx.Err()
   238  	}
   239  	return err
   240  }
   241  
   242  func WaitForActiveNodeAndStandbys(ctx context.Context, cluster VaultCluster) (int, error) {
   243  	ctx, cancel := context.WithCancel(ctx)
   244  	defer cancel()
   245  
   246  	leaderIdx, err := WaitForActiveNode(ctx, cluster)
   247  	if err != nil {
   248  		return 0, err
   249  	}
   250  
   251  	if len(cluster.Nodes()) == 1 {
   252  		return 0, nil
   253  	}
   254  
   255  	errs := make(chan error)
   256  	for i := range cluster.Nodes() {
   257  		if i == leaderIdx {
   258  			continue
   259  		}
   260  		go func(i int) {
   261  			errs <- WaitForStandbyNode(ctx, cluster, i)
   262  		}(i)
   263  	}
   264  
   265  	var merr *multierror.Error
   266  	expectedStandbys := len(cluster.Nodes()) - 1
   267  	for i := 0; i < expectedStandbys; i++ {
   268  		merr = multierror.Append(merr, <-errs)
   269  	}
   270  
   271  	return leaderIdx, merr.ErrorOrNil()
   272  }
   273  
   274  func WaitForActiveNodeAndPerfStandbys(ctx context.Context, cluster VaultCluster) error {
   275  	logger := cluster.NamedLogger("WaitForActiveNodeAndPerfStandbys")
   276  	// This WaitForActiveNode was added because after a Raft cluster is sealed
   277  	// and then unsealed, when it comes up it may have a different leader than
   278  	// Core0, making this helper fail.
   279  	// A sleep before calling WaitForActiveNodeAndPerfStandbys seems to sort
   280  	// things out, but so apparently does this.  We should be able to eliminate
   281  	// this call to WaitForActiveNode by reworking the logic in this method.
   282  	leaderIdx, err := WaitForActiveNode(ctx, cluster)
   283  	if err != nil {
   284  		return err
   285  	}
   286  
   287  	if len(cluster.Nodes()) == 1 {
   288  		return nil
   289  	}
   290  
   291  	expectedStandbys := len(cluster.Nodes()) - 1
   292  
   293  	mountPoint, err := uuid.GenerateUUID()
   294  	if err != nil {
   295  		return err
   296  	}
   297  	leaderClient := cluster.Nodes()[leaderIdx].APIClient()
   298  
   299  	for ctx.Err() == nil {
   300  		err = leaderClient.Sys().MountWithContext(ctx, mountPoint, &api.MountInput{
   301  			Type:  "kv",
   302  			Local: true,
   303  		})
   304  		if err == nil {
   305  			break
   306  		}
   307  		time.Sleep(1 * time.Second)
   308  	}
   309  	if err != nil {
   310  		return fmt.Errorf("unable to mount KV engine: %v", err)
   311  	}
   312  	path := mountPoint + "/waitforactivenodeandperfstandbys"
   313  	var standbys, actives int64
   314  	errchan := make(chan error, len(cluster.Nodes()))
   315  	for i := range cluster.Nodes() {
   316  		go func(coreNo int) {
   317  			node := cluster.Nodes()[coreNo]
   318  			client := node.APIClient()
   319  			val := 1
   320  			var err error
   321  			defer func() {
   322  				errchan <- err
   323  			}()
   324  
   325  			var lastWAL uint64
   326  			for ctx.Err() == nil {
   327  				_, err = leaderClient.Logical().WriteWithContext(ctx, path, map[string]interface{}{
   328  					"bar": val,
   329  				})
   330  				val++
   331  				time.Sleep(250 * time.Millisecond)
   332  				if err != nil {
   333  					continue
   334  				}
   335  				var leader *api.LeaderResponse
   336  				leader, err = client.Sys().LeaderWithContext(ctx)
   337  				if err != nil {
   338  					logger.Trace("waiting for core", "core", coreNo, "err", err)
   339  					continue
   340  				}
   341  				switch {
   342  				case leader.IsSelf:
   343  					logger.Trace("waiting for core", "core", coreNo, "isLeader", true)
   344  					atomic.AddInt64(&actives, 1)
   345  					return
   346  				case leader.PerfStandby && leader.PerfStandbyLastRemoteWAL > 0:
   347  					switch {
   348  					case lastWAL == 0:
   349  						lastWAL = leader.PerfStandbyLastRemoteWAL
   350  						logger.Trace("waiting for core", "core", coreNo, "lastRemoteWAL", leader.PerfStandbyLastRemoteWAL, "lastWAL", lastWAL)
   351  					case lastWAL < leader.PerfStandbyLastRemoteWAL:
   352  						logger.Trace("waiting for core", "core", coreNo, "lastRemoteWAL", leader.PerfStandbyLastRemoteWAL, "lastWAL", lastWAL)
   353  						atomic.AddInt64(&standbys, 1)
   354  						return
   355  					}
   356  				default:
   357  					logger.Trace("waiting for core", "core", coreNo,
   358  						"ha_enabled", leader.HAEnabled,
   359  						"is_self", leader.IsSelf,
   360  						"perf_standby", leader.PerfStandby,
   361  						"perf_standby_remote_wal", leader.PerfStandbyLastRemoteWAL)
   362  				}
   363  			}
   364  		}(i)
   365  	}
   366  
   367  	errs := make([]error, 0, len(cluster.Nodes()))
   368  	for range cluster.Nodes() {
   369  		errs = append(errs, <-errchan)
   370  	}
   371  	if actives != 1 || int(standbys) != expectedStandbys {
   372  		return fmt.Errorf("expected 1 active core and %d standbys, got %d active and %d standbys, errs: %v",
   373  			expectedStandbys, actives, standbys, errs)
   374  	}
   375  
   376  	for ctx.Err() == nil {
   377  		err = leaderClient.Sys().UnmountWithContext(ctx, mountPoint)
   378  		if err == nil {
   379  			break
   380  		}
   381  		time.Sleep(time.Second)
   382  	}
   383  	if err != nil {
   384  		return fmt.Errorf("unable to unmount KV engine on primary")
   385  	}
   386  	return nil
   387  }
   388  
   389  type GenerateRootKind int
   390  
   391  const (
   392  	GenerateRootRegular GenerateRootKind = iota
   393  	GenerateRootDR
   394  	GenerateRecovery
   395  )
   396  
   397  func GenerateRoot(cluster VaultCluster, kind GenerateRootKind) (string, error) {
   398  	// If recovery keys supported, use those to perform root token generation instead
   399  	keys := cluster.GetBarrierOrRecoveryKeys()
   400  
   401  	client := cluster.Nodes()[0].APIClient()
   402  
   403  	var err error
   404  	var status *api.GenerateRootStatusResponse
   405  	switch kind {
   406  	case GenerateRootRegular:
   407  		status, err = client.Sys().GenerateRootInit("", "")
   408  	case GenerateRootDR:
   409  		status, err = client.Sys().GenerateDROperationTokenInit("", "")
   410  	case GenerateRecovery:
   411  		status, err = client.Sys().GenerateRecoveryOperationTokenInit("", "")
   412  	}
   413  	if err != nil {
   414  		return "", err
   415  	}
   416  
   417  	if status.Required > len(keys) {
   418  		return "", fmt.Errorf("need more keys than have, need %d have %d", status.Required, len(keys))
   419  	}
   420  
   421  	otp := status.OTP
   422  
   423  	for i, key := range keys {
   424  		if i >= status.Required {
   425  			break
   426  		}
   427  
   428  		strKey := base64.StdEncoding.EncodeToString(key)
   429  		switch kind {
   430  		case GenerateRootRegular:
   431  			status, err = client.Sys().GenerateRootUpdate(strKey, status.Nonce)
   432  		case GenerateRootDR:
   433  			status, err = client.Sys().GenerateDROperationTokenUpdate(strKey, status.Nonce)
   434  		case GenerateRecovery:
   435  			status, err = client.Sys().GenerateRecoveryOperationTokenUpdate(strKey, status.Nonce)
   436  		}
   437  		if err != nil {
   438  			return "", err
   439  		}
   440  	}
   441  	if !status.Complete {
   442  		return "", fmt.Errorf("generate root operation did not end successfully")
   443  	}
   444  
   445  	tokenBytes, err := base64.RawStdEncoding.DecodeString(status.EncodedToken)
   446  	if err != nil {
   447  		return "", err
   448  	}
   449  	tokenBytes, err = xor.XORBytes(tokenBytes, []byte(otp))
   450  	if err != nil {
   451  		return "", err
   452  	}
   453  	return string(tokenBytes), nil
   454  }