github.com/twelsh-aw/go/src@v0.0.0-20230516233729-a56fe86a7c81/crypto/internal/edwards25519/field/fe_test.go (about) 1 // Copyright (c) 2017 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package field 6 7 import ( 8 "bytes" 9 "crypto/rand" 10 "encoding/hex" 11 "io" 12 "math/big" 13 "math/bits" 14 mathrand "math/rand" 15 "reflect" 16 "testing" 17 "testing/quick" 18 ) 19 20 func (v Element) String() string { 21 return hex.EncodeToString(v.Bytes()) 22 } 23 24 // quickCheckConfig1024 will make each quickcheck test run (1024 * -quickchecks) 25 // times. The default value of -quickchecks is 100. 26 var quickCheckConfig1024 = &quick.Config{MaxCountScale: 1 << 10} 27 28 func generateFieldElement(rand *mathrand.Rand) Element { 29 const maskLow52Bits = (1 << 52) - 1 30 return Element{ 31 rand.Uint64() & maskLow52Bits, 32 rand.Uint64() & maskLow52Bits, 33 rand.Uint64() & maskLow52Bits, 34 rand.Uint64() & maskLow52Bits, 35 rand.Uint64() & maskLow52Bits, 36 } 37 } 38 39 // weirdLimbs can be combined to generate a range of edge-case field elements. 40 // 0 and -1 are intentionally more weighted, as they combine well. 41 var ( 42 weirdLimbs51 = []uint64{ 43 0, 0, 0, 0, 44 1, 45 19 - 1, 46 19, 47 0x2aaaaaaaaaaaa, 48 0x5555555555555, 49 (1 << 51) - 20, 50 (1 << 51) - 19, 51 (1 << 51) - 1, (1 << 51) - 1, 52 (1 << 51) - 1, (1 << 51) - 1, 53 } 54 weirdLimbs52 = []uint64{ 55 0, 0, 0, 0, 0, 0, 56 1, 57 19 - 1, 58 19, 59 0x2aaaaaaaaaaaa, 60 0x5555555555555, 61 (1 << 51) - 20, 62 (1 << 51) - 19, 63 (1 << 51) - 1, (1 << 51) - 1, 64 (1 << 51) - 1, (1 << 51) - 1, 65 (1 << 51) - 1, (1 << 51) - 1, 66 1 << 51, 67 (1 << 51) + 1, 68 (1 << 52) - 19, 69 (1 << 52) - 1, 70 } 71 ) 72 73 func generateWeirdFieldElement(rand *mathrand.Rand) Element { 74 return Element{ 75 weirdLimbs52[rand.Intn(len(weirdLimbs52))], 76 weirdLimbs51[rand.Intn(len(weirdLimbs51))], 77 weirdLimbs51[rand.Intn(len(weirdLimbs51))], 78 weirdLimbs51[rand.Intn(len(weirdLimbs51))], 79 weirdLimbs51[rand.Intn(len(weirdLimbs51))], 80 } 81 } 82 83 func (Element) Generate(rand *mathrand.Rand, size int) reflect.Value { 84 if rand.Intn(2) == 0 { 85 return reflect.ValueOf(generateWeirdFieldElement(rand)) 86 } 87 return reflect.ValueOf(generateFieldElement(rand)) 88 } 89 90 // isInBounds returns whether the element is within the expected bit size bounds 91 // after a light reduction. 92 func isInBounds(x *Element) bool { 93 return bits.Len64(x.l0) <= 52 && 94 bits.Len64(x.l1) <= 52 && 95 bits.Len64(x.l2) <= 52 && 96 bits.Len64(x.l3) <= 52 && 97 bits.Len64(x.l4) <= 52 98 } 99 100 func TestMultiplyDistributesOverAdd(t *testing.T) { 101 multiplyDistributesOverAdd := func(x, y, z Element) bool { 102 // Compute t1 = (x+y)*z 103 t1 := new(Element) 104 t1.Add(&x, &y) 105 t1.Multiply(t1, &z) 106 107 // Compute t2 = x*z + y*z 108 t2 := new(Element) 109 t3 := new(Element) 110 t2.Multiply(&x, &z) 111 t3.Multiply(&y, &z) 112 t2.Add(t2, t3) 113 114 return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2) 115 } 116 117 if err := quick.Check(multiplyDistributesOverAdd, quickCheckConfig1024); err != nil { 118 t.Error(err) 119 } 120 } 121 122 func TestMul64to128(t *testing.T) { 123 a := uint64(5) 124 b := uint64(5) 125 r := mul64(a, b) 126 if r.lo != 0x19 || r.hi != 0 { 127 t.Errorf("lo-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi) 128 } 129 130 a = uint64(18014398509481983) // 2^54 - 1 131 b = uint64(18014398509481983) // 2^54 - 1 132 r = mul64(a, b) 133 if r.lo != 0xff80000000000001 || r.hi != 0xfffffffffff { 134 t.Errorf("hi-range wide mult failed, got %d + %d*(2**64)", r.lo, r.hi) 135 } 136 137 a = uint64(1125899906842661) 138 b = uint64(2097155) 139 r = mul64(a, b) 140 r = addMul64(r, a, b) 141 r = addMul64(r, a, b) 142 r = addMul64(r, a, b) 143 r = addMul64(r, a, b) 144 if r.lo != 16888498990613035 || r.hi != 640 { 145 t.Errorf("wrong answer: %d + %d*(2**64)", r.lo, r.hi) 146 } 147 } 148 149 func TestSetBytesRoundTrip(t *testing.T) { 150 f1 := func(in [32]byte, fe Element) bool { 151 fe.SetBytes(in[:]) 152 153 // Mask the most significant bit as it's ignored by SetBytes. (Now 154 // instead of earlier so we check the masking in SetBytes is working.) 155 in[len(in)-1] &= (1 << 7) - 1 156 157 return bytes.Equal(in[:], fe.Bytes()) && isInBounds(&fe) 158 } 159 if err := quick.Check(f1, nil); err != nil { 160 t.Errorf("failed bytes->FE->bytes round-trip: %v", err) 161 } 162 163 f2 := func(fe, r Element) bool { 164 r.SetBytes(fe.Bytes()) 165 166 // Intentionally not using Equal not to go through Bytes again. 167 // Calling reduce because both Generate and SetBytes can produce 168 // non-canonical representations. 169 fe.reduce() 170 r.reduce() 171 return fe == r 172 } 173 if err := quick.Check(f2, nil); err != nil { 174 t.Errorf("failed FE->bytes->FE round-trip: %v", err) 175 } 176 177 // Check some fixed vectors from dalek 178 type feRTTest struct { 179 fe Element 180 b []byte 181 } 182 var tests = []feRTTest{ 183 { 184 fe: Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676}, 185 b: []byte{74, 209, 69, 197, 70, 70, 161, 222, 56, 226, 229, 19, 112, 60, 25, 92, 187, 74, 222, 56, 50, 153, 51, 233, 40, 74, 57, 6, 160, 185, 213, 31}, 186 }, 187 { 188 fe: Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972}, 189 b: []byte{199, 23, 106, 112, 61, 77, 216, 79, 186, 60, 11, 118, 13, 16, 103, 15, 42, 32, 83, 250, 44, 57, 204, 198, 78, 199, 253, 119, 146, 172, 3, 122}, 190 }, 191 } 192 193 for _, tt := range tests { 194 b := tt.fe.Bytes() 195 fe, _ := new(Element).SetBytes(tt.b) 196 if !bytes.Equal(b, tt.b) || fe.Equal(&tt.fe) != 1 { 197 t.Errorf("Failed fixed roundtrip: %v", tt) 198 } 199 } 200 } 201 202 func swapEndianness(buf []byte) []byte { 203 for i := 0; i < len(buf)/2; i++ { 204 buf[i], buf[len(buf)-i-1] = buf[len(buf)-i-1], buf[i] 205 } 206 return buf 207 } 208 209 func TestBytesBigEquivalence(t *testing.T) { 210 f1 := func(in [32]byte, fe, fe1 Element) bool { 211 fe.SetBytes(in[:]) 212 213 in[len(in)-1] &= (1 << 7) - 1 // mask the most significant bit 214 b := new(big.Int).SetBytes(swapEndianness(in[:])) 215 fe1.fromBig(b) 216 217 if fe != fe1 { 218 return false 219 } 220 221 buf := make([]byte, 32) 222 buf = swapEndianness(fe1.toBig().FillBytes(buf)) 223 224 return bytes.Equal(fe.Bytes(), buf) && isInBounds(&fe) && isInBounds(&fe1) 225 } 226 if err := quick.Check(f1, nil); err != nil { 227 t.Error(err) 228 } 229 } 230 231 // fromBig sets v = n, and returns v. The bit length of n must not exceed 256. 232 func (v *Element) fromBig(n *big.Int) *Element { 233 if n.BitLen() > 32*8 { 234 panic("edwards25519: invalid field element input size") 235 } 236 237 buf := make([]byte, 0, 32) 238 for _, word := range n.Bits() { 239 for i := 0; i < bits.UintSize; i += 8 { 240 if len(buf) >= cap(buf) { 241 break 242 } 243 buf = append(buf, byte(word)) 244 word >>= 8 245 } 246 } 247 248 v.SetBytes(buf[:32]) 249 return v 250 } 251 252 func (v *Element) fromDecimal(s string) *Element { 253 n, ok := new(big.Int).SetString(s, 10) 254 if !ok { 255 panic("not a valid decimal: " + s) 256 } 257 return v.fromBig(n) 258 } 259 260 // toBig returns v as a big.Int. 261 func (v *Element) toBig() *big.Int { 262 buf := v.Bytes() 263 264 words := make([]big.Word, 32*8/bits.UintSize) 265 for n := range words { 266 for i := 0; i < bits.UintSize; i += 8 { 267 if len(buf) == 0 { 268 break 269 } 270 words[n] |= big.Word(buf[0]) << big.Word(i) 271 buf = buf[1:] 272 } 273 } 274 275 return new(big.Int).SetBits(words) 276 } 277 278 func TestDecimalConstants(t *testing.T) { 279 sqrtM1String := "19681161376707505956807079304988542015446066515923890162744021073123829784752" 280 if exp := new(Element).fromDecimal(sqrtM1String); sqrtM1.Equal(exp) != 1 { 281 t.Errorf("sqrtM1 is %v, expected %v", sqrtM1, exp) 282 } 283 // d is in the parent package, and we don't want to expose d or fromDecimal. 284 // dString := "37095705934669439343138083508754565189542113879843219016388785533085940283555" 285 // if exp := new(Element).fromDecimal(dString); d.Equal(exp) != 1 { 286 // t.Errorf("d is %v, expected %v", d, exp) 287 // } 288 } 289 290 func TestSetBytesRoundTripEdgeCases(t *testing.T) { 291 // TODO: values close to 0, close to 2^255-19, between 2^255-19 and 2^255-1, 292 // and between 2^255 and 2^256-1. Test both the documented SetBytes 293 // behavior, and that Bytes reduces them. 294 } 295 296 // Tests self-consistency between Multiply and Square. 297 func TestConsistency(t *testing.T) { 298 var x Element 299 var x2, x2sq Element 300 301 x = Element{1, 1, 1, 1, 1} 302 x2.Multiply(&x, &x) 303 x2sq.Square(&x) 304 305 if x2 != x2sq { 306 t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq) 307 } 308 309 var bytes [32]byte 310 311 _, err := io.ReadFull(rand.Reader, bytes[:]) 312 if err != nil { 313 t.Fatal(err) 314 } 315 x.SetBytes(bytes[:]) 316 317 x2.Multiply(&x, &x) 318 x2sq.Square(&x) 319 320 if x2 != x2sq { 321 t.Fatalf("all ones failed\nmul: %x\nsqr: %x\n", x2, x2sq) 322 } 323 } 324 325 func TestEqual(t *testing.T) { 326 x := Element{1, 1, 1, 1, 1} 327 y := Element{5, 4, 3, 2, 1} 328 329 eq := x.Equal(&x) 330 if eq != 1 { 331 t.Errorf("wrong about equality") 332 } 333 334 eq = x.Equal(&y) 335 if eq != 0 { 336 t.Errorf("wrong about inequality") 337 } 338 } 339 340 func TestInvert(t *testing.T) { 341 x := Element{1, 1, 1, 1, 1} 342 one := Element{1, 0, 0, 0, 0} 343 var xinv, r Element 344 345 xinv.Invert(&x) 346 r.Multiply(&x, &xinv) 347 r.reduce() 348 349 if one != r { 350 t.Errorf("inversion identity failed, got: %x", r) 351 } 352 353 var bytes [32]byte 354 355 _, err := io.ReadFull(rand.Reader, bytes[:]) 356 if err != nil { 357 t.Fatal(err) 358 } 359 x.SetBytes(bytes[:]) 360 361 xinv.Invert(&x) 362 r.Multiply(&x, &xinv) 363 r.reduce() 364 365 if one != r { 366 t.Errorf("random inversion identity failed, got: %x for field element %x", r, x) 367 } 368 369 zero := Element{} 370 x.Set(&zero) 371 if xx := xinv.Invert(&x); xx != &xinv { 372 t.Errorf("inverting zero did not return the receiver") 373 } else if xinv.Equal(&zero) != 1 { 374 t.Errorf("inverting zero did not return zero") 375 } 376 } 377 378 func TestSelectSwap(t *testing.T) { 379 a := Element{358744748052810, 1691584618240980, 977650209285361, 1429865912637724, 560044844278676} 380 b := Element{84926274344903, 473620666599931, 365590438845504, 1028470286882429, 2146499180330972} 381 382 var c, d Element 383 384 c.Select(&a, &b, 1) 385 d.Select(&a, &b, 0) 386 387 if c.Equal(&a) != 1 || d.Equal(&b) != 1 { 388 t.Errorf("Select failed") 389 } 390 391 c.Swap(&d, 0) 392 393 if c.Equal(&a) != 1 || d.Equal(&b) != 1 { 394 t.Errorf("Swap failed") 395 } 396 397 c.Swap(&d, 1) 398 399 if c.Equal(&b) != 1 || d.Equal(&a) != 1 { 400 t.Errorf("Swap failed") 401 } 402 } 403 404 func TestMult32(t *testing.T) { 405 mult32EquivalentToMul := func(x Element, y uint32) bool { 406 t1 := new(Element) 407 for i := 0; i < 100; i++ { 408 t1.Mult32(&x, y) 409 } 410 411 ty := new(Element) 412 ty.l0 = uint64(y) 413 414 t2 := new(Element) 415 for i := 0; i < 100; i++ { 416 t2.Multiply(&x, ty) 417 } 418 419 return t1.Equal(t2) == 1 && isInBounds(t1) && isInBounds(t2) 420 } 421 422 if err := quick.Check(mult32EquivalentToMul, quickCheckConfig1024); err != nil { 423 t.Error(err) 424 } 425 } 426 427 func TestSqrtRatio(t *testing.T) { 428 // From draft-irtf-cfrg-ristretto255-decaf448-00, Appendix A.4. 429 type test struct { 430 u, v string 431 wasSquare int 432 r string 433 } 434 var tests = []test{ 435 // If u is 0, the function is defined to return (0, TRUE), even if v 436 // is zero. Note that where used in this package, the denominator v 437 // is never zero. 438 { 439 "0000000000000000000000000000000000000000000000000000000000000000", 440 "0000000000000000000000000000000000000000000000000000000000000000", 441 1, "0000000000000000000000000000000000000000000000000000000000000000", 442 }, 443 // 0/1 == 0² 444 { 445 "0000000000000000000000000000000000000000000000000000000000000000", 446 "0100000000000000000000000000000000000000000000000000000000000000", 447 1, "0000000000000000000000000000000000000000000000000000000000000000", 448 }, 449 // If u is non-zero and v is zero, defined to return (0, FALSE). 450 { 451 "0100000000000000000000000000000000000000000000000000000000000000", 452 "0000000000000000000000000000000000000000000000000000000000000000", 453 0, "0000000000000000000000000000000000000000000000000000000000000000", 454 }, 455 // 2/1 is not square in this field. 456 { 457 "0200000000000000000000000000000000000000000000000000000000000000", 458 "0100000000000000000000000000000000000000000000000000000000000000", 459 0, "3c5ff1b5d8e4113b871bd052f9e7bcd0582804c266ffb2d4f4203eb07fdb7c54", 460 }, 461 // 4/1 == 2² 462 { 463 "0400000000000000000000000000000000000000000000000000000000000000", 464 "0100000000000000000000000000000000000000000000000000000000000000", 465 1, "0200000000000000000000000000000000000000000000000000000000000000", 466 }, 467 // 1/4 == (2⁻¹)² == (2^(p-2))² per Euler's theorem 468 { 469 "0100000000000000000000000000000000000000000000000000000000000000", 470 "0400000000000000000000000000000000000000000000000000000000000000", 471 1, "f6ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3f", 472 }, 473 } 474 475 for i, tt := range tests { 476 u, _ := new(Element).SetBytes(decodeHex(tt.u)) 477 v, _ := new(Element).SetBytes(decodeHex(tt.v)) 478 want, _ := new(Element).SetBytes(decodeHex(tt.r)) 479 got, wasSquare := new(Element).SqrtRatio(u, v) 480 if got.Equal(want) == 0 || wasSquare != tt.wasSquare { 481 t.Errorf("%d: got (%v, %v), want (%v, %v)", i, got, wasSquare, want, tt.wasSquare) 482 } 483 } 484 } 485 486 func TestCarryPropagate(t *testing.T) { 487 asmLikeGeneric := func(a [5]uint64) bool { 488 t1 := &Element{a[0], a[1], a[2], a[3], a[4]} 489 t2 := &Element{a[0], a[1], a[2], a[3], a[4]} 490 491 t1.carryPropagate() 492 t2.carryPropagateGeneric() 493 494 if *t1 != *t2 { 495 t.Logf("got: %#v,\nexpected: %#v", t1, t2) 496 } 497 498 return *t1 == *t2 && isInBounds(t2) 499 } 500 501 if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil { 502 t.Error(err) 503 } 504 505 if !asmLikeGeneric([5]uint64{0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}) { 506 t.Errorf("failed for {0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff}") 507 } 508 } 509 510 func TestFeSquare(t *testing.T) { 511 asmLikeGeneric := func(a Element) bool { 512 t1 := a 513 t2 := a 514 515 feSquareGeneric(&t1, &t1) 516 feSquare(&t2, &t2) 517 518 if t1 != t2 { 519 t.Logf("got: %#v,\nexpected: %#v", t1, t2) 520 } 521 522 return t1 == t2 && isInBounds(&t2) 523 } 524 525 if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil { 526 t.Error(err) 527 } 528 } 529 530 func TestFeMul(t *testing.T) { 531 asmLikeGeneric := func(a, b Element) bool { 532 a1 := a 533 a2 := a 534 b1 := b 535 b2 := b 536 537 feMulGeneric(&a1, &a1, &b1) 538 feMul(&a2, &a2, &b2) 539 540 if a1 != a2 || b1 != b2 { 541 t.Logf("got: %#v,\nexpected: %#v", a1, a2) 542 t.Logf("got: %#v,\nexpected: %#v", b1, b2) 543 } 544 545 return a1 == a2 && isInBounds(&a2) && 546 b1 == b2 && isInBounds(&b2) 547 } 548 549 if err := quick.Check(asmLikeGeneric, quickCheckConfig1024); err != nil { 550 t.Error(err) 551 } 552 } 553 554 func decodeHex(s string) []byte { 555 b, err := hex.DecodeString(s) 556 if err != nil { 557 panic(err) 558 } 559 return b 560 }