github.com/cloudflare/circl@v1.5.0/kem/mlkem/acvp_test.go (about)

     1  package mlkem
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"encoding/hex"
     7  	"encoding/json"
     8  	"io"
     9  	"os"
    10  	"testing"
    11  
    12  	"github.com/cloudflare/circl/kem/schemes"
    13  )
    14  
    15  // []byte but is encoded in hex for JSON
    16  type HexBytes []byte
    17  
    18  func (b HexBytes) MarshalJSON() ([]byte, error) {
    19  	return json.Marshal(hex.EncodeToString(b))
    20  }
    21  
    22  func (b *HexBytes) UnmarshalJSON(data []byte) (err error) {
    23  	var s string
    24  	if err = json.Unmarshal(data, &s); err != nil {
    25  		return err
    26  	}
    27  	*b, err = hex.DecodeString(s)
    28  	return err
    29  }
    30  
    31  func gunzip(in []byte) ([]byte, error) {
    32  	buf := bytes.NewBuffer(in)
    33  	r, err := gzip.NewReader(buf)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  	return io.ReadAll(r)
    38  }
    39  
    40  func readGzip(path string) ([]byte, error) {
    41  	buf, err := os.ReadFile(path)
    42  	if err != nil {
    43  		return nil, err
    44  	}
    45  	return gunzip(buf)
    46  }
    47  
    48  func TestACVP(t *testing.T) {
    49  	for _, sub := range []string{
    50  		"keyGen",
    51  		"encapDecap",
    52  	} {
    53  		t.Run(sub, func(t *testing.T) {
    54  			testACVP(t, sub)
    55  		})
    56  	}
    57  }
    58  
    59  // nolint:funlen,gocyclo
    60  func testACVP(t *testing.T, sub string) {
    61  	buf, err := readGzip("testdata/ML-KEM-" + sub + "-FIPS203/prompt.json.gz")
    62  	if err != nil {
    63  		t.Fatal(err)
    64  	}
    65  
    66  	var prompt struct {
    67  		TestGroups []json.RawMessage `json:"testGroups"`
    68  	}
    69  
    70  	if err = json.Unmarshal(buf, &prompt); err != nil {
    71  		t.Fatal(err)
    72  	}
    73  
    74  	buf, err = readGzip("testdata/ML-KEM-" + sub + "-FIPS203/expectedResults.json.gz")
    75  	if err != nil {
    76  		t.Fatal(err)
    77  	}
    78  
    79  	var results struct {
    80  		TestGroups []json.RawMessage `json:"testGroups"`
    81  	}
    82  
    83  	if err := json.Unmarshal(buf, &results); err != nil {
    84  		t.Fatal(err)
    85  	}
    86  
    87  	rawResults := make(map[int]json.RawMessage)
    88  
    89  	for _, rawGroup := range results.TestGroups {
    90  		var abstractGroup struct {
    91  			Tests []json.RawMessage `json:"tests"`
    92  		}
    93  		if err := json.Unmarshal(rawGroup, &abstractGroup); err != nil {
    94  			t.Fatal(err)
    95  		}
    96  		for _, rawTest := range abstractGroup.Tests {
    97  			var abstractTest struct {
    98  				TcID int `json:"tcId"`
    99  			}
   100  			if err := json.Unmarshal(rawTest, &abstractTest); err != nil {
   101  				t.Fatal(err)
   102  			}
   103  			if _, exists := rawResults[abstractTest.TcID]; exists {
   104  				t.Fatalf("Duplicate test id: %d", abstractTest.TcID)
   105  			}
   106  			rawResults[abstractTest.TcID] = rawTest
   107  		}
   108  	}
   109  
   110  	for _, rawGroup := range prompt.TestGroups {
   111  		var abstractGroup struct {
   112  			TestType string `json:"testType"`
   113  		}
   114  		if err := json.Unmarshal(rawGroup, &abstractGroup); err != nil {
   115  			t.Fatal(err)
   116  		}
   117  		switch {
   118  		case abstractGroup.TestType == "AFT" && sub == "keyGen":
   119  			var group struct {
   120  				TgID         int    `json:"tgId"`
   121  				ParameterSet string `json:"parameterSet"`
   122  				Tests        []struct {
   123  					TcID int      `json:"tcId"`
   124  					Z    HexBytes `json:"z"`
   125  					D    HexBytes `json:"d"`
   126  				}
   127  			}
   128  			if err := json.Unmarshal(rawGroup, &group); err != nil {
   129  				t.Fatal(err)
   130  			}
   131  
   132  			scheme := schemes.ByName(group.ParameterSet)
   133  			if scheme == nil {
   134  				t.Fatalf("No such scheme: %s", group.ParameterSet)
   135  			}
   136  
   137  			for _, test := range group.Tests {
   138  				var result struct {
   139  					Ek HexBytes `json:"ek"`
   140  					Dk HexBytes `json:"dk"`
   141  				}
   142  				rawResult, ok := rawResults[test.TcID]
   143  				if !ok {
   144  					t.Fatalf("Missing result: %d", test.TcID)
   145  				}
   146  				if err := json.Unmarshal(rawResult, &result); err != nil {
   147  					t.Fatal(err)
   148  				}
   149  
   150  				var seed [64]byte
   151  				copy(seed[:], test.D)
   152  				copy(seed[32:], test.Z)
   153  
   154  				ek, dk := scheme.DeriveKeyPair(seed[:])
   155  
   156  				ek2, err := scheme.UnmarshalBinaryPublicKey(result.Ek)
   157  				if err != nil {
   158  					t.Fatalf("tc=%d: %v", test.TcID, err)
   159  				}
   160  				dk2, err := scheme.UnmarshalBinaryPrivateKey(result.Dk)
   161  				if err != nil {
   162  					t.Fatal(err)
   163  				}
   164  
   165  				if !dk.Equal(dk2) {
   166  					t.Fatal("dk does not match")
   167  				}
   168  				if !ek.Equal(ek2) {
   169  					t.Fatal("ek does not match")
   170  				}
   171  			}
   172  		case abstractGroup.TestType == "AFT" && sub == "encapDecap":
   173  			var group struct {
   174  				TgID         int    `json:"tgId"`
   175  				ParameterSet string `json:"parameterSet"`
   176  				Tests        []struct {
   177  					TcID int      `json:"tcId"`
   178  					Ek   HexBytes `json:"ek"`
   179  					M    HexBytes `json:"m"`
   180  				}
   181  			}
   182  			if err := json.Unmarshal(rawGroup, &group); err != nil {
   183  				t.Fatal(err)
   184  			}
   185  
   186  			scheme := schemes.ByName(group.ParameterSet)
   187  			if scheme == nil {
   188  				t.Fatalf("No such scheme: %s", group.ParameterSet)
   189  			}
   190  
   191  			for _, test := range group.Tests {
   192  				var result struct {
   193  					C HexBytes `json:"c"`
   194  					K HexBytes `json:"k"`
   195  				}
   196  				rawResult, ok := rawResults[test.TcID]
   197  				if !ok {
   198  					t.Fatalf("Missing result: %d", test.TcID)
   199  				}
   200  				if err := json.Unmarshal(rawResult, &result); err != nil {
   201  					t.Fatal(err)
   202  				}
   203  
   204  				ek, err := scheme.UnmarshalBinaryPublicKey(test.Ek)
   205  				if err != nil {
   206  					t.Fatal(err)
   207  				}
   208  
   209  				ct, ss, err := scheme.EncapsulateDeterministically(ek, test.M)
   210  				if err != nil {
   211  					t.Fatal(err)
   212  				}
   213  
   214  				if !bytes.Equal(ct, result.C) {
   215  					t.Fatalf("ciphertext doesn't match: %x ≠ %x", ct, result.C)
   216  				}
   217  				if !bytes.Equal(ss, result.K) {
   218  					t.Fatalf("shared secret doesn't match: %x ≠ %x", ss, result.K)
   219  				}
   220  			}
   221  		case abstractGroup.TestType == "VAL" && sub == "encapDecap":
   222  			var group struct {
   223  				TgID         int      `json:"tgId"`
   224  				ParameterSet string   `json:"parameterSet"`
   225  				Dk           HexBytes `json:"dk"`
   226  				Tests        []struct {
   227  					TcID int      `json:"tcId"`
   228  					C    HexBytes `json:"c"`
   229  				}
   230  			}
   231  			if err := json.Unmarshal(rawGroup, &group); err != nil {
   232  				t.Fatal(err)
   233  			}
   234  
   235  			scheme := schemes.ByName(group.ParameterSet)
   236  			if scheme == nil {
   237  				t.Fatalf("No such scheme: %s", group.ParameterSet)
   238  			}
   239  
   240  			dk, err := scheme.UnmarshalBinaryPrivateKey(group.Dk)
   241  			if err != nil {
   242  				t.Fatal(err)
   243  			}
   244  
   245  			for _, test := range group.Tests {
   246  				var result struct {
   247  					K HexBytes `json:"k"`
   248  				}
   249  				rawResult, ok := rawResults[test.TcID]
   250  				if !ok {
   251  					t.Fatalf("Missing rawResult: %d", test.TcID)
   252  				}
   253  				if err := json.Unmarshal(rawResult, &result); err != nil {
   254  					t.Fatal(err)
   255  				}
   256  
   257  				ss, err := scheme.Decapsulate(dk, test.C)
   258  				if err != nil {
   259  					t.Fatal(err)
   260  				}
   261  
   262  				if !bytes.Equal(ss, result.K) {
   263  					t.Fatalf("shared secret doesn't match: %x ≠ %x", ss, result.K)
   264  				}
   265  			}
   266  		default:
   267  			t.Fatalf("unknown type %s for %s", abstractGroup.TestType, sub)
   268  		}
   269  	}
   270  }