github.com/dusk-network/dusk-crypto@v0.1.3/mlsag/mlsag.go (about)

     1  package mlsag
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  
    10  	ristretto "github.com/bwesterb/go-ristretto"
    11  )
    12  
    13  type Signature struct {
    14  	c       ristretto.Scalar
    15  	r       []Responses
    16  	PubKeys []PubKeys
    17  	Msg     []byte
    18  }
    19  
    20  func (s *Signature) Encode(w io.Writer, encodeKeys bool) error {
    21  	err := binary.Write(w, binary.BigEndian, s.c.Bytes())
    22  	if err != nil {
    23  		return err
    24  	}
    25  
    26  	// lenR is the number of response vectors == num users = num pubkey vectors
    27  	lenR := uint32(len(s.r))
    28  	err = binary.Write(w, binary.BigEndian, lenR)
    29  	if err != nil {
    30  		return err
    31  	}
    32  
    33  	if lenR <= 0 {
    34  		return nil
    35  	}
    36  
    37  	// numResponses is the number of responses per user  == num pubkeys
    38  	numResponses := uint32(s.r[0].Len())
    39  	err = binary.Write(w, binary.BigEndian, numResponses)
    40  	if err != nil {
    41  		return err
    42  	}
    43  
    44  	// Encode the responses
    45  	for i := range s.r {
    46  		err = s.r[i].Encode(w)
    47  		if err != nil {
    48  			return err
    49  		}
    50  	}
    51  
    52  	if !encodeKeys {
    53  		return nil
    54  	}
    55  
    56  	// Encode the pubkeys
    57  	for i := range s.PubKeys {
    58  		err = s.PubKeys[i].Encode(w)
    59  		if err != nil {
    60  			return err
    61  		}
    62  	}
    63  
    64  	return nil
    65  }
    66  
    67  func (s *Signature) Decode(r io.Reader, decodeKeys bool) error {
    68  
    69  	if s == nil {
    70  		return errors.New("struct is nil")
    71  	}
    72  
    73  	err := readerToScalar(r, &s.c)
    74  	if err != nil {
    75  		return err
    76  	}
    77  
    78  	var lenR, numResponses uint32
    79  	err = binary.Read(r, binary.BigEndian, &lenR)
    80  	if err != nil {
    81  		return err
    82  	}
    83  	err = binary.Read(r, binary.BigEndian, &numResponses)
    84  	if err != nil {
    85  		return err
    86  	}
    87  
    88  	// Decode the responses
    89  	s.r = make([]Responses, lenR)
    90  	for i := uint32(0); i < lenR; i++ {
    91  		err = s.r[i].Decode(r, numResponses)
    92  		if err != nil {
    93  			return err
    94  		}
    95  	}
    96  
    97  	if !decodeKeys {
    98  		return nil
    99  	}
   100  
   101  	// Decode pubkeys
   102  	s.PubKeys = make([]PubKeys, lenR)
   103  	for i := uint32(0); i < lenR; i++ {
   104  		err = s.PubKeys[i].Decode(r, numResponses)
   105  		if err != nil {
   106  			return err
   107  		}
   108  	}
   109  	return nil
   110  }
   111  
   112  func (s Signature) Equals(other Signature, includeKeys bool) bool {
   113  	ok := s.c.Equals(&other.c)
   114  	if !ok {
   115  		return ok
   116  	}
   117  
   118  	for i := range s.r {
   119  		ok = s.r[i].Equals(other.r[i])
   120  		if !ok {
   121  			return ok
   122  		}
   123  	}
   124  
   125  	if !includeKeys {
   126  		return true
   127  	}
   128  
   129  	if len(s.PubKeys) != len(other.PubKeys) {
   130  		return false
   131  	}
   132  
   133  	for i := 0; i < len(s.PubKeys); i++ {
   134  		ok = s.PubKeys[i].Equals(other.PubKeys[i])
   135  		if !ok {
   136  			return ok
   137  		}
   138  	}
   139  	return true
   140  }
   141  
   142  func (proof *Proof) prove(skipLastKeyImage bool) (*Signature, []ristretto.Point, error) {
   143  
   144  	proof.addSignerPubKey()
   145  
   146  	// Shuffle the PubKeys and update the index for our corresponding key
   147  	err := proof.shuffleSet()
   148  	if err != nil {
   149  		return nil, nil, err
   150  	}
   151  
   152  	// Check that all key vectors are the same size in pubkey matrix
   153  	pubKeyVecLen := proof.privKeys.Len()
   154  	for i := range proof.pubKeysMatrix {
   155  		if proof.pubKeysMatrix[i].Len() != pubKeyVecLen {
   156  			return nil, []ristretto.Point{}, errors.New("all vectors in the pubkey matrix must be the same size")
   157  		}
   158  	}
   159  
   160  	keyImages := proof.calculateKeyImages(skipLastKeyImage)
   161  	nonces := generateNonces(len(proof.privKeys))
   162  
   163  	numUsers := len(proof.pubKeysMatrix)
   164  	numKeysPerUser := len(proof.privKeys)
   165  
   166  	// We will overwrite the signers responses
   167  	responses := generateResponses(numUsers, numKeysPerUser, proof.index)
   168  
   169  	// Let secretIndex = index of signer
   170  	secretIndex := proof.index
   171  
   172  	// Generate C_{secretIndex+1}
   173  	buf := &bytes.Buffer{}
   174  	buf.Write(proof.msg)
   175  	signersPubKeys := proof.pubKeysMatrix[secretIndex]
   176  
   177  	for i := 0; i < len(nonces); i++ {
   178  
   179  		nonce := nonces[i]
   180  
   181  		// P = nonce * G
   182  		var P ristretto.Point
   183  		P.ScalarMultBase(&nonce)
   184  		_, err = buf.Write(P.Bytes())
   185  		if err != nil {
   186  			return nil, nil, err
   187  		}
   188  	}
   189  
   190  	for i := 0; i < len(keyImages); i++ {
   191  
   192  		nonce := nonces[i]
   193  
   194  		// P = nonce * H(K)
   195  		var P, hK ristretto.Point
   196  		hK.Derive(signersPubKeys.keys[i].Bytes())
   197  		P.ScalarMult(&hK, &nonce)
   198  		_, err = buf.Write(P.Bytes())
   199  		if err != nil {
   200  			return nil, nil, err
   201  		}
   202  	}
   203  
   204  	var CjPlusOne ristretto.Scalar
   205  	CjPlusOne.Derive(buf.Bytes())
   206  
   207  	// generate challenges
   208  	challenges := make([]ristretto.Scalar, numUsers)
   209  	challenges[(secretIndex+1)%numUsers] = CjPlusOne
   210  
   211  	var prevChallenge ristretto.Scalar
   212  	prevChallenge.Set(&CjPlusOne)
   213  
   214  	for k := secretIndex + 2; k != (secretIndex+1)%numUsers; k = (k + 1) % numUsers {
   215  		i := k % numUsers
   216  
   217  		prevIndex := (i - 1) % numUsers
   218  		if prevIndex < 0 {
   219  			prevIndex = prevIndex + numUsers
   220  		}
   221  		fakeResponses := responses[prevIndex]
   222  		decoyPubKeys := proof.pubKeysMatrix[prevIndex]
   223  
   224  		c, err := generateChallenge(proof.msg, fakeResponses, keyImages, decoyPubKeys, prevChallenge)
   225  		if err != nil {
   226  			return nil, nil, err
   227  		}
   228  
   229  		challenges[i].Set(&c)
   230  		prevChallenge.Set(&c)
   231  	}
   232  
   233  	// Set the real response for signer
   234  	var realResponse Responses
   235  	for i := 0; i < numKeysPerUser; i++ {
   236  		challenge := challenges[proof.index]
   237  		privKey := proof.privKeys[i]
   238  		nonce := nonces[i]
   239  		var r ristretto.Scalar
   240  
   241  		// r = nonce - challenge*privKey
   242  		r.Mul(&challenge, &privKey)
   243  		r.Neg(&r)
   244  		r.Add(&r, &nonce)
   245  		realResponse.AddResponse(r)
   246  	}
   247  
   248  	// replace real response in responses array
   249  	responses[proof.index] = realResponse
   250  
   251  	sig := &Signature{
   252  		c:       challenges[0],
   253  		r:       responses,
   254  		PubKeys: proof.pubKeysMatrix,
   255  		Msg:     proof.msg,
   256  	}
   257  
   258  	return sig, keyImages, nil
   259  }
   260  
   261  func (sig *Signature) Verify(keyImages []ristretto.Point) (bool, error) {
   262  
   263  	if len(sig.PubKeys) == 0 || len(sig.r) == 0 || len(keyImages) == 0 {
   264  		return false, errors.New("cannot have zero length for responses, pubkeys or key images")
   265  	}
   266  
   267  	numUsers := len(sig.r)
   268  	index := 0
   269  
   270  	var prevChallenge = sig.c
   271  
   272  	for k := index + 1; k != (index)%numUsers; k = (k + 1) % numUsers {
   273  		i := k % numUsers
   274  		prevIndex := (i - 1) % numUsers
   275  		if prevIndex < 0 {
   276  			prevIndex = prevIndex + numUsers
   277  		}
   278  
   279  		fakeResponses := sig.r[prevIndex]
   280  		decoyPubKeys := sig.PubKeys[prevIndex]
   281  		challenge, err := generateChallenge(sig.Msg, fakeResponses, keyImages, decoyPubKeys, prevChallenge)
   282  		if err != nil {
   283  			return false, err
   284  		}
   285  		prevChallenge = challenge
   286  	}
   287  
   288  	// Calculate c'
   289  	prevIndex := (index - 1) % numUsers
   290  	if prevIndex < 0 {
   291  		prevIndex = prevIndex + numUsers
   292  	}
   293  	fakeResponses := sig.r[prevIndex]
   294  	decoyPubKeys := sig.PubKeys[prevIndex]
   295  
   296  	challenge, err := generateChallenge(sig.Msg, fakeResponses, keyImages, decoyPubKeys, prevChallenge)
   297  	if err != nil {
   298  		return false, err
   299  	}
   300  
   301  	if !challenge.Equals(&sig.c) {
   302  		return false, fmt.Errorf("c'0 does not equal c0, %s != %s", challenge.String(), sig.c.String())
   303  	}
   304  
   305  	return true, nil
   306  }
   307  
   308  func generateNonces(n int) []ristretto.Scalar {
   309  	var nonces []ristretto.Scalar
   310  	for i := 0; i < n; i++ {
   311  		var nonce ristretto.Scalar
   312  		nonce.Rand()
   313  		nonces = append(nonces, nonce)
   314  	}
   315  	return nonces
   316  }
   317  
   318  // XXX: Test should check that random numbers are not all zero
   319  //A bug in ristretto lib that may not be fixed
   320  // Check the same for points too
   321  // skip skips the singers responses
   322  func generateResponses(m int, n, skip int) []Responses {
   323  	var matrixResponses []Responses
   324  	for i := 0; i < m; i++ {
   325  		if i == skip {
   326  			matrixResponses = append(matrixResponses, Responses{})
   327  			continue
   328  		}
   329  		var resp Responses
   330  		for i := 0; i < n; i++ {
   331  			var r ristretto.Scalar
   332  			r.Rand()
   333  			resp.AddResponse(r)
   334  		}
   335  		matrixResponses = append(matrixResponses, resp)
   336  	}
   337  	return matrixResponses
   338  }
   339  
   340  func generateChallenge(
   341  	msg []byte,
   342  	respsonses Responses,
   343  	keyImages []ristretto.Point,
   344  	pubKeys PubKeys,
   345  	prevChallenge ristretto.Scalar) (ristretto.Scalar, error) {
   346  
   347  	buf := &bytes.Buffer{}
   348  	_, err := buf.Write(msg)
   349  	if err != nil {
   350  		return ristretto.Scalar{}, err
   351  	}
   352  
   353  	for i := 0; i < pubKeys.Len(); i++ {
   354  
   355  		r := respsonses[i]
   356  
   357  		// P = r * G + c * PubKey
   358  		var P, cK ristretto.Point
   359  		P.ScalarMultBase(&r)
   360  		cK.ScalarMult(&pubKeys.keys[i], &prevChallenge)
   361  		P.Add(&P, &cK)
   362  		_, err = buf.Write(P.Bytes())
   363  		if err != nil {
   364  			return ristretto.Scalar{}, err
   365  		}
   366  
   367  	}
   368  
   369  	for i := 0; i < len(keyImages); i++ {
   370  		r := respsonses[i]
   371  
   372  		// P = r * H(K) + c * Ki
   373  		var P, cK ristretto.Point
   374  		var hK ristretto.Point
   375  		hK.Derive(pubKeys.keys[i].Bytes())
   376  		P.ScalarMult(&hK, &r)
   377  		cK.ScalarMult(&keyImages[i], &prevChallenge)
   378  		P.Add(&P, &cK)
   379  		_, err = buf.Write(P.Bytes())
   380  		if err != nil {
   381  			return ristretto.Scalar{}, err
   382  		}
   383  	}
   384  
   385  	var challenge ristretto.Scalar
   386  	challenge.Derive(buf.Bytes())
   387  
   388  	return challenge, nil
   389  }
   390  
   391  func (proof *Proof) calculateKeyImages(skipLastKeyImage bool) []ristretto.Point {
   392  	var keyImages []ristretto.Point
   393  
   394  	privKeys := proof.privKeys
   395  	pubKeys := proof.signerPubKeys
   396  
   397  	for i := 0; i < len(privKeys); i++ {
   398  		keyImages = append(keyImages, CalculateKeyImage(privKeys[i], pubKeys.keys[i]))
   399  	}
   400  
   401  	if !skipLastKeyImage {
   402  		return keyImages
   403  	}
   404  
   405  	// Here we assume that there will be atleast one privkey
   406  	// which means there will be atleast one key image
   407  	keyImages = keyImages[:len(keyImages)-1]
   408  	return keyImages
   409  }
   410  
   411  func CalculateKeyImage(privKey ristretto.Scalar, pubKey ristretto.Point) ristretto.Point {
   412  	var keyImage ristretto.Point
   413  	keyImage.Set(&pubKey)
   414  	// P = H(xG)
   415  	keyImage.Derive(keyImage.Bytes())
   416  	// P = xH(P)
   417  	keyImage.ScalarMult(&keyImage, &privKey)
   418  	return keyImage
   419  }
   420  
   421  func isNumInList(x int, numList []int) bool {
   422  	for _, b := range numList {
   423  		if b == x {
   424  			return true
   425  		}
   426  	}
   427  	return false
   428  }
   429  
   430  func readerToPoint(r io.Reader, p *ristretto.Point) error {
   431  	var x [32]byte
   432  	err := binary.Read(r, binary.BigEndian, &x)
   433  	if err != nil {
   434  		return err
   435  	}
   436  	ok := p.SetBytes(&x)
   437  	if !ok {
   438  		return errors.New("point not encodable")
   439  	}
   440  	return nil
   441  }
   442  func readerToScalar(r io.Reader, s *ristretto.Scalar) error {
   443  	var x [32]byte
   444  	err := binary.Read(r, binary.BigEndian, &x)
   445  	if err != nil {
   446  		return err
   447  	}
   448  	s.SetBytes(&x)
   449  	return nil
   450  }