github.com/lestrrat-go/jwx/v2@v2.0.21/jwk/fetch.go (about)

     1  package jwk
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"math"
     8  	"os"
     9  	"strconv"
    10  	"sync"
    11  	"sync/atomic"
    12  
    13  	"github.com/lestrrat-go/httprc"
    14  )
    15  
    16  type Fetcher interface {
    17  	Fetch(context.Context, string, ...FetchOption) (Set, error)
    18  }
    19  
    20  type FetchFunc func(context.Context, string, ...FetchOption) (Set, error)
    21  
    22  func (f FetchFunc) Fetch(ctx context.Context, u string, options ...FetchOption) (Set, error) {
    23  	return f(ctx, u, options...)
    24  }
    25  
    26  var globalFetcher httprc.Fetcher
    27  var muGlobalFetcher sync.Mutex
    28  var fetcherChanged uint32
    29  
    30  func init() {
    31  	atomic.StoreUint32(&fetcherChanged, 1)
    32  }
    33  
    34  func getGlobalFetcher() httprc.Fetcher {
    35  	if v := atomic.LoadUint32(&fetcherChanged); v == 0 {
    36  		return globalFetcher
    37  	}
    38  
    39  	muGlobalFetcher.Lock()
    40  	defer muGlobalFetcher.Unlock()
    41  	if globalFetcher == nil {
    42  		var nworkers int
    43  		v := os.Getenv(`JWK_FETCHER_WORKER_COUNT`)
    44  		if c, err := strconv.ParseInt(v, 10, 64); err == nil {
    45  			if c > math.MaxInt {
    46  				nworkers = math.MaxInt
    47  			} else {
    48  				nworkers = int(c)
    49  			}
    50  		}
    51  		if nworkers < 1 {
    52  			nworkers = 3
    53  		}
    54  
    55  		globalFetcher = httprc.NewFetcher(context.Background(), httprc.WithFetcherWorkerCount(nworkers))
    56  	}
    57  
    58  	atomic.StoreUint32(&fetcherChanged, 0)
    59  	return globalFetcher
    60  }
    61  
    62  // SetGlobalFetcher allows users to specify a custom global fetcher,
    63  // which is used by the `Fetch` function. Assigning `nil` forces
    64  // the default fetcher to be (re)created when the next call to
    65  // `jwk.Fetch` occurs
    66  //
    67  // You only need to call this function when you want to
    68  // either change the fetching behavior (for example, you want to change
    69  // how the default whitelist is handled), or when you want to control
    70  // the lifetime of the global fetcher, for example for tests
    71  // that require a clean shutdown.
    72  //
    73  // If you do use this function to set a custom fetcher and you
    74  // control its termination, make sure that you call `jwk.SetGlobalFetcher()`
    75  // one more time (possibly with `nil`) to assign a valid fetcher.
    76  // Otherwise, once the fetcher is invalidated, subsequent calls to `jwk.Fetch`
    77  // may hang, causing very hard to debug problems.
    78  //
    79  // If you are sure you no longer need `jwk.Fetch` after terminating the
    80  // fetcher, then you the above caution is not necessary.
    81  func SetGlobalFetcher(f httprc.Fetcher) {
    82  	muGlobalFetcher.Lock()
    83  	globalFetcher = f
    84  	muGlobalFetcher.Unlock()
    85  	atomic.StoreUint32(&fetcherChanged, 1)
    86  }
    87  
    88  // Fetch fetches a JWK resource specified by a URL. The url must be
    89  // pointing to a resource that is supported by `net/http`.
    90  //
    91  // If you are using the same `jwk.Set` for long periods of time during
    92  // the lifecycle of your program, and would like to periodically refresh the
    93  // contents of the object with the data at the remote resource,
    94  // consider using `jwk.Cache`, which automatically refreshes
    95  // jwk.Set objects asynchronously.
    96  //
    97  // Please note that underneath the `jwk.Fetch` function, it uses a global
    98  // object that spawns goroutines that are present until the go runtime
    99  // exits. Initially this global variable is uninitialized, but upon
   100  // calling `jwk.Fetch` once, it is initialized and goroutines are spawned.
   101  // If you want to control the lifetime of these goroutines, you can
   102  // call `jwk.SetGlobalFetcher` with a custom fetcher which is tied to
   103  // a `context.Context` object that you can control.
   104  func Fetch(ctx context.Context, u string, options ...FetchOption) (Set, error) {
   105  	var hrfopts []httprc.FetchOption
   106  	var parseOptions []ParseOption
   107  	for _, option := range options {
   108  		if parseOpt, ok := option.(ParseOption); ok {
   109  			parseOptions = append(parseOptions, parseOpt)
   110  			continue
   111  		}
   112  
   113  		//nolint:forcetypeassert
   114  		switch option.Ident() {
   115  		case identHTTPClient{}:
   116  			hrfopts = append(hrfopts, httprc.WithHTTPClient(option.Value().(HTTPClient)))
   117  		case identFetchWhitelist{}:
   118  			hrfopts = append(hrfopts, httprc.WithWhitelist(option.Value().(httprc.Whitelist)))
   119  		}
   120  	}
   121  
   122  	res, err := getGlobalFetcher().Fetch(ctx, u, hrfopts...)
   123  	if err != nil {
   124  		return nil, fmt.Errorf(`failed to fetch %q: %w`, u, err)
   125  	}
   126  
   127  	buf, err := io.ReadAll(res.Body)
   128  	defer res.Body.Close()
   129  	if err != nil {
   130  		return nil, fmt.Errorf(`failed to read response body for %q: %w`, u, err)
   131  	}
   132  
   133  	return Parse(buf, parseOptions...)
   134  }