github.com/cloudflare/circl@v1.5.0/hpke/vectors_test.go (about) 1 package hpke 2 3 import ( 4 "bytes" 5 "encoding/hex" 6 "encoding/json" 7 "fmt" 8 "io" 9 "os" 10 "testing" 11 12 "github.com/cloudflare/circl/internal/test" 13 "github.com/cloudflare/circl/kem" 14 "golang.org/x/crypto/sha3" 15 ) 16 17 var ( 18 outputTestVectorEnvironmentKey = "HPKE_TEST_VECTORS_OUT" 19 testVectorEncryptionCount = 257 20 testVectorExportLength = 32 21 ) 22 23 func TestVectors(t *testing.T) { 24 // Test vectors from 25 // https://github.com/cfrg/draft-irtf-cfrg-hpke/blob/master/test-vectors.json 26 vectors := readFile(t, "testdata/vectors_rfc9180_5f503c5.json") 27 for i, v := range vectors { 28 t.Run(fmt.Sprintf("v%v", i), v.verify) 29 } 30 } 31 32 func (v *vector) verify(t *testing.T) { 33 m := v.ModeID 34 kem, kdf, aead := KEM(v.KemID), KDF(v.KdfID), AEAD(v.AeadID) 35 if !kem.IsValid() { 36 t.Skipf("Skipping test with unknown KEM: %x", kem) 37 } 38 if !kdf.IsValid() { 39 t.Skipf("Skipping test with unknown KDF: %x", kdf) 40 } 41 if !aead.IsValid() { 42 t.Skipf("Skipping test with unknown AEAD: %x", aead) 43 } 44 s := NewSuite(kem, kdf, aead) 45 46 sender, recv := v.getActors(t, kem.Scheme(), s) 47 sealer, opener := v.setup(t, kem.Scheme(), sender, recv, m, s) 48 49 v.checkAead(t, (sealer.(*sealContext)).encdecContext, m) 50 v.checkAead(t, (opener.(*openContext)).encdecContext, m) 51 v.checkEncryptions(t, sealer, opener, m) 52 v.checkExports(t, sealer, m) 53 v.checkExports(t, opener, m) 54 } 55 56 func (v *vector) getActors( 57 t *testing.T, dhkem kem.Scheme, s Suite, 58 ) (*Sender, *Receiver) { 59 h := s.String() + "\n" 60 61 pkR, err := dhkem.UnmarshalBinaryPublicKey(hexB(t, v.PkRm)) 62 test.CheckNoErr(t, err, h+"bad public key") 63 64 skR, err := dhkem.UnmarshalBinaryPrivateKey(hexB(t, v.SkRm)) 65 test.CheckNoErr(t, err, h+"bad private key") 66 67 info := hexB(t, v.Info) 68 sender, err := s.NewSender(pkR, info) 69 test.CheckNoErr(t, err, h+"err sender") 70 71 recv, err := s.NewReceiver(skR, info) 72 test.CheckNoErr(t, err, h+"err receiver") 73 74 return sender, recv 75 } 76 77 func (v *vector) setup(t *testing.T, k kem.Scheme, 78 se *Sender, re *Receiver, 79 m modeID, s Suite, 80 ) (sealer Sealer, opener Opener) { 81 seed := hexB(t, v.IkmE) 82 rd := bytes.NewReader(seed) 83 84 var enc []byte 85 var skS kem.PrivateKey 86 var pkS kem.PublicKey 87 var errS, errR, errPK, errSK error 88 89 switch v.ModeID { 90 case modeBase: 91 enc, sealer, errS = se.Setup(rd) 92 if errS == nil { 93 opener, errR = re.Setup(enc) 94 } 95 96 case modePSK: 97 psk, pskid := hexB(t, v.Psk), hexB(t, v.PskID) 98 enc, sealer, errS = se.SetupPSK(rd, psk, pskid) 99 if errS == nil { 100 opener, errR = re.SetupPSK(enc, psk, pskid) 101 } 102 103 case modeAuth: 104 skS, errSK = k.UnmarshalBinaryPrivateKey(hexB(t, v.SkSm)) 105 if errSK == nil { 106 pkS, errPK = k.UnmarshalBinaryPublicKey(hexB(t, v.PkSm)) 107 if errPK == nil { 108 enc, sealer, errS = se.SetupAuth(rd, skS) 109 if errS == nil { 110 opener, errR = re.SetupAuth(enc, pkS) 111 } 112 } 113 } 114 115 case modeAuthPSK: 116 psk, pskid := hexB(t, v.Psk), hexB(t, v.PskID) 117 skS, errSK = k.UnmarshalBinaryPrivateKey(hexB(t, v.SkSm)) 118 if errSK == nil { 119 pkS, errPK = k.UnmarshalBinaryPublicKey(hexB(t, v.PkSm)) 120 if errPK == nil { 121 enc, sealer, errS = se.SetupAuthPSK(rd, skS, psk, pskid) 122 if errS == nil { 123 opener, errR = re.SetupAuthPSK(enc, psk, pskid, pkS) 124 } 125 } 126 } 127 } 128 129 h := fmt.Sprintf("mode: %v %v\n", m, s) 130 test.CheckNoErr(t, errS, h+"error on sender setup") 131 test.CheckNoErr(t, errR, h+"error on receiver setup") 132 test.CheckNoErr(t, errSK, h+"bad private key") 133 test.CheckNoErr(t, errPK, h+"bad public key") 134 135 return sealer, opener 136 } 137 138 func (v *vector) checkAead(t *testing.T, e *encdecContext, m modeID) { 139 got := e.baseNonce 140 want := hexB(t, v.BaseNonce) 141 if !bytes.Equal(got, want) { 142 test.ReportError(t, got, want, m, e.Suite()) 143 } 144 145 got = e.exporterSecret 146 want = hexB(t, v.ExporterSecret) 147 if !bytes.Equal(got, want) { 148 test.ReportError(t, got, want, m, e.Suite()) 149 } 150 } 151 152 func (v *vector) checkEncryptions( 153 t *testing.T, 154 se Sealer, 155 op Opener, 156 m modeID, 157 ) { 158 for j, encv := range v.Encryptions { 159 pt := hexB(t, encv.Plaintext) 160 aad := hexB(t, encv.Aad) 161 162 ct, err := se.Seal(pt, aad) 163 test.CheckNoErr(t, err, "error on sealing") 164 165 got, err := op.Open(ct, aad) 166 test.CheckNoErr(t, err, "error on opening") 167 168 want := pt 169 if !bytes.Equal(got, want) { 170 test.ReportError(t, got, want, m, se.Suite(), j) 171 } 172 } 173 } 174 175 func (v *vector) checkExports(t *testing.T, context Context, m modeID) { 176 for j, expv := range v.Exports { 177 ctx := hexB(t, expv.ExportContext) 178 want := hexB(t, expv.ExportValue) 179 180 got := context.Export(ctx, uint(expv.ExportLength)) 181 if !bytes.Equal(got, want) { 182 test.ReportError(t, got, want, m, context.Suite(), j) 183 } 184 } 185 } 186 187 func hexB(t *testing.T, x string) []byte { 188 t.Helper() 189 z, err := hex.DecodeString(x) 190 test.CheckNoErr(t, err, "") 191 return z 192 } 193 194 func readFile(t *testing.T, fileName string) []vector { 195 jsonFile, err := os.Open(fileName) 196 if err != nil { 197 t.Fatalf("File %v can not be opened. Error: %v", fileName, err) 198 } 199 defer jsonFile.Close() 200 input, err := io.ReadAll(jsonFile) 201 if err != nil { 202 t.Fatalf("File %v can not be read. Error: %v", fileName, err) 203 } 204 var vectors []vector 205 err = json.Unmarshal(input, &vectors) 206 if err != nil { 207 t.Fatalf("File %v can not be loaded. Error: %v", fileName, err) 208 } 209 return vectors 210 } 211 212 type encryptionVector struct { 213 Aad string `json:"aad"` 214 Ciphertext string `json:"ct"` 215 Nonce string `json:"nonce"` 216 Plaintext string `json:"pt"` 217 } 218 219 type exportVector struct { 220 ExportContext string `json:"exporter_context"` 221 ExportLength int `json:"L"` 222 ExportValue string `json:"exported_value"` 223 } 224 225 type vector struct { 226 ModeID uint8 `json:"mode"` 227 KemID uint16 `json:"kem_id"` 228 KdfID uint16 `json:"kdf_id"` 229 AeadID uint16 `json:"aead_id"` 230 Info string `json:"info"` 231 Ier string `json:"ier,omitempty"` 232 IkmR string `json:"ikmR"` 233 IkmE string `json:"ikmE,omitempty"` 234 SkRm string `json:"skRm"` 235 SkEm string `json:"skEm,omitempty"` 236 SkSm string `json:"skSm,omitempty"` 237 Psk string `json:"psk,omitempty"` 238 PskID string `json:"psk_id,omitempty"` 239 PkSm string `json:"pkSm,omitempty"` 240 PkRm string `json:"pkRm"` 241 PkEm string `json:"pkEm,omitempty"` 242 Enc string `json:"enc"` 243 SharedSecret string `json:"shared_secret"` 244 KeyScheduleContext string `json:"key_schedule_context"` 245 Secret string `json:"secret"` 246 Key string `json:"key"` 247 BaseNonce string `json:"base_nonce"` 248 ExporterSecret string `json:"exporter_secret"` 249 Encryptions []encryptionVector `json:"encryptions"` 250 Exports []exportVector `json:"exports"` 251 } 252 253 func generateHybridKeyPair(rnd io.Reader, h kem.Scheme) ([]byte, kem.PublicKey, kem.PrivateKey, error) { 254 seed := make([]byte, h.SeedSize()) 255 _, err := rnd.Read(seed) 256 if err != nil { 257 return nil, nil, nil, err 258 } 259 260 pk, sk := h.DeriveKeyPair(seed) 261 return seed, pk, sk, nil 262 } 263 264 func mustEncodePublicKey(pk kem.PublicKey) []byte { 265 enc, err := pk.MarshalBinary() 266 if err != nil { 267 panic(err) 268 } 269 return enc 270 } 271 272 func mustEncodePrivateKey(sk kem.PrivateKey) []byte { 273 enc, err := sk.MarshalBinary() 274 if err != nil { 275 panic(err) 276 } 277 return enc 278 } 279 280 func generateEncryptions(sealer Sealer, opener Opener, msg []byte) ([]encryptionVector, error) { 281 vectors := make([]encryptionVector, testVectorEncryptionCount) 282 for i := 0; i < len(vectors); i++ { 283 aad := []byte(fmt.Sprintf("Count-%d", i)) 284 innerSealer := sealer.(*sealContext) 285 nonce := innerSealer.calcNonce() 286 encrypted, err := sealer.Seal(msg, aad) 287 if err != nil { 288 return nil, err 289 } 290 decrypted, err := opener.Open(encrypted, aad) 291 if err != nil { 292 return nil, err 293 } 294 if !bytes.Equal(decrypted, msg) { 295 return nil, fmt.Errorf("Mismatch messages %d", i) 296 } 297 vectors[i] = encryptionVector{ 298 Plaintext: hex.EncodeToString(msg), 299 Aad: hex.EncodeToString(aad), 300 Nonce: hex.EncodeToString(nonce), 301 Ciphertext: hex.EncodeToString(encrypted), 302 } 303 } 304 305 return vectors, nil 306 } 307 308 func generateExports(sealer Sealer, opener Opener) ([]exportVector, error) { 309 exportContexts := [][]byte{ 310 []byte(""), 311 {0x00}, 312 []byte("TestContext"), 313 } 314 vectors := make([]exportVector, len(exportContexts)) 315 for i := 0; i < len(vectors); i++ { 316 senderValue := sealer.Export(exportContexts[i], uint(testVectorExportLength)) 317 receiverValue := opener.Export(exportContexts[i], uint(testVectorExportLength)) 318 if !bytes.Equal(senderValue, receiverValue) { 319 return nil, fmt.Errorf("Mismatch export values") 320 } 321 vectors[i] = exportVector{ 322 ExportContext: hex.EncodeToString(exportContexts[i]), 323 ExportLength: testVectorExportLength, 324 ExportValue: hex.EncodeToString(senderValue), 325 } 326 } 327 328 return vectors, nil 329 } 330 331 func TestHybridKemRoundTrip(t *testing.T) { 332 kemID := KEM_X25519_KYBER768_DRAFT00 333 kdfID := KDF_HKDF_SHA256 334 aeadID := AEAD_AES128GCM 335 rnd := sha3.NewShake128() 336 suite := NewSuite(kemID, kdfID, aeadID) 337 msg := []byte("To the universal deployment of PQC") 338 info := []byte("Hear hear") 339 pskid := []byte("before everybody for everybody for everything") 340 psk := make([]byte, 32) 341 _, _ = rnd.Read(psk) 342 343 ikmR, pkR, skR, err := generateHybridKeyPair(rnd, kemID.Scheme()) 344 if err != nil { 345 t.Error(err) 346 } 347 348 ier := make([]byte, 64) 349 _, _ = rnd.Read(ier) 350 351 receiver, err := suite.NewReceiver(skR, info) 352 if err != nil { 353 t.Error(err) 354 } 355 356 sender, err := suite.NewSender(pkR, info) 357 if err != nil { 358 t.Error(err) 359 } 360 361 generateVector := func(mode uint8) vector { 362 var ( 363 err2 error 364 sealer Sealer 365 opener Opener 366 enc []byte 367 ) 368 rnd2 := bytes.NewBuffer(ier) 369 switch mode { 370 case modeBase: 371 enc, sealer, err2 = sender.Setup(rnd2) 372 if err2 != nil { 373 t.Error(err2) 374 } 375 opener, err2 = receiver.Setup(enc) 376 if err2 != nil { 377 t.Error(err2) 378 } 379 case modePSK: 380 enc, sealer, err2 = sender.SetupPSK(rnd2, psk, pskid) 381 if err2 != nil { 382 t.Error(err2) 383 } 384 opener, err2 = receiver.SetupPSK(enc, psk, pskid) 385 if err2 != nil { 386 t.Error(err2) 387 } 388 default: 389 panic("unsupported mode") 390 } 391 392 if rnd2.Len() != 0 { 393 t.Fatal() 394 } 395 396 innerSealer := sealer.(*sealContext) 397 398 encryptions, err2 := generateEncryptions(sealer, opener, msg) 399 if err2 != nil { 400 t.Error(err2) 401 } 402 exports, err2 := generateExports(sealer, opener) 403 if err2 != nil { 404 t.Error(err2) 405 } 406 407 ret := vector{ 408 ModeID: mode, 409 KemID: uint16(kemID), 410 KdfID: uint16(kdfID), 411 AeadID: uint16(aeadID), 412 Ier: hex.EncodeToString(ier), 413 Info: hex.EncodeToString(info), 414 IkmR: hex.EncodeToString(ikmR), 415 SkRm: hex.EncodeToString(mustEncodePrivateKey(skR)), 416 PkRm: hex.EncodeToString(mustEncodePublicKey(pkR)), 417 Enc: hex.EncodeToString(enc), 418 SharedSecret: hex.EncodeToString(innerSealer.sharedSecret), 419 KeyScheduleContext: hex.EncodeToString(innerSealer.keyScheduleContext), 420 Secret: hex.EncodeToString(innerSealer.secret), 421 Key: hex.EncodeToString(innerSealer.key), 422 BaseNonce: hex.EncodeToString(innerSealer.baseNonce), 423 ExporterSecret: hex.EncodeToString(innerSealer.exporterSecret), 424 Encryptions: encryptions, 425 Exports: exports, 426 } 427 428 if mode == modePSK { 429 ret.Psk = hex.EncodeToString(psk) 430 ret.PskID = hex.EncodeToString(pskid) 431 } 432 433 return ret 434 } 435 436 encodedVector, err := json.Marshal([]vector{ 437 generateVector(modeBase), 438 generateVector(modePSK), 439 }) 440 if err != nil { 441 t.Error(err) 442 } 443 444 var outputFile string 445 if outputFile = os.Getenv(outputTestVectorEnvironmentKey); len(outputFile) > 0 { 446 // nolint: gosec 447 err = os.WriteFile(outputFile, encodedVector, 0o644) 448 if err != nil { 449 t.Error(err) 450 } 451 } 452 }