github.com/lestrrat-go/jwx/v2@v2.0.21/jws/key_provider.go (about) 1 package jws 2 3 import ( 4 "context" 5 "fmt" 6 "net/url" 7 "sync" 8 9 "github.com/lestrrat-go/jwx/v2/jwa" 10 "github.com/lestrrat-go/jwx/v2/jwk" 11 ) 12 13 // KeyProvider is responsible for providing key(s) to sign or verify a payload. 14 // Multiple `jws.KeyProvider`s can be passed to `jws.Verify()` or `jws.Sign()` 15 // 16 // `jws.Sign()` can only accept static key providers via `jws.WithKey()`, 17 // while `jws.Verify()` can accept `jws.WithKey()`, `jws.WithKeySet()`, 18 // `jws.WithVerifyAuto()`, and `jws.WithKeyProvider()`. 19 // 20 // Understanding how this works is crucial to learn how this package works. 21 // 22 // `jws.Sign()` is straightforward: signatures are created for each 23 // provided key. 24 // 25 // `jws.Verify()` is a bit more involved, because there are cases you 26 // will want to compute/deduce/guess the keys that you would like to 27 // use for verification. 28 // 29 // The first thing that `jws.Verify()` does is to collect the 30 // KeyProviders from the option list that the user provided (presented in pseudocode): 31 // 32 // keyProviders := filterKeyProviders(options) 33 // 34 // Then, remember that a JWS message may contain multiple signatures in the 35 // message. For each signature, we call on the KeyProviders to give us 36 // the key(s) to use on this signature: 37 // 38 // for sig in msg.Signatures { 39 // for kp in keyProviders { 40 // kp.FetcKeys(ctx, sink, sig, msg) 41 // ... 42 // } 43 // } 44 // 45 // The `sink` argument passed to the KeyProvider is a temporary storage 46 // for the keys (either a jwk.Key or a "raw" key). The `KeyProvider` 47 // is responsible for sending keys into the `sink`. 48 // 49 // When called, the `KeyProvider` created by `jws.WithKey()` sends the same key, 50 // `jws.WithKeySet()` sends keys that matches a particular `kid` and `alg`, 51 // `jws.WithVerifyAuto()` fetchs a JWK from the `jku` URL, 52 // and finally `jws.WithKeyProvider()` allows you to execute arbitrary 53 // logic to provide keys. If you are providing a custom `KeyProvider`, 54 // you should execute the necessary checks or retrieval of keys, and 55 // then send the key(s) to the sink: 56 // 57 // sink.Key(alg, key) 58 // 59 // These keys are then retrieved and tried for each signature, until 60 // a match is found: 61 // 62 // keys := sink.Keys() 63 // for key in keys { 64 // if givenSignature == makeSignatre(key, payload, ...)) { 65 // return OK 66 // } 67 // } 68 type KeyProvider interface { 69 FetchKeys(context.Context, KeySink, *Signature, *Message) error 70 } 71 72 // KeySink is a data storage where `jws.KeyProvider` objects should 73 // send their keys to. 74 type KeySink interface { 75 Key(jwa.SignatureAlgorithm, interface{}) 76 } 77 78 type algKeyPair struct { 79 alg jwa.KeyAlgorithm 80 key interface{} 81 } 82 83 type algKeySink struct { 84 mu sync.Mutex 85 list []algKeyPair 86 } 87 88 func (s *algKeySink) Key(alg jwa.SignatureAlgorithm, key interface{}) { 89 s.mu.Lock() 90 s.list = append(s.list, algKeyPair{alg, key}) 91 s.mu.Unlock() 92 } 93 94 type staticKeyProvider struct { 95 alg jwa.SignatureAlgorithm 96 key interface{} 97 } 98 99 func (kp *staticKeyProvider) FetchKeys(_ context.Context, sink KeySink, _ *Signature, _ *Message) error { 100 sink.Key(kp.alg, kp.key) 101 return nil 102 } 103 104 type keySetProvider struct { 105 set jwk.Set 106 requireKid bool // true if `kid` must be specified 107 useDefault bool // true if the first key should be used iff there's exactly one key in set 108 inferAlgorithm bool // true if the algorithm should be inferred from key type 109 multipleKeysPerKeyID bool // true if we should attempt to match multiple keys per key ID. if false we assume that only one key exists for a given key ID 110 } 111 112 func (kp *keySetProvider) selectKey(sink KeySink, key jwk.Key, sig *Signature, _ *Message) error { 113 if usage := key.KeyUsage(); usage != "" && usage != jwk.ForSignature.String() { 114 return nil 115 } 116 117 if v := key.Algorithm(); v.String() != "" { 118 var alg jwa.SignatureAlgorithm 119 if err := alg.Accept(v); err != nil { 120 return fmt.Errorf(`invalid signature algorithm %s: %w`, key.Algorithm(), err) 121 } 122 123 sink.Key(alg, key) 124 return nil 125 } 126 127 if kp.inferAlgorithm { 128 algs, err := AlgorithmsForKey(key) 129 if err != nil { 130 return fmt.Errorf(`failed to get a list of signature methods for key type %s: %w`, key.KeyType(), err) 131 } 132 133 // bail out if the JWT has a `alg` field, and it doesn't match 134 if tokAlg := sig.ProtectedHeaders().Algorithm(); tokAlg != "" { 135 for _, alg := range algs { 136 if tokAlg == alg { 137 sink.Key(alg, key) 138 return nil 139 } 140 } 141 return fmt.Errorf(`algorithm in the message does not match any of the inferred algorithms`) 142 } 143 144 // Yes, you get to try them all!!!!!!! 145 for _, alg := range algs { 146 sink.Key(alg, key) 147 } 148 return nil 149 } 150 return nil 151 } 152 153 func (kp *keySetProvider) FetchKeys(_ context.Context, sink KeySink, sig *Signature, msg *Message) error { 154 if kp.requireKid { 155 wantedKid := sig.ProtectedHeaders().KeyID() 156 if wantedKid == "" { 157 // If the kid is NOT specified... kp.useDefault needs to be true, and the 158 // JWKs must have exactly one key in it 159 if !kp.useDefault { 160 return fmt.Errorf(`failed to find matching key: no key ID ("kid") specified in token`) 161 } else if kp.useDefault && kp.set.Len() > 1 { 162 return fmt.Errorf(`failed to find matching key: no key ID ("kid") specified in token but multiple keys available in key set`) 163 } 164 165 // if we got here, then useDefault == true AND there is exactly 166 // one key in the set. 167 key, _ := kp.set.Key(0) 168 return kp.selectKey(sink, key, sig, msg) 169 } 170 171 // Otherwise we better be able to look up the key. 172 // <= v2.0.3 backwards compatible case: only match a single key 173 // whose key ID matches `wantedKid` 174 if !kp.multipleKeysPerKeyID { 175 key, ok := kp.set.LookupKeyID(wantedKid) 176 if !ok { 177 return fmt.Errorf(`failed to find key with key ID %q in key set`, wantedKid) 178 } 179 return kp.selectKey(sink, key, sig, msg) 180 } 181 182 // if multipleKeysPerKeyID is true, we attempt all keys whose key ID matches 183 // the wantedKey 184 var ok bool 185 for i := 0; i < kp.set.Len(); i++ { 186 key, _ := kp.set.Key(i) 187 if key.KeyID() != wantedKid { 188 continue 189 } 190 191 if err := kp.selectKey(sink, key, sig, msg); err != nil { 192 continue 193 } 194 ok = true 195 // continue processing so that we try all keys with the same key ID 196 } 197 if !ok { 198 return fmt.Errorf(`failed to find key with key ID %q in key set`, wantedKid) 199 } 200 return nil 201 } 202 203 // Otherwise just try all keys 204 for i := 0; i < kp.set.Len(); i++ { 205 key, _ := kp.set.Key(i) 206 if err := kp.selectKey(sink, key, sig, msg); err != nil { 207 continue 208 } 209 } 210 return nil 211 } 212 213 type jkuProvider struct { 214 fetcher jwk.Fetcher 215 options []jwk.FetchOption 216 } 217 218 func (kp jkuProvider) FetchKeys(ctx context.Context, sink KeySink, sig *Signature, _ *Message) error { 219 kid := sig.ProtectedHeaders().KeyID() 220 if kid == "" { 221 return fmt.Errorf(`use of "jku" requires that the payload contain a "kid" field in the protected header`) 222 } 223 224 // errors here can't be reliablly passed to the consumers. 225 // it's unfortunate, but if you need this control, you are 226 // going to have to write your own fetcher 227 u := sig.ProtectedHeaders().JWKSetURL() 228 if u == "" { 229 return fmt.Errorf(`use of "jku" field specified, but the field is empty`) 230 } 231 uo, err := url.Parse(u) 232 if err != nil { 233 return fmt.Errorf(`failed to parse "jku": %w`, err) 234 } 235 if uo.Scheme != "https" { 236 return fmt.Errorf(`url in "jku" must be HTTPS`) 237 } 238 239 set, err := kp.fetcher.Fetch(ctx, u, kp.options...) 240 if err != nil { 241 return fmt.Errorf(`failed to fetch %q: %w`, u, err) 242 } 243 244 key, ok := set.LookupKeyID(kid) 245 if !ok { 246 // It is not an error if the key with the kid doesn't exist 247 return nil 248 } 249 250 algs, err := AlgorithmsForKey(key) 251 if err != nil { 252 return fmt.Errorf(`failed to get a list of signature methods for key type %s: %w`, key.KeyType(), err) 253 } 254 255 hdrAlg := sig.ProtectedHeaders().Algorithm() 256 for _, alg := range algs { 257 // if we have a "alg" field in the JWS, we can only proceed if 258 // the inferred algorithm matches 259 if hdrAlg != "" && hdrAlg != alg { 260 continue 261 } 262 263 sink.Key(alg, key) 264 break 265 } 266 return nil 267 } 268 269 // KeyProviderFunc is a type of KeyProvider that is implemented by 270 // a single function. You can use this to create ad-hoc `KeyProvider` 271 // instances. 272 type KeyProviderFunc func(context.Context, KeySink, *Signature, *Message) error 273 274 func (kp KeyProviderFunc) FetchKeys(ctx context.Context, sink KeySink, sig *Signature, msg *Message) error { 275 return kp(ctx, sink, sig, msg) 276 }