github.com/lestrrat-go/jwx/v2@v2.0.21/jwk/refresh_test.go (about)

     1  package jwk_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"sync"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/lestrrat-go/jwx/v2/internal/json"
    14  	"github.com/lestrrat-go/jwx/v2/internal/jwxtest"
    15  	"github.com/lestrrat-go/jwx/v2/jwk"
    16  	"github.com/stretchr/testify/assert"
    17  )
    18  
    19  //nolint:revive,golint
    20  func checkAccessCount(t *testing.T, ctx context.Context, src jwk.Set, expected ...int) bool {
    21  	t.Helper()
    22  
    23  	iter := src.Keys(ctx)
    24  	iter.Next(ctx)
    25  
    26  	key := iter.Pair().Value.(jwk.Key)
    27  	v, ok := key.Get(`accessCount`)
    28  	if !assert.True(t, ok, `key.Get("accessCount") should succeed`) {
    29  		return false
    30  	}
    31  
    32  	for _, e := range expected {
    33  		if v == float64(e) {
    34  			return assert.Equal(t, float64(e), v, `key.Get("accessCount") should be %d`, e)
    35  		}
    36  	}
    37  
    38  	var buf bytes.Buffer
    39  	fmt.Fprint(&buf, "[")
    40  	for i, e := range expected {
    41  		fmt.Fprintf(&buf, "%d", e)
    42  		if i < len(expected)-1 {
    43  			fmt.Fprint(&buf, ", ")
    44  		}
    45  	}
    46  	fmt.Fprintf(&buf, "]")
    47  	return assert.Failf(t, `checking access count failed`, `key.Get("accessCount") should be one of %s (got %f)`, buf.String(), v)
    48  }
    49  
    50  func TestCache(t *testing.T) {
    51  	t.Parallel()
    52  
    53  	t.Run("CachedSet", func(t *testing.T) {
    54  		const numKeys = 3
    55  		t.Parallel()
    56  		ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
    57  		defer cancel()
    58  
    59  		set := jwk.NewSet()
    60  		for i := 0; i < numKeys; i++ {
    61  			key, err := jwxtest.GenerateRsaJwk()
    62  			if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk should succeed`) {
    63  				return
    64  			}
    65  			if !assert.NoError(t, set.AddKey(key), `set.AddKey should succeed`) {
    66  				return
    67  			}
    68  		}
    69  
    70  		srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    71  			hdrs := w.Header()
    72  			hdrs.Set(`Content-Type`, `application/json`)
    73  			hdrs.Set(`Cache-Control`, `max-age=5`)
    74  
    75  			json.NewEncoder(w).Encode(set)
    76  		}))
    77  		defer srv.Close()
    78  
    79  		af := jwk.NewCache(ctx, jwk.WithRefreshWindow(time.Second))
    80  		if !assert.NoError(t, af.Register(srv.URL), `af.Register should succeed`) {
    81  			return
    82  		}
    83  
    84  		cached := jwk.NewCachedSet(af, srv.URL)
    85  		if !assert.Error(t, cached.Set("bogus", nil), `cached.Set should be an error`) {
    86  			return
    87  		}
    88  		if !assert.Error(t, cached.Remove("bogus"), `cached.Remove should be an error`) {
    89  			return
    90  		}
    91  		if !assert.Error(t, cached.AddKey(nil), `cached.AddKey should be an error`) {
    92  			return
    93  		}
    94  		if !assert.Error(t, cached.RemoveKey(nil), `cached.RemoveKey should be an error`) {
    95  			return
    96  		}
    97  		if !assert.Equal(t, set.Len(), cached.Len(), `value of Len() should be the same`) {
    98  			return
    99  		}
   100  
   101  		iter := set.Keys(ctx)
   102  		citer := cached.Keys(ctx)
   103  		for i := 0; i < numKeys; i++ {
   104  			k, err := set.Key(i)
   105  			ck, cerr := cached.Key(i)
   106  			if !assert.Equal(t, k, ck, `key %d should match`, i) {
   107  				return
   108  			}
   109  			if !assert.Equal(t, err, cerr, `error %d should match`, i) {
   110  				return
   111  			}
   112  
   113  			if !assert.Equal(t, iter.Next(ctx), citer.Next(ctx), `iter.Next should match`) {
   114  				return
   115  			}
   116  
   117  			if !assert.Equal(t, iter.Pair(), citer.Pair(), `iter.Pair should match`) {
   118  				return
   119  			}
   120  		}
   121  	})
   122  	t.Run("Specify explicit refresh interval", func(t *testing.T) {
   123  		t.Parallel()
   124  		ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
   125  		defer cancel()
   126  
   127  		var accessCount int
   128  		srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   129  			accessCount++
   130  
   131  			key := map[string]interface{}{
   132  				"kty":         "EC",
   133  				"crv":         "P-256",
   134  				"x":           "SVqB4JcUD6lsfvqMr-OKUNUphdNn64Eay60978ZlL74",
   135  				"y":           "lf0u0pMj4lGAzZix5u4Cm5CMQIgMNpkwy163wtKYVKI",
   136  				"accessCount": accessCount,
   137  			}
   138  			hdrs := w.Header()
   139  			hdrs.Set(`Content-Type`, `application/json`)
   140  			hdrs.Set(`Cache-Control`, `max-age=7200`) // Make sure this is ignored
   141  
   142  			json.NewEncoder(w).Encode(key)
   143  		}))
   144  		defer srv.Close()
   145  
   146  		af := jwk.NewCache(ctx, jwk.WithRefreshWindow(time.Second))
   147  		if !assert.NoError(t, af.Register(srv.URL, jwk.WithRefreshInterval(3*time.Second)), `af.Register should succeed`) {
   148  			return
   149  		}
   150  
   151  		retries := 5
   152  
   153  		var wg sync.WaitGroup
   154  		wg.Add(retries)
   155  		for i := 0; i < retries; i++ {
   156  			// Run these in separate goroutines to emulate a possible thundering herd
   157  			go func() {
   158  				defer wg.Done()
   159  				ks, err := af.Get(ctx, srv.URL)
   160  				if !assert.NoError(t, err, `af.Get should succeed`) {
   161  					return
   162  				}
   163  				if !checkAccessCount(t, ctx, ks, 1) {
   164  					return
   165  				}
   166  			}()
   167  		}
   168  
   169  		t.Logf("Waiting for fetching goroutines...")
   170  		wg.Wait()
   171  		t.Logf("Waiting for the refresh ...")
   172  		time.Sleep(4 * time.Second)
   173  		ks, err := af.Get(ctx, srv.URL)
   174  		if !assert.NoError(t, err, `af.Get should succeed`) {
   175  			return
   176  		}
   177  		if !checkAccessCount(t, ctx, ks, 2) {
   178  			return
   179  		}
   180  	})
   181  	t.Run("Calculate next refresh from Cache-Control header", func(t *testing.T) {
   182  		t.Parallel()
   183  		ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
   184  		defer cancel()
   185  
   186  		var accessCount int
   187  		srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   188  			accessCount++
   189  
   190  			key := map[string]interface{}{
   191  				"kty":         "EC",
   192  				"crv":         "P-256",
   193  				"x":           "SVqB4JcUD6lsfvqMr-OKUNUphdNn64Eay60978ZlL74",
   194  				"y":           "lf0u0pMj4lGAzZix5u4Cm5CMQIgMNpkwy163wtKYVKI",
   195  				"accessCount": accessCount,
   196  			}
   197  			hdrs := w.Header()
   198  			hdrs.Set(`Content-Type`, `application/json`)
   199  			hdrs.Set(`Cache-Control`, `max-age=3`)
   200  
   201  			json.NewEncoder(w).Encode(key)
   202  		}))
   203  		defer srv.Close()
   204  
   205  		af := jwk.NewCache(ctx, jwk.WithRefreshWindow(time.Second))
   206  		if !assert.NoError(t, af.Register(srv.URL, jwk.WithMinRefreshInterval(time.Second)), `af.Register should succeed`) {
   207  			return
   208  		}
   209  
   210  		if !assert.True(t, af.IsRegistered(srv.URL), `af.IsRegistered should be true`) {
   211  			return
   212  		}
   213  
   214  		retries := 5
   215  
   216  		var wg sync.WaitGroup
   217  		wg.Add(retries)
   218  		for i := 0; i < retries; i++ {
   219  			// Run these in separate goroutines to emulate a possible thundering herd
   220  			go func() {
   221  				defer wg.Done()
   222  				ks, err := af.Get(ctx, srv.URL)
   223  				if !assert.NoError(t, err, `af.Get should succeed`) {
   224  					return
   225  				}
   226  
   227  				if !checkAccessCount(t, ctx, ks, 1) {
   228  					return
   229  				}
   230  			}()
   231  		}
   232  
   233  		t.Logf("Waiting for fetching goroutines...")
   234  		wg.Wait()
   235  		t.Logf("Waiting for the refresh ...")
   236  		time.Sleep(4 * time.Second)
   237  		ks, err := af.Get(ctx, srv.URL)
   238  		if !assert.NoError(t, err, `af.Get should succeed`) {
   239  			return
   240  		}
   241  		if !checkAccessCount(t, ctx, ks, 2) {
   242  			return
   243  		}
   244  	})
   245  	t.Run("Backoff", func(t *testing.T) {
   246  		t.Parallel()
   247  		ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
   248  		defer cancel()
   249  
   250  		var accessCount int
   251  		srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   252  			accessCount++
   253  			if accessCount > 1 && accessCount < 4 {
   254  				http.Error(w, "wait for it....", http.StatusForbidden)
   255  				return
   256  			}
   257  
   258  			key := map[string]interface{}{
   259  				"kty":         "EC",
   260  				"crv":         "P-256",
   261  				"x":           "SVqB4JcUD6lsfvqMr-OKUNUphdNn64Eay60978ZlL74",
   262  				"y":           "lf0u0pMj4lGAzZix5u4Cm5CMQIgMNpkwy163wtKYVKI",
   263  				"accessCount": accessCount,
   264  			}
   265  			hdrs := w.Header()
   266  			hdrs.Set(`Content-Type`, `application/json`)
   267  			hdrs.Set(`Cache-Control`, `max-age=1`)
   268  
   269  			json.NewEncoder(w).Encode(key)
   270  		}))
   271  		defer srv.Close()
   272  
   273  		af := jwk.NewCache(ctx, jwk.WithRefreshWindow(time.Second))
   274  		af.Register(srv.URL, jwk.WithMinRefreshInterval(time.Second))
   275  
   276  		// First fetch should succeed
   277  		ks, err := af.Get(ctx, srv.URL)
   278  		if !assert.NoError(t, err, `af.Get (#1) should succeed`) {
   279  			return
   280  		}
   281  		if !checkAccessCount(t, ctx, ks, 1) {
   282  			return
   283  		}
   284  
   285  		// enough time for 1 refresh to have occurred
   286  		time.Sleep(1500 * time.Millisecond)
   287  		ks, err = af.Get(ctx, srv.URL)
   288  		if !assert.NoError(t, err, `af.Get (#2) should succeed`) {
   289  			return
   290  		}
   291  		// Should be using the cached version
   292  		if !checkAccessCount(t, ctx, ks, 1) {
   293  			return
   294  		}
   295  
   296  		// enough time for 2 refreshes to have occurred
   297  		time.Sleep(2500 * time.Millisecond)
   298  
   299  		ks, err = af.Get(ctx, srv.URL)
   300  		if !assert.NoError(t, err, `af.Get (#3) should succeed`) {
   301  			return
   302  		}
   303  		// should be new
   304  		if !checkAccessCount(t, ctx, ks, 4, 5) {
   305  			return
   306  		}
   307  	})
   308  }
   309  
   310  func TestRefreshSnapshot(t *testing.T) {
   311  	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
   312  	defer cancel()
   313  
   314  	var jwksURLs []string
   315  	getJwksURL := func(dst *[]string, url string) bool {
   316  		req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
   317  		if err != nil {
   318  			return false
   319  		}
   320  
   321  		res, err := http.DefaultClient.Do(req)
   322  		if err != nil {
   323  			return false
   324  		}
   325  		defer res.Body.Close()
   326  
   327  		var m map[string]interface{}
   328  		if err := json.NewDecoder(res.Body).Decode(&m); err != nil {
   329  			return false
   330  		}
   331  
   332  		jwksURL, ok := m["jwks_uri"]
   333  		if !ok {
   334  			return false
   335  		}
   336  		*dst = append(*dst, jwksURL.(string))
   337  		return true
   338  	}
   339  	if !getJwksURL(&jwksURLs, "https://oidc-sample.onelogin.com/oidc/2/.well-known/openid-configuration") {
   340  		t.SkipNow()
   341  	}
   342  	if !getJwksURL(&jwksURLs, "https://accounts.google.com/.well-known/openid-configuration") {
   343  		t.SkipNow()
   344  	}
   345  
   346  	ar := jwk.NewCache(ctx, jwk.WithRefreshWindow(time.Second))
   347  	for _, url := range jwksURLs {
   348  		if !assert.NoError(t, ar.Register(url), `ar.Register should succeed`) {
   349  			return
   350  		}
   351  	}
   352  
   353  	for _, url := range jwksURLs {
   354  		_ = ar.Unregister(url)
   355  	}
   356  
   357  	for _, target := range ar.Snapshot().Entries {
   358  		t.Logf("%s last refreshed at %s", target.URL, target.LastFetched)
   359  	}
   360  
   361  	for _, url := range jwksURLs {
   362  		ar.Unregister(url)
   363  	}
   364  
   365  	if !assert.Len(t, ar.Snapshot().Entries, 0, `there should be no URLs`) {
   366  		return
   367  	}
   368  
   369  	if !assert.Error(t, ar.Unregister(`dummy`), `removing a non-existing url should be an error`) {
   370  		return
   371  	}
   372  }
   373  
   374  type accumulateErrs struct {
   375  	mu   sync.RWMutex
   376  	errs []error
   377  }
   378  
   379  func (e *accumulateErrs) Error(err error) {
   380  	e.mu.Lock()
   381  	e.errs = append(e.errs, err)
   382  	e.mu.Unlock()
   383  }
   384  
   385  func (e *accumulateErrs) Len() int {
   386  	e.mu.RLock()
   387  	l := len(e.errs)
   388  	e.mu.RUnlock()
   389  	return l
   390  }
   391  func TestErrorSink(t *testing.T) {
   392  	t.Parallel()
   393  
   394  	k, err := jwxtest.GenerateRsaJwk()
   395  	if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk should succeed`) {
   396  		return
   397  	}
   398  	set := jwk.NewSet()
   399  	_ = set.AddKey(k)
   400  	testcases := []struct {
   401  		Name    string
   402  		Options func() []jwk.RegisterOption
   403  		Handler http.Handler
   404  	}{
   405  		/*
   406  			{
   407  				Name: "non-200 response",
   408  				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   409  					w.WriteHeader(http.StatusForbidden)
   410  				}),
   411  			},
   412  			{
   413  				Name: "invalid JWK",
   414  				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   415  					w.WriteHeader(http.StatusOK)
   416  					w.Write([]byte(`{"empty": "nonthingness"}`))
   417  				}),
   418  			},
   419  		*/
   420  		{
   421  			Name: `rejected by whitelist`,
   422  			Options: func() []jwk.RegisterOption {
   423  				return []jwk.RegisterOption{
   424  					jwk.WithFetchWhitelist(jwk.WhitelistFunc(func(_ string) bool {
   425  						return false
   426  					})),
   427  				}
   428  			},
   429  			Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   430  				w.WriteHeader(http.StatusOK)
   431  				json.NewEncoder(w).Encode(k)
   432  			}),
   433  		},
   434  	}
   435  
   436  	for _, tc := range testcases {
   437  		tc := tc
   438  		t.Run(tc.Name, func(t *testing.T) {
   439  			t.Parallel()
   440  			ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
   441  			defer cancel()
   442  			srv := httptest.NewServer(tc.Handler)
   443  			defer srv.Close()
   444  
   445  			var errSink accumulateErrs
   446  			ar := jwk.NewCache(ctx, jwk.WithErrSink(&errSink), jwk.WithRefreshWindow(time.Second))
   447  
   448  			var options []jwk.RegisterOption
   449  			if f := tc.Options; f != nil {
   450  				options = f()
   451  			}
   452  			options = append(options, jwk.WithRefreshInterval(time.Second))
   453  			if !assert.NoError(t, ar.Register(srv.URL, options...), `ar.Register should succeed`) {
   454  				return
   455  			}
   456  
   457  			_, _ = ar.Get(ctx, srv.URL)
   458  
   459  			timer := time.NewTimer(6 * time.Second)
   460  
   461  			select {
   462  			case <-ctx.Done():
   463  				t.Errorf(`ctx.Done before timer`)
   464  			case <-timer.C:
   465  			}
   466  
   467  			cancel() // forcefully end context, and thus the Cache
   468  
   469  			// timing issues can cause this to be non-deterministic...
   470  			// we'll say it's okay as long as we're in +/- 1 range
   471  			l := errSink.Len()
   472  			if !assert.True(t, l <= 7, "number of errors shold be less than or equal to 7 (%d)", l) {
   473  				return
   474  			}
   475  			if !assert.True(t, l >= 5, "number of errors shold be greather than or equal to 5 (%d)", l) {
   476  				return
   477  			}
   478  		})
   479  	}
   480  }
   481  
   482  func TestPostFetch(t *testing.T) {
   483  	t.Parallel()
   484  
   485  	set := jwk.NewSet()
   486  	for i := 0; i < 3; i++ {
   487  		key, err := jwk.FromRaw([]byte(fmt.Sprintf(`abracadabra-%d`, i)))
   488  		if !assert.NoError(t, err, `jwk.FromRaw should succeed`) {
   489  			return
   490  		}
   491  		_ = set.AddKey(key)
   492  	}
   493  
   494  	testcases := []struct {
   495  		Name      string
   496  		Options   []jwk.RegisterOption
   497  		ExpectKid bool
   498  	}{
   499  		{
   500  			Name: "No PostFetch",
   501  		},
   502  		{
   503  			Name: "With PostFetch",
   504  			Options: []jwk.RegisterOption{jwk.WithPostFetcher(jwk.PostFetchFunc(func(_ string, set jwk.Set) (jwk.Set, error) {
   505  				for i := 0; i < set.Len(); i++ {
   506  					key, _ := set.Key(i)
   507  					key.Set(jwk.KeyIDKey, fmt.Sprintf(`key-%d`, i))
   508  				}
   509  				return set, nil
   510  			}))},
   511  			ExpectKid: true,
   512  		},
   513  	}
   514  
   515  	for _, tc := range testcases {
   516  		tc := tc
   517  		t.Run(tc.Name, func(t *testing.T) {
   518  			t.Parallel()
   519  
   520  			srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   521  				w.WriteHeader(http.StatusOK)
   522  				json.NewEncoder(w).Encode(set)
   523  			}))
   524  			defer srv.Close()
   525  
   526  			ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
   527  			defer cancel()
   528  
   529  			ar := jwk.NewCache(ctx)
   530  
   531  			ar.Register(srv.URL, tc.Options...)
   532  			set, err := ar.Get(ctx, srv.URL)
   533  			if !assert.NoError(t, err, `ar.Fetch should succeed`) {
   534  				return
   535  			}
   536  
   537  			for i := 0; i < set.Len(); i++ {
   538  				key, _ := set.Key(i)
   539  				if tc.ExpectKid {
   540  					if !assert.NotEmpty(t, key.KeyID(), `key.KeyID should not be empty`) {
   541  						return
   542  					}
   543  				} else {
   544  					if !assert.Empty(t, key.KeyID(), `key.KeyID should be empty`) {
   545  						return
   546  					}
   547  				}
   548  			}
   549  		})
   550  	}
   551  }