github.com/cloudflare/circl@v1.5.0/blindsign/blindrsa/brsa_test.go (about)

     1  package blindrsa
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"crypto/rsa"
     7  	"crypto/x509"
     8  	"encoding/hex"
     9  	"encoding/json"
    10  	"encoding/pem"
    11  	"fmt"
    12  	"io"
    13  	"math/big"
    14  	"os"
    15  	"strings"
    16  	"testing"
    17  
    18  	"github.com/cloudflare/circl/internal/test"
    19  )
    20  
    21  func loadPrivateKey() (*rsa.PrivateKey, error) {
    22  	file, err := os.ReadFile("./testdata/testRSA2048.rfc9500.pem")
    23  	if err != nil {
    24  		return nil, err
    25  	}
    26  
    27  	block, _ := pem.Decode(file)
    28  	if block == nil || block.Type != "RSA TESTING KEY" {
    29  		return nil, fmt.Errorf("PEM private key decoding failed")
    30  	}
    31  
    32  	privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
    33  	if err != nil {
    34  		return nil, err
    35  	}
    36  
    37  	return privateKey, nil
    38  }
    39  
    40  func mustDecodeHex(h string) []byte {
    41  	b, err := hex.DecodeString(h)
    42  	if err != nil {
    43  		panic(err)
    44  	}
    45  	return b
    46  }
    47  
    48  func loadStrongRSAKey() *rsa.PrivateKey {
    49  	// https://gist.github.com/chris-wood/b77536febb25a5a11af428afff77820a
    50  	pEnc := "dcd90af1be463632c0d5ea555256a20605af3db667475e190e3af12a34a3324c46a3094062c59fb4b249e0ee6afba8bee14e0276d126c99f4784b23009bf6168ff628ac1486e5ae8e23ce4d362889de4df63109cbd90ef93db5ae64372bfe1c55f832766f21e94ea3322eb2182f10a891546536ba907ad74b8d72469bea396f3"
    51  	qEnc := "f8ba5c89bd068f57234a3cf54a1c89d5b4cd0194f2633ca7c60b91a795a56fa8c8686c0e37b1c4498b851e3420d08bea29f71d195cfbd3671c6ddc49cf4c1db5b478231ea9d91377ffa98fe95685fca20ba4623212b2f2def4da5b281ed0100b651f6db32112e4017d831c0da668768afa7141d45bbc279f1e0f8735d74395b3"
    52  	NEnc := "d6930820f71fe517bf3259d14d40209b02a5c0d3d61991c731dd7da39f8d69821552e2318d6c9ad897e603887a476ea3162c1205da9ac96f02edf31df049bd55f142134c17d4382a0e78e275345f165fbe8e49cdca6cf5c726c599dd39e09e75e0f330a33121e73976e4facba9cfa001c28b7c96f8134f9981db6750b43a41710f51da4240fe03106c12acb1e7bb53d75ec7256da3fddd0718b89c365410fce61bc7c99b115fb4c3c318081fa7e1b65a37774e8e50c96e8ce2b2cc6b3b367982366a2bf9924c4bafdb3ff5e722258ab705c76d43e5f1f121b984814e98ea2b2b8725cd9bc905c0bc3d75c2a8db70a7153213c39ae371b2b5dc1dafcb19d6fae9"
    53  	eEnc := "010001"
    54  	dEnc := "4e21356983722aa1adedb084a483401c1127b781aac89eab103e1cfc52215494981d18dd8028566d9d499469c25476358de23821c78a6ae43005e26b394e3051b5ca206aa9968d68cae23b5affd9cbb4cb16d64ac7754b3cdba241b72ad6ddfc000facdb0f0dd03abd4efcfee1730748fcc47b7621182ef8af2eeb7c985349f62ce96ab373d2689baeaea0e28ea7d45f2d605451920ca4ea1f0c08b0f1f6711eaa4b7cca66d58a6b916f9985480f90aca97210685ac7b12d2ec3e30a1c7b97b65a18d38a93189258aa346bf2bc572cd7e7359605c20221b8909d599ed9d38164c9c4abf396f897b9993c1e805e574d704649985b600fa0ced8e5427071d7049d"
    55  
    56  	p := new(big.Int).SetBytes(mustDecodeHex(pEnc))
    57  	q := new(big.Int).SetBytes(mustDecodeHex(qEnc))
    58  	N := new(big.Int).SetBytes(mustDecodeHex(NEnc))
    59  	e := new(big.Int).SetBytes(mustDecodeHex(eEnc))
    60  	d := new(big.Int).SetBytes(mustDecodeHex(dEnc))
    61  
    62  	primes := make([]*big.Int, 2)
    63  	primes[0] = p
    64  	primes[1] = q
    65  
    66  	key := &rsa.PrivateKey{
    67  		PublicKey: rsa.PublicKey{
    68  			N: N,
    69  			E: int(e.Int64()),
    70  		},
    71  		D:      d,
    72  		Primes: primes,
    73  	}
    74  
    75  	return key
    76  }
    77  
    78  func runSignatureProtocol(signer Signer, client Client, message []byte, random io.Reader) ([]byte, error) {
    79  	inputMsg, err := client.Prepare(random, message)
    80  	if err != nil {
    81  		return nil, fmt.Errorf("prepare failed: %w", err)
    82  	}
    83  
    84  	blindedMsg, state, err := client.Blind(random, inputMsg)
    85  	if err != nil {
    86  		return nil, fmt.Errorf("blind failed: %w", err)
    87  	}
    88  
    89  	kLen := (signer.sk.N.BitLen() + 7) / 8
    90  	if len(blindedMsg) != kLen {
    91  		return nil, fmt.Errorf("Protocol message (blind message) length mismatch, expected %d, got %d", kLen, len(blindedMsg))
    92  	}
    93  
    94  	blindedSig, err := signer.BlindSign(blindedMsg)
    95  	if err != nil {
    96  		return nil, fmt.Errorf("blindSign failed: %w", err)
    97  	}
    98  
    99  	if len(blindedSig) != kLen {
   100  		return nil, fmt.Errorf("Protocol message (blind signature) length mismatch, expected %d, got %d", kLen, len(blindedMsg))
   101  	}
   102  
   103  	sig, err := client.Finalize(state, blindedSig)
   104  	if err != nil {
   105  		return nil, fmt.Errorf("finalize failed: %w", err)
   106  	}
   107  
   108  	err = client.Verify(inputMsg, sig)
   109  	if err != nil {
   110  		return nil, fmt.Errorf("verification failed: %w", err)
   111  	}
   112  
   113  	return sig, nil
   114  }
   115  
   116  func TestRoundTrip(t *testing.T) {
   117  	message := []byte("hello world")
   118  	key, err := loadPrivateKey()
   119  	if err != nil {
   120  		t.Fatal(err)
   121  	}
   122  
   123  	for _, variant := range []Variant{
   124  		SHA384PSSDeterministic,
   125  		SHA384PSSZeroDeterministic,
   126  		SHA384PSSRandomized,
   127  		SHA384PSSZeroRandomized,
   128  	} {
   129  		t.Run(variant.String(), func(tt *testing.T) {
   130  			client, err := NewClient(variant, &key.PublicKey)
   131  			if err != nil {
   132  				t.Fatal(err)
   133  			}
   134  			signer := NewSigner(key)
   135  
   136  			sig, err := runSignatureProtocol(signer, client, message, rand.Reader)
   137  			if err != nil {
   138  				t.Fatal(err)
   139  			}
   140  			if sig == nil {
   141  				t.Fatal("nil signature output")
   142  			}
   143  		})
   144  	}
   145  }
   146  
   147  func TestDeterministicRoundTrip(t *testing.T) {
   148  	message := []byte("hello world")
   149  	key, err := loadPrivateKey()
   150  	if err != nil {
   151  		t.Fatal(err)
   152  	}
   153  
   154  	client, err := NewClient(SHA384PSSDeterministic, &key.PublicKey)
   155  	if err != nil {
   156  		t.Fatal(err)
   157  	}
   158  	signer := NewSigner(key)
   159  
   160  	sig, err := runSignatureProtocol(signer, client, message, rand.Reader)
   161  	if err != nil {
   162  		t.Fatal(err)
   163  	}
   164  	if sig == nil {
   165  		t.Fatal("nil signature output")
   166  	}
   167  }
   168  
   169  func TestDeterministicBlindFailure(t *testing.T) {
   170  	message := []byte("hello world")
   171  	key, err := loadPrivateKey()
   172  	if err != nil {
   173  		t.Fatal(err)
   174  	}
   175  
   176  	client, err := NewClient(SHA384PSSDeterministic, &key.PublicKey)
   177  	if err != nil {
   178  		t.Fatal(err)
   179  	}
   180  	signer := NewSigner(key)
   181  
   182  	_, err = runSignatureProtocol(signer, client, message, nil)
   183  	if err == nil {
   184  		t.Fatal("Expected signature generation to fail with empty randomness")
   185  	}
   186  }
   187  
   188  func TestRandomSignVerify(t *testing.T) {
   189  	message := []byte("hello world")
   190  	key, err := loadPrivateKey()
   191  	if err != nil {
   192  		t.Fatal(err)
   193  	}
   194  
   195  	client, err := NewClient(SHA384PSSRandomized, &key.PublicKey)
   196  	if err != nil {
   197  		t.Fatal(err)
   198  	}
   199  	signer := NewSigner(key)
   200  
   201  	sig1, err := runSignatureProtocol(signer, client, message, rand.Reader)
   202  	if err != nil {
   203  		t.Fatal(err)
   204  	}
   205  	sig2, err := runSignatureProtocol(signer, client, message, rand.Reader)
   206  	if err != nil {
   207  		t.Fatal(err)
   208  	}
   209  
   210  	if sig1 == nil || sig2 == nil {
   211  		t.Fatal("nil signature output")
   212  	}
   213  	if bytes.Equal(sig1, sig2) {
   214  		t.Fatal("random signatures matched when they should differ")
   215  	}
   216  }
   217  
   218  type mockRandom struct {
   219  	counter uint8
   220  }
   221  
   222  func (r *mockRandom) Read(p []byte) (n int, err error) {
   223  	for i := range p {
   224  		p[i] = r.counter
   225  		r.counter = r.counter + 1
   226  	}
   227  	return len(p), nil
   228  }
   229  
   230  func TestFixedRandomSignVerify(t *testing.T) {
   231  	message := []byte("hello world")
   232  	key, err := loadPrivateKey()
   233  	if err != nil {
   234  		t.Fatal(err)
   235  	}
   236  
   237  	client, err := NewClient(SHA384PSSRandomized, &key.PublicKey)
   238  	if err != nil {
   239  		t.Fatal(err)
   240  	}
   241  	signer := NewSigner(key)
   242  
   243  	mockRand := &mockRandom{0}
   244  	sig1, err := runSignatureProtocol(signer, client, message, mockRand)
   245  	if err != nil {
   246  		t.Fatal(err)
   247  	}
   248  	mockRand = &mockRandom{0}
   249  	sig2, err := runSignatureProtocol(signer, client, message, mockRand)
   250  	if err != nil {
   251  		t.Fatal(err)
   252  	}
   253  
   254  	if sig1 == nil || sig2 == nil {
   255  		t.Fatal("nil signature output")
   256  	}
   257  	if !bytes.Equal(sig1, sig2) {
   258  		t.Fatal("random signatures with fixed random seeds differ when they should be equal")
   259  	}
   260  }
   261  
   262  type rawTestVector struct {
   263  	Name           string `json:"name"`
   264  	P              string `json:"p"`
   265  	Q              string `json:"q"`
   266  	N              string `json:"n"`
   267  	E              string `json:"e"`
   268  	D              string `json:"d"`
   269  	Msg            string `json:"msg"`
   270  	MsgPrefix      string `json:"msg_prefix"`
   271  	InputMsg       string `json:"input_msg"`
   272  	Salt           string `json:"salt"`
   273  	SaltLen        string `json:"sLen"`
   274  	IsRandomized   string `json:"is_randomized"`
   275  	Inv            string `json:"inv"`
   276  	BlindedMessage string `json:"blinded_msg"`
   277  	BlindSig       string `json:"blind_sig"`
   278  	Sig            string `json:"sig"`
   279  }
   280  
   281  type testVector struct {
   282  	t              *testing.T
   283  	name           string
   284  	p              *big.Int
   285  	q              *big.Int
   286  	n              *big.Int
   287  	e              int
   288  	d              *big.Int
   289  	msg            []byte
   290  	msgPrefix      []byte
   291  	inputMsg       []byte
   292  	salt           []byte
   293  	saltLen        int
   294  	isRandomized   bool
   295  	blindInverse   *big.Int
   296  	blindedMessage []byte
   297  	blindSig       []byte
   298  	sig            []byte
   299  }
   300  
   301  type testVectorList struct {
   302  	t       *testing.T
   303  	vectors []testVector
   304  }
   305  
   306  func mustUnhexBigInt(number string) *big.Int {
   307  	data := mustUnhex(number)
   308  	value := new(big.Int)
   309  	value.SetBytes(data)
   310  	return value
   311  }
   312  
   313  func mustUnhex(value string) []byte {
   314  	value = strings.TrimPrefix(value, "0x")
   315  	data, err := hex.DecodeString(value)
   316  	if err != nil {
   317  		panic(err)
   318  	}
   319  
   320  	return data
   321  }
   322  
   323  func mustUnhexInt(value string) int {
   324  	number := mustUnhexBigInt(value)
   325  	result := int(number.Int64())
   326  	return result
   327  }
   328  
   329  func (tv *testVector) UnmarshalJSON(data []byte) error {
   330  	raw := rawTestVector{}
   331  	err := json.Unmarshal(data, &raw)
   332  	if err != nil {
   333  		return err
   334  	}
   335  
   336  	tv.name = raw.Name
   337  	tv.p = mustUnhexBigInt(raw.P)
   338  	tv.q = mustUnhexBigInt(raw.Q)
   339  	tv.n = mustUnhexBigInt(raw.N)
   340  	tv.e = mustUnhexInt(raw.E)
   341  	tv.d = mustUnhexBigInt(raw.D)
   342  	tv.msg = mustUnhex(raw.Msg)
   343  	tv.msgPrefix = mustUnhex(raw.MsgPrefix)
   344  	tv.inputMsg = mustUnhex(raw.InputMsg)
   345  	tv.salt = mustUnhex(raw.Salt)
   346  	tv.saltLen = mustUnhexInt(raw.SaltLen)
   347  	tv.isRandomized = mustUnhexInt(raw.IsRandomized) != 0
   348  	tv.blindedMessage = mustUnhex(raw.BlindedMessage)
   349  	tv.blindInverse = mustUnhexBigInt(raw.Inv)
   350  	tv.blindSig = mustUnhex(raw.BlindSig)
   351  	tv.sig = mustUnhex(raw.Sig)
   352  
   353  	return nil
   354  }
   355  
   356  func (tvl testVectorList) MarshalJSON() ([]byte, error) {
   357  	return json.Marshal(tvl.vectors)
   358  }
   359  
   360  func (tvl *testVectorList) UnmarshalJSON(data []byte) error {
   361  	err := json.Unmarshal(data, &tvl.vectors)
   362  	if err != nil {
   363  		return err
   364  	}
   365  
   366  	for i := range tvl.vectors {
   367  		tvl.vectors[i].t = tvl.t
   368  	}
   369  
   370  	return nil
   371  }
   372  
   373  func verifyTestVector(t *testing.T, vector testVector) {
   374  	key := new(rsa.PrivateKey)
   375  	key.PublicKey.N = vector.n
   376  	key.PublicKey.E = vector.e
   377  	key.D = vector.d
   378  	key.Primes = []*big.Int{vector.p, vector.q}
   379  	key.Precomputed.Dp = nil // Remove precomputed CRT values
   380  
   381  	// Recompute the original blind
   382  	rInv := new(big.Int).Set(vector.blindInverse)
   383  	r := new(big.Int).ModInverse(rInv, key.N)
   384  	if r == nil {
   385  		t.Fatal("Failed to compute blind inverse")
   386  	}
   387  
   388  	var variant Variant
   389  	switch vector.name {
   390  	case "RSABSSA-SHA384-PSS-Deterministic":
   391  		variant = SHA384PSSDeterministic
   392  	case "RSABSSA-SHA384-PSSZERO-Deterministic":
   393  		variant = SHA384PSSZeroDeterministic
   394  	case "RSABSSA-SHA384-PSS-Randomized":
   395  		variant = SHA384PSSRandomized
   396  	case "RSABSSA-SHA384-PSSZERO-Randomized":
   397  		variant = SHA384PSSZeroRandomized
   398  	default:
   399  		t.Fatal("variant not supported")
   400  	}
   401  
   402  	signer := NewSigner(key)
   403  
   404  	client, err := NewClient(variant, &key.PublicKey)
   405  	test.CheckNoErr(t, err, "new client failed")
   406  
   407  	blindedMsg, state, err := client.fixedBlind(vector.inputMsg, vector.salt, r, rInv)
   408  	test.CheckNoErr(t, err, "fixedBlind failed")
   409  	got := hex.EncodeToString(blindedMsg)
   410  	want := hex.EncodeToString(vector.blindedMessage)
   411  	if got != want {
   412  		test.ReportError(t, got, want)
   413  	}
   414  
   415  	blindSig, err := signer.BlindSign(blindedMsg)
   416  	test.CheckNoErr(t, err, "blindSign failed")
   417  	got = hex.EncodeToString(blindSig)
   418  	want = hex.EncodeToString(vector.blindSig)
   419  	if got != want {
   420  		test.ReportError(t, got, want)
   421  	}
   422  
   423  	sig, err := client.Finalize(state, blindSig)
   424  	test.CheckNoErr(t, err, "finalize failed")
   425  	got = hex.EncodeToString(sig)
   426  	want = hex.EncodeToString(vector.sig)
   427  	if got != want {
   428  		test.ReportError(t, got, want)
   429  	}
   430  
   431  	verifier, err := NewVerifier(variant, &key.PublicKey)
   432  	test.CheckNoErr(t, err, "new verifier failed")
   433  
   434  	test.CheckNoErr(t, verifier.Verify(vector.inputMsg, sig), "verification failed")
   435  }
   436  
   437  func TestVectors(t *testing.T) {
   438  	data, err := os.ReadFile("testdata/test_vectors_rfc9474.json")
   439  	if err != nil {
   440  		t.Fatal("Failed reading test vectors:", err)
   441  	}
   442  
   443  	tvl := &testVectorList{}
   444  	err = tvl.UnmarshalJSON(data)
   445  	if err != nil {
   446  		t.Fatal("Failed deserializing test vectors:", err)
   447  	}
   448  
   449  	for _, vector := range tvl.vectors {
   450  		t.Run(vector.name, func(tt *testing.T) {
   451  			verifyTestVector(tt, vector)
   452  		})
   453  	}
   454  }
   455  
   456  func BenchmarkBRSA(b *testing.B) {
   457  	message := []byte("hello world")
   458  	key := loadStrongRSAKey()
   459  	server := NewSigner(key)
   460  
   461  	client, err := NewClient(SHA384PSSRandomized, &key.PublicKey)
   462  	if err != nil {
   463  		b.Fatal(err)
   464  	}
   465  
   466  	inputMsg, err := client.Prepare(rand.Reader, message)
   467  	if err != nil {
   468  		b.Errorf("prepare failed: %v", err)
   469  	}
   470  
   471  	blindedMsg, state, err := client.Blind(rand.Reader, inputMsg)
   472  	if err != nil {
   473  		b.Errorf("blind failed: %v", err)
   474  	}
   475  
   476  	blindedSig, err := server.BlindSign(blindedMsg)
   477  	if err != nil {
   478  		b.Errorf("blindSign failed: %v", err)
   479  	}
   480  
   481  	sig, err := client.Finalize(state, blindedSig)
   482  	if err != nil {
   483  		b.Errorf("finalize failed: %v", err)
   484  	}
   485  
   486  	err = client.Verify(inputMsg, sig)
   487  	if err != nil {
   488  		b.Errorf("verification failed: %v", err)
   489  	}
   490  
   491  	b.Run("Blind", func(b *testing.B) {
   492  		for n := 0; n < b.N; n++ {
   493  			_, _, err := client.Blind(rand.Reader, inputMsg)
   494  			if err != nil {
   495  				b.Fatal(err)
   496  			}
   497  		}
   498  	})
   499  
   500  	b.Run("BlindSign", func(b *testing.B) {
   501  		for n := 0; n < b.N; n++ {
   502  			_, err := server.BlindSign(blindedMsg)
   503  			if err != nil {
   504  				b.Fatal(err)
   505  			}
   506  		}
   507  	})
   508  
   509  	b.Run("Finalize", func(b *testing.B) {
   510  		for n := 0; n < b.N; n++ {
   511  			_, err := client.Finalize(state, blindedSig)
   512  			if err != nil {
   513  				b.Fatal(err)
   514  			}
   515  		}
   516  	})
   517  
   518  	b.Run("Verify", func(b *testing.B) {
   519  		for n := 0; n < b.N; n++ {
   520  			err := client.Verify(inputMsg, sig)
   521  			if err != nil {
   522  				b.Fatal(err)
   523  			}
   524  		}
   525  	})
   526  }
   527  
   528  func Example_blindrsa() {
   529  	// Setup (offline)
   530  
   531  	// Server: generate an RSA keypair.
   532  	sk, err := rsa.GenerateKey(rand.Reader, 2048)
   533  	if err != nil {
   534  		fmt.Printf("failed to generate RSA key: %v", err)
   535  		return
   536  	}
   537  	pk := &sk.PublicKey
   538  	server := NewSigner(sk)
   539  
   540  	// Client: stores Server's public key.
   541  	client, err := NewClient(SHA384PSSRandomized, pk)
   542  	if err != nil {
   543  		fmt.Printf("failed to invoke a client: %v", err)
   544  		return
   545  	}
   546  
   547  	// Protocol (online)
   548  
   549  	// Client prepares the message to be signed.
   550  	msg := []byte("alice and bob")
   551  	preparedMessage, err := client.Prepare(rand.Reader, msg)
   552  	if err != nil {
   553  		fmt.Printf("client failed to prepare the message: %v", err)
   554  		return
   555  	}
   556  
   557  	// Client blinds a message.
   558  	blindedMsg, state, err := client.Blind(rand.Reader, preparedMessage)
   559  	if err != nil {
   560  		fmt.Printf("client failed to generate blinded message: %v", err)
   561  		return
   562  	}
   563  
   564  	// Server signs a blinded message, and produces a blinded signature.
   565  	blindedSignature, err := server.BlindSign(blindedMsg)
   566  	if err != nil {
   567  		fmt.Printf("server failed to sign: %v", err)
   568  		return
   569  	}
   570  
   571  	// Client build a signature from the previous state and blinded signature.
   572  	signature, err := client.Finalize(state, blindedSignature)
   573  	if err != nil {
   574  		fmt.Printf("client failed to obtain signature: %v", err)
   575  		return
   576  	}
   577  
   578  	// Client build a signature from the previous state and blinded signature.
   579  	ok := client.Verify(preparedMessage, signature)
   580  
   581  	fmt.Printf("Valid signature: %v", ok == nil)
   582  	// Output: Valid signature: true
   583  }