github.com/cloudflare/circl@v1.5.0/hpke/kembase.go (about) 1 package hpke 2 3 import ( 4 "crypto" 5 "crypto/rand" 6 "encoding/binary" 7 "io" 8 9 "github.com/cloudflare/circl/kem" 10 "golang.org/x/crypto/hkdf" 11 ) 12 13 type dhKEM interface { 14 sizeDH() int 15 calcDH(dh []byte, sk kem.PrivateKey, pk kem.PublicKey) error 16 SeedSize() int 17 DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) 18 UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) 19 UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) 20 } 21 22 type kemBase struct { 23 id KEM 24 name string 25 crypto.Hash 26 } 27 28 type dhKemBase struct { 29 kemBase 30 dhKEM 31 } 32 33 func (k kemBase) Name() string { return k.name } 34 func (k kemBase) SharedKeySize() int { return k.Hash.Size() } 35 36 func (k kemBase) getSuiteID() (sid [5]byte) { 37 sid[0], sid[1], sid[2] = 'K', 'E', 'M' 38 binary.BigEndian.PutUint16(sid[3:5], uint16(k.id)) 39 return 40 } 41 42 func (k kemBase) extractExpand(dh, kemCtx []byte) []byte { 43 eaePkr := k.labeledExtract([]byte(""), []byte("eae_prk"), dh) 44 return k.labeledExpand( 45 eaePkr, 46 []byte("shared_secret"), 47 kemCtx, 48 uint16(k.Size()), 49 ) 50 } 51 52 func (k kemBase) labeledExtract(salt, label, info []byte) []byte { 53 suiteID := k.getSuiteID() 54 labeledIKM := append(append(append(append( 55 make([]byte, 0, len(versionLabel)+len(suiteID)+len(label)+len(info)), 56 versionLabel...), 57 suiteID[:]...), 58 label...), 59 info...) 60 return hkdf.Extract(k.New, labeledIKM, salt) 61 } 62 63 func (k kemBase) labeledExpand(prk, label, info []byte, l uint16) []byte { 64 suiteID := k.getSuiteID() 65 labeledInfo := make( 66 []byte, 67 2, 68 2+len(versionLabel)+len(suiteID)+len(label)+len(info), 69 ) 70 binary.BigEndian.PutUint16(labeledInfo[0:2], l) 71 labeledInfo = append(append(append(append(labeledInfo, 72 versionLabel...), 73 suiteID[:]...), 74 label...), 75 info...) 76 b := make([]byte, l) 77 rd := hkdf.Expand(k.New, prk, labeledInfo) 78 if _, err := io.ReadFull(rd, b); err != nil { 79 panic(err) 80 } 81 return b 82 } 83 84 func (k dhKemBase) AuthEncapsulate(pkr kem.PublicKey, sks kem.PrivateKey) ( 85 ct []byte, ss []byte, err error, 86 ) { 87 seed := make([]byte, k.SeedSize()) 88 _, err = io.ReadFull(rand.Reader, seed) 89 if err != nil { 90 return nil, nil, err 91 } 92 93 return k.authEncap(pkr, sks, seed) 94 } 95 96 func (k dhKemBase) Encapsulate(pkr kem.PublicKey) ( 97 ct []byte, ss []byte, err error, 98 ) { 99 seed := make([]byte, k.SeedSize()) 100 _, err = io.ReadFull(rand.Reader, seed) 101 if err != nil { 102 return nil, nil, err 103 } 104 105 return k.encap(pkr, seed) 106 } 107 108 func (k dhKemBase) AuthEncapsulateDeterministically( 109 pkr kem.PublicKey, sks kem.PrivateKey, seed []byte, 110 ) (ct, ss []byte, err error) { 111 return k.authEncap(pkr, sks, seed) 112 } 113 114 func (k dhKemBase) EncapsulateDeterministically( 115 pkr kem.PublicKey, seed []byte, 116 ) (ct, ss []byte, err error) { 117 return k.encap(pkr, seed) 118 } 119 120 func (k dhKemBase) encap( 121 pkR kem.PublicKey, 122 seed []byte, 123 ) (ct []byte, ss []byte, err error) { 124 dh := make([]byte, k.sizeDH()) 125 enc, kemCtx, err := k.coreEncap(dh, pkR, seed) 126 if err != nil { 127 return nil, nil, err 128 } 129 ss = k.extractExpand(dh, kemCtx) 130 return enc, ss, nil 131 } 132 133 func (k dhKemBase) authEncap( 134 pkR kem.PublicKey, 135 skS kem.PrivateKey, 136 seed []byte, 137 ) (ct []byte, ss []byte, err error) { 138 dhLen := k.sizeDH() 139 dh := make([]byte, 2*dhLen) 140 enc, kemCtx, err := k.coreEncap(dh[:dhLen], pkR, seed) 141 if err != nil { 142 return nil, nil, err 143 } 144 145 err = k.calcDH(dh[dhLen:], skS, pkR) 146 if err != nil { 147 return nil, nil, err 148 } 149 150 pkS := skS.Public() 151 pkSm, err := pkS.MarshalBinary() 152 if err != nil { 153 return nil, nil, err 154 } 155 kemCtx = append(kemCtx, pkSm...) 156 157 ss = k.extractExpand(dh, kemCtx) 158 return enc, ss, nil 159 } 160 161 func (k dhKemBase) coreEncap( 162 dh []byte, 163 pkR kem.PublicKey, 164 seed []byte, 165 ) (enc []byte, kemCtx []byte, err error) { 166 pkE, skE := k.DeriveKeyPair(seed) 167 err = k.calcDH(dh, skE, pkR) 168 if err != nil { 169 return nil, nil, err 170 } 171 172 enc, err = pkE.MarshalBinary() 173 if err != nil { 174 return nil, nil, err 175 } 176 pkRm, err := pkR.MarshalBinary() 177 if err != nil { 178 return nil, nil, err 179 } 180 kemCtx = append(append([]byte{}, enc...), pkRm...) 181 182 return enc, kemCtx, nil 183 } 184 185 func (k dhKemBase) Decapsulate(skr kem.PrivateKey, ct []byte) ([]byte, error) { 186 dh := make([]byte, k.sizeDH()) 187 kemCtx, err := k.coreDecap(dh, skr, ct) 188 if err != nil { 189 return nil, err 190 } 191 return k.extractExpand(dh, kemCtx), nil 192 } 193 194 func (k dhKemBase) AuthDecapsulate( 195 skR kem.PrivateKey, 196 ct []byte, 197 pkS kem.PublicKey, 198 ) ([]byte, error) { 199 dhLen := k.sizeDH() 200 dh := make([]byte, 2*dhLen) 201 kemCtx, err := k.coreDecap(dh[:dhLen], skR, ct) 202 if err != nil { 203 return nil, err 204 } 205 206 err = k.calcDH(dh[dhLen:], skR, pkS) 207 if err != nil { 208 return nil, err 209 } 210 211 pkSm, err := pkS.MarshalBinary() 212 if err != nil { 213 return nil, err 214 } 215 kemCtx = append(kemCtx, pkSm...) 216 return k.extractExpand(dh, kemCtx), nil 217 } 218 219 func (k dhKemBase) coreDecap( 220 dh []byte, 221 skR kem.PrivateKey, 222 ct []byte, 223 ) ([]byte, error) { 224 pkE, err := k.UnmarshalBinaryPublicKey(ct) 225 if err != nil { 226 return nil, err 227 } 228 229 err = k.calcDH(dh, skR, pkE) 230 if err != nil { 231 return nil, err 232 } 233 234 pkR := skR.Public() 235 pkRm, err := pkR.MarshalBinary() 236 if err != nil { 237 return nil, err 238 } 239 240 return append(append([]byte{}, ct...), pkRm...), nil 241 }