github.com/dusk-network/dusk-crypto@v0.1.3/rangeproof/rangeproof.go (about) 1 package rangeproof 2 3 import ( 4 "encoding/binary" 5 "fmt" 6 "io" 7 "math/big" 8 9 "github.com/pkg/errors" 10 11 ristretto "github.com/bwesterb/go-ristretto" 12 "github.com/dusk-network/dusk-crypto/rangeproof/fiatshamir" 13 "github.com/dusk-network/dusk-crypto/rangeproof/innerproduct" 14 "github.com/dusk-network/dusk-crypto/rangeproof/pedersen" 15 "github.com/dusk-network/dusk-crypto/rangeproof/vector" 16 ) 17 18 // N is number of bits in range 19 // So amount will be between 0...2^(N-1) 20 const N = 64 21 22 // M is the number of outputs for one bulletproof 23 var M = 1 24 25 // M is the maximum number of values allowed per rangeproof 26 const maxM = 16 27 28 // Proof is the constructed BulletProof 29 type Proof struct { 30 V []pedersen.Commitment // Curve points 32 bytes 31 Blinders []ristretto.Scalar 32 A ristretto.Point // Curve point 32 bytes 33 S ristretto.Point // Curve point 32 bytes 34 T1 ristretto.Point // Curve point 32 bytes 35 T2 ristretto.Point // Curve point 32 bytes 36 37 taux ristretto.Scalar //scalar 38 mu ristretto.Scalar //scalar 39 t ristretto.Scalar 40 41 IPProof *innerproduct.Proof 42 } 43 44 // Prove will take a set of scalars as a parameter and prove that it is [0, 2^N) 45 func Prove(v []ristretto.Scalar, debug bool) (Proof, error) { 46 47 if len(v) < 1 { 48 return Proof{}, errors.New("length of slice v is zero") 49 } 50 51 M = len(v) 52 if M > maxM { 53 return Proof{}, fmt.Errorf("maximum amount of values must be less than %d", maxM) 54 } 55 56 // Pad zero values until we have power of two 57 padAmount := innerproduct.DiffNextPow2(uint32(M)) 58 M = M + int(padAmount) 59 for i := uint32(0); i < padAmount; i++ { 60 var zeroScalar ristretto.Scalar 61 zeroScalar.SetZero() 62 v = append(v, zeroScalar) 63 } 64 65 // commitment to values v 66 Vs := make([]pedersen.Commitment, 0, M) 67 genData := []byte("dusk.BulletProof.vec1") 68 ped := pedersen.New(genData) 69 ped.BaseVector.Compute(uint32((N * M))) 70 71 // Hash for Fiat-Shamir 72 hs := fiatshamir.HashCacher{Cache: []byte{}} 73 74 for _, amount := range v { 75 // compute commmitment to v 76 V := ped.CommitToScalar(amount) 77 78 Vs = append(Vs, V) 79 80 // update Fiat-Shamir 81 hs.Append(V.Value.Bytes()) 82 } 83 84 aLs := make([]ristretto.Scalar, 0, N*M) 85 aRs := make([]ristretto.Scalar, 0, N*M) 86 87 for i := range v { 88 // Compute Bitcommits aL and aR to v 89 BC := BitCommit(v[i].BigInt()) 90 aLs = append(aLs, BC.AL...) 91 aRs = append(aRs, BC.AR...) 92 } 93 94 // Compute A 95 A := computeA(ped, aLs, aRs) 96 97 // // Compute S 98 S, sL, sR := computeS(ped) 99 100 // // update Fiat-Shamir 101 hs.Append(A.Value.Bytes(), S.Value.Bytes()) 102 103 // compute y and z 104 y, z := computeYAndZ(hs) 105 106 // compute polynomial 107 poly, err := computePoly(aLs, aRs, sL, sR, y, z) 108 if err != nil { 109 return Proof{}, errors.Wrap(err, "[Prove] - poly") 110 } 111 112 // Compute T1 and T2 113 T1 := ped.CommitToScalar(poly.t1) 114 T2 := ped.CommitToScalar(poly.t2) 115 116 // update Fiat-Shamir 117 hs.Append(z.Bytes(), T1.Value.Bytes(), T2.Value.Bytes()) 118 119 // compute x 120 x := computeX(hs) 121 // compute taux which is just the polynomial for the blinding factors at a point x 122 taux := computeTaux(x, z, T1.BlindingFactor, T2.BlindingFactor, Vs) 123 // compute mu 124 mu := computeMu(x, A.BlindingFactor, S.BlindingFactor) 125 126 // compute l dot r 127 l, err := poly.computeL(x) 128 if err != nil { 129 return Proof{}, errors.Wrap(err, "[Prove] - l") 130 } 131 r, err := poly.computeR(x) 132 if err != nil { 133 return Proof{}, errors.Wrap(err, "[Prove] - r") 134 } 135 t, err := vector.InnerProduct(l, r) 136 if err != nil { 137 return Proof{}, errors.Wrap(err, "[Prove] - t") 138 } 139 140 // START DEBUG 141 if debug { 142 err := debugProve(x, y, z, v, l, r, aLs, aRs, sL, sR) 143 if err != nil { 144 return Proof{}, errors.Wrap(err, "[Prove] - debugProve") 145 } 146 147 // DEBUG T0 148 testT0, err := debugT0(aLs, aRs, y, z) 149 if err != nil { 150 return Proof{}, errors.Wrap(err, "[Prove] - testT0") 151 152 } 153 if !testT0.Equals(&poly.t0) { 154 return Proof{}, errors.New("[Prove]: Test t0 value does not match the value calculated from the polynomial") 155 } 156 157 polyt0 := poly.computeT0(y, z, v, N, uint32(M)) 158 if !polyt0.Equals(&poly.t0) { 159 return Proof{}, errors.New("[Prove]: t0 value from delta function, does not match the polynomial t0 value(Correct)") 160 } 161 162 tPoly := poly.eval(x) 163 if !t.Equals(&tPoly) { 164 return Proof{}, errors.New("[Prove]: The t value computed from the t-poly, does not match the t value computed from the inner product of l and r") 165 } 166 } 167 // End DEBUG 168 169 // check if any challenge scalars are zero 170 if x.IsNonZeroI() == 0 || y.IsNonZeroI() == 0 || z.IsNonZeroI() == 0 { 171 return Proof{}, errors.New("[Prove] - One of the challenge scalars, x, y, or z was equal to zero. Generate proof again") 172 } 173 174 hs.Append(x.Bytes(), taux.Bytes(), mu.Bytes(), t.Bytes()) 175 176 // calculate inner product proof 177 Q := ristretto.Point{} 178 w := hs.Derive() 179 Q.ScalarMult(&ped.BasePoint, &w) 180 181 var yinv ristretto.Scalar 182 yinv.Inverse(&y) 183 Hpf := vector.ScalarPowers(yinv, uint32(N*M)) 184 185 genData = append(genData, uint8(1)) 186 ped2 := pedersen.New(genData) 187 ped2.BaseVector.Compute(uint32(N * M)) 188 189 H := ped2.BaseVector.Bases 190 G := ped.BaseVector.Bases 191 192 ip, err := innerproduct.Generate(G, H, l, r, Hpf, Q) 193 if err != nil { 194 return Proof{}, errors.Wrap(err, "[Prove] - ipproof") 195 } 196 197 return Proof{ 198 V: Vs, 199 A: A.Value, 200 S: S.Value, 201 T1: T1.Value, 202 T2: T2.Value, 203 t: t, 204 taux: taux, 205 mu: mu, 206 IPProof: ip, 207 }, nil 208 } 209 210 // A = kH + aL*G + aR*H 211 func computeA(ped *pedersen.Pedersen, aLs, aRs []ristretto.Scalar) pedersen.Commitment { 212 213 cA := ped.CommitToVectors(aLs, aRs) 214 215 return cA 216 } 217 218 // S = kH + sL*G + sR * H 219 func computeS(ped *pedersen.Pedersen) (pedersen.Commitment, []ristretto.Scalar, []ristretto.Scalar) { 220 221 sL, sR := make([]ristretto.Scalar, N*M), make([]ristretto.Scalar, N*M) 222 for i := 0; i < N*M; i++ { 223 var randA ristretto.Scalar 224 randA.Rand() 225 sL[i] = randA 226 227 var randB ristretto.Scalar 228 randB.Rand() 229 sR[i] = randB 230 } 231 232 cS := ped.CommitToVectors(sL, sR) 233 234 return cS, sL, sR 235 } 236 237 func computeYAndZ(hs fiatshamir.HashCacher) (ristretto.Scalar, ristretto.Scalar) { 238 239 var y ristretto.Scalar 240 y.Derive(hs.Result()) 241 242 var z ristretto.Scalar 243 z.Derive(y.Bytes()) 244 245 return y, z 246 } 247 248 func computeX(hs fiatshamir.HashCacher) ristretto.Scalar { 249 var x ristretto.Scalar 250 x.Derive(hs.Result()) 251 return x 252 } 253 254 // compute polynomial for blinding factors l61 255 // N.B. tau1 means tau superscript 1 256 // taux = t1Blind * x + t2Blind * x^2 + (sum(z^n+1 * vBlind[n-1])) from n = 1 to n = m 257 func computeTaux(x, z, t1Blind, t2Blind ristretto.Scalar, vBlinds []pedersen.Commitment) ristretto.Scalar { 258 tau1X := t1Blind.Mul(&x, &t1Blind) 259 260 var xsq ristretto.Scalar 261 xsq.Square(&x) 262 263 tau2Xsq := t2Blind.Mul(&xsq, &t2Blind) 264 265 var zN ristretto.Scalar 266 zN.Square(&z) // start at zSq 267 268 var zNBlindSum ristretto.Scalar 269 zNBlindSum.SetZero() 270 271 for i := range vBlinds { 272 zNBlindSum.MulAdd(&zN, &vBlinds[i].BlindingFactor, &zNBlindSum) 273 zN.Mul(&zN, &z) 274 } 275 276 var res ristretto.Scalar 277 res.Add(tau1X, tau2Xsq) 278 res.Add(&res, &zNBlindSum) 279 280 return res 281 } 282 283 // alpha is the blinding factor for A 284 // rho is the blinding factor for S 285 // mu = alpha + rho * x 286 func computeMu(x, alpha, rho ristretto.Scalar) ristretto.Scalar { 287 288 var mu ristretto.Scalar 289 290 mu.MulAdd(&rho, &x, &alpha) 291 292 return mu 293 } 294 295 // computeHprime will take a a slice of points H, with a scalar y 296 // and return a slice of points Hprime, such that Hprime = y^-n * H 297 func computeHprime(H []ristretto.Point, y ristretto.Scalar) []ristretto.Point { 298 Hprimes := make([]ristretto.Point, len(H)) 299 300 var yInv ristretto.Scalar 301 yInv.Inverse(&y) 302 303 invYInt := yInv.BigInt() 304 305 for i, p := range H { 306 // compute y^-i 307 var invYPowInt big.Int 308 invYPowInt.Exp(invYInt, big.NewInt(int64(i)), nil) 309 310 var invY ristretto.Scalar 311 invY.SetBigInt(&invYPowInt) 312 313 var hprime ristretto.Point 314 hprime.ScalarMult(&p, &invY) 315 316 Hprimes[i] = hprime 317 } 318 319 return Hprimes 320 } 321 322 // Verify takes a bullet proof and returns true only if the proof was valid 323 func Verify(p Proof) (bool, error) { 324 325 genData := []byte("dusk.BulletProof.vec1") 326 ped := pedersen.New(genData) 327 ped.BaseVector.Compute(uint32(N * M)) 328 329 genData = append(genData, uint8(1)) 330 331 ped2 := pedersen.New(genData) 332 ped2.BaseVector.Compute(uint32(N * M)) 333 334 G := ped.BaseVector.Bases 335 H := ped2.BaseVector.Bases 336 337 // Reconstruct the challenges 338 hs := fiatshamir.HashCacher{Cache: []byte{}} 339 for _, V := range p.V { 340 hs.Append(V.Value.Bytes()) 341 } 342 343 hs.Append(p.A.Bytes(), p.S.Bytes()) 344 y, z := computeYAndZ(hs) 345 hs.Append(z.Bytes(), p.T1.Bytes(), p.T2.Bytes()) 346 x := computeX(hs) 347 hs.Append(x.Bytes(), p.taux.Bytes(), p.mu.Bytes(), p.t.Bytes()) 348 w := hs.Derive() 349 350 return megacheckWithC(p.IPProof, p.mu, x, y, z, p.t, p.taux, w, p.A, ped.BasePoint, ped.BlindPoint, p.S, p.T1, p.T2, G, H, p.V) 351 } 352 353 func megacheckWithC(ipproof *innerproduct.Proof, mu, x, y, z, t, taux, w ristretto.Scalar, A, G, H, S, T1, T2 ristretto.Point, GVec, HVec []ristretto.Point, V []pedersen.Commitment) (bool, error) { 354 355 var c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11 ristretto.Point 356 357 var c ristretto.Scalar 358 c.Rand() 359 360 uSq, uInvSq, s := ipproof.VerifScalars() 361 sInv := make([]ristretto.Scalar, len(s)) 362 copy(sInv, s) 363 364 // reverse s 365 for i, j := 0, len(sInv)-1; i < j; i, j = i+1, j-1 { 366 sInv[i], sInv[j] = sInv[j], sInv[i] 367 } 368 369 // g vector scalars : as + z points : G 370 as := vector.MulScalar(s, ipproof.A) 371 g := vector.AddScalar(as, z) 372 g = vector.MulScalar(g, c) 373 374 c1, err := vector.Exp(g, GVec, len(GVec), 1) 375 if err != nil { 376 return false, err 377 } 378 379 // h vector scalars : y Had (bsInv - zM2N) - z points : H 380 bs := vector.MulScalar(sInv, ipproof.B) 381 zAnd2 := sumZMTwoN(z) 382 h, err := vector.Sub(bs, zAnd2) 383 if err != nil { 384 return false, errors.Wrap(err, "[h1]") 385 } 386 387 var yinv ristretto.Scalar 388 yinv.Inverse(&y) 389 Hpf := vector.ScalarPowers(yinv, uint32(N*M)) 390 391 h, err = vector.Hadamard(h, Hpf) 392 if err != nil { 393 return false, errors.Wrap(err, "[h2]") 394 } 395 h = vector.SubScalar(h, z) 396 h = vector.MulScalar(h, c) 397 398 c2, err = vector.Exp(h, HVec, len(HVec), 1) 399 if err != nil { 400 return false, err 401 } 402 403 // G basepoint gbp : (c * w(ab-t)) + t-D(y,z) point : G 404 delta := computeDelta(y, z, N, uint32(M)) 405 var tMinusDelta ristretto.Scalar 406 tMinusDelta.Sub(&t, &delta) 407 408 var abMinusT ristretto.Scalar 409 abMinusT.Mul(&ipproof.A, &ipproof.B) 410 abMinusT.Sub(&abMinusT, &t) 411 412 var cw ristretto.Scalar 413 cw.Mul(&c, &w) 414 415 var gBP ristretto.Scalar 416 gBP.MulAdd(&cw, &abMinusT, &tMinusDelta) 417 418 c3.ScalarMult(&G, &gBP) 419 420 // H basepoint hbp : c * mu + taux point: H 421 var cmu ristretto.Scalar 422 cmu.Mul(&mu, &c) 423 424 var hBP ristretto.Scalar 425 hBP.Add(&cmu, &taux) 426 427 c4.ScalarMult(&H, &hBP) 428 429 // scalar :c point: A 430 c5.ScalarMult(&A, &c) 431 432 // scalar: cx point : S 433 var cx ristretto.Scalar 434 cx.Mul(&c, &x) 435 c6.ScalarMult(&S, &cx) 436 437 // scalar: uSq challenges points: Lj 438 c7, err = vector.Exp(uSq, ipproof.L, len(ipproof.L), 1) 439 if err != nil { 440 return false, err 441 } 442 c7.PublicScalarMult(&c7, &c) 443 444 // scalar : uInvSq challenges points: Rj 445 c8, err = vector.Exp(uInvSq, ipproof.R, len(ipproof.R), 1) 446 if err != nil { 447 return false, err 448 } 449 c8.PublicScalarMult(&c8, &c) 450 451 // scalar: z_j+2 points: Vj 452 zM := vector.ScalarPowers(z, uint32(M)) 453 var zSq ristretto.Scalar 454 zSq.Square(&z) 455 zM = vector.MulScalar(zM, zSq) 456 c9.SetZero() 457 for i := range zM { 458 var temp ristretto.Point 459 temp.PublicScalarMult(&V[i].Value, &zM[i]) 460 c9.Add(&c9, &temp) 461 } 462 463 // scalar : x point: T1 464 c10.PublicScalarMult(&T1, &x) 465 466 // scalar : xSq point: T2 467 var xSq ristretto.Scalar 468 xSq.Square(&x) 469 c11.PublicScalarMult(&T2, &xSq) 470 471 var sum ristretto.Point 472 sum.SetZero() 473 sum.Add(&c1, &c2) 474 sum.Add(&sum, &c3) 475 sum.Add(&sum, &c4) 476 sum.Sub(&sum, &c5) 477 sum.Sub(&sum, &c6) 478 sum.Sub(&sum, &c7) 479 sum.Sub(&sum, &c8) 480 sum.Sub(&sum, &c9) 481 sum.Sub(&sum, &c10) 482 sum.Sub(&sum, &c11) 483 484 var zero ristretto.Point 485 zero.SetZero() 486 487 ok := zero.Equals(&sum) 488 if !ok { 489 return false, errors.New("megacheck failed") 490 } 491 492 return true, nil 493 } 494 495 // Encode a Proof 496 func (p *Proof) Encode(w io.Writer, includeCommits bool) error { 497 498 if includeCommits { 499 err := pedersen.EncodeCommitments(w, p.V) 500 if err != nil { 501 return err 502 } 503 } 504 505 err := binary.Write(w, binary.BigEndian, p.A.Bytes()) 506 if err != nil { 507 return err 508 } 509 err = binary.Write(w, binary.BigEndian, p.S.Bytes()) 510 if err != nil { 511 return err 512 } 513 err = binary.Write(w, binary.BigEndian, p.T1.Bytes()) 514 if err != nil { 515 return err 516 } 517 err = binary.Write(w, binary.BigEndian, p.T2.Bytes()) 518 if err != nil { 519 return err 520 } 521 err = binary.Write(w, binary.BigEndian, p.taux.Bytes()) 522 if err != nil { 523 return err 524 } 525 err = binary.Write(w, binary.BigEndian, p.mu.Bytes()) 526 if err != nil { 527 return err 528 } 529 err = binary.Write(w, binary.BigEndian, p.t.Bytes()) 530 if err != nil { 531 return err 532 } 533 return p.IPProof.Encode(w) 534 } 535 536 // Decode a Proof 537 func (p *Proof) Decode(r io.Reader, includeCommits bool) error { 538 539 if p == nil { 540 return errors.New("struct is nil") 541 } 542 543 if includeCommits { 544 comms, err := pedersen.DecodeCommitments(r) 545 if err != nil { 546 return err 547 } 548 p.V = comms 549 } 550 551 err := readerToPoint(r, &p.A) 552 if err != nil { 553 return err 554 } 555 err = readerToPoint(r, &p.S) 556 if err != nil { 557 return err 558 } 559 err = readerToPoint(r, &p.T1) 560 if err != nil { 561 return err 562 } 563 err = readerToPoint(r, &p.T2) 564 if err != nil { 565 return err 566 } 567 err = readerToScalar(r, &p.taux) 568 if err != nil { 569 return err 570 } 571 err = readerToScalar(r, &p.mu) 572 if err != nil { 573 return err 574 } 575 err = readerToScalar(r, &p.t) 576 if err != nil { 577 return err 578 } 579 p.IPProof = &innerproduct.Proof{} 580 return p.IPProof.Decode(r) 581 } 582 583 // Equals returns proof equality with commitments 584 func (p *Proof) Equals(other Proof, includeCommits bool) bool { 585 if len(p.V) != len(other.V) && includeCommits { 586 return false 587 } 588 589 for i := range p.V { 590 ok := p.V[i].EqualValue(other.V[i]) 591 if !ok { 592 return ok 593 } 594 } 595 596 ok := p.A.Equals(&other.A) 597 if !ok { 598 return ok 599 } 600 ok = p.S.Equals(&other.S) 601 if !ok { 602 return ok 603 } 604 ok = p.T1.Equals(&other.T1) 605 if !ok { 606 return ok 607 } 608 ok = p.T2.Equals(&other.T2) 609 if !ok { 610 return ok 611 } 612 ok = p.taux.Equals(&other.taux) 613 if !ok { 614 return ok 615 } 616 ok = p.mu.Equals(&other.mu) 617 if !ok { 618 return ok 619 } 620 ok = p.t.Equals(&other.t) 621 if !ok { 622 return ok 623 } 624 return true 625 // return p.IPProof.Equals(*other.IPProof) 626 } 627 628 func readerToPoint(r io.Reader, p *ristretto.Point) error { 629 var x [32]byte 630 err := binary.Read(r, binary.BigEndian, &x) 631 if err != nil { 632 return err 633 } 634 ok := p.SetBytes(&x) 635 if !ok { 636 return errors.New("point not encodable") 637 } 638 return nil 639 } 640 func readerToScalar(r io.Reader, s *ristretto.Scalar) error { 641 var x [32]byte 642 err := binary.Read(r, binary.BigEndian, &x) 643 if err != nil { 644 return err 645 } 646 s.SetBytes(&x) 647 return nil 648 }