github.com/bigcommerce/nomad@v0.9.3-bc/client/vaultclient/vaultclient_test.go (about) 1 package vaultclient 2 3 import ( 4 "strings" 5 "testing" 6 "time" 7 8 "github.com/hashicorp/nomad/client/config" 9 "github.com/hashicorp/nomad/helper/testlog" 10 "github.com/hashicorp/nomad/testutil" 11 vaultapi "github.com/hashicorp/vault/api" 12 vaultconsts "github.com/hashicorp/vault/helper/consts" 13 "github.com/stretchr/testify/assert" 14 "github.com/stretchr/testify/require" 15 ) 16 17 func TestVaultClient_TokenRenewals(t *testing.T) { 18 t.Parallel() 19 require := require.New(t) 20 v := testutil.NewTestVault(t) 21 defer v.Stop() 22 23 logger := testlog.HCLogger(t) 24 v.Config.ConnectionRetryIntv = 100 * time.Millisecond 25 v.Config.TaskTokenTTL = "4s" 26 c, err := NewVaultClient(v.Config, logger, nil) 27 if err != nil { 28 t.Fatalf("failed to build vault client: %v", err) 29 } 30 31 c.Start() 32 defer c.Stop() 33 34 // Sleep a little while to ensure that the renewal loop is active 35 time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second) 36 37 tcr := &vaultapi.TokenCreateRequest{ 38 Policies: []string{"foo", "bar"}, 39 TTL: "2s", 40 DisplayName: "derived-for-task", 41 Renewable: new(bool), 42 } 43 *tcr.Renewable = true 44 45 num := 5 46 tokens := make([]string, num) 47 for i := 0; i < num; i++ { 48 c.client.SetToken(v.Config.Token) 49 50 if err := c.client.SetAddress(v.Config.Addr); err != nil { 51 t.Fatal(err) 52 } 53 54 secret, err := c.client.Auth().Token().Create(tcr) 55 if err != nil { 56 t.Fatalf("failed to create vault token: %v", err) 57 } 58 59 if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { 60 t.Fatal("failed to derive a wrapped vault token") 61 } 62 63 tokens[i] = secret.Auth.ClientToken 64 65 errCh, err := c.RenewToken(tokens[i], secret.Auth.LeaseDuration) 66 if err != nil { 67 t.Fatalf("Unexpected error: %v", err) 68 } 69 70 go func(errCh <-chan error) { 71 for { 72 select { 73 case err := <-errCh: 74 require.NoError(err, "unexpected error while renewing vault token") 75 } 76 } 77 }(errCh) 78 } 79 80 c.lock.Lock() 81 length := c.heap.Length() 82 c.lock.Unlock() 83 if length != num { 84 t.Fatalf("bad: heap length: expected: %d, actual: %d", num, length) 85 } 86 87 time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second) 88 89 for i := 0; i < num; i++ { 90 if err := c.StopRenewToken(tokens[i]); err != nil { 91 require.NoError(err) 92 } 93 } 94 95 c.lock.Lock() 96 length = c.heap.Length() 97 c.lock.Unlock() 98 if length != 0 { 99 t.Fatalf("bad: heap length: expected: 0, actual: %d", length) 100 } 101 } 102 103 // TestVaultClient_NamespaceSupport tests that the Vault namespace config, if present, will result in the 104 // namespace header being set on the created Vault client. 105 func TestVaultClient_NamespaceSupport(t *testing.T) { 106 t.Parallel() 107 require := require.New(t) 108 tr := true 109 testNs := "test-namespace" 110 111 logger := testlog.HCLogger(t) 112 113 conf := config.DefaultConfig() 114 conf.VaultConfig.Enabled = &tr 115 conf.VaultConfig.Token = "testvaulttoken" 116 conf.VaultConfig.Namespace = testNs 117 c, err := NewVaultClient(conf.VaultConfig, logger, nil) 118 require.NoError(err) 119 require.Equal(testNs, c.client.Headers().Get(vaultconsts.NamespaceHeaderName)) 120 } 121 122 func TestVaultClient_Heap(t *testing.T) { 123 t.Parallel() 124 tr := true 125 conf := config.DefaultConfig() 126 conf.VaultConfig.Enabled = &tr 127 conf.VaultConfig.Token = "testvaulttoken" 128 conf.VaultConfig.TaskTokenTTL = "10s" 129 130 logger := testlog.HCLogger(t) 131 c, err := NewVaultClient(conf.VaultConfig, logger, nil) 132 if err != nil { 133 t.Fatal(err) 134 } 135 if c == nil { 136 t.Fatal("failed to create vault client") 137 } 138 139 now := time.Now() 140 141 renewalReq1 := &vaultClientRenewalRequest{ 142 errCh: make(chan error, 1), 143 id: "id1", 144 increment: 10, 145 } 146 if err := c.heap.Push(renewalReq1, now.Add(50*time.Second)); err != nil { 147 t.Fatal(err) 148 } 149 if !c.isTracked("id1") { 150 t.Fatalf("id1 should have been tracked") 151 } 152 153 renewalReq2 := &vaultClientRenewalRequest{ 154 errCh: make(chan error, 1), 155 id: "id2", 156 increment: 10, 157 } 158 if err := c.heap.Push(renewalReq2, now.Add(40*time.Second)); err != nil { 159 t.Fatal(err) 160 } 161 if !c.isTracked("id2") { 162 t.Fatalf("id2 should have been tracked") 163 } 164 165 renewalReq3 := &vaultClientRenewalRequest{ 166 errCh: make(chan error, 1), 167 id: "id3", 168 increment: 10, 169 } 170 if err := c.heap.Push(renewalReq3, now.Add(60*time.Second)); err != nil { 171 t.Fatal(err) 172 } 173 if !c.isTracked("id3") { 174 t.Fatalf("id3 should have been tracked") 175 } 176 177 // Reading elements should yield id2, id1 and id3 in order 178 req, _ := c.nextRenewal() 179 if req != renewalReq2 { 180 t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq2, req) 181 } 182 if err := c.heap.Update(req, now.Add(70*time.Second)); err != nil { 183 t.Fatal(err) 184 } 185 186 req, _ = c.nextRenewal() 187 if req != renewalReq1 { 188 t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq1, req) 189 } 190 if err := c.heap.Update(req, now.Add(80*time.Second)); err != nil { 191 t.Fatal(err) 192 } 193 194 req, _ = c.nextRenewal() 195 if req != renewalReq3 { 196 t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq3, req) 197 } 198 if err := c.heap.Update(req, now.Add(90*time.Second)); err != nil { 199 t.Fatal(err) 200 } 201 202 if err := c.StopRenewToken("id1"); err != nil { 203 t.Fatal(err) 204 } 205 206 if err := c.StopRenewToken("id2"); err != nil { 207 t.Fatal(err) 208 } 209 210 if err := c.StopRenewToken("id3"); err != nil { 211 t.Fatal(err) 212 } 213 214 if c.isTracked("id1") { 215 t.Fatalf("id1 should not have been tracked") 216 } 217 218 if c.isTracked("id1") { 219 t.Fatalf("id1 should not have been tracked") 220 } 221 222 if c.isTracked("id1") { 223 t.Fatalf("id1 should not have been tracked") 224 } 225 226 } 227 228 func TestVaultClient_RenewNonRenewableLease(t *testing.T) { 229 t.Parallel() 230 v := testutil.NewTestVault(t) 231 defer v.Stop() 232 233 logger := testlog.HCLogger(t) 234 v.Config.ConnectionRetryIntv = 100 * time.Millisecond 235 v.Config.TaskTokenTTL = "4s" 236 c, err := NewVaultClient(v.Config, logger, nil) 237 if err != nil { 238 t.Fatalf("failed to build vault client: %v", err) 239 } 240 241 c.Start() 242 defer c.Stop() 243 244 // Sleep a little while to ensure that the renewal loop is active 245 time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second) 246 247 tcr := &vaultapi.TokenCreateRequest{ 248 Policies: []string{"foo", "bar"}, 249 TTL: "2s", 250 DisplayName: "derived-for-task", 251 Renewable: new(bool), 252 } 253 254 c.client.SetToken(v.Config.Token) 255 256 if err := c.client.SetAddress(v.Config.Addr); err != nil { 257 t.Fatal(err) 258 } 259 260 secret, err := c.client.Auth().Token().Create(tcr) 261 if err != nil { 262 t.Fatalf("failed to create vault token: %v", err) 263 } 264 265 if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { 266 t.Fatal("failed to derive a wrapped vault token") 267 } 268 269 _, err = c.RenewToken(secret.Auth.ClientToken, secret.Auth.LeaseDuration) 270 if err == nil { 271 t.Fatalf("expected error, got nil") 272 } else if !strings.Contains(err.Error(), "lease is not renewable") { 273 t.Fatalf("expected \"%s\" in error message, got \"%v\"", "lease is not renewable", err) 274 } 275 } 276 277 func TestVaultClient_RenewNonexistentLease(t *testing.T) { 278 t.Parallel() 279 v := testutil.NewTestVault(t) 280 defer v.Stop() 281 282 logger := testlog.HCLogger(t) 283 v.Config.ConnectionRetryIntv = 100 * time.Millisecond 284 v.Config.TaskTokenTTL = "4s" 285 c, err := NewVaultClient(v.Config, logger, nil) 286 if err != nil { 287 t.Fatalf("failed to build vault client: %v", err) 288 } 289 290 c.Start() 291 defer c.Stop() 292 293 // Sleep a little while to ensure that the renewal loop is active 294 time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second) 295 296 c.client.SetToken(v.Config.Token) 297 298 if err := c.client.SetAddress(v.Config.Addr); err != nil { 299 t.Fatal(err) 300 } 301 302 _, err = c.RenewToken(c.client.Token(), 10) 303 if err == nil { 304 t.Fatalf("expected error, got nil") 305 // The Vault error message changed between 0.10.2 and 1.0.1 306 } else if !strings.Contains(err.Error(), "lease not found") && !strings.Contains(err.Error(), "lease is not renewable") { 307 t.Fatalf("expected \"%s\" or \"%s\" in error message, got \"%v\"", "lease not found", "lease is not renewable", err.Error()) 308 } 309 } 310 311 // TestVaultClient_RenewalTime_Long asserts that for leases over 1m the renewal 312 // time is jittered. 313 func TestVaultClient_RenewalTime_Long(t *testing.T) { 314 t.Parallel() 315 316 // highRoller is a randIntn func that always returns the max value 317 highRoller := func(n int) int { 318 return n - 1 319 } 320 321 // lowRoller is a randIntn func that always returns the min value (0) 322 lowRoller := func(int) int { 323 return 0 324 } 325 326 assert.Equal(t, 39*time.Second, renewalTime(highRoller, 60)) 327 assert.Equal(t, 20*time.Second, renewalTime(lowRoller, 60)) 328 329 assert.Equal(t, 309*time.Second, renewalTime(highRoller, 600)) 330 assert.Equal(t, 290*time.Second, renewalTime(lowRoller, 600)) 331 332 const days3 = 60 * 60 * 24 * 3 333 assert.Equal(t, (days3/2+9)*time.Second, renewalTime(highRoller, days3)) 334 assert.Equal(t, (days3/2-10)*time.Second, renewalTime(lowRoller, days3)) 335 } 336 337 // TestVaultClient_RenewalTime_Short asserts that for leases under 1m the renewal 338 // time is lease/2. 339 func TestVaultClient_RenewalTime_Short(t *testing.T) { 340 t.Parallel() 341 342 dice := func(int) int { 343 require.Fail(t, "dice should not have been called") 344 panic("unreachable") 345 } 346 347 assert.Equal(t, 29*time.Second, renewalTime(dice, 58)) 348 assert.Equal(t, 15*time.Second, renewalTime(dice, 30)) 349 assert.Equal(t, 1*time.Second, renewalTime(dice, 2)) 350 }