github.com/lestrrat-go/jwx/v2@v2.0.21/jwe/decrypt.go (about)

     1  package jwe
     2  
     3  import (
     4  	"crypto/aes"
     5  	cryptocipher "crypto/cipher"
     6  	"crypto/ecdsa"
     7  	"crypto/rsa"
     8  	"crypto/sha256"
     9  	"crypto/sha512"
    10  	"fmt"
    11  	"hash"
    12  
    13  	"golang.org/x/crypto/pbkdf2"
    14  
    15  	"github.com/lestrrat-go/jwx/v2/internal/keyconv"
    16  	"github.com/lestrrat-go/jwx/v2/jwa"
    17  	"github.com/lestrrat-go/jwx/v2/jwe/internal/cipher"
    18  	"github.com/lestrrat-go/jwx/v2/jwe/internal/content_crypt"
    19  	"github.com/lestrrat-go/jwx/v2/jwe/internal/keyenc"
    20  	"github.com/lestrrat-go/jwx/v2/x25519"
    21  )
    22  
    23  // decrypter is responsible for taking various components to decrypt a message.
    24  // its operation is not concurrency safe. You must provide locking yourself
    25  //
    26  //nolint:govet
    27  type decrypter struct {
    28  	aad         []byte
    29  	apu         []byte
    30  	apv         []byte
    31  	cek         *[]byte
    32  	computedAad []byte
    33  	iv          []byte
    34  	keyiv       []byte
    35  	keysalt     []byte
    36  	keytag      []byte
    37  	tag         []byte
    38  	privkey     interface{}
    39  	pubkey      interface{}
    40  	ctalg       jwa.ContentEncryptionAlgorithm
    41  	keyalg      jwa.KeyEncryptionAlgorithm
    42  	cipher      content_crypt.Cipher
    43  	keycount    int
    44  }
    45  
    46  // newDecrypter Creates a new Decrypter instance. You must supply the
    47  // rest of parameters via their respective setter methods before
    48  // calling Decrypt().
    49  //
    50  // privkey must be a private key in its "raw" format (i.e. something like
    51  // *rsa.PrivateKey, instead of jwk.Key)
    52  //
    53  // You should consider this object immutable once you assign values to it.
    54  func newDecrypter(keyalg jwa.KeyEncryptionAlgorithm, ctalg jwa.ContentEncryptionAlgorithm, privkey interface{}) *decrypter {
    55  	return &decrypter{
    56  		ctalg:   ctalg,
    57  		keyalg:  keyalg,
    58  		privkey: privkey,
    59  	}
    60  }
    61  
    62  func (d *decrypter) AgreementPartyUInfo(apu []byte) *decrypter {
    63  	d.apu = apu
    64  	return d
    65  }
    66  
    67  func (d *decrypter) AgreementPartyVInfo(apv []byte) *decrypter {
    68  	d.apv = apv
    69  	return d
    70  }
    71  
    72  func (d *decrypter) AuthenticatedData(aad []byte) *decrypter {
    73  	d.aad = aad
    74  	return d
    75  }
    76  
    77  func (d *decrypter) ComputedAuthenticatedData(aad []byte) *decrypter {
    78  	d.computedAad = aad
    79  	return d
    80  }
    81  
    82  func (d *decrypter) ContentEncryptionAlgorithm(ctalg jwa.ContentEncryptionAlgorithm) *decrypter {
    83  	d.ctalg = ctalg
    84  	return d
    85  }
    86  
    87  func (d *decrypter) InitializationVector(iv []byte) *decrypter {
    88  	d.iv = iv
    89  	return d
    90  }
    91  
    92  func (d *decrypter) KeyCount(keycount int) *decrypter {
    93  	d.keycount = keycount
    94  	return d
    95  }
    96  
    97  func (d *decrypter) KeyInitializationVector(keyiv []byte) *decrypter {
    98  	d.keyiv = keyiv
    99  	return d
   100  }
   101  
   102  func (d *decrypter) KeySalt(keysalt []byte) *decrypter {
   103  	d.keysalt = keysalt
   104  	return d
   105  }
   106  
   107  func (d *decrypter) KeyTag(keytag []byte) *decrypter {
   108  	d.keytag = keytag
   109  	return d
   110  }
   111  
   112  // PublicKey sets the public key to be used in decoding EC based encryptions.
   113  // The key must be in its "raw" format (i.e. *ecdsa.PublicKey, instead of jwk.Key)
   114  func (d *decrypter) PublicKey(pubkey interface{}) *decrypter {
   115  	d.pubkey = pubkey
   116  	return d
   117  }
   118  
   119  func (d *decrypter) Tag(tag []byte) *decrypter {
   120  	d.tag = tag
   121  	return d
   122  }
   123  
   124  func (d *decrypter) CEK(ptr *[]byte) *decrypter {
   125  	d.cek = ptr
   126  	return d
   127  }
   128  
   129  func (d *decrypter) ContentCipher() (content_crypt.Cipher, error) {
   130  	if d.cipher == nil {
   131  		switch d.ctalg {
   132  		case jwa.A128GCM, jwa.A192GCM, jwa.A256GCM, jwa.A128CBC_HS256, jwa.A192CBC_HS384, jwa.A256CBC_HS512:
   133  			cipher, err := cipher.NewAES(d.ctalg)
   134  			if err != nil {
   135  				return nil, fmt.Errorf(`failed to build content cipher for %s: %w`, d.ctalg, err)
   136  			}
   137  			d.cipher = cipher
   138  		default:
   139  			return nil, fmt.Errorf(`invalid content cipher algorithm (%s)`, d.ctalg)
   140  		}
   141  	}
   142  
   143  	return d.cipher, nil
   144  }
   145  
   146  func (d *decrypter) Decrypt(recipient Recipient, ciphertext []byte, msg *Message) (plaintext []byte, err error) {
   147  	cek, keyerr := d.DecryptKey(recipient, msg)
   148  	if keyerr != nil {
   149  		err = fmt.Errorf(`failed to decrypt key: %w`, keyerr)
   150  		return
   151  	}
   152  
   153  	cipher, ciphererr := d.ContentCipher()
   154  	if ciphererr != nil {
   155  		err = fmt.Errorf(`failed to fetch content crypt cipher: %w`, ciphererr)
   156  		return
   157  	}
   158  
   159  	computedAad := d.computedAad
   160  	if d.aad != nil {
   161  		computedAad = append(append(computedAad, '.'), d.aad...)
   162  	}
   163  
   164  	plaintext, err = cipher.Decrypt(cek, d.iv, ciphertext, d.tag, computedAad)
   165  	if err != nil {
   166  		err = fmt.Errorf(`failed to decrypt payload: %w`, err)
   167  		return
   168  	}
   169  
   170  	if d.cek != nil {
   171  		*d.cek = cek
   172  	}
   173  	return plaintext, nil
   174  }
   175  
   176  func (d *decrypter) decryptSymmetricKey(recipientKey, cek []byte) ([]byte, error) {
   177  	switch d.keyalg {
   178  	case jwa.DIRECT:
   179  		return cek, nil
   180  	case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW:
   181  		var hashFunc func() hash.Hash
   182  		var keylen int
   183  		switch d.keyalg {
   184  		case jwa.PBES2_HS256_A128KW:
   185  			hashFunc = sha256.New
   186  			keylen = 16
   187  		case jwa.PBES2_HS384_A192KW:
   188  			hashFunc = sha512.New384
   189  			keylen = 24
   190  		case jwa.PBES2_HS512_A256KW:
   191  			hashFunc = sha512.New
   192  			keylen = 32
   193  		}
   194  		salt := []byte(d.keyalg)
   195  		salt = append(salt, byte(0))
   196  		salt = append(salt, d.keysalt...)
   197  		cek = pbkdf2.Key(cek, salt, d.keycount, keylen, hashFunc)
   198  		fallthrough
   199  	case jwa.A128KW, jwa.A192KW, jwa.A256KW:
   200  		block, err := aes.NewCipher(cek)
   201  		if err != nil {
   202  			return nil, fmt.Errorf(`failed to create new AES cipher: %w`, err)
   203  		}
   204  
   205  		jek, err := keyenc.Unwrap(block, recipientKey)
   206  		if err != nil {
   207  			return nil, fmt.Errorf(`failed to unwrap key: %w`, err)
   208  		}
   209  
   210  		return jek, nil
   211  	case jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW:
   212  		if len(d.keyiv) != 12 {
   213  			return nil, fmt.Errorf("GCM requires 96-bit iv, got %d", len(d.keyiv)*8)
   214  		}
   215  		if len(d.keytag) != 16 {
   216  			return nil, fmt.Errorf("GCM requires 128-bit tag, got %d", len(d.keytag)*8)
   217  		}
   218  		block, err := aes.NewCipher(cek)
   219  		if err != nil {
   220  			return nil, fmt.Errorf(`failed to create new AES cipher: %w`, err)
   221  		}
   222  		aesgcm, err := cryptocipher.NewGCM(block)
   223  		if err != nil {
   224  			return nil, fmt.Errorf(`failed to create new GCM wrap: %w`, err)
   225  		}
   226  		ciphertext := recipientKey[:]
   227  		ciphertext = append(ciphertext, d.keytag...)
   228  		jek, err := aesgcm.Open(nil, d.keyiv, ciphertext, nil)
   229  		if err != nil {
   230  			return nil, fmt.Errorf(`failed to decode key: %w`, err)
   231  		}
   232  		return jek, nil
   233  	default:
   234  		return nil, fmt.Errorf("decrypt key: unsupported algorithm %s", d.keyalg)
   235  	}
   236  }
   237  
   238  func (d *decrypter) DecryptKey(recipient Recipient, msg *Message) (cek []byte, err error) {
   239  	recipientKey := recipient.EncryptedKey()
   240  	if kd, ok := d.privkey.(KeyDecrypter); ok {
   241  		return kd.DecryptKey(d.keyalg, recipientKey, recipient, msg)
   242  	}
   243  
   244  	if d.keyalg.IsSymmetric() {
   245  		var ok bool
   246  		cek, ok = d.privkey.([]byte)
   247  		if !ok {
   248  			return nil, fmt.Errorf("decrypt key: []byte is required as the key to build %s key decrypter (got %T)", d.keyalg, d.privkey)
   249  		}
   250  
   251  		return d.decryptSymmetricKey(recipientKey, cek)
   252  	}
   253  
   254  	k, err := d.BuildKeyDecrypter()
   255  	if err != nil {
   256  		return nil, fmt.Errorf(`failed to build key decrypter: %w`, err)
   257  	}
   258  
   259  	cek, err = k.Decrypt(recipientKey)
   260  	if err != nil {
   261  		return nil, fmt.Errorf(`failed to decrypt key: %w`, err)
   262  	}
   263  
   264  	return cek, nil
   265  }
   266  
   267  func (d *decrypter) BuildKeyDecrypter() (keyenc.Decrypter, error) {
   268  	cipher, err := d.ContentCipher()
   269  	if err != nil {
   270  		return nil, fmt.Errorf(`failed to fetch content crypt cipher: %w`, err)
   271  	}
   272  
   273  	switch alg := d.keyalg; alg {
   274  	case jwa.RSA1_5:
   275  		var privkey rsa.PrivateKey
   276  		if err := keyconv.RSAPrivateKey(&privkey, d.privkey); err != nil {
   277  			return nil, fmt.Errorf(`*rsa.PrivateKey is required as the key to build %s key decrypter: %w`, alg, err)
   278  		}
   279  
   280  		return keyenc.NewRSAPKCS15Decrypt(alg, &privkey, cipher.KeySize()/2), nil
   281  	case jwa.RSA_OAEP, jwa.RSA_OAEP_256:
   282  		var privkey rsa.PrivateKey
   283  		if err := keyconv.RSAPrivateKey(&privkey, d.privkey); err != nil {
   284  			return nil, fmt.Errorf(`*rsa.PrivateKey is required as the key to build %s key decrypter: %w`, alg, err)
   285  		}
   286  
   287  		return keyenc.NewRSAOAEPDecrypt(alg, &privkey)
   288  	case jwa.A128KW, jwa.A192KW, jwa.A256KW:
   289  		sharedkey, ok := d.privkey.([]byte)
   290  		if !ok {
   291  			return nil, fmt.Errorf("[]byte is required as the key to build %s key decrypter", alg)
   292  		}
   293  
   294  		return keyenc.NewAES(alg, sharedkey)
   295  	case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW:
   296  		switch d.pubkey.(type) {
   297  		case x25519.PublicKey:
   298  			return keyenc.NewECDHESDecrypt(alg, d.ctalg, d.pubkey, d.apu, d.apv, d.privkey), nil
   299  		default:
   300  			var pubkey ecdsa.PublicKey
   301  			if err := keyconv.ECDSAPublicKey(&pubkey, d.pubkey); err != nil {
   302  				return nil, fmt.Errorf(`*ecdsa.PublicKey is required as the key to build %s key decrypter: %w`, alg, err)
   303  			}
   304  
   305  			var privkey ecdsa.PrivateKey
   306  			if err := keyconv.ECDSAPrivateKey(&privkey, d.privkey); err != nil {
   307  				return nil, fmt.Errorf(`*ecdsa.PrivateKey is required as the key to build %s key decrypter: %w`, alg, err)
   308  			}
   309  
   310  			return keyenc.NewECDHESDecrypt(alg, d.ctalg, &pubkey, d.apu, d.apv, &privkey), nil
   311  		}
   312  	default:
   313  		return nil, fmt.Errorf(`unsupported algorithm for key decryption (%s)`, alg)
   314  	}
   315  }