github.com/cloudflare/circl@v1.5.0/tss/rsa/rsa_threshold.go (about)

     1  // Package rsa provides RSA threshold signature scheme.
     2  //
     3  // This package implements the Protocol 1 of "Practical Threshold Signatures"
     4  // by Victor Shoup [1].
     5  //
     6  // # References
     7  //
     8  // [1] https://www.iacr.org/archive/eurocrypt2000/1807/18070209-new.pdf
     9  package rsa
    10  
    11  import (
    12  	"crypto"
    13  	"crypto/rand"
    14  	"crypto/rsa"
    15  	"errors"
    16  	"fmt"
    17  	"io"
    18  	"math"
    19  	"math/big"
    20  
    21  	cmath "github.com/cloudflare/circl/math"
    22  )
    23  
    24  // GenerateKey generates a RSA keypair for its use in RSA threshold signatures.
    25  // Internally, the modulus is the product of two safe primes. The time
    26  // consumed by this function is relatively longer than the regular
    27  // GenerateKey function from the crypto/rsa package.
    28  func GenerateKey(random io.Reader, bits int) (*rsa.PrivateKey, error) {
    29  	p, err := cmath.SafePrime(random, bits/2)
    30  	if err != nil {
    31  		return nil, err
    32  	}
    33  
    34  	var q *big.Int
    35  	n := new(big.Int)
    36  	found := false
    37  	for !found {
    38  		q, err = cmath.SafePrime(random, bits-p.BitLen())
    39  		if err != nil {
    40  			return nil, err
    41  		}
    42  
    43  		// check for different primes.
    44  		if p.Cmp(q) != 0 {
    45  			n.Mul(p, q)
    46  			// check n has the desired bitlength.
    47  			if n.BitLen() == bits {
    48  				found = true
    49  			}
    50  		}
    51  	}
    52  
    53  	one := big.NewInt(1)
    54  	pminus1 := new(big.Int).Sub(p, one)
    55  	qminus1 := new(big.Int).Sub(q, one)
    56  	totient := new(big.Int).Mul(pminus1, qminus1)
    57  
    58  	priv := new(rsa.PrivateKey)
    59  	priv.Primes = []*big.Int{p, q}
    60  	priv.N = n
    61  	priv.E = 65537
    62  	priv.D = new(big.Int)
    63  	e := big.NewInt(int64(priv.E))
    64  	ok := priv.D.ModInverse(e, totient)
    65  	if ok == nil {
    66  		return nil, errors.New("public key is not coprime to phi(n)")
    67  	}
    68  
    69  	priv.Precompute()
    70  
    71  	return priv, nil
    72  }
    73  
    74  // l or `Players`, the total number of Players.
    75  // t, the number of corrupted Players.
    76  // k=t+1 or `Threshold`, the number of signature shares needed to obtain a signature.
    77  
    78  func validateParams(players, threshold uint) error {
    79  	if players <= 1 {
    80  		return errors.New("rsa_threshold: Players (l) invalid: should be > 1")
    81  	}
    82  	if threshold < 1 || threshold > players {
    83  		return fmt.Errorf("rsa_threshold: Threshold (k) invalid: %d < 1 || %d > %d", threshold, threshold, players)
    84  	}
    85  	return nil
    86  }
    87  
    88  // Deal takes in an existing RSA private key generated elsewhere. If cache is true, cached values are stored in KeyShare taking up more memory by reducing Sign time.
    89  // See KeyShare documentation. Multi-prime RSA keys are unsupported.
    90  func Deal(randSource io.Reader, players, threshold uint, key *rsa.PrivateKey, cache bool) ([]KeyShare, error) {
    91  	err := validateParams(players, threshold)
    92  
    93  	ONE := big.NewInt(1)
    94  
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	if len(key.Primes) != 2 {
   100  		return nil, errors.New("multiprime rsa keys are unsupported")
   101  	}
   102  
   103  	p := key.Primes[0]
   104  	q := key.Primes[1]
   105  	e := int64(key.E)
   106  
   107  	// p = 2p' + 1
   108  	// q = 2q' + 1
   109  	// p' = (p - 1)/2
   110  	// q' = (q - 1)/2
   111  	// m = p'q' = (p - 1)(q - 1)/4
   112  
   113  	var pprime big.Int
   114  	// p - 1
   115  	pprime.Sub(p, ONE)
   116  
   117  	// q - 1
   118  	var m big.Int
   119  	m.Sub(q, ONE)
   120  	// (p - 1)(q - 1)
   121  	m.Mul(&m, &pprime)
   122  	// >> 2 == / 4
   123  	m.Rsh(&m, 2)
   124  
   125  	// de ≡ 1
   126  	var d big.Int
   127  	_d := d.ModInverse(big.NewInt(e), &m)
   128  
   129  	if _d == nil {
   130  		return nil, errors.New("rsa_threshold: no ModInverse for e in Z/Zm")
   131  	}
   132  
   133  	// a_0...a_{k-1}
   134  	a := make([]*big.Int, threshold)
   135  	// a_0 = d
   136  	a[0] = &d
   137  
   138  	// a_0...a_{k-1} = rand from {0, ..., m - 1}
   139  	for i := uint(1); i <= threshold-1; i++ {
   140  		a[i], err = rand.Int(randSource, &m)
   141  		if err != nil {
   142  			return nil, errors.New("rsa_threshold: unable to generate an int within [0, m)")
   143  		}
   144  	}
   145  
   146  	shares := make([]KeyShare, players)
   147  
   148  	// 1 <= i <= l
   149  	for i := uint(1); i <= players; i++ {
   150  		shares[i-1].Players = players
   151  		shares[i-1].Threshold = threshold
   152  		// Σ^{k-1}_{i=0} | a_i * X^i (mod m)
   153  		poly := computePolynomial(threshold, a, i, &m)
   154  		shares[i-1].si = poly
   155  		shares[i-1].Index = i
   156  		if cache {
   157  			shares[i-1].get2DeltaSi(int64(players))
   158  		}
   159  	}
   160  
   161  	return shares, nil
   162  }
   163  
   164  func calcN(p, q *big.Int) big.Int {
   165  	// n = pq
   166  	var n big.Int
   167  	n.Mul(p, q)
   168  	return n
   169  }
   170  
   171  // f(X) = Σ^{k-1}_{i=0} | a_i * X^i (mod m)
   172  func computePolynomial(k uint, a []*big.Int, x uint, m *big.Int) *big.Int {
   173  	// TODO: use Horner's method here.
   174  	sum := big.NewInt(0)
   175  	//  Σ^{k-1}_{i=0}
   176  	for i := uint(0); i <= k-1; i++ {
   177  		// X^i
   178  		// TODO optimize: we can compute x^{n+1} from the previous x^n
   179  		xi := int64(math.Pow(float64(x), float64(i)))
   180  		// a_i * X^i
   181  		prod := big.Int{}
   182  		prod.Mul(a[i], big.NewInt(xi))
   183  		// (mod m)
   184  		prod.Mod(&prod, m) // while not in the spec, we are eventually modding m, so we can mod here for efficiency
   185  		// Σ
   186  		sum.Add(sum, &prod)
   187  	}
   188  
   189  	sum.Mod(sum, m)
   190  
   191  	return sum
   192  }
   193  
   194  // PadHash MUST be called before signing a message
   195  func PadHash(padder Padder, hash crypto.Hash, pub *rsa.PublicKey, msg []byte) ([]byte, error) {
   196  	// Sign(Pad(Hash(M)))
   197  
   198  	hasher := hash.New()
   199  	hasher.Write(msg)
   200  	digest := hasher.Sum(nil)
   201  
   202  	return padder.Pad(pub, hash, digest)
   203  }
   204  
   205  type Signature = []byte
   206  
   207  // CombineSignShares combines t SignShare's to produce a valid signature
   208  func CombineSignShares(pub *rsa.PublicKey, shares []SignShare, msg []byte) (Signature, error) {
   209  	players := shares[0].Players
   210  	threshold := shares[0].Threshold
   211  
   212  	for i := range shares {
   213  		if shares[i].Players != players {
   214  			return nil, errors.New("rsa_threshold: shares didn't have consistent players")
   215  		}
   216  		if shares[i].Threshold != threshold {
   217  			return nil, errors.New("rsa_threshold: shares didn't have consistent threshold")
   218  		}
   219  	}
   220  
   221  	if uint(len(shares)) < threshold {
   222  		return nil, errors.New("rsa_threshold: insufficient shares for the threshold")
   223  	}
   224  
   225  	w := big.NewInt(1)
   226  	delta := calculateDelta(int64(players))
   227  	// i_1 ... i_k
   228  	for _, share := range shares {
   229  		// λ(S, 0, i)
   230  		lambda, err := computeLambda(delta, shares, 0, int64(share.Index))
   231  		if err != nil {
   232  			return nil, err
   233  		}
   234  		// 2λ
   235  		var exp big.Int
   236  		exp.Add(lambda, lambda) // faster than TWO * lambda
   237  
   238  		// we need to handle negative λ's (aka inverse), so abs it, compare, and if necessary modinverse
   239  		abslam := big.Int{}
   240  		abslam.Abs(&exp)
   241  		var tmp big.Int
   242  		// x_i^{|2λ|}
   243  		tmp.Exp(share.xi, &abslam, pub.N)
   244  		if abslam.Cmp(&exp) == 1 {
   245  			tmp.ModInverse(&tmp, pub.N)
   246  		}
   247  		// TODO  first compute all the powers for the negative exponents (but don't invert yet); multiply these together and then invert all at once. This is ok since (ab)^-1 = a^-1 b^-1
   248  
   249  		w.Mul(w, &tmp).Mod(w, pub.N)
   250  	}
   251  	w.Mod(w, pub.N)
   252  
   253  	// e′ = 4∆^2
   254  	eprime := big.Int{}
   255  	eprime.Mul(delta, delta)     // faster than delta^TWO
   256  	eprime.Add(&eprime, &eprime) // faster than FOUR * eprime
   257  	eprime.Add(&eprime, &eprime)
   258  
   259  	// e′a + eb = 1
   260  	a := big.Int{}
   261  	b := big.Int{}
   262  	e := big.NewInt(int64(pub.E))
   263  	tmp := big.Int{}
   264  	tmp.GCD(&a, &b, &eprime, e)
   265  
   266  	// TODO You can compute a earlier and multiply a into the exponents used when computing w.
   267  	// w^a
   268  	wa := big.Int{}
   269  	wa.Exp(w, &a, pub.N) // TODO justification
   270  	// x^b
   271  	x := big.Int{}
   272  	x.SetBytes(msg)
   273  	xb := big.Int{}
   274  	xb.Exp(&x, &b, pub.N) // TODO justification
   275  	// y = w^a * x^b
   276  	y := big.Int{}
   277  	y.Mul(&wa, &xb).Mod(&y, pub.N)
   278  
   279  	// verify that signature is valid by checking x == y^e.
   280  	ye := big.Int{}
   281  	ye.Exp(&y, e, pub.N)
   282  	if ye.Cmp(&x) != 0 {
   283  		return nil, errors.New("rsa: internal error")
   284  	}
   285  
   286  	// ensure signature has the right size.
   287  	sig := y.FillBytes(make([]byte, pub.Size()))
   288  
   289  	return sig, nil
   290  }
   291  
   292  // computes lagrange Interpolation for the shares
   293  // i must be an id 0..l but not in S
   294  // j must be in S
   295  func computeLambda(delta *big.Int, S []SignShare, i, j int64) (*big.Int, error) {
   296  	if i == j {
   297  		return nil, errors.New("rsa_threshold: i and j can't be equal by precondition")
   298  	}
   299  	// these are just to check preconditions
   300  	foundi := false
   301  	foundj := false
   302  
   303  	// λ(s, i, j) = ∆( (  π{j'∈S\{j}} (i - j')  ) /  (  π{j'∈S\{j}} (j - j') ) )
   304  
   305  	num := int64(1)
   306  	den := int64(1)
   307  
   308  	// ∈ S
   309  	for _, s := range S {
   310  		// j'
   311  		jprime := int64(s.Index)
   312  		// S\{j}
   313  		if jprime == j {
   314  			foundj = true
   315  			continue
   316  		}
   317  		if jprime == i {
   318  			foundi = false
   319  			break
   320  		}
   321  		//  (i - j')
   322  		num *= i - jprime
   323  		// (j - j')
   324  		den *= j - jprime
   325  	}
   326  
   327  	// ∆ * (num/den)
   328  	var lambda big.Int
   329  	// (num/den)
   330  	lambda.Div(big.NewInt(num), big.NewInt(den))
   331  	// ∆ * (num/den)
   332  	lambda.Mul(delta, &lambda)
   333  
   334  	if foundi {
   335  		return nil, fmt.Errorf("rsa_threshold: i: %d should not be in S", i)
   336  	}
   337  
   338  	if !foundj {
   339  		return nil, fmt.Errorf("rsa_threshold: j: %d should be in S", j)
   340  	}
   341  
   342  	return &lambda, nil
   343  }