github.com/consensys/gnark-crypto@v0.14.0/ecc/stark-curve/marshal.go (about) 1 // Copyright 2020 ConsenSys Software Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // FOO 16 17 package starkcurve 18 19 import ( 20 "encoding/binary" 21 "errors" 22 "io" 23 "reflect" 24 "sync/atomic" 25 26 "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" 27 "github.com/consensys/gnark-crypto/ecc/stark-curve/fr" 28 "github.com/consensys/gnark-crypto/internal/parallel" 29 ) 30 31 // To encode G1Affine points, we mask the most significant bits with these bits to specify without ambiguity 32 // metadata needed for point (de)compression 33 // we have less than 3 bits available on the msw, so we can't follow BLS12-381 style encoding. 34 // the difference is the case where a point is infinity and uncompressed is not flagged 35 const ( 36 mMask byte = 0b11 << 6 37 mUncompressed byte = 0b00 << 6 38 mCompressedSmallest byte = 0b10 << 6 39 mCompressedLargest byte = 0b11 << 6 40 mCompressedInfinity byte = 0b01 << 6 41 ) 42 43 // Encoder writes stark-curve object values to an output stream 44 type Encoder struct { 45 w io.Writer 46 n int64 // written bytes 47 raw bool // raw vs compressed encoding 48 } 49 50 // Decoder reads stark-curve object values from an inbound stream 51 type Decoder struct { 52 r io.Reader 53 n int64 // read bytes 54 subGroupCheck bool // default to true 55 } 56 57 // NewDecoder returns a binary decoder supporting curve stark-curve objects in both 58 // compressed and uncompressed (raw) forms 59 func NewDecoder(r io.Reader, options ...func(*Decoder)) *Decoder { 60 d := &Decoder{r: r, subGroupCheck: true} 61 62 for _, o := range options { 63 o(d) 64 } 65 66 return d 67 } 68 69 // Decode reads the binary encoding of v from the stream 70 // type must be *uint64, *fr.Element, *fp.Element, *G1Affine or *[]G1Affine 71 func (dec *Decoder) Decode(v interface{}) (err error) { 72 rv := reflect.ValueOf(v) 73 if v == nil || rv.Kind() != reflect.Ptr || rv.IsNil() || !rv.Elem().CanSet() { 74 return errors.New("stark-curve decoder: unsupported type, need pointer") 75 } 76 77 // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap 78 // in particular, careful attention must be given to usage of Bytes() method on Elements and Points 79 // that return an array (not a slice) of bytes. Using this is beneficial to minimize memallocs 80 // in very large (de)serialization upstream in gnark. 81 // (but detrimental to code visibility here) 82 83 var buf [SizeOfG1AffineUncompressed]byte 84 var read int 85 86 switch t := v.(type) { 87 case *fr.Element: 88 read, err = io.ReadFull(dec.r, buf[:fr.Bytes]) 89 dec.n += int64(read) 90 if err != nil { 91 return 92 } 93 err = t.SetBytesCanonical(buf[:fr.Bytes]) 94 return 95 case *fp.Element: 96 read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) 97 dec.n += int64(read) 98 if err != nil { 99 return 100 } 101 err = t.SetBytesCanonical(buf[:fp.Bytes]) 102 return 103 case *[]fr.Element: 104 var sliceLen uint32 105 sliceLen, err = dec.readUint32() 106 if err != nil { 107 return 108 } 109 if len(*t) != int(sliceLen) { 110 *t = make([]fr.Element, sliceLen) 111 } 112 113 for i := 0; i < len(*t); i++ { 114 read, err = io.ReadFull(dec.r, buf[:fr.Bytes]) 115 dec.n += int64(read) 116 if err != nil { 117 return 118 } 119 if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { 120 return 121 } 122 } 123 return 124 case *[]fp.Element: 125 var sliceLen uint32 126 sliceLen, err = dec.readUint32() 127 if err != nil { 128 return 129 } 130 if len(*t) != int(sliceLen) { 131 *t = make([]fp.Element, sliceLen) 132 } 133 134 for i := 0; i < len(*t); i++ { 135 read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) 136 dec.n += int64(read) 137 if err != nil { 138 return 139 } 140 if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { 141 return 142 } 143 } 144 return 145 case *G1Affine: 146 // we start by reading compressed point size, if metadata tells us it is uncompressed, we read more. 147 read, err = io.ReadFull(dec.r, buf[:SizeOfG1AffineCompressed]) 148 dec.n += int64(read) 149 if err != nil { 150 return 151 } 152 nbBytes := SizeOfG1AffineCompressed 153 // most significant byte contains metadata 154 if !isCompressed(buf[0]) { 155 nbBytes = SizeOfG1AffineUncompressed 156 // we read more. 157 read, err = io.ReadFull(dec.r, buf[SizeOfG1AffineCompressed:SizeOfG1AffineUncompressed]) 158 dec.n += int64(read) 159 if err != nil { 160 return 161 } 162 } 163 _, err = t.setBytes(buf[:nbBytes], dec.subGroupCheck) 164 return 165 case *[]G1Affine: 166 var sliceLen uint32 167 sliceLen, err = dec.readUint32() 168 if err != nil { 169 return 170 } 171 if len(*t) != int(sliceLen) { 172 *t = make([]G1Affine, sliceLen) 173 } 174 compressed := make([]bool, sliceLen) 175 for i := 0; i < len(*t); i++ { 176 177 // we start by reading compressed point size, if metadata tells us it is uncompressed, we read more. 178 read, err = io.ReadFull(dec.r, buf[:SizeOfG1AffineCompressed]) 179 dec.n += int64(read) 180 if err != nil { 181 return 182 } 183 nbBytes := SizeOfG1AffineCompressed 184 // most significant byte contains metadata 185 if !isCompressed(buf[0]) { 186 nbBytes = SizeOfG1AffineUncompressed 187 // we read more. 188 read, err = io.ReadFull(dec.r, buf[SizeOfG1AffineCompressed:SizeOfG1AffineUncompressed]) 189 dec.n += int64(read) 190 if err != nil { 191 return 192 } 193 _, err = (*t)[i].setBytes(buf[:nbBytes], false) 194 if err != nil { 195 return 196 } 197 } else { 198 var r bool 199 if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { 200 return 201 } 202 compressed[i] = !r 203 } 204 } 205 var nbErrs uint64 206 parallel.Execute(len(compressed), func(start, end int) { 207 for i := start; i < end; i++ { 208 if compressed[i] { 209 if err := (*t)[i].unsafeComputeY(dec.subGroupCheck); err != nil { 210 atomic.AddUint64(&nbErrs, 1) 211 } 212 } else if dec.subGroupCheck { 213 if !(*t)[i].IsInSubGroup() { 214 atomic.AddUint64(&nbErrs, 1) 215 } 216 } 217 } 218 }) 219 if nbErrs != 0 { 220 return errors.New("point decompression failed") 221 } 222 223 return nil 224 default: 225 n := binary.Size(t) 226 if n == -1 { 227 return errors.New("stark-curve encoder: unsupported type") 228 } 229 err = binary.Read(dec.r, binary.BigEndian, t) 230 if err == nil { 231 dec.n += int64(n) 232 } 233 return 234 } 235 } 236 237 // BytesRead return total bytes read from reader 238 func (dec *Decoder) BytesRead() int64 { 239 return dec.n 240 } 241 242 func (dec *Decoder) readUint32() (r uint32, err error) { 243 var read int 244 var buf [4]byte 245 read, err = io.ReadFull(dec.r, buf[:4]) 246 dec.n += int64(read) 247 if err != nil { 248 return 249 } 250 r = binary.BigEndian.Uint32(buf[:4]) 251 return 252 } 253 254 func isCompressed(msb byte) bool { 255 mData := msb & mMask 256 return !(mData == mUncompressed) 257 } 258 259 // NewEncoder returns a binary encoder supporting curve stark-curve objects 260 func NewEncoder(w io.Writer, options ...func(*Encoder)) *Encoder { 261 // default settings 262 enc := &Encoder{ 263 w: w, 264 n: 0, 265 raw: false, 266 } 267 268 // handle options 269 for _, option := range options { 270 option(enc) 271 } 272 273 return enc 274 } 275 276 // Encode writes the binary encoding of v to the stream 277 // type must be uint64, *fr.Element, *fp.Element, *G1Affine, *G2Affine, []G1Affine or []G2Affine 278 func (enc *Encoder) Encode(v interface{}) (err error) { 279 if enc.raw { 280 return enc.encodeRaw(v) 281 } 282 return enc.encode(v) 283 } 284 285 // BytesWritten return total bytes written on writer 286 func (enc *Encoder) BytesWritten() int64 { 287 return enc.n 288 } 289 290 // RawEncoding returns an option to use in NewEncoder(...) which sets raw encoding mode to true 291 // points will not be compressed using this option 292 func RawEncoding() func(*Encoder) { 293 return func(enc *Encoder) { 294 enc.raw = true 295 } 296 } 297 298 // NoSubgroupChecks returns an option to use in NewDecoder(...) which disable subgroup checks on the points 299 // the decoder will read. Use with caution, as crafted points from an untrusted source can lead to crypto-attacks. 300 func NoSubgroupChecks() func(*Decoder) { 301 return func(dec *Decoder) { 302 dec.subGroupCheck = false 303 } 304 } 305 306 func (enc *Encoder) encode(v interface{}) (err error) { 307 rv := reflect.ValueOf(v) 308 if v == nil || (rv.Kind() == reflect.Ptr && rv.IsNil()) { 309 return errors.New("<no value> encoder: can't encode <nil>") 310 } 311 312 // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap 313 314 var written int 315 switch t := v.(type) { 316 case *fr.Element: 317 buf := t.Bytes() 318 written, err = enc.w.Write(buf[:]) 319 enc.n += int64(written) 320 return 321 case *fp.Element: 322 buf := t.Bytes() 323 written, err = enc.w.Write(buf[:]) 324 enc.n += int64(written) 325 return 326 case *G1Affine: 327 buf := t.Bytes() 328 written, err = enc.w.Write(buf[:]) 329 enc.n += int64(written) 330 return 331 case []fr.Element: 332 // write slice length 333 err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) 334 if err != nil { 335 return 336 } 337 enc.n += 4 338 var buf [fr.Bytes]byte 339 for i := 0; i < len(t); i++ { 340 buf = t[i].Bytes() 341 written, err = enc.w.Write(buf[:]) 342 enc.n += int64(written) 343 if err != nil { 344 return 345 } 346 } 347 return nil 348 case []fp.Element: 349 // write slice length 350 err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) 351 if err != nil { 352 return 353 } 354 enc.n += 4 355 var buf [fp.Bytes]byte 356 for i := 0; i < len(t); i++ { 357 buf = t[i].Bytes() 358 written, err = enc.w.Write(buf[:]) 359 enc.n += int64(written) 360 if err != nil { 361 return 362 } 363 } 364 return nil 365 366 case []G1Affine: 367 // write slice length 368 err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) 369 if err != nil { 370 return 371 } 372 enc.n += 4 373 374 var buf [SizeOfG1AffineCompressed]byte 375 376 for i := 0; i < len(t); i++ { 377 buf = t[i].Bytes() 378 written, err = enc.w.Write(buf[:]) 379 enc.n += int64(written) 380 if err != nil { 381 return 382 } 383 } 384 return nil 385 default: 386 n := binary.Size(t) 387 if n == -1 { 388 return errors.New("<no value> encoder: unsupported type") 389 } 390 err = binary.Write(enc.w, binary.BigEndian, t) 391 enc.n += int64(n) 392 return 393 } 394 } 395 396 func (enc *Encoder) encodeRaw(v interface{}) (err error) { 397 rv := reflect.ValueOf(v) 398 if v == nil || (rv.Kind() == reflect.Ptr && rv.IsNil()) { 399 return errors.New("<no value> encoder: can't encode <nil>") 400 } 401 402 // implementation note: code is a bit verbose (abusing code generation), but minimize allocations on the heap 403 404 var written int 405 switch t := v.(type) { 406 case *fr.Element: 407 buf := t.Bytes() 408 written, err = enc.w.Write(buf[:]) 409 enc.n += int64(written) 410 return 411 case *fp.Element: 412 buf := t.Bytes() 413 written, err = enc.w.Write(buf[:]) 414 enc.n += int64(written) 415 return 416 case *G1Affine: 417 buf := t.RawBytes() 418 written, err = enc.w.Write(buf[:]) 419 enc.n += int64(written) 420 return 421 case []fr.Element: 422 // write slice length 423 err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) 424 if err != nil { 425 return 426 } 427 enc.n += 4 428 var buf [fr.Bytes]byte 429 for i := 0; i < len(t); i++ { 430 buf = t[i].Bytes() 431 written, err = enc.w.Write(buf[:]) 432 enc.n += int64(written) 433 if err != nil { 434 return 435 } 436 } 437 return nil 438 case []fp.Element: 439 // write slice length 440 err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) 441 if err != nil { 442 return 443 } 444 enc.n += 4 445 var buf [fp.Bytes]byte 446 for i := 0; i < len(t); i++ { 447 buf = t[i].Bytes() 448 written, err = enc.w.Write(buf[:]) 449 enc.n += int64(written) 450 if err != nil { 451 return 452 } 453 } 454 return nil 455 456 case []G1Affine: 457 // write slice length 458 err = binary.Write(enc.w, binary.BigEndian, uint32(len(t))) 459 if err != nil { 460 return 461 } 462 enc.n += 4 463 464 var buf [SizeOfG1AffineUncompressed]byte 465 466 for i := 0; i < len(t); i++ { 467 buf = t[i].RawBytes() 468 written, err = enc.w.Write(buf[:]) 469 enc.n += int64(written) 470 if err != nil { 471 return 472 } 473 } 474 return nil 475 default: 476 n := binary.Size(t) 477 if n == -1 { 478 return errors.New("<no value> encoder: unsupported type") 479 } 480 err = binary.Write(enc.w, binary.BigEndian, t) 481 enc.n += int64(n) 482 return 483 } 484 } 485 486 // SizeOfG1AffineCompressed represents the size in bytes that a G1Affine need in binary form, compressed 487 const SizeOfG1AffineCompressed = 32 488 489 // SizeOfG1AffineUncompressed represents the size in bytes that a G1Affine need in binary form, uncompressed 490 const SizeOfG1AffineUncompressed = SizeOfG1AffineCompressed * 2 491 492 // Marshal converts p to a byte slice (without point compression) 493 func (p *G1Affine) Marshal() []byte { 494 b := p.RawBytes() 495 return b[:] 496 } 497 498 // Unmarshal is an alias to SetBytes() 499 func (p *G1Affine) Unmarshal(buf []byte) error { 500 _, err := p.SetBytes(buf) 501 return err 502 } 503 504 // Bytes returns binary representation of p 505 // will store X coordinate in regular form and a parity bit 506 // as we have less than 3 bits available in our coordinate, we can't follow BLS12-381 style encoding (ZCash/IETF) 507 // 508 // we use the 2 most significant bits instead 509 // 510 // 00 -> uncompressed 511 // 10 -> compressed, use smallest lexicographically square root of Y^2 512 // 11 -> compressed, use largest lexicographically square root of Y^2 513 // 01 -> compressed infinity point 514 // the "uncompressed infinity point" will just have 00 (uncompressed) followed by zeroes (infinity = 0,0 in affine coordinates) 515 func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { 516 517 // check if p is infinity point 518 if p.X.IsZero() && p.Y.IsZero() { 519 res[0] = mCompressedInfinity 520 return 521 } 522 523 msbMask := mCompressedSmallest 524 // compressed, we need to know if Y is lexicographically bigger than -Y 525 // if p.Y ">" -p.Y 526 if p.Y.LexicographicallyLargest() { 527 msbMask = mCompressedLargest 528 } 529 530 // we store X and mask the most significant word with our metadata mask 531 fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) 532 533 res[0] |= msbMask 534 535 return 536 } 537 538 // RawBytes returns binary representation of p (stores X and Y coordinate) 539 // see Bytes() for a compressed representation 540 func (p *G1Affine) RawBytes() (res [SizeOfG1AffineUncompressed]byte) { 541 542 // check if p is infinity point 543 if p.X.IsZero() && p.Y.IsZero() { 544 545 res[0] = mUncompressed 546 547 return 548 } 549 550 // not compressed 551 // we store the Y coordinate 552 fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[32:32+fp.Bytes]), p.Y) 553 554 // we store X and mask the most significant word with our metadata mask 555 fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) 556 557 res[0] |= mUncompressed 558 559 return 560 } 561 562 // SetBytes sets p from binary representation in buf and returns number of consumed bytes 563 // 564 // bytes in buf must match either RawBytes() or Bytes() output 565 // 566 // if buf is too short io.ErrShortBuffer is returned 567 // 568 // if buf contains compressed representation (output from Bytes()) and we're unable to compute 569 // the Y coordinate (i.e the square root doesn't exist) this function returns an error 570 // 571 // this check if the resulting point is on the curve and in the correct subgroup 572 func (p *G1Affine) SetBytes(buf []byte) (int, error) { 573 return p.setBytes(buf, true) 574 } 575 576 func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { 577 if len(buf) < SizeOfG1AffineCompressed { 578 return 0, io.ErrShortBuffer 579 } 580 581 // most significant byte 582 mData := buf[0] & mMask 583 584 // check buffer size 585 if mData == mUncompressed { 586 if len(buf) < SizeOfG1AffineUncompressed { 587 return 0, io.ErrShortBuffer 588 } 589 } 590 591 // if infinity is encoded in the metadata, we don't need to read the buffer 592 if mData == mCompressedInfinity { 593 p.X.SetZero() 594 p.Y.SetZero() 595 return SizeOfG1AffineCompressed, nil 596 } 597 598 // uncompressed point 599 if mData == mUncompressed { 600 // read X and Y coordinates 601 if err := p.X.SetBytesCanonical(buf[:fp.Bytes]); err != nil { 602 return 0, err 603 } 604 if err := p.Y.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { 605 return 0, err 606 } 607 608 // subgroup check 609 if subGroupCheck && !p.IsInSubGroup() { 610 return 0, errors.New("invalid point: subgroup check failed") 611 } 612 613 return SizeOfG1AffineUncompressed, nil 614 } 615 616 // we have a compressed coordinate 617 // we need to 618 // 1. copy the buffer (to keep this method thread safe) 619 // 2. we need to solve the curve equation to compute Y 620 621 var bufX [fp.Bytes]byte 622 copy(bufX[:fp.Bytes], buf[:fp.Bytes]) 623 bufX[0] &= ^mMask 624 625 // read X coordinate 626 if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { 627 return 0, err 628 } 629 630 var YSquared, Y fp.Element 631 632 // y^2=x^3+x+b 633 YSquared.Square(&p.X).Mul(&YSquared, &p.X) 634 YSquared.Add(&YSquared, &p.X). 635 Add(&YSquared, &bCurveCoeff) 636 637 if Y.Sqrt(&YSquared) == nil { 638 return 0, errors.New("invalid compressed coordinate: square root doesn't exist") 639 } 640 641 if Y.LexicographicallyLargest() { 642 // Y ">" -Y 643 if mData == mCompressedSmallest { 644 Y.Neg(&Y) 645 } 646 } else { 647 // Y "<=" -Y 648 if mData == mCompressedLargest { 649 Y.Neg(&Y) 650 } 651 } 652 653 p.Y = Y 654 655 // subgroup check 656 if subGroupCheck && !p.IsInSubGroup() { 657 return 0, errors.New("invalid point: subgroup check failed") 658 } 659 660 return SizeOfG1AffineCompressed, nil 661 } 662 663 // unsafeComputeY called by Decoder when processing slices of compressed point in parallel (step 2) 664 // it computes the Y coordinate from the already set X coordinate and is compute intensive 665 func (p *G1Affine) unsafeComputeY(subGroupCheck bool) error { 666 // stored in unsafeSetCompressedBytes 667 668 mData := byte(p.Y[0]) 669 670 // we have a compressed coordinate, we need to solve the curve equation to compute Y 671 var YSquared, Y fp.Element 672 673 // y^2=x^3+x+b 674 YSquared.Square(&p.X).Mul(&YSquared, &p.X) 675 YSquared.Add(&YSquared, &p.X). 676 Add(&YSquared, &bCurveCoeff) 677 678 if Y.Sqrt(&YSquared) == nil { 679 return errors.New("invalid compressed coordinate: square root doesn't exist") 680 } 681 682 if Y.LexicographicallyLargest() { 683 // Y ">" -Y 684 if mData == mCompressedSmallest { 685 Y.Neg(&Y) 686 } 687 } else { 688 // Y "<=" -Y 689 if mData == mCompressedLargest { 690 Y.Neg(&Y) 691 } 692 } 693 694 p.Y = Y 695 696 // subgroup check 697 if subGroupCheck && !p.IsInSubGroup() { 698 return errors.New("invalid point: subgroup check failed") 699 } 700 701 return nil 702 } 703 704 // unsafeSetCompressedBytes is called by Decoder when processing slices of compressed point in parallel (step 1) 705 // assumes buf[:8] mask is set to compressed 706 // returns true if point is infinity and need no further processing 707 // it sets X coordinate and uses Y for scratch space to store decompression metadata 708 func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { 709 710 // read the most significant byte 711 mData := buf[0] & mMask 712 713 if mData == mCompressedInfinity { 714 p.X.SetZero() 715 p.Y.SetZero() 716 isInfinity = true 717 return isInfinity, nil 718 } 719 720 // we need to copy the input buffer (to keep this method thread safe) 721 var bufX [fp.Bytes]byte 722 copy(bufX[:fp.Bytes], buf[:fp.Bytes]) 723 bufX[0] &= ^mMask 724 725 // read X coordinate 726 if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { 727 return false, err 728 } 729 // store mData in p.Y[0] 730 p.Y[0] = uint64(mData) 731 732 // recomputing Y will be done asynchronously 733 return isInfinity, nil 734 }