github.com/hernad/nomad@v1.6.112/e2e/vaultcompat/vault_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package vaultcompat
     5  
     6  import (
     7  	"archive/zip"
     8  	"bytes"
     9  	"encoding/json"
    10  	"flag"
    11  	"fmt"
    12  	"io"
    13  	"net/http"
    14  	"os"
    15  	"path/filepath"
    16  	"runtime"
    17  	"sort"
    18  	"testing"
    19  	"time"
    20  
    21  	"github.com/hashicorp/go-version"
    22  	"github.com/hernad/nomad/command/agent"
    23  	"github.com/hernad/nomad/helper/pointer"
    24  	"github.com/hernad/nomad/nomad/structs/config"
    25  	"github.com/hernad/nomad/testutil"
    26  	vapi "github.com/hashicorp/vault/api"
    27  	"github.com/stretchr/testify/require"
    28  )
    29  
    30  var (
    31  	integration = flag.Bool("integration", false, "run integration tests")
    32  	minVaultVer = version.Must(version.NewVersion("0.6.2"))
    33  )
    34  
    35  // syncVault discovers available versions of Vault, downloads the binaries,
    36  // returns a map of version to binary path as well as a sorted list of
    37  // versions.
    38  func syncVault(t *testing.T) ([]*version.Version, map[string]string) {
    39  
    40  	binDir := filepath.Join(os.TempDir(), "vault-bins/")
    41  
    42  	urls := vaultVersions(t)
    43  
    44  	sorted, versions, err := pruneVersions(urls)
    45  	require.NoError(t, err)
    46  
    47  	// Get the binaries we need to download
    48  	missing, err := missingVault(binDir, versions)
    49  	require.NoError(t, err)
    50  
    51  	// Create the directory for the binaries
    52  	require.NoError(t, createBinDir(binDir))
    53  
    54  	// Download in parallel
    55  	start := time.Now()
    56  	errCh := make(chan error, len(missing))
    57  	for ver, url := range missing {
    58  		go func(dst, url string) {
    59  			errCh <- getVault(dst, url)
    60  		}(filepath.Join(binDir, ver), url)
    61  	}
    62  	for i := 0; i < len(missing); i++ {
    63  		select {
    64  		case err := <-errCh:
    65  			require.NoError(t, err)
    66  		case <-time.After(5 * time.Minute):
    67  			require.Fail(t, "timed out downloading Vault binaries")
    68  		}
    69  	}
    70  	if n := len(missing); n > 0 {
    71  		t.Logf("Downloaded %d versions of Vault in %s", n, time.Now().Sub(start))
    72  	}
    73  
    74  	binaries := make(map[string]string, len(versions))
    75  	for ver := range versions {
    76  		binaries[ver] = filepath.Join(binDir, ver)
    77  	}
    78  	return sorted, binaries
    79  }
    80  
    81  // vaultVersions discovers available Vault versions from releases.hashicorp.com
    82  // and returns a map of version to url.
    83  func vaultVersions(t *testing.T) map[string]string {
    84  	resp, err := http.Get("https://releases.hashicorp.com/vault/index.json")
    85  	require.NoError(t, err)
    86  
    87  	respJson := struct {
    88  		Versions map[string]struct {
    89  			Builds []struct {
    90  				Version string `json:"version"`
    91  				Os      string `json:"os"`
    92  				Arch    string `json:"arch"`
    93  				URL     string `json:"url"`
    94  			} `json:"builds"`
    95  		}
    96  	}{}
    97  	require.NoError(t, json.NewDecoder(resp.Body).Decode(&respJson))
    98  	require.NoError(t, resp.Body.Close())
    99  
   100  	versions := map[string]string{}
   101  	for vk, vv := range respJson.Versions {
   102  		gover, err := version.NewVersion(vk)
   103  		if err != nil {
   104  			t.Logf("error parsing Vault version %q -> %v", vk, err)
   105  			continue
   106  		}
   107  
   108  		// Skip ancient versions
   109  		if gover.LessThan(minVaultVer) {
   110  			continue
   111  		}
   112  
   113  		// Skip prerelease and enterprise versions
   114  		if gover.Prerelease() != "" || gover.Metadata() != "" {
   115  			continue
   116  		}
   117  
   118  		url := ""
   119  		for _, b := range vv.Builds {
   120  			buildver, err := version.NewVersion(b.Version)
   121  			if err != nil {
   122  				t.Logf("error parsing Vault build version %q -> %v", b.Version, err)
   123  				continue
   124  			}
   125  
   126  			if buildver.Prerelease() != "" {
   127  				continue
   128  			}
   129  
   130  			if buildver.Metadata() != "" {
   131  				continue
   132  			}
   133  
   134  			if b.Os != runtime.GOOS {
   135  				continue
   136  			}
   137  
   138  			if b.Arch != runtime.GOARCH {
   139  				continue
   140  			}
   141  
   142  			// Match!
   143  			url = b.URL
   144  			break
   145  		}
   146  
   147  		if url != "" {
   148  			versions[vk] = url
   149  		}
   150  	}
   151  
   152  	return versions
   153  }
   154  
   155  // pruneVersions only takes the latest Z for each X.Y.Z release. Returns a
   156  // sorted list and map of kept versions.
   157  func pruneVersions(all map[string]string) ([]*version.Version, map[string]string, error) {
   158  	if len(all) == 0 {
   159  		return nil, nil, fmt.Errorf("0 Vault versions")
   160  	}
   161  
   162  	sorted := make([]*version.Version, 0, len(all))
   163  
   164  	for k := range all {
   165  		sorted = append(sorted, version.Must(version.NewVersion(k)))
   166  	}
   167  
   168  	sort.Sort(version.Collection(sorted))
   169  
   170  	keep := make([]*version.Version, 0, len(all))
   171  
   172  	for _, v := range sorted {
   173  		segments := v.Segments()
   174  		if len(segments) < 3 {
   175  			// Drop malformed versions
   176  			continue
   177  		}
   178  
   179  		if len(keep) == 0 {
   180  			keep = append(keep, v)
   181  			continue
   182  		}
   183  
   184  		last := keep[len(keep)-1].Segments()
   185  
   186  		if segments[0] == last[0] && segments[1] == last[1] {
   187  			// current X.Y == last X.Y, replace last with current
   188  			keep[len(keep)-1] = v
   189  		} else {
   190  			// current X.Y != last X.Y, append
   191  			keep = append(keep, v)
   192  		}
   193  	}
   194  
   195  	// Create a new map of canonicalized versions to urls
   196  	urls := make(map[string]string, len(keep))
   197  	for _, v := range keep {
   198  		origURL := all[v.Original()]
   199  		if origURL == "" {
   200  			return nil, nil, fmt.Errorf("missing version %s", v.Original())
   201  		}
   202  		urls[v.String()] = origURL
   203  	}
   204  
   205  	return keep, urls, nil
   206  }
   207  
   208  // createBinDir creates the binary directory
   209  func createBinDir(binDir string) error {
   210  	// Check if the directory exists, otherwise create it
   211  	f, err := os.Stat(binDir)
   212  	if err != nil && !os.IsNotExist(err) {
   213  		return fmt.Errorf("failed to stat directory: %v", err)
   214  	}
   215  
   216  	if f != nil && f.IsDir() {
   217  		return nil
   218  	} else if f != nil {
   219  		if err := os.RemoveAll(binDir); err != nil {
   220  			return fmt.Errorf("failed to remove file at directory path: %v", err)
   221  		}
   222  	}
   223  
   224  	// Create the directory
   225  	if err := os.Mkdir(binDir, 075); err != nil {
   226  		return fmt.Errorf("failed to make directory: %v", err)
   227  	}
   228  	if err := os.Chmod(binDir, 0755); err != nil {
   229  		return fmt.Errorf("failed to chmod: %v", err)
   230  	}
   231  
   232  	return nil
   233  }
   234  
   235  // missingVault returns the binaries that must be downloaded. versions key must
   236  // be the Vault version.
   237  func missingVault(binDir string, versions map[string]string) (map[string]string, error) {
   238  	files, err := os.ReadDir(binDir)
   239  	if err != nil {
   240  		if os.IsNotExist(err) {
   241  			return versions, nil
   242  		}
   243  
   244  		return nil, fmt.Errorf("failed to stat directory: %v", err)
   245  	}
   246  
   247  	// Copy versions so we don't mutate it
   248  	missingSet := make(map[string]string, len(versions))
   249  	for k, v := range versions {
   250  		missingSet[k] = v
   251  	}
   252  
   253  	for _, f := range files {
   254  		delete(missingSet, f.Name())
   255  	}
   256  
   257  	return missingSet, nil
   258  }
   259  
   260  // getVault downloads the given Vault binary
   261  func getVault(dst, url string) error {
   262  	resp, err := http.Get(url)
   263  	if err != nil {
   264  		return err
   265  	}
   266  	defer resp.Body.Close()
   267  
   268  	// Wrap in an in-mem buffer
   269  	b := bytes.NewBuffer(nil)
   270  	if _, err := io.Copy(b, resp.Body); err != nil {
   271  		return fmt.Errorf("error reading response body: %v", err)
   272  	}
   273  	resp.Body.Close()
   274  
   275  	zreader, err := zip.NewReader(bytes.NewReader(b.Bytes()), resp.ContentLength)
   276  	if err != nil {
   277  		return err
   278  	}
   279  
   280  	if l := len(zreader.File); l != 1 {
   281  		return fmt.Errorf("unexpected number of files in zip: %v", l)
   282  	}
   283  
   284  	// Copy the file to its destination
   285  	out, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0777)
   286  	if err != nil {
   287  		return err
   288  	}
   289  	defer out.Close()
   290  
   291  	zfile, err := zreader.File[0].Open()
   292  	if err != nil {
   293  		return fmt.Errorf("failed to open zip file: %v", err)
   294  	}
   295  
   296  	if _, err := io.Copy(out, zfile); err != nil {
   297  		return fmt.Errorf("failed to decompress file to destination: %v", err)
   298  	}
   299  
   300  	return nil
   301  }
   302  
   303  // TestVaultCompatibility tests compatibility across Vault versions
   304  func TestVaultCompatibility(t *testing.T) {
   305  	if !*integration {
   306  		t.Skip("skipping test in non-integration mode: add -integration flag to run")
   307  	}
   308  
   309  	sorted, vaultBinaries := syncVault(t)
   310  
   311  	for _, v := range sorted {
   312  		ver := v.String()
   313  		bin := vaultBinaries[ver]
   314  		require.NotZerof(t, bin, "missing version: %s", ver)
   315  		t.Run(ver, func(t *testing.T) {
   316  			testVaultCompatibility(t, bin, ver)
   317  		})
   318  	}
   319  }
   320  
   321  // testVaultCompatibility tests compatibility with the given vault binary
   322  func testVaultCompatibility(t *testing.T, vault string, version string) {
   323  	require := require.New(t)
   324  
   325  	// Create a Vault server
   326  	v := testutil.NewTestVaultFromPath(t, vault)
   327  	defer v.Stop()
   328  
   329  	token := setupVault(t, v.Client, version)
   330  
   331  	// Create a Nomad agent using the created vault
   332  	nomad := agent.NewTestAgent(t, t.Name(), func(c *agent.Config) {
   333  		if c.Vault == nil {
   334  			c.Vault = &config.VaultConfig{}
   335  		}
   336  		c.Vault.Enabled = pointer.Of(true)
   337  		c.Vault.Token = token
   338  		c.Vault.Role = "nomad-cluster"
   339  		c.Vault.AllowUnauthenticated = pointer.Of(true)
   340  		c.Vault.Addr = v.HTTPAddr
   341  	})
   342  	defer nomad.Shutdown()
   343  
   344  	// Submit the Nomad job that requests a Vault token and cats that the Vault
   345  	// token is there
   346  	c := nomad.Client()
   347  	j := c.Jobs()
   348  	_, _, err := j.Register(job, nil)
   349  	require.NoError(err)
   350  
   351  	// Wait for there to be an allocation terminated successfully
   352  	//var allocID string
   353  	testutil.WaitForResult(func() (bool, error) {
   354  		// Get the allocations for the job
   355  		allocs, _, err := j.Allocations(*job.ID, false, nil)
   356  		if err != nil {
   357  			return false, err
   358  		}
   359  		l := len(allocs)
   360  		switch l {
   361  		case 0:
   362  			return false, fmt.Errorf("want one alloc; got zero")
   363  		case 1:
   364  		default:
   365  			// exit early
   366  			require.Fail("too many allocations; something failed")
   367  		}
   368  		alloc := allocs[0]
   369  		//allocID = alloc.ID
   370  		if alloc.ClientStatus == "complete" {
   371  			return true, nil
   372  		}
   373  
   374  		return false, fmt.Errorf("client status %q", alloc.ClientStatus)
   375  	}, func(err error) {
   376  		require.NoError(err, "allocation did not finish")
   377  	})
   378  
   379  }
   380  
   381  // setupVault takes the Vault client and creates the required policies and
   382  // roles. It returns the token that should be used by Nomad
   383  func setupVault(t *testing.T, client *vapi.Client, vaultVersion string) string {
   384  	// Write the policy
   385  	sys := client.Sys()
   386  
   387  	// pre-0.9.0 vault servers do not work with our new vault client for the policy endpoint
   388  	// perform this using a raw HTTP request
   389  	newApi := version.Must(version.NewVersion("0.9.0"))
   390  	testVersion := version.Must(version.NewVersion(vaultVersion))
   391  	if testVersion.LessThan(newApi) {
   392  		body := map[string]string{
   393  			"rules": policy,
   394  		}
   395  		request := client.NewRequest("PUT", "/v1/sys/policy/nomad-server")
   396  		if err := request.SetJSONBody(body); err != nil {
   397  			require.NoError(t, err, "failed to set JSON body on legacy policy creation")
   398  		}
   399  		if _, err := client.RawRequest(request); err != nil {
   400  			require.NoError(t, err, "failed to create legacy policy")
   401  		}
   402  	} else {
   403  		if err := sys.PutPolicy("nomad-server", policy); err != nil {
   404  			require.NoError(t, err, "failed to create policy")
   405  		}
   406  	}
   407  
   408  	// Build the role
   409  	l := client.Logical()
   410  	l.Write("auth/token/roles/nomad-cluster", role)
   411  
   412  	// Create a new token with the role
   413  	a := client.Auth().Token()
   414  	req := vapi.TokenCreateRequest{
   415  		Policies: []string{"nomad-server"},
   416  		Period:   "72h",
   417  		NoParent: true,
   418  	}
   419  	s, err := a.Create(&req)
   420  	if err != nil {
   421  		require.NoError(t, err, "failed to create child token")
   422  	}
   423  
   424  	// Get the client token
   425  	if s == nil || s.Auth == nil {
   426  		require.NoError(t, err, "bad secret response")
   427  	}
   428  
   429  	return s.Auth.ClientToken
   430  }