github.com/iqoqo/nomad@v0.11.3-0.20200911112621-d7021c74d101/e2e/vault/vault_test.go (about)

     1  package vault
     2  
     3  import (
     4  	"archive/zip"
     5  	"bytes"
     6  	"encoding/json"
     7  	"flag"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"net/http"
    12  	"os"
    13  	"path/filepath"
    14  	"runtime"
    15  	"sort"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/hashicorp/go-version"
    20  	"github.com/hashicorp/nomad/command/agent"
    21  	"github.com/hashicorp/nomad/helper"
    22  	"github.com/hashicorp/nomad/nomad/structs/config"
    23  	"github.com/hashicorp/nomad/testutil"
    24  	vapi "github.com/hashicorp/vault/api"
    25  	"github.com/stretchr/testify/require"
    26  )
    27  
    28  var (
    29  	integration = flag.Bool("integration", false, "run integration tests")
    30  	minVaultVer = version.Must(version.NewVersion("0.6.2"))
    31  )
    32  
    33  // syncVault discovers available versions of Vault, downloads the binaries,
    34  // returns a map of version to binary path as well as a sorted list of
    35  // versions.
    36  func syncVault(t *testing.T) ([]*version.Version, map[string]string) {
    37  
    38  	binDir := filepath.Join(os.TempDir(), "vault-bins/")
    39  
    40  	urls := vaultVersions(t)
    41  
    42  	sorted, versions, err := pruneVersions(urls)
    43  	require.NoError(t, err)
    44  
    45  	// Get the binaries we need to download
    46  	missing, err := missingVault(binDir, versions)
    47  	require.NoError(t, err)
    48  
    49  	// Create the directory for the binaries
    50  	require.NoError(t, createBinDir(binDir))
    51  
    52  	// Download in parallel
    53  	start := time.Now()
    54  	errCh := make(chan error, len(missing))
    55  	for ver, url := range missing {
    56  		go func(dst, url string) {
    57  			errCh <- getVault(dst, url)
    58  		}(filepath.Join(binDir, ver), url)
    59  	}
    60  	for i := 0; i < len(missing); i++ {
    61  		select {
    62  		case err := <-errCh:
    63  			require.NoError(t, err)
    64  		case <-time.After(5 * time.Minute):
    65  			t.Fatalf("timed out downloading Vault binaries")
    66  		}
    67  	}
    68  	if n := len(missing); n > 0 {
    69  		t.Logf("Downloaded %d versions of Vault in %s", n, time.Now().Sub(start))
    70  	}
    71  
    72  	binaries := make(map[string]string, len(versions))
    73  	for ver, _ := range versions {
    74  		binaries[ver] = filepath.Join(binDir, ver)
    75  	}
    76  	return sorted, binaries
    77  }
    78  
    79  // vaultVersions discovers available Vault versions from releases.hashicorp.com
    80  // and returns a map of version to url.
    81  func vaultVersions(t *testing.T) map[string]string {
    82  	resp, err := http.Get("https://releases.hashicorp.com/vault/index.json")
    83  	require.NoError(t, err)
    84  
    85  	respJson := struct {
    86  		Versions map[string]struct {
    87  			Builds []struct {
    88  				Version string `json:"version"`
    89  				Os      string `json:"os"`
    90  				Arch    string `json:"arch"`
    91  				URL     string `json:"url"`
    92  			} `json:"builds"`
    93  		}
    94  	}{}
    95  	require.NoError(t, json.NewDecoder(resp.Body).Decode(&respJson))
    96  	require.NoError(t, resp.Body.Close())
    97  
    98  	versions := map[string]string{}
    99  	for vk, vv := range respJson.Versions {
   100  		gover, err := version.NewVersion(vk)
   101  		if err != nil {
   102  			t.Logf("error parsing Vault version %q -> %v", vk, err)
   103  			continue
   104  		}
   105  
   106  		// Skip ancient versions
   107  		if gover.LessThan(minVaultVer) {
   108  			continue
   109  		}
   110  
   111  		// Skip prerelease and enterprise versions
   112  		if gover.Prerelease() != "" || gover.Metadata() != "" {
   113  			continue
   114  		}
   115  
   116  		url := ""
   117  		for _, b := range vv.Builds {
   118  			buildver, err := version.NewVersion(b.Version)
   119  			if err != nil {
   120  				t.Logf("error parsing Vault build version %q -> %v", b.Version, err)
   121  				continue
   122  			}
   123  
   124  			if buildver.Prerelease() != "" {
   125  				continue
   126  			}
   127  
   128  			if buildver.Metadata() != "" {
   129  				continue
   130  			}
   131  
   132  			if b.Os != runtime.GOOS {
   133  				continue
   134  			}
   135  
   136  			if b.Arch != runtime.GOARCH {
   137  				continue
   138  			}
   139  
   140  			// Match!
   141  			url = b.URL
   142  			break
   143  		}
   144  
   145  		if url != "" {
   146  			versions[vk] = url
   147  		}
   148  	}
   149  
   150  	return versions
   151  }
   152  
   153  // pruneVersions only takes the latest Z for each X.Y.Z release. Returns a
   154  // sorted list and map of kept versions.
   155  func pruneVersions(all map[string]string) ([]*version.Version, map[string]string, error) {
   156  	if len(all) == 0 {
   157  		return nil, nil, fmt.Errorf("0 Vault versions")
   158  	}
   159  
   160  	sorted := make([]*version.Version, 0, len(all))
   161  
   162  	for k := range all {
   163  		sorted = append(sorted, version.Must(version.NewVersion(k)))
   164  	}
   165  
   166  	sort.Sort(version.Collection(sorted))
   167  
   168  	keep := make([]*version.Version, 0, len(all))
   169  
   170  	for _, v := range sorted {
   171  		segments := v.Segments()
   172  		if len(segments) < 3 {
   173  			// Drop malformed versions
   174  			continue
   175  		}
   176  
   177  		if len(keep) == 0 {
   178  			keep = append(keep, v)
   179  			continue
   180  		}
   181  
   182  		last := keep[len(keep)-1].Segments()
   183  
   184  		if segments[0] == last[0] && segments[1] == last[1] {
   185  			// current X.Y == last X.Y, replace last with current
   186  			keep[len(keep)-1] = v
   187  		} else {
   188  			// current X.Y != last X.Y, append
   189  			keep = append(keep, v)
   190  		}
   191  	}
   192  
   193  	// Create a new map of canonicalized versions to urls
   194  	urls := make(map[string]string, len(keep))
   195  	for _, v := range keep {
   196  		origURL := all[v.Original()]
   197  		if origURL == "" {
   198  			return nil, nil, fmt.Errorf("missing version %s", v.Original())
   199  		}
   200  		urls[v.String()] = origURL
   201  	}
   202  
   203  	return keep, urls, nil
   204  }
   205  
   206  // createBinDir creates the binary directory
   207  func createBinDir(binDir string) error {
   208  	// Check if the directory exists, otherwise create it
   209  	f, err := os.Stat(binDir)
   210  	if err != nil && !os.IsNotExist(err) {
   211  		return fmt.Errorf("failed to stat directory: %v", err)
   212  	}
   213  
   214  	if f != nil && f.IsDir() {
   215  		return nil
   216  	} else if f != nil {
   217  		if err := os.RemoveAll(binDir); err != nil {
   218  			return fmt.Errorf("failed to remove file at directory path: %v", err)
   219  		}
   220  	}
   221  
   222  	// Create the directory
   223  	if err := os.Mkdir(binDir, 075); err != nil {
   224  		return fmt.Errorf("failed to make directory: %v", err)
   225  	}
   226  	if err := os.Chmod(binDir, 0755); err != nil {
   227  		return fmt.Errorf("failed to chmod: %v", err)
   228  	}
   229  
   230  	return nil
   231  }
   232  
   233  // missingVault returns the binaries that must be downloaded. versions key must
   234  // be the Vault version.
   235  func missingVault(binDir string, versions map[string]string) (map[string]string, error) {
   236  	files, err := ioutil.ReadDir(binDir)
   237  	if err != nil {
   238  		if os.IsNotExist(err) {
   239  			return versions, nil
   240  		}
   241  
   242  		return nil, fmt.Errorf("failed to stat directory: %v", err)
   243  	}
   244  
   245  	// Copy versions so we don't mutate it
   246  	missingSet := make(map[string]string, len(versions))
   247  	for k, v := range versions {
   248  		missingSet[k] = v
   249  	}
   250  
   251  	for _, f := range files {
   252  		delete(missingSet, f.Name())
   253  	}
   254  
   255  	return missingSet, nil
   256  }
   257  
   258  // getVault downloads the given Vault binary
   259  func getVault(dst, url string) error {
   260  	resp, err := http.Get(url)
   261  	if err != nil {
   262  		return err
   263  	}
   264  	defer resp.Body.Close()
   265  
   266  	// Wrap in an in-mem buffer
   267  	b := bytes.NewBuffer(nil)
   268  	if _, err := io.Copy(b, resp.Body); err != nil {
   269  		return fmt.Errorf("error reading response body: %v", err)
   270  	}
   271  	resp.Body.Close()
   272  
   273  	zreader, err := zip.NewReader(bytes.NewReader(b.Bytes()), resp.ContentLength)
   274  	if err != nil {
   275  		return err
   276  	}
   277  
   278  	if l := len(zreader.File); l != 1 {
   279  		return fmt.Errorf("unexpected number of files in zip: %v", l)
   280  	}
   281  
   282  	// Copy the file to its destination
   283  	out, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0777)
   284  	if err != nil {
   285  		return err
   286  	}
   287  	defer out.Close()
   288  
   289  	zfile, err := zreader.File[0].Open()
   290  	if err != nil {
   291  		return fmt.Errorf("failed to open zip file: %v", err)
   292  	}
   293  
   294  	if _, err := io.Copy(out, zfile); err != nil {
   295  		return fmt.Errorf("failed to decompress file to destination: %v", err)
   296  	}
   297  
   298  	return nil
   299  }
   300  
   301  // TestVaultCompatibility tests compatibility across Vault versions
   302  func TestVaultCompatibility(t *testing.T) {
   303  	if !*integration {
   304  		t.Skip("skipping test in non-integration mode: add -integration flag to run")
   305  	}
   306  
   307  	sorted, vaultBinaries := syncVault(t)
   308  
   309  	for _, v := range sorted {
   310  		ver := v.String()
   311  		bin := vaultBinaries[ver]
   312  		require.NotZerof(t, bin, "missing version: %s", ver)
   313  		t.Run(ver, func(t *testing.T) {
   314  			testVaultCompatibility(t, bin, ver)
   315  		})
   316  	}
   317  }
   318  
   319  // testVaultCompatibility tests compatibility with the given vault binary
   320  func testVaultCompatibility(t *testing.T, vault string, version string) {
   321  	require := require.New(t)
   322  
   323  	// Create a Vault server
   324  	v := testutil.NewTestVaultFromPath(t, vault)
   325  	defer v.Stop()
   326  
   327  	token := setupVault(t, v.Client, version)
   328  
   329  	// Create a Nomad agent using the created vault
   330  	nomad := agent.NewTestAgent(t, t.Name(), func(c *agent.Config) {
   331  		if c.Vault == nil {
   332  			c.Vault = &config.VaultConfig{}
   333  		}
   334  		c.Vault.Enabled = helper.BoolToPtr(true)
   335  		c.Vault.Token = token
   336  		c.Vault.Role = "nomad-cluster"
   337  		c.Vault.AllowUnauthenticated = helper.BoolToPtr(true)
   338  		c.Vault.Addr = v.HTTPAddr
   339  	})
   340  	defer nomad.Shutdown()
   341  
   342  	// Submit the Nomad job that requests a Vault token and cats that the Vault
   343  	// token is there
   344  	c := nomad.Client()
   345  	j := c.Jobs()
   346  	_, _, err := j.Register(job, nil)
   347  	require.NoError(err)
   348  
   349  	// Wait for there to be an allocation terminated successfully
   350  	//var allocID string
   351  	testutil.WaitForResult(func() (bool, error) {
   352  		// Get the allocations for the job
   353  		allocs, _, err := j.Allocations(*job.ID, false, nil)
   354  		if err != nil {
   355  			return false, err
   356  		}
   357  		l := len(allocs)
   358  		switch l {
   359  		case 0:
   360  			return false, fmt.Errorf("want one alloc; got zero")
   361  		case 1:
   362  		default:
   363  			// exit early
   364  			t.Fatalf("too many allocations; something failed")
   365  		}
   366  		alloc := allocs[0]
   367  		//allocID = alloc.ID
   368  		if alloc.ClientStatus == "complete" {
   369  			return true, nil
   370  		}
   371  
   372  		return false, fmt.Errorf("client status %q", alloc.ClientStatus)
   373  	}, func(err error) {
   374  		t.Fatalf("allocation did not finish: %v", err)
   375  	})
   376  
   377  }
   378  
   379  // setupVault takes the Vault client and creates the required policies and
   380  // roles. It returns the token that should be used by Nomad
   381  func setupVault(t *testing.T, client *vapi.Client, vaultVersion string) string {
   382  	// Write the policy
   383  	sys := client.Sys()
   384  
   385  	// pre-0.9.0 vault servers do not work with our new vault client for the policy endpoint
   386  	// perform this using a raw HTTP request
   387  	newApi := version.Must(version.NewVersion("0.9.0"))
   388  	testVersion := version.Must(version.NewVersion(vaultVersion))
   389  	if testVersion.LessThan(newApi) {
   390  		body := map[string]string{
   391  			"rules": policy,
   392  		}
   393  		request := client.NewRequest("PUT", "/v1/sys/policy/nomad-server")
   394  		if err := request.SetJSONBody(body); err != nil {
   395  			t.Fatalf("failed to set JSON body on legacy policy creation: %v", err)
   396  		}
   397  		if _, err := client.RawRequest(request); err != nil {
   398  			t.Fatalf("failed to create legacy policy: %v", err)
   399  		}
   400  	} else {
   401  		if err := sys.PutPolicy("nomad-server", policy); err != nil {
   402  			t.Fatalf("failed to create policy: %v", err)
   403  		}
   404  	}
   405  
   406  	// Build the role
   407  	l := client.Logical()
   408  	l.Write("auth/token/roles/nomad-cluster", role)
   409  
   410  	// Create a new token with the role
   411  	a := client.Auth().Token()
   412  	req := vapi.TokenCreateRequest{
   413  		Policies: []string{"nomad-server"},
   414  		Period:   "72h",
   415  		NoParent: true,
   416  	}
   417  	s, err := a.Create(&req)
   418  	if err != nil {
   419  		t.Fatalf("failed to create child token: %v", err)
   420  	}
   421  
   422  	// Get the client token
   423  	if s == nil || s.Auth == nil {
   424  		t.Fatalf("bad secret response: %+v", s)
   425  	}
   426  
   427  	return s.Auth.ClientToken
   428  }