github.com/lestrrat-go/jwx/v2@v2.0.21/jwe/internal/cipher/cipher.go (about) 1 package cipher 2 3 import ( 4 "crypto/aes" 5 "crypto/cipher" 6 "fmt" 7 8 "github.com/lestrrat-go/jwx/v2/jwa" 9 "github.com/lestrrat-go/jwx/v2/jwe/internal/aescbc" 10 "github.com/lestrrat-go/jwx/v2/jwe/internal/keygen" 11 ) 12 13 var gcm = &gcmFetcher{} 14 var cbc = &cbcFetcher{} 15 16 func (f gcmFetcher) Fetch(key []byte) (cipher.AEAD, error) { 17 aescipher, err := aes.NewCipher(key) 18 if err != nil { 19 return nil, fmt.Errorf(`cipher: failed to create AES cipher for GCM: %w`, err) 20 } 21 22 aead, err := cipher.NewGCM(aescipher) 23 if err != nil { 24 return nil, fmt.Errorf(`failed to create GCM for cipher: %w`, err) 25 } 26 return aead, nil 27 } 28 29 func (f cbcFetcher) Fetch(key []byte) (cipher.AEAD, error) { 30 aead, err := aescbc.New(key, aes.NewCipher) 31 if err != nil { 32 return nil, fmt.Errorf(`cipher: failed to create AES cipher for CBC: %w`, err) 33 } 34 return aead, nil 35 } 36 37 func (c AesContentCipher) KeySize() int { 38 return c.keysize 39 } 40 41 func (c AesContentCipher) TagSize() int { 42 return c.tagsize 43 } 44 45 func NewAES(alg jwa.ContentEncryptionAlgorithm) (*AesContentCipher, error) { 46 var keysize int 47 var tagsize int 48 var fetcher Fetcher 49 switch alg { 50 case jwa.A128GCM: 51 keysize = 16 52 tagsize = 16 53 fetcher = gcm 54 case jwa.A192GCM: 55 keysize = 24 56 tagsize = 16 57 fetcher = gcm 58 case jwa.A256GCM: 59 keysize = 32 60 tagsize = 16 61 fetcher = gcm 62 case jwa.A128CBC_HS256: 63 tagsize = 16 64 keysize = tagsize * 2 65 fetcher = cbc 66 case jwa.A192CBC_HS384: 67 tagsize = 24 68 keysize = tagsize * 2 69 fetcher = cbc 70 case jwa.A256CBC_HS512: 71 tagsize = 32 72 keysize = tagsize * 2 73 fetcher = cbc 74 default: 75 return nil, fmt.Errorf("failed to create AES content cipher: invalid algorithm (%s)", alg) 76 } 77 78 return &AesContentCipher{ 79 keysize: keysize, 80 tagsize: tagsize, 81 fetch: fetcher, 82 }, nil 83 } 84 85 func (c AesContentCipher) Encrypt(cek, plaintext, aad []byte) (iv, ciphertxt, tag []byte, err error) { 86 var aead cipher.AEAD 87 aead, err = c.fetch.Fetch(cek) 88 if err != nil { 89 return nil, nil, nil, fmt.Errorf(`failed to fetch AEAD: %w`, err) 90 } 91 92 // Seal may panic (argh!), so protect ourselves from that 93 defer func() { 94 if e := recover(); e != nil { 95 switch e := e.(type) { 96 case error: 97 err = e 98 default: 99 err = fmt.Errorf("%s", e) 100 } 101 err = fmt.Errorf(`failed to encrypt: %w`, err) 102 } 103 }() 104 105 var bs keygen.ByteSource 106 if c.NonceGenerator == nil { 107 bs, err = keygen.NewRandom(aead.NonceSize()).Generate() 108 } else { 109 bs, err = c.NonceGenerator.Generate() 110 } 111 if err != nil { 112 return nil, nil, nil, fmt.Errorf(`failed to generate nonce: %w`, err) 113 } 114 iv = bs.Bytes() 115 116 combined := aead.Seal(nil, iv, plaintext, aad) 117 tagoffset := len(combined) - c.TagSize() 118 119 if tagoffset < 0 { 120 panic(fmt.Sprintf("tag offset is less than 0 (combined len = %d, tagsize = %d)", len(combined), c.TagSize())) 121 } 122 123 tag = combined[tagoffset:] 124 ciphertxt = make([]byte, tagoffset) 125 copy(ciphertxt, combined[:tagoffset]) 126 127 return 128 } 129 130 func (c AesContentCipher) Decrypt(cek, iv, ciphertxt, tag, aad []byte) (plaintext []byte, err error) { 131 aead, err := c.fetch.Fetch(cek) 132 if err != nil { 133 return nil, fmt.Errorf(`failed to fetch AEAD data: %w`, err) 134 } 135 136 // Open may panic (argh!), so protect ourselves from that 137 defer func() { 138 if e := recover(); e != nil { 139 switch e := e.(type) { 140 case error: 141 err = e 142 default: 143 err = fmt.Errorf(`%s`, e) 144 } 145 err = fmt.Errorf(`failed to decrypt: %w`, err) 146 return 147 } 148 }() 149 150 combined := make([]byte, len(ciphertxt)+len(tag)) 151 copy(combined, ciphertxt) 152 copy(combined[len(ciphertxt):], tag) 153 154 buf, aeaderr := aead.Open(nil, iv, combined, aad) 155 if aeaderr != nil { 156 err = fmt.Errorf(`aead.Open failed: %w`, aeaderr) 157 return 158 } 159 plaintext = buf 160 return 161 }