github.com/dusk-network/dusk-crypto@v0.1.3/rangeproof/poly.go (about)

     1  package rangeproof
     2  
     3  import (
     4  	"math/big"
     5  
     6  	"github.com/pkg/errors"
     7  
     8  	ristretto "github.com/bwesterb/go-ristretto"
     9  	"github.com/dusk-network/dusk-crypto/rangeproof/vector"
    10  )
    11  
    12  // Polynomial construction
    13  type polynomial struct {
    14  	l0, l1, r0, r1 []ristretto.Scalar
    15  	t0, t1, t2     ristretto.Scalar
    16  }
    17  
    18  func computePoly(aL, aR, sL, sR []ristretto.Scalar, y, z ristretto.Scalar) (*polynomial, error) {
    19  
    20  	// calculate l_0
    21  	l0 := vector.SubScalar(aL, z)
    22  
    23  	// calculate l_1
    24  	l1 := sL
    25  
    26  	// calculate r_0
    27  	yNM := vector.ScalarPowers(y, uint32(N*M))
    28  
    29  	zMTwoN := sumZMTwoN(z)
    30  
    31  	r0 := vector.AddScalar(aR, z)
    32  
    33  	r0, err := vector.Hadamard(r0, yNM)
    34  	if err != nil {
    35  		return nil, errors.Wrap(err, "[ComputePoly] - r0 (1)")
    36  	}
    37  	r0, err = vector.Add(r0, zMTwoN)
    38  	if err != nil {
    39  		return nil, errors.Wrap(err, "[ComputePoly] - r0 (2)")
    40  	}
    41  
    42  	// calculate r_1
    43  	r1, err := vector.Hadamard(yNM, sR)
    44  	if err != nil {
    45  		return nil, errors.Wrap(err, "[ComputePoly] - r1")
    46  	}
    47  
    48  	// calculate t0 // t_0 = <l_0, r_0>
    49  	t0, err := vector.InnerProduct(l0, r0)
    50  	if err != nil {
    51  		return nil, errors.Wrap(err, "[ComputePoly] - t0")
    52  	}
    53  
    54  	// calculate t1 // t_1 = <l_0, r_1> + <l_1, r_0>
    55  	t1Left, err := vector.InnerProduct(l1, r0)
    56  	if err != nil {
    57  		return nil, errors.Wrap(err, "[ComputePoly] - t1Left")
    58  	}
    59  	t1Right, err := vector.InnerProduct(l0, r1)
    60  	if err != nil {
    61  		return nil, errors.Wrap(err, "[ComputePoly] - t1Right")
    62  	}
    63  	var t1 ristretto.Scalar
    64  	t1.Add(&t1Left, &t1Right)
    65  
    66  	// calculate t2 // t_2 = <l_1, r_1>
    67  	t2, err := vector.InnerProduct(l1, r1)
    68  	if err != nil {
    69  		return nil, errors.Wrap(err, "[ComputePoly] - t2")
    70  	}
    71  	return &polynomial{
    72  		l0: l0,
    73  		l1: l1[:],
    74  		r0: r0,
    75  		r1: r1,
    76  		t0: t0,
    77  		t1: t1,
    78  		t2: t2,
    79  	}, nil
    80  }
    81  
    82  // evalute the polynomial with coefficients t
    83  // t = t_0 + t_1 * x + t_2 x^2
    84  func (p *polynomial) eval(x ristretto.Scalar) ristretto.Scalar {
    85  
    86  	var t1x ristretto.Scalar
    87  	t1x.Mul(&x, &p.t1)
    88  
    89  	var xsq ristretto.Scalar
    90  	xsq.Square(&x)
    91  
    92  	var t2xsq ristretto.Scalar
    93  	t2xsq.Mul(&xsq, &p.t2)
    94  
    95  	var t ristretto.Scalar
    96  	t.Add(&t1x, &t2xsq)
    97  	t.Add(&t, &p.t0)
    98  
    99  	return t
   100  }
   101  
   102  // l = l_0 + l_1 * x
   103  func (p *polynomial) computeL(x ristretto.Scalar) ([]ristretto.Scalar, error) {
   104  
   105  	lLeft := p.l0
   106  
   107  	lRight := vector.MulScalar(p.l1, x)
   108  
   109  	l, err := vector.Add(lLeft, lRight)
   110  	if err != nil {
   111  		return nil, errors.Wrap(err, "[ComputeL]")
   112  	}
   113  	return l, nil
   114  }
   115  
   116  // r = r_0 + r_1 * x
   117  func (p *polynomial) computeR(x ristretto.Scalar) ([]ristretto.Scalar, error) {
   118  	rLeft := p.r0
   119  
   120  	rRight := vector.MulScalar(p.r1, x)
   121  
   122  	r, err := vector.Add(rLeft, rRight)
   123  	if err != nil {
   124  		return nil, errors.Wrap(err, "[computeR]")
   125  	}
   126  	return r, nil
   127  }
   128  
   129  // t_0 = z^2 * v + D(y,z)
   130  func (p *polynomial) computeT0(y, z ristretto.Scalar, v []ristretto.Scalar, n, m uint32) ristretto.Scalar {
   131  
   132  	delta := computeDelta(y, z, n, uint32(m))
   133  
   134  	var zSq ristretto.Scalar
   135  	zSq.Square(&z)
   136  
   137  	zM := vector.ScalarPowers(z, uint32(len(v)))
   138  	zM = vector.MulScalar(zM, zSq)
   139  
   140  	var sumZmV ristretto.Scalar
   141  	sumZmV.SetZero()
   142  
   143  	for i := range v {
   144  		sumZmV.MulAdd(&zM[i], &v[i], &sumZmV)
   145  	}
   146  
   147  	var t0 ristretto.Scalar
   148  	t0.SetZero()
   149  
   150  	t0.Add(&delta, &sumZmV)
   151  
   152  	return t0
   153  }
   154  
   155  // calculates sum( z^(1+j) * ( 0^(j-1)n || 2 ^n || 0^(m-j)n ) ) from j = 1 to j=M (71)
   156  // implementation taken directly from java implementation.
   157  // XXX: Look into ways to speed this up, and improve readability
   158  // XXX: pass n and m as parameters
   159  func sumZMTwoN(z ristretto.Scalar) []ristretto.Scalar {
   160  
   161  	res := make([]ristretto.Scalar, N*M)
   162  
   163  	zM := vector.ScalarPowers(z, uint32(M+3))
   164  
   165  	var two ristretto.Scalar
   166  	two.SetBigInt(big.NewInt(2))
   167  	twoN := vector.ScalarPowers(two, N)
   168  
   169  	for i := 0; i < M*N; i++ {
   170  		res[i].SetZero()
   171  		for j := 1; j <= M; j++ {
   172  			if (i >= (j-1)*N) && (i < j*N) {
   173  				res[i].MulAdd(&zM[j+1], &twoN[i-(j-1)*N], &res[i])
   174  			}
   175  		}
   176  
   177  	}
   178  	return res
   179  
   180  	// Below is the old variation of the above code , for clarity on what the above is doing
   181  
   182  	// res := make([]ristretto.Scalar, 0, N*M)
   183  
   184  	// var zSq ristretto.Scalar
   185  	// zSq.Square(&z)
   186  
   187  	// zM := vector.ScalarPowers(z, uint32(M))
   188  	// zM = vector.MulScalar(zM, z)
   189  
   190  	// var two ristretto.Scalar
   191  	// two.SetBigInt(big.NewInt(2))
   192  	// twoN := vector.ScalarPowers(two, N)
   193  
   194  	// for i := 0; i < M; i++ {
   195  	// 	a := vector.MulScalar(twoN, zM[i])
   196  	// 	res = append(res, a...)
   197  	// }
   198  
   199  	// return res
   200  
   201  }
   202  
   203  // D(y,z) - This is the data shared by both prover and verifier
   204  // ported from rust impl
   205  func computeDelta(y, z ristretto.Scalar, n, m uint32) ristretto.Scalar {
   206  
   207  	var res ristretto.Scalar
   208  	res.SetZero()
   209  
   210  	sumY := vector.ScalarPowersSum(y, uint64(n*m))
   211  	sumZ := vector.ScalarPowersSum(z, uint64(m))
   212  
   213  	var two ristretto.Scalar
   214  	two.SetBigInt(big.NewInt(2))
   215  	sum2 := vector.ScalarPowersSum(two, uint64(n))
   216  	var zsq ristretto.Scalar
   217  	zsq.Square(&z)
   218  
   219  	var zcu ristretto.Scalar
   220  	zcu.Mul(&z, &zsq)
   221  
   222  	var resA, resB ristretto.Scalar
   223  	resA.Sub(&z, &zsq)
   224  
   225  	resA.Mul(&resA, &sumY)
   226  
   227  	resB.Mul(&sum2, &sumZ)
   228  	resB.Mul(&resB, &zcu)
   229  
   230  	res.Sub(&resA, &resB)
   231  
   232  	return res
   233  }