istio.io/istio@v0.0.0-20240520182934-d79c90f27776/pilot/pkg/model/jwks_resolver.go (about)

     1  // Copyright Istio Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package model
    16  
    17  import (
    18  	"context"
    19  	"crypto/tls"
    20  	"crypto/x509"
    21  	"encoding/json"
    22  	"errors"
    23  	"fmt"
    24  	"io"
    25  	"net/http"
    26  	"net/url"
    27  	"os"
    28  	"reflect"
    29  	"sort"
    30  	"strconv"
    31  	"strings"
    32  	"sync"
    33  	"sync/atomic"
    34  	"time"
    35  
    36  	core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
    37  	envoy_jwt "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/jwt_authn/v3"
    38  
    39  	"istio.io/istio/pilot/pkg/features"
    40  	"istio.io/istio/pkg/monitoring"
    41  )
    42  
    43  const (
    44  	// https://openid.net/specs/openid-connect-discovery-1_0.html
    45  	// OpenID Providers supporting Discovery MUST make a JSON document available at the path
    46  	// formed by concatenating the string /.well-known/openid-configuration to the Issuer.
    47  	openIDDiscoveryCfgURLSuffix = "/.well-known/openid-configuration"
    48  
    49  	// JwtPubKeyEvictionDuration is the life duration for cached item.
    50  	// Cached item will be removed from the cache if it hasn't been used longer than JwtPubKeyEvictionDuration or if pilot
    51  	// has failed to refresh it for more than JwtPubKeyEvictionDuration.
    52  	JwtPubKeyEvictionDuration = 24 * 7 * time.Hour
    53  
    54  	// JwtPubKeyRefreshIntervalOnFailure is the running interval of JWT pubKey refresh job on failure.
    55  	JwtPubKeyRefreshIntervalOnFailure = time.Minute
    56  
    57  	// JwtPubKeyRetryInterval is the retry interval between the attempt to retry getting the remote
    58  	// content from network.
    59  	JwtPubKeyRetryInterval = time.Second
    60  
    61  	// JwtPubKeyRefreshIntervalOnFailureResetThreshold is the threshold to reset the refresh interval on failure.
    62  	JwtPubKeyRefreshIntervalOnFailureResetThreshold = 60 * time.Minute
    63  
    64  	// How many times should we retry the failed network fetch on main flow. The main flow
    65  	// means it's called when Pilot is pushing configs. Do not retry to make sure not to block Pilot
    66  	// too long.
    67  	networkFetchRetryCountOnMainFlow = 0
    68  
    69  	// How many times should we retry the failed network fetch on refresh flow. The refresh flow
    70  	// means it's called when the periodically refresh job is triggered. We can retry more aggressively
    71  	// as it's running separately from the main flow.
    72  	networkFetchRetryCountOnRefreshFlow = 7
    73  
    74  	// jwksExtraRootCABundlePath is the path to any additional CA certificates pilot should accept when resolving JWKS URIs
    75  	jwksExtraRootCABundlePath = "/cacerts/extra.pem"
    76  )
    77  
    78  var (
    79  	// Close channel
    80  	closeChan = make(chan bool)
    81  
    82  	networkFetchSuccessCounter = monitoring.NewSum(
    83  		"pilot_jwks_resolver_network_fetch_success_total",
    84  		"Total number of successfully network fetch by pilot jwks resolver",
    85  	)
    86  	networkFetchFailCounter = monitoring.NewSum(
    87  		"pilot_jwks_resolver_network_fetch_fail_total",
    88  		"Total number of failed network fetch by pilot jwks resolver",
    89  	)
    90  
    91  	// JwtPubKeyRefreshInterval is the running interval of JWT pubKey refresh job.
    92  	JwtPubKeyRefreshInterval = features.PilotJwtPubKeyRefreshInterval
    93  
    94  	// channel for making jwksuri request aynsc
    95  	jwksuriChannel = make(chan jwtKey, 5)
    96  )
    97  
    98  // jwtPubKeyEntry is a single cached entry for jwt public key and the http context options
    99  type jwtPubKeyEntry struct {
   100  	pubKey string
   101  
   102  	// The last success refreshed time of the pubKey.
   103  	lastRefreshedTime time.Time
   104  
   105  	// Cached item's last used time, which is set in GetPublicKey.
   106  	lastUsedTime time.Time
   107  
   108  	// OpenID Discovery web request timeout
   109  	timeout time.Duration
   110  }
   111  
   112  // jwtKey is a key in the JwksResolver keyEntries map.
   113  type jwtKey struct {
   114  	jwksURI string
   115  	issuer  string
   116  }
   117  
   118  // JwksResolver is resolver for jwksURI and jwt public key.
   119  type JwksResolver struct {
   120  	// Callback function to invoke when detecting jwt public key change.
   121  	PushFunc func()
   122  
   123  	// cache for JWT public key.
   124  	// map key is jwtKey, map value is jwtPubKeyEntry.
   125  	keyEntries sync.Map
   126  
   127  	secureHTTPClient *http.Client
   128  	httpClient       *http.Client
   129  	refreshTicker    *time.Ticker
   130  
   131  	// Cached key will be removed from cache if (time.now - cachedItem.lastUsedTime >= evictionDuration), this prevents key cache growing indefinitely.
   132  	evictionDuration time.Duration
   133  
   134  	// Refresher job running interval.
   135  	refreshInterval time.Duration
   136  
   137  	// Refresher job running interval on failure.
   138  	refreshIntervalOnFailure time.Duration
   139  
   140  	// Refresher job default running interval without failure.
   141  	refreshDefaultInterval time.Duration
   142  
   143  	retryInterval time.Duration
   144  
   145  	// How many times refresh job has detected JWT public key change happened, used in unit test.
   146  	refreshJobKeyChangedCount uint64
   147  
   148  	// How many times refresh job failed to fetch the public key from network, used in unit test.
   149  	refreshJobFetchFailedCount uint64
   150  
   151  	// Whenever istiod fails to fetch the pubkey from jwksuri in main flow this variable becomes true for background trigger
   152  	jwksUribackgroundChannel bool
   153  }
   154  
   155  func NewJwksResolver(evictionDuration, refreshDefaultInterval, refreshIntervalOnFailure, retryInterval time.Duration) *JwksResolver {
   156  	return newJwksResolverWithCABundlePaths(
   157  		evictionDuration,
   158  		refreshDefaultInterval,
   159  		refreshIntervalOnFailure,
   160  		retryInterval,
   161  		[]string{jwksExtraRootCABundlePath},
   162  	)
   163  }
   164  
   165  func newJwksResolverWithCABundlePaths(
   166  	evictionDuration,
   167  	refreshDefaultInterval,
   168  	refreshIntervalOnFailure,
   169  	retryInterval time.Duration,
   170  	caBundlePaths []string,
   171  ) *JwksResolver {
   172  	ret := &JwksResolver{
   173  		evictionDuration:         evictionDuration,
   174  		refreshInterval:          refreshDefaultInterval,
   175  		refreshDefaultInterval:   refreshDefaultInterval,
   176  		refreshIntervalOnFailure: refreshIntervalOnFailure,
   177  		retryInterval:            retryInterval,
   178  		httpClient: &http.Client{
   179  			Transport: &http.Transport{
   180  				Proxy:             http.ProxyFromEnvironment,
   181  				DisableKeepAlives: true,
   182  			},
   183  		},
   184  	}
   185  
   186  	caCertPool, err := x509.SystemCertPool()
   187  	caCertsFound := true
   188  	if err != nil {
   189  		caCertsFound = false
   190  		log.Errorf("Failed to fetch Cert from SystemCertPool: %v", err)
   191  	}
   192  
   193  	if caCertPool != nil {
   194  		for _, pemFile := range caBundlePaths {
   195  			caCert, err := os.ReadFile(pemFile)
   196  			if err == nil {
   197  				caCertsFound = caCertPool.AppendCertsFromPEM(caCert) || caCertsFound
   198  			}
   199  		}
   200  	}
   201  
   202  	if caCertsFound {
   203  		ret.secureHTTPClient = &http.Client{
   204  			Transport: &http.Transport{
   205  				Proxy:             http.ProxyFromEnvironment,
   206  				DisableKeepAlives: true,
   207  				TLSClientConfig: &tls.Config{
   208  					// nolint: gosec // user explicitly opted into insecure
   209  					InsecureSkipVerify: features.JwksResolverInsecureSkipVerify,
   210  					RootCAs:            caCertPool,
   211  					MinVersion:         tls.VersionTLS12,
   212  				},
   213  			},
   214  		}
   215  	}
   216  
   217  	atomic.StoreUint64(&ret.refreshJobKeyChangedCount, 0)
   218  	atomic.StoreUint64(&ret.refreshJobFetchFailedCount, 0)
   219  	go ret.refresher()
   220  
   221  	return ret
   222  }
   223  
   224  var errEmptyPubKeyFoundInCache = errors.New("empty public key found in cache")
   225  
   226  // GetPublicKey returns the JWT public key if it is available in the cache
   227  // or fetch with from jwksuri if there is a error while fetching then it adds the
   228  // jwksURI in the cache to fetch the public key in the background process
   229  func (r *JwksResolver) GetPublicKey(issuer string, jwksURI string, timeout time.Duration) (string, error) {
   230  	now := time.Now()
   231  	key := jwtKey{issuer: issuer, jwksURI: jwksURI}
   232  	if val, found := r.keyEntries.Load(key); found {
   233  		e := val.(jwtPubKeyEntry)
   234  		// Update cached key's last used time.
   235  		e.lastUsedTime = now
   236  		e.timeout = timeout
   237  		r.keyEntries.Store(key, e)
   238  		if e.pubKey == "" {
   239  			return e.pubKey, errEmptyPubKeyFoundInCache
   240  		}
   241  		return e.pubKey, nil
   242  	}
   243  
   244  	var err error
   245  	var pubKey string
   246  	if jwksURI == "" {
   247  		// Fetch the jwks URI if it is not hardcoded on config.
   248  		jwksURI, err = r.resolveJwksURIUsingOpenID(issuer, timeout)
   249  	}
   250  	if err != nil {
   251  		log.Errorf("Failed to jwks URI from %q: %v", issuer, err)
   252  	} else {
   253  		var resp []byte
   254  		resp, err = r.getRemoteContentWithRetry(jwksURI, networkFetchRetryCountOnMainFlow, timeout)
   255  		if err != nil {
   256  			log.Errorf("Failed to fetch public key from %q: %v", jwksURI, err)
   257  		}
   258  		pubKey = string(resp)
   259  	}
   260  
   261  	r.keyEntries.Store(key, jwtPubKeyEntry{
   262  		pubKey:            pubKey,
   263  		lastRefreshedTime: now,
   264  		lastUsedTime:      now,
   265  		timeout:           timeout,
   266  	})
   267  	if err != nil {
   268  		// fetching the public key in the background
   269  		jwksuriChannel <- key
   270  	}
   271  	return pubKey, err
   272  }
   273  
   274  // BuildLocalJwks builds local Jwks by fetching the Jwt Public Key from the URL passed if it is empty.
   275  func (r *JwksResolver) BuildLocalJwks(jwksURI, jwtIssuer, jwtPubKey string, timeout time.Duration) *envoy_jwt.JwtProvider_LocalJwks {
   276  	var err error
   277  	if jwtPubKey == "" {
   278  		// jwtKeyResolver should never be nil since the function is only called in Discovery Server request processing
   279  		// workflow, where the JWT key resolver should have already been initialized on server creation.
   280  		jwtPubKey, err = r.GetPublicKey(jwtIssuer, jwksURI, timeout)
   281  		if err != nil {
   282  			log.Infof("The JWKS key is not yet fetched for issuer %s (%s), using a fake JWKS for now", jwtIssuer, jwksURI)
   283  			// This is a temporary workaround to reject a request with JWT token by using a fake jwks when istiod failed to fetch it.
   284  			// TODO(xulingqing): Find a better way to reject the request without using the fake jwks.
   285  			jwtPubKey = FakeJwks
   286  		}
   287  	}
   288  	return &envoy_jwt.JwtProvider_LocalJwks{
   289  		LocalJwks: &core.DataSource{
   290  			Specifier: &core.DataSource_InlineString{
   291  				InlineString: jwtPubKey,
   292  			},
   293  		},
   294  	}
   295  }
   296  
   297  // FakeJwks is a fake jwks, generated by following code
   298  /*
   299  	fakeJwksRSAKey, _ := rsa.GenerateKey(rand.Reader, 2048)
   300  	key, _ := jwk.FromRaw(fakeJwksRSAKey)
   301  	rsaKey, _ := key.(jwk.RSAPrivateKey)
   302  	res, _ := json.Marshal(rsaKey)
   303  	fmt.Printf("{\"keys\":[ %s]}\n", string(res))
   304  */
   305  // it should be static across different instances and versions.
   306  // more details can be found: https://github.com/istio/istio/pull/47661.
   307  // nolint: lll
   308  const FakeJwks = `{
   309    "keys": [
   310      {
   311        "d": "T6cYL1_1mWHQLtOcbOgWV6HjhS0HVh3Apt4xEar5beaMBX3IYLFITz684DOHNy5dzaxTRqvGj-zHEgNrgy2T-Izoo2Z-xJ2Zse6wQ4R0xbwd0by8IbhiePcjgNWXXzildMHkBVrxNZhUICpb_r8efTHZfEwc6FPjJDVgJKtEc6WGCOiWnRYcGTTlsB5-QrQQlDFLmrU2Z6QDmqJU33aDJFr_qzmRiVNXeHuhlNca2JnKNPpxjRVsy7Kbc8PorxiPijnLzV8_pccsMyLvA8pWUl5FRtAJNSss7x_81HEcInlj7yA896zMiELSPps1rW68yVvpuKEuYulzGi4z74gz0Q",
   312        "dp": "YkH_MFMlgnGZntOCXLhib1LLW1JJCYmTzebn-JSluFJbG_qQgzuZkUu5s2cYBHmiZkDGmnTDOAYXrOaQSgVIBQMPxMqdUf8WjRIlEb88zvKpM_Curp59wuy6MhI7Ej3xKiixHX3bIq5Qujk3ZdsDbHUi3HH56-V7cdFKccqlg6E",
   313        "dq": "CXCwRpRgbtqzLcsfuy-5IUZosrvEDHCrFh0C-A6OYvKpHzn8PDwb62YGddhiHzSrgr1EUgykQxiIF2xG8dBaq8xXg9Bh4G1kkgIsqJmL5DG1lwyh_-Jt4nPyiLHZ--ERc48cjj515uRpGd-CWXdIf2EWYaJNsEkiNaYEClJQIA8",
   314        "e": "AQAB",
   315        "kty": "RSA",
   316        "n": "vqS7RN4b34i3_5YyhygtBe33gI6GK_0ldW8WMZaunS28T-WAzJOAoZ7E9Y0mHS8vcDES0eZIUpp6Ft9sRPhOlzQfo_7l-3DnaD9LxJVKdXjE1jugxfI9YX1qJpD9S9wRZxQIhPky9UzZDkpFh_KpL6pZUt4cbPtW0VCctjqvpI11yHNk4CEbzw-RRFLMJkLFJqgPa2JPzGZ-TqJdkSDQ7UtRiKzjRcWGnAdLsTq6WabDA1Fn1JVI9TWu-YDbLufDUDco46qyPgpxAqcRQG39cWZAQzMwNEZ-Yec_WiqDYqGTU6K8BBWeEIuMhiWfxGmtqX35rb9Qk_qeYDsqqT95Pw",
   317        "p": "7EK8xaN7qCdWCeQ1ptXWvuc6qotZc6oD-j1ecgel9FqmfkmaioVEbEAfP_N73QAjw-sU60sK3XK8LV4fkGUoJV-MDvmiCzy3wUPe-adSaTCxFykgOm6SPA9NKCqAh8lUm6GUm9RZkjwkv4xzZ8pJjng3d74WXx7zhTEH6yi4E00",
   318        "q": "zpJPbhAn79s_jPm4OhOvvPKT-ISN6EyLu_g6joh1Dzf-HCF149KKQfuLDtwDCsCNf1cE_BCb4qoHAVBLDjbqusQF019zNIFTHeUL8oMpbv-5of7km0K8oo-DQp5b8u05PKaEQu3OXmRZFwuO6dSTPvXO094X-8vm791FLcJ-4Ls",
   319        "qi": "SXz-JeBcTYMcO5lDBlrI9qd2eMQAYfVFDyq523L-RFhdravaxaYutT7dWk5f4Smzbh5KtvKifcFUMnV88On4HCiTrdBjLJJhIYqZQwzP8hYbXZlw4SvCtXKUrvLwLEUQaYg6bopp4VJ5c3XCZD5z3paHlZ45oCDsMeSEWxAD6lo"
   320      }
   321    ]
   322  }`
   323  
   324  // Resolve jwks_uri through openID discovery.
   325  func (r *JwksResolver) resolveJwksURIUsingOpenID(issuer string, timeout time.Duration) (string, error) {
   326  	// Try to get jwks_uri through OpenID Discovery.
   327  	issuer = strings.TrimSuffix(issuer, "/")
   328  	body, err := r.getRemoteContentWithRetry(issuer+openIDDiscoveryCfgURLSuffix, networkFetchRetryCountOnMainFlow, timeout)
   329  	if err != nil {
   330  		log.Errorf("Failed to fetch jwks_uri from %q: %v", issuer+openIDDiscoveryCfgURLSuffix, err)
   331  		return "", err
   332  	}
   333  	var data map[string]any
   334  	if err := json.Unmarshal(body, &data); err != nil {
   335  		return "", err
   336  	}
   337  
   338  	jwksURI, ok := data["jwks_uri"].(string)
   339  	if !ok {
   340  		return "", fmt.Errorf("invalid jwks_uri %v in openID discovery configuration", data["jwks_uri"])
   341  	}
   342  
   343  	return jwksURI, nil
   344  }
   345  
   346  func (r *JwksResolver) getRemoteContentWithRetry(uri string, retry int, timeout time.Duration) ([]byte, error) {
   347  	u, err := url.Parse(uri)
   348  	if err != nil {
   349  		log.Errorf("Failed to parse %q", uri)
   350  		return nil, err
   351  	}
   352  
   353  	ctx, cancel := context.WithTimeout(context.Background(), timeout)
   354  	defer cancel()
   355  
   356  	client := r.httpClient
   357  	if strings.EqualFold(u.Scheme, "https") {
   358  		// https client may be uninitialized because of root CA bundle missing.
   359  		if r.secureHTTPClient == nil {
   360  			return nil, fmt.Errorf("pilot does not support fetch public key through https endpoint %q", uri)
   361  		}
   362  
   363  		client = r.secureHTTPClient
   364  	}
   365  
   366  	getPublicKey := func() (b []byte, e error) {
   367  		defer func() {
   368  			if e != nil {
   369  				networkFetchFailCounter.Increment()
   370  			} else {
   371  				networkFetchSuccessCounter.Increment()
   372  			}
   373  		}()
   374  		req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil)
   375  		if err != nil {
   376  			return nil, err
   377  		}
   378  		resp, err := client.Do(req)
   379  		if err != nil {
   380  			return nil, err
   381  		}
   382  		defer resp.Body.Close()
   383  
   384  		body, err := io.ReadAll(resp.Body)
   385  		if err != nil {
   386  			return nil, err
   387  		}
   388  
   389  		if resp.StatusCode < 200 || resp.StatusCode >= 300 {
   390  			message := strconv.Quote(string(body))
   391  			if len(message) > 100 {
   392  				message = message[:100]
   393  				return nil, fmt.Errorf("status %d, message %s(truncated)", resp.StatusCode, message)
   394  			}
   395  			return nil, fmt.Errorf("status %d, message %s", resp.StatusCode, message)
   396  		}
   397  
   398  		return body, nil
   399  	}
   400  
   401  	for i := 0; i < retry; i++ {
   402  		body, err := getPublicKey()
   403  		if err == nil {
   404  			return body, nil
   405  		}
   406  		log.Warnf("Failed to GET from %q: %s. Retry in %v", uri, err, r.retryInterval)
   407  		time.Sleep(r.retryInterval)
   408  	}
   409  
   410  	// Return the last fetch directly, reaching here means we have tried `retry` times, this will be
   411  	// the last time for the retry.
   412  	return getPublicKey()
   413  }
   414  
   415  func (r *JwksResolver) refresher() {
   416  	// Wake up once in a while and refresh stale items.
   417  	r.refreshTicker = time.NewTicker(r.refreshInterval)
   418  	lastHasError := false
   419  	for {
   420  		select {
   421  		case <-r.refreshTicker.C:
   422  			if !r.jwksUribackgroundChannel {
   423  				lastHasError = r.refreshCache(lastHasError)
   424  			}
   425  		case <-closeChan:
   426  			r.refreshTicker.Stop()
   427  			return
   428  		case <-jwksuriChannel:
   429  			r.jwksUribackgroundChannel = true
   430  			lastHasError = r.refreshCache(lastHasError)
   431  			r.jwksUribackgroundChannel = false
   432  		}
   433  	}
   434  }
   435  
   436  func (r *JwksResolver) refreshCache(lastHasError bool) bool {
   437  	currentHasError := r.refresh()
   438  	if currentHasError {
   439  		if lastHasError {
   440  			// update to exponential backoff if last time also failed.
   441  			r.refreshInterval *= 2
   442  			if r.refreshInterval > JwtPubKeyRefreshIntervalOnFailureResetThreshold {
   443  				r.refreshInterval = JwtPubKeyRefreshIntervalOnFailureResetThreshold
   444  			}
   445  		} else {
   446  			// change to the refreshIntervalOnFailure if failed for the first time.
   447  			r.refreshInterval = r.refreshIntervalOnFailure
   448  		}
   449  	} else {
   450  		// reset the refresh interval if success.
   451  		r.refreshInterval = r.refreshDefaultInterval
   452  	}
   453  	r.refreshTicker.Reset(r.refreshInterval)
   454  	return currentHasError
   455  }
   456  
   457  func (r *JwksResolver) refresh() bool {
   458  	var wg sync.WaitGroup
   459  	var hasChange, hasErrors atomic.Bool
   460  	r.keyEntries.Range(func(key any, value any) bool {
   461  		now := time.Now()
   462  		k := key.(jwtKey)
   463  		e := value.(jwtPubKeyEntry)
   464  
   465  		if e.pubKey != "" && r.jwksUribackgroundChannel {
   466  			return true
   467  		}
   468  		// Remove cached item for either of the following 2 situations
   469  		// 1) it hasn't been used for a while
   470  		// 2) it hasn't been refreshed successfully for a while
   471  		// This makes sure 2 things, we don't grow the cache infinitely and also we don't reuse a cached public key
   472  		// with no success refresh for too much time.
   473  		if now.Sub(e.lastUsedTime) >= r.evictionDuration || now.Sub(e.lastRefreshedTime) >= r.evictionDuration {
   474  			log.Infof("Removed cached JWT public key (lastRefreshed: %s, lastUsed: %s) from %q",
   475  				e.lastRefreshedTime, e.lastUsedTime, k.issuer)
   476  			r.keyEntries.Delete(k)
   477  			return true
   478  		}
   479  
   480  		oldPubKey := e.pubKey
   481  		// Increment the WaitGroup counter.
   482  		wg.Add(1)
   483  
   484  		go func() {
   485  			// Decrement the counter when the goroutine completes.
   486  			defer wg.Done()
   487  			jwksURI := k.jwksURI
   488  			if jwksURI == "" {
   489  				var err error
   490  				jwksURI, err = r.resolveJwksURIUsingOpenID(k.issuer, e.timeout)
   491  				if err != nil {
   492  					hasErrors.Store(true)
   493  					log.Errorf("Failed to resolve Jwks from issuer %q: %v", k.issuer, err)
   494  					atomic.AddUint64(&r.refreshJobFetchFailedCount, 1)
   495  					return
   496  				}
   497  				r.keyEntries.Delete(k)
   498  				k.jwksURI = jwksURI
   499  			}
   500  			resp, err := r.getRemoteContentWithRetry(jwksURI, networkFetchRetryCountOnRefreshFlow, e.timeout)
   501  			if err != nil {
   502  				hasErrors.Store(true)
   503  				log.Errorf("Failed to refresh JWT public key from %q: %v", jwksURI, err)
   504  				atomic.AddUint64(&r.refreshJobFetchFailedCount, 1)
   505  				if oldPubKey == "" {
   506  					r.keyEntries.Delete(k)
   507  				}
   508  				return
   509  			}
   510  			newPubKey := string(resp)
   511  			r.keyEntries.Store(k, jwtPubKeyEntry{
   512  				pubKey:            newPubKey,
   513  				lastRefreshedTime: now,            // update the lastRefreshedTime if we get a success response from the network.
   514  				lastUsedTime:      e.lastUsedTime, // keep original lastUsedTime.
   515  				timeout:           e.timeout,
   516  			})
   517  			isNewKey, err := compareJWKSResponse(oldPubKey, newPubKey)
   518  			if err != nil {
   519  				hasErrors.Store(true)
   520  				log.Errorf("Failed to refresh JWT public key from %q: %v", jwksURI, err)
   521  				return
   522  			}
   523  			if isNewKey {
   524  				hasChange.Store(true)
   525  				log.Infof("Updated cached JWT public key from %q", jwksURI)
   526  			}
   527  		}()
   528  
   529  		return true
   530  	})
   531  
   532  	// Wait for all go routine to complete.
   533  	wg.Wait()
   534  
   535  	if hasChange.Load() {
   536  		atomic.AddUint64(&r.refreshJobKeyChangedCount, 1)
   537  		// Push public key changes to sidecars.
   538  		if r.PushFunc != nil {
   539  			r.PushFunc()
   540  		}
   541  	}
   542  	return hasErrors.Load()
   543  }
   544  
   545  // Close will shut down the refresher job.
   546  // TODO: may need to figure out the right place to call this function.
   547  // (right now calls it from initDiscoveryService in pkg/bootstrap/server.go).
   548  func (r *JwksResolver) Close() {
   549  	closeChan <- true
   550  }
   551  
   552  // Compare two JWKS responses, returning true if there is a difference and false otherwise
   553  func compareJWKSResponse(oldKeyString string, newKeyString string) (bool, error) {
   554  	if oldKeyString == newKeyString {
   555  		return false, nil
   556  	}
   557  
   558  	var oldJWKs map[string]any
   559  	var newJWKs map[string]any
   560  	if err := json.Unmarshal([]byte(newKeyString), &newJWKs); err != nil {
   561  		// If the new key is not parseable as JSON return an error since we will not want to use this key
   562  		log.Warnf("New JWKs public key JSON is not parseable: %s", newKeyString)
   563  		return false, err
   564  	}
   565  	if err := json.Unmarshal([]byte(oldKeyString), &oldJWKs); err != nil {
   566  		log.Warnf("Previous JWKs public key JSON is not parseable: %s", oldKeyString)
   567  		return true, nil
   568  	}
   569  
   570  	// Sort both sets of keys by "kid (key ID)" to be able to directly compare
   571  	oldKeys, oldKeysExists := oldJWKs["keys"].([]any)
   572  	newKeys, newKeysExists := newJWKs["keys"].([]any)
   573  	if oldKeysExists && newKeysExists {
   574  		sort.Slice(oldKeys, func(i, j int) bool {
   575  			key1, ok1 := oldKeys[i].(map[string]any)
   576  			key2, ok2 := oldKeys[j].(map[string]any)
   577  			if ok1 && ok2 {
   578  				key1Id, kid1Exists := key1["kid"]
   579  				key2Id, kid2Exists := key2["kid"]
   580  				if kid1Exists && kid2Exists {
   581  					key1IdStr, ok1 := key1Id.(string)
   582  					key2IdStr, ok2 := key2Id.(string)
   583  					if ok1 && ok2 {
   584  						return key1IdStr < key2IdStr
   585  					}
   586  				}
   587  			}
   588  			return len(key1) < len(key2)
   589  		})
   590  		sort.Slice(newKeys, func(i, j int) bool {
   591  			key1, ok1 := newKeys[i].(map[string]any)
   592  			key2, ok2 := newKeys[j].(map[string]any)
   593  			if ok1 && ok2 {
   594  				key1Id, kid1Exists := key1["kid"]
   595  				key2Id, kid2Exists := key2["kid"]
   596  				if kid1Exists && kid2Exists {
   597  					key1IdStr, ok1 := key1Id.(string)
   598  					key2IdStr, ok2 := key2Id.(string)
   599  					if ok1 && ok2 {
   600  						return key1IdStr < key2IdStr
   601  					}
   602  				}
   603  			}
   604  			return len(key1) < len(key2)
   605  		})
   606  
   607  		// Once sorted, return the result of deep comparison of the arrays of keys
   608  		return !reflect.DeepEqual(oldKeys, newKeys), nil
   609  	}
   610  
   611  	// If we aren't able to compare using keys, we should return true
   612  	// since we already checked exact equality of the responses
   613  	return true, nil
   614  }