github.com/blend/go-sdk@v1.20240719.1/oauth/public_key_cache.go (about)

     1  /*
     2  
     3  Copyright (c) 2024 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package oauth
     9  
    10  import (
    11  	"context"
    12  	"crypto/rsa"
    13  	"net/http"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/golang-jwt/jwt/v4"
    18  
    19  	"github.com/blend/go-sdk/ex"
    20  	"github.com/blend/go-sdk/jwk"
    21  	"github.com/blend/go-sdk/r2"
    22  )
    23  
    24  // PublicKeyCache holds cached signing certs.
    25  type PublicKeyCache struct {
    26  	FetchPublicKeysDefaults []r2.Option
    27  	mu                      sync.RWMutex
    28  	current                 *PublicKeysResponse
    29  	keyURL                  string
    30  }
    31  
    32  // NewPublicKeyCache creates a new public key cache.
    33  func NewPublicKeyCache(keyURL string) *PublicKeyCache {
    34  	return &PublicKeyCache{
    35  		keyURL: keyURL,
    36  	}
    37  }
    38  
    39  // Keyfunc returns a jwt keyfunc for a specific exchange tied to context.
    40  func (pkc *PublicKeyCache) Keyfunc(ctx context.Context) jwt.Keyfunc {
    41  	return func(token *jwt.Token) (interface{}, error) {
    42  		if token == nil {
    43  			return nil, Error("invalid jwt; token is unset")
    44  		}
    45  		kid, ok := token.Header["kid"]
    46  		if !ok {
    47  			return nil, Error("invalid jwt header; `kid` missing")
    48  		}
    49  		typedKid, ok := kid.(string)
    50  		if !ok {
    51  			return nil, Error("invalid jwt header; `kid` not a string")
    52  		}
    53  		return pkc.Get(ctx, typedKid)
    54  	}
    55  }
    56  
    57  // Get gets a cert by id.
    58  func (pkc *PublicKeyCache) Get(ctx context.Context, id string) (*rsa.PublicKey, error) {
    59  	var jwk jwk.JWK
    60  	var ok bool
    61  	pkc.mu.RLock()
    62  	if pkc.current != nil && !pkc.current.IsExpired() {
    63  		jwk, ok = pkc.current.Keys[id]
    64  	}
    65  	pkc.mu.RUnlock()
    66  	if ok {
    67  		return jwk.RSAPublicKey()
    68  	}
    69  
    70  	pkc.mu.Lock()
    71  	defer pkc.mu.Unlock()
    72  
    73  	// check again after grabbing the lock if
    74  	// the keys have been updated
    75  	if pkc.current != nil && !pkc.current.IsExpired() {
    76  		jwk, ok = pkc.current.Keys[id]
    77  	}
    78  	if ok {
    79  		return jwk.RSAPublicKey()
    80  	}
    81  
    82  	// if we should still refresh after grabbing
    83  	// the write lock
    84  	keys, err := pkc.FetchPublicKeys(ctx, pkc.FetchPublicKeysDefaults...)
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  	pkc.current = keys
    89  
    90  	jwk, ok = pkc.current.Keys[id]
    91  	if !ok {
    92  		return nil, ex.New("invalid jwt key id; not found in signing keys cache", ex.OptMessagef("Key ID: %s", id))
    93  	}
    94  	return jwk.RSAPublicKey()
    95  }
    96  
    97  // FetchPublicKeys gets the google signing certs.
    98  func (pkc *PublicKeyCache) FetchPublicKeys(ctx context.Context, opts ...r2.Option) (*PublicKeysResponse, error) {
    99  	var jwks fetchPublicKeysResponse
   100  	meta, err := r2.New(pkc.keyURL, opts...).JSON(&jwks)
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	expiresHeader := meta.Header.Get("Expires")
   106  	if expiresHeader == "" {
   107  		return nil, ex.New("invalid google keys response; expires unset")
   108  	}
   109  
   110  	expires, err := time.Parse(http.TimeFormat, expiresHeader)
   111  	if err != nil {
   112  		return nil, ex.New("invalid google keys response; invalid expires value", ex.OptInner(err))
   113  	}
   114  	res := &PublicKeysResponse{
   115  		Keys:         jwkLookup(jwks.Keys),
   116  		CacheControl: meta.Header.Get("Cache-Control"),
   117  		Expires:      expires,
   118  	}
   119  	return res, nil
   120  }
   121  
   122  type fetchPublicKeysResponse struct {
   123  	Keys []jwk.JWK `json:"keys"`
   124  }
   125  
   126  // jwkLookup creates a jwk lookup.
   127  func jwkLookup(jwks []jwk.JWK) map[string]jwk.JWK {
   128  	output := make(map[string]jwk.JWK)
   129  	for _, jwk := range jwks {
   130  		// We don't check that `jwk.KID` collides with an existing key. We trust that
   131  		// the public certs URL (e.g. the one from Google) does not include duplicates.
   132  		output[jwk.KID] = jwk
   133  	}
   134  	return output
   135  }
   136  
   137  // PublicKeysResponse is a response for the google certs api.
   138  type PublicKeysResponse struct {
   139  	CacheControl string
   140  	Expires      time.Time
   141  	Keys         map[string]jwk.JWK
   142  }
   143  
   144  // IsExpired returns if the cert response is expired.
   145  func (pkr PublicKeysResponse) IsExpired() bool {
   146  	if pkr.Expires.IsZero() {
   147  		return true
   148  	}
   149  	return time.Now().UTC().After(pkr.Expires.UTC())
   150  }