github.com/bigcommerce/nomad@v0.9.3-bc/e2e/vault/vault_test.go (about)

     1  package vault
     2  
     3  import (
     4  	"archive/zip"
     5  	"bytes"
     6  	"context"
     7  	"flag"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"net/http"
    12  	"os"
    13  	"path/filepath"
    14  	"runtime"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/hashicorp/go-version"
    19  
    20  	"github.com/hashicorp/nomad/command/agent"
    21  	"github.com/hashicorp/nomad/helper"
    22  	"github.com/hashicorp/nomad/nomad/structs/config"
    23  	"github.com/hashicorp/nomad/testutil"
    24  	"github.com/stretchr/testify/require"
    25  	"golang.org/x/sync/errgroup"
    26  
    27  	vapi "github.com/hashicorp/vault/api"
    28  )
    29  
    30  var integration = flag.Bool("integration", false, "run integration tests")
    31  
    32  // harness is used to retrieve the required Vault test binaries
    33  type harness struct {
    34  	t      *testing.T
    35  	binDir string
    36  	os     string
    37  	arch   string
    38  }
    39  
    40  // newHarness returns a new Vault test harness.
    41  func newHarness(t *testing.T) *harness {
    42  	return &harness{
    43  		t:      t,
    44  		binDir: filepath.Join(os.TempDir(), "vault-bins/"),
    45  		os:     runtime.GOOS,
    46  		arch:   runtime.GOARCH,
    47  	}
    48  }
    49  
    50  // reconcile retrieves the desired binaries, returning a map of version to
    51  // binary path
    52  func (h *harness) reconcile() map[string]string {
    53  	// Get the binaries we need to download
    54  	missing := h.diff()
    55  
    56  	// Create the directory for the binaries
    57  	h.createBinDir()
    58  
    59  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
    60  	defer cancel()
    61  
    62  	g, _ := errgroup.WithContext(ctx)
    63  	for _, v := range missing {
    64  		version := v
    65  		g.Go(func() error {
    66  			return h.get(version)
    67  		})
    68  	}
    69  	if err := g.Wait(); err != nil {
    70  		h.t.Fatalf("failed getting versions: %v", err)
    71  	}
    72  
    73  	binaries := make(map[string]string, len(versions))
    74  	for _, v := range versions {
    75  		binaries[v] = filepath.Join(h.binDir, v)
    76  	}
    77  	return binaries
    78  }
    79  
    80  // createBinDir creates the binary directory
    81  func (h *harness) createBinDir() {
    82  	// Check if the directory exists, otherwise create it
    83  	f, err := os.Stat(h.binDir)
    84  	if err != nil && !os.IsNotExist(err) {
    85  		h.t.Fatalf("failed to stat directory: %v", err)
    86  	}
    87  
    88  	if f != nil && f.IsDir() {
    89  		return
    90  	} else if f != nil {
    91  		if err := os.RemoveAll(h.binDir); err != nil {
    92  			h.t.Fatalf("failed to remove file at directory path: %v", err)
    93  		}
    94  	}
    95  
    96  	// Create the directory
    97  	if err := os.Mkdir(h.binDir, 0700); err != nil {
    98  		h.t.Fatalf("failed to make directory: %v", err)
    99  	}
   100  	if err := os.Chmod(h.binDir, 0700); err != nil {
   101  		h.t.Fatalf("failed to chmod: %v", err)
   102  	}
   103  }
   104  
   105  // diff returns the binaries that must be downloaded
   106  func (h *harness) diff() (missing []string) {
   107  	files, err := ioutil.ReadDir(h.binDir)
   108  	if err != nil {
   109  		if os.IsNotExist(err) {
   110  			return versions
   111  		}
   112  
   113  		h.t.Fatalf("failed to stat directory: %v", err)
   114  	}
   115  
   116  	// Build the set we need
   117  	missingSet := make(map[string]struct{}, len(versions))
   118  	for _, v := range versions {
   119  		missingSet[v] = struct{}{}
   120  	}
   121  
   122  	for _, f := range files {
   123  		delete(missingSet, f.Name())
   124  	}
   125  
   126  	for k := range missingSet {
   127  		missing = append(missing, k)
   128  	}
   129  
   130  	return missing
   131  }
   132  
   133  // get retrieves the given Vault binary
   134  func (h *harness) get(version string) error {
   135  	resp, err := http.Get(
   136  		fmt.Sprintf("https://releases.hashicorp.com/vault/%s/vault_%s_%s_%s.zip",
   137  			version, version, h.os, h.arch))
   138  	if err != nil {
   139  		return err
   140  	}
   141  	defer resp.Body.Close()
   142  
   143  	// Wrap in an in-mem buffer
   144  	b := bytes.NewBuffer(nil)
   145  	io.Copy(b, resp.Body)
   146  	resp.Body.Close()
   147  
   148  	zreader, err := zip.NewReader(bytes.NewReader(b.Bytes()), resp.ContentLength)
   149  	if err != nil {
   150  		return err
   151  	}
   152  
   153  	if l := len(zreader.File); l != 1 {
   154  		return fmt.Errorf("unexpected number of files in zip: %v", l)
   155  	}
   156  
   157  	// Copy the file to its destination
   158  	file := filepath.Join(h.binDir, version)
   159  	out, err := os.OpenFile(file, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0777)
   160  	if err != nil {
   161  		return err
   162  	}
   163  	defer out.Close()
   164  
   165  	zfile, err := zreader.File[0].Open()
   166  	if err != nil {
   167  		return fmt.Errorf("failed to open zip file: %v", err)
   168  	}
   169  
   170  	if _, err := io.Copy(out, zfile); err != nil {
   171  		return fmt.Errorf("failed to decompress file to destination: %v", err)
   172  	}
   173  
   174  	return nil
   175  }
   176  
   177  // TestVaultCompatibility tests compatibility across Vault versions
   178  func TestVaultCompatibility(t *testing.T) {
   179  	if !*integration {
   180  		t.Skip("skipping test in non-integration mode.")
   181  	}
   182  
   183  	h := newHarness(t)
   184  	vaultBinaries := h.reconcile()
   185  
   186  	for version, vaultBin := range vaultBinaries {
   187  		vbin := vaultBin
   188  		t.Run(version, func(t *testing.T) {
   189  			testVaultCompatibility(t, vbin, version)
   190  		})
   191  	}
   192  }
   193  
   194  // testVaultCompatibility tests compatibility with the given vault binary
   195  func testVaultCompatibility(t *testing.T, vault string, version string) {
   196  	require := require.New(t)
   197  
   198  	// Create a Vault server
   199  	v := testutil.NewTestVaultFromPath(t, vault)
   200  	defer v.Stop()
   201  
   202  	token := setupVault(t, v.Client, version)
   203  
   204  	// Create a Nomad agent using the created vault
   205  	nomad := agent.NewTestAgent(t, t.Name(), func(c *agent.Config) {
   206  		if c.Vault == nil {
   207  			c.Vault = &config.VaultConfig{}
   208  		}
   209  		c.Vault.Enabled = helper.BoolToPtr(true)
   210  		c.Vault.Token = token
   211  		c.Vault.Role = "nomad-cluster"
   212  		c.Vault.AllowUnauthenticated = helper.BoolToPtr(true)
   213  		c.Vault.Addr = v.HTTPAddr
   214  	})
   215  	defer nomad.Shutdown()
   216  
   217  	// Submit the Nomad job that requests a Vault token and cats that the Vault
   218  	// token is there
   219  	c := nomad.Client()
   220  	j := c.Jobs()
   221  	_, _, err := j.Register(job, nil)
   222  	require.NoError(err)
   223  
   224  	// Wait for there to be an allocation terminated successfully
   225  	//var allocID string
   226  	testutil.WaitForResult(func() (bool, error) {
   227  		// Get the allocations for the job
   228  		allocs, _, err := j.Allocations(*job.ID, false, nil)
   229  		if err != nil {
   230  			return false, err
   231  		}
   232  		l := len(allocs)
   233  		switch l {
   234  		case 0:
   235  			return false, fmt.Errorf("want one alloc; got zero")
   236  		case 1:
   237  		default:
   238  			// exit early
   239  			t.Fatalf("too many allocations; something failed")
   240  		}
   241  		alloc := allocs[0]
   242  		//allocID = alloc.ID
   243  		if alloc.ClientStatus == "complete" {
   244  			return true, nil
   245  		}
   246  
   247  		return false, fmt.Errorf("client status %q", alloc.ClientStatus)
   248  	}, func(err error) {
   249  		t.Fatalf("allocation did not finish: %v", err)
   250  	})
   251  
   252  }
   253  
   254  // setupVault takes the Vault client and creates the required policies and
   255  // roles. It returns the token that should be used by Nomad
   256  func setupVault(t *testing.T, client *vapi.Client, vaultVersion string) string {
   257  	// Write the policy
   258  	sys := client.Sys()
   259  
   260  	// pre-0.9.0 vault servers do not work with our new vault client for the policy endpoint
   261  	// perform this using a raw HTTP request
   262  	newApi, _ := version.NewVersion("0.9.0")
   263  	testVersion, err := version.NewVersion(vaultVersion)
   264  	if err != nil {
   265  		t.Fatalf("failed to parse test version from '%v': %v", t.Name(), err)
   266  	}
   267  	if testVersion.LessThan(newApi) {
   268  		body := map[string]string{
   269  			"rules": policy,
   270  		}
   271  		request := client.NewRequest("PUT", fmt.Sprintf("/v1/sys/policy/%s", "nomad-server"))
   272  		if err := request.SetJSONBody(body); err != nil {
   273  			t.Fatalf("failed to set JSON body on legacy policy creation: %v", err)
   274  		}
   275  		if _, err := client.RawRequest(request); err != nil {
   276  			t.Fatalf("failed to create legacy policy: %v", err)
   277  		}
   278  	} else {
   279  		if err := sys.PutPolicy("nomad-server", policy); err != nil {
   280  			t.Fatalf("failed to create policy: %v", err)
   281  		}
   282  	}
   283  
   284  	// Build the role
   285  	l := client.Logical()
   286  	l.Write("auth/token/roles/nomad-cluster", role)
   287  
   288  	// Create a new token with the role
   289  	a := client.Auth().Token()
   290  	req := vapi.TokenCreateRequest{
   291  		Policies: []string{"nomad-server"},
   292  		Period:   "72h",
   293  		NoParent: true,
   294  	}
   295  	s, err := a.Create(&req)
   296  	if err != nil {
   297  		t.Fatalf("failed to create child token: %v", err)
   298  	}
   299  
   300  	// Get the client token
   301  	if s == nil || s.Auth == nil {
   302  		t.Fatalf("bad secret response: %+v", s)
   303  	}
   304  
   305  	return s.Auth.ClientToken
   306  }