github.com/lestrrat-go/jwx/v2@v2.0.21/jwk/refresh_test.go (about) 1 package jwk_test 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "net/http" 8 "net/http/httptest" 9 "sync" 10 "testing" 11 "time" 12 13 "github.com/lestrrat-go/jwx/v2/internal/json" 14 "github.com/lestrrat-go/jwx/v2/internal/jwxtest" 15 "github.com/lestrrat-go/jwx/v2/jwk" 16 "github.com/stretchr/testify/assert" 17 ) 18 19 //nolint:revive,golint 20 func checkAccessCount(t *testing.T, ctx context.Context, src jwk.Set, expected ...int) bool { 21 t.Helper() 22 23 iter := src.Keys(ctx) 24 iter.Next(ctx) 25 26 key := iter.Pair().Value.(jwk.Key) 27 v, ok := key.Get(`accessCount`) 28 if !assert.True(t, ok, `key.Get("accessCount") should succeed`) { 29 return false 30 } 31 32 for _, e := range expected { 33 if v == float64(e) { 34 return assert.Equal(t, float64(e), v, `key.Get("accessCount") should be %d`, e) 35 } 36 } 37 38 var buf bytes.Buffer 39 fmt.Fprint(&buf, "[") 40 for i, e := range expected { 41 fmt.Fprintf(&buf, "%d", e) 42 if i < len(expected)-1 { 43 fmt.Fprint(&buf, ", ") 44 } 45 } 46 fmt.Fprintf(&buf, "]") 47 return assert.Failf(t, `checking access count failed`, `key.Get("accessCount") should be one of %s (got %f)`, buf.String(), v) 48 } 49 50 func TestCache(t *testing.T) { 51 t.Parallel() 52 53 t.Run("CachedSet", func(t *testing.T) { 54 const numKeys = 3 55 t.Parallel() 56 ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 57 defer cancel() 58 59 set := jwk.NewSet() 60 for i := 0; i < numKeys; i++ { 61 key, err := jwxtest.GenerateRsaJwk() 62 if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk should succeed`) { 63 return 64 } 65 if !assert.NoError(t, set.AddKey(key), `set.AddKey should succeed`) { 66 return 67 } 68 } 69 70 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 71 hdrs := w.Header() 72 hdrs.Set(`Content-Type`, `application/json`) 73 hdrs.Set(`Cache-Control`, `max-age=5`) 74 75 json.NewEncoder(w).Encode(set) 76 })) 77 defer srv.Close() 78 79 af := jwk.NewCache(ctx, jwk.WithRefreshWindow(time.Second)) 80 if !assert.NoError(t, af.Register(srv.URL), `af.Register should succeed`) { 81 return 82 } 83 84 cached := jwk.NewCachedSet(af, srv.URL) 85 if !assert.Error(t, cached.Set("bogus", nil), `cached.Set should be an error`) { 86 return 87 } 88 if !assert.Error(t, cached.Remove("bogus"), `cached.Remove should be an error`) { 89 return 90 } 91 if !assert.Error(t, cached.AddKey(nil), `cached.AddKey should be an error`) { 92 return 93 } 94 if !assert.Error(t, cached.RemoveKey(nil), `cached.RemoveKey should be an error`) { 95 return 96 } 97 if !assert.Equal(t, set.Len(), cached.Len(), `value of Len() should be the same`) { 98 return 99 } 100 101 iter := set.Keys(ctx) 102 citer := cached.Keys(ctx) 103 for i := 0; i < numKeys; i++ { 104 k, err := set.Key(i) 105 ck, cerr := cached.Key(i) 106 if !assert.Equal(t, k, ck, `key %d should match`, i) { 107 return 108 } 109 if !assert.Equal(t, err, cerr, `error %d should match`, i) { 110 return 111 } 112 113 if !assert.Equal(t, iter.Next(ctx), citer.Next(ctx), `iter.Next should match`) { 114 return 115 } 116 117 if !assert.Equal(t, iter.Pair(), citer.Pair(), `iter.Pair should match`) { 118 return 119 } 120 } 121 }) 122 t.Run("Specify explicit refresh interval", func(t *testing.T) { 123 t.Parallel() 124 ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 125 defer cancel() 126 127 var accessCount int 128 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 129 accessCount++ 130 131 key := map[string]interface{}{ 132 "kty": "EC", 133 "crv": "P-256", 134 "x": "SVqB4JcUD6lsfvqMr-OKUNUphdNn64Eay60978ZlL74", 135 "y": "lf0u0pMj4lGAzZix5u4Cm5CMQIgMNpkwy163wtKYVKI", 136 "accessCount": accessCount, 137 } 138 hdrs := w.Header() 139 hdrs.Set(`Content-Type`, `application/json`) 140 hdrs.Set(`Cache-Control`, `max-age=7200`) // Make sure this is ignored 141 142 json.NewEncoder(w).Encode(key) 143 })) 144 defer srv.Close() 145 146 af := jwk.NewCache(ctx, jwk.WithRefreshWindow(time.Second)) 147 if !assert.NoError(t, af.Register(srv.URL, jwk.WithRefreshInterval(3*time.Second)), `af.Register should succeed`) { 148 return 149 } 150 151 retries := 5 152 153 var wg sync.WaitGroup 154 wg.Add(retries) 155 for i := 0; i < retries; i++ { 156 // Run these in separate goroutines to emulate a possible thundering herd 157 go func() { 158 defer wg.Done() 159 ks, err := af.Get(ctx, srv.URL) 160 if !assert.NoError(t, err, `af.Get should succeed`) { 161 return 162 } 163 if !checkAccessCount(t, ctx, ks, 1) { 164 return 165 } 166 }() 167 } 168 169 t.Logf("Waiting for fetching goroutines...") 170 wg.Wait() 171 t.Logf("Waiting for the refresh ...") 172 time.Sleep(4 * time.Second) 173 ks, err := af.Get(ctx, srv.URL) 174 if !assert.NoError(t, err, `af.Get should succeed`) { 175 return 176 } 177 if !checkAccessCount(t, ctx, ks, 2) { 178 return 179 } 180 }) 181 t.Run("Calculate next refresh from Cache-Control header", func(t *testing.T) { 182 t.Parallel() 183 ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 184 defer cancel() 185 186 var accessCount int 187 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 188 accessCount++ 189 190 key := map[string]interface{}{ 191 "kty": "EC", 192 "crv": "P-256", 193 "x": "SVqB4JcUD6lsfvqMr-OKUNUphdNn64Eay60978ZlL74", 194 "y": "lf0u0pMj4lGAzZix5u4Cm5CMQIgMNpkwy163wtKYVKI", 195 "accessCount": accessCount, 196 } 197 hdrs := w.Header() 198 hdrs.Set(`Content-Type`, `application/json`) 199 hdrs.Set(`Cache-Control`, `max-age=3`) 200 201 json.NewEncoder(w).Encode(key) 202 })) 203 defer srv.Close() 204 205 af := jwk.NewCache(ctx, jwk.WithRefreshWindow(time.Second)) 206 if !assert.NoError(t, af.Register(srv.URL, jwk.WithMinRefreshInterval(time.Second)), `af.Register should succeed`) { 207 return 208 } 209 210 if !assert.True(t, af.IsRegistered(srv.URL), `af.IsRegistered should be true`) { 211 return 212 } 213 214 retries := 5 215 216 var wg sync.WaitGroup 217 wg.Add(retries) 218 for i := 0; i < retries; i++ { 219 // Run these in separate goroutines to emulate a possible thundering herd 220 go func() { 221 defer wg.Done() 222 ks, err := af.Get(ctx, srv.URL) 223 if !assert.NoError(t, err, `af.Get should succeed`) { 224 return 225 } 226 227 if !checkAccessCount(t, ctx, ks, 1) { 228 return 229 } 230 }() 231 } 232 233 t.Logf("Waiting for fetching goroutines...") 234 wg.Wait() 235 t.Logf("Waiting for the refresh ...") 236 time.Sleep(4 * time.Second) 237 ks, err := af.Get(ctx, srv.URL) 238 if !assert.NoError(t, err, `af.Get should succeed`) { 239 return 240 } 241 if !checkAccessCount(t, ctx, ks, 2) { 242 return 243 } 244 }) 245 t.Run("Backoff", func(t *testing.T) { 246 t.Parallel() 247 ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 248 defer cancel() 249 250 var accessCount int 251 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 252 accessCount++ 253 if accessCount > 1 && accessCount < 4 { 254 http.Error(w, "wait for it....", http.StatusForbidden) 255 return 256 } 257 258 key := map[string]interface{}{ 259 "kty": "EC", 260 "crv": "P-256", 261 "x": "SVqB4JcUD6lsfvqMr-OKUNUphdNn64Eay60978ZlL74", 262 "y": "lf0u0pMj4lGAzZix5u4Cm5CMQIgMNpkwy163wtKYVKI", 263 "accessCount": accessCount, 264 } 265 hdrs := w.Header() 266 hdrs.Set(`Content-Type`, `application/json`) 267 hdrs.Set(`Cache-Control`, `max-age=1`) 268 269 json.NewEncoder(w).Encode(key) 270 })) 271 defer srv.Close() 272 273 af := jwk.NewCache(ctx, jwk.WithRefreshWindow(time.Second)) 274 af.Register(srv.URL, jwk.WithMinRefreshInterval(time.Second)) 275 276 // First fetch should succeed 277 ks, err := af.Get(ctx, srv.URL) 278 if !assert.NoError(t, err, `af.Get (#1) should succeed`) { 279 return 280 } 281 if !checkAccessCount(t, ctx, ks, 1) { 282 return 283 } 284 285 // enough time for 1 refresh to have occurred 286 time.Sleep(1500 * time.Millisecond) 287 ks, err = af.Get(ctx, srv.URL) 288 if !assert.NoError(t, err, `af.Get (#2) should succeed`) { 289 return 290 } 291 // Should be using the cached version 292 if !checkAccessCount(t, ctx, ks, 1) { 293 return 294 } 295 296 // enough time for 2 refreshes to have occurred 297 time.Sleep(2500 * time.Millisecond) 298 299 ks, err = af.Get(ctx, srv.URL) 300 if !assert.NoError(t, err, `af.Get (#3) should succeed`) { 301 return 302 } 303 // should be new 304 if !checkAccessCount(t, ctx, ks, 4, 5) { 305 return 306 } 307 }) 308 } 309 310 func TestRefreshSnapshot(t *testing.T) { 311 ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 312 defer cancel() 313 314 var jwksURLs []string 315 getJwksURL := func(dst *[]string, url string) bool { 316 req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) 317 if err != nil { 318 return false 319 } 320 321 res, err := http.DefaultClient.Do(req) 322 if err != nil { 323 return false 324 } 325 defer res.Body.Close() 326 327 var m map[string]interface{} 328 if err := json.NewDecoder(res.Body).Decode(&m); err != nil { 329 return false 330 } 331 332 jwksURL, ok := m["jwks_uri"] 333 if !ok { 334 return false 335 } 336 *dst = append(*dst, jwksURL.(string)) 337 return true 338 } 339 if !getJwksURL(&jwksURLs, "https://oidc-sample.onelogin.com/oidc/2/.well-known/openid-configuration") { 340 t.SkipNow() 341 } 342 if !getJwksURL(&jwksURLs, "https://accounts.google.com/.well-known/openid-configuration") { 343 t.SkipNow() 344 } 345 346 ar := jwk.NewCache(ctx, jwk.WithRefreshWindow(time.Second)) 347 for _, url := range jwksURLs { 348 if !assert.NoError(t, ar.Register(url), `ar.Register should succeed`) { 349 return 350 } 351 } 352 353 for _, url := range jwksURLs { 354 _ = ar.Unregister(url) 355 } 356 357 for _, target := range ar.Snapshot().Entries { 358 t.Logf("%s last refreshed at %s", target.URL, target.LastFetched) 359 } 360 361 for _, url := range jwksURLs { 362 ar.Unregister(url) 363 } 364 365 if !assert.Len(t, ar.Snapshot().Entries, 0, `there should be no URLs`) { 366 return 367 } 368 369 if !assert.Error(t, ar.Unregister(`dummy`), `removing a non-existing url should be an error`) { 370 return 371 } 372 } 373 374 type accumulateErrs struct { 375 mu sync.RWMutex 376 errs []error 377 } 378 379 func (e *accumulateErrs) Error(err error) { 380 e.mu.Lock() 381 e.errs = append(e.errs, err) 382 e.mu.Unlock() 383 } 384 385 func (e *accumulateErrs) Len() int { 386 e.mu.RLock() 387 l := len(e.errs) 388 e.mu.RUnlock() 389 return l 390 } 391 func TestErrorSink(t *testing.T) { 392 t.Parallel() 393 394 k, err := jwxtest.GenerateRsaJwk() 395 if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk should succeed`) { 396 return 397 } 398 set := jwk.NewSet() 399 _ = set.AddKey(k) 400 testcases := []struct { 401 Name string 402 Options func() []jwk.RegisterOption 403 Handler http.Handler 404 }{ 405 /* 406 { 407 Name: "non-200 response", 408 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 409 w.WriteHeader(http.StatusForbidden) 410 }), 411 }, 412 { 413 Name: "invalid JWK", 414 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 415 w.WriteHeader(http.StatusOK) 416 w.Write([]byte(`{"empty": "nonthingness"}`)) 417 }), 418 }, 419 */ 420 { 421 Name: `rejected by whitelist`, 422 Options: func() []jwk.RegisterOption { 423 return []jwk.RegisterOption{ 424 jwk.WithFetchWhitelist(jwk.WhitelistFunc(func(_ string) bool { 425 return false 426 })), 427 } 428 }, 429 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 430 w.WriteHeader(http.StatusOK) 431 json.NewEncoder(w).Encode(k) 432 }), 433 }, 434 } 435 436 for _, tc := range testcases { 437 tc := tc 438 t.Run(tc.Name, func(t *testing.T) { 439 t.Parallel() 440 ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 441 defer cancel() 442 srv := httptest.NewServer(tc.Handler) 443 defer srv.Close() 444 445 var errSink accumulateErrs 446 ar := jwk.NewCache(ctx, jwk.WithErrSink(&errSink), jwk.WithRefreshWindow(time.Second)) 447 448 var options []jwk.RegisterOption 449 if f := tc.Options; f != nil { 450 options = f() 451 } 452 options = append(options, jwk.WithRefreshInterval(time.Second)) 453 if !assert.NoError(t, ar.Register(srv.URL, options...), `ar.Register should succeed`) { 454 return 455 } 456 457 _, _ = ar.Get(ctx, srv.URL) 458 459 timer := time.NewTimer(6 * time.Second) 460 461 select { 462 case <-ctx.Done(): 463 t.Errorf(`ctx.Done before timer`) 464 case <-timer.C: 465 } 466 467 cancel() // forcefully end context, and thus the Cache 468 469 // timing issues can cause this to be non-deterministic... 470 // we'll say it's okay as long as we're in +/- 1 range 471 l := errSink.Len() 472 if !assert.True(t, l <= 7, "number of errors shold be less than or equal to 7 (%d)", l) { 473 return 474 } 475 if !assert.True(t, l >= 5, "number of errors shold be greather than or equal to 5 (%d)", l) { 476 return 477 } 478 }) 479 } 480 } 481 482 func TestPostFetch(t *testing.T) { 483 t.Parallel() 484 485 set := jwk.NewSet() 486 for i := 0; i < 3; i++ { 487 key, err := jwk.FromRaw([]byte(fmt.Sprintf(`abracadabra-%d`, i))) 488 if !assert.NoError(t, err, `jwk.FromRaw should succeed`) { 489 return 490 } 491 _ = set.AddKey(key) 492 } 493 494 testcases := []struct { 495 Name string 496 Options []jwk.RegisterOption 497 ExpectKid bool 498 }{ 499 { 500 Name: "No PostFetch", 501 }, 502 { 503 Name: "With PostFetch", 504 Options: []jwk.RegisterOption{jwk.WithPostFetcher(jwk.PostFetchFunc(func(_ string, set jwk.Set) (jwk.Set, error) { 505 for i := 0; i < set.Len(); i++ { 506 key, _ := set.Key(i) 507 key.Set(jwk.KeyIDKey, fmt.Sprintf(`key-%d`, i)) 508 } 509 return set, nil 510 }))}, 511 ExpectKid: true, 512 }, 513 } 514 515 for _, tc := range testcases { 516 tc := tc 517 t.Run(tc.Name, func(t *testing.T) { 518 t.Parallel() 519 520 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 521 w.WriteHeader(http.StatusOK) 522 json.NewEncoder(w).Encode(set) 523 })) 524 defer srv.Close() 525 526 ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 527 defer cancel() 528 529 ar := jwk.NewCache(ctx) 530 531 ar.Register(srv.URL, tc.Options...) 532 set, err := ar.Get(ctx, srv.URL) 533 if !assert.NoError(t, err, `ar.Fetch should succeed`) { 534 return 535 } 536 537 for i := 0; i < set.Len(); i++ { 538 key, _ := set.Key(i) 539 if tc.ExpectKid { 540 if !assert.NotEmpty(t, key.KeyID(), `key.KeyID should not be empty`) { 541 return 542 } 543 } else { 544 if !assert.Empty(t, key.KeyID(), `key.KeyID should be empty`) { 545 return 546 } 547 } 548 } 549 }) 550 } 551 }