github.com/cloudflare/circl@v1.5.0/hpke/aead.go (about) 1 package hpke 2 3 import ( 4 "crypto/cipher" 5 "fmt" 6 ) 7 8 type encdecContext struct { 9 // Serialized parameters 10 suite Suite 11 sharedSecret []byte 12 secret []byte 13 keyScheduleContext []byte 14 exporterSecret []byte 15 key []byte 16 baseNonce []byte 17 sequenceNumber []byte 18 19 // Operational parameters 20 cipher.AEAD 21 nonce []byte 22 } 23 24 type ( 25 sealContext struct{ *encdecContext } 26 openContext struct{ *encdecContext } 27 ) 28 29 // Export takes a context string exporterContext and a desired length (in 30 // bytes), and produces a secret derived from the internal exporter secret 31 // using the corresponding KDF Expand function. It panics if length is 32 // greater than 255*N bytes, where N is the size (in bytes) of the KDF's 33 // output. 34 func (c *encdecContext) Export(exporterContext []byte, length uint) []byte { 35 maxLength := uint(255 * c.suite.kdfID.ExtractSize()) 36 if length > maxLength { 37 panic(fmt.Errorf("output length must be lesser than %v bytes", maxLength)) 38 } 39 return c.suite.labeledExpand(c.exporterSecret, []byte("sec"), 40 exporterContext, uint16(length)) 41 } 42 43 func (c *encdecContext) Suite() Suite { 44 return c.suite 45 } 46 47 func (c *encdecContext) calcNonce() []byte { 48 for i := range c.baseNonce { 49 c.nonce[i] = c.baseNonce[i] ^ c.sequenceNumber[i] 50 } 51 return c.nonce 52 } 53 54 func (c *encdecContext) increment() error { 55 // tests whether the sequence number is all-ones, which prevents an 56 // overflow after the increment. 57 allOnes := byte(0xFF) 58 for i := range c.sequenceNumber { 59 allOnes &= c.sequenceNumber[i] 60 } 61 if allOnes == byte(0xFF) { 62 return ErrAEADSeqOverflows 63 } 64 65 // performs an increment by 1 and verifies whether the sequence overflows. 66 carry := uint(1) 67 for i := len(c.sequenceNumber) - 1; i >= 0; i-- { 68 sum := uint(c.sequenceNumber[i]) + carry 69 carry = sum >> 8 70 c.sequenceNumber[i] = byte(sum & 0xFF) 71 } 72 if carry != 0 { 73 return ErrAEADSeqOverflows 74 } 75 return nil 76 } 77 78 func (c *sealContext) Seal(pt, aad []byte) ([]byte, error) { 79 ct := c.AEAD.Seal(nil, c.calcNonce(), pt, aad) 80 err := c.increment() 81 if err != nil { 82 for i := range ct { 83 ct[i] = 0 84 } 85 return nil, err 86 } 87 return ct, nil 88 } 89 90 func (c *openContext) Open(ct, aad []byte) ([]byte, error) { 91 pt, err := c.AEAD.Open(nil, c.calcNonce(), ct, aad) 92 if err != nil { 93 return nil, err 94 } 95 err = c.increment() 96 if err != nil { 97 for i := range pt { 98 pt[i] = 0 99 } 100 return nil, err 101 } 102 return pt, nil 103 }