gitee.com/lh-her-team/common@v1.5.1/crypto/paillier/paillier.go (about)

     1  package paillier
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/sha256"
     6  	"errors"
     7  	"io"
     8  	"math/big"
     9  	"reflect"
    10  )
    11  
    12  var (
    13  	one = big.NewInt(1)
    14  )
    15  
    16  // ErrMessageTooLong is returned when attempting to encrypt a message which is
    17  // too large for the size of the public key.
    18  var ErrMessageTooLong = errors.New("paillier: message too long for Paillier public key size")
    19  var ErrInvalidCiphertext = errors.New("paillier: invalid ciphertext")
    20  var ErrInvalidPlaintext = errors.New("paillier: invalid plaintext")
    21  var ErrInvalidPublicKey = errors.New("paillier: invalid public key")
    22  var ErrInvalidPrivateKey = errors.New("paillier: invalid private key")
    23  var ErrInvalidMismatch = errors.New("paillier: key mismatch")
    24  
    25  // PubKey represents the public part of a Paillier key.
    26  type PubKey struct {
    27  	N        *big.Int // modulus
    28  	G        *big.Int // n+1, since p and q are same length
    29  	NSquared *big.Int
    30  }
    31  
    32  // PrvKey represents a Paillier key.
    33  type PrvKey struct {
    34  	*PubKey
    35  	p         *big.Int
    36  	pp        *big.Int
    37  	pminusone *big.Int
    38  	q         *big.Int
    39  	qq        *big.Int
    40  	qminusone *big.Int
    41  	pinvq     *big.Int
    42  	hp        *big.Int
    43  	hq        *big.Int
    44  	n         *big.Int
    45  }
    46  
    47  type Ciphertext struct {
    48  	Ct       *big.Int
    49  	Checksum []byte
    50  }
    51  
    52  func GenKey() (*PrvKey, error) {
    53  	return generateKey(rand.Reader, 256)
    54  }
    55  
    56  // generateKey generates an Paillier keypair of the given bit size using the
    57  // random source random (for example, crypto/rand.Reader).
    58  func generateKey(random io.Reader, bits int) (*PrvKey, error) {
    59  	// First, begin generation of p in the background.
    60  	var p *big.Int
    61  	var errChan = make(chan error, 1)
    62  	go func() {
    63  		var err error
    64  		p, err = rand.Prime(random, bits/2)
    65  		errChan <- err
    66  	}()
    67  	// Now, find a prime q in the foreground.
    68  	q, err := rand.Prime(random, bits/2)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	// Wait for generation of p to complete successfully.
    73  	if err := <-errChan; err != nil {
    74  		return nil, err
    75  	}
    76  	n := new(big.Int).Mul(p, q)
    77  	pp := new(big.Int).Mul(p, p)
    78  	qq := new(big.Int).Mul(q, q)
    79  	return &PrvKey{
    80  		PubKey: &PubKey{
    81  			N:        n,
    82  			NSquared: new(big.Int).Mul(n, n),
    83  			G:        new(big.Int).Add(n, one), // g = n + 1
    84  		},
    85  		p:         p,
    86  		pp:        pp,
    87  		pminusone: new(big.Int).Sub(p, one),
    88  		q:         q,
    89  		qq:        qq,
    90  		qminusone: new(big.Int).Sub(q, one),
    91  		pinvq:     new(big.Int).ModInverse(p, q),
    92  		hp:        h(p, pp, n),
    93  		hq:        h(q, qq, n),
    94  		n:         n,
    95  	}, nil
    96  }
    97  
    98  // hp hq
    99  func h(p *big.Int, pp *big.Int, n *big.Int) *big.Int {
   100  	gp := new(big.Int).Mod(new(big.Int).Sub(one, n), pp)
   101  	lp := l(gp, p)
   102  	hp := new(big.Int).ModInverse(lp, p)
   103  	return hp
   104  }
   105  
   106  func l(u *big.Int, n *big.Int) *big.Int {
   107  	return new(big.Int).Div(new(big.Int).Sub(u, one), n)
   108  }
   109  
   110  // Encrypt encrypts a plain text represented as a byte array. The passed plain
   111  // text MUST NOT be larger than the modulus of the passed public key.
   112  func (key *PubKey) Encrypt(plainText *big.Int) (*Ciphertext, error) {
   113  	if err := validatePubKey(key); err != nil {
   114  		return nil, err
   115  	}
   116  	if err := validatePlaintext(plainText); err != nil {
   117  		return nil, err
   118  	}
   119  	plaintext, err := AdjustPlaintextDomain(key, plainText)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  	c, _, err := EncryptAndNonce(key, plaintext)
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  	checksum, err := key.bindingCtPubKey(c.Bytes())
   128  	ct := &Ciphertext{
   129  		Ct:       c,
   130  		Checksum: checksum,
   131  	}
   132  	return ct, err
   133  }
   134  
   135  // EncryptAndNonce encrypts a plain text represented as a byte array, and in
   136  // addition, returns the nonce used during encryption. The passed plain text
   137  // MUST NOT be larger than the modulus of the passed public key.
   138  func EncryptAndNonce(pubKey *PubKey, plainText *big.Int) (*big.Int, *big.Int, error) {
   139  	r, err := rand.Int(rand.Reader, pubKey.N)
   140  	if err != nil {
   141  		return nil, nil, err
   142  	}
   143  	for new(big.Int).GCD(nil, nil, r, pubKey.N).Cmp(one) != 0 {
   144  		r = new(big.Int).Mod(new(big.Int).Add(r, one), pubKey.N)
   145  	}
   146  	c, err := EncryptWithNonce(pubKey, r, plainText)
   147  	if err != nil {
   148  		return nil, nil, err
   149  	}
   150  	return c, r, nil
   151  }
   152  
   153  // EncryptWithNonce encrypts a plain text represented as a byte array using the
   154  // provided nonce to perform encryption. The passed plain text MUST NOT be
   155  // larger than the modulus of the passed public key.
   156  func EncryptWithNonce(pubKey *PubKey, r *big.Int, pt *big.Int) (*big.Int, error) {
   157  	if pubKey.N.Cmp(pt) < 1 { // N < m
   158  		return nil, ErrMessageTooLong
   159  	}
   160  	// c = g^m * r^n mod n^2 = ((m*n+1) mod n^2) * r^n mod n^2
   161  	n := pubKey.N
   162  	c := new(big.Int).Mod(
   163  		new(big.Int).Mul(
   164  			new(big.Int).Mod(new(big.Int).Add(one, new(big.Int).Mul(pt, n)), pubKey.NSquared),
   165  			new(big.Int).Exp(r, n, pubKey.NSquared),
   166  		),
   167  		pubKey.NSquared,
   168  	)
   169  	return c, nil
   170  }
   171  
   172  // Decrypt decrypts the passed cipher text.
   173  func (key *PrvKey) Decrypt(cipherText *Ciphertext) (*big.Int, error) {
   174  	if err := validatePrvKey(key); err != nil {
   175  		return nil, err
   176  	}
   177  	if err := validateCiphertext(cipherText); err != nil {
   178  		return nil, err
   179  	}
   180  	if key.NSquared.Cmp(cipherText.Ct) < 1 { // c > n^2
   181  		return nil, ErrMessageTooLong
   182  	}
   183  	cp := new(big.Int).Exp(cipherText.Ct, key.pminusone, key.pp)
   184  	lp := l(cp, key.p)
   185  	mp := new(big.Int).Mod(new(big.Int).Mul(lp, key.hp), key.p)
   186  	cq := new(big.Int).Exp(cipherText.Ct, key.qminusone, key.qq)
   187  	lq := l(cq, key.q)
   188  	mqq := new(big.Int).Mul(lq, key.hq)
   189  	mq := new(big.Int).Mod(mqq, key.q)
   190  	m := crt(mp, mq, key)
   191  	plaintext, err := AdjustDecryptedDomain(key.PubKey, m)
   192  	return plaintext, err
   193  }
   194  
   195  func crt(mp *big.Int, mq *big.Int, privKey *PrvKey) *big.Int {
   196  	u := new(big.Int).Mod(new(big.Int).Mul(new(big.Int).Sub(mq, mp), privKey.pinvq), privKey.q)
   197  	m := new(big.Int).Add(mp, new(big.Int).Mul(u, privKey.p))
   198  	return new(big.Int).Mod(m, privKey.n)
   199  }
   200  
   201  func Neg(pk *PubKey, cipher *Ciphertext) (*Ciphertext, error) {
   202  	cipher.Ct = new(big.Int).ModInverse(cipher.Ct, pk.NSquared)
   203  	return cipher, nil
   204  }
   205  
   206  func (key *PrvKey) GetPubKey() (*PubKey, error) {
   207  	if err := validatePrvKey(key); err != nil {
   208  		return nil, err
   209  	}
   210  
   211  	return key.PubKey, nil
   212  }
   213  
   214  // Marshal encodes the PubKey as a byte slice.
   215  func (key *PubKey) Marshal() ([]byte, error) {
   216  	if err := validatePubKey(key); err != nil {
   217  		return nil, err
   218  	}
   219  	// public key io
   220  	return []byte(GetPublicKeyHex(key)), nil
   221  }
   222  
   223  // Unmarshal recovers the PubKey from an encoded byte slice.
   224  func (key *PubKey) Unmarshal(pubKeyBytes []byte) error {
   225  	k, err := GetPublicKeyFromHex(string(pubKeyBytes))
   226  	if err != nil {
   227  		return err
   228  	}
   229  	key.N = k.N
   230  	key.NSquared = k.NSquared
   231  	key.G = k.G
   232  	return nil
   233  }
   234  
   235  func (ct *Ciphertext) Marshal() ([]byte, error) {
   236  	if err := validateCiphertext(ct); err != nil {
   237  		return nil, ErrInvalidCiphertext
   238  	}
   239  	ctBytes := ct.Ct.Bytes()
   240  	return append(ct.Checksum, ctBytes...), nil
   241  }
   242  
   243  func (ct *Ciphertext) Unmarshal(ctBytes []byte) error {
   244  	if ctBytes == nil {
   245  		return ErrInvalidCiphertext
   246  	}
   247  	if ct.Ct == nil {
   248  		ct.Ct = new(big.Int)
   249  	}
   250  	ct.Ct.SetBytes(ctBytes[defaultChecksumSize:])
   251  	ct.Checksum = ctBytes[:defaultChecksumSize]
   252  	return nil
   253  }
   254  
   255  // Marshal encodes the PrvKey as a byte slice.
   256  func (key *PrvKey) Marshal() ([]byte, error) {
   257  	if err := validatePrvKey(key); err != nil {
   258  		return nil, err
   259  	}
   260  	tempBytes := []byte(GetPrivateKeyHex(key))
   261  	return tempBytes, nil
   262  }
   263  
   264  // Unmarshal recovers the PrvKey from an encoded byte slice.
   265  func (key *PrvKey) Unmarshal(prvKeyBytes []byte) error {
   266  	if prvKeyBytes == nil {
   267  		return ErrInvalidPrivateKey
   268  	}
   269  	k, err := GetPrivateKeyFromHex(string(prvKeyBytes))
   270  	if err != nil {
   271  		return ErrInvalidPrivateKey
   272  	}
   273  	key.PubKey = k.PubKey
   274  	key.p = k.p
   275  	key.pp = k.pp
   276  	key.pminusone = k.pminusone
   277  	key.q = k.q
   278  	key.qq = k.qq
   279  	key.qminusone = k.qminusone
   280  	key.pinvq = k.pinvq
   281  	key.hp = k.hp
   282  	key.hq = k.hq
   283  	key.n = k.n
   284  	return nil
   285  }
   286  
   287  func (ct *Ciphertext) GetChecksum() ([]byte, error) {
   288  	if err := validateCiphertext(ct); err != nil {
   289  		return nil, err
   290  	}
   291  	return ct.Checksum, nil
   292  }
   293  
   294  func (ct *Ciphertext) GetCtBytes() ([]byte, error) {
   295  	if err := validateCiphertext(ct); err != nil {
   296  		return nil, err
   297  	}
   298  	return ct.Ct.Bytes(), nil
   299  }
   300  
   301  func (ct *Ciphertext) GetCtStr() (string, error) {
   302  	if err := validateCiphertext(ct); err != nil {
   303  		return "", err
   304  	}
   305  	return ct.Ct.String(), nil
   306  }
   307  
   308  // AddCiphertext homomorphically adds together two cipher texts.
   309  // To do this we multiply the two cipher texts, upon decryption, the resulting
   310  // plain text will be the sum of the corresponding plain texts.
   311  func (key *PubKey) AddCiphertext(cipher1, cipher2 *Ciphertext) (*Ciphertext, error) {
   312  	if err := validatePubKey(key); err != nil {
   313  		return nil, err
   314  	}
   315  	if err := validateCiphertext(cipher1, cipher2); err != nil {
   316  		return nil, err
   317  	}
   318  	if !key.checkOperand(cipher1, cipher2) {
   319  		return nil, ErrInvalidMismatch
   320  	}
   321  	x := cipher1.Ct
   322  	y := cipher2.Ct
   323  	// x * y mod n^2
   324  	c := new(big.Int).Mod(
   325  		new(big.Int).Mul(x, y),
   326  		key.NSquared,
   327  	)
   328  	return key.constructCiphertext(c)
   329  }
   330  
   331  // AddPlaintext homomorphically adds a passed constant to the encrypted integer
   332  // (our cipher text). We do this by multiplying the constant with our
   333  // ciphertext. Upon decryption, the resulting plain text will be the sum of
   334  // the plaintext integer and the constant.
   335  func (key *PubKey) AddPlaintext(cipher *Ciphertext, constant *big.Int) (*Ciphertext, error) {
   336  	if err := validatePubKey(key); err != nil {
   337  		return nil, err
   338  	}
   339  	if err := validateCiphertext(cipher); err != nil {
   340  		return nil, err
   341  	}
   342  	if err := validatePlaintext(constant); err != nil {
   343  		return nil, err
   344  	}
   345  	if !key.checkOperand(cipher) {
   346  		return nil, ErrInvalidMismatch
   347  	}
   348  	c := cipher.Ct
   349  	x := constant
   350  	// c * g ^ x mod n^2
   351  	c = new(big.Int).Mod(
   352  		new(big.Int).Mul(c, new(big.Int).Exp(key.G, x, key.NSquared)),
   353  		key.NSquared,
   354  	)
   355  	return key.constructCiphertext(c)
   356  }
   357  
   358  func (key *PubKey) SubCiphertext(cipher1, cipher2 *Ciphertext) (*Ciphertext, error) {
   359  	if err := validatePubKey(key); err != nil {
   360  		return nil, err
   361  	}
   362  	if err := validateCiphertext(cipher1, cipher2); err != nil {
   363  		return nil, err
   364  	}
   365  	if !key.checkOperand(cipher1, cipher2) {
   366  		return nil, ErrInvalidMismatch
   367  	}
   368  	c1 := cipher1.Ct
   369  	c2 := cipher2.Ct
   370  	c2Inversed := new(big.Int).ModInverse(c2, key.NSquared)
   371  	c := new(big.Int).Mod(new(big.Int).Mul(c1, c2Inversed), key.NSquared)
   372  	return key.constructCiphertext(c)
   373  }
   374  
   375  func (key *PubKey) SubPlaintext(cipher *Ciphertext, constant *big.Int) (*Ciphertext, error) {
   376  	if err := validatePubKey(key); err != nil {
   377  		return nil, err
   378  	}
   379  	if err := validateCiphertext(cipher); err != nil {
   380  		return nil, err
   381  	}
   382  	if err := validatePlaintext(constant); err != nil {
   383  		return nil, err
   384  	}
   385  	if !key.checkOperand(cipher) {
   386  		return nil, ErrInvalidMismatch
   387  	}
   388  	plain := constant
   389  	plain = new(big.Int).Mod(new(big.Int).Add(new(big.Int).Mul(plain, key.N), one), key.NSquared)
   390  	c := cipher.Ct
   391  	c = new(big.Int).Mod(new(big.Int).Mul(c, new(big.Int).ModInverse(plain, key.NSquared)), key.NSquared)
   392  	return key.constructCiphertext(c)
   393  }
   394  
   395  func (key *PubKey) SubByConstant(pubKey *PubKey, cipher *Ciphertext, constant *big.Int) (*Ciphertext, error) {
   396  	cipherNeg, err := Neg(pubKey, cipher)
   397  	if err != nil {
   398  		return nil, err
   399  	}
   400  	return key.AddPlaintext(cipherNeg, constant)
   401  }
   402  
   403  // NumMul homomorphically multiplies an encrypted integer (cipher text) by a
   404  // constant. We do this by raising our cipher text to the power of the passed
   405  // constant. Upon decryption, the resulting plain text will be the product of
   406  // the plaintext integer and the constant.
   407  func (key *PubKey) NumMul(cipher *Ciphertext, constant *big.Int) (*Ciphertext, error) {
   408  	if err := validatePubKey(key); err != nil {
   409  		return nil, err
   410  	}
   411  	if err := validateCiphertext(cipher); err != nil {
   412  		return nil, err
   413  	}
   414  	if err := validatePlaintext(constant); err != nil {
   415  		return nil, err
   416  	}
   417  	if !key.checkOperand(cipher) {
   418  		return nil, ErrInvalidMismatch
   419  	}
   420  	c := new(big.Int).Exp(cipher.Ct, constant, key.NSquared)
   421  	return key.constructCiphertext(c)
   422  }
   423  
   424  func (key *PubKey) constructCiphertext(ciphertext *big.Int) (*Ciphertext, error) {
   425  	checksum, err := key.bindingCtPubKey(ciphertext.Bytes())
   426  	if err != nil {
   427  		return nil, err
   428  	}
   429  	ct := &Ciphertext{
   430  		Ct:       ciphertext,
   431  		Checksum: checksum,
   432  	}
   433  	return ct, nil
   434  }
   435  
   436  func (key *PubKey) bindingCtPubKey(ciphertext []byte) ([]byte, error) {
   437  	pubKeyBytes, err := key.Marshal()
   438  	if ciphertext == nil {
   439  		return nil, ErrInvalidCiphertext
   440  	}
   441  	if err != nil {
   442  		return nil, err
   443  	}
   444  	checksum := sha256.Sum256(append(pubKeyBytes, ciphertext[:]...))
   445  	return checksum[:defaultChecksumSize], nil
   446  }
   447  
   448  func (key *PubKey) checkOperand(cts ...*Ciphertext) bool {
   449  	for _, ct := range cts {
   450  		if !key.ChecksumVerify(ct) {
   451  			return false
   452  		}
   453  	}
   454  	return true
   455  }
   456  
   457  // ChecksumVerify verifying public key ciphertext pairs
   458  func (key *PubKey) ChecksumVerify(ct *Ciphertext) bool {
   459  	if err := validatePubKey(key); err != nil {
   460  		return false
   461  	}
   462  	if err := validateCiphertext(ct); err != nil {
   463  		return false
   464  	}
   465  	keyBytes, err := key.Marshal()
   466  	if err != nil {
   467  		return false
   468  	}
   469  	ctBytes, err := ct.GetCtBytes()
   470  	if err != nil {
   471  		return false
   472  	}
   473  	currentChecksum, err := ct.GetChecksum()
   474  	if err != nil {
   475  		return false
   476  	}
   477  	checksum := sha256.Sum256(append(keyBytes, ctBytes...))
   478  	return reflect.DeepEqual(checksum[:defaultChecksumSize], currentChecksum)
   479  }