github.com/cloudflare/circl@v1.5.0/hpke/aead.go (about)

     1  package hpke
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"fmt"
     6  )
     7  
     8  type encdecContext struct {
     9  	// Serialized parameters
    10  	suite              Suite
    11  	sharedSecret       []byte
    12  	secret             []byte
    13  	keyScheduleContext []byte
    14  	exporterSecret     []byte
    15  	key                []byte
    16  	baseNonce          []byte
    17  	sequenceNumber     []byte
    18  
    19  	// Operational parameters
    20  	cipher.AEAD
    21  	nonce []byte
    22  }
    23  
    24  type (
    25  	sealContext struct{ *encdecContext }
    26  	openContext struct{ *encdecContext }
    27  )
    28  
    29  // Export takes a context string exporterContext and a desired length (in
    30  // bytes), and produces a secret derived from the internal exporter secret
    31  // using the corresponding KDF Expand function. It panics if length is
    32  // greater than 255*N bytes, where N is the size (in bytes) of the KDF's
    33  // output.
    34  func (c *encdecContext) Export(exporterContext []byte, length uint) []byte {
    35  	maxLength := uint(255 * c.suite.kdfID.ExtractSize())
    36  	if length > maxLength {
    37  		panic(fmt.Errorf("output length must be lesser than %v bytes", maxLength))
    38  	}
    39  	return c.suite.labeledExpand(c.exporterSecret, []byte("sec"),
    40  		exporterContext, uint16(length))
    41  }
    42  
    43  func (c *encdecContext) Suite() Suite {
    44  	return c.suite
    45  }
    46  
    47  func (c *encdecContext) calcNonce() []byte {
    48  	for i := range c.baseNonce {
    49  		c.nonce[i] = c.baseNonce[i] ^ c.sequenceNumber[i]
    50  	}
    51  	return c.nonce
    52  }
    53  
    54  func (c *encdecContext) increment() error {
    55  	// tests whether the sequence number is all-ones, which prevents an
    56  	// overflow after the increment.
    57  	allOnes := byte(0xFF)
    58  	for i := range c.sequenceNumber {
    59  		allOnes &= c.sequenceNumber[i]
    60  	}
    61  	if allOnes == byte(0xFF) {
    62  		return ErrAEADSeqOverflows
    63  	}
    64  
    65  	// performs an increment by 1 and verifies whether the sequence overflows.
    66  	carry := uint(1)
    67  	for i := len(c.sequenceNumber) - 1; i >= 0; i-- {
    68  		sum := uint(c.sequenceNumber[i]) + carry
    69  		carry = sum >> 8
    70  		c.sequenceNumber[i] = byte(sum & 0xFF)
    71  	}
    72  	if carry != 0 {
    73  		return ErrAEADSeqOverflows
    74  	}
    75  	return nil
    76  }
    77  
    78  func (c *sealContext) Seal(pt, aad []byte) ([]byte, error) {
    79  	ct := c.AEAD.Seal(nil, c.calcNonce(), pt, aad)
    80  	err := c.increment()
    81  	if err != nil {
    82  		for i := range ct {
    83  			ct[i] = 0
    84  		}
    85  		return nil, err
    86  	}
    87  	return ct, nil
    88  }
    89  
    90  func (c *openContext) Open(ct, aad []byte) ([]byte, error) {
    91  	pt, err := c.AEAD.Open(nil, c.calcNonce(), ct, aad)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  	err = c.increment()
    96  	if err != nil {
    97  		for i := range pt {
    98  			pt[i] = 0
    99  		}
   100  		return nil, err
   101  	}
   102  	return pt, nil
   103  }