github.com/cloudflare/circl@v1.5.0/kem/frodo/kat_test.go (about)

     1  package frodo
     2  
     3  // Code to generate the NIST "PQCsignKAT" test vectors.
     4  // See PQCsignKAT_sign.c and randombytes.c in the reference implementation.
     5  
     6  import (
     7  	"bytes"
     8  	"crypto/sha256"
     9  	"fmt"
    10  	"testing"
    11  
    12  	"github.com/cloudflare/circl/internal/nist"
    13  	"github.com/cloudflare/circl/kem/schemes"
    14  )
    15  
    16  func TestPQCgenKATKem(t *testing.T) {
    17  	kats := []struct {
    18  		name string
    19  		want string
    20  	}{
    21  		// Computed from:
    22  		// https://github.com/microsoft/PQCrypto-LWEKE/blob/66fc7744c3aae6acfc5fcc587ec7f2cdec48d216/KAT/PQCkemKAT_19888_shake.rsp
    23  		{"FrodoKEM-640-SHAKE", "604a10cfc871dfaed9cb5b057c644ab03b16852cea7f39bc7f9831513b5b1cfa"},
    24  	}
    25  	for _, kat := range kats {
    26  		t.Run(kat.name, func(t *testing.T) {
    27  			testPQCgenKATKem(t, kat.name, kat.want)
    28  		})
    29  	}
    30  }
    31  
    32  func testPQCgenKATKem(t *testing.T, name, expected string) {
    33  	scheme := schemes.ByName(name)
    34  	if scheme == nil {
    35  		t.Fatal()
    36  	}
    37  
    38  	var seed [48]byte
    39  	kseed := make([]byte, scheme.SeedSize())
    40  	eseed := make([]byte, scheme.EncapsulationSeedSize())
    41  	for i := 0; i < 48; i++ {
    42  		seed[i] = byte(i)
    43  	}
    44  	f := sha256.New()
    45  	g := nist.NewDRBG(&seed)
    46  	fmt.Fprintf(f, "# %s\n\n", name)
    47  	for i := 0; i < 100; i++ {
    48  		g.Fill(seed[:])
    49  		fmt.Fprintf(f, "count = %d\n", i)
    50  		fmt.Fprintf(f, "seed = %X\n", seed)
    51  		g2 := nist.NewDRBG(&seed)
    52  
    53  		g2.Fill(kseed[:])
    54  
    55  		pk, sk := scheme.DeriveKeyPair(kseed)
    56  		ppk, _ := pk.MarshalBinary()
    57  		psk, _ := sk.MarshalBinary()
    58  
    59  		g2.Fill(eseed)
    60  		ct, ss, err := scheme.EncapsulateDeterministically(pk, eseed)
    61  		if err != nil {
    62  			t.Fatal(err)
    63  		}
    64  		ss2, _ := scheme.Decapsulate(sk, ct)
    65  		if !bytes.Equal(ss, ss2) {
    66  			t.Fatal()
    67  		}
    68  		fmt.Fprintf(f, "pk = %X\n", ppk)
    69  		fmt.Fprintf(f, "sk = %X\n", psk)
    70  		fmt.Fprintf(f, "ct = %X\n", ct)
    71  		fmt.Fprintf(f, "ss = %X\n\n", ss)
    72  	}
    73  	if fmt.Sprintf("%x", f.Sum(nil)) != expected {
    74  		t.Fatal()
    75  	}
    76  }