github.com/cloudflare/circl@v1.5.0/hpke/util.go (about) 1 package hpke 2 3 import ( 4 "encoding/binary" 5 "errors" 6 "fmt" 7 ) 8 9 func (st state) keySchedule(ss, info, psk, pskID []byte) (*encdecContext, error) { 10 if err := st.verifyPSKInputs(psk, pskID); err != nil { 11 return nil, err 12 } 13 14 pskIDHash := st.labeledExtract(nil, []byte("psk_id_hash"), pskID) 15 infoHash := st.labeledExtract(nil, []byte("info_hash"), info) 16 keySchCtx := append(append( 17 []byte{st.modeID}, 18 pskIDHash...), 19 infoHash...) 20 21 secret := st.labeledExtract(ss, []byte("secret"), psk) 22 23 Nk := uint16(st.aeadID.KeySize()) 24 key := st.labeledExpand(secret, []byte("key"), keySchCtx, Nk) 25 26 aead, err := st.aeadID.New(key) 27 if err != nil { 28 return nil, err 29 } 30 31 Nn := uint16(aead.NonceSize()) 32 baseNonce := st.labeledExpand(secret, []byte("base_nonce"), keySchCtx, Nn) 33 exporterSecret := st.labeledExpand( 34 secret, 35 []byte("exp"), 36 keySchCtx, 37 uint16(st.kdfID.ExtractSize()), 38 ) 39 40 return &encdecContext{ 41 st.Suite, 42 ss, 43 secret, 44 keySchCtx, 45 exporterSecret, 46 key, 47 baseNonce, 48 make([]byte, Nn), 49 aead, 50 make([]byte, Nn), 51 }, nil 52 } 53 54 func (st state) verifyPSKInputs(psk, pskID []byte) error { 55 gotPSK := psk != nil 56 gotPSKID := pskID != nil 57 if gotPSK != gotPSKID { 58 return errors.New("inconsistent PSK inputs") 59 } 60 switch st.modeID { 61 case modeBase | modeAuth: 62 if gotPSK { 63 return errors.New("PSK input provided when not needed") 64 } 65 case modePSK | modeAuthPSK: 66 if !gotPSK { 67 return errors.New("missing required PSK input") 68 } 69 } 70 return nil 71 } 72 73 // Params returns the codepoints for the algorithms comprising the suite. 74 func (suite Suite) Params() (KEM, KDF, AEAD) { 75 return suite.kemID, suite.kdfID, suite.aeadID 76 } 77 78 func (suite Suite) String() string { 79 return fmt.Sprintf( 80 "kem_id: %v kdf_id: %v aead_id: %v", 81 suite.kemID, suite.kdfID, suite.aeadID, 82 ) 83 } 84 85 func (suite Suite) getSuiteID() (id [10]byte) { 86 id[0], id[1], id[2], id[3] = 'H', 'P', 'K', 'E' 87 binary.BigEndian.PutUint16(id[4:6], uint16(suite.kemID)) 88 binary.BigEndian.PutUint16(id[6:8], uint16(suite.kdfID)) 89 binary.BigEndian.PutUint16(id[8:10], uint16(suite.aeadID)) 90 return 91 } 92 93 func (suite Suite) isValid() bool { 94 return suite.kemID.IsValid() && 95 suite.kdfID.IsValid() && 96 suite.aeadID.IsValid() 97 } 98 99 func (suite Suite) labeledExtract(salt, label, ikm []byte) []byte { 100 suiteID := suite.getSuiteID() 101 labeledIKM := append(append(append(append( 102 make([]byte, 0, len(versionLabel)+len(suiteID)+len(label)+len(ikm)), 103 versionLabel...), 104 suiteID[:]...), 105 label...), 106 ikm...) 107 return suite.kdfID.Extract(labeledIKM, salt) 108 } 109 110 func (suite Suite) labeledExpand(prk, label, info []byte, l uint16) []byte { 111 suiteID := suite.getSuiteID() 112 labeledInfo := make([]byte, 113 2, 2+len(versionLabel)+len(suiteID)+len(label)+len(info)) 114 binary.BigEndian.PutUint16(labeledInfo[0:2], l) 115 labeledInfo = append(append(append(append(labeledInfo, 116 versionLabel...), 117 suiteID[:]...), 118 label...), 119 info...) 120 return suite.kdfID.Expand(prk, labeledInfo, uint(l)) 121 }