github.com/lestrrat-go/jwx/v2@v2.0.21/jwe/key_provider.go (about)

     1  package jwe
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  
     8  	"github.com/lestrrat-go/jwx/v2/jwa"
     9  	"github.com/lestrrat-go/jwx/v2/jwk"
    10  )
    11  
    12  // KeyProvider is responsible for providing key(s) to encrypt or decrypt a payload.
    13  // Multiple `jwe.KeyProvider`s can be passed to `jwe.Encrypt()` or `jwe.Decrypt()`
    14  //
    15  // `jwe.Encrypt()` can only accept static key providers via `jwe.WithKey()`,
    16  // while `jwe.Derypt()` can accept `jwe.WithKey()`, `jwe.WithKeySet()`,
    17  // and `jwe.WithKeyProvider()`.
    18  //
    19  // Understanding how this works is crucial to learn how this package works.
    20  // Here we will use `jwe.Decrypt()` as an example to show how the `KeyProvider`
    21  // works.
    22  //
    23  // `jwe.Encrypt()` is straightforward: the content encryption key is encrypted
    24  // using the provided keys, and JWS recipient objects are created for each.
    25  //
    26  // `jwe.Decrypt()` is a bit more involved, because there are cases you
    27  // will want to compute/deduce/guess the keys that you would like to
    28  // use for decryption.
    29  //
    30  // The first thing that `jwe.Decrypt()` needs to do is to collect the
    31  // KeyProviders from the option list that the user provided (presented in pseudocode):
    32  //
    33  //	keyProviders := filterKeyProviders(options)
    34  //
    35  // Then, remember that a JWE message may contain multiple recipients in the
    36  // message. For each recipient, we call on the KeyProviders to give us
    37  // the key(s) to use on this signature:
    38  //
    39  //	for r in msg.Recipients {
    40  //	  for kp in keyProviders {
    41  //	    kp.FetcKeys(ctx, sink, r, msg)
    42  //	    ...
    43  //	  }
    44  //	}
    45  //
    46  // The `sink` argument passed to the KeyProvider is a temporary storage
    47  // for the keys (either a jwk.Key or a "raw" key). The `KeyProvider`
    48  // is responsible for sending keys into the `sink`.
    49  //
    50  // When called, the `KeyProvider` created by `jwe.WithKey()` sends the same key,
    51  // `jwe.WithKeySet()` sends keys that matches a particular `kid` and `alg`,
    52  // and finally `jwe.WithKeyProvider()` allows you to execute arbitrary
    53  // logic to provide keys. If you are providing a custom `KeyProvider`,
    54  // you should execute the necessary checks or retrieval of keys, and
    55  // then send the key(s) to the sink:
    56  //
    57  //	sink.Key(alg, key)
    58  //
    59  // These keys are then retrieved and tried for each signature, until
    60  // a match is found:
    61  //
    62  //	keys := sink.Keys()
    63  //	for key in keys {
    64  //	  if decryptJWEKey(recipient.EncryptedKey(), key) {
    65  //	    return OK
    66  //	  }
    67  //	}
    68  type KeyProvider interface {
    69  	FetchKeys(context.Context, KeySink, Recipient, *Message) error
    70  }
    71  
    72  // KeySink is a data storage where `jwe.KeyProvider` objects should
    73  // send their keys to.
    74  type KeySink interface {
    75  	Key(jwa.KeyEncryptionAlgorithm, interface{})
    76  }
    77  
    78  type algKeyPair struct {
    79  	alg jwa.KeyAlgorithm
    80  	key interface{}
    81  }
    82  
    83  type algKeySink struct {
    84  	mu   sync.Mutex
    85  	list []algKeyPair
    86  }
    87  
    88  func (s *algKeySink) Key(alg jwa.KeyEncryptionAlgorithm, key interface{}) {
    89  	s.mu.Lock()
    90  	s.list = append(s.list, algKeyPair{alg, key})
    91  	s.mu.Unlock()
    92  }
    93  
    94  type staticKeyProvider struct {
    95  	alg jwa.KeyEncryptionAlgorithm
    96  	key interface{}
    97  }
    98  
    99  func (kp *staticKeyProvider) FetchKeys(_ context.Context, sink KeySink, _ Recipient, _ *Message) error {
   100  	sink.Key(kp.alg, kp.key)
   101  	return nil
   102  }
   103  
   104  type keySetProvider struct {
   105  	set        jwk.Set
   106  	requireKid bool
   107  }
   108  
   109  func (kp *keySetProvider) selectKey(sink KeySink, key jwk.Key, _ Recipient, _ *Message) error {
   110  	if usage := key.KeyUsage(); usage != "" && usage != jwk.ForEncryption.String() {
   111  		return nil
   112  	}
   113  
   114  	if v := key.Algorithm(); v.String() != "" {
   115  		var alg jwa.KeyEncryptionAlgorithm
   116  		if err := alg.Accept(v); err != nil {
   117  			return fmt.Errorf(`invalid key encryption algorithm %s: %w`, key.Algorithm(), err)
   118  		}
   119  
   120  		sink.Key(alg, key)
   121  		return nil
   122  	}
   123  
   124  	return nil
   125  }
   126  
   127  func (kp *keySetProvider) FetchKeys(_ context.Context, sink KeySink, r Recipient, msg *Message) error {
   128  	if kp.requireKid {
   129  		var key jwk.Key
   130  
   131  		wantedKid := r.Headers().KeyID()
   132  		if wantedKid == "" {
   133  			return fmt.Errorf(`failed to find matching key: no key ID ("kid") specified in token but multiple keys available in key set`)
   134  		}
   135  		// Otherwise we better be able to look up the key, baby.
   136  		v, ok := kp.set.LookupKeyID(wantedKid)
   137  		if !ok {
   138  			return fmt.Errorf(`failed to find key with key ID %q in key set`, wantedKid)
   139  		}
   140  		key = v
   141  
   142  		return kp.selectKey(sink, key, r, msg)
   143  	}
   144  
   145  	for i := 0; i < kp.set.Len(); i++ {
   146  		key, _ := kp.set.Key(i)
   147  		if err := kp.selectKey(sink, key, r, msg); err != nil {
   148  			continue
   149  		}
   150  	}
   151  	return nil
   152  }
   153  
   154  // KeyProviderFunc is a type of KeyProvider that is implemented by
   155  // a single function. You can use this to create ad-hoc `KeyProvider`
   156  // instances.
   157  type KeyProviderFunc func(context.Context, KeySink, Recipient, *Message) error
   158  
   159  func (kp KeyProviderFunc) FetchKeys(ctx context.Context, sink KeySink, r Recipient, msg *Message) error {
   160  	return kp(ctx, sink, r, msg)
   161  }