github.com/lestrrat-go/jwx/v2@v2.0.21/jwe/internal/cipher/cipher.go (about)

     1  package cipher
     2  
     3  import (
     4  	"crypto/aes"
     5  	"crypto/cipher"
     6  	"fmt"
     7  
     8  	"github.com/lestrrat-go/jwx/v2/jwa"
     9  	"github.com/lestrrat-go/jwx/v2/jwe/internal/aescbc"
    10  	"github.com/lestrrat-go/jwx/v2/jwe/internal/keygen"
    11  )
    12  
    13  var gcm = &gcmFetcher{}
    14  var cbc = &cbcFetcher{}
    15  
    16  func (f gcmFetcher) Fetch(key []byte) (cipher.AEAD, error) {
    17  	aescipher, err := aes.NewCipher(key)
    18  	if err != nil {
    19  		return nil, fmt.Errorf(`cipher: failed to create AES cipher for GCM: %w`, err)
    20  	}
    21  
    22  	aead, err := cipher.NewGCM(aescipher)
    23  	if err != nil {
    24  		return nil, fmt.Errorf(`failed to create GCM for cipher: %w`, err)
    25  	}
    26  	return aead, nil
    27  }
    28  
    29  func (f cbcFetcher) Fetch(key []byte) (cipher.AEAD, error) {
    30  	aead, err := aescbc.New(key, aes.NewCipher)
    31  	if err != nil {
    32  		return nil, fmt.Errorf(`cipher: failed to create AES cipher for CBC: %w`, err)
    33  	}
    34  	return aead, nil
    35  }
    36  
    37  func (c AesContentCipher) KeySize() int {
    38  	return c.keysize
    39  }
    40  
    41  func (c AesContentCipher) TagSize() int {
    42  	return c.tagsize
    43  }
    44  
    45  func NewAES(alg jwa.ContentEncryptionAlgorithm) (*AesContentCipher, error) {
    46  	var keysize int
    47  	var tagsize int
    48  	var fetcher Fetcher
    49  	switch alg {
    50  	case jwa.A128GCM:
    51  		keysize = 16
    52  		tagsize = 16
    53  		fetcher = gcm
    54  	case jwa.A192GCM:
    55  		keysize = 24
    56  		tagsize = 16
    57  		fetcher = gcm
    58  	case jwa.A256GCM:
    59  		keysize = 32
    60  		tagsize = 16
    61  		fetcher = gcm
    62  	case jwa.A128CBC_HS256:
    63  		tagsize = 16
    64  		keysize = tagsize * 2
    65  		fetcher = cbc
    66  	case jwa.A192CBC_HS384:
    67  		tagsize = 24
    68  		keysize = tagsize * 2
    69  		fetcher = cbc
    70  	case jwa.A256CBC_HS512:
    71  		tagsize = 32
    72  		keysize = tagsize * 2
    73  		fetcher = cbc
    74  	default:
    75  		return nil, fmt.Errorf("failed to create AES content cipher: invalid algorithm (%s)", alg)
    76  	}
    77  
    78  	return &AesContentCipher{
    79  		keysize: keysize,
    80  		tagsize: tagsize,
    81  		fetch:   fetcher,
    82  	}, nil
    83  }
    84  
    85  func (c AesContentCipher) Encrypt(cek, plaintext, aad []byte) (iv, ciphertxt, tag []byte, err error) {
    86  	var aead cipher.AEAD
    87  	aead, err = c.fetch.Fetch(cek)
    88  	if err != nil {
    89  		return nil, nil, nil, fmt.Errorf(`failed to fetch AEAD: %w`, err)
    90  	}
    91  
    92  	// Seal may panic (argh!), so protect ourselves from that
    93  	defer func() {
    94  		if e := recover(); e != nil {
    95  			switch e := e.(type) {
    96  			case error:
    97  				err = e
    98  			default:
    99  				err = fmt.Errorf("%s", e)
   100  			}
   101  			err = fmt.Errorf(`failed to encrypt: %w`, err)
   102  		}
   103  	}()
   104  
   105  	var bs keygen.ByteSource
   106  	if c.NonceGenerator == nil {
   107  		bs, err = keygen.NewRandom(aead.NonceSize()).Generate()
   108  	} else {
   109  		bs, err = c.NonceGenerator.Generate()
   110  	}
   111  	if err != nil {
   112  		return nil, nil, nil, fmt.Errorf(`failed to generate nonce: %w`, err)
   113  	}
   114  	iv = bs.Bytes()
   115  
   116  	combined := aead.Seal(nil, iv, plaintext, aad)
   117  	tagoffset := len(combined) - c.TagSize()
   118  
   119  	if tagoffset < 0 {
   120  		panic(fmt.Sprintf("tag offset is less than 0 (combined len = %d, tagsize = %d)", len(combined), c.TagSize()))
   121  	}
   122  
   123  	tag = combined[tagoffset:]
   124  	ciphertxt = make([]byte, tagoffset)
   125  	copy(ciphertxt, combined[:tagoffset])
   126  
   127  	return
   128  }
   129  
   130  func (c AesContentCipher) Decrypt(cek, iv, ciphertxt, tag, aad []byte) (plaintext []byte, err error) {
   131  	aead, err := c.fetch.Fetch(cek)
   132  	if err != nil {
   133  		return nil, fmt.Errorf(`failed to fetch AEAD data: %w`, err)
   134  	}
   135  
   136  	// Open may panic (argh!), so protect ourselves from that
   137  	defer func() {
   138  		if e := recover(); e != nil {
   139  			switch e := e.(type) {
   140  			case error:
   141  				err = e
   142  			default:
   143  				err = fmt.Errorf(`%s`, e)
   144  			}
   145  			err = fmt.Errorf(`failed to decrypt: %w`, err)
   146  			return
   147  		}
   148  	}()
   149  
   150  	combined := make([]byte, len(ciphertxt)+len(tag))
   151  	copy(combined, ciphertxt)
   152  	copy(combined[len(ciphertxt):], tag)
   153  
   154  	buf, aeaderr := aead.Open(nil, iv, combined, aad)
   155  	if aeaderr != nil {
   156  		err = fmt.Errorf(`aead.Open failed: %w`, aeaderr)
   157  		return
   158  	}
   159  	plaintext = buf
   160  	return
   161  }