github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/vaultclient/vaultclient_test.go (about)

     1  package vaultclient
     2  
     3  import (
     4  	"strings"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/hashicorp/nomad/ci"
     9  	"github.com/hashicorp/nomad/client/config"
    10  	"github.com/hashicorp/nomad/helper/testlog"
    11  	"github.com/hashicorp/nomad/testutil"
    12  	vaultapi "github.com/hashicorp/vault/api"
    13  	vaultconsts "github.com/hashicorp/vault/sdk/helper/consts"
    14  	"github.com/stretchr/testify/assert"
    15  	"github.com/stretchr/testify/require"
    16  )
    17  
    18  func TestVaultClient_TokenRenewals(t *testing.T) {
    19  	ci.Parallel(t)
    20  
    21  	require := require.New(t)
    22  	v := testutil.NewTestVault(t)
    23  	defer v.Stop()
    24  
    25  	logger := testlog.HCLogger(t)
    26  	v.Config.ConnectionRetryIntv = 100 * time.Millisecond
    27  	v.Config.TaskTokenTTL = "4s"
    28  	c, err := NewVaultClient(v.Config, logger, nil)
    29  	if err != nil {
    30  		t.Fatalf("failed to build vault client: %v", err)
    31  	}
    32  
    33  	c.Start()
    34  	defer c.Stop()
    35  
    36  	// Sleep a little while to ensure that the renewal loop is active
    37  	time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second)
    38  
    39  	tcr := &vaultapi.TokenCreateRequest{
    40  		Policies:    []string{"foo", "bar"},
    41  		TTL:         "2s",
    42  		DisplayName: "derived-for-task",
    43  		Renewable:   new(bool),
    44  	}
    45  	*tcr.Renewable = true
    46  
    47  	num := 5
    48  	tokens := make([]string, num)
    49  	for i := 0; i < num; i++ {
    50  		c.client.SetToken(v.Config.Token)
    51  
    52  		if err := c.client.SetAddress(v.Config.Addr); err != nil {
    53  			t.Fatal(err)
    54  		}
    55  
    56  		secret, err := c.client.Auth().Token().Create(tcr)
    57  		if err != nil {
    58  			t.Fatalf("failed to create vault token: %v", err)
    59  		}
    60  
    61  		if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" {
    62  			t.Fatal("failed to derive a wrapped vault token")
    63  		}
    64  
    65  		tokens[i] = secret.Auth.ClientToken
    66  
    67  		errCh, err := c.RenewToken(tokens[i], secret.Auth.LeaseDuration)
    68  		if err != nil {
    69  			t.Fatalf("Unexpected error: %v", err)
    70  		}
    71  
    72  		go func(errCh <-chan error) {
    73  			for {
    74  				select {
    75  				case err := <-errCh:
    76  					require.NoError(err, "unexpected error while renewing vault token")
    77  				}
    78  			}
    79  		}(errCh)
    80  	}
    81  
    82  	c.lock.Lock()
    83  	length := c.heap.Length()
    84  	c.lock.Unlock()
    85  	if length != num {
    86  		t.Fatalf("bad: heap length: expected: %d, actual: %d", num, length)
    87  	}
    88  
    89  	time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second)
    90  
    91  	for i := 0; i < num; i++ {
    92  		if err := c.StopRenewToken(tokens[i]); err != nil {
    93  			require.NoError(err)
    94  		}
    95  	}
    96  
    97  	c.lock.Lock()
    98  	length = c.heap.Length()
    99  	c.lock.Unlock()
   100  	if length != 0 {
   101  		t.Fatalf("bad: heap length: expected: 0, actual: %d", length)
   102  	}
   103  }
   104  
   105  // TestVaultClient_NamespaceSupport tests that the Vault namespace config, if present, will result in the
   106  // namespace header being set on the created Vault client.
   107  func TestVaultClient_NamespaceSupport(t *testing.T) {
   108  	ci.Parallel(t)
   109  
   110  	require := require.New(t)
   111  	tr := true
   112  	testNs := "test-namespace"
   113  
   114  	logger := testlog.HCLogger(t)
   115  
   116  	conf := config.DefaultConfig()
   117  	conf.VaultConfig.Enabled = &tr
   118  	conf.VaultConfig.Token = "testvaulttoken"
   119  	conf.VaultConfig.Namespace = testNs
   120  	c, err := NewVaultClient(conf.VaultConfig, logger, nil)
   121  	require.NoError(err)
   122  	require.Equal(testNs, c.client.Headers().Get(vaultconsts.NamespaceHeaderName))
   123  }
   124  
   125  func TestVaultClient_Heap(t *testing.T) {
   126  	ci.Parallel(t)
   127  
   128  	tr := true
   129  	conf := config.DefaultConfig()
   130  	conf.VaultConfig.Enabled = &tr
   131  	conf.VaultConfig.Token = "testvaulttoken"
   132  	conf.VaultConfig.TaskTokenTTL = "10s"
   133  
   134  	logger := testlog.HCLogger(t)
   135  	c, err := NewVaultClient(conf.VaultConfig, logger, nil)
   136  	if err != nil {
   137  		t.Fatal(err)
   138  	}
   139  	if c == nil {
   140  		t.Fatal("failed to create vault client")
   141  	}
   142  
   143  	now := time.Now()
   144  
   145  	renewalReq1 := &vaultClientRenewalRequest{
   146  		errCh:     make(chan error, 1),
   147  		id:        "id1",
   148  		increment: 10,
   149  	}
   150  	if err := c.heap.Push(renewalReq1, now.Add(50*time.Second)); err != nil {
   151  		t.Fatal(err)
   152  	}
   153  	if !c.isTracked("id1") {
   154  		t.Fatalf("id1 should have been tracked")
   155  	}
   156  
   157  	renewalReq2 := &vaultClientRenewalRequest{
   158  		errCh:     make(chan error, 1),
   159  		id:        "id2",
   160  		increment: 10,
   161  	}
   162  	if err := c.heap.Push(renewalReq2, now.Add(40*time.Second)); err != nil {
   163  		t.Fatal(err)
   164  	}
   165  	if !c.isTracked("id2") {
   166  		t.Fatalf("id2 should have been tracked")
   167  	}
   168  
   169  	renewalReq3 := &vaultClientRenewalRequest{
   170  		errCh:     make(chan error, 1),
   171  		id:        "id3",
   172  		increment: 10,
   173  	}
   174  	if err := c.heap.Push(renewalReq3, now.Add(60*time.Second)); err != nil {
   175  		t.Fatal(err)
   176  	}
   177  	if !c.isTracked("id3") {
   178  		t.Fatalf("id3 should have been tracked")
   179  	}
   180  
   181  	// Reading elements should yield id2, id1 and id3 in order
   182  	req, _ := c.nextRenewal()
   183  	if req != renewalReq2 {
   184  		t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq2, req)
   185  	}
   186  	if err := c.heap.Update(req, now.Add(70*time.Second)); err != nil {
   187  		t.Fatal(err)
   188  	}
   189  
   190  	req, _ = c.nextRenewal()
   191  	if req != renewalReq1 {
   192  		t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq1, req)
   193  	}
   194  	if err := c.heap.Update(req, now.Add(80*time.Second)); err != nil {
   195  		t.Fatal(err)
   196  	}
   197  
   198  	req, _ = c.nextRenewal()
   199  	if req != renewalReq3 {
   200  		t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq3, req)
   201  	}
   202  	if err := c.heap.Update(req, now.Add(90*time.Second)); err != nil {
   203  		t.Fatal(err)
   204  	}
   205  
   206  	if err := c.StopRenewToken("id1"); err != nil {
   207  		t.Fatal(err)
   208  	}
   209  
   210  	if err := c.StopRenewToken("id2"); err != nil {
   211  		t.Fatal(err)
   212  	}
   213  
   214  	if err := c.StopRenewToken("id3"); err != nil {
   215  		t.Fatal(err)
   216  	}
   217  
   218  	if c.isTracked("id1") {
   219  		t.Fatalf("id1 should not have been tracked")
   220  	}
   221  
   222  	if c.isTracked("id1") {
   223  		t.Fatalf("id1 should not have been tracked")
   224  	}
   225  
   226  	if c.isTracked("id1") {
   227  		t.Fatalf("id1 should not have been tracked")
   228  	}
   229  
   230  }
   231  
   232  func TestVaultClient_RenewNonRenewableLease(t *testing.T) {
   233  	ci.Parallel(t)
   234  
   235  	v := testutil.NewTestVault(t)
   236  	defer v.Stop()
   237  
   238  	logger := testlog.HCLogger(t)
   239  	v.Config.ConnectionRetryIntv = 100 * time.Millisecond
   240  	v.Config.TaskTokenTTL = "4s"
   241  	c, err := NewVaultClient(v.Config, logger, nil)
   242  	if err != nil {
   243  		t.Fatalf("failed to build vault client: %v", err)
   244  	}
   245  
   246  	c.Start()
   247  	defer c.Stop()
   248  
   249  	// Sleep a little while to ensure that the renewal loop is active
   250  	time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second)
   251  
   252  	tcr := &vaultapi.TokenCreateRequest{
   253  		Policies:    []string{"foo", "bar"},
   254  		TTL:         "2s",
   255  		DisplayName: "derived-for-task",
   256  		Renewable:   new(bool),
   257  	}
   258  
   259  	c.client.SetToken(v.Config.Token)
   260  
   261  	if err := c.client.SetAddress(v.Config.Addr); err != nil {
   262  		t.Fatal(err)
   263  	}
   264  
   265  	secret, err := c.client.Auth().Token().Create(tcr)
   266  	if err != nil {
   267  		t.Fatalf("failed to create vault token: %v", err)
   268  	}
   269  
   270  	if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" {
   271  		t.Fatal("failed to derive a wrapped vault token")
   272  	}
   273  
   274  	_, err = c.RenewToken(secret.Auth.ClientToken, secret.Auth.LeaseDuration)
   275  	if err == nil {
   276  		t.Fatalf("expected error, got nil")
   277  	} else if !strings.Contains(err.Error(), "lease is not renewable") {
   278  		t.Fatalf("expected \"%s\" in error message, got \"%v\"", "lease is not renewable", err)
   279  	}
   280  }
   281  
   282  func TestVaultClient_RenewNonexistentLease(t *testing.T) {
   283  	ci.Parallel(t)
   284  
   285  	v := testutil.NewTestVault(t)
   286  	defer v.Stop()
   287  
   288  	logger := testlog.HCLogger(t)
   289  	v.Config.ConnectionRetryIntv = 100 * time.Millisecond
   290  	v.Config.TaskTokenTTL = "4s"
   291  	c, err := NewVaultClient(v.Config, logger, nil)
   292  	if err != nil {
   293  		t.Fatalf("failed to build vault client: %v", err)
   294  	}
   295  
   296  	c.Start()
   297  	defer c.Stop()
   298  
   299  	// Sleep a little while to ensure that the renewal loop is active
   300  	time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second)
   301  
   302  	c.client.SetToken(v.Config.Token)
   303  
   304  	if err := c.client.SetAddress(v.Config.Addr); err != nil {
   305  		t.Fatal(err)
   306  	}
   307  
   308  	_, err = c.RenewToken(c.client.Token(), 10)
   309  	if err == nil {
   310  		t.Fatalf("expected error, got nil")
   311  		// The Vault error message changed between 0.10.2 and 1.0.1
   312  	} else if !strings.Contains(err.Error(), "lease not found") && !strings.Contains(err.Error(), "lease is not renewable") {
   313  		t.Fatalf("expected \"%s\" or \"%s\" in error message, got \"%v\"", "lease not found", "lease is not renewable", err.Error())
   314  	}
   315  }
   316  
   317  // TestVaultClient_RenewalTime_Long asserts that for leases over 1m the renewal
   318  // time is jittered.
   319  func TestVaultClient_RenewalTime_Long(t *testing.T) {
   320  	ci.Parallel(t)
   321  
   322  	// highRoller is a randIntn func that always returns the max value
   323  	highRoller := func(n int) int {
   324  		return n - 1
   325  	}
   326  
   327  	// lowRoller is a randIntn func that always returns the min value (0)
   328  	lowRoller := func(int) int {
   329  		return 0
   330  	}
   331  
   332  	assert.Equal(t, 39*time.Second, renewalTime(highRoller, 60))
   333  	assert.Equal(t, 20*time.Second, renewalTime(lowRoller, 60))
   334  
   335  	assert.Equal(t, 309*time.Second, renewalTime(highRoller, 600))
   336  	assert.Equal(t, 290*time.Second, renewalTime(lowRoller, 600))
   337  
   338  	const days3 = 60 * 60 * 24 * 3
   339  	assert.Equal(t, (days3/2+9)*time.Second, renewalTime(highRoller, days3))
   340  	assert.Equal(t, (days3/2-10)*time.Second, renewalTime(lowRoller, days3))
   341  }
   342  
   343  // TestVaultClient_RenewalTime_Short asserts that for leases under 1m the renewal
   344  // time is lease/2.
   345  func TestVaultClient_RenewalTime_Short(t *testing.T) {
   346  	ci.Parallel(t)
   347  
   348  	dice := func(int) int {
   349  		require.Fail(t, "dice should not have been called")
   350  		panic("unreachable")
   351  	}
   352  
   353  	assert.Equal(t, 29*time.Second, renewalTime(dice, 58))
   354  	assert.Equal(t, 15*time.Second, renewalTime(dice, 30))
   355  	assert.Equal(t, 1*time.Second, renewalTime(dice, 2))
   356  }