github.com/dusk-network/dusk-crypto@v0.1.3/rangeproof/innerproduct/innerproduct.go (about)

     1  package innerproduct
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"io"
     8  	"math/bits"
     9  
    10  	ristretto "github.com/bwesterb/go-ristretto"
    11  	"github.com/dusk-network/dusk-crypto/rangeproof/fiatshamir"
    12  	"github.com/dusk-network/dusk-crypto/rangeproof/vector"
    13  )
    14  
    15  // This is a reference of the innerProduct implementation at rust
    16  
    17  // Proof represents an innner product proof struct
    18  type Proof struct {
    19  	L, R []ristretto.Point
    20  	A, B ristretto.Scalar // a and b are capitalised so that they are exported, in paper it is `a``b`
    21  }
    22  
    23  // Generate generates an inner product proof or an error
    24  // if proof cannot be constucted
    25  func Generate(GVec, HVec []ristretto.Point, aVec, bVec, HprimeFactors []ristretto.Scalar, Q ristretto.Point) (*Proof, error) {
    26  	n := uint32(len(GVec))
    27  
    28  	// XXX : When n is not a power of two, will the bulletproof struct pad it
    29  	// or will the inner product proof struct?
    30  	if !isPower2(uint32(n)) {
    31  		return nil, errors.New("[IPProof]: size of n (NM) is not a power of 2")
    32  	}
    33  
    34  	a := make([]ristretto.Scalar, len(aVec))
    35  	copy(a, aVec)
    36  	b := make([]ristretto.Scalar, len(bVec))
    37  	copy(b, bVec)
    38  	G := make([]ristretto.Point, len(GVec))
    39  	copy(G, GVec)
    40  	H := make([]ristretto.Point, len(HVec))
    41  	copy(H, HVec)
    42  
    43  	hs := fiatshamir.HashCacher{Cache: []byte{}}
    44  
    45  	lgN := bits.TrailingZeros(nextPow2(uint(n)))
    46  
    47  	Lj := make([]ristretto.Point, 0, lgN)
    48  	Rj := make([]ristretto.Point, 0, lgN)
    49  
    50  	if n != 1 {
    51  		n = n / 2
    52  
    53  		aL, aR, err := vector.SplitScalars(a, n)
    54  		bL, bR, err := vector.SplitScalars(b, n)
    55  		GL, GR, err := vector.SplitPoints(G, n)
    56  		HL, HR, err := vector.SplitPoints(H, n)
    57  
    58  		cL, err := vector.InnerProduct(aL, bR)
    59  		if err != nil {
    60  
    61  		}
    62  		cR, err := vector.InnerProduct(aR, bL)
    63  		if err != nil {
    64  			return nil, err
    65  		}
    66  
    67  		// L = aL * GR + bR * HL * HPrime[0..n] + cL * Q = e1 + e2 + e3
    68  
    69  		e1, err := vector.Exp(aL, GR, int(n), 1)
    70  		if err != nil {
    71  			return nil, err
    72  
    73  		}
    74  
    75  		bRYi := make([]ristretto.Scalar, len(bR))
    76  		copy(bRYi, bR)
    77  
    78  		for i := range bRYi {
    79  			bRYi[i].Mul(&bRYi[i], &HprimeFactors[i])
    80  		}
    81  
    82  		e2, err := vector.Exp(bRYi, HL, int(n), 1)
    83  		if err != nil {
    84  			return nil, err
    85  		}
    86  
    87  		var e3 ristretto.Point
    88  		e3.ScalarMult(&Q, &cL)
    89  
    90  		var L ristretto.Point
    91  		L.SetZero()
    92  		L.Add(&e1, &e2)
    93  		L.Add(&L, &e3)
    94  
    95  		Lj = append(Lj, L)
    96  
    97  		// R = aR * GL + bL * HR * HPrimeFactors[n .. 2n] + cR * Q = e4 + e5 + e6
    98  
    99  		e4, err := vector.Exp(aR, GL, int(n), 1)
   100  		if err != nil {
   101  			return nil, err
   102  		}
   103  
   104  		bLYi := make([]ristretto.Scalar, len(bL))
   105  		copy(bLYi, bL)
   106  
   107  		for i := range bLYi {
   108  			bLYi[i].Mul(&bLYi[i], &HprimeFactors[uint32(i)+n])
   109  		}
   110  
   111  		e5, err := vector.Exp(bLYi, HR, int(n), 1)
   112  		if err != nil {
   113  			return nil, err
   114  		}
   115  
   116  		var e6 ristretto.Point
   117  		e6.ScalarMult(&Q, &cR)
   118  
   119  		var R ristretto.Point
   120  		R.SetZero()
   121  		R.Add(&e4, &e5)
   122  		R.Add(&R, &e6)
   123  		Rj = append(Rj, R)
   124  
   125  		hs.Append(L.Bytes(), R.Bytes())
   126  
   127  		u := hs.Derive()
   128  		var uinv ristretto.Scalar
   129  		uinv.Inverse(&u)
   130  
   131  		var a1, a2, b1, b2, h1a, h2a ristretto.Scalar
   132  		var g1, g2, h1, h2 ristretto.Point
   133  
   134  		for i := uint32(0); i < n; i++ {
   135  
   136  			a1.Mul(&aL[i], &u)
   137  			a2.Mul(&aR[i], &uinv)
   138  			aL[i].Add(&a1, &a2)
   139  
   140  			b1.Mul(&bL[i], &uinv)
   141  			b2.Mul(&bR[i], &u)
   142  			bL[i].Add(&b1, &b2)
   143  
   144  			g1.ScalarMult(&GL[i], &uinv)
   145  			g2.ScalarMult(&GR[i], &u)
   146  			GL[i].Add(&g1, &g2)
   147  
   148  			h1a.Mul(&HprimeFactors[i], &u)
   149  			h1.ScalarMult(&HL[i], &h1a)
   150  			h2a.Mul(&HprimeFactors[i+n], &uinv)
   151  			h2.ScalarMult(&HR[i], &h2a)
   152  			HL[i].Add(&h1, &h2)
   153  		}
   154  
   155  		a = aL
   156  		b = bL
   157  		G = GL
   158  		H = HL
   159  	}
   160  
   161  	for n > 1 {
   162  
   163  		n = n / 2
   164  
   165  		aL, aR, err := vector.SplitScalars(a, n)
   166  		bL, bR, err := vector.SplitScalars(b, n)
   167  		GL, GR, err := vector.SplitPoints(G, n)
   168  		HL, HR, err := vector.SplitPoints(H, n)
   169  
   170  		cL, err := vector.InnerProduct(aL, bR)
   171  		if err != nil {
   172  			return nil, err
   173  		}
   174  		cR, err := vector.InnerProduct(aR, bL)
   175  		if err != nil {
   176  			return nil, err
   177  		}
   178  
   179  		// L = aL * GR + bR * HL + cL * Q = e1 + e2 + e3
   180  
   181  		e1, err := vector.Exp(aL, GR, int(n), 1)
   182  		if err != nil {
   183  			return nil, err
   184  		}
   185  		e2, err := vector.Exp(bR, HL, int(n), 1)
   186  		if err != nil {
   187  			return nil, err
   188  		}
   189  		var e3 ristretto.Point
   190  		e3.ScalarMult(&Q, &cL)
   191  
   192  		var L ristretto.Point
   193  		L.SetZero()
   194  		L.Add(&e1, &e2)
   195  		L.Add(&L, &e3)
   196  
   197  		Lj = append(Lj, L)
   198  
   199  		// R = aR * GL + bL * HR + cR * Q = e4 + e5 + e6
   200  
   201  		e4, err := vector.Exp(aR, GL, int(n), 1)
   202  		if err != nil {
   203  			return nil, err
   204  		}
   205  		e5, err := vector.Exp(bL, HR, int(n), 1)
   206  		if err != nil {
   207  			return nil, err
   208  		}
   209  		var e6 ristretto.Point
   210  		e6.ScalarMult(&Q, &cR)
   211  
   212  		var R ristretto.Point
   213  		R.SetZero()
   214  		R.Add(&e4, &e5)
   215  		R.Add(&R, &e6)
   216  		Rj = append(Rj, R)
   217  
   218  		hs.Append(L.Bytes(), R.Bytes())
   219  
   220  		u := hs.Derive()
   221  		var uinv ristretto.Scalar
   222  		uinv.Inverse(&u)
   223  
   224  		// aL = aL * u + aR *uinv = a1 + a2 - aprime
   225  		// bL = bR * u + bL *uinv = b1 + b2 - bprime
   226  		// GL = GL * uinv + GR * u = g1 + g2 - gprime
   227  		// HL = HL * u + HR * uinv = h1 + h2 - hprime
   228  
   229  		var a1, a2, b1, b2 ristretto.Scalar
   230  		var g1, g2, h1, h2 ristretto.Point
   231  
   232  		for i := uint32(0); i < n; i++ {
   233  
   234  			a1.Mul(&aL[i], &u)
   235  			a2.Mul(&aR[i], &uinv)
   236  			aL[i].Add(&a1, &a2)
   237  
   238  			b1.Mul(&bL[i], &uinv)
   239  			b2.Mul(&bR[i], &u)
   240  			bL[i].Add(&b1, &b2)
   241  
   242  			g1.ScalarMult(&GL[i], &uinv)
   243  			g2.ScalarMult(&GR[i], &u)
   244  			GL[i].Add(&g1, &g2)
   245  
   246  			h1.ScalarMult(&HL[i], &u)
   247  			h2.ScalarMult(&HR[i], &uinv)
   248  			HL[i].Add(&h1, &h2)
   249  		}
   250  
   251  		a = aL
   252  		b = bL
   253  		G = GL
   254  		H = HL
   255  	}
   256  
   257  	return &Proof{
   258  		L: Lj,
   259  		R: Rj,
   260  		A: a[len(a)-1],
   261  		B: b[len(b)-1],
   262  	}, nil
   263  }
   264  
   265  // VerifScalars generates the challenge squared, the inverse challenge squared
   266  // and s for a given inner product proof
   267  func (proof *Proof) VerifScalars() ([]ristretto.Scalar, []ristretto.Scalar, []ristretto.Scalar) {
   268  	// generate scalars for verification
   269  
   270  	if len(proof.L) != len(proof.R) {
   271  		return nil, nil, nil
   272  	}
   273  
   274  	lgN := len(proof.L)
   275  	n := uint32(1 << uint(lgN))
   276  
   277  	hs := fiatshamir.HashCacher{Cache: []byte{}}
   278  
   279  	// 1. compute x's
   280  	xChals := make([]ristretto.Scalar, 0, lgN)
   281  	for k := range proof.L {
   282  		hs.Append(proof.L[k].Bytes(), proof.R[k].Bytes())
   283  		xChals = append(xChals, hs.Derive())
   284  	}
   285  
   286  	// 2. compute inverse of x's
   287  	invXChals := make([]ristretto.Scalar, 0, lgN)
   288  
   289  	var invProd ristretto.Scalar // this will be the product of all of the inverses
   290  	invProd.SetOne()
   291  
   292  	for k := range xChals {
   293  
   294  		var xinv ristretto.Scalar
   295  		xinv.Inverse(&xChals[k])
   296  
   297  		invProd.Mul(&invProd, &xinv)
   298  
   299  		invXChals = append(invXChals, xinv)
   300  	}
   301  
   302  	// 3. compute x^2 and inv(x)^2
   303  	chalSq := make([]ristretto.Scalar, 0, lgN)
   304  	invChalSq := make([]ristretto.Scalar, 0, lgN)
   305  
   306  	for k := range xChals {
   307  		var sq ristretto.Scalar
   308  		var invSq ristretto.Scalar
   309  
   310  		sq.Square(&xChals[k])
   311  		invSq.Square(&invXChals[k])
   312  
   313  		chalSq = append(chalSq, sq)
   314  		invChalSq = append(invChalSq, invSq)
   315  	}
   316  
   317  	// 4. compute s
   318  	s := make([]ristretto.Scalar, 0, n)
   319  
   320  	// push the inverse product
   321  	s = append(s, invProd)
   322  
   323  	for i := uint32(1); i < n; i++ {
   324  
   325  		lgI := 32 - 1 - bits.LeadingZeros32(i)
   326  		k := uint32(1 << uint(lgI))
   327  
   328  		uLgISq := chalSq[(lgN-1)-lgI]
   329  
   330  		var sRes ristretto.Scalar
   331  		sRes.Mul(&s[i-k], &uLgISq)
   332  		s = append(s, sRes)
   333  	}
   334  
   335  	return chalSq, invChalSq, s
   336  }
   337  
   338  // Verify is used for unit tests and verifies that a given proof evaluates to the point P
   339  func (proof *Proof) Verify(G, H, L, R []ristretto.Point, HprimeFactor []ristretto.Scalar, Q, P ristretto.Point, n int) bool {
   340  	uSq, uInvSq, s := proof.VerifScalars()
   341  
   342  	sInv := make([]ristretto.Scalar, len(s))
   343  	copy(sInv, s)
   344  
   345  	// reverse s
   346  	for i, j := 0, len(sInv)-1; i < j; i, j = i+1, j-1 {
   347  		sInv[i], sInv[j] = sInv[j], sInv[i]
   348  	}
   349  
   350  	aTimesS := vector.MulScalar(s, proof.A)
   351  	hTimesbDivS := vector.MulScalar(sInv, proof.B)
   352  	for i, bDivS := range hTimesbDivS {
   353  		hTimesbDivS[i].Mul(&bDivS, &HprimeFactor[i])
   354  	}
   355  
   356  	negUSq := make([]ristretto.Scalar, len(uSq))
   357  	for i := range negUSq {
   358  		negUSq[i].Neg(&uSq[i])
   359  	}
   360  
   361  	negUInvSq := make([]ristretto.Scalar, len(uInvSq))
   362  	for i := range negUInvSq {
   363  		negUInvSq[i].Neg(&uInvSq[i])
   364  	}
   365  
   366  	// Scalars
   367  	scalars := make([]ristretto.Scalar, 0)
   368  
   369  	var baseC ristretto.Scalar
   370  	baseC.Mul(&proof.A, &proof.B)
   371  
   372  	scalars = append(scalars, baseC)
   373  	scalars = append(scalars, aTimesS...)
   374  	scalars = append(scalars, hTimesbDivS...)
   375  	scalars = append(scalars, negUSq...)
   376  	scalars = append(scalars, negUInvSq...)
   377  
   378  	// Points
   379  	points := make([]ristretto.Point, 0)
   380  	points = append(points, Q)
   381  	points = append(points, G...)
   382  	points = append(points, H...)
   383  	points = append(points, proof.L...)
   384  	points = append(points, proof.R...)
   385  
   386  	have, err := vector.Exp(scalars, points, n, 1)
   387  	if err != nil {
   388  		return false
   389  	}
   390  	return have.Equals(&P)
   391  }
   392  
   393  // Encode a Proof
   394  func (proof *Proof) Encode(w io.Writer) error {
   395  
   396  	err := binary.Write(w, binary.BigEndian, proof.A.Bytes())
   397  	if err != nil {
   398  		return err
   399  	}
   400  	err = binary.Write(w, binary.BigEndian, proof.B.Bytes())
   401  	if err != nil {
   402  		return err
   403  	}
   404  	lenL := uint32(len(proof.L))
   405  
   406  	for i := uint32(0); i < lenL; i++ {
   407  		err = binary.Write(w, binary.BigEndian, proof.L[i].Bytes())
   408  		if err != nil {
   409  			return err
   410  		}
   411  		err = binary.Write(w, binary.BigEndian, proof.R[i].Bytes())
   412  		if err != nil {
   413  			return err
   414  		}
   415  	}
   416  	return nil
   417  }
   418  
   419  // Decode a Proof
   420  func (proof *Proof) Decode(r io.Reader) error {
   421  	if proof == nil {
   422  		return errors.New("struct is nil")
   423  	}
   424  
   425  	var ABytes, BBytes [32]byte
   426  	err := binary.Read(r, binary.BigEndian, &ABytes)
   427  	if err != nil {
   428  		return err
   429  	}
   430  	err = binary.Read(r, binary.BigEndian, &BBytes)
   431  	if err != nil {
   432  		return err
   433  	}
   434  	proof.A.SetBytes(&ABytes)
   435  	proof.B.SetBytes(&BBytes)
   436  
   437  	buf := &bytes.Buffer{}
   438  	_, err = buf.ReadFrom(r)
   439  	if err != nil {
   440  		return err
   441  	}
   442  	numBytes := len(buf.Bytes())
   443  	if numBytes%32 != 0 {
   444  		return errors.New("proof was not formatted correctly")
   445  	}
   446  	lenL := uint32(numBytes / 64)
   447  
   448  	proof.L = make([]ristretto.Point, lenL)
   449  	proof.R = make([]ristretto.Point, lenL)
   450  
   451  	for i := uint32(0); i < lenL; i++ {
   452  		var LBytes, RBytes [32]byte
   453  		err = binary.Read(buf, binary.BigEndian, &LBytes)
   454  		if err != nil {
   455  			return err
   456  		}
   457  		err = binary.Read(buf, binary.BigEndian, &RBytes)
   458  		if err != nil {
   459  			return err
   460  		}
   461  		proof.L[i].SetBytes(&LBytes)
   462  		proof.R[i].SetBytes(&RBytes)
   463  	}
   464  
   465  	return nil
   466  }
   467  
   468  // Equals test another proof for equality
   469  func (proof *Proof) Equals(other Proof) bool {
   470  	if ok := proof.A.Equals(&other.A); !ok {
   471  		return false
   472  	}
   473  
   474  	if ok := proof.B.Equals(&other.B); !ok {
   475  		return false
   476  	}
   477  
   478  	for i := range proof.L {
   479  		if ok := proof.L[i].Equals(&other.L[i]); !ok {
   480  			return false
   481  		}
   482  
   483  		if ok := proof.R[i].Equals(&other.R[i]); !ok {
   484  			return false
   485  		}
   486  	}
   487  
   488  	return true
   489  }
   490  
   491  func nextPow2(n uint) uint {
   492  	n--
   493  	n |= n >> 1
   494  	n |= n >> 2
   495  	n |= n >> 4
   496  	n |= n >> 8
   497  	n |= n >> 16
   498  	return n
   499  }
   500  
   501  func isPower2(n uint32) bool {
   502  	return (n & (n - 1)) == 0
   503  }
   504  
   505  // DiffNextPow2 checks the closest next pow2 and returns the necessary padding
   506  // amount to get to the that
   507  func DiffNextPow2(n uint32) uint32 {
   508  	pow2 := nextPow2(uint(n))
   509  	padAmount := uint32(pow2) - n + 1
   510  	return padAmount
   511  }