github.com/cloudflare/circl@v1.5.0/blindsign/blindrsa/partiallyblindrsa/pbrsa_test.go (about)

     1  package partiallyblindrsa
     2  
     3  import (
     4  	"bytes"
     5  	"crypto"
     6  	"crypto/rand"
     7  	"crypto/rsa"
     8  	"encoding/hex"
     9  	"encoding/json"
    10  	"fmt"
    11  	"io"
    12  	"math/big"
    13  	"os"
    14  	"testing"
    15  
    16  	"github.com/cloudflare/circl/blindsign/blindrsa/internal/keys"
    17  )
    18  
    19  const (
    20  	pbrsaTestVectorOutEnvironmentKey = "PBRSA_TEST_VECTORS_OUT"
    21  	pbrsaTestVectorInEnvironmentKey  = "PBRSA_TEST_VECTORS_IN"
    22  )
    23  
    24  func loadStrongRSAKey() *rsa.PrivateKey {
    25  	// https://gist.github.com/chris-wood/b77536febb25a5a11af428afff77820a
    26  	pEnc := "dcd90af1be463632c0d5ea555256a20605af3db667475e190e3af12a34a3324c46a3094062c59fb4b249e0ee6afba8bee14e0276d126c99f4784b23009bf6168ff628ac1486e5ae8e23ce4d362889de4df63109cbd90ef93db5ae64372bfe1c55f832766f21e94ea3322eb2182f10a891546536ba907ad74b8d72469bea396f3"
    27  	qEnc := "f8ba5c89bd068f57234a3cf54a1c89d5b4cd0194f2633ca7c60b91a795a56fa8c8686c0e37b1c4498b851e3420d08bea29f71d195cfbd3671c6ddc49cf4c1db5b478231ea9d91377ffa98fe95685fca20ba4623212b2f2def4da5b281ed0100b651f6db32112e4017d831c0da668768afa7141d45bbc279f1e0f8735d74395b3"
    28  	NEnc := "d6930820f71fe517bf3259d14d40209b02a5c0d3d61991c731dd7da39f8d69821552e2318d6c9ad897e603887a476ea3162c1205da9ac96f02edf31df049bd55f142134c17d4382a0e78e275345f165fbe8e49cdca6cf5c726c599dd39e09e75e0f330a33121e73976e4facba9cfa001c28b7c96f8134f9981db6750b43a41710f51da4240fe03106c12acb1e7bb53d75ec7256da3fddd0718b89c365410fce61bc7c99b115fb4c3c318081fa7e1b65a37774e8e50c96e8ce2b2cc6b3b367982366a2bf9924c4bafdb3ff5e722258ab705c76d43e5f1f121b984814e98ea2b2b8725cd9bc905c0bc3d75c2a8db70a7153213c39ae371b2b5dc1dafcb19d6fae9"
    29  	eEnc := "010001"
    30  	dEnc := "4e21356983722aa1adedb084a483401c1127b781aac89eab103e1cfc52215494981d18dd8028566d9d499469c25476358de23821c78a6ae43005e26b394e3051b5ca206aa9968d68cae23b5affd9cbb4cb16d64ac7754b3cdba241b72ad6ddfc000facdb0f0dd03abd4efcfee1730748fcc47b7621182ef8af2eeb7c985349f62ce96ab373d2689baeaea0e28ea7d45f2d605451920ca4ea1f0c08b0f1f6711eaa4b7cca66d58a6b916f9985480f90aca97210685ac7b12d2ec3e30a1c7b97b65a18d38a93189258aa346bf2bc572cd7e7359605c20221b8909d599ed9d38164c9c4abf396f897b9993c1e805e574d704649985b600fa0ced8e5427071d7049d"
    31  
    32  	p := new(big.Int).SetBytes(mustDecodeHex(pEnc))
    33  	q := new(big.Int).SetBytes(mustDecodeHex(qEnc))
    34  	N := new(big.Int).SetBytes(mustDecodeHex(NEnc))
    35  	e := new(big.Int).SetBytes(mustDecodeHex(eEnc))
    36  	d := new(big.Int).SetBytes(mustDecodeHex(dEnc))
    37  
    38  	primes := make([]*big.Int, 2)
    39  	primes[0] = p
    40  	primes[1] = q
    41  
    42  	key := &rsa.PrivateKey{
    43  		PublicKey: rsa.PublicKey{
    44  			N: N,
    45  			E: int(e.Int64()),
    46  		},
    47  		D:      d,
    48  		Primes: primes,
    49  	}
    50  
    51  	return key
    52  }
    53  
    54  func runPBRSA(signer Signer, verifier Verifier, message, metadata []byte, random io.Reader) ([]byte, error) {
    55  	blindedMsg, state, err := verifier.Blind(random, message, metadata)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  
    60  	kLen := (signer.sk.Pk.N.BitLen() + 7) / 8
    61  	if len(blindedMsg) != kLen {
    62  		return nil, fmt.Errorf("Protocol message (blind message) length mismatch, expected %d, got %d", kLen, len(blindedMsg))
    63  	}
    64  
    65  	blindedSig, err := signer.BlindSign(blindedMsg, metadata)
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  
    70  	if len(blindedSig) != kLen {
    71  		return nil, fmt.Errorf("Protocol message (blind signature) length mismatch, expected %d, got %d", kLen, len(blindedMsg))
    72  	}
    73  
    74  	sig, err := state.Finalize(blindedSig)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  
    79  	err = verifier.Verify(message, metadata, sig)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	return sig, nil
    85  }
    86  
    87  func mustDecodeHex(h string) []byte {
    88  	b, err := hex.DecodeString(h)
    89  	if err != nil {
    90  		panic(err)
    91  	}
    92  	return b
    93  }
    94  
    95  func TestPBRSARoundTrip(t *testing.T) {
    96  	message := []byte("hello world")
    97  	metadata := []byte("metadata")
    98  	key := loadStrongRSAKey()
    99  
   100  	hash := crypto.SHA384
   101  	verifier := NewVerifier(&key.PublicKey, hash)
   102  	signer, err := NewSigner(key, hash)
   103  	if err != nil {
   104  		t.Fatal(err)
   105  	}
   106  
   107  	sig, err := runPBRSA(signer, verifier, message, metadata, rand.Reader)
   108  	if err != nil {
   109  		t.Fatal(err)
   110  	}
   111  	if sig == nil {
   112  		t.Fatal("nil signature output")
   113  	}
   114  }
   115  
   116  type encodedPBRSATestVector struct {
   117  	Message   string `json:"msg"`
   118  	Info      string `json:"info"`
   119  	P         string `json:"p"`
   120  	Q         string `json:"q"`
   121  	D         string `json:"d"`
   122  	E         string `json:"e"`
   123  	N         string `json:"N"`
   124  	Eprime    string `json:"eprime"`
   125  	Blind     string `json:"blind"`
   126  	Salt      string `json:"salt"`
   127  	Request   string `json:"blinded_msg"`
   128  	Response  string `json:"blinded_sig"`
   129  	Signature string `json:"sig"`
   130  }
   131  
   132  type rawPBRSATestVector struct {
   133  	privateKey *rsa.PrivateKey
   134  	message    []byte
   135  	info       []byte
   136  	infoKey    []byte
   137  	blind      []byte
   138  	salt       []byte
   139  	request    []byte
   140  	response   []byte
   141  	signature  []byte
   142  }
   143  
   144  func mustHex(d []byte) string {
   145  	return hex.EncodeToString(d)
   146  }
   147  
   148  func (tv rawPBRSATestVector) MarshalJSON() ([]byte, error) {
   149  	pEnc := mustHex(tv.privateKey.Primes[0].Bytes())
   150  	qEnc := mustHex(tv.privateKey.Primes[1].Bytes())
   151  	nEnc := mustHex(tv.privateKey.N.Bytes())
   152  	e := new(big.Int).SetInt64(int64(tv.privateKey.PublicKey.E))
   153  	eEnc := mustHex(e.Bytes())
   154  	dEnc := mustHex(tv.privateKey.D.Bytes())
   155  	ePrimeEnc := mustHex(tv.infoKey)
   156  	return json.Marshal(encodedPBRSATestVector{
   157  		P:         pEnc,
   158  		Q:         qEnc,
   159  		D:         dEnc,
   160  		E:         eEnc,
   161  		N:         nEnc,
   162  		Eprime:    ePrimeEnc,
   163  		Message:   mustHex(tv.message),
   164  		Info:      mustHex(tv.info),
   165  		Blind:     mustHex(tv.blind),
   166  		Salt:      mustHex(tv.salt),
   167  		Request:   mustHex(tv.request),
   168  		Response:  mustHex(tv.response),
   169  		Signature: mustHex(tv.signature),
   170  	})
   171  }
   172  
   173  func generatePBRSATestVector(t *testing.T, msg, metadata []byte) rawPBRSATestVector {
   174  	key := loadStrongRSAKey()
   175  
   176  	hash := crypto.SHA384
   177  	verifier := NewVerifier(&key.PublicKey, hash)
   178  	signer, err := NewSigner(key, hash)
   179  	if err != nil {
   180  		t.Fatal(err)
   181  	}
   182  
   183  	publicKey := keys.NewBigPublicKey(&key.PublicKey)
   184  	metadataKey := derivePublicKey(hash, publicKey, metadata)
   185  
   186  	blindedMsg, state, err := verifier.Blind(rand.Reader, msg, metadata)
   187  	if err != nil {
   188  		t.Fatal(err)
   189  	}
   190  
   191  	blindedSig, err := signer.BlindSign(blindedMsg, metadata)
   192  	if err != nil {
   193  		t.Fatal(err)
   194  	}
   195  
   196  	sig, err := state.Finalize(blindedSig)
   197  	if err != nil {
   198  		t.Fatal(err)
   199  	}
   200  
   201  	err = verifier.Verify(msg, metadata, sig)
   202  	if err != nil {
   203  		t.Fatal(err)
   204  	}
   205  
   206  	return rawPBRSATestVector{
   207  		message:    msg,
   208  		info:       metadata,
   209  		privateKey: key,
   210  		infoKey:    metadataKey.Marshal(),
   211  		salt:       state.CopySalt(),
   212  		blind:      state.CopyBlind(),
   213  		request:    blindedMsg,
   214  		response:   blindedSig,
   215  		signature:  sig,
   216  	}
   217  }
   218  
   219  func verifyTestVector(t *testing.T, vector rawPBRSATestVector) {
   220  	key := loadStrongRSAKey()
   221  
   222  	key.PublicKey.N = vector.privateKey.N
   223  	key.PublicKey.E = vector.privateKey.E
   224  	key.D = vector.privateKey.D
   225  	key.Primes[0] = vector.privateKey.Primes[0]
   226  	key.Primes[1] = vector.privateKey.Primes[1]
   227  	key.Precomputed.Dp = nil // Remove precomputed CRT values
   228  
   229  	hash := crypto.SHA384
   230  	signer, err := NewSigner(key, hash)
   231  	if err != nil {
   232  		t.Fatal(err)
   233  	}
   234  	verifier := NewVerifier(&key.PublicKey, crypto.SHA384)
   235  
   236  	r := new(big.Int).SetBytes(vector.blind)
   237  	rInv := new(big.Int).ModInverse(r, key.N)
   238  	if r == nil {
   239  		t.Fatal("Failed to compute blind inverse")
   240  	}
   241  
   242  	blindedMsg, state, err := verifier.FixedBlind(vector.message, vector.info, vector.salt, r.Bytes(), rInv.Bytes())
   243  	if err != nil {
   244  		t.Fatal(err)
   245  	}
   246  
   247  	blindSig, err := signer.BlindSign(blindedMsg, vector.info)
   248  	if err != nil {
   249  		t.Fatal(err)
   250  	}
   251  
   252  	sig, err := state.Finalize(blindSig)
   253  	if err != nil {
   254  		t.Fatal(err)
   255  	}
   256  
   257  	if !bytes.Equal(sig, vector.signature) {
   258  		t.Errorf("Signature mismatch: expected %x, got %x", sig, vector.signature)
   259  	}
   260  }
   261  
   262  func TestPBRSAGenerateTestVector(t *testing.T) {
   263  	testCases := []struct {
   264  		msg      []byte
   265  		metadata []byte
   266  	}{
   267  		{
   268  			[]byte("hello world"),
   269  			[]byte("metadata"),
   270  		},
   271  		{
   272  			[]byte("hello world"),
   273  			[]byte(""),
   274  		},
   275  		{
   276  			[]byte(""),
   277  			[]byte("metadata"),
   278  		},
   279  		{
   280  			[]byte(""),
   281  			[]byte(""),
   282  		},
   283  	}
   284  
   285  	vectors := []rawPBRSATestVector{}
   286  	for _, testCase := range testCases {
   287  		vectors = append(vectors, generatePBRSATestVector(t, testCase.msg, testCase.metadata))
   288  	}
   289  
   290  	for _, vector := range vectors {
   291  		verifyTestVector(t, vector)
   292  	}
   293  
   294  	// Encode the test vectors
   295  	encoded, err := json.Marshal(vectors)
   296  	if err != nil {
   297  		t.Fatalf("Error producing test vectors: %v", err)
   298  	}
   299  
   300  	var outputFile string
   301  	if outputFile = os.Getenv(pbrsaTestVectorOutEnvironmentKey); len(outputFile) > 0 {
   302  		err := os.WriteFile(outputFile, encoded, 0o600)
   303  		if err != nil {
   304  			t.Fatalf("Error writing test vectors: %v", err)
   305  		}
   306  	}
   307  }
   308  
   309  func BenchmarkPBRSA(b *testing.B) {
   310  	message := []byte("hello world")
   311  	metadata := []byte("good doggo")
   312  	key := loadStrongRSAKey()
   313  
   314  	hash := crypto.SHA384
   315  	verifier := NewVerifier(&key.PublicKey, hash)
   316  	signer, err := NewSigner(key, hash)
   317  	if err != nil {
   318  		b.Fatal(err)
   319  	}
   320  
   321  	var blindedMsg []byte
   322  	var state VerifierState
   323  	b.Run("Blind", func(b *testing.B) {
   324  		for n := 0; n < b.N; n++ {
   325  			blindedMsg, state, err = verifier.Blind(rand.Reader, message, metadata)
   326  			if err != nil {
   327  				b.Fatal(err)
   328  			}
   329  		}
   330  	})
   331  
   332  	var blindedSig []byte
   333  	b.Run("BlindSign", func(b *testing.B) {
   334  		for n := 0; n < b.N; n++ {
   335  			blindedSig, err = signer.BlindSign(blindedMsg, metadata)
   336  			if err != nil {
   337  				b.Fatal(err)
   338  			}
   339  		}
   340  	})
   341  
   342  	var sig []byte
   343  	b.Run("Finalize", func(b *testing.B) {
   344  		for n := 0; n < b.N; n++ {
   345  			sig, err = state.Finalize(blindedSig)
   346  			if err != nil {
   347  				b.Fatal(err)
   348  			}
   349  		}
   350  	})
   351  
   352  	err = verifier.Verify(message, metadata, sig)
   353  	if err != nil {
   354  		b.Fatal(err)
   355  	}
   356  }