github.com/uchennaokeke444/nomad@v0.11.8/testutil/vault.go (about)

     1  package testutil
     2  
     3  import (
     4  	"fmt"
     5  	"math/rand"
     6  	"os"
     7  	"os/exec"
     8  	"time"
     9  
    10  	"github.com/hashicorp/nomad/helper/freeport"
    11  	"github.com/hashicorp/nomad/helper/testlog"
    12  	"github.com/hashicorp/nomad/helper/uuid"
    13  	"github.com/hashicorp/nomad/nomad/structs/config"
    14  	vapi "github.com/hashicorp/vault/api"
    15  	testing "github.com/mitchellh/go-testing-interface"
    16  	"github.com/stretchr/testify/require"
    17  )
    18  
    19  // TestVault is a test helper. It uses a fork/exec model to create a test Vault
    20  // server instance in the background and can be initialized with policies, roles
    21  // and backends mounted. The test Vault instances can be used to run a unit test
    22  // and offers and easy API to tear itself down on test end. The only
    23  // prerequisite is that the Vault binary is on the $PATH.
    24  
    25  // TestVault wraps a test Vault server launched in dev mode, suitable for
    26  // testing.
    27  type TestVault struct {
    28  	cmd    *exec.Cmd
    29  	t      testing.T
    30  	waitCh chan error
    31  
    32  	// ports (if any) that are reserved through freeport that must be returned
    33  	// at the end of a test, done when Stop() is called.
    34  	ports []int
    35  
    36  	Addr      string
    37  	HTTPAddr  string
    38  	RootToken string
    39  	Config    *config.VaultConfig
    40  	Client    *vapi.Client
    41  }
    42  
    43  func NewTestVaultFromPath(t testing.T, binary string) *TestVault {
    44  	var ports []int
    45  	nextPort := func() int {
    46  		next := freeport.MustTake(1)
    47  		ports = append(ports, next...)
    48  		return next[0]
    49  	}
    50  
    51  	for i := 10; i >= 0; i-- {
    52  
    53  		port := nextPort() // collect every port for cleanup after the test
    54  
    55  		token := uuid.Generate()
    56  		bind := fmt.Sprintf("-dev-listen-address=127.0.0.1:%d", port)
    57  		http := fmt.Sprintf("http://127.0.0.1:%d", port)
    58  		root := fmt.Sprintf("-dev-root-token-id=%s", token)
    59  
    60  		cmd := exec.Command(binary, "server", "-dev", bind, root)
    61  		cmd.Stdout = testlog.NewWriter(t)
    62  		cmd.Stderr = testlog.NewWriter(t)
    63  
    64  		// Build the config
    65  		conf := vapi.DefaultConfig()
    66  		conf.Address = http
    67  
    68  		// Make the client and set the token to the root token
    69  		client, err := vapi.NewClient(conf)
    70  		if err != nil {
    71  			t.Fatalf("failed to build Vault API client: %v", err)
    72  		}
    73  		client.SetToken(token)
    74  
    75  		enable := true
    76  		tv := &TestVault{
    77  			cmd:       cmd,
    78  			t:         t,
    79  			ports:     ports,
    80  			Addr:      bind,
    81  			HTTPAddr:  http,
    82  			RootToken: token,
    83  			Client:    client,
    84  			Config: &config.VaultConfig{
    85  				Enabled: &enable,
    86  				Token:   token,
    87  				Addr:    http,
    88  			},
    89  		}
    90  
    91  		if err := tv.cmd.Start(); err != nil {
    92  			tv.t.Fatalf("failed to start vault: %v", err)
    93  		}
    94  
    95  		// Start the waiter
    96  		tv.waitCh = make(chan error, 1)
    97  		go func() {
    98  			err := tv.cmd.Wait()
    99  			tv.waitCh <- err
   100  		}()
   101  
   102  		// Ensure Vault started
   103  		var startErr error
   104  		select {
   105  		case startErr = <-tv.waitCh:
   106  		case <-time.After(time.Duration(500*TestMultiplier()) * time.Millisecond):
   107  		}
   108  
   109  		if startErr != nil && i == 0 {
   110  			t.Fatalf("failed to start vault: %v", startErr)
   111  		} else if startErr != nil {
   112  			wait := time.Duration(rand.Int31n(2000)) * time.Millisecond
   113  			time.Sleep(wait)
   114  			continue
   115  		}
   116  
   117  		waitErr := tv.waitForAPI()
   118  		if waitErr != nil && i == 0 {
   119  			t.Fatalf("failed to start vault: %v", waitErr)
   120  		} else if waitErr != nil {
   121  			wait := time.Duration(rand.Int31n(2000)) * time.Millisecond
   122  			time.Sleep(wait)
   123  			continue
   124  		}
   125  
   126  		return tv
   127  	}
   128  
   129  	return nil
   130  
   131  }
   132  
   133  // NewTestVault returns a new TestVault instance that has yet to be started
   134  func NewTestVault(t testing.T) *TestVault {
   135  	// Lookup vault from the path
   136  	return NewTestVaultFromPath(t, "vault")
   137  }
   138  
   139  // NewTestVaultDelayed returns a test Vault server that has not been started.
   140  // Start must be called and it is the callers responsibility to deal with any
   141  // port conflicts that may occur and retry accordingly.
   142  func NewTestVaultDelayed(t testing.T) *TestVault {
   143  	port := freeport.MustTake(1)[0]
   144  	token := uuid.Generate()
   145  	bind := fmt.Sprintf("-dev-listen-address=127.0.0.1:%d", port)
   146  	http := fmt.Sprintf("http://127.0.0.1:%d", port)
   147  	root := fmt.Sprintf("-dev-root-token-id=%s", token)
   148  
   149  	cmd := exec.Command("vault", "server", "-dev", bind, root)
   150  	cmd.Stdout = os.Stdout
   151  	cmd.Stderr = os.Stderr
   152  
   153  	// Build the config
   154  	conf := vapi.DefaultConfig()
   155  	conf.Address = http
   156  
   157  	// Make the client and set the token to the root token
   158  	client, err := vapi.NewClient(conf)
   159  	if err != nil {
   160  		t.Fatalf("failed to build Vault API client: %v", err)
   161  	}
   162  	client.SetToken(token)
   163  
   164  	enable := true
   165  	tv := &TestVault{
   166  		cmd:       cmd,
   167  		t:         t,
   168  		Addr:      bind,
   169  		HTTPAddr:  http,
   170  		RootToken: token,
   171  		Client:    client,
   172  		Config: &config.VaultConfig{
   173  			Enabled: &enable,
   174  			Token:   token,
   175  			Addr:    http,
   176  		},
   177  	}
   178  
   179  	return tv
   180  }
   181  
   182  // Start starts the test Vault server and waits for it to respond to its HTTP
   183  // API
   184  func (tv *TestVault) Start() error {
   185  	// Start the waiter
   186  	tv.waitCh = make(chan error, 1)
   187  
   188  	go func() {
   189  		// Must call Start and Wait in the same goroutine on Windows #5174
   190  		if err := tv.cmd.Start(); err != nil {
   191  			tv.waitCh <- err
   192  			return
   193  		}
   194  
   195  		err := tv.cmd.Wait()
   196  		tv.waitCh <- err
   197  	}()
   198  
   199  	// Ensure Vault started
   200  	select {
   201  	case err := <-tv.waitCh:
   202  		return err
   203  	case <-time.After(time.Duration(500*TestMultiplier()) * time.Millisecond):
   204  	}
   205  
   206  	return tv.waitForAPI()
   207  }
   208  
   209  // Stop stops the test Vault server
   210  func (tv *TestVault) Stop() {
   211  	defer freeport.Return(tv.ports)
   212  
   213  	if tv.cmd.Process == nil {
   214  		return
   215  	}
   216  
   217  	if err := tv.cmd.Process.Kill(); err != nil {
   218  		tv.t.Errorf("err: %s", err)
   219  	}
   220  	if tv.waitCh != nil {
   221  		select {
   222  		case <-tv.waitCh:
   223  			return
   224  		case <-time.After(1 * time.Second):
   225  			require.Fail(tv.t, "Timed out waiting for vault to terminate")
   226  		}
   227  	}
   228  }
   229  
   230  // waitForAPI waits for the Vault HTTP endpoint to start
   231  // responding. This is an indication that the agent has started.
   232  func (tv *TestVault) waitForAPI() error {
   233  	var waitErr error
   234  	WaitForResult(func() (bool, error) {
   235  		inited, err := tv.Client.Sys().InitStatus()
   236  		if err != nil {
   237  			return false, err
   238  		}
   239  		return inited, nil
   240  	}, func(err error) {
   241  		waitErr = err
   242  	})
   243  	return waitErr
   244  }
   245  
   246  // VaultVersion returns the Vault version as a string or an error if it couldn't
   247  // be determined
   248  func VaultVersion() (string, error) {
   249  	cmd := exec.Command("vault", "version")
   250  	out, err := cmd.Output()
   251  	return string(out), err
   252  }