github.com/blend/go-sdk@v1.20240719.1/oauth/public_key_cache.go (about) 1 /* 2 3 Copyright (c) 2024 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package oauth 9 10 import ( 11 "context" 12 "crypto/rsa" 13 "net/http" 14 "sync" 15 "time" 16 17 "github.com/golang-jwt/jwt/v4" 18 19 "github.com/blend/go-sdk/ex" 20 "github.com/blend/go-sdk/jwk" 21 "github.com/blend/go-sdk/r2" 22 ) 23 24 // PublicKeyCache holds cached signing certs. 25 type PublicKeyCache struct { 26 FetchPublicKeysDefaults []r2.Option 27 mu sync.RWMutex 28 current *PublicKeysResponse 29 keyURL string 30 } 31 32 // NewPublicKeyCache creates a new public key cache. 33 func NewPublicKeyCache(keyURL string) *PublicKeyCache { 34 return &PublicKeyCache{ 35 keyURL: keyURL, 36 } 37 } 38 39 // Keyfunc returns a jwt keyfunc for a specific exchange tied to context. 40 func (pkc *PublicKeyCache) Keyfunc(ctx context.Context) jwt.Keyfunc { 41 return func(token *jwt.Token) (interface{}, error) { 42 if token == nil { 43 return nil, Error("invalid jwt; token is unset") 44 } 45 kid, ok := token.Header["kid"] 46 if !ok { 47 return nil, Error("invalid jwt header; `kid` missing") 48 } 49 typedKid, ok := kid.(string) 50 if !ok { 51 return nil, Error("invalid jwt header; `kid` not a string") 52 } 53 return pkc.Get(ctx, typedKid) 54 } 55 } 56 57 // Get gets a cert by id. 58 func (pkc *PublicKeyCache) Get(ctx context.Context, id string) (*rsa.PublicKey, error) { 59 var jwk jwk.JWK 60 var ok bool 61 pkc.mu.RLock() 62 if pkc.current != nil && !pkc.current.IsExpired() { 63 jwk, ok = pkc.current.Keys[id] 64 } 65 pkc.mu.RUnlock() 66 if ok { 67 return jwk.RSAPublicKey() 68 } 69 70 pkc.mu.Lock() 71 defer pkc.mu.Unlock() 72 73 // check again after grabbing the lock if 74 // the keys have been updated 75 if pkc.current != nil && !pkc.current.IsExpired() { 76 jwk, ok = pkc.current.Keys[id] 77 } 78 if ok { 79 return jwk.RSAPublicKey() 80 } 81 82 // if we should still refresh after grabbing 83 // the write lock 84 keys, err := pkc.FetchPublicKeys(ctx, pkc.FetchPublicKeysDefaults...) 85 if err != nil { 86 return nil, err 87 } 88 pkc.current = keys 89 90 jwk, ok = pkc.current.Keys[id] 91 if !ok { 92 return nil, ex.New("invalid jwt key id; not found in signing keys cache", ex.OptMessagef("Key ID: %s", id)) 93 } 94 return jwk.RSAPublicKey() 95 } 96 97 // FetchPublicKeys gets the google signing certs. 98 func (pkc *PublicKeyCache) FetchPublicKeys(ctx context.Context, opts ...r2.Option) (*PublicKeysResponse, error) { 99 var jwks fetchPublicKeysResponse 100 meta, err := r2.New(pkc.keyURL, opts...).JSON(&jwks) 101 if err != nil { 102 return nil, err 103 } 104 105 expiresHeader := meta.Header.Get("Expires") 106 if expiresHeader == "" { 107 return nil, ex.New("invalid google keys response; expires unset") 108 } 109 110 expires, err := time.Parse(http.TimeFormat, expiresHeader) 111 if err != nil { 112 return nil, ex.New("invalid google keys response; invalid expires value", ex.OptInner(err)) 113 } 114 res := &PublicKeysResponse{ 115 Keys: jwkLookup(jwks.Keys), 116 CacheControl: meta.Header.Get("Cache-Control"), 117 Expires: expires, 118 } 119 return res, nil 120 } 121 122 type fetchPublicKeysResponse struct { 123 Keys []jwk.JWK `json:"keys"` 124 } 125 126 // jwkLookup creates a jwk lookup. 127 func jwkLookup(jwks []jwk.JWK) map[string]jwk.JWK { 128 output := make(map[string]jwk.JWK) 129 for _, jwk := range jwks { 130 // We don't check that `jwk.KID` collides with an existing key. We trust that 131 // the public certs URL (e.g. the one from Google) does not include duplicates. 132 output[jwk.KID] = jwk 133 } 134 return output 135 } 136 137 // PublicKeysResponse is a response for the google certs api. 138 type PublicKeysResponse struct { 139 CacheControl string 140 Expires time.Time 141 Keys map[string]jwk.JWK 142 } 143 144 // IsExpired returns if the cert response is expired. 145 func (pkr PublicKeysResponse) IsExpired() bool { 146 if pkr.Expires.IsZero() { 147 return true 148 } 149 return time.Now().UTC().After(pkr.Expires.UTC()) 150 }