github.com/lestrrat-go/jwx/v2@v2.0.21/jws/key_provider.go (about)

     1  package jws
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net/url"
     7  	"sync"
     8  
     9  	"github.com/lestrrat-go/jwx/v2/jwa"
    10  	"github.com/lestrrat-go/jwx/v2/jwk"
    11  )
    12  
    13  // KeyProvider is responsible for providing key(s) to sign or verify a payload.
    14  // Multiple `jws.KeyProvider`s can be passed to `jws.Verify()` or `jws.Sign()`
    15  //
    16  // `jws.Sign()` can only accept static key providers via `jws.WithKey()`,
    17  // while `jws.Verify()` can accept `jws.WithKey()`, `jws.WithKeySet()`,
    18  // `jws.WithVerifyAuto()`, and `jws.WithKeyProvider()`.
    19  //
    20  // Understanding how this works is crucial to learn how this package works.
    21  //
    22  // `jws.Sign()` is straightforward: signatures are created for each
    23  // provided key.
    24  //
    25  // `jws.Verify()` is a bit more involved, because there are cases you
    26  // will want to compute/deduce/guess the keys that you would like to
    27  // use for verification.
    28  //
    29  // The first thing that `jws.Verify()` does is to collect the
    30  // KeyProviders from the option list that the user provided (presented in pseudocode):
    31  //
    32  //	keyProviders := filterKeyProviders(options)
    33  //
    34  // Then, remember that a JWS message may contain multiple signatures in the
    35  // message. For each signature, we call on the KeyProviders to give us
    36  // the key(s) to use on this signature:
    37  //
    38  //	for sig in msg.Signatures {
    39  //	  for kp in keyProviders {
    40  //	    kp.FetcKeys(ctx, sink, sig, msg)
    41  //	    ...
    42  //	  }
    43  //	}
    44  //
    45  // The `sink` argument passed to the KeyProvider is a temporary storage
    46  // for the keys (either a jwk.Key or a "raw" key). The `KeyProvider`
    47  // is responsible for sending keys into the `sink`.
    48  //
    49  // When called, the `KeyProvider` created by `jws.WithKey()` sends the same key,
    50  // `jws.WithKeySet()` sends keys that matches a particular `kid` and `alg`,
    51  // `jws.WithVerifyAuto()` fetchs a JWK from the `jku` URL,
    52  // and finally `jws.WithKeyProvider()` allows you to execute arbitrary
    53  // logic to provide keys. If you are providing a custom `KeyProvider`,
    54  // you should execute the necessary checks or retrieval of keys, and
    55  // then send the key(s) to the sink:
    56  //
    57  //	sink.Key(alg, key)
    58  //
    59  // These keys are then retrieved and tried for each signature, until
    60  // a match is found:
    61  //
    62  //	keys := sink.Keys()
    63  //	for key in keys {
    64  //	  if givenSignature == makeSignatre(key, payload, ...)) {
    65  //	    return OK
    66  //	  }
    67  //	}
    68  type KeyProvider interface {
    69  	FetchKeys(context.Context, KeySink, *Signature, *Message) error
    70  }
    71  
    72  // KeySink is a data storage where `jws.KeyProvider` objects should
    73  // send their keys to.
    74  type KeySink interface {
    75  	Key(jwa.SignatureAlgorithm, interface{})
    76  }
    77  
    78  type algKeyPair struct {
    79  	alg jwa.KeyAlgorithm
    80  	key interface{}
    81  }
    82  
    83  type algKeySink struct {
    84  	mu   sync.Mutex
    85  	list []algKeyPair
    86  }
    87  
    88  func (s *algKeySink) Key(alg jwa.SignatureAlgorithm, key interface{}) {
    89  	s.mu.Lock()
    90  	s.list = append(s.list, algKeyPair{alg, key})
    91  	s.mu.Unlock()
    92  }
    93  
    94  type staticKeyProvider struct {
    95  	alg jwa.SignatureAlgorithm
    96  	key interface{}
    97  }
    98  
    99  func (kp *staticKeyProvider) FetchKeys(_ context.Context, sink KeySink, _ *Signature, _ *Message) error {
   100  	sink.Key(kp.alg, kp.key)
   101  	return nil
   102  }
   103  
   104  type keySetProvider struct {
   105  	set                  jwk.Set
   106  	requireKid           bool // true if `kid` must be specified
   107  	useDefault           bool // true if the first key should be used iff there's exactly one key in set
   108  	inferAlgorithm       bool // true if the algorithm should be inferred from key type
   109  	multipleKeysPerKeyID bool // true if we should attempt to match multiple keys per key ID. if false we assume that only one key exists for a given key ID
   110  }
   111  
   112  func (kp *keySetProvider) selectKey(sink KeySink, key jwk.Key, sig *Signature, _ *Message) error {
   113  	if usage := key.KeyUsage(); usage != "" && usage != jwk.ForSignature.String() {
   114  		return nil
   115  	}
   116  
   117  	if v := key.Algorithm(); v.String() != "" {
   118  		var alg jwa.SignatureAlgorithm
   119  		if err := alg.Accept(v); err != nil {
   120  			return fmt.Errorf(`invalid signature algorithm %s: %w`, key.Algorithm(), err)
   121  		}
   122  
   123  		sink.Key(alg, key)
   124  		return nil
   125  	}
   126  
   127  	if kp.inferAlgorithm {
   128  		algs, err := AlgorithmsForKey(key)
   129  		if err != nil {
   130  			return fmt.Errorf(`failed to get a list of signature methods for key type %s: %w`, key.KeyType(), err)
   131  		}
   132  
   133  		// bail out if the JWT has a `alg` field, and it doesn't match
   134  		if tokAlg := sig.ProtectedHeaders().Algorithm(); tokAlg != "" {
   135  			for _, alg := range algs {
   136  				if tokAlg == alg {
   137  					sink.Key(alg, key)
   138  					return nil
   139  				}
   140  			}
   141  			return fmt.Errorf(`algorithm in the message does not match any of the inferred algorithms`)
   142  		}
   143  
   144  		// Yes, you get to try them all!!!!!!!
   145  		for _, alg := range algs {
   146  			sink.Key(alg, key)
   147  		}
   148  		return nil
   149  	}
   150  	return nil
   151  }
   152  
   153  func (kp *keySetProvider) FetchKeys(_ context.Context, sink KeySink, sig *Signature, msg *Message) error {
   154  	if kp.requireKid {
   155  		wantedKid := sig.ProtectedHeaders().KeyID()
   156  		if wantedKid == "" {
   157  			// If the kid is NOT specified... kp.useDefault needs to be true, and the
   158  			// JWKs must have exactly one key in it
   159  			if !kp.useDefault {
   160  				return fmt.Errorf(`failed to find matching key: no key ID ("kid") specified in token`)
   161  			} else if kp.useDefault && kp.set.Len() > 1 {
   162  				return fmt.Errorf(`failed to find matching key: no key ID ("kid") specified in token but multiple keys available in key set`)
   163  			}
   164  
   165  			// if we got here, then useDefault == true AND there is exactly
   166  			// one key in the set.
   167  			key, _ := kp.set.Key(0)
   168  			return kp.selectKey(sink, key, sig, msg)
   169  		}
   170  
   171  		// Otherwise we better be able to look up the key.
   172  		// <= v2.0.3 backwards compatible case: only match a single key
   173  		// whose key ID matches `wantedKid`
   174  		if !kp.multipleKeysPerKeyID {
   175  			key, ok := kp.set.LookupKeyID(wantedKid)
   176  			if !ok {
   177  				return fmt.Errorf(`failed to find key with key ID %q in key set`, wantedKid)
   178  			}
   179  			return kp.selectKey(sink, key, sig, msg)
   180  		}
   181  
   182  		// if multipleKeysPerKeyID is true, we attempt all keys whose key ID matches
   183  		// the wantedKey
   184  		var ok bool
   185  		for i := 0; i < kp.set.Len(); i++ {
   186  			key, _ := kp.set.Key(i)
   187  			if key.KeyID() != wantedKid {
   188  				continue
   189  			}
   190  
   191  			if err := kp.selectKey(sink, key, sig, msg); err != nil {
   192  				continue
   193  			}
   194  			ok = true
   195  			// continue processing so that we try all keys with the same key ID
   196  		}
   197  		if !ok {
   198  			return fmt.Errorf(`failed to find key with key ID %q in key set`, wantedKid)
   199  		}
   200  		return nil
   201  	}
   202  
   203  	// Otherwise just try all keys
   204  	for i := 0; i < kp.set.Len(); i++ {
   205  		key, _ := kp.set.Key(i)
   206  		if err := kp.selectKey(sink, key, sig, msg); err != nil {
   207  			continue
   208  		}
   209  	}
   210  	return nil
   211  }
   212  
   213  type jkuProvider struct {
   214  	fetcher jwk.Fetcher
   215  	options []jwk.FetchOption
   216  }
   217  
   218  func (kp jkuProvider) FetchKeys(ctx context.Context, sink KeySink, sig *Signature, _ *Message) error {
   219  	kid := sig.ProtectedHeaders().KeyID()
   220  	if kid == "" {
   221  		return fmt.Errorf(`use of "jku" requires that the payload contain a "kid" field in the protected header`)
   222  	}
   223  
   224  	// errors here can't be reliablly passed to the consumers.
   225  	// it's unfortunate, but if you need this control, you are
   226  	// going to have to write your own fetcher
   227  	u := sig.ProtectedHeaders().JWKSetURL()
   228  	if u == "" {
   229  		return fmt.Errorf(`use of "jku" field specified, but the field is empty`)
   230  	}
   231  	uo, err := url.Parse(u)
   232  	if err != nil {
   233  		return fmt.Errorf(`failed to parse "jku": %w`, err)
   234  	}
   235  	if uo.Scheme != "https" {
   236  		return fmt.Errorf(`url in "jku" must be HTTPS`)
   237  	}
   238  
   239  	set, err := kp.fetcher.Fetch(ctx, u, kp.options...)
   240  	if err != nil {
   241  		return fmt.Errorf(`failed to fetch %q: %w`, u, err)
   242  	}
   243  
   244  	key, ok := set.LookupKeyID(kid)
   245  	if !ok {
   246  		// It is not an error if the key with the kid doesn't exist
   247  		return nil
   248  	}
   249  
   250  	algs, err := AlgorithmsForKey(key)
   251  	if err != nil {
   252  		return fmt.Errorf(`failed to get a list of signature methods for key type %s: %w`, key.KeyType(), err)
   253  	}
   254  
   255  	hdrAlg := sig.ProtectedHeaders().Algorithm()
   256  	for _, alg := range algs {
   257  		// if we have a "alg" field in the JWS, we can only proceed if
   258  		// the inferred algorithm matches
   259  		if hdrAlg != "" && hdrAlg != alg {
   260  			continue
   261  		}
   262  
   263  		sink.Key(alg, key)
   264  		break
   265  	}
   266  	return nil
   267  }
   268  
   269  // KeyProviderFunc is a type of KeyProvider that is implemented by
   270  // a single function. You can use this to create ad-hoc `KeyProvider`
   271  // instances.
   272  type KeyProviderFunc func(context.Context, KeySink, *Signature, *Message) error
   273  
   274  func (kp KeyProviderFunc) FetchKeys(ctx context.Context, sink KeySink, sig *Signature, msg *Message) error {
   275  	return kp(ctx, sink, sig, msg)
   276  }