github.com/cloudflare/circl@v1.5.0/sign/dilithium/templates/acvp.templ.go (about)

     1  // +build ignore
     2  // The previous line (and this one up to the warning below) is removed by the
     3  // template generator.
     4  
     5  // Code generated from acvp.templ.go. DO NOT EDIT.
     6  
     7  package {{.Pkg}}
     8  
     9  import (
    10  	"bytes"
    11  	"compress/gzip"
    12  	"encoding/hex"
    13  	"encoding/json"
    14  	"io"
    15  	"os"
    16  	"testing"
    17  )
    18  
    19  // []byte but is encoded in hex for JSON
    20  type HexBytes []byte
    21  
    22  func (b HexBytes) MarshalJSON() ([]byte, error) {
    23  	return json.Marshal(hex.EncodeToString(b))
    24  }
    25  
    26  func (b *HexBytes) UnmarshalJSON(data []byte) (err error) {
    27  	var s string
    28  	if err = json.Unmarshal(data, &s); err != nil {
    29  		return err
    30  	}
    31  	*b, err = hex.DecodeString(s)
    32  	return err
    33  }
    34  
    35  func gunzip(in []byte) ([]byte, error) {
    36  	buf := bytes.NewBuffer(in)
    37  	r, err := gzip.NewReader(buf)
    38  	if err != nil {
    39  		return nil, err
    40  	}
    41  	return io.ReadAll(r)
    42  }
    43  
    44  func readGzip(path string) ([]byte, error) {
    45  	buf, err := os.ReadFile(path)
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  	return gunzip(buf)
    50  }
    51  
    52  func TestACVP(t *testing.T) {
    53  	for _, sub := range []string{
    54  		"keyGen",
    55  		"sigGen",
    56  		"sigVer",
    57  	} {
    58  		t.Run(sub, func(t *testing.T) {
    59  			testACVP(t, sub)
    60  		})
    61  	}
    62  }
    63  
    64  // nolint:funlen,gocyclo
    65  func testACVP(t *testing.T, sub string) {
    66  	buf, err := readGzip("../testdata/ML-DSA-" + sub + "-FIPS204/prompt.json.gz")
    67  	if err != nil {
    68  		t.Fatal(err)
    69  	}
    70  
    71  	var prompt struct {
    72  		TestGroups []json.RawMessage `json:"testGroups"`
    73  	}
    74  
    75  	if err = json.Unmarshal(buf, &prompt); err != nil {
    76  		t.Fatal(err)
    77  	}
    78  
    79  	buf, err = readGzip("../testdata/ML-DSA-" + sub + "-FIPS204/expectedResults.json.gz")
    80  	if err != nil {
    81  		t.Fatal(err)
    82  	}
    83  
    84  	var results struct {
    85  		TestGroups []json.RawMessage `json:"testGroups"`
    86  	}
    87  
    88  	if err := json.Unmarshal(buf, &results); err != nil {
    89  		t.Fatal(err)
    90  	}
    91  
    92  	rawResults := make(map[int]json.RawMessage)
    93  
    94  	for _, rawGroup := range results.TestGroups {
    95  		var abstractGroup struct {
    96  			Tests []json.RawMessage `json:"tests"`
    97  		}
    98  		if err := json.Unmarshal(rawGroup, &abstractGroup); err != nil {
    99  			t.Fatal(err)
   100  		}
   101  		for _, rawTest := range abstractGroup.Tests {
   102  			var abstractTest struct {
   103  				TcID int `json:"tcId"`
   104  			}
   105  			if err := json.Unmarshal(rawTest, &abstractTest); err != nil {
   106  				t.Fatal(err)
   107  			}
   108  			if _, exists := rawResults[abstractTest.TcID]; exists {
   109  				t.Fatalf("Duplicate test id: %d", abstractTest.TcID)
   110  			}
   111  			rawResults[abstractTest.TcID] = rawTest
   112  		}
   113  	}
   114  
   115  	scheme := Scheme()
   116  
   117  	for _, rawGroup := range prompt.TestGroups {
   118  		var abstractGroup struct {
   119  			TestType string `json:"testType"`
   120  		}
   121  		if err := json.Unmarshal(rawGroup, &abstractGroup); err != nil {
   122  			t.Fatal(err)
   123  		}
   124  		switch {
   125  		case abstractGroup.TestType == "AFT" && sub == "keyGen":
   126  			var group struct {
   127  				TgID         int    `json:"tgId"`
   128  				ParameterSet string `json:"parameterSet"`
   129  				Tests        []struct {
   130  					TcID int      `json:"tcId"`
   131  					Seed HexBytes `json:"seed"`
   132  				}
   133  			}
   134  			if err := json.Unmarshal(rawGroup, &group); err != nil {
   135  				t.Fatal(err)
   136  			}
   137  
   138  			if group.ParameterSet != scheme.Name() {
   139  				continue
   140  			}
   141  
   142  			for _, test := range group.Tests {
   143  				var result struct {
   144  					Pk HexBytes `json:"pk"`
   145  					Sk HexBytes `json:"sk"`
   146  				}
   147  				rawResult, ok := rawResults[test.TcID]
   148  				if !ok {
   149  					t.Fatalf("Missing result: %d", test.TcID)
   150  				}
   151  				if err := json.Unmarshal(rawResult, &result); err != nil {
   152  					t.Fatal(err)
   153  				}
   154  
   155  				pk, sk := scheme.DeriveKey(test.Seed)
   156  
   157  				pk2, err := scheme.UnmarshalBinaryPublicKey(result.Pk)
   158  				if err != nil {
   159  					t.Fatalf("tc=%d: %v", test.TcID, err)
   160  				}
   161  				sk2, err := scheme.UnmarshalBinaryPrivateKey(result.Sk)
   162  				if err != nil {
   163  					t.Fatal(err)
   164  				}
   165  
   166  				if !pk.Equal(pk2) {
   167  					t.Fatal("pk does not match")
   168  				}
   169  				if !sk.Equal(sk2) {
   170  					t.Fatal("sk does not match")
   171  				}
   172  			}
   173  		case abstractGroup.TestType == "AFT" && sub == "sigGen":
   174  			var group struct {
   175  				TgID          int    `json:"tgId"`
   176  				ParameterSet  string `json:"parameterSet"`
   177  				Deterministic bool   `json:"deterministic"`
   178  				Tests         []struct {
   179  					TcID    int      `json:"tcId"`
   180  					Sk      HexBytes `json:"sk"`
   181  					Message HexBytes `json:"message"`
   182  					Rnd     HexBytes `json:"rnd"`
   183  				}
   184  			}
   185  			if err := json.Unmarshal(rawGroup, &group); err != nil {
   186  				t.Fatal(err)
   187  			}
   188  
   189  			if group.ParameterSet != scheme.Name() {
   190  				continue
   191  			}
   192  
   193  			for _, test := range group.Tests {
   194  				var result struct {
   195  					Signature HexBytes `json:"signature"`
   196  				}
   197  				rawResult, ok := rawResults[test.TcID]
   198  				if !ok {
   199  					t.Fatalf("Missing result: %d", test.TcID)
   200  				}
   201  				if err := json.Unmarshal(rawResult, &result); err != nil {
   202  					t.Fatal(err)
   203  				}
   204  
   205  				sk, err := scheme.UnmarshalBinaryPrivateKey(test.Sk)
   206  				if err != nil {
   207  					t.Fatal(err)
   208  				}
   209  
   210  				var rnd [32]byte
   211  				if !group.Deterministic {
   212  					copy(rnd[:], test.Rnd)
   213  				}
   214  
   215  				sig2 := sk.(*PrivateKey).unsafeSignInternal(test.Message, rnd)
   216  
   217  				if !bytes.Equal(sig2, result.Signature) {
   218  					t.Fatalf("signature doesn't match: %x ≠ %x",
   219  						sig2, result.Signature)
   220  				}
   221  			}
   222  		case abstractGroup.TestType == "AFT" && sub == "sigVer":
   223  			var group struct {
   224  				TgID         int      `json:"tgId"`
   225  				ParameterSet string   `json:"parameterSet"`
   226  				Pk           HexBytes `json:"pk"`
   227  				Tests        []struct {
   228  					TcID      int      `json:"tcId"`
   229  					Message   HexBytes `json:"message"`
   230  					Signature HexBytes `json:"signature"`
   231  				}
   232  			}
   233  			if err := json.Unmarshal(rawGroup, &group); err != nil {
   234  				t.Fatal(err)
   235  			}
   236  
   237  			if group.ParameterSet != scheme.Name() {
   238  				continue
   239  			}
   240  
   241  			pk, err := scheme.UnmarshalBinaryPublicKey(group.Pk)
   242  			if err != nil {
   243  				t.Fatal(err)
   244  			}
   245  
   246  			for _, test := range group.Tests {
   247  				var result struct {
   248  					TestPassed bool `json:"testPassed"`
   249  				}
   250  				rawResult, ok := rawResults[test.TcID]
   251  				if !ok {
   252  					t.Fatalf("Missing result: %d", test.TcID)
   253  				}
   254  				if err := json.Unmarshal(rawResult, &result); err != nil {
   255  					t.Fatal(err)
   256  				}
   257  
   258  				passed2 := unsafeVerifyInternal(pk.(*PublicKey), test.Message, test.Signature)
   259  				if passed2 != result.TestPassed {
   260  					t.Fatalf("verification %v ≠ %v", passed2, result.TestPassed)
   261  				}
   262  			}
   263  		default:
   264  			t.Fatalf("unknown type %s for %s", abstractGroup.TestType, sub)
   265  		}
   266  	}
   267  }