github.com/lestrrat-go/jwx/v2@v2.0.21/jwe/internal/aescbc/aescbc.go (about)

     1  package aescbc
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"crypto/hmac"
     6  	"crypto/sha256"
     7  	"crypto/sha512"
     8  	"crypto/subtle"
     9  	"encoding/binary"
    10  	"errors"
    11  	"fmt"
    12  	"hash"
    13  	"sync/atomic"
    14  )
    15  
    16  const (
    17  	NonceSize = 16
    18  )
    19  
    20  const defaultBufSize int64 = 256 * 1024 * 1024
    21  
    22  // Grr, we would like to use atomic.Int64, but that's only available
    23  // from Go 1.19. Yes, we will cut support for Go 1.19 at some point,
    24  // but not today (probably going to up the minimum required Go version
    25  // some time after 1.22 is released)
    26  var maxBufSize int64
    27  
    28  func init() {
    29  	atomic.StoreInt64(&maxBufSize, defaultBufSize)
    30  }
    31  
    32  func SetMaxBufferSize(siz int64) {
    33  	if siz <= 0 {
    34  		siz = defaultBufSize
    35  	}
    36  	atomic.StoreInt64(&maxBufSize, siz)
    37  }
    38  
    39  func pad(buf []byte, n int) []byte {
    40  	rem := n - len(buf)%n
    41  	if rem == 0 {
    42  		return buf
    43  	}
    44  
    45  	mbs := atomic.LoadInt64(&maxBufSize)
    46  	if int64(len(buf)+rem) > mbs {
    47  		panic(fmt.Errorf("failed to allocate buffer"))
    48  	}
    49  	newbuf := make([]byte, len(buf)+rem)
    50  	copy(newbuf, buf)
    51  
    52  	for i := len(buf); i < len(newbuf); i++ {
    53  		newbuf[i] = byte(rem)
    54  	}
    55  	return newbuf
    56  }
    57  
    58  // ref. https://github.com/golang/go/blob/c3db64c0f45e8f2d75c5b59401e0fc925701b6f4/src/crypto/tls/conn.go#L279-L324
    59  //
    60  // extractPadding returns, in constant time, the length of the padding to remove
    61  // from the end of payload. It also returns a byte which is equal to 255 if the
    62  // padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2.
    63  func extractPadding(payload []byte) (toRemove int, good byte) {
    64  	if len(payload) < 1 {
    65  		return 0, 0
    66  	}
    67  
    68  	paddingLen := payload[len(payload)-1]
    69  	t := uint(len(payload)) - uint(paddingLen)
    70  	// if len(payload) > paddingLen then the MSB of t is zero
    71  	good = byte(int32(^t) >> 31)
    72  
    73  	// The maximum possible padding length plus the actual length field
    74  	toCheck := 256
    75  	// The length of the padded data is public, so we can use an if here
    76  	if toCheck > len(payload) {
    77  		toCheck = len(payload)
    78  	}
    79  
    80  	for i := 1; i <= toCheck; i++ {
    81  		t := uint(paddingLen) - uint(i)
    82  		// if i <= paddingLen then the MSB of t is zero
    83  		mask := byte(int32(^t) >> 31)
    84  		b := payload[len(payload)-i]
    85  		good &^= mask&paddingLen ^ mask&b
    86  	}
    87  
    88  	// We AND together the bits of good and replicate the result across
    89  	// all the bits.
    90  	good &= good << 4
    91  	good &= good << 2
    92  	good &= good << 1
    93  	good = uint8(int8(good) >> 7)
    94  
    95  	// Zero the padding length on error. This ensures any unchecked bytes
    96  	// are included in the MAC. Otherwise, an attacker that could
    97  	// distinguish MAC failures from padding failures could mount an attack
    98  	// similar to POODLE in SSL 3.0: given a good ciphertext that uses a
    99  	// full block's worth of padding, replace the final block with another
   100  	// block. If the MAC check passed but the padding check failed, the
   101  	// last byte of that block decrypted to the block size.
   102  	//
   103  	// See also macAndPaddingGood logic below.
   104  	paddingLen &= good
   105  
   106  	toRemove = int(paddingLen)
   107  	return
   108  }
   109  
   110  type Hmac struct {
   111  	blockCipher  cipher.Block
   112  	hash         func() hash.Hash
   113  	keysize      int
   114  	tagsize      int
   115  	integrityKey []byte
   116  }
   117  
   118  type BlockCipherFunc func([]byte) (cipher.Block, error)
   119  
   120  func New(key []byte, f BlockCipherFunc) (hmac *Hmac, err error) {
   121  	keysize := len(key) / 2
   122  	ikey := key[:keysize]
   123  	ekey := key[keysize:]
   124  
   125  	bc, ciphererr := f(ekey)
   126  	if ciphererr != nil {
   127  		err = fmt.Errorf(`failed to execute block cipher function: %w`, ciphererr)
   128  		return
   129  	}
   130  
   131  	var hfunc func() hash.Hash
   132  	switch keysize {
   133  	case 16:
   134  		hfunc = sha256.New
   135  	case 24:
   136  		hfunc = sha512.New384
   137  	case 32:
   138  		hfunc = sha512.New
   139  	default:
   140  		return nil, fmt.Errorf("unsupported key size %d", keysize)
   141  	}
   142  
   143  	return &Hmac{
   144  		blockCipher:  bc,
   145  		hash:         hfunc,
   146  		integrityKey: ikey,
   147  		keysize:      keysize,
   148  		tagsize:      keysize, // NonceSize,
   149  		// While investigating GH #207, I stumbled upon another problem where
   150  		// the computed tags don't match on decrypt. After poking through the
   151  		// code using a bunch of debug statements, I've finally found out that
   152  		// tagsize = keysize makes the whole thing work.
   153  	}, nil
   154  }
   155  
   156  // NonceSize fulfills the crypto.AEAD interface
   157  func (c Hmac) NonceSize() int {
   158  	return NonceSize
   159  }
   160  
   161  // Overhead fulfills the crypto.AEAD interface
   162  func (c Hmac) Overhead() int {
   163  	return c.blockCipher.BlockSize() + c.tagsize
   164  }
   165  
   166  func (c Hmac) ComputeAuthTag(aad, nonce, ciphertext []byte) ([]byte, error) {
   167  	var buf [8]byte
   168  	binary.BigEndian.PutUint64(buf[:], uint64(len(aad)*8))
   169  
   170  	h := hmac.New(c.hash, c.integrityKey)
   171  
   172  	// compute the tag
   173  	// no need to check errors because Write never returns an error: https://pkg.go.dev/hash#Hash
   174  	//
   175  	// > Write (via the embedded io.Writer interface) adds more data to the running hash.
   176  	// > It never returns an error.
   177  	h.Write(aad)
   178  	h.Write(nonce)
   179  	h.Write(ciphertext)
   180  	h.Write(buf[:])
   181  	s := h.Sum(nil)
   182  	return s[:c.tagsize], nil
   183  }
   184  
   185  func ensureSize(dst []byte, n int) []byte {
   186  	// if the dst buffer has enough length just copy the relevant parts to it.
   187  	// Otherwise create a new slice that's big enough, and operate on that
   188  	// Note: I think go-jose has a bug in that it checks for cap(), but not len().
   189  	ret := dst
   190  	if diff := n - len(dst); diff > 0 {
   191  		// dst is not big enough
   192  		ret = make([]byte, n)
   193  		copy(ret, dst)
   194  	}
   195  	return ret
   196  }
   197  
   198  // Seal fulfills the crypto.AEAD interface
   199  func (c Hmac) Seal(dst, nonce, plaintext, data []byte) []byte {
   200  	ctlen := len(plaintext)
   201  	bufsiz := ctlen + c.Overhead()
   202  	mbs := atomic.LoadInt64(&maxBufSize)
   203  	if int64(bufsiz) > mbs {
   204  		panic(fmt.Errorf("failed to allocate buffer"))
   205  	}
   206  	ciphertext := make([]byte, ctlen+c.Overhead())[:ctlen]
   207  	copy(ciphertext, plaintext)
   208  	ciphertext = pad(ciphertext, c.blockCipher.BlockSize())
   209  
   210  	cbc := cipher.NewCBCEncrypter(c.blockCipher, nonce)
   211  	cbc.CryptBlocks(ciphertext, ciphertext)
   212  
   213  	authtag, err := c.ComputeAuthTag(data, nonce, ciphertext)
   214  	if err != nil {
   215  		// Hmac implements cipher.AEAD interface. Seal can't return error.
   216  		// But currently it never reach here because of Hmac.ComputeAuthTag doesn't return error.
   217  		panic(fmt.Errorf("failed to seal on hmac: %v", err))
   218  	}
   219  
   220  	retlen := len(dst) + len(ciphertext) + len(authtag)
   221  
   222  	ret := ensureSize(dst, retlen)
   223  	out := ret[len(dst):]
   224  	n := copy(out, ciphertext)
   225  	copy(out[n:], authtag)
   226  
   227  	return ret
   228  }
   229  
   230  // Open fulfills the crypto.AEAD interface
   231  func (c Hmac) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) {
   232  	if len(ciphertext) < c.keysize {
   233  		return nil, fmt.Errorf(`invalid ciphertext (too short)`)
   234  	}
   235  
   236  	tagOffset := len(ciphertext) - c.tagsize
   237  	if tagOffset%c.blockCipher.BlockSize() != 0 {
   238  		return nil, fmt.Errorf(
   239  			"invalid ciphertext (invalid length: %d %% %d != 0)",
   240  			tagOffset,
   241  			c.blockCipher.BlockSize(),
   242  		)
   243  	}
   244  	tag := ciphertext[tagOffset:]
   245  	ciphertext = ciphertext[:tagOffset]
   246  
   247  	expectedTag, err := c.ComputeAuthTag(data, nonce, ciphertext[:tagOffset])
   248  	if err != nil {
   249  		return nil, fmt.Errorf(`failed to compute auth tag: %w`, err)
   250  	}
   251  
   252  	cbc := cipher.NewCBCDecrypter(c.blockCipher, nonce)
   253  	buf := make([]byte, tagOffset)
   254  	cbc.CryptBlocks(buf, ciphertext)
   255  
   256  	toRemove, good := extractPadding(buf)
   257  	cmp := subtle.ConstantTimeCompare(expectedTag, tag) & int(good)
   258  	if cmp != 1 {
   259  		return nil, errors.New(`invalid ciphertext`)
   260  	}
   261  
   262  	plaintext := buf[:len(buf)-toRemove]
   263  	ret := ensureSize(dst, len(plaintext))
   264  	out := ret[len(dst):]
   265  	copy(out, plaintext)
   266  	return ret, nil
   267  }