github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/utils/jwtauth/jwks.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package jwtauth 16 17 import ( 18 "errors" 19 "fmt" 20 "io" 21 "net/http" 22 "os" 23 "sync" 24 "time" 25 26 "github.com/sirupsen/logrus" 27 jose "gopkg.in/square/go-jose.v2" 28 "gopkg.in/square/go-jose.v2/json" 29 ) 30 31 type cachedJWKS struct { 32 value *jose.JSONWebKeySet 33 expires time.Time 34 mutex *sync.Mutex 35 } 36 37 func newCachedJWKS() *cachedJWKS { 38 return &cachedJWKS{value: nil, expires: time.Now(), mutex: &sync.Mutex{}} 39 } 40 41 type fetchedJWKS struct { 42 URL string 43 HTTPTransport *http.Transport 44 cache *cachedJWKS 45 } 46 47 func newJWKS(provider JWTProvider) (*fetchedJWKS, error) { 48 return newFetchedJWKS(provider.URL) 49 } 50 51 func newFetchedJWKS(url string) (*fetchedJWKS, error) { 52 ret := &fetchedJWKS{ 53 URL: url, 54 cache: newCachedJWKS(), 55 } 56 57 pwd, err := os.Getwd() 58 if err != nil { 59 return nil, err 60 } 61 62 // Allows use of file:// for jwks location url for tests 63 tr := &http.Transport{} 64 tr.RegisterProtocol("file", http.NewFileTransport(http.Dir(pwd))) 65 ret.HTTPTransport = tr 66 67 return ret, nil 68 } 69 70 func (f *fetchedJWKS) needsRefresh() bool { 71 return f.cache.value == nil || time.Now().After(f.cache.expires) 72 } 73 74 func (f *fetchedJWKS) GetJWKS() (*jose.JSONWebKeySet, error) { 75 f.cache.mutex.Lock() 76 defer f.cache.mutex.Unlock() 77 if f.needsRefresh() { 78 client := &http.Client{Transport: f.HTTPTransport} 79 80 request, err := http.NewRequest("GET", f.URL, nil) 81 if err != nil { 82 return nil, err 83 } 84 85 response, err := client.Do(request) 86 if err != nil { 87 return nil, err 88 } else if response.StatusCode/100 != 2 { 89 return nil, errors.New("FetchedJWKS: Non-2xx status code from JWKS fetch") 90 } else { 91 defer response.Body.Close() 92 contents, err := io.ReadAll(response.Body) 93 if err != nil { 94 return nil, err 95 } 96 97 jwks := jose.JSONWebKeySet{} 98 err = json.Unmarshal(contents, &jwks) 99 if err != nil { 100 return nil, err 101 } 102 f.cache.value = &jwks 103 } 104 } 105 return f.cache.value, nil 106 } 107 108 func (f *fetchedJWKS) GetKey(kid string) ([]jose.JSONWebKey, error) { 109 jwks, err := f.GetJWKS() 110 if err != nil { 111 return nil, err 112 } 113 return jwks.Key(kid), nil 114 } 115 116 // The MultiJWKS will source JWKS from multiple URLs and will make them all 117 // available through GetKey(). It's GetKey() cannot error, but it can return no 118 // results. 119 // 120 // The URLs in the refresh list are static. Each URL will be periodically 121 // refreshed and the results will be aggregated into the JWKS view. If a key no 122 // longer appears at the URL, it may eventually be removed from the set of keys 123 // available through GetKey(). Requesting a key which is not currently in the 124 // key set will generally hint that the URLs should be more aggressively 125 // refreshed, but there is no blocking on refreshing the URLs. 126 // 127 // GracefulStop() will shutdown any ongoing fetching work and will return when 128 // everything is cleanly shutdown. 129 type MultiJWKS struct { 130 client *http.Client 131 wg sync.WaitGroup 132 stop chan struct{} 133 refresh []chan *sync.WaitGroup 134 urls []string 135 sets []jose.JSONWebKeySet 136 agg jose.JSONWebKeySet 137 mu sync.RWMutex 138 lgr *logrus.Entry 139 stopped bool 140 } 141 142 func NewMultiJWKS(lgr *logrus.Entry, urls []string, client *http.Client) *MultiJWKS { 143 res := new(MultiJWKS) 144 res.lgr = lgr 145 res.client = client 146 res.urls = urls 147 res.stop = make(chan struct{}) 148 res.refresh = make([]chan *sync.WaitGroup, len(urls)) 149 for i := range res.refresh { 150 res.refresh[i] = make(chan *sync.WaitGroup, 3) 151 } 152 res.sets = make([]jose.JSONWebKeySet, len(urls)) 153 return res 154 } 155 156 func (t *MultiJWKS) Run() { 157 t.wg.Add(len(t.urls)) 158 for i := 0; i < len(t.urls); i++ { 159 go t.thread(i) 160 } 161 t.wg.Wait() 162 } 163 164 func (t *MultiJWKS) GracefulStop() { 165 t.mu.Lock() 166 t.stopped = true 167 t.mu.Unlock() 168 close(t.stop) 169 t.wg.Wait() 170 // TODO: Potentially clear t.refresh channels, ensure nothing else can call GetKey()... 171 } 172 173 func (t *MultiJWKS) needsRefresh() *sync.WaitGroup { 174 wg := new(sync.WaitGroup) 175 if t.stopped { 176 return wg 177 } 178 wg.Add(len(t.refresh)) 179 for _, c := range t.refresh { 180 select { 181 case c <- wg: 182 default: 183 wg.Done() 184 } 185 } 186 return wg 187 } 188 189 func (t *MultiJWKS) store(i int, jwks jose.JSONWebKeySet) { 190 t.mu.Lock() 191 defer t.mu.Unlock() 192 t.sets[i] = jwks 193 sum := 0 194 for _, s := range t.sets { 195 sum += len(s.Keys) 196 } 197 t.agg.Keys = make([]jose.JSONWebKey, 0, sum) 198 for _, s := range t.sets { 199 t.agg.Keys = append(t.agg.Keys, s.Keys...) 200 } 201 } 202 203 func (t *MultiJWKS) GetKey(kid string) ([]jose.JSONWebKey, error) { 204 t.mu.RLock() 205 defer t.mu.RUnlock() 206 res := t.agg.Key(kid) 207 if len(res) == 0 { 208 t.lgr.Infof("fetched key %s, found no key, signaling refresh", kid) 209 refresh := t.needsRefresh() 210 t.mu.RUnlock() 211 refresh.Wait() 212 t.mu.RLock() 213 res = t.agg.Key(kid) 214 t.lgr.Infof("refresh for key %s done, found %d keys", kid, len(res)) 215 } 216 return res, nil 217 } 218 219 func (t *MultiJWKS) fetch(i int) error { 220 request, err := http.NewRequest("GET", t.urls[i], nil) 221 if err != nil { 222 return err 223 } 224 response, err := t.client.Do(request) 225 if err != nil { 226 return err 227 } 228 defer response.Body.Close() 229 if response.StatusCode/100 != 2 { 230 return fmt.Errorf("http request failed: StatusCode: %d", response.StatusCode) 231 } 232 contents, err := io.ReadAll(response.Body) 233 if err != nil { 234 return err 235 } 236 var jwks jose.JSONWebKeySet 237 err = json.Unmarshal(contents, &jwks) 238 if err != nil { 239 return err 240 } 241 t.store(i, jwks) 242 return nil 243 } 244 245 func (t *MultiJWKS) thread(i int) { 246 defer t.wg.Done() 247 timer := time.NewTimer(30 * time.Second) 248 var refresh *sync.WaitGroup 249 for { 250 nextRefresh := 30 * time.Second 251 err := t.fetch(i) 252 if err != nil { 253 // Something bad... 254 t.lgr.Warnf("error fetching %s: %v", t.urls[i], err) 255 nextRefresh = 1 * time.Second 256 } 257 timer.Reset(nextRefresh) 258 if refresh != nil { 259 refresh.Done() 260 } 261 refresh = nil 262 select { 263 case <-t.stop: 264 if !timer.Stop() { 265 <-timer.C 266 } 267 for { 268 select { 269 case refresh = <-t.refresh[i]: 270 refresh.Done() 271 default: 272 return 273 } 274 } 275 case refresh = <-t.refresh[i]: 276 if !timer.Stop() { 277 <-timer.C 278 } 279 case <-timer.C: 280 } 281 } 282 }