github.com/bigzoro/my_simplechain@v0.0.0-20240315012955-8ad0a2a29bb9/core/access_contoller/crypto/paillier/paillier.go (about)

     1  /*
     2  Copyright (C) BABEC. All rights reserved.
     3  Copyright (C) THL A29 Limited, a Tencent company. All rights reserved.
     4  
     5  SPDX-License-Identifier: Apache-2.0
     6  */
     7  
     8  package paillier
     9  
    10  import (
    11  	"crypto/rand"
    12  	"crypto/sha256"
    13  	"errors"
    14  	"io"
    15  	"math/big"
    16  	"reflect"
    17  )
    18  
    19  var (
    20  	one = big.NewInt(1)
    21  )
    22  
    23  // ErrMessageTooLong is returned when attempting to encrypt a message which is
    24  // too large for the size of the public key.
    25  var ErrMessageTooLong = errors.New("paillier: message too long for Paillier public key size")
    26  var ErrInvalidCiphertext = errors.New("paillier: invalid ciphertext")
    27  var ErrInvalidPlaintext = errors.New("paillier: invalid plaintext")
    28  var ErrInvalidPublicKey = errors.New("paillier: invalid public key")
    29  var ErrInvalidPrivateKey = errors.New("paillier: invalid private key")
    30  var ErrInvalidMismatch = errors.New("paillier: key mismatch")
    31  
    32  // PubKey represents the public part of a Paillier key.
    33  type PubKey struct {
    34  	N        *big.Int // modulus
    35  	G        *big.Int // n+1, since p and q are same length
    36  	NSquared *big.Int
    37  }
    38  
    39  // PrvKey represents a Paillier key.
    40  type PrvKey struct {
    41  	*PubKey
    42  	p         *big.Int
    43  	pp        *big.Int
    44  	pminusone *big.Int
    45  	q         *big.Int
    46  	qq        *big.Int
    47  	qminusone *big.Int
    48  	pinvq     *big.Int
    49  	hp        *big.Int
    50  	hq        *big.Int
    51  	n         *big.Int
    52  }
    53  
    54  type Ciphertext struct {
    55  	Ct       *big.Int
    56  	Checksum []byte
    57  }
    58  
    59  func GenKey() (*PrvKey, error) {
    60  	return generateKey(rand.Reader, 256)
    61  }
    62  
    63  // generateKey generates an Paillier keypair of the given bit size using the
    64  // random source random (for example, crypto/rand.Reader).
    65  func generateKey(random io.Reader, bits int) (*PrvKey, error) {
    66  	// First, begin generation of p in the background.
    67  	var p *big.Int
    68  	var errChan = make(chan error, 1)
    69  	go func() {
    70  		var err error
    71  		p, err = rand.Prime(random, bits/2)
    72  		errChan <- err
    73  	}()
    74  
    75  	// Now, find a prime q in the foreground.
    76  	q, err := rand.Prime(random, bits/2)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  
    81  	// Wait for generation of p to complete successfully.
    82  	if err := <-errChan; err != nil {
    83  		return nil, err
    84  	}
    85  
    86  	n := new(big.Int).Mul(p, q)
    87  	pp := new(big.Int).Mul(p, p)
    88  	qq := new(big.Int).Mul(q, q)
    89  
    90  	return &PrvKey{
    91  		PubKey: &PubKey{
    92  			N:        n,
    93  			NSquared: new(big.Int).Mul(n, n),
    94  			G:        new(big.Int).Add(n, one), // g = n + 1
    95  		},
    96  		p:         p,
    97  		pp:        pp,
    98  		pminusone: new(big.Int).Sub(p, one),
    99  		q:         q,
   100  		qq:        qq,
   101  		qminusone: new(big.Int).Sub(q, one),
   102  		pinvq:     new(big.Int).ModInverse(p, q),
   103  		hp:        h(p, pp, n),
   104  		hq:        h(q, qq, n),
   105  		n:         n,
   106  	}, nil
   107  
   108  }
   109  
   110  // hp hq
   111  func h(p *big.Int, pp *big.Int, n *big.Int) *big.Int {
   112  	gp := new(big.Int).Mod(new(big.Int).Sub(one, n), pp)
   113  
   114  	lp := l(gp, p)
   115  
   116  	hp := new(big.Int).ModInverse(lp, p)
   117  	return hp
   118  }
   119  
   120  func l(u *big.Int, n *big.Int) *big.Int {
   121  	return new(big.Int).Div(new(big.Int).Sub(u, one), n)
   122  }
   123  
   124  // Encrypt encrypts a plain text represented as a byte array. The passed plain
   125  // text MUST NOT be larger than the modulus of the passed public key.
   126  func (key *PubKey) Encrypt(plainText *big.Int) (*Ciphertext, error) {
   127  	if err := validatePubKey(key); err != nil {
   128  		return nil, err
   129  	}
   130  
   131  	if err := validatePlaintext(plainText); err != nil {
   132  		return nil, err
   133  	}
   134  
   135  	plaintext, err := AdjustPlaintextDomain(key, plainText)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  	c, _, err := EncryptAndNonce(key, plaintext)
   140  	if err != nil {
   141  		return nil, err
   142  	}
   143  
   144  	checksum, err := key.bindingCtPubKey(c.Bytes())
   145  
   146  	ct := &Ciphertext{
   147  		Ct:       c,
   148  		Checksum: checksum,
   149  	}
   150  	return ct, err
   151  }
   152  
   153  // EncryptAndNonce encrypts a plain text represented as a byte array, and in
   154  // addition, returns the nonce used during encryption. The passed plain text
   155  // MUST NOT be larger than the modulus of the passed public key.
   156  func EncryptAndNonce(pubKey *PubKey, plainText *big.Int) (*big.Int, *big.Int, error) {
   157  	r, err := rand.Int(rand.Reader, pubKey.N)
   158  	if err != nil {
   159  		return nil, nil, err
   160  	}
   161  	for new(big.Int).GCD(nil, nil, r, pubKey.N).Cmp(one) != 0 {
   162  		r = new(big.Int).Mod(new(big.Int).Add(r, one), pubKey.N)
   163  	}
   164  
   165  	c, err := EncryptWithNonce(pubKey, r, plainText)
   166  	if err != nil {
   167  		return nil, nil, err
   168  	}
   169  
   170  	return c, r, nil
   171  }
   172  
   173  // EncryptWithNonce encrypts a plain text represented as a byte array using the
   174  // provided nonce to perform encryption. The passed plain text MUST NOT be
   175  // larger than the modulus of the passed public key.
   176  func EncryptWithNonce(pubKey *PubKey, r *big.Int, pt *big.Int) (*big.Int, error) {
   177  	if pubKey.N.Cmp(pt) < 1 { // N < m
   178  		return nil, ErrMessageTooLong
   179  	}
   180  
   181  	// c = g^m * r^n mod n^2 = ((m*n+1) mod n^2) * r^n mod n^2
   182  	n := pubKey.N
   183  	c := new(big.Int).Mod(
   184  		new(big.Int).Mul(
   185  			new(big.Int).Mod(new(big.Int).Add(one, new(big.Int).Mul(pt, n)), pubKey.NSquared),
   186  			new(big.Int).Exp(r, n, pubKey.NSquared),
   187  		),
   188  		pubKey.NSquared,
   189  	)
   190  
   191  	return c, nil
   192  }
   193  
   194  // Decrypt decrypts the passed cipher text.
   195  func (key *PrvKey) Decrypt(cipherText *Ciphertext) (*big.Int, error) {
   196  	if err := validatePrvKey(key); err != nil {
   197  		return nil, err
   198  	}
   199  
   200  	if err := validateCiphertext(cipherText); err != nil {
   201  		return nil, err
   202  	}
   203  
   204  	if key.NSquared.Cmp(cipherText.Ct) < 1 { // c > n^2
   205  		return nil, ErrMessageTooLong
   206  	}
   207  
   208  	cp := new(big.Int).Exp(cipherText.Ct, key.pminusone, key.pp)
   209  	lp := l(cp, key.p)
   210  	mp := new(big.Int).Mod(new(big.Int).Mul(lp, key.hp), key.p)
   211  	cq := new(big.Int).Exp(cipherText.Ct, key.qminusone, key.qq)
   212  	lq := l(cq, key.q)
   213  
   214  	mqq := new(big.Int).Mul(lq, key.hq)
   215  	mq := new(big.Int).Mod(mqq, key.q)
   216  	m := crt(mp, mq, key)
   217  
   218  	plaintext, err := AdjustDecryptedDomain(key.PubKey, m)
   219  	return plaintext, err
   220  }
   221  
   222  func crt(mp *big.Int, mq *big.Int, privKey *PrvKey) *big.Int {
   223  	u := new(big.Int).Mod(new(big.Int).Mul(new(big.Int).Sub(mq, mp), privKey.pinvq), privKey.q)
   224  	m := new(big.Int).Add(mp, new(big.Int).Mul(u, privKey.p))
   225  	return new(big.Int).Mod(m, privKey.n)
   226  }
   227  
   228  func Neg(pk *PubKey, cipher *Ciphertext) (*Ciphertext, error) {
   229  	cipher.Ct = new(big.Int).ModInverse(cipher.Ct, pk.NSquared)
   230  	return cipher, nil
   231  }
   232  
   233  func (key *PrvKey) GetPubKey() (*PubKey, error) {
   234  	if err := validatePrvKey(key); err != nil {
   235  		return nil, err
   236  	}
   237  
   238  	return key.PubKey, nil
   239  }
   240  
   241  // Marshal encodes the PubKey as a byte slice.
   242  func (key *PubKey) Marshal() ([]byte, error) {
   243  	if err := validatePubKey(key); err != nil {
   244  		return nil, err
   245  	}
   246  	// public key io
   247  	return []byte(GetPublicKeyHex(key)), nil
   248  }
   249  
   250  // Unmarshal recovers the PubKey from an encoded byte slice.
   251  func (key *PubKey) Unmarshal(pubKeyBytes []byte) error {
   252  	k, err := GetPublicKeyFromHex(string(pubKeyBytes))
   253  	if err != nil {
   254  		return err
   255  	}
   256  
   257  	key.N = k.N
   258  	key.NSquared = k.NSquared
   259  	key.G = k.G
   260  
   261  	return nil
   262  }
   263  
   264  func (ct *Ciphertext) Marshal() ([]byte, error) {
   265  	if err := validateCiphertext(ct); err != nil {
   266  		return nil, ErrInvalidCiphertext
   267  	}
   268  
   269  	ctBytes := ct.Ct.Bytes()
   270  	return append(ct.Checksum, ctBytes...), nil
   271  }
   272  
   273  func (ct *Ciphertext) Unmarshal(ctBytes []byte) error {
   274  	if ctBytes == nil {
   275  		return ErrInvalidCiphertext
   276  	}
   277  
   278  	if ct.Ct == nil {
   279  		ct.Ct = new(big.Int)
   280  	}
   281  
   282  	ct.Ct.SetBytes(ctBytes[defaultChecksumSize:])
   283  	ct.Checksum = ctBytes[:defaultChecksumSize]
   284  	return nil
   285  }
   286  
   287  // Marshal encodes the PrvKey as a byte slice.
   288  func (key *PrvKey) Marshal() ([]byte, error) {
   289  	if err := validatePrvKey(key); err != nil {
   290  		return nil, err
   291  	}
   292  
   293  	tempBytes := []byte(GetPrivateKeyHex(key))
   294  
   295  	return tempBytes, nil
   296  }
   297  
   298  // Unmarshal recovers the PrvKey from an encoded byte slice.
   299  func (key *PrvKey) Unmarshal(prvKeyBytes []byte) error {
   300  	if prvKeyBytes == nil {
   301  		return ErrInvalidPrivateKey
   302  	}
   303  
   304  	k, err := GetPrivateKeyFromHex(string(prvKeyBytes))
   305  	if err != nil {
   306  		return ErrInvalidPrivateKey
   307  	}
   308  
   309  	key.PubKey = k.PubKey
   310  	key.p = k.p
   311  	key.pp = k.pp
   312  	key.pminusone = k.pminusone
   313  	key.q = k.q
   314  	key.qq = k.qq
   315  	key.qminusone = k.qminusone
   316  	key.pinvq = k.pinvq
   317  	key.hp = k.hp
   318  	key.hq = k.hq
   319  	key.n = k.n
   320  
   321  	return nil
   322  }
   323  
   324  func (ct *Ciphertext) GetChecksum() ([]byte, error) {
   325  	if err := validateCiphertext(ct); err != nil {
   326  		return nil, err
   327  	}
   328  
   329  	return ct.Checksum, nil
   330  }
   331  
   332  func (ct *Ciphertext) GetCtBytes() ([]byte, error) {
   333  	if err := validateCiphertext(ct); err != nil {
   334  		return nil, err
   335  	}
   336  
   337  	return ct.Ct.Bytes(), nil
   338  }
   339  
   340  func (ct *Ciphertext) GetCtStr() (string, error) {
   341  	if err := validateCiphertext(ct); err != nil {
   342  		return "", err
   343  	}
   344  
   345  	return ct.Ct.String(), nil
   346  }
   347  
   348  // AddCiphertext homomorphically adds together two cipher texts.
   349  // To do this we multiply the two cipher texts, upon decryption, the resulting
   350  // plain text will be the sum of the corresponding plain texts.
   351  func (key *PubKey) AddCiphertext(cipher1, cipher2 *Ciphertext) (*Ciphertext, error) {
   352  	if err := validatePubKey(key); err != nil {
   353  		return nil, err
   354  	}
   355  
   356  	if err := validateCiphertext(cipher1, cipher2); err != nil {
   357  		return nil, err
   358  	}
   359  
   360  	if !key.checkOperand(cipher1, cipher2) {
   361  		return nil, ErrInvalidMismatch
   362  	}
   363  
   364  	x := cipher1.Ct
   365  	y := cipher2.Ct
   366  	// x * y mod n^2
   367  	c := new(big.Int).Mod(
   368  		new(big.Int).Mul(x, y),
   369  		key.NSquared,
   370  	)
   371  
   372  	return key.constructCiphertext(c)
   373  }
   374  
   375  // AddPlaintext homomorphically adds a passed constant to the encrypted integer
   376  // (our cipher text). We do this by multiplying the constant with our
   377  // ciphertext. Upon decryption, the resulting plain text will be the sum of
   378  // the plaintext integer and the constant.
   379  func (key *PubKey) AddPlaintext(cipher *Ciphertext, constant *big.Int) (*Ciphertext, error) {
   380  	if err := validatePubKey(key); err != nil {
   381  		return nil, err
   382  	}
   383  
   384  	if err := validateCiphertext(cipher); err != nil {
   385  		return nil, err
   386  	}
   387  
   388  	if err := validatePlaintext(constant); err != nil {
   389  		return nil, err
   390  	}
   391  
   392  	if !key.checkOperand(cipher) {
   393  		return nil, ErrInvalidMismatch
   394  	}
   395  
   396  	c := cipher.Ct
   397  	x := constant
   398  
   399  	// c * g ^ x mod n^2
   400  	c = new(big.Int).Mod(
   401  		new(big.Int).Mul(c, new(big.Int).Exp(key.G, x, key.NSquared)),
   402  		key.NSquared,
   403  	)
   404  
   405  	return key.constructCiphertext(c)
   406  }
   407  
   408  func (key *PubKey) SubCiphertext(cipher1, cipher2 *Ciphertext) (*Ciphertext, error) {
   409  	if err := validatePubKey(key); err != nil {
   410  		return nil, err
   411  	}
   412  
   413  	if err := validateCiphertext(cipher1, cipher2); err != nil {
   414  		return nil, err
   415  	}
   416  
   417  	if !key.checkOperand(cipher1, cipher2) {
   418  		return nil, ErrInvalidMismatch
   419  	}
   420  
   421  	c1 := cipher1.Ct
   422  	c2 := cipher2.Ct
   423  	c2Inversed := new(big.Int).ModInverse(c2, key.NSquared)
   424  	c := new(big.Int).Mod(new(big.Int).Mul(c1, c2Inversed), key.NSquared)
   425  
   426  	return key.constructCiphertext(c)
   427  }
   428  
   429  func (key *PubKey) SubPlaintext(cipher *Ciphertext, constant *big.Int) (*Ciphertext, error) {
   430  	if err := validatePubKey(key); err != nil {
   431  		return nil, err
   432  	}
   433  
   434  	if err := validateCiphertext(cipher); err != nil {
   435  		return nil, err
   436  	}
   437  
   438  	if err := validatePlaintext(constant); err != nil {
   439  		return nil, err
   440  	}
   441  
   442  	if !key.checkOperand(cipher) {
   443  		return nil, ErrInvalidMismatch
   444  	}
   445  
   446  	plain := constant
   447  	plain = new(big.Int).Mod(new(big.Int).Add(new(big.Int).Mul(plain, key.N), one), key.NSquared)
   448  	c := cipher.Ct
   449  	c = new(big.Int).Mod(new(big.Int).Mul(c, new(big.Int).ModInverse(plain, key.NSquared)), key.NSquared)
   450  
   451  	return key.constructCiphertext(c)
   452  }
   453  
   454  func (key *PubKey) SubByConstant(pubKey *PubKey, cipher *Ciphertext, constant *big.Int) (*Ciphertext, error) {
   455  	cipherNeg, err := Neg(pubKey, cipher)
   456  	if err != nil {
   457  		return nil, err
   458  	}
   459  	return key.AddPlaintext(cipherNeg, constant)
   460  }
   461  
   462  // NumMul homomorphically multiplies an encrypted integer (cipher text) by a
   463  // constant. We do this by raising our cipher text to the power of the passed
   464  // constant. Upon decryption, the resulting plain text will be the product of
   465  // the plaintext integer and the constant.
   466  func (key *PubKey) NumMul(cipher *Ciphertext, constant *big.Int) (*Ciphertext, error) {
   467  	if err := validatePubKey(key); err != nil {
   468  		return nil, err
   469  	}
   470  
   471  	if err := validateCiphertext(cipher); err != nil {
   472  		return nil, err
   473  	}
   474  
   475  	if err := validatePlaintext(constant); err != nil {
   476  		return nil, err
   477  	}
   478  
   479  	if !key.checkOperand(cipher) {
   480  		return nil, ErrInvalidMismatch
   481  	}
   482  
   483  	c := new(big.Int).Exp(cipher.Ct, constant, key.NSquared)
   484  
   485  	return key.constructCiphertext(c)
   486  }
   487  
   488  func (key *PubKey) constructCiphertext(ciphertext *big.Int) (*Ciphertext, error) {
   489  	checksum, err := key.bindingCtPubKey(ciphertext.Bytes())
   490  	if err != nil {
   491  		return nil, err
   492  	}
   493  
   494  	ct := &Ciphertext{
   495  		Ct:       ciphertext,
   496  		Checksum: checksum,
   497  	}
   498  
   499  	return ct, nil
   500  }
   501  
   502  func (key *PubKey) bindingCtPubKey(ciphertext []byte) ([]byte, error) {
   503  	pubKeyBytes, err := key.Marshal()
   504  	if ciphertext == nil {
   505  		return nil, ErrInvalidCiphertext
   506  	}
   507  
   508  	if err != nil {
   509  		return nil, err
   510  	}
   511  
   512  	checksum := sha256.Sum256(append(pubKeyBytes, ciphertext[:]...))
   513  	return checksum[:defaultChecksumSize], nil
   514  }
   515  
   516  func (key *PubKey) checkOperand(cts ...*Ciphertext) bool {
   517  	for _, ct := range cts {
   518  		if !key.ChecksumVerify(ct) {
   519  			return false
   520  		}
   521  	}
   522  	return true
   523  }
   524  
   525  // ChecksumVerify verifying public key ciphertext pairs
   526  func (key *PubKey) ChecksumVerify(ct *Ciphertext) bool {
   527  	if err := validatePubKey(key); err != nil {
   528  		return false
   529  	}
   530  
   531  	if err := validateCiphertext(ct); err != nil {
   532  		return false
   533  	}
   534  
   535  	keyBytes, err := key.Marshal()
   536  	if err != nil {
   537  		return false
   538  	}
   539  
   540  	ctBytes, err := ct.GetCtBytes()
   541  	if err != nil {
   542  		return false
   543  	}
   544  
   545  	currentChecksum, err := ct.GetChecksum()
   546  	if err != nil {
   547  		return false
   548  	}
   549  
   550  	checksum := sha256.Sum256(append(keyBytes, ctBytes...))
   551  	return reflect.DeepEqual(checksum[:defaultChecksumSize], currentChecksum)
   552  }