github.com/bigzoro/my_simplechain@v0.0.0-20240315012955-8ad0a2a29bb9/core/access_contoller/crypto/paillier/paillier.go (about) 1 /* 2 Copyright (C) BABEC. All rights reserved. 3 Copyright (C) THL A29 Limited, a Tencent company. All rights reserved. 4 5 SPDX-License-Identifier: Apache-2.0 6 */ 7 8 package paillier 9 10 import ( 11 "crypto/rand" 12 "crypto/sha256" 13 "errors" 14 "io" 15 "math/big" 16 "reflect" 17 ) 18 19 var ( 20 one = big.NewInt(1) 21 ) 22 23 // ErrMessageTooLong is returned when attempting to encrypt a message which is 24 // too large for the size of the public key. 25 var ErrMessageTooLong = errors.New("paillier: message too long for Paillier public key size") 26 var ErrInvalidCiphertext = errors.New("paillier: invalid ciphertext") 27 var ErrInvalidPlaintext = errors.New("paillier: invalid plaintext") 28 var ErrInvalidPublicKey = errors.New("paillier: invalid public key") 29 var ErrInvalidPrivateKey = errors.New("paillier: invalid private key") 30 var ErrInvalidMismatch = errors.New("paillier: key mismatch") 31 32 // PubKey represents the public part of a Paillier key. 33 type PubKey struct { 34 N *big.Int // modulus 35 G *big.Int // n+1, since p and q are same length 36 NSquared *big.Int 37 } 38 39 // PrvKey represents a Paillier key. 40 type PrvKey struct { 41 *PubKey 42 p *big.Int 43 pp *big.Int 44 pminusone *big.Int 45 q *big.Int 46 qq *big.Int 47 qminusone *big.Int 48 pinvq *big.Int 49 hp *big.Int 50 hq *big.Int 51 n *big.Int 52 } 53 54 type Ciphertext struct { 55 Ct *big.Int 56 Checksum []byte 57 } 58 59 func GenKey() (*PrvKey, error) { 60 return generateKey(rand.Reader, 256) 61 } 62 63 // generateKey generates an Paillier keypair of the given bit size using the 64 // random source random (for example, crypto/rand.Reader). 65 func generateKey(random io.Reader, bits int) (*PrvKey, error) { 66 // First, begin generation of p in the background. 67 var p *big.Int 68 var errChan = make(chan error, 1) 69 go func() { 70 var err error 71 p, err = rand.Prime(random, bits/2) 72 errChan <- err 73 }() 74 75 // Now, find a prime q in the foreground. 76 q, err := rand.Prime(random, bits/2) 77 if err != nil { 78 return nil, err 79 } 80 81 // Wait for generation of p to complete successfully. 82 if err := <-errChan; err != nil { 83 return nil, err 84 } 85 86 n := new(big.Int).Mul(p, q) 87 pp := new(big.Int).Mul(p, p) 88 qq := new(big.Int).Mul(q, q) 89 90 return &PrvKey{ 91 PubKey: &PubKey{ 92 N: n, 93 NSquared: new(big.Int).Mul(n, n), 94 G: new(big.Int).Add(n, one), // g = n + 1 95 }, 96 p: p, 97 pp: pp, 98 pminusone: new(big.Int).Sub(p, one), 99 q: q, 100 qq: qq, 101 qminusone: new(big.Int).Sub(q, one), 102 pinvq: new(big.Int).ModInverse(p, q), 103 hp: h(p, pp, n), 104 hq: h(q, qq, n), 105 n: n, 106 }, nil 107 108 } 109 110 // hp hq 111 func h(p *big.Int, pp *big.Int, n *big.Int) *big.Int { 112 gp := new(big.Int).Mod(new(big.Int).Sub(one, n), pp) 113 114 lp := l(gp, p) 115 116 hp := new(big.Int).ModInverse(lp, p) 117 return hp 118 } 119 120 func l(u *big.Int, n *big.Int) *big.Int { 121 return new(big.Int).Div(new(big.Int).Sub(u, one), n) 122 } 123 124 // Encrypt encrypts a plain text represented as a byte array. The passed plain 125 // text MUST NOT be larger than the modulus of the passed public key. 126 func (key *PubKey) Encrypt(plainText *big.Int) (*Ciphertext, error) { 127 if err := validatePubKey(key); err != nil { 128 return nil, err 129 } 130 131 if err := validatePlaintext(plainText); err != nil { 132 return nil, err 133 } 134 135 plaintext, err := AdjustPlaintextDomain(key, plainText) 136 if err != nil { 137 return nil, err 138 } 139 c, _, err := EncryptAndNonce(key, plaintext) 140 if err != nil { 141 return nil, err 142 } 143 144 checksum, err := key.bindingCtPubKey(c.Bytes()) 145 146 ct := &Ciphertext{ 147 Ct: c, 148 Checksum: checksum, 149 } 150 return ct, err 151 } 152 153 // EncryptAndNonce encrypts a plain text represented as a byte array, and in 154 // addition, returns the nonce used during encryption. The passed plain text 155 // MUST NOT be larger than the modulus of the passed public key. 156 func EncryptAndNonce(pubKey *PubKey, plainText *big.Int) (*big.Int, *big.Int, error) { 157 r, err := rand.Int(rand.Reader, pubKey.N) 158 if err != nil { 159 return nil, nil, err 160 } 161 for new(big.Int).GCD(nil, nil, r, pubKey.N).Cmp(one) != 0 { 162 r = new(big.Int).Mod(new(big.Int).Add(r, one), pubKey.N) 163 } 164 165 c, err := EncryptWithNonce(pubKey, r, plainText) 166 if err != nil { 167 return nil, nil, err 168 } 169 170 return c, r, nil 171 } 172 173 // EncryptWithNonce encrypts a plain text represented as a byte array using the 174 // provided nonce to perform encryption. The passed plain text MUST NOT be 175 // larger than the modulus of the passed public key. 176 func EncryptWithNonce(pubKey *PubKey, r *big.Int, pt *big.Int) (*big.Int, error) { 177 if pubKey.N.Cmp(pt) < 1 { // N < m 178 return nil, ErrMessageTooLong 179 } 180 181 // c = g^m * r^n mod n^2 = ((m*n+1) mod n^2) * r^n mod n^2 182 n := pubKey.N 183 c := new(big.Int).Mod( 184 new(big.Int).Mul( 185 new(big.Int).Mod(new(big.Int).Add(one, new(big.Int).Mul(pt, n)), pubKey.NSquared), 186 new(big.Int).Exp(r, n, pubKey.NSquared), 187 ), 188 pubKey.NSquared, 189 ) 190 191 return c, nil 192 } 193 194 // Decrypt decrypts the passed cipher text. 195 func (key *PrvKey) Decrypt(cipherText *Ciphertext) (*big.Int, error) { 196 if err := validatePrvKey(key); err != nil { 197 return nil, err 198 } 199 200 if err := validateCiphertext(cipherText); err != nil { 201 return nil, err 202 } 203 204 if key.NSquared.Cmp(cipherText.Ct) < 1 { // c > n^2 205 return nil, ErrMessageTooLong 206 } 207 208 cp := new(big.Int).Exp(cipherText.Ct, key.pminusone, key.pp) 209 lp := l(cp, key.p) 210 mp := new(big.Int).Mod(new(big.Int).Mul(lp, key.hp), key.p) 211 cq := new(big.Int).Exp(cipherText.Ct, key.qminusone, key.qq) 212 lq := l(cq, key.q) 213 214 mqq := new(big.Int).Mul(lq, key.hq) 215 mq := new(big.Int).Mod(mqq, key.q) 216 m := crt(mp, mq, key) 217 218 plaintext, err := AdjustDecryptedDomain(key.PubKey, m) 219 return plaintext, err 220 } 221 222 func crt(mp *big.Int, mq *big.Int, privKey *PrvKey) *big.Int { 223 u := new(big.Int).Mod(new(big.Int).Mul(new(big.Int).Sub(mq, mp), privKey.pinvq), privKey.q) 224 m := new(big.Int).Add(mp, new(big.Int).Mul(u, privKey.p)) 225 return new(big.Int).Mod(m, privKey.n) 226 } 227 228 func Neg(pk *PubKey, cipher *Ciphertext) (*Ciphertext, error) { 229 cipher.Ct = new(big.Int).ModInverse(cipher.Ct, pk.NSquared) 230 return cipher, nil 231 } 232 233 func (key *PrvKey) GetPubKey() (*PubKey, error) { 234 if err := validatePrvKey(key); err != nil { 235 return nil, err 236 } 237 238 return key.PubKey, nil 239 } 240 241 // Marshal encodes the PubKey as a byte slice. 242 func (key *PubKey) Marshal() ([]byte, error) { 243 if err := validatePubKey(key); err != nil { 244 return nil, err 245 } 246 // public key io 247 return []byte(GetPublicKeyHex(key)), nil 248 } 249 250 // Unmarshal recovers the PubKey from an encoded byte slice. 251 func (key *PubKey) Unmarshal(pubKeyBytes []byte) error { 252 k, err := GetPublicKeyFromHex(string(pubKeyBytes)) 253 if err != nil { 254 return err 255 } 256 257 key.N = k.N 258 key.NSquared = k.NSquared 259 key.G = k.G 260 261 return nil 262 } 263 264 func (ct *Ciphertext) Marshal() ([]byte, error) { 265 if err := validateCiphertext(ct); err != nil { 266 return nil, ErrInvalidCiphertext 267 } 268 269 ctBytes := ct.Ct.Bytes() 270 return append(ct.Checksum, ctBytes...), nil 271 } 272 273 func (ct *Ciphertext) Unmarshal(ctBytes []byte) error { 274 if ctBytes == nil { 275 return ErrInvalidCiphertext 276 } 277 278 if ct.Ct == nil { 279 ct.Ct = new(big.Int) 280 } 281 282 ct.Ct.SetBytes(ctBytes[defaultChecksumSize:]) 283 ct.Checksum = ctBytes[:defaultChecksumSize] 284 return nil 285 } 286 287 // Marshal encodes the PrvKey as a byte slice. 288 func (key *PrvKey) Marshal() ([]byte, error) { 289 if err := validatePrvKey(key); err != nil { 290 return nil, err 291 } 292 293 tempBytes := []byte(GetPrivateKeyHex(key)) 294 295 return tempBytes, nil 296 } 297 298 // Unmarshal recovers the PrvKey from an encoded byte slice. 299 func (key *PrvKey) Unmarshal(prvKeyBytes []byte) error { 300 if prvKeyBytes == nil { 301 return ErrInvalidPrivateKey 302 } 303 304 k, err := GetPrivateKeyFromHex(string(prvKeyBytes)) 305 if err != nil { 306 return ErrInvalidPrivateKey 307 } 308 309 key.PubKey = k.PubKey 310 key.p = k.p 311 key.pp = k.pp 312 key.pminusone = k.pminusone 313 key.q = k.q 314 key.qq = k.qq 315 key.qminusone = k.qminusone 316 key.pinvq = k.pinvq 317 key.hp = k.hp 318 key.hq = k.hq 319 key.n = k.n 320 321 return nil 322 } 323 324 func (ct *Ciphertext) GetChecksum() ([]byte, error) { 325 if err := validateCiphertext(ct); err != nil { 326 return nil, err 327 } 328 329 return ct.Checksum, nil 330 } 331 332 func (ct *Ciphertext) GetCtBytes() ([]byte, error) { 333 if err := validateCiphertext(ct); err != nil { 334 return nil, err 335 } 336 337 return ct.Ct.Bytes(), nil 338 } 339 340 func (ct *Ciphertext) GetCtStr() (string, error) { 341 if err := validateCiphertext(ct); err != nil { 342 return "", err 343 } 344 345 return ct.Ct.String(), nil 346 } 347 348 // AddCiphertext homomorphically adds together two cipher texts. 349 // To do this we multiply the two cipher texts, upon decryption, the resulting 350 // plain text will be the sum of the corresponding plain texts. 351 func (key *PubKey) AddCiphertext(cipher1, cipher2 *Ciphertext) (*Ciphertext, error) { 352 if err := validatePubKey(key); err != nil { 353 return nil, err 354 } 355 356 if err := validateCiphertext(cipher1, cipher2); err != nil { 357 return nil, err 358 } 359 360 if !key.checkOperand(cipher1, cipher2) { 361 return nil, ErrInvalidMismatch 362 } 363 364 x := cipher1.Ct 365 y := cipher2.Ct 366 // x * y mod n^2 367 c := new(big.Int).Mod( 368 new(big.Int).Mul(x, y), 369 key.NSquared, 370 ) 371 372 return key.constructCiphertext(c) 373 } 374 375 // AddPlaintext homomorphically adds a passed constant to the encrypted integer 376 // (our cipher text). We do this by multiplying the constant with our 377 // ciphertext. Upon decryption, the resulting plain text will be the sum of 378 // the plaintext integer and the constant. 379 func (key *PubKey) AddPlaintext(cipher *Ciphertext, constant *big.Int) (*Ciphertext, error) { 380 if err := validatePubKey(key); err != nil { 381 return nil, err 382 } 383 384 if err := validateCiphertext(cipher); err != nil { 385 return nil, err 386 } 387 388 if err := validatePlaintext(constant); err != nil { 389 return nil, err 390 } 391 392 if !key.checkOperand(cipher) { 393 return nil, ErrInvalidMismatch 394 } 395 396 c := cipher.Ct 397 x := constant 398 399 // c * g ^ x mod n^2 400 c = new(big.Int).Mod( 401 new(big.Int).Mul(c, new(big.Int).Exp(key.G, x, key.NSquared)), 402 key.NSquared, 403 ) 404 405 return key.constructCiphertext(c) 406 } 407 408 func (key *PubKey) SubCiphertext(cipher1, cipher2 *Ciphertext) (*Ciphertext, error) { 409 if err := validatePubKey(key); err != nil { 410 return nil, err 411 } 412 413 if err := validateCiphertext(cipher1, cipher2); err != nil { 414 return nil, err 415 } 416 417 if !key.checkOperand(cipher1, cipher2) { 418 return nil, ErrInvalidMismatch 419 } 420 421 c1 := cipher1.Ct 422 c2 := cipher2.Ct 423 c2Inversed := new(big.Int).ModInverse(c2, key.NSquared) 424 c := new(big.Int).Mod(new(big.Int).Mul(c1, c2Inversed), key.NSquared) 425 426 return key.constructCiphertext(c) 427 } 428 429 func (key *PubKey) SubPlaintext(cipher *Ciphertext, constant *big.Int) (*Ciphertext, error) { 430 if err := validatePubKey(key); err != nil { 431 return nil, err 432 } 433 434 if err := validateCiphertext(cipher); err != nil { 435 return nil, err 436 } 437 438 if err := validatePlaintext(constant); err != nil { 439 return nil, err 440 } 441 442 if !key.checkOperand(cipher) { 443 return nil, ErrInvalidMismatch 444 } 445 446 plain := constant 447 plain = new(big.Int).Mod(new(big.Int).Add(new(big.Int).Mul(plain, key.N), one), key.NSquared) 448 c := cipher.Ct 449 c = new(big.Int).Mod(new(big.Int).Mul(c, new(big.Int).ModInverse(plain, key.NSquared)), key.NSquared) 450 451 return key.constructCiphertext(c) 452 } 453 454 func (key *PubKey) SubByConstant(pubKey *PubKey, cipher *Ciphertext, constant *big.Int) (*Ciphertext, error) { 455 cipherNeg, err := Neg(pubKey, cipher) 456 if err != nil { 457 return nil, err 458 } 459 return key.AddPlaintext(cipherNeg, constant) 460 } 461 462 // NumMul homomorphically multiplies an encrypted integer (cipher text) by a 463 // constant. We do this by raising our cipher text to the power of the passed 464 // constant. Upon decryption, the resulting plain text will be the product of 465 // the plaintext integer and the constant. 466 func (key *PubKey) NumMul(cipher *Ciphertext, constant *big.Int) (*Ciphertext, error) { 467 if err := validatePubKey(key); err != nil { 468 return nil, err 469 } 470 471 if err := validateCiphertext(cipher); err != nil { 472 return nil, err 473 } 474 475 if err := validatePlaintext(constant); err != nil { 476 return nil, err 477 } 478 479 if !key.checkOperand(cipher) { 480 return nil, ErrInvalidMismatch 481 } 482 483 c := new(big.Int).Exp(cipher.Ct, constant, key.NSquared) 484 485 return key.constructCiphertext(c) 486 } 487 488 func (key *PubKey) constructCiphertext(ciphertext *big.Int) (*Ciphertext, error) { 489 checksum, err := key.bindingCtPubKey(ciphertext.Bytes()) 490 if err != nil { 491 return nil, err 492 } 493 494 ct := &Ciphertext{ 495 Ct: ciphertext, 496 Checksum: checksum, 497 } 498 499 return ct, nil 500 } 501 502 func (key *PubKey) bindingCtPubKey(ciphertext []byte) ([]byte, error) { 503 pubKeyBytes, err := key.Marshal() 504 if ciphertext == nil { 505 return nil, ErrInvalidCiphertext 506 } 507 508 if err != nil { 509 return nil, err 510 } 511 512 checksum := sha256.Sum256(append(pubKeyBytes, ciphertext[:]...)) 513 return checksum[:defaultChecksumSize], nil 514 } 515 516 func (key *PubKey) checkOperand(cts ...*Ciphertext) bool { 517 for _, ct := range cts { 518 if !key.ChecksumVerify(ct) { 519 return false 520 } 521 } 522 return true 523 } 524 525 // ChecksumVerify verifying public key ciphertext pairs 526 func (key *PubKey) ChecksumVerify(ct *Ciphertext) bool { 527 if err := validatePubKey(key); err != nil { 528 return false 529 } 530 531 if err := validateCiphertext(ct); err != nil { 532 return false 533 } 534 535 keyBytes, err := key.Marshal() 536 if err != nil { 537 return false 538 } 539 540 ctBytes, err := ct.GetCtBytes() 541 if err != nil { 542 return false 543 } 544 545 currentChecksum, err := ct.GetChecksum() 546 if err != nil { 547 return false 548 } 549 550 checksum := sha256.Sum256(append(keyBytes, ctBytes...)) 551 return reflect.DeepEqual(checksum[:defaultChecksumSize], currentChecksum) 552 }