github.com/dusk-network/dusk-crypto@v0.1.3/mlsag/mlsag.go (about) 1 package mlsag 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "errors" 7 "fmt" 8 "io" 9 10 ristretto "github.com/bwesterb/go-ristretto" 11 ) 12 13 type Signature struct { 14 c ristretto.Scalar 15 r []Responses 16 PubKeys []PubKeys 17 Msg []byte 18 } 19 20 func (s *Signature) Encode(w io.Writer, encodeKeys bool) error { 21 err := binary.Write(w, binary.BigEndian, s.c.Bytes()) 22 if err != nil { 23 return err 24 } 25 26 // lenR is the number of response vectors == num users = num pubkey vectors 27 lenR := uint32(len(s.r)) 28 err = binary.Write(w, binary.BigEndian, lenR) 29 if err != nil { 30 return err 31 } 32 33 if lenR <= 0 { 34 return nil 35 } 36 37 // numResponses is the number of responses per user == num pubkeys 38 numResponses := uint32(s.r[0].Len()) 39 err = binary.Write(w, binary.BigEndian, numResponses) 40 if err != nil { 41 return err 42 } 43 44 // Encode the responses 45 for i := range s.r { 46 err = s.r[i].Encode(w) 47 if err != nil { 48 return err 49 } 50 } 51 52 if !encodeKeys { 53 return nil 54 } 55 56 // Encode the pubkeys 57 for i := range s.PubKeys { 58 err = s.PubKeys[i].Encode(w) 59 if err != nil { 60 return err 61 } 62 } 63 64 return nil 65 } 66 67 func (s *Signature) Decode(r io.Reader, decodeKeys bool) error { 68 69 if s == nil { 70 return errors.New("struct is nil") 71 } 72 73 err := readerToScalar(r, &s.c) 74 if err != nil { 75 return err 76 } 77 78 var lenR, numResponses uint32 79 err = binary.Read(r, binary.BigEndian, &lenR) 80 if err != nil { 81 return err 82 } 83 err = binary.Read(r, binary.BigEndian, &numResponses) 84 if err != nil { 85 return err 86 } 87 88 // Decode the responses 89 s.r = make([]Responses, lenR) 90 for i := uint32(0); i < lenR; i++ { 91 err = s.r[i].Decode(r, numResponses) 92 if err != nil { 93 return err 94 } 95 } 96 97 if !decodeKeys { 98 return nil 99 } 100 101 // Decode pubkeys 102 s.PubKeys = make([]PubKeys, lenR) 103 for i := uint32(0); i < lenR; i++ { 104 err = s.PubKeys[i].Decode(r, numResponses) 105 if err != nil { 106 return err 107 } 108 } 109 return nil 110 } 111 112 func (s Signature) Equals(other Signature, includeKeys bool) bool { 113 ok := s.c.Equals(&other.c) 114 if !ok { 115 return ok 116 } 117 118 for i := range s.r { 119 ok = s.r[i].Equals(other.r[i]) 120 if !ok { 121 return ok 122 } 123 } 124 125 if !includeKeys { 126 return true 127 } 128 129 if len(s.PubKeys) != len(other.PubKeys) { 130 return false 131 } 132 133 for i := 0; i < len(s.PubKeys); i++ { 134 ok = s.PubKeys[i].Equals(other.PubKeys[i]) 135 if !ok { 136 return ok 137 } 138 } 139 return true 140 } 141 142 func (proof *Proof) prove(skipLastKeyImage bool) (*Signature, []ristretto.Point, error) { 143 144 proof.addSignerPubKey() 145 146 // Shuffle the PubKeys and update the index for our corresponding key 147 err := proof.shuffleSet() 148 if err != nil { 149 return nil, nil, err 150 } 151 152 // Check that all key vectors are the same size in pubkey matrix 153 pubKeyVecLen := proof.privKeys.Len() 154 for i := range proof.pubKeysMatrix { 155 if proof.pubKeysMatrix[i].Len() != pubKeyVecLen { 156 return nil, []ristretto.Point{}, errors.New("all vectors in the pubkey matrix must be the same size") 157 } 158 } 159 160 keyImages := proof.calculateKeyImages(skipLastKeyImage) 161 nonces := generateNonces(len(proof.privKeys)) 162 163 numUsers := len(proof.pubKeysMatrix) 164 numKeysPerUser := len(proof.privKeys) 165 166 // We will overwrite the signers responses 167 responses := generateResponses(numUsers, numKeysPerUser, proof.index) 168 169 // Let secretIndex = index of signer 170 secretIndex := proof.index 171 172 // Generate C_{secretIndex+1} 173 buf := &bytes.Buffer{} 174 buf.Write(proof.msg) 175 signersPubKeys := proof.pubKeysMatrix[secretIndex] 176 177 for i := 0; i < len(nonces); i++ { 178 179 nonce := nonces[i] 180 181 // P = nonce * G 182 var P ristretto.Point 183 P.ScalarMultBase(&nonce) 184 _, err = buf.Write(P.Bytes()) 185 if err != nil { 186 return nil, nil, err 187 } 188 } 189 190 for i := 0; i < len(keyImages); i++ { 191 192 nonce := nonces[i] 193 194 // P = nonce * H(K) 195 var P, hK ristretto.Point 196 hK.Derive(signersPubKeys.keys[i].Bytes()) 197 P.ScalarMult(&hK, &nonce) 198 _, err = buf.Write(P.Bytes()) 199 if err != nil { 200 return nil, nil, err 201 } 202 } 203 204 var CjPlusOne ristretto.Scalar 205 CjPlusOne.Derive(buf.Bytes()) 206 207 // generate challenges 208 challenges := make([]ristretto.Scalar, numUsers) 209 challenges[(secretIndex+1)%numUsers] = CjPlusOne 210 211 var prevChallenge ristretto.Scalar 212 prevChallenge.Set(&CjPlusOne) 213 214 for k := secretIndex + 2; k != (secretIndex+1)%numUsers; k = (k + 1) % numUsers { 215 i := k % numUsers 216 217 prevIndex := (i - 1) % numUsers 218 if prevIndex < 0 { 219 prevIndex = prevIndex + numUsers 220 } 221 fakeResponses := responses[prevIndex] 222 decoyPubKeys := proof.pubKeysMatrix[prevIndex] 223 224 c, err := generateChallenge(proof.msg, fakeResponses, keyImages, decoyPubKeys, prevChallenge) 225 if err != nil { 226 return nil, nil, err 227 } 228 229 challenges[i].Set(&c) 230 prevChallenge.Set(&c) 231 } 232 233 // Set the real response for signer 234 var realResponse Responses 235 for i := 0; i < numKeysPerUser; i++ { 236 challenge := challenges[proof.index] 237 privKey := proof.privKeys[i] 238 nonce := nonces[i] 239 var r ristretto.Scalar 240 241 // r = nonce - challenge*privKey 242 r.Mul(&challenge, &privKey) 243 r.Neg(&r) 244 r.Add(&r, &nonce) 245 realResponse.AddResponse(r) 246 } 247 248 // replace real response in responses array 249 responses[proof.index] = realResponse 250 251 sig := &Signature{ 252 c: challenges[0], 253 r: responses, 254 PubKeys: proof.pubKeysMatrix, 255 Msg: proof.msg, 256 } 257 258 return sig, keyImages, nil 259 } 260 261 func (sig *Signature) Verify(keyImages []ristretto.Point) (bool, error) { 262 263 if len(sig.PubKeys) == 0 || len(sig.r) == 0 || len(keyImages) == 0 { 264 return false, errors.New("cannot have zero length for responses, pubkeys or key images") 265 } 266 267 numUsers := len(sig.r) 268 index := 0 269 270 var prevChallenge = sig.c 271 272 for k := index + 1; k != (index)%numUsers; k = (k + 1) % numUsers { 273 i := k % numUsers 274 prevIndex := (i - 1) % numUsers 275 if prevIndex < 0 { 276 prevIndex = prevIndex + numUsers 277 } 278 279 fakeResponses := sig.r[prevIndex] 280 decoyPubKeys := sig.PubKeys[prevIndex] 281 challenge, err := generateChallenge(sig.Msg, fakeResponses, keyImages, decoyPubKeys, prevChallenge) 282 if err != nil { 283 return false, err 284 } 285 prevChallenge = challenge 286 } 287 288 // Calculate c' 289 prevIndex := (index - 1) % numUsers 290 if prevIndex < 0 { 291 prevIndex = prevIndex + numUsers 292 } 293 fakeResponses := sig.r[prevIndex] 294 decoyPubKeys := sig.PubKeys[prevIndex] 295 296 challenge, err := generateChallenge(sig.Msg, fakeResponses, keyImages, decoyPubKeys, prevChallenge) 297 if err != nil { 298 return false, err 299 } 300 301 if !challenge.Equals(&sig.c) { 302 return false, fmt.Errorf("c'0 does not equal c0, %s != %s", challenge.String(), sig.c.String()) 303 } 304 305 return true, nil 306 } 307 308 func generateNonces(n int) []ristretto.Scalar { 309 var nonces []ristretto.Scalar 310 for i := 0; i < n; i++ { 311 var nonce ristretto.Scalar 312 nonce.Rand() 313 nonces = append(nonces, nonce) 314 } 315 return nonces 316 } 317 318 // XXX: Test should check that random numbers are not all zero 319 //A bug in ristretto lib that may not be fixed 320 // Check the same for points too 321 // skip skips the singers responses 322 func generateResponses(m int, n, skip int) []Responses { 323 var matrixResponses []Responses 324 for i := 0; i < m; i++ { 325 if i == skip { 326 matrixResponses = append(matrixResponses, Responses{}) 327 continue 328 } 329 var resp Responses 330 for i := 0; i < n; i++ { 331 var r ristretto.Scalar 332 r.Rand() 333 resp.AddResponse(r) 334 } 335 matrixResponses = append(matrixResponses, resp) 336 } 337 return matrixResponses 338 } 339 340 func generateChallenge( 341 msg []byte, 342 respsonses Responses, 343 keyImages []ristretto.Point, 344 pubKeys PubKeys, 345 prevChallenge ristretto.Scalar) (ristretto.Scalar, error) { 346 347 buf := &bytes.Buffer{} 348 _, err := buf.Write(msg) 349 if err != nil { 350 return ristretto.Scalar{}, err 351 } 352 353 for i := 0; i < pubKeys.Len(); i++ { 354 355 r := respsonses[i] 356 357 // P = r * G + c * PubKey 358 var P, cK ristretto.Point 359 P.ScalarMultBase(&r) 360 cK.ScalarMult(&pubKeys.keys[i], &prevChallenge) 361 P.Add(&P, &cK) 362 _, err = buf.Write(P.Bytes()) 363 if err != nil { 364 return ristretto.Scalar{}, err 365 } 366 367 } 368 369 for i := 0; i < len(keyImages); i++ { 370 r := respsonses[i] 371 372 // P = r * H(K) + c * Ki 373 var P, cK ristretto.Point 374 var hK ristretto.Point 375 hK.Derive(pubKeys.keys[i].Bytes()) 376 P.ScalarMult(&hK, &r) 377 cK.ScalarMult(&keyImages[i], &prevChallenge) 378 P.Add(&P, &cK) 379 _, err = buf.Write(P.Bytes()) 380 if err != nil { 381 return ristretto.Scalar{}, err 382 } 383 } 384 385 var challenge ristretto.Scalar 386 challenge.Derive(buf.Bytes()) 387 388 return challenge, nil 389 } 390 391 func (proof *Proof) calculateKeyImages(skipLastKeyImage bool) []ristretto.Point { 392 var keyImages []ristretto.Point 393 394 privKeys := proof.privKeys 395 pubKeys := proof.signerPubKeys 396 397 for i := 0; i < len(privKeys); i++ { 398 keyImages = append(keyImages, CalculateKeyImage(privKeys[i], pubKeys.keys[i])) 399 } 400 401 if !skipLastKeyImage { 402 return keyImages 403 } 404 405 // Here we assume that there will be atleast one privkey 406 // which means there will be atleast one key image 407 keyImages = keyImages[:len(keyImages)-1] 408 return keyImages 409 } 410 411 func CalculateKeyImage(privKey ristretto.Scalar, pubKey ristretto.Point) ristretto.Point { 412 var keyImage ristretto.Point 413 keyImage.Set(&pubKey) 414 // P = H(xG) 415 keyImage.Derive(keyImage.Bytes()) 416 // P = xH(P) 417 keyImage.ScalarMult(&keyImage, &privKey) 418 return keyImage 419 } 420 421 func isNumInList(x int, numList []int) bool { 422 for _, b := range numList { 423 if b == x { 424 return true 425 } 426 } 427 return false 428 } 429 430 func readerToPoint(r io.Reader, p *ristretto.Point) error { 431 var x [32]byte 432 err := binary.Read(r, binary.BigEndian, &x) 433 if err != nil { 434 return err 435 } 436 ok := p.SetBytes(&x) 437 if !ok { 438 return errors.New("point not encodable") 439 } 440 return nil 441 } 442 func readerToScalar(r io.Reader, s *ristretto.Scalar) error { 443 var x [32]byte 444 err := binary.Read(r, binary.BigEndian, &x) 445 if err != nil { 446 return err 447 } 448 s.SetBytes(&x) 449 return nil 450 }