github.com/bigcommerce/nomad@v0.9.3-bc/client/vaultclient/vaultclient_test.go (about)

     1  package vaultclient
     2  
     3  import (
     4  	"strings"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/hashicorp/nomad/client/config"
     9  	"github.com/hashicorp/nomad/helper/testlog"
    10  	"github.com/hashicorp/nomad/testutil"
    11  	vaultapi "github.com/hashicorp/vault/api"
    12  	vaultconsts "github.com/hashicorp/vault/helper/consts"
    13  	"github.com/stretchr/testify/assert"
    14  	"github.com/stretchr/testify/require"
    15  )
    16  
    17  func TestVaultClient_TokenRenewals(t *testing.T) {
    18  	t.Parallel()
    19  	require := require.New(t)
    20  	v := testutil.NewTestVault(t)
    21  	defer v.Stop()
    22  
    23  	logger := testlog.HCLogger(t)
    24  	v.Config.ConnectionRetryIntv = 100 * time.Millisecond
    25  	v.Config.TaskTokenTTL = "4s"
    26  	c, err := NewVaultClient(v.Config, logger, nil)
    27  	if err != nil {
    28  		t.Fatalf("failed to build vault client: %v", err)
    29  	}
    30  
    31  	c.Start()
    32  	defer c.Stop()
    33  
    34  	// Sleep a little while to ensure that the renewal loop is active
    35  	time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second)
    36  
    37  	tcr := &vaultapi.TokenCreateRequest{
    38  		Policies:    []string{"foo", "bar"},
    39  		TTL:         "2s",
    40  		DisplayName: "derived-for-task",
    41  		Renewable:   new(bool),
    42  	}
    43  	*tcr.Renewable = true
    44  
    45  	num := 5
    46  	tokens := make([]string, num)
    47  	for i := 0; i < num; i++ {
    48  		c.client.SetToken(v.Config.Token)
    49  
    50  		if err := c.client.SetAddress(v.Config.Addr); err != nil {
    51  			t.Fatal(err)
    52  		}
    53  
    54  		secret, err := c.client.Auth().Token().Create(tcr)
    55  		if err != nil {
    56  			t.Fatalf("failed to create vault token: %v", err)
    57  		}
    58  
    59  		if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" {
    60  			t.Fatal("failed to derive a wrapped vault token")
    61  		}
    62  
    63  		tokens[i] = secret.Auth.ClientToken
    64  
    65  		errCh, err := c.RenewToken(tokens[i], secret.Auth.LeaseDuration)
    66  		if err != nil {
    67  			t.Fatalf("Unexpected error: %v", err)
    68  		}
    69  
    70  		go func(errCh <-chan error) {
    71  			for {
    72  				select {
    73  				case err := <-errCh:
    74  					require.NoError(err, "unexpected error while renewing vault token")
    75  				}
    76  			}
    77  		}(errCh)
    78  	}
    79  
    80  	c.lock.Lock()
    81  	length := c.heap.Length()
    82  	c.lock.Unlock()
    83  	if length != num {
    84  		t.Fatalf("bad: heap length: expected: %d, actual: %d", num, length)
    85  	}
    86  
    87  	time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second)
    88  
    89  	for i := 0; i < num; i++ {
    90  		if err := c.StopRenewToken(tokens[i]); err != nil {
    91  			require.NoError(err)
    92  		}
    93  	}
    94  
    95  	c.lock.Lock()
    96  	length = c.heap.Length()
    97  	c.lock.Unlock()
    98  	if length != 0 {
    99  		t.Fatalf("bad: heap length: expected: 0, actual: %d", length)
   100  	}
   101  }
   102  
   103  // TestVaultClient_NamespaceSupport tests that the Vault namespace config, if present, will result in the
   104  // namespace header being set on the created Vault client.
   105  func TestVaultClient_NamespaceSupport(t *testing.T) {
   106  	t.Parallel()
   107  	require := require.New(t)
   108  	tr := true
   109  	testNs := "test-namespace"
   110  
   111  	logger := testlog.HCLogger(t)
   112  
   113  	conf := config.DefaultConfig()
   114  	conf.VaultConfig.Enabled = &tr
   115  	conf.VaultConfig.Token = "testvaulttoken"
   116  	conf.VaultConfig.Namespace = testNs
   117  	c, err := NewVaultClient(conf.VaultConfig, logger, nil)
   118  	require.NoError(err)
   119  	require.Equal(testNs, c.client.Headers().Get(vaultconsts.NamespaceHeaderName))
   120  }
   121  
   122  func TestVaultClient_Heap(t *testing.T) {
   123  	t.Parallel()
   124  	tr := true
   125  	conf := config.DefaultConfig()
   126  	conf.VaultConfig.Enabled = &tr
   127  	conf.VaultConfig.Token = "testvaulttoken"
   128  	conf.VaultConfig.TaskTokenTTL = "10s"
   129  
   130  	logger := testlog.HCLogger(t)
   131  	c, err := NewVaultClient(conf.VaultConfig, logger, nil)
   132  	if err != nil {
   133  		t.Fatal(err)
   134  	}
   135  	if c == nil {
   136  		t.Fatal("failed to create vault client")
   137  	}
   138  
   139  	now := time.Now()
   140  
   141  	renewalReq1 := &vaultClientRenewalRequest{
   142  		errCh:     make(chan error, 1),
   143  		id:        "id1",
   144  		increment: 10,
   145  	}
   146  	if err := c.heap.Push(renewalReq1, now.Add(50*time.Second)); err != nil {
   147  		t.Fatal(err)
   148  	}
   149  	if !c.isTracked("id1") {
   150  		t.Fatalf("id1 should have been tracked")
   151  	}
   152  
   153  	renewalReq2 := &vaultClientRenewalRequest{
   154  		errCh:     make(chan error, 1),
   155  		id:        "id2",
   156  		increment: 10,
   157  	}
   158  	if err := c.heap.Push(renewalReq2, now.Add(40*time.Second)); err != nil {
   159  		t.Fatal(err)
   160  	}
   161  	if !c.isTracked("id2") {
   162  		t.Fatalf("id2 should have been tracked")
   163  	}
   164  
   165  	renewalReq3 := &vaultClientRenewalRequest{
   166  		errCh:     make(chan error, 1),
   167  		id:        "id3",
   168  		increment: 10,
   169  	}
   170  	if err := c.heap.Push(renewalReq3, now.Add(60*time.Second)); err != nil {
   171  		t.Fatal(err)
   172  	}
   173  	if !c.isTracked("id3") {
   174  		t.Fatalf("id3 should have been tracked")
   175  	}
   176  
   177  	// Reading elements should yield id2, id1 and id3 in order
   178  	req, _ := c.nextRenewal()
   179  	if req != renewalReq2 {
   180  		t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq2, req)
   181  	}
   182  	if err := c.heap.Update(req, now.Add(70*time.Second)); err != nil {
   183  		t.Fatal(err)
   184  	}
   185  
   186  	req, _ = c.nextRenewal()
   187  	if req != renewalReq1 {
   188  		t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq1, req)
   189  	}
   190  	if err := c.heap.Update(req, now.Add(80*time.Second)); err != nil {
   191  		t.Fatal(err)
   192  	}
   193  
   194  	req, _ = c.nextRenewal()
   195  	if req != renewalReq3 {
   196  		t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq3, req)
   197  	}
   198  	if err := c.heap.Update(req, now.Add(90*time.Second)); err != nil {
   199  		t.Fatal(err)
   200  	}
   201  
   202  	if err := c.StopRenewToken("id1"); err != nil {
   203  		t.Fatal(err)
   204  	}
   205  
   206  	if err := c.StopRenewToken("id2"); err != nil {
   207  		t.Fatal(err)
   208  	}
   209  
   210  	if err := c.StopRenewToken("id3"); err != nil {
   211  		t.Fatal(err)
   212  	}
   213  
   214  	if c.isTracked("id1") {
   215  		t.Fatalf("id1 should not have been tracked")
   216  	}
   217  
   218  	if c.isTracked("id1") {
   219  		t.Fatalf("id1 should not have been tracked")
   220  	}
   221  
   222  	if c.isTracked("id1") {
   223  		t.Fatalf("id1 should not have been tracked")
   224  	}
   225  
   226  }
   227  
   228  func TestVaultClient_RenewNonRenewableLease(t *testing.T) {
   229  	t.Parallel()
   230  	v := testutil.NewTestVault(t)
   231  	defer v.Stop()
   232  
   233  	logger := testlog.HCLogger(t)
   234  	v.Config.ConnectionRetryIntv = 100 * time.Millisecond
   235  	v.Config.TaskTokenTTL = "4s"
   236  	c, err := NewVaultClient(v.Config, logger, nil)
   237  	if err != nil {
   238  		t.Fatalf("failed to build vault client: %v", err)
   239  	}
   240  
   241  	c.Start()
   242  	defer c.Stop()
   243  
   244  	// Sleep a little while to ensure that the renewal loop is active
   245  	time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second)
   246  
   247  	tcr := &vaultapi.TokenCreateRequest{
   248  		Policies:    []string{"foo", "bar"},
   249  		TTL:         "2s",
   250  		DisplayName: "derived-for-task",
   251  		Renewable:   new(bool),
   252  	}
   253  
   254  	c.client.SetToken(v.Config.Token)
   255  
   256  	if err := c.client.SetAddress(v.Config.Addr); err != nil {
   257  		t.Fatal(err)
   258  	}
   259  
   260  	secret, err := c.client.Auth().Token().Create(tcr)
   261  	if err != nil {
   262  		t.Fatalf("failed to create vault token: %v", err)
   263  	}
   264  
   265  	if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" {
   266  		t.Fatal("failed to derive a wrapped vault token")
   267  	}
   268  
   269  	_, err = c.RenewToken(secret.Auth.ClientToken, secret.Auth.LeaseDuration)
   270  	if err == nil {
   271  		t.Fatalf("expected error, got nil")
   272  	} else if !strings.Contains(err.Error(), "lease is not renewable") {
   273  		t.Fatalf("expected \"%s\" in error message, got \"%v\"", "lease is not renewable", err)
   274  	}
   275  }
   276  
   277  func TestVaultClient_RenewNonexistentLease(t *testing.T) {
   278  	t.Parallel()
   279  	v := testutil.NewTestVault(t)
   280  	defer v.Stop()
   281  
   282  	logger := testlog.HCLogger(t)
   283  	v.Config.ConnectionRetryIntv = 100 * time.Millisecond
   284  	v.Config.TaskTokenTTL = "4s"
   285  	c, err := NewVaultClient(v.Config, logger, nil)
   286  	if err != nil {
   287  		t.Fatalf("failed to build vault client: %v", err)
   288  	}
   289  
   290  	c.Start()
   291  	defer c.Stop()
   292  
   293  	// Sleep a little while to ensure that the renewal loop is active
   294  	time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second)
   295  
   296  	c.client.SetToken(v.Config.Token)
   297  
   298  	if err := c.client.SetAddress(v.Config.Addr); err != nil {
   299  		t.Fatal(err)
   300  	}
   301  
   302  	_, err = c.RenewToken(c.client.Token(), 10)
   303  	if err == nil {
   304  		t.Fatalf("expected error, got nil")
   305  		// The Vault error message changed between 0.10.2 and 1.0.1
   306  	} else if !strings.Contains(err.Error(), "lease not found") && !strings.Contains(err.Error(), "lease is not renewable") {
   307  		t.Fatalf("expected \"%s\" or \"%s\" in error message, got \"%v\"", "lease not found", "lease is not renewable", err.Error())
   308  	}
   309  }
   310  
   311  // TestVaultClient_RenewalTime_Long asserts that for leases over 1m the renewal
   312  // time is jittered.
   313  func TestVaultClient_RenewalTime_Long(t *testing.T) {
   314  	t.Parallel()
   315  
   316  	// highRoller is a randIntn func that always returns the max value
   317  	highRoller := func(n int) int {
   318  		return n - 1
   319  	}
   320  
   321  	// lowRoller is a randIntn func that always returns the min value (0)
   322  	lowRoller := func(int) int {
   323  		return 0
   324  	}
   325  
   326  	assert.Equal(t, 39*time.Second, renewalTime(highRoller, 60))
   327  	assert.Equal(t, 20*time.Second, renewalTime(lowRoller, 60))
   328  
   329  	assert.Equal(t, 309*time.Second, renewalTime(highRoller, 600))
   330  	assert.Equal(t, 290*time.Second, renewalTime(lowRoller, 600))
   331  
   332  	const days3 = 60 * 60 * 24 * 3
   333  	assert.Equal(t, (days3/2+9)*time.Second, renewalTime(highRoller, days3))
   334  	assert.Equal(t, (days3/2-10)*time.Second, renewalTime(lowRoller, days3))
   335  }
   336  
   337  // TestVaultClient_RenewalTime_Short asserts that for leases under 1m the renewal
   338  // time is lease/2.
   339  func TestVaultClient_RenewalTime_Short(t *testing.T) {
   340  	t.Parallel()
   341  
   342  	dice := func(int) int {
   343  		require.Fail(t, "dice should not have been called")
   344  		panic("unreachable")
   345  	}
   346  
   347  	assert.Equal(t, 29*time.Second, renewalTime(dice, 58))
   348  	assert.Equal(t, 15*time.Second, renewalTime(dice, 30))
   349  	assert.Equal(t, 1*time.Second, renewalTime(dice, 2))
   350  }