github.com/lestrrat-go/jwx/v2@v2.0.21/jwk/fetch.go (about) 1 package jwk 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "math" 8 "os" 9 "strconv" 10 "sync" 11 "sync/atomic" 12 13 "github.com/lestrrat-go/httprc" 14 ) 15 16 type Fetcher interface { 17 Fetch(context.Context, string, ...FetchOption) (Set, error) 18 } 19 20 type FetchFunc func(context.Context, string, ...FetchOption) (Set, error) 21 22 func (f FetchFunc) Fetch(ctx context.Context, u string, options ...FetchOption) (Set, error) { 23 return f(ctx, u, options...) 24 } 25 26 var globalFetcher httprc.Fetcher 27 var muGlobalFetcher sync.Mutex 28 var fetcherChanged uint32 29 30 func init() { 31 atomic.StoreUint32(&fetcherChanged, 1) 32 } 33 34 func getGlobalFetcher() httprc.Fetcher { 35 if v := atomic.LoadUint32(&fetcherChanged); v == 0 { 36 return globalFetcher 37 } 38 39 muGlobalFetcher.Lock() 40 defer muGlobalFetcher.Unlock() 41 if globalFetcher == nil { 42 var nworkers int 43 v := os.Getenv(`JWK_FETCHER_WORKER_COUNT`) 44 if c, err := strconv.ParseInt(v, 10, 64); err == nil { 45 if c > math.MaxInt { 46 nworkers = math.MaxInt 47 } else { 48 nworkers = int(c) 49 } 50 } 51 if nworkers < 1 { 52 nworkers = 3 53 } 54 55 globalFetcher = httprc.NewFetcher(context.Background(), httprc.WithFetcherWorkerCount(nworkers)) 56 } 57 58 atomic.StoreUint32(&fetcherChanged, 0) 59 return globalFetcher 60 } 61 62 // SetGlobalFetcher allows users to specify a custom global fetcher, 63 // which is used by the `Fetch` function. Assigning `nil` forces 64 // the default fetcher to be (re)created when the next call to 65 // `jwk.Fetch` occurs 66 // 67 // You only need to call this function when you want to 68 // either change the fetching behavior (for example, you want to change 69 // how the default whitelist is handled), or when you want to control 70 // the lifetime of the global fetcher, for example for tests 71 // that require a clean shutdown. 72 // 73 // If you do use this function to set a custom fetcher and you 74 // control its termination, make sure that you call `jwk.SetGlobalFetcher()` 75 // one more time (possibly with `nil`) to assign a valid fetcher. 76 // Otherwise, once the fetcher is invalidated, subsequent calls to `jwk.Fetch` 77 // may hang, causing very hard to debug problems. 78 // 79 // If you are sure you no longer need `jwk.Fetch` after terminating the 80 // fetcher, then you the above caution is not necessary. 81 func SetGlobalFetcher(f httprc.Fetcher) { 82 muGlobalFetcher.Lock() 83 globalFetcher = f 84 muGlobalFetcher.Unlock() 85 atomic.StoreUint32(&fetcherChanged, 1) 86 } 87 88 // Fetch fetches a JWK resource specified by a URL. The url must be 89 // pointing to a resource that is supported by `net/http`. 90 // 91 // If you are using the same `jwk.Set` for long periods of time during 92 // the lifecycle of your program, and would like to periodically refresh the 93 // contents of the object with the data at the remote resource, 94 // consider using `jwk.Cache`, which automatically refreshes 95 // jwk.Set objects asynchronously. 96 // 97 // Please note that underneath the `jwk.Fetch` function, it uses a global 98 // object that spawns goroutines that are present until the go runtime 99 // exits. Initially this global variable is uninitialized, but upon 100 // calling `jwk.Fetch` once, it is initialized and goroutines are spawned. 101 // If you want to control the lifetime of these goroutines, you can 102 // call `jwk.SetGlobalFetcher` with a custom fetcher which is tied to 103 // a `context.Context` object that you can control. 104 func Fetch(ctx context.Context, u string, options ...FetchOption) (Set, error) { 105 var hrfopts []httprc.FetchOption 106 var parseOptions []ParseOption 107 for _, option := range options { 108 if parseOpt, ok := option.(ParseOption); ok { 109 parseOptions = append(parseOptions, parseOpt) 110 continue 111 } 112 113 //nolint:forcetypeassert 114 switch option.Ident() { 115 case identHTTPClient{}: 116 hrfopts = append(hrfopts, httprc.WithHTTPClient(option.Value().(HTTPClient))) 117 case identFetchWhitelist{}: 118 hrfopts = append(hrfopts, httprc.WithWhitelist(option.Value().(httprc.Whitelist))) 119 } 120 } 121 122 res, err := getGlobalFetcher().Fetch(ctx, u, hrfopts...) 123 if err != nil { 124 return nil, fmt.Errorf(`failed to fetch %q: %w`, u, err) 125 } 126 127 buf, err := io.ReadAll(res.Body) 128 defer res.Body.Close() 129 if err != nil { 130 return nil, fmt.Errorf(`failed to read response body for %q: %w`, u, err) 131 } 132 133 return Parse(buf, parseOptions...) 134 }