github.com/cloudflare/circl@v1.5.0/hpke/util.go (about)

     1  package hpke
     2  
     3  import (
     4  	"encoding/binary"
     5  	"errors"
     6  	"fmt"
     7  )
     8  
     9  func (st state) keySchedule(ss, info, psk, pskID []byte) (*encdecContext, error) {
    10  	if err := st.verifyPSKInputs(psk, pskID); err != nil {
    11  		return nil, err
    12  	}
    13  
    14  	pskIDHash := st.labeledExtract(nil, []byte("psk_id_hash"), pskID)
    15  	infoHash := st.labeledExtract(nil, []byte("info_hash"), info)
    16  	keySchCtx := append(append(
    17  		[]byte{st.modeID},
    18  		pskIDHash...),
    19  		infoHash...)
    20  
    21  	secret := st.labeledExtract(ss, []byte("secret"), psk)
    22  
    23  	Nk := uint16(st.aeadID.KeySize())
    24  	key := st.labeledExpand(secret, []byte("key"), keySchCtx, Nk)
    25  
    26  	aead, err := st.aeadID.New(key)
    27  	if err != nil {
    28  		return nil, err
    29  	}
    30  
    31  	Nn := uint16(aead.NonceSize())
    32  	baseNonce := st.labeledExpand(secret, []byte("base_nonce"), keySchCtx, Nn)
    33  	exporterSecret := st.labeledExpand(
    34  		secret,
    35  		[]byte("exp"),
    36  		keySchCtx,
    37  		uint16(st.kdfID.ExtractSize()),
    38  	)
    39  
    40  	return &encdecContext{
    41  		st.Suite,
    42  		ss,
    43  		secret,
    44  		keySchCtx,
    45  		exporterSecret,
    46  		key,
    47  		baseNonce,
    48  		make([]byte, Nn),
    49  		aead,
    50  		make([]byte, Nn),
    51  	}, nil
    52  }
    53  
    54  func (st state) verifyPSKInputs(psk, pskID []byte) error {
    55  	gotPSK := psk != nil
    56  	gotPSKID := pskID != nil
    57  	if gotPSK != gotPSKID {
    58  		return errors.New("inconsistent PSK inputs")
    59  	}
    60  	switch st.modeID {
    61  	case modeBase | modeAuth:
    62  		if gotPSK {
    63  			return errors.New("PSK input provided when not needed")
    64  		}
    65  	case modePSK | modeAuthPSK:
    66  		if !gotPSK {
    67  			return errors.New("missing required PSK input")
    68  		}
    69  	}
    70  	return nil
    71  }
    72  
    73  // Params returns the codepoints for the algorithms comprising the suite.
    74  func (suite Suite) Params() (KEM, KDF, AEAD) {
    75  	return suite.kemID, suite.kdfID, suite.aeadID
    76  }
    77  
    78  func (suite Suite) String() string {
    79  	return fmt.Sprintf(
    80  		"kem_id: %v kdf_id: %v aead_id: %v",
    81  		suite.kemID, suite.kdfID, suite.aeadID,
    82  	)
    83  }
    84  
    85  func (suite Suite) getSuiteID() (id [10]byte) {
    86  	id[0], id[1], id[2], id[3] = 'H', 'P', 'K', 'E'
    87  	binary.BigEndian.PutUint16(id[4:6], uint16(suite.kemID))
    88  	binary.BigEndian.PutUint16(id[6:8], uint16(suite.kdfID))
    89  	binary.BigEndian.PutUint16(id[8:10], uint16(suite.aeadID))
    90  	return
    91  }
    92  
    93  func (suite Suite) isValid() bool {
    94  	return suite.kemID.IsValid() &&
    95  		suite.kdfID.IsValid() &&
    96  		suite.aeadID.IsValid()
    97  }
    98  
    99  func (suite Suite) labeledExtract(salt, label, ikm []byte) []byte {
   100  	suiteID := suite.getSuiteID()
   101  	labeledIKM := append(append(append(append(
   102  		make([]byte, 0, len(versionLabel)+len(suiteID)+len(label)+len(ikm)),
   103  		versionLabel...),
   104  		suiteID[:]...),
   105  		label...),
   106  		ikm...)
   107  	return suite.kdfID.Extract(labeledIKM, salt)
   108  }
   109  
   110  func (suite Suite) labeledExpand(prk, label, info []byte, l uint16) []byte {
   111  	suiteID := suite.getSuiteID()
   112  	labeledInfo := make([]byte,
   113  		2, 2+len(versionLabel)+len(suiteID)+len(label)+len(info))
   114  	binary.BigEndian.PutUint16(labeledInfo[0:2], l)
   115  	labeledInfo = append(append(append(append(labeledInfo,
   116  		versionLabel...),
   117  		suiteID[:]...),
   118  		label...),
   119  		info...)
   120  	return suite.kdfID.Expand(prk, labeledInfo, uint(l))
   121  }