github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/utils/jwtauth/jwks.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package jwtauth
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"io"
    21  	"net/http"
    22  	"os"
    23  	"sync"
    24  	"time"
    25  
    26  	"github.com/sirupsen/logrus"
    27  	jose "gopkg.in/square/go-jose.v2"
    28  	"gopkg.in/square/go-jose.v2/json"
    29  )
    30  
    31  type cachedJWKS struct {
    32  	value   *jose.JSONWebKeySet
    33  	expires time.Time
    34  	mutex   *sync.Mutex
    35  }
    36  
    37  func newCachedJWKS() *cachedJWKS {
    38  	return &cachedJWKS{value: nil, expires: time.Now(), mutex: &sync.Mutex{}}
    39  }
    40  
    41  type fetchedJWKS struct {
    42  	URL           string
    43  	HTTPTransport *http.Transport
    44  	cache         *cachedJWKS
    45  }
    46  
    47  func newJWKS(provider JWTProvider) (*fetchedJWKS, error) {
    48  	return newFetchedJWKS(provider.URL)
    49  }
    50  
    51  func newFetchedJWKS(url string) (*fetchedJWKS, error) {
    52  	ret := &fetchedJWKS{
    53  		URL:   url,
    54  		cache: newCachedJWKS(),
    55  	}
    56  
    57  	pwd, err := os.Getwd()
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	// Allows use of file:// for jwks location  url for tests
    63  	tr := &http.Transport{}
    64  	tr.RegisterProtocol("file", http.NewFileTransport(http.Dir(pwd)))
    65  	ret.HTTPTransport = tr
    66  
    67  	return ret, nil
    68  }
    69  
    70  func (f *fetchedJWKS) needsRefresh() bool {
    71  	return f.cache.value == nil || time.Now().After(f.cache.expires)
    72  }
    73  
    74  func (f *fetchedJWKS) GetJWKS() (*jose.JSONWebKeySet, error) {
    75  	f.cache.mutex.Lock()
    76  	defer f.cache.mutex.Unlock()
    77  	if f.needsRefresh() {
    78  		client := &http.Client{Transport: f.HTTPTransport}
    79  
    80  		request, err := http.NewRequest("GET", f.URL, nil)
    81  		if err != nil {
    82  			return nil, err
    83  		}
    84  
    85  		response, err := client.Do(request)
    86  		if err != nil {
    87  			return nil, err
    88  		} else if response.StatusCode/100 != 2 {
    89  			return nil, errors.New("FetchedJWKS: Non-2xx status code from JWKS fetch")
    90  		} else {
    91  			defer response.Body.Close()
    92  			contents, err := io.ReadAll(response.Body)
    93  			if err != nil {
    94  				return nil, err
    95  			}
    96  
    97  			jwks := jose.JSONWebKeySet{}
    98  			err = json.Unmarshal(contents, &jwks)
    99  			if err != nil {
   100  				return nil, err
   101  			}
   102  			f.cache.value = &jwks
   103  		}
   104  	}
   105  	return f.cache.value, nil
   106  }
   107  
   108  func (f *fetchedJWKS) GetKey(kid string) ([]jose.JSONWebKey, error) {
   109  	jwks, err := f.GetJWKS()
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  	return jwks.Key(kid), nil
   114  }
   115  
   116  // The MultiJWKS will source JWKS from multiple URLs and will make them all
   117  // available through GetKey(). It's GetKey() cannot error, but it can return no
   118  // results.
   119  //
   120  // The URLs in the refresh list are static. Each URL will be periodically
   121  // refreshed and the results will be aggregated into the JWKS view. If a key no
   122  // longer appears at the URL, it may eventually be removed from the set of keys
   123  // available through GetKey(). Requesting a key which is not currently in the
   124  // key set will generally hint that the URLs should be more aggressively
   125  // refreshed, but there is no blocking on refreshing the URLs.
   126  //
   127  // GracefulStop() will shutdown any ongoing fetching work and will return when
   128  // everything is cleanly shutdown.
   129  type MultiJWKS struct {
   130  	client  *http.Client
   131  	wg      sync.WaitGroup
   132  	stop    chan struct{}
   133  	refresh []chan *sync.WaitGroup
   134  	urls    []string
   135  	sets    []jose.JSONWebKeySet
   136  	agg     jose.JSONWebKeySet
   137  	mu      sync.RWMutex
   138  	lgr     *logrus.Entry
   139  	stopped bool
   140  }
   141  
   142  func NewMultiJWKS(lgr *logrus.Entry, urls []string, client *http.Client) *MultiJWKS {
   143  	res := new(MultiJWKS)
   144  	res.lgr = lgr
   145  	res.client = client
   146  	res.urls = urls
   147  	res.stop = make(chan struct{})
   148  	res.refresh = make([]chan *sync.WaitGroup, len(urls))
   149  	for i := range res.refresh {
   150  		res.refresh[i] = make(chan *sync.WaitGroup, 3)
   151  	}
   152  	res.sets = make([]jose.JSONWebKeySet, len(urls))
   153  	return res
   154  }
   155  
   156  func (t *MultiJWKS) Run() {
   157  	t.wg.Add(len(t.urls))
   158  	for i := 0; i < len(t.urls); i++ {
   159  		go t.thread(i)
   160  	}
   161  	t.wg.Wait()
   162  }
   163  
   164  func (t *MultiJWKS) GracefulStop() {
   165  	t.mu.Lock()
   166  	t.stopped = true
   167  	t.mu.Unlock()
   168  	close(t.stop)
   169  	t.wg.Wait()
   170  	// TODO: Potentially clear t.refresh channels, ensure nothing else can call GetKey()...
   171  }
   172  
   173  func (t *MultiJWKS) needsRefresh() *sync.WaitGroup {
   174  	wg := new(sync.WaitGroup)
   175  	if t.stopped {
   176  		return wg
   177  	}
   178  	wg.Add(len(t.refresh))
   179  	for _, c := range t.refresh {
   180  		select {
   181  		case c <- wg:
   182  		default:
   183  			wg.Done()
   184  		}
   185  	}
   186  	return wg
   187  }
   188  
   189  func (t *MultiJWKS) store(i int, jwks jose.JSONWebKeySet) {
   190  	t.mu.Lock()
   191  	defer t.mu.Unlock()
   192  	t.sets[i] = jwks
   193  	sum := 0
   194  	for _, s := range t.sets {
   195  		sum += len(s.Keys)
   196  	}
   197  	t.agg.Keys = make([]jose.JSONWebKey, 0, sum)
   198  	for _, s := range t.sets {
   199  		t.agg.Keys = append(t.agg.Keys, s.Keys...)
   200  	}
   201  }
   202  
   203  func (t *MultiJWKS) GetKey(kid string) ([]jose.JSONWebKey, error) {
   204  	t.mu.RLock()
   205  	defer t.mu.RUnlock()
   206  	res := t.agg.Key(kid)
   207  	if len(res) == 0 {
   208  		t.lgr.Infof("fetched key %s, found no key, signaling refresh", kid)
   209  		refresh := t.needsRefresh()
   210  		t.mu.RUnlock()
   211  		refresh.Wait()
   212  		t.mu.RLock()
   213  		res = t.agg.Key(kid)
   214  		t.lgr.Infof("refresh for key %s done, found %d keys", kid, len(res))
   215  	}
   216  	return res, nil
   217  }
   218  
   219  func (t *MultiJWKS) fetch(i int) error {
   220  	request, err := http.NewRequest("GET", t.urls[i], nil)
   221  	if err != nil {
   222  		return err
   223  	}
   224  	response, err := t.client.Do(request)
   225  	if err != nil {
   226  		return err
   227  	}
   228  	defer response.Body.Close()
   229  	if response.StatusCode/100 != 2 {
   230  		return fmt.Errorf("http request failed: StatusCode: %d", response.StatusCode)
   231  	}
   232  	contents, err := io.ReadAll(response.Body)
   233  	if err != nil {
   234  		return err
   235  	}
   236  	var jwks jose.JSONWebKeySet
   237  	err = json.Unmarshal(contents, &jwks)
   238  	if err != nil {
   239  		return err
   240  	}
   241  	t.store(i, jwks)
   242  	return nil
   243  }
   244  
   245  func (t *MultiJWKS) thread(i int) {
   246  	defer t.wg.Done()
   247  	timer := time.NewTimer(30 * time.Second)
   248  	var refresh *sync.WaitGroup
   249  	for {
   250  		nextRefresh := 30 * time.Second
   251  		err := t.fetch(i)
   252  		if err != nil {
   253  			// Something bad...
   254  			t.lgr.Warnf("error fetching %s: %v", t.urls[i], err)
   255  			nextRefresh = 1 * time.Second
   256  		}
   257  		timer.Reset(nextRefresh)
   258  		if refresh != nil {
   259  			refresh.Done()
   260  		}
   261  		refresh = nil
   262  		select {
   263  		case <-t.stop:
   264  			if !timer.Stop() {
   265  				<-timer.C
   266  			}
   267  			for {
   268  				select {
   269  				case refresh = <-t.refresh[i]:
   270  					refresh.Done()
   271  				default:
   272  					return
   273  				}
   274  			}
   275  		case refresh = <-t.refresh[i]:
   276  			if !timer.Stop() {
   277  				<-timer.C
   278  			}
   279  		case <-timer.C:
   280  		}
   281  	}
   282  }