github.com/cloudflare/circl@v1.5.0/hpke/vectors_test.go (about)

     1  package hpke
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/hex"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io"
     9  	"os"
    10  	"testing"
    11  
    12  	"github.com/cloudflare/circl/internal/test"
    13  	"github.com/cloudflare/circl/kem"
    14  	"golang.org/x/crypto/sha3"
    15  )
    16  
    17  var (
    18  	outputTestVectorEnvironmentKey = "HPKE_TEST_VECTORS_OUT"
    19  	testVectorEncryptionCount      = 257
    20  	testVectorExportLength         = 32
    21  )
    22  
    23  func TestVectors(t *testing.T) {
    24  	// Test vectors from
    25  	// https://github.com/cfrg/draft-irtf-cfrg-hpke/blob/master/test-vectors.json
    26  	vectors := readFile(t, "testdata/vectors_rfc9180_5f503c5.json")
    27  	for i, v := range vectors {
    28  		t.Run(fmt.Sprintf("v%v", i), v.verify)
    29  	}
    30  }
    31  
    32  func (v *vector) verify(t *testing.T) {
    33  	m := v.ModeID
    34  	kem, kdf, aead := KEM(v.KemID), KDF(v.KdfID), AEAD(v.AeadID)
    35  	if !kem.IsValid() {
    36  		t.Skipf("Skipping test with unknown KEM: %x", kem)
    37  	}
    38  	if !kdf.IsValid() {
    39  		t.Skipf("Skipping test with unknown KDF: %x", kdf)
    40  	}
    41  	if !aead.IsValid() {
    42  		t.Skipf("Skipping test with unknown AEAD: %x", aead)
    43  	}
    44  	s := NewSuite(kem, kdf, aead)
    45  
    46  	sender, recv := v.getActors(t, kem.Scheme(), s)
    47  	sealer, opener := v.setup(t, kem.Scheme(), sender, recv, m, s)
    48  
    49  	v.checkAead(t, (sealer.(*sealContext)).encdecContext, m)
    50  	v.checkAead(t, (opener.(*openContext)).encdecContext, m)
    51  	v.checkEncryptions(t, sealer, opener, m)
    52  	v.checkExports(t, sealer, m)
    53  	v.checkExports(t, opener, m)
    54  }
    55  
    56  func (v *vector) getActors(
    57  	t *testing.T, dhkem kem.Scheme, s Suite,
    58  ) (*Sender, *Receiver) {
    59  	h := s.String() + "\n"
    60  
    61  	pkR, err := dhkem.UnmarshalBinaryPublicKey(hexB(t, v.PkRm))
    62  	test.CheckNoErr(t, err, h+"bad public key")
    63  
    64  	skR, err := dhkem.UnmarshalBinaryPrivateKey(hexB(t, v.SkRm))
    65  	test.CheckNoErr(t, err, h+"bad private key")
    66  
    67  	info := hexB(t, v.Info)
    68  	sender, err := s.NewSender(pkR, info)
    69  	test.CheckNoErr(t, err, h+"err sender")
    70  
    71  	recv, err := s.NewReceiver(skR, info)
    72  	test.CheckNoErr(t, err, h+"err receiver")
    73  
    74  	return sender, recv
    75  }
    76  
    77  func (v *vector) setup(t *testing.T, k kem.Scheme,
    78  	se *Sender, re *Receiver,
    79  	m modeID, s Suite,
    80  ) (sealer Sealer, opener Opener) {
    81  	seed := hexB(t, v.IkmE)
    82  	rd := bytes.NewReader(seed)
    83  
    84  	var enc []byte
    85  	var skS kem.PrivateKey
    86  	var pkS kem.PublicKey
    87  	var errS, errR, errPK, errSK error
    88  
    89  	switch v.ModeID {
    90  	case modeBase:
    91  		enc, sealer, errS = se.Setup(rd)
    92  		if errS == nil {
    93  			opener, errR = re.Setup(enc)
    94  		}
    95  
    96  	case modePSK:
    97  		psk, pskid := hexB(t, v.Psk), hexB(t, v.PskID)
    98  		enc, sealer, errS = se.SetupPSK(rd, psk, pskid)
    99  		if errS == nil {
   100  			opener, errR = re.SetupPSK(enc, psk, pskid)
   101  		}
   102  
   103  	case modeAuth:
   104  		skS, errSK = k.UnmarshalBinaryPrivateKey(hexB(t, v.SkSm))
   105  		if errSK == nil {
   106  			pkS, errPK = k.UnmarshalBinaryPublicKey(hexB(t, v.PkSm))
   107  			if errPK == nil {
   108  				enc, sealer, errS = se.SetupAuth(rd, skS)
   109  				if errS == nil {
   110  					opener, errR = re.SetupAuth(enc, pkS)
   111  				}
   112  			}
   113  		}
   114  
   115  	case modeAuthPSK:
   116  		psk, pskid := hexB(t, v.Psk), hexB(t, v.PskID)
   117  		skS, errSK = k.UnmarshalBinaryPrivateKey(hexB(t, v.SkSm))
   118  		if errSK == nil {
   119  			pkS, errPK = k.UnmarshalBinaryPublicKey(hexB(t, v.PkSm))
   120  			if errPK == nil {
   121  				enc, sealer, errS = se.SetupAuthPSK(rd, skS, psk, pskid)
   122  				if errS == nil {
   123  					opener, errR = re.SetupAuthPSK(enc, psk, pskid, pkS)
   124  				}
   125  			}
   126  		}
   127  	}
   128  
   129  	h := fmt.Sprintf("mode: %v %v\n", m, s)
   130  	test.CheckNoErr(t, errS, h+"error on sender setup")
   131  	test.CheckNoErr(t, errR, h+"error on receiver setup")
   132  	test.CheckNoErr(t, errSK, h+"bad private key")
   133  	test.CheckNoErr(t, errPK, h+"bad public key")
   134  
   135  	return sealer, opener
   136  }
   137  
   138  func (v *vector) checkAead(t *testing.T, e *encdecContext, m modeID) {
   139  	got := e.baseNonce
   140  	want := hexB(t, v.BaseNonce)
   141  	if !bytes.Equal(got, want) {
   142  		test.ReportError(t, got, want, m, e.Suite())
   143  	}
   144  
   145  	got = e.exporterSecret
   146  	want = hexB(t, v.ExporterSecret)
   147  	if !bytes.Equal(got, want) {
   148  		test.ReportError(t, got, want, m, e.Suite())
   149  	}
   150  }
   151  
   152  func (v *vector) checkEncryptions(
   153  	t *testing.T,
   154  	se Sealer,
   155  	op Opener,
   156  	m modeID,
   157  ) {
   158  	for j, encv := range v.Encryptions {
   159  		pt := hexB(t, encv.Plaintext)
   160  		aad := hexB(t, encv.Aad)
   161  
   162  		ct, err := se.Seal(pt, aad)
   163  		test.CheckNoErr(t, err, "error on sealing")
   164  
   165  		got, err := op.Open(ct, aad)
   166  		test.CheckNoErr(t, err, "error on opening")
   167  
   168  		want := pt
   169  		if !bytes.Equal(got, want) {
   170  			test.ReportError(t, got, want, m, se.Suite(), j)
   171  		}
   172  	}
   173  }
   174  
   175  func (v *vector) checkExports(t *testing.T, context Context, m modeID) {
   176  	for j, expv := range v.Exports {
   177  		ctx := hexB(t, expv.ExportContext)
   178  		want := hexB(t, expv.ExportValue)
   179  
   180  		got := context.Export(ctx, uint(expv.ExportLength))
   181  		if !bytes.Equal(got, want) {
   182  			test.ReportError(t, got, want, m, context.Suite(), j)
   183  		}
   184  	}
   185  }
   186  
   187  func hexB(t *testing.T, x string) []byte {
   188  	t.Helper()
   189  	z, err := hex.DecodeString(x)
   190  	test.CheckNoErr(t, err, "")
   191  	return z
   192  }
   193  
   194  func readFile(t *testing.T, fileName string) []vector {
   195  	jsonFile, err := os.Open(fileName)
   196  	if err != nil {
   197  		t.Fatalf("File %v can not be opened. Error: %v", fileName, err)
   198  	}
   199  	defer jsonFile.Close()
   200  	input, err := io.ReadAll(jsonFile)
   201  	if err != nil {
   202  		t.Fatalf("File %v can not be read. Error: %v", fileName, err)
   203  	}
   204  	var vectors []vector
   205  	err = json.Unmarshal(input, &vectors)
   206  	if err != nil {
   207  		t.Fatalf("File %v can not be loaded. Error: %v", fileName, err)
   208  	}
   209  	return vectors
   210  }
   211  
   212  type encryptionVector struct {
   213  	Aad        string `json:"aad"`
   214  	Ciphertext string `json:"ct"`
   215  	Nonce      string `json:"nonce"`
   216  	Plaintext  string `json:"pt"`
   217  }
   218  
   219  type exportVector struct {
   220  	ExportContext string `json:"exporter_context"`
   221  	ExportLength  int    `json:"L"`
   222  	ExportValue   string `json:"exported_value"`
   223  }
   224  
   225  type vector struct {
   226  	ModeID             uint8              `json:"mode"`
   227  	KemID              uint16             `json:"kem_id"`
   228  	KdfID              uint16             `json:"kdf_id"`
   229  	AeadID             uint16             `json:"aead_id"`
   230  	Info               string             `json:"info"`
   231  	Ier                string             `json:"ier,omitempty"`
   232  	IkmR               string             `json:"ikmR"`
   233  	IkmE               string             `json:"ikmE,omitempty"`
   234  	SkRm               string             `json:"skRm"`
   235  	SkEm               string             `json:"skEm,omitempty"`
   236  	SkSm               string             `json:"skSm,omitempty"`
   237  	Psk                string             `json:"psk,omitempty"`
   238  	PskID              string             `json:"psk_id,omitempty"`
   239  	PkSm               string             `json:"pkSm,omitempty"`
   240  	PkRm               string             `json:"pkRm"`
   241  	PkEm               string             `json:"pkEm,omitempty"`
   242  	Enc                string             `json:"enc"`
   243  	SharedSecret       string             `json:"shared_secret"`
   244  	KeyScheduleContext string             `json:"key_schedule_context"`
   245  	Secret             string             `json:"secret"`
   246  	Key                string             `json:"key"`
   247  	BaseNonce          string             `json:"base_nonce"`
   248  	ExporterSecret     string             `json:"exporter_secret"`
   249  	Encryptions        []encryptionVector `json:"encryptions"`
   250  	Exports            []exportVector     `json:"exports"`
   251  }
   252  
   253  func generateHybridKeyPair(rnd io.Reader, h kem.Scheme) ([]byte, kem.PublicKey, kem.PrivateKey, error) {
   254  	seed := make([]byte, h.SeedSize())
   255  	_, err := rnd.Read(seed)
   256  	if err != nil {
   257  		return nil, nil, nil, err
   258  	}
   259  
   260  	pk, sk := h.DeriveKeyPair(seed)
   261  	return seed, pk, sk, nil
   262  }
   263  
   264  func mustEncodePublicKey(pk kem.PublicKey) []byte {
   265  	enc, err := pk.MarshalBinary()
   266  	if err != nil {
   267  		panic(err)
   268  	}
   269  	return enc
   270  }
   271  
   272  func mustEncodePrivateKey(sk kem.PrivateKey) []byte {
   273  	enc, err := sk.MarshalBinary()
   274  	if err != nil {
   275  		panic(err)
   276  	}
   277  	return enc
   278  }
   279  
   280  func generateEncryptions(sealer Sealer, opener Opener, msg []byte) ([]encryptionVector, error) {
   281  	vectors := make([]encryptionVector, testVectorEncryptionCount)
   282  	for i := 0; i < len(vectors); i++ {
   283  		aad := []byte(fmt.Sprintf("Count-%d", i))
   284  		innerSealer := sealer.(*sealContext)
   285  		nonce := innerSealer.calcNonce()
   286  		encrypted, err := sealer.Seal(msg, aad)
   287  		if err != nil {
   288  			return nil, err
   289  		}
   290  		decrypted, err := opener.Open(encrypted, aad)
   291  		if err != nil {
   292  			return nil, err
   293  		}
   294  		if !bytes.Equal(decrypted, msg) {
   295  			return nil, fmt.Errorf("Mismatch messages %d", i)
   296  		}
   297  		vectors[i] = encryptionVector{
   298  			Plaintext:  hex.EncodeToString(msg),
   299  			Aad:        hex.EncodeToString(aad),
   300  			Nonce:      hex.EncodeToString(nonce),
   301  			Ciphertext: hex.EncodeToString(encrypted),
   302  		}
   303  	}
   304  
   305  	return vectors, nil
   306  }
   307  
   308  func generateExports(sealer Sealer, opener Opener) ([]exportVector, error) {
   309  	exportContexts := [][]byte{
   310  		[]byte(""),
   311  		{0x00},
   312  		[]byte("TestContext"),
   313  	}
   314  	vectors := make([]exportVector, len(exportContexts))
   315  	for i := 0; i < len(vectors); i++ {
   316  		senderValue := sealer.Export(exportContexts[i], uint(testVectorExportLength))
   317  		receiverValue := opener.Export(exportContexts[i], uint(testVectorExportLength))
   318  		if !bytes.Equal(senderValue, receiverValue) {
   319  			return nil, fmt.Errorf("Mismatch export values")
   320  		}
   321  		vectors[i] = exportVector{
   322  			ExportContext: hex.EncodeToString(exportContexts[i]),
   323  			ExportLength:  testVectorExportLength,
   324  			ExportValue:   hex.EncodeToString(senderValue),
   325  		}
   326  	}
   327  
   328  	return vectors, nil
   329  }
   330  
   331  func TestHybridKemRoundTrip(t *testing.T) {
   332  	kemID := KEM_X25519_KYBER768_DRAFT00
   333  	kdfID := KDF_HKDF_SHA256
   334  	aeadID := AEAD_AES128GCM
   335  	rnd := sha3.NewShake128()
   336  	suite := NewSuite(kemID, kdfID, aeadID)
   337  	msg := []byte("To the universal deployment of PQC")
   338  	info := []byte("Hear hear")
   339  	pskid := []byte("before everybody for everybody for everything")
   340  	psk := make([]byte, 32)
   341  	_, _ = rnd.Read(psk)
   342  
   343  	ikmR, pkR, skR, err := generateHybridKeyPair(rnd, kemID.Scheme())
   344  	if err != nil {
   345  		t.Error(err)
   346  	}
   347  
   348  	ier := make([]byte, 64)
   349  	_, _ = rnd.Read(ier)
   350  
   351  	receiver, err := suite.NewReceiver(skR, info)
   352  	if err != nil {
   353  		t.Error(err)
   354  	}
   355  
   356  	sender, err := suite.NewSender(pkR, info)
   357  	if err != nil {
   358  		t.Error(err)
   359  	}
   360  
   361  	generateVector := func(mode uint8) vector {
   362  		var (
   363  			err2   error
   364  			sealer Sealer
   365  			opener Opener
   366  			enc    []byte
   367  		)
   368  		rnd2 := bytes.NewBuffer(ier)
   369  		switch mode {
   370  		case modeBase:
   371  			enc, sealer, err2 = sender.Setup(rnd2)
   372  			if err2 != nil {
   373  				t.Error(err2)
   374  			}
   375  			opener, err2 = receiver.Setup(enc)
   376  			if err2 != nil {
   377  				t.Error(err2)
   378  			}
   379  		case modePSK:
   380  			enc, sealer, err2 = sender.SetupPSK(rnd2, psk, pskid)
   381  			if err2 != nil {
   382  				t.Error(err2)
   383  			}
   384  			opener, err2 = receiver.SetupPSK(enc, psk, pskid)
   385  			if err2 != nil {
   386  				t.Error(err2)
   387  			}
   388  		default:
   389  			panic("unsupported mode")
   390  		}
   391  
   392  		if rnd2.Len() != 0 {
   393  			t.Fatal()
   394  		}
   395  
   396  		innerSealer := sealer.(*sealContext)
   397  
   398  		encryptions, err2 := generateEncryptions(sealer, opener, msg)
   399  		if err2 != nil {
   400  			t.Error(err2)
   401  		}
   402  		exports, err2 := generateExports(sealer, opener)
   403  		if err2 != nil {
   404  			t.Error(err2)
   405  		}
   406  
   407  		ret := vector{
   408  			ModeID:             mode,
   409  			KemID:              uint16(kemID),
   410  			KdfID:              uint16(kdfID),
   411  			AeadID:             uint16(aeadID),
   412  			Ier:                hex.EncodeToString(ier),
   413  			Info:               hex.EncodeToString(info),
   414  			IkmR:               hex.EncodeToString(ikmR),
   415  			SkRm:               hex.EncodeToString(mustEncodePrivateKey(skR)),
   416  			PkRm:               hex.EncodeToString(mustEncodePublicKey(pkR)),
   417  			Enc:                hex.EncodeToString(enc),
   418  			SharedSecret:       hex.EncodeToString(innerSealer.sharedSecret),
   419  			KeyScheduleContext: hex.EncodeToString(innerSealer.keyScheduleContext),
   420  			Secret:             hex.EncodeToString(innerSealer.secret),
   421  			Key:                hex.EncodeToString(innerSealer.key),
   422  			BaseNonce:          hex.EncodeToString(innerSealer.baseNonce),
   423  			ExporterSecret:     hex.EncodeToString(innerSealer.exporterSecret),
   424  			Encryptions:        encryptions,
   425  			Exports:            exports,
   426  		}
   427  
   428  		if mode == modePSK {
   429  			ret.Psk = hex.EncodeToString(psk)
   430  			ret.PskID = hex.EncodeToString(pskid)
   431  		}
   432  
   433  		return ret
   434  	}
   435  
   436  	encodedVector, err := json.Marshal([]vector{
   437  		generateVector(modeBase),
   438  		generateVector(modePSK),
   439  	})
   440  	if err != nil {
   441  		t.Error(err)
   442  	}
   443  
   444  	var outputFile string
   445  	if outputFile = os.Getenv(outputTestVectorEnvironmentKey); len(outputFile) > 0 {
   446  		// nolint: gosec
   447  		err = os.WriteFile(outputFile, encodedVector, 0o644)
   448  		if err != nil {
   449  			t.Error(err)
   450  		}
   451  	}
   452  }