github.com/hashicorp/vault/sdk@v0.11.0/helper/testcluster/exec.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package testcluster
     5  
     6  import (
     7  	"bufio"
     8  	"context"
     9  	"crypto/tls"
    10  	"fmt"
    11  	"os"
    12  	"os/exec"
    13  	"path/filepath"
    14  	"strings"
    15  	"testing"
    16  	"time"
    17  
    18  	log "github.com/hashicorp/go-hclog"
    19  	"github.com/hashicorp/vault/api"
    20  	"github.com/hashicorp/vault/sdk/helper/jsonutil"
    21  	"github.com/hashicorp/vault/sdk/helper/logging"
    22  )
    23  
    24  type ExecDevCluster struct {
    25  	ID                 string
    26  	ClusterName        string
    27  	ClusterNodes       []*execDevClusterNode
    28  	CACertPEMFile      string
    29  	barrierKeys        [][]byte
    30  	recoveryKeys       [][]byte
    31  	tmpDir             string
    32  	clientAuthRequired bool
    33  	rootToken          string
    34  	stop               func()
    35  	stopCh             chan struct{}
    36  	Logger             log.Logger
    37  }
    38  
    39  func (dc *ExecDevCluster) SetRootToken(token string) {
    40  	dc.rootToken = token
    41  }
    42  
    43  func (dc *ExecDevCluster) NamedLogger(s string) log.Logger {
    44  	return dc.Logger.Named(s)
    45  }
    46  
    47  var _ VaultCluster = &ExecDevCluster{}
    48  
    49  type ExecDevClusterOptions struct {
    50  	ClusterOptions
    51  	BinaryPath string
    52  	// this is -dev-listen-address, defaults to "127.0.0.1:8200"
    53  	BaseListenAddress string
    54  }
    55  
    56  func NewTestExecDevCluster(t *testing.T, opts *ExecDevClusterOptions) *ExecDevCluster {
    57  	if opts == nil {
    58  		opts = &ExecDevClusterOptions{}
    59  	}
    60  	if opts.ClusterName == "" {
    61  		opts.ClusterName = strings.ReplaceAll(t.Name(), "/", "-")
    62  	}
    63  	if opts.Logger == nil {
    64  		opts.Logger = logging.NewVaultLogger(log.Trace).Named(t.Name()) // .Named("container")
    65  	}
    66  	if opts.VaultLicense == "" {
    67  		opts.VaultLicense = os.Getenv(EnvVaultLicenseCI)
    68  	}
    69  
    70  	ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
    71  	t.Cleanup(cancel)
    72  
    73  	dc, err := NewExecDevCluster(ctx, opts)
    74  	if err != nil {
    75  		t.Fatal(err)
    76  	}
    77  	return dc
    78  }
    79  
    80  func NewExecDevCluster(ctx context.Context, opts *ExecDevClusterOptions) (*ExecDevCluster, error) {
    81  	dc := &ExecDevCluster{
    82  		ClusterName: opts.ClusterName,
    83  		stopCh:      make(chan struct{}),
    84  	}
    85  
    86  	if opts == nil {
    87  		opts = &ExecDevClusterOptions{}
    88  	}
    89  	if opts.NumCores == 0 {
    90  		opts.NumCores = 3
    91  	}
    92  	if err := dc.setupExecDevCluster(ctx, opts); err != nil {
    93  		dc.Cleanup()
    94  		return nil, err
    95  	}
    96  
    97  	return dc, nil
    98  }
    99  
   100  func (dc *ExecDevCluster) setupExecDevCluster(ctx context.Context, opts *ExecDevClusterOptions) (retErr error) {
   101  	if opts == nil {
   102  		opts = &ExecDevClusterOptions{}
   103  	}
   104  	if opts.Logger == nil {
   105  		opts.Logger = log.NewNullLogger()
   106  	}
   107  	dc.Logger = opts.Logger
   108  
   109  	if opts.TmpDir != "" {
   110  		if _, err := os.Stat(opts.TmpDir); os.IsNotExist(err) {
   111  			if err := os.MkdirAll(opts.TmpDir, 0o700); err != nil {
   112  				return err
   113  			}
   114  		}
   115  		dc.tmpDir = opts.TmpDir
   116  	} else {
   117  		tempDir, err := os.MkdirTemp("", "vault-test-cluster-")
   118  		if err != nil {
   119  			return err
   120  		}
   121  		dc.tmpDir = tempDir
   122  	}
   123  
   124  	// This context is used to stop the subprocess
   125  	execCtx, cancel := context.WithCancel(context.Background())
   126  	dc.stop = func() {
   127  		cancel()
   128  		close(dc.stopCh)
   129  	}
   130  	defer func() {
   131  		if retErr != nil {
   132  			cancel()
   133  		}
   134  	}()
   135  
   136  	bin := opts.BinaryPath
   137  	if bin == "" {
   138  		bin = "vault"
   139  	}
   140  
   141  	clusterJsonPath := filepath.Join(dc.tmpDir, "cluster.json")
   142  	args := []string{"server", "-dev", "-dev-cluster-json", clusterJsonPath}
   143  	switch {
   144  	case opts.NumCores == 3:
   145  		args = append(args, "-dev-three-node")
   146  	case opts.NumCores == 1:
   147  		args = append(args, "-dev-tls")
   148  	default:
   149  		return fmt.Errorf("NumCores=1 and NumCores=3 are the only supported options right now")
   150  	}
   151  	if opts.BaseListenAddress != "" {
   152  		args = append(args, "-dev-listen-address", opts.BaseListenAddress)
   153  	}
   154  	cmd := exec.CommandContext(execCtx, bin, args...)
   155  	cmd.Env = os.Environ()
   156  	cmd.Env = append(cmd.Env, "VAULT_LICENSE="+opts.VaultLicense)
   157  	cmd.Env = append(cmd.Env, "VAULT_LOG_FORMAT=json")
   158  	cmd.Env = append(cmd.Env, "VAULT_DEV_TEMP_DIR="+dc.tmpDir)
   159  	if opts.Logger != nil {
   160  		stdout, err := cmd.StdoutPipe()
   161  		if err != nil {
   162  			return err
   163  		}
   164  		go func() {
   165  			outlog := opts.Logger.Named("stdout")
   166  			scanner := bufio.NewScanner(stdout)
   167  			for scanner.Scan() {
   168  				outlog.Trace(scanner.Text())
   169  			}
   170  		}()
   171  		stderr, err := cmd.StderrPipe()
   172  		if err != nil {
   173  			return err
   174  		}
   175  		go func() {
   176  			errlog := opts.Logger.Named("stderr")
   177  			scanner := bufio.NewScanner(stderr)
   178  			// The default buffer is 4k, and Vault can emit bigger log lines
   179  			scanner.Buffer(make([]byte, 64*1024), bufio.MaxScanTokenSize)
   180  			for scanner.Scan() {
   181  				JSONLogNoTimestamp(errlog, scanner.Text())
   182  			}
   183  		}()
   184  	}
   185  
   186  	if err := cmd.Start(); err != nil {
   187  		return err
   188  	}
   189  
   190  	for ctx.Err() == nil {
   191  		if b, err := os.ReadFile(clusterJsonPath); err == nil && len(b) > 0 {
   192  			var clusterJson ClusterJson
   193  			if err := jsonutil.DecodeJSON(b, &clusterJson); err != nil {
   194  				continue
   195  			}
   196  			dc.CACertPEMFile = clusterJson.CACertPath
   197  			dc.rootToken = clusterJson.RootToken
   198  			for i, node := range clusterJson.Nodes {
   199  				config := api.DefaultConfig()
   200  				config.Address = node.APIAddress
   201  				err := config.ConfigureTLS(&api.TLSConfig{
   202  					CACert: clusterJson.CACertPath,
   203  				})
   204  				if err != nil {
   205  					return err
   206  				}
   207  				client, err := api.NewClient(config)
   208  				if err != nil {
   209  					return err
   210  				}
   211  				client.SetToken(dc.rootToken)
   212  				_, err = client.Sys().ListMounts()
   213  				if err != nil {
   214  					return err
   215  				}
   216  
   217  				dc.ClusterNodes = append(dc.ClusterNodes, &execDevClusterNode{
   218  					name:   fmt.Sprintf("core-%d", i),
   219  					client: client,
   220  				})
   221  			}
   222  			return nil
   223  		}
   224  		time.Sleep(500 * time.Millisecond)
   225  	}
   226  	return ctx.Err()
   227  }
   228  
   229  type execDevClusterNode struct {
   230  	name   string
   231  	client *api.Client
   232  }
   233  
   234  var _ VaultClusterNode = &execDevClusterNode{}
   235  
   236  func (e *execDevClusterNode) Name() string {
   237  	return e.name
   238  }
   239  
   240  func (e *execDevClusterNode) APIClient() *api.Client {
   241  	// We clone to ensure that whenever this method is called, the caller gets
   242  	// back a pristine client, without e.g. any namespace or token changes that
   243  	// might pollute a shared client.  We clone the config instead of the
   244  	// client because (1) Client.clone propagates the replicationStateStore and
   245  	// the httpClient pointers, (2) it doesn't copy the tlsConfig at all, and
   246  	// (3) if clone returns an error, it doesn't feel as appropriate to panic
   247  	// below.  Who knows why clone might return an error?
   248  	cfg := e.client.CloneConfig()
   249  	client, err := api.NewClient(cfg)
   250  	if err != nil {
   251  		// It seems fine to panic here, since this should be the same input
   252  		// we provided to NewClient when we were setup, and we didn't panic then.
   253  		// Better not to completely ignore the error though, suppose there's a
   254  		// bug in CloneConfig?
   255  		panic(fmt.Sprintf("NewClient error on cloned config: %v", err))
   256  	}
   257  	client.SetToken(e.client.Token())
   258  	return client
   259  }
   260  
   261  func (e *execDevClusterNode) TLSConfig() *tls.Config {
   262  	return e.client.CloneConfig().TLSConfig()
   263  }
   264  
   265  func (dc *ExecDevCluster) ClusterID() string {
   266  	return dc.ID
   267  }
   268  
   269  func (dc *ExecDevCluster) Nodes() []VaultClusterNode {
   270  	ret := make([]VaultClusterNode, len(dc.ClusterNodes))
   271  	for i := range dc.ClusterNodes {
   272  		ret[i] = dc.ClusterNodes[i]
   273  	}
   274  	return ret
   275  }
   276  
   277  func (dc *ExecDevCluster) GetBarrierKeys() [][]byte {
   278  	return dc.barrierKeys
   279  }
   280  
   281  func copyKey(key []byte) []byte {
   282  	result := make([]byte, len(key))
   283  	copy(result, key)
   284  	return result
   285  }
   286  
   287  func (dc *ExecDevCluster) GetRecoveryKeys() [][]byte {
   288  	ret := make([][]byte, len(dc.recoveryKeys))
   289  	for i, k := range dc.recoveryKeys {
   290  		ret[i] = copyKey(k)
   291  	}
   292  	return ret
   293  }
   294  
   295  func (dc *ExecDevCluster) GetBarrierOrRecoveryKeys() [][]byte {
   296  	return dc.GetBarrierKeys()
   297  }
   298  
   299  func (dc *ExecDevCluster) SetBarrierKeys(keys [][]byte) {
   300  	dc.barrierKeys = make([][]byte, len(keys))
   301  	for i, k := range keys {
   302  		dc.barrierKeys[i] = copyKey(k)
   303  	}
   304  }
   305  
   306  func (dc *ExecDevCluster) SetRecoveryKeys(keys [][]byte) {
   307  	dc.recoveryKeys = make([][]byte, len(keys))
   308  	for i, k := range keys {
   309  		dc.recoveryKeys[i] = copyKey(k)
   310  	}
   311  }
   312  
   313  func (dc *ExecDevCluster) GetCACertPEMFile() string {
   314  	return dc.CACertPEMFile
   315  }
   316  
   317  func (dc *ExecDevCluster) Cleanup() {
   318  	dc.stop()
   319  }
   320  
   321  // GetRootToken returns the root token of the cluster, if set
   322  func (dc *ExecDevCluster) GetRootToken() string {
   323  	return dc.rootToken
   324  }