gitee.com/lh-her-team/common@v1.5.1/crypto/paillier/paillier.go (about) 1 package paillier 2 3 import ( 4 "crypto/rand" 5 "crypto/sha256" 6 "errors" 7 "io" 8 "math/big" 9 "reflect" 10 ) 11 12 var ( 13 one = big.NewInt(1) 14 ) 15 16 // ErrMessageTooLong is returned when attempting to encrypt a message which is 17 // too large for the size of the public key. 18 var ErrMessageTooLong = errors.New("paillier: message too long for Paillier public key size") 19 var ErrInvalidCiphertext = errors.New("paillier: invalid ciphertext") 20 var ErrInvalidPlaintext = errors.New("paillier: invalid plaintext") 21 var ErrInvalidPublicKey = errors.New("paillier: invalid public key") 22 var ErrInvalidPrivateKey = errors.New("paillier: invalid private key") 23 var ErrInvalidMismatch = errors.New("paillier: key mismatch") 24 25 // PubKey represents the public part of a Paillier key. 26 type PubKey struct { 27 N *big.Int // modulus 28 G *big.Int // n+1, since p and q are same length 29 NSquared *big.Int 30 } 31 32 // PrvKey represents a Paillier key. 33 type PrvKey struct { 34 *PubKey 35 p *big.Int 36 pp *big.Int 37 pminusone *big.Int 38 q *big.Int 39 qq *big.Int 40 qminusone *big.Int 41 pinvq *big.Int 42 hp *big.Int 43 hq *big.Int 44 n *big.Int 45 } 46 47 type Ciphertext struct { 48 Ct *big.Int 49 Checksum []byte 50 } 51 52 func GenKey() (*PrvKey, error) { 53 return generateKey(rand.Reader, 256) 54 } 55 56 // generateKey generates an Paillier keypair of the given bit size using the 57 // random source random (for example, crypto/rand.Reader). 58 func generateKey(random io.Reader, bits int) (*PrvKey, error) { 59 // First, begin generation of p in the background. 60 var p *big.Int 61 var errChan = make(chan error, 1) 62 go func() { 63 var err error 64 p, err = rand.Prime(random, bits/2) 65 errChan <- err 66 }() 67 // Now, find a prime q in the foreground. 68 q, err := rand.Prime(random, bits/2) 69 if err != nil { 70 return nil, err 71 } 72 // Wait for generation of p to complete successfully. 73 if err := <-errChan; err != nil { 74 return nil, err 75 } 76 n := new(big.Int).Mul(p, q) 77 pp := new(big.Int).Mul(p, p) 78 qq := new(big.Int).Mul(q, q) 79 return &PrvKey{ 80 PubKey: &PubKey{ 81 N: n, 82 NSquared: new(big.Int).Mul(n, n), 83 G: new(big.Int).Add(n, one), // g = n + 1 84 }, 85 p: p, 86 pp: pp, 87 pminusone: new(big.Int).Sub(p, one), 88 q: q, 89 qq: qq, 90 qminusone: new(big.Int).Sub(q, one), 91 pinvq: new(big.Int).ModInverse(p, q), 92 hp: h(p, pp, n), 93 hq: h(q, qq, n), 94 n: n, 95 }, nil 96 } 97 98 // hp hq 99 func h(p *big.Int, pp *big.Int, n *big.Int) *big.Int { 100 gp := new(big.Int).Mod(new(big.Int).Sub(one, n), pp) 101 lp := l(gp, p) 102 hp := new(big.Int).ModInverse(lp, p) 103 return hp 104 } 105 106 func l(u *big.Int, n *big.Int) *big.Int { 107 return new(big.Int).Div(new(big.Int).Sub(u, one), n) 108 } 109 110 // Encrypt encrypts a plain text represented as a byte array. The passed plain 111 // text MUST NOT be larger than the modulus of the passed public key. 112 func (key *PubKey) Encrypt(plainText *big.Int) (*Ciphertext, error) { 113 if err := validatePubKey(key); err != nil { 114 return nil, err 115 } 116 if err := validatePlaintext(plainText); err != nil { 117 return nil, err 118 } 119 plaintext, err := AdjustPlaintextDomain(key, plainText) 120 if err != nil { 121 return nil, err 122 } 123 c, _, err := EncryptAndNonce(key, plaintext) 124 if err != nil { 125 return nil, err 126 } 127 checksum, err := key.bindingCtPubKey(c.Bytes()) 128 ct := &Ciphertext{ 129 Ct: c, 130 Checksum: checksum, 131 } 132 return ct, err 133 } 134 135 // EncryptAndNonce encrypts a plain text represented as a byte array, and in 136 // addition, returns the nonce used during encryption. The passed plain text 137 // MUST NOT be larger than the modulus of the passed public key. 138 func EncryptAndNonce(pubKey *PubKey, plainText *big.Int) (*big.Int, *big.Int, error) { 139 r, err := rand.Int(rand.Reader, pubKey.N) 140 if err != nil { 141 return nil, nil, err 142 } 143 for new(big.Int).GCD(nil, nil, r, pubKey.N).Cmp(one) != 0 { 144 r = new(big.Int).Mod(new(big.Int).Add(r, one), pubKey.N) 145 } 146 c, err := EncryptWithNonce(pubKey, r, plainText) 147 if err != nil { 148 return nil, nil, err 149 } 150 return c, r, nil 151 } 152 153 // EncryptWithNonce encrypts a plain text represented as a byte array using the 154 // provided nonce to perform encryption. The passed plain text MUST NOT be 155 // larger than the modulus of the passed public key. 156 func EncryptWithNonce(pubKey *PubKey, r *big.Int, pt *big.Int) (*big.Int, error) { 157 if pubKey.N.Cmp(pt) < 1 { // N < m 158 return nil, ErrMessageTooLong 159 } 160 // c = g^m * r^n mod n^2 = ((m*n+1) mod n^2) * r^n mod n^2 161 n := pubKey.N 162 c := new(big.Int).Mod( 163 new(big.Int).Mul( 164 new(big.Int).Mod(new(big.Int).Add(one, new(big.Int).Mul(pt, n)), pubKey.NSquared), 165 new(big.Int).Exp(r, n, pubKey.NSquared), 166 ), 167 pubKey.NSquared, 168 ) 169 return c, nil 170 } 171 172 // Decrypt decrypts the passed cipher text. 173 func (key *PrvKey) Decrypt(cipherText *Ciphertext) (*big.Int, error) { 174 if err := validatePrvKey(key); err != nil { 175 return nil, err 176 } 177 if err := validateCiphertext(cipherText); err != nil { 178 return nil, err 179 } 180 if key.NSquared.Cmp(cipherText.Ct) < 1 { // c > n^2 181 return nil, ErrMessageTooLong 182 } 183 cp := new(big.Int).Exp(cipherText.Ct, key.pminusone, key.pp) 184 lp := l(cp, key.p) 185 mp := new(big.Int).Mod(new(big.Int).Mul(lp, key.hp), key.p) 186 cq := new(big.Int).Exp(cipherText.Ct, key.qminusone, key.qq) 187 lq := l(cq, key.q) 188 mqq := new(big.Int).Mul(lq, key.hq) 189 mq := new(big.Int).Mod(mqq, key.q) 190 m := crt(mp, mq, key) 191 plaintext, err := AdjustDecryptedDomain(key.PubKey, m) 192 return plaintext, err 193 } 194 195 func crt(mp *big.Int, mq *big.Int, privKey *PrvKey) *big.Int { 196 u := new(big.Int).Mod(new(big.Int).Mul(new(big.Int).Sub(mq, mp), privKey.pinvq), privKey.q) 197 m := new(big.Int).Add(mp, new(big.Int).Mul(u, privKey.p)) 198 return new(big.Int).Mod(m, privKey.n) 199 } 200 201 func Neg(pk *PubKey, cipher *Ciphertext) (*Ciphertext, error) { 202 cipher.Ct = new(big.Int).ModInverse(cipher.Ct, pk.NSquared) 203 return cipher, nil 204 } 205 206 func (key *PrvKey) GetPubKey() (*PubKey, error) { 207 if err := validatePrvKey(key); err != nil { 208 return nil, err 209 } 210 211 return key.PubKey, nil 212 } 213 214 // Marshal encodes the PubKey as a byte slice. 215 func (key *PubKey) Marshal() ([]byte, error) { 216 if err := validatePubKey(key); err != nil { 217 return nil, err 218 } 219 // public key io 220 return []byte(GetPublicKeyHex(key)), nil 221 } 222 223 // Unmarshal recovers the PubKey from an encoded byte slice. 224 func (key *PubKey) Unmarshal(pubKeyBytes []byte) error { 225 k, err := GetPublicKeyFromHex(string(pubKeyBytes)) 226 if err != nil { 227 return err 228 } 229 key.N = k.N 230 key.NSquared = k.NSquared 231 key.G = k.G 232 return nil 233 } 234 235 func (ct *Ciphertext) Marshal() ([]byte, error) { 236 if err := validateCiphertext(ct); err != nil { 237 return nil, ErrInvalidCiphertext 238 } 239 ctBytes := ct.Ct.Bytes() 240 return append(ct.Checksum, ctBytes...), nil 241 } 242 243 func (ct *Ciphertext) Unmarshal(ctBytes []byte) error { 244 if ctBytes == nil { 245 return ErrInvalidCiphertext 246 } 247 if ct.Ct == nil { 248 ct.Ct = new(big.Int) 249 } 250 ct.Ct.SetBytes(ctBytes[defaultChecksumSize:]) 251 ct.Checksum = ctBytes[:defaultChecksumSize] 252 return nil 253 } 254 255 // Marshal encodes the PrvKey as a byte slice. 256 func (key *PrvKey) Marshal() ([]byte, error) { 257 if err := validatePrvKey(key); err != nil { 258 return nil, err 259 } 260 tempBytes := []byte(GetPrivateKeyHex(key)) 261 return tempBytes, nil 262 } 263 264 // Unmarshal recovers the PrvKey from an encoded byte slice. 265 func (key *PrvKey) Unmarshal(prvKeyBytes []byte) error { 266 if prvKeyBytes == nil { 267 return ErrInvalidPrivateKey 268 } 269 k, err := GetPrivateKeyFromHex(string(prvKeyBytes)) 270 if err != nil { 271 return ErrInvalidPrivateKey 272 } 273 key.PubKey = k.PubKey 274 key.p = k.p 275 key.pp = k.pp 276 key.pminusone = k.pminusone 277 key.q = k.q 278 key.qq = k.qq 279 key.qminusone = k.qminusone 280 key.pinvq = k.pinvq 281 key.hp = k.hp 282 key.hq = k.hq 283 key.n = k.n 284 return nil 285 } 286 287 func (ct *Ciphertext) GetChecksum() ([]byte, error) { 288 if err := validateCiphertext(ct); err != nil { 289 return nil, err 290 } 291 return ct.Checksum, nil 292 } 293 294 func (ct *Ciphertext) GetCtBytes() ([]byte, error) { 295 if err := validateCiphertext(ct); err != nil { 296 return nil, err 297 } 298 return ct.Ct.Bytes(), nil 299 } 300 301 func (ct *Ciphertext) GetCtStr() (string, error) { 302 if err := validateCiphertext(ct); err != nil { 303 return "", err 304 } 305 return ct.Ct.String(), nil 306 } 307 308 // AddCiphertext homomorphically adds together two cipher texts. 309 // To do this we multiply the two cipher texts, upon decryption, the resulting 310 // plain text will be the sum of the corresponding plain texts. 311 func (key *PubKey) AddCiphertext(cipher1, cipher2 *Ciphertext) (*Ciphertext, error) { 312 if err := validatePubKey(key); err != nil { 313 return nil, err 314 } 315 if err := validateCiphertext(cipher1, cipher2); err != nil { 316 return nil, err 317 } 318 if !key.checkOperand(cipher1, cipher2) { 319 return nil, ErrInvalidMismatch 320 } 321 x := cipher1.Ct 322 y := cipher2.Ct 323 // x * y mod n^2 324 c := new(big.Int).Mod( 325 new(big.Int).Mul(x, y), 326 key.NSquared, 327 ) 328 return key.constructCiphertext(c) 329 } 330 331 // AddPlaintext homomorphically adds a passed constant to the encrypted integer 332 // (our cipher text). We do this by multiplying the constant with our 333 // ciphertext. Upon decryption, the resulting plain text will be the sum of 334 // the plaintext integer and the constant. 335 func (key *PubKey) AddPlaintext(cipher *Ciphertext, constant *big.Int) (*Ciphertext, error) { 336 if err := validatePubKey(key); err != nil { 337 return nil, err 338 } 339 if err := validateCiphertext(cipher); err != nil { 340 return nil, err 341 } 342 if err := validatePlaintext(constant); err != nil { 343 return nil, err 344 } 345 if !key.checkOperand(cipher) { 346 return nil, ErrInvalidMismatch 347 } 348 c := cipher.Ct 349 x := constant 350 // c * g ^ x mod n^2 351 c = new(big.Int).Mod( 352 new(big.Int).Mul(c, new(big.Int).Exp(key.G, x, key.NSquared)), 353 key.NSquared, 354 ) 355 return key.constructCiphertext(c) 356 } 357 358 func (key *PubKey) SubCiphertext(cipher1, cipher2 *Ciphertext) (*Ciphertext, error) { 359 if err := validatePubKey(key); err != nil { 360 return nil, err 361 } 362 if err := validateCiphertext(cipher1, cipher2); err != nil { 363 return nil, err 364 } 365 if !key.checkOperand(cipher1, cipher2) { 366 return nil, ErrInvalidMismatch 367 } 368 c1 := cipher1.Ct 369 c2 := cipher2.Ct 370 c2Inversed := new(big.Int).ModInverse(c2, key.NSquared) 371 c := new(big.Int).Mod(new(big.Int).Mul(c1, c2Inversed), key.NSquared) 372 return key.constructCiphertext(c) 373 } 374 375 func (key *PubKey) SubPlaintext(cipher *Ciphertext, constant *big.Int) (*Ciphertext, error) { 376 if err := validatePubKey(key); err != nil { 377 return nil, err 378 } 379 if err := validateCiphertext(cipher); err != nil { 380 return nil, err 381 } 382 if err := validatePlaintext(constant); err != nil { 383 return nil, err 384 } 385 if !key.checkOperand(cipher) { 386 return nil, ErrInvalidMismatch 387 } 388 plain := constant 389 plain = new(big.Int).Mod(new(big.Int).Add(new(big.Int).Mul(plain, key.N), one), key.NSquared) 390 c := cipher.Ct 391 c = new(big.Int).Mod(new(big.Int).Mul(c, new(big.Int).ModInverse(plain, key.NSquared)), key.NSquared) 392 return key.constructCiphertext(c) 393 } 394 395 func (key *PubKey) SubByConstant(pubKey *PubKey, cipher *Ciphertext, constant *big.Int) (*Ciphertext, error) { 396 cipherNeg, err := Neg(pubKey, cipher) 397 if err != nil { 398 return nil, err 399 } 400 return key.AddPlaintext(cipherNeg, constant) 401 } 402 403 // NumMul homomorphically multiplies an encrypted integer (cipher text) by a 404 // constant. We do this by raising our cipher text to the power of the passed 405 // constant. Upon decryption, the resulting plain text will be the product of 406 // the plaintext integer and the constant. 407 func (key *PubKey) NumMul(cipher *Ciphertext, constant *big.Int) (*Ciphertext, error) { 408 if err := validatePubKey(key); err != nil { 409 return nil, err 410 } 411 if err := validateCiphertext(cipher); err != nil { 412 return nil, err 413 } 414 if err := validatePlaintext(constant); err != nil { 415 return nil, err 416 } 417 if !key.checkOperand(cipher) { 418 return nil, ErrInvalidMismatch 419 } 420 c := new(big.Int).Exp(cipher.Ct, constant, key.NSquared) 421 return key.constructCiphertext(c) 422 } 423 424 func (key *PubKey) constructCiphertext(ciphertext *big.Int) (*Ciphertext, error) { 425 checksum, err := key.bindingCtPubKey(ciphertext.Bytes()) 426 if err != nil { 427 return nil, err 428 } 429 ct := &Ciphertext{ 430 Ct: ciphertext, 431 Checksum: checksum, 432 } 433 return ct, nil 434 } 435 436 func (key *PubKey) bindingCtPubKey(ciphertext []byte) ([]byte, error) { 437 pubKeyBytes, err := key.Marshal() 438 if ciphertext == nil { 439 return nil, ErrInvalidCiphertext 440 } 441 if err != nil { 442 return nil, err 443 } 444 checksum := sha256.Sum256(append(pubKeyBytes, ciphertext[:]...)) 445 return checksum[:defaultChecksumSize], nil 446 } 447 448 func (key *PubKey) checkOperand(cts ...*Ciphertext) bool { 449 for _, ct := range cts { 450 if !key.ChecksumVerify(ct) { 451 return false 452 } 453 } 454 return true 455 } 456 457 // ChecksumVerify verifying public key ciphertext pairs 458 func (key *PubKey) ChecksumVerify(ct *Ciphertext) bool { 459 if err := validatePubKey(key); err != nil { 460 return false 461 } 462 if err := validateCiphertext(ct); err != nil { 463 return false 464 } 465 keyBytes, err := key.Marshal() 466 if err != nil { 467 return false 468 } 469 ctBytes, err := ct.GetCtBytes() 470 if err != nil { 471 return false 472 } 473 currentChecksum, err := ct.GetChecksum() 474 if err != nil { 475 return false 476 } 477 checksum := sha256.Sum256(append(keyBytes, ctBytes...)) 478 return reflect.DeepEqual(checksum[:defaultChecksumSize], currentChecksum) 479 }