github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/vaultclient/vaultclient_test.go (about) 1 package vaultclient 2 3 import ( 4 "strings" 5 "testing" 6 "time" 7 8 "github.com/hashicorp/nomad/ci" 9 "github.com/hashicorp/nomad/client/config" 10 "github.com/hashicorp/nomad/helper/testlog" 11 "github.com/hashicorp/nomad/testutil" 12 vaultapi "github.com/hashicorp/vault/api" 13 vaultconsts "github.com/hashicorp/vault/sdk/helper/consts" 14 "github.com/stretchr/testify/assert" 15 "github.com/stretchr/testify/require" 16 ) 17 18 func TestVaultClient_TokenRenewals(t *testing.T) { 19 ci.Parallel(t) 20 21 require := require.New(t) 22 v := testutil.NewTestVault(t) 23 defer v.Stop() 24 25 logger := testlog.HCLogger(t) 26 v.Config.ConnectionRetryIntv = 100 * time.Millisecond 27 v.Config.TaskTokenTTL = "4s" 28 c, err := NewVaultClient(v.Config, logger, nil) 29 if err != nil { 30 t.Fatalf("failed to build vault client: %v", err) 31 } 32 33 c.Start() 34 defer c.Stop() 35 36 // Sleep a little while to ensure that the renewal loop is active 37 time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second) 38 39 tcr := &vaultapi.TokenCreateRequest{ 40 Policies: []string{"foo", "bar"}, 41 TTL: "2s", 42 DisplayName: "derived-for-task", 43 Renewable: new(bool), 44 } 45 *tcr.Renewable = true 46 47 num := 5 48 tokens := make([]string, num) 49 for i := 0; i < num; i++ { 50 c.client.SetToken(v.Config.Token) 51 52 if err := c.client.SetAddress(v.Config.Addr); err != nil { 53 t.Fatal(err) 54 } 55 56 secret, err := c.client.Auth().Token().Create(tcr) 57 if err != nil { 58 t.Fatalf("failed to create vault token: %v", err) 59 } 60 61 if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { 62 t.Fatal("failed to derive a wrapped vault token") 63 } 64 65 tokens[i] = secret.Auth.ClientToken 66 67 errCh, err := c.RenewToken(tokens[i], secret.Auth.LeaseDuration) 68 if err != nil { 69 t.Fatalf("Unexpected error: %v", err) 70 } 71 72 go func(errCh <-chan error) { 73 for { 74 select { 75 case err := <-errCh: 76 require.NoError(err, "unexpected error while renewing vault token") 77 } 78 } 79 }(errCh) 80 } 81 82 c.lock.Lock() 83 length := c.heap.Length() 84 c.lock.Unlock() 85 if length != num { 86 t.Fatalf("bad: heap length: expected: %d, actual: %d", num, length) 87 } 88 89 time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second) 90 91 for i := 0; i < num; i++ { 92 if err := c.StopRenewToken(tokens[i]); err != nil { 93 require.NoError(err) 94 } 95 } 96 97 c.lock.Lock() 98 length = c.heap.Length() 99 c.lock.Unlock() 100 if length != 0 { 101 t.Fatalf("bad: heap length: expected: 0, actual: %d", length) 102 } 103 } 104 105 // TestVaultClient_NamespaceSupport tests that the Vault namespace config, if present, will result in the 106 // namespace header being set on the created Vault client. 107 func TestVaultClient_NamespaceSupport(t *testing.T) { 108 ci.Parallel(t) 109 110 require := require.New(t) 111 tr := true 112 testNs := "test-namespace" 113 114 logger := testlog.HCLogger(t) 115 116 conf := config.DefaultConfig() 117 conf.VaultConfig.Enabled = &tr 118 conf.VaultConfig.Token = "testvaulttoken" 119 conf.VaultConfig.Namespace = testNs 120 c, err := NewVaultClient(conf.VaultConfig, logger, nil) 121 require.NoError(err) 122 require.Equal(testNs, c.client.Headers().Get(vaultconsts.NamespaceHeaderName)) 123 } 124 125 func TestVaultClient_Heap(t *testing.T) { 126 ci.Parallel(t) 127 128 tr := true 129 conf := config.DefaultConfig() 130 conf.VaultConfig.Enabled = &tr 131 conf.VaultConfig.Token = "testvaulttoken" 132 conf.VaultConfig.TaskTokenTTL = "10s" 133 134 logger := testlog.HCLogger(t) 135 c, err := NewVaultClient(conf.VaultConfig, logger, nil) 136 if err != nil { 137 t.Fatal(err) 138 } 139 if c == nil { 140 t.Fatal("failed to create vault client") 141 } 142 143 now := time.Now() 144 145 renewalReq1 := &vaultClientRenewalRequest{ 146 errCh: make(chan error, 1), 147 id: "id1", 148 increment: 10, 149 } 150 if err := c.heap.Push(renewalReq1, now.Add(50*time.Second)); err != nil { 151 t.Fatal(err) 152 } 153 if !c.isTracked("id1") { 154 t.Fatalf("id1 should have been tracked") 155 } 156 157 renewalReq2 := &vaultClientRenewalRequest{ 158 errCh: make(chan error, 1), 159 id: "id2", 160 increment: 10, 161 } 162 if err := c.heap.Push(renewalReq2, now.Add(40*time.Second)); err != nil { 163 t.Fatal(err) 164 } 165 if !c.isTracked("id2") { 166 t.Fatalf("id2 should have been tracked") 167 } 168 169 renewalReq3 := &vaultClientRenewalRequest{ 170 errCh: make(chan error, 1), 171 id: "id3", 172 increment: 10, 173 } 174 if err := c.heap.Push(renewalReq3, now.Add(60*time.Second)); err != nil { 175 t.Fatal(err) 176 } 177 if !c.isTracked("id3") { 178 t.Fatalf("id3 should have been tracked") 179 } 180 181 // Reading elements should yield id2, id1 and id3 in order 182 req, _ := c.nextRenewal() 183 if req != renewalReq2 { 184 t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq2, req) 185 } 186 if err := c.heap.Update(req, now.Add(70*time.Second)); err != nil { 187 t.Fatal(err) 188 } 189 190 req, _ = c.nextRenewal() 191 if req != renewalReq1 { 192 t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq1, req) 193 } 194 if err := c.heap.Update(req, now.Add(80*time.Second)); err != nil { 195 t.Fatal(err) 196 } 197 198 req, _ = c.nextRenewal() 199 if req != renewalReq3 { 200 t.Fatalf("bad: expected: %#v, actual: %#v", renewalReq3, req) 201 } 202 if err := c.heap.Update(req, now.Add(90*time.Second)); err != nil { 203 t.Fatal(err) 204 } 205 206 if err := c.StopRenewToken("id1"); err != nil { 207 t.Fatal(err) 208 } 209 210 if err := c.StopRenewToken("id2"); err != nil { 211 t.Fatal(err) 212 } 213 214 if err := c.StopRenewToken("id3"); err != nil { 215 t.Fatal(err) 216 } 217 218 if c.isTracked("id1") { 219 t.Fatalf("id1 should not have been tracked") 220 } 221 222 if c.isTracked("id1") { 223 t.Fatalf("id1 should not have been tracked") 224 } 225 226 if c.isTracked("id1") { 227 t.Fatalf("id1 should not have been tracked") 228 } 229 230 } 231 232 func TestVaultClient_RenewNonRenewableLease(t *testing.T) { 233 ci.Parallel(t) 234 235 v := testutil.NewTestVault(t) 236 defer v.Stop() 237 238 logger := testlog.HCLogger(t) 239 v.Config.ConnectionRetryIntv = 100 * time.Millisecond 240 v.Config.TaskTokenTTL = "4s" 241 c, err := NewVaultClient(v.Config, logger, nil) 242 if err != nil { 243 t.Fatalf("failed to build vault client: %v", err) 244 } 245 246 c.Start() 247 defer c.Stop() 248 249 // Sleep a little while to ensure that the renewal loop is active 250 time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second) 251 252 tcr := &vaultapi.TokenCreateRequest{ 253 Policies: []string{"foo", "bar"}, 254 TTL: "2s", 255 DisplayName: "derived-for-task", 256 Renewable: new(bool), 257 } 258 259 c.client.SetToken(v.Config.Token) 260 261 if err := c.client.SetAddress(v.Config.Addr); err != nil { 262 t.Fatal(err) 263 } 264 265 secret, err := c.client.Auth().Token().Create(tcr) 266 if err != nil { 267 t.Fatalf("failed to create vault token: %v", err) 268 } 269 270 if secret == nil || secret.Auth == nil || secret.Auth.ClientToken == "" { 271 t.Fatal("failed to derive a wrapped vault token") 272 } 273 274 _, err = c.RenewToken(secret.Auth.ClientToken, secret.Auth.LeaseDuration) 275 if err == nil { 276 t.Fatalf("expected error, got nil") 277 } else if !strings.Contains(err.Error(), "lease is not renewable") { 278 t.Fatalf("expected \"%s\" in error message, got \"%v\"", "lease is not renewable", err) 279 } 280 } 281 282 func TestVaultClient_RenewNonexistentLease(t *testing.T) { 283 ci.Parallel(t) 284 285 v := testutil.NewTestVault(t) 286 defer v.Stop() 287 288 logger := testlog.HCLogger(t) 289 v.Config.ConnectionRetryIntv = 100 * time.Millisecond 290 v.Config.TaskTokenTTL = "4s" 291 c, err := NewVaultClient(v.Config, logger, nil) 292 if err != nil { 293 t.Fatalf("failed to build vault client: %v", err) 294 } 295 296 c.Start() 297 defer c.Stop() 298 299 // Sleep a little while to ensure that the renewal loop is active 300 time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second) 301 302 c.client.SetToken(v.Config.Token) 303 304 if err := c.client.SetAddress(v.Config.Addr); err != nil { 305 t.Fatal(err) 306 } 307 308 _, err = c.RenewToken(c.client.Token(), 10) 309 if err == nil { 310 t.Fatalf("expected error, got nil") 311 // The Vault error message changed between 0.10.2 and 1.0.1 312 } else if !strings.Contains(err.Error(), "lease not found") && !strings.Contains(err.Error(), "lease is not renewable") { 313 t.Fatalf("expected \"%s\" or \"%s\" in error message, got \"%v\"", "lease not found", "lease is not renewable", err.Error()) 314 } 315 } 316 317 // TestVaultClient_RenewalTime_Long asserts that for leases over 1m the renewal 318 // time is jittered. 319 func TestVaultClient_RenewalTime_Long(t *testing.T) { 320 ci.Parallel(t) 321 322 // highRoller is a randIntn func that always returns the max value 323 highRoller := func(n int) int { 324 return n - 1 325 } 326 327 // lowRoller is a randIntn func that always returns the min value (0) 328 lowRoller := func(int) int { 329 return 0 330 } 331 332 assert.Equal(t, 39*time.Second, renewalTime(highRoller, 60)) 333 assert.Equal(t, 20*time.Second, renewalTime(lowRoller, 60)) 334 335 assert.Equal(t, 309*time.Second, renewalTime(highRoller, 600)) 336 assert.Equal(t, 290*time.Second, renewalTime(lowRoller, 600)) 337 338 const days3 = 60 * 60 * 24 * 3 339 assert.Equal(t, (days3/2+9)*time.Second, renewalTime(highRoller, days3)) 340 assert.Equal(t, (days3/2-10)*time.Second, renewalTime(lowRoller, days3)) 341 } 342 343 // TestVaultClient_RenewalTime_Short asserts that for leases under 1m the renewal 344 // time is lease/2. 345 func TestVaultClient_RenewalTime_Short(t *testing.T) { 346 ci.Parallel(t) 347 348 dice := func(int) int { 349 require.Fail(t, "dice should not have been called") 350 panic("unreachable") 351 } 352 353 assert.Equal(t, 29*time.Second, renewalTime(dice, 58)) 354 assert.Equal(t, 15*time.Second, renewalTime(dice, 30)) 355 assert.Equal(t, 1*time.Second, renewalTime(dice, 2)) 356 }