github.com/dusk-network/dusk-crypto@v0.1.3/rangeproof/rangeproof.go (about)

     1  package rangeproof
     2  
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     6  	"io"
     7  	"math/big"
     8  
     9  	"github.com/pkg/errors"
    10  
    11  	ristretto "github.com/bwesterb/go-ristretto"
    12  	"github.com/dusk-network/dusk-crypto/rangeproof/fiatshamir"
    13  	"github.com/dusk-network/dusk-crypto/rangeproof/innerproduct"
    14  	"github.com/dusk-network/dusk-crypto/rangeproof/pedersen"
    15  	"github.com/dusk-network/dusk-crypto/rangeproof/vector"
    16  )
    17  
    18  // N is number of bits in range
    19  // So amount will be between 0...2^(N-1)
    20  const N = 64
    21  
    22  // M is the number of outputs for one bulletproof
    23  var M = 1
    24  
    25  // M is the maximum number of values allowed per rangeproof
    26  const maxM = 16
    27  
    28  // Proof is the constructed BulletProof
    29  type Proof struct {
    30  	V        []pedersen.Commitment // Curve points 32 bytes
    31  	Blinders []ristretto.Scalar
    32  	A        ristretto.Point // Curve point 32 bytes
    33  	S        ristretto.Point // Curve point 32 bytes
    34  	T1       ristretto.Point // Curve point 32 bytes
    35  	T2       ristretto.Point // Curve point 32 bytes
    36  
    37  	taux ristretto.Scalar //scalar
    38  	mu   ristretto.Scalar //scalar
    39  	t    ristretto.Scalar
    40  
    41  	IPProof *innerproduct.Proof
    42  }
    43  
    44  // Prove will take a set of scalars as a parameter and prove that it is [0, 2^N)
    45  func Prove(v []ristretto.Scalar, debug bool) (Proof, error) {
    46  
    47  	if len(v) < 1 {
    48  		return Proof{}, errors.New("length of slice v is zero")
    49  	}
    50  
    51  	M = len(v)
    52  	if M > maxM {
    53  		return Proof{}, fmt.Errorf("maximum amount of values must be less than %d", maxM)
    54  	}
    55  
    56  	// Pad zero values until we have power of two
    57  	padAmount := innerproduct.DiffNextPow2(uint32(M))
    58  	M = M + int(padAmount)
    59  	for i := uint32(0); i < padAmount; i++ {
    60  		var zeroScalar ristretto.Scalar
    61  		zeroScalar.SetZero()
    62  		v = append(v, zeroScalar)
    63  	}
    64  
    65  	// commitment to values v
    66  	Vs := make([]pedersen.Commitment, 0, M)
    67  	genData := []byte("dusk.BulletProof.vec1")
    68  	ped := pedersen.New(genData)
    69  	ped.BaseVector.Compute(uint32((N * M)))
    70  
    71  	// Hash for Fiat-Shamir
    72  	hs := fiatshamir.HashCacher{Cache: []byte{}}
    73  
    74  	for _, amount := range v {
    75  		// compute commmitment to v
    76  		V := ped.CommitToScalar(amount)
    77  
    78  		Vs = append(Vs, V)
    79  
    80  		// update Fiat-Shamir
    81  		hs.Append(V.Value.Bytes())
    82  	}
    83  
    84  	aLs := make([]ristretto.Scalar, 0, N*M)
    85  	aRs := make([]ristretto.Scalar, 0, N*M)
    86  
    87  	for i := range v {
    88  		// Compute Bitcommits aL and aR to v
    89  		BC := BitCommit(v[i].BigInt())
    90  		aLs = append(aLs, BC.AL...)
    91  		aRs = append(aRs, BC.AR...)
    92  	}
    93  
    94  	// Compute A
    95  	A := computeA(ped, aLs, aRs)
    96  
    97  	// // Compute S
    98  	S, sL, sR := computeS(ped)
    99  
   100  	// // update Fiat-Shamir
   101  	hs.Append(A.Value.Bytes(), S.Value.Bytes())
   102  
   103  	// compute y and z
   104  	y, z := computeYAndZ(hs)
   105  
   106  	// compute polynomial
   107  	poly, err := computePoly(aLs, aRs, sL, sR, y, z)
   108  	if err != nil {
   109  		return Proof{}, errors.Wrap(err, "[Prove] - poly")
   110  	}
   111  
   112  	// Compute T1 and T2
   113  	T1 := ped.CommitToScalar(poly.t1)
   114  	T2 := ped.CommitToScalar(poly.t2)
   115  
   116  	// update Fiat-Shamir
   117  	hs.Append(z.Bytes(), T1.Value.Bytes(), T2.Value.Bytes())
   118  
   119  	// compute x
   120  	x := computeX(hs)
   121  	// compute taux which is just the polynomial for the blinding factors at a point x
   122  	taux := computeTaux(x, z, T1.BlindingFactor, T2.BlindingFactor, Vs)
   123  	// compute mu
   124  	mu := computeMu(x, A.BlindingFactor, S.BlindingFactor)
   125  
   126  	// compute l dot r
   127  	l, err := poly.computeL(x)
   128  	if err != nil {
   129  		return Proof{}, errors.Wrap(err, "[Prove] - l")
   130  	}
   131  	r, err := poly.computeR(x)
   132  	if err != nil {
   133  		return Proof{}, errors.Wrap(err, "[Prove] - r")
   134  	}
   135  	t, err := vector.InnerProduct(l, r)
   136  	if err != nil {
   137  		return Proof{}, errors.Wrap(err, "[Prove] - t")
   138  	}
   139  
   140  	// START DEBUG
   141  	if debug {
   142  		err := debugProve(x, y, z, v, l, r, aLs, aRs, sL, sR)
   143  		if err != nil {
   144  			return Proof{}, errors.Wrap(err, "[Prove] - debugProve")
   145  		}
   146  
   147  		// DEBUG T0
   148  		testT0, err := debugT0(aLs, aRs, y, z)
   149  		if err != nil {
   150  			return Proof{}, errors.Wrap(err, "[Prove] - testT0")
   151  
   152  		}
   153  		if !testT0.Equals(&poly.t0) {
   154  			return Proof{}, errors.New("[Prove]: Test t0 value does not match the value calculated from the polynomial")
   155  		}
   156  
   157  		polyt0 := poly.computeT0(y, z, v, N, uint32(M))
   158  		if !polyt0.Equals(&poly.t0) {
   159  			return Proof{}, errors.New("[Prove]: t0 value from delta function, does not match the polynomial t0 value(Correct)")
   160  		}
   161  
   162  		tPoly := poly.eval(x)
   163  		if !t.Equals(&tPoly) {
   164  			return Proof{}, errors.New("[Prove]: The t value computed from the t-poly, does not match the t value computed from the inner product of l and r")
   165  		}
   166  	}
   167  	// End DEBUG
   168  
   169  	// check if any challenge scalars are zero
   170  	if x.IsNonZeroI() == 0 || y.IsNonZeroI() == 0 || z.IsNonZeroI() == 0 {
   171  		return Proof{}, errors.New("[Prove] - One of the challenge scalars, x, y, or z was equal to zero. Generate proof again")
   172  	}
   173  
   174  	hs.Append(x.Bytes(), taux.Bytes(), mu.Bytes(), t.Bytes())
   175  
   176  	// calculate inner product proof
   177  	Q := ristretto.Point{}
   178  	w := hs.Derive()
   179  	Q.ScalarMult(&ped.BasePoint, &w)
   180  
   181  	var yinv ristretto.Scalar
   182  	yinv.Inverse(&y)
   183  	Hpf := vector.ScalarPowers(yinv, uint32(N*M))
   184  
   185  	genData = append(genData, uint8(1))
   186  	ped2 := pedersen.New(genData)
   187  	ped2.BaseVector.Compute(uint32(N * M))
   188  
   189  	H := ped2.BaseVector.Bases
   190  	G := ped.BaseVector.Bases
   191  
   192  	ip, err := innerproduct.Generate(G, H, l, r, Hpf, Q)
   193  	if err != nil {
   194  		return Proof{}, errors.Wrap(err, "[Prove] -  ipproof")
   195  	}
   196  
   197  	return Proof{
   198  		V:       Vs,
   199  		A:       A.Value,
   200  		S:       S.Value,
   201  		T1:      T1.Value,
   202  		T2:      T2.Value,
   203  		t:       t,
   204  		taux:    taux,
   205  		mu:      mu,
   206  		IPProof: ip,
   207  	}, nil
   208  }
   209  
   210  // A = kH + aL*G + aR*H
   211  func computeA(ped *pedersen.Pedersen, aLs, aRs []ristretto.Scalar) pedersen.Commitment {
   212  
   213  	cA := ped.CommitToVectors(aLs, aRs)
   214  
   215  	return cA
   216  }
   217  
   218  // S = kH + sL*G + sR * H
   219  func computeS(ped *pedersen.Pedersen) (pedersen.Commitment, []ristretto.Scalar, []ristretto.Scalar) {
   220  
   221  	sL, sR := make([]ristretto.Scalar, N*M), make([]ristretto.Scalar, N*M)
   222  	for i := 0; i < N*M; i++ {
   223  		var randA ristretto.Scalar
   224  		randA.Rand()
   225  		sL[i] = randA
   226  
   227  		var randB ristretto.Scalar
   228  		randB.Rand()
   229  		sR[i] = randB
   230  	}
   231  
   232  	cS := ped.CommitToVectors(sL, sR)
   233  
   234  	return cS, sL, sR
   235  }
   236  
   237  func computeYAndZ(hs fiatshamir.HashCacher) (ristretto.Scalar, ristretto.Scalar) {
   238  
   239  	var y ristretto.Scalar
   240  	y.Derive(hs.Result())
   241  
   242  	var z ristretto.Scalar
   243  	z.Derive(y.Bytes())
   244  
   245  	return y, z
   246  }
   247  
   248  func computeX(hs fiatshamir.HashCacher) ristretto.Scalar {
   249  	var x ristretto.Scalar
   250  	x.Derive(hs.Result())
   251  	return x
   252  }
   253  
   254  // compute polynomial for blinding factors l61
   255  // N.B. tau1 means tau superscript 1
   256  // taux = t1Blind * x + t2Blind * x^2 + (sum(z^n+1 * vBlind[n-1])) from n = 1 to n = m
   257  func computeTaux(x, z, t1Blind, t2Blind ristretto.Scalar, vBlinds []pedersen.Commitment) ristretto.Scalar {
   258  	tau1X := t1Blind.Mul(&x, &t1Blind)
   259  
   260  	var xsq ristretto.Scalar
   261  	xsq.Square(&x)
   262  
   263  	tau2Xsq := t2Blind.Mul(&xsq, &t2Blind)
   264  
   265  	var zN ristretto.Scalar
   266  	zN.Square(&z) // start at zSq
   267  
   268  	var zNBlindSum ristretto.Scalar
   269  	zNBlindSum.SetZero()
   270  
   271  	for i := range vBlinds {
   272  		zNBlindSum.MulAdd(&zN, &vBlinds[i].BlindingFactor, &zNBlindSum)
   273  		zN.Mul(&zN, &z)
   274  	}
   275  
   276  	var res ristretto.Scalar
   277  	res.Add(tau1X, tau2Xsq)
   278  	res.Add(&res, &zNBlindSum)
   279  
   280  	return res
   281  }
   282  
   283  // alpha is the blinding factor for A
   284  // rho is the blinding factor for S
   285  // mu = alpha + rho * x
   286  func computeMu(x, alpha, rho ristretto.Scalar) ristretto.Scalar {
   287  
   288  	var mu ristretto.Scalar
   289  
   290  	mu.MulAdd(&rho, &x, &alpha)
   291  
   292  	return mu
   293  }
   294  
   295  // computeHprime will take a a slice of points H, with a scalar y
   296  // and return a slice of points Hprime,  such that Hprime = y^-n * H
   297  func computeHprime(H []ristretto.Point, y ristretto.Scalar) []ristretto.Point {
   298  	Hprimes := make([]ristretto.Point, len(H))
   299  
   300  	var yInv ristretto.Scalar
   301  	yInv.Inverse(&y)
   302  
   303  	invYInt := yInv.BigInt()
   304  
   305  	for i, p := range H {
   306  		// compute y^-i
   307  		var invYPowInt big.Int
   308  		invYPowInt.Exp(invYInt, big.NewInt(int64(i)), nil)
   309  
   310  		var invY ristretto.Scalar
   311  		invY.SetBigInt(&invYPowInt)
   312  
   313  		var hprime ristretto.Point
   314  		hprime.ScalarMult(&p, &invY)
   315  
   316  		Hprimes[i] = hprime
   317  	}
   318  
   319  	return Hprimes
   320  }
   321  
   322  // Verify takes a bullet proof and returns true only if the proof was valid
   323  func Verify(p Proof) (bool, error) {
   324  
   325  	genData := []byte("dusk.BulletProof.vec1")
   326  	ped := pedersen.New(genData)
   327  	ped.BaseVector.Compute(uint32(N * M))
   328  
   329  	genData = append(genData, uint8(1))
   330  
   331  	ped2 := pedersen.New(genData)
   332  	ped2.BaseVector.Compute(uint32(N * M))
   333  
   334  	G := ped.BaseVector.Bases
   335  	H := ped2.BaseVector.Bases
   336  
   337  	// Reconstruct the challenges
   338  	hs := fiatshamir.HashCacher{Cache: []byte{}}
   339  	for _, V := range p.V {
   340  		hs.Append(V.Value.Bytes())
   341  	}
   342  
   343  	hs.Append(p.A.Bytes(), p.S.Bytes())
   344  	y, z := computeYAndZ(hs)
   345  	hs.Append(z.Bytes(), p.T1.Bytes(), p.T2.Bytes())
   346  	x := computeX(hs)
   347  	hs.Append(x.Bytes(), p.taux.Bytes(), p.mu.Bytes(), p.t.Bytes())
   348  	w := hs.Derive()
   349  
   350  	return megacheckWithC(p.IPProof, p.mu, x, y, z, p.t, p.taux, w, p.A, ped.BasePoint, ped.BlindPoint, p.S, p.T1, p.T2, G, H, p.V)
   351  }
   352  
   353  func megacheckWithC(ipproof *innerproduct.Proof, mu, x, y, z, t, taux, w ristretto.Scalar, A, G, H, S, T1, T2 ristretto.Point, GVec, HVec []ristretto.Point, V []pedersen.Commitment) (bool, error) {
   354  
   355  	var c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11 ristretto.Point
   356  
   357  	var c ristretto.Scalar
   358  	c.Rand()
   359  
   360  	uSq, uInvSq, s := ipproof.VerifScalars()
   361  	sInv := make([]ristretto.Scalar, len(s))
   362  	copy(sInv, s)
   363  
   364  	// reverse s
   365  	for i, j := 0, len(sInv)-1; i < j; i, j = i+1, j-1 {
   366  		sInv[i], sInv[j] = sInv[j], sInv[i]
   367  	}
   368  
   369  	// g vector scalars : as + z points : G
   370  	as := vector.MulScalar(s, ipproof.A)
   371  	g := vector.AddScalar(as, z)
   372  	g = vector.MulScalar(g, c)
   373  
   374  	c1, err := vector.Exp(g, GVec, len(GVec), 1)
   375  	if err != nil {
   376  		return false, err
   377  	}
   378  
   379  	// h vector scalars : y Had (bsInv - zM2N) - z points : H
   380  	bs := vector.MulScalar(sInv, ipproof.B)
   381  	zAnd2 := sumZMTwoN(z)
   382  	h, err := vector.Sub(bs, zAnd2)
   383  	if err != nil {
   384  		return false, errors.Wrap(err, "[h1]")
   385  	}
   386  
   387  	var yinv ristretto.Scalar
   388  	yinv.Inverse(&y)
   389  	Hpf := vector.ScalarPowers(yinv, uint32(N*M))
   390  
   391  	h, err = vector.Hadamard(h, Hpf)
   392  	if err != nil {
   393  		return false, errors.Wrap(err, "[h2]")
   394  	}
   395  	h = vector.SubScalar(h, z)
   396  	h = vector.MulScalar(h, c)
   397  
   398  	c2, err = vector.Exp(h, HVec, len(HVec), 1)
   399  	if err != nil {
   400  		return false, err
   401  	}
   402  
   403  	// G basepoint gbp : (c * w(ab-t)) + t-D(y,z) point : G
   404  	delta := computeDelta(y, z, N, uint32(M))
   405  	var tMinusDelta ristretto.Scalar
   406  	tMinusDelta.Sub(&t, &delta)
   407  
   408  	var abMinusT ristretto.Scalar
   409  	abMinusT.Mul(&ipproof.A, &ipproof.B)
   410  	abMinusT.Sub(&abMinusT, &t)
   411  
   412  	var cw ristretto.Scalar
   413  	cw.Mul(&c, &w)
   414  
   415  	var gBP ristretto.Scalar
   416  	gBP.MulAdd(&cw, &abMinusT, &tMinusDelta)
   417  
   418  	c3.ScalarMult(&G, &gBP)
   419  
   420  	// H basepoint hbp : c * mu + taux point: H
   421  	var cmu ristretto.Scalar
   422  	cmu.Mul(&mu, &c)
   423  
   424  	var hBP ristretto.Scalar
   425  	hBP.Add(&cmu, &taux)
   426  
   427  	c4.ScalarMult(&H, &hBP)
   428  
   429  	// scalar :c point: A
   430  	c5.ScalarMult(&A, &c)
   431  
   432  	//  scalar: cx point : S
   433  	var cx ristretto.Scalar
   434  	cx.Mul(&c, &x)
   435  	c6.ScalarMult(&S, &cx)
   436  
   437  	// scalar: uSq challenges  points: Lj
   438  	c7, err = vector.Exp(uSq, ipproof.L, len(ipproof.L), 1)
   439  	if err != nil {
   440  		return false, err
   441  	}
   442  	c7.PublicScalarMult(&c7, &c)
   443  
   444  	// scalar : uInvSq challenges points: Rj
   445  	c8, err = vector.Exp(uInvSq, ipproof.R, len(ipproof.R), 1)
   446  	if err != nil {
   447  		return false, err
   448  	}
   449  	c8.PublicScalarMult(&c8, &c)
   450  
   451  	// scalar: z_j+2  points: Vj
   452  	zM := vector.ScalarPowers(z, uint32(M))
   453  	var zSq ristretto.Scalar
   454  	zSq.Square(&z)
   455  	zM = vector.MulScalar(zM, zSq)
   456  	c9.SetZero()
   457  	for i := range zM {
   458  		var temp ristretto.Point
   459  		temp.PublicScalarMult(&V[i].Value, &zM[i])
   460  		c9.Add(&c9, &temp)
   461  	}
   462  
   463  	// scalar : x point: T1
   464  	c10.PublicScalarMult(&T1, &x)
   465  
   466  	// scalar : xSq point: T2
   467  	var xSq ristretto.Scalar
   468  	xSq.Square(&x)
   469  	c11.PublicScalarMult(&T2, &xSq)
   470  
   471  	var sum ristretto.Point
   472  	sum.SetZero()
   473  	sum.Add(&c1, &c2)
   474  	sum.Add(&sum, &c3)
   475  	sum.Add(&sum, &c4)
   476  	sum.Sub(&sum, &c5)
   477  	sum.Sub(&sum, &c6)
   478  	sum.Sub(&sum, &c7)
   479  	sum.Sub(&sum, &c8)
   480  	sum.Sub(&sum, &c9)
   481  	sum.Sub(&sum, &c10)
   482  	sum.Sub(&sum, &c11)
   483  
   484  	var zero ristretto.Point
   485  	zero.SetZero()
   486  
   487  	ok := zero.Equals(&sum)
   488  	if !ok {
   489  		return false, errors.New("megacheck failed")
   490  	}
   491  
   492  	return true, nil
   493  }
   494  
   495  // Encode a Proof
   496  func (p *Proof) Encode(w io.Writer, includeCommits bool) error {
   497  
   498  	if includeCommits {
   499  		err := pedersen.EncodeCommitments(w, p.V)
   500  		if err != nil {
   501  			return err
   502  		}
   503  	}
   504  
   505  	err := binary.Write(w, binary.BigEndian, p.A.Bytes())
   506  	if err != nil {
   507  		return err
   508  	}
   509  	err = binary.Write(w, binary.BigEndian, p.S.Bytes())
   510  	if err != nil {
   511  		return err
   512  	}
   513  	err = binary.Write(w, binary.BigEndian, p.T1.Bytes())
   514  	if err != nil {
   515  		return err
   516  	}
   517  	err = binary.Write(w, binary.BigEndian, p.T2.Bytes())
   518  	if err != nil {
   519  		return err
   520  	}
   521  	err = binary.Write(w, binary.BigEndian, p.taux.Bytes())
   522  	if err != nil {
   523  		return err
   524  	}
   525  	err = binary.Write(w, binary.BigEndian, p.mu.Bytes())
   526  	if err != nil {
   527  		return err
   528  	}
   529  	err = binary.Write(w, binary.BigEndian, p.t.Bytes())
   530  	if err != nil {
   531  		return err
   532  	}
   533  	return p.IPProof.Encode(w)
   534  }
   535  
   536  // Decode a Proof
   537  func (p *Proof) Decode(r io.Reader, includeCommits bool) error {
   538  
   539  	if p == nil {
   540  		return errors.New("struct is nil")
   541  	}
   542  
   543  	if includeCommits {
   544  		comms, err := pedersen.DecodeCommitments(r)
   545  		if err != nil {
   546  			return err
   547  		}
   548  		p.V = comms
   549  	}
   550  
   551  	err := readerToPoint(r, &p.A)
   552  	if err != nil {
   553  		return err
   554  	}
   555  	err = readerToPoint(r, &p.S)
   556  	if err != nil {
   557  		return err
   558  	}
   559  	err = readerToPoint(r, &p.T1)
   560  	if err != nil {
   561  		return err
   562  	}
   563  	err = readerToPoint(r, &p.T2)
   564  	if err != nil {
   565  		return err
   566  	}
   567  	err = readerToScalar(r, &p.taux)
   568  	if err != nil {
   569  		return err
   570  	}
   571  	err = readerToScalar(r, &p.mu)
   572  	if err != nil {
   573  		return err
   574  	}
   575  	err = readerToScalar(r, &p.t)
   576  	if err != nil {
   577  		return err
   578  	}
   579  	p.IPProof = &innerproduct.Proof{}
   580  	return p.IPProof.Decode(r)
   581  }
   582  
   583  // Equals returns proof equality with commitments
   584  func (p *Proof) Equals(other Proof, includeCommits bool) bool {
   585  	if len(p.V) != len(other.V) && includeCommits {
   586  		return false
   587  	}
   588  
   589  	for i := range p.V {
   590  		ok := p.V[i].EqualValue(other.V[i])
   591  		if !ok {
   592  			return ok
   593  		}
   594  	}
   595  
   596  	ok := p.A.Equals(&other.A)
   597  	if !ok {
   598  		return ok
   599  	}
   600  	ok = p.S.Equals(&other.S)
   601  	if !ok {
   602  		return ok
   603  	}
   604  	ok = p.T1.Equals(&other.T1)
   605  	if !ok {
   606  		return ok
   607  	}
   608  	ok = p.T2.Equals(&other.T2)
   609  	if !ok {
   610  		return ok
   611  	}
   612  	ok = p.taux.Equals(&other.taux)
   613  	if !ok {
   614  		return ok
   615  	}
   616  	ok = p.mu.Equals(&other.mu)
   617  	if !ok {
   618  		return ok
   619  	}
   620  	ok = p.t.Equals(&other.t)
   621  	if !ok {
   622  		return ok
   623  	}
   624  	return true
   625  	// return p.IPProof.Equals(*other.IPProof)
   626  }
   627  
   628  func readerToPoint(r io.Reader, p *ristretto.Point) error {
   629  	var x [32]byte
   630  	err := binary.Read(r, binary.BigEndian, &x)
   631  	if err != nil {
   632  		return err
   633  	}
   634  	ok := p.SetBytes(&x)
   635  	if !ok {
   636  		return errors.New("point not encodable")
   637  	}
   638  	return nil
   639  }
   640  func readerToScalar(r io.Reader, s *ristretto.Scalar) error {
   641  	var x [32]byte
   642  	err := binary.Read(r, binary.BigEndian, &x)
   643  	if err != nil {
   644  		return err
   645  	}
   646  	s.SetBytes(&x)
   647  	return nil
   648  }