github.com/emmansun/gmsm@v0.29.1/internal/bigmod/nat_test.go (about) 1 // Copyright 2021 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package bigmod 6 7 import ( 8 "fmt" 9 "math/big" 10 "math/bits" 11 "math/rand" 12 "reflect" 13 "strings" 14 "testing" 15 "testing/quick" 16 ) 17 18 func (n *Nat) String() string { 19 var limbs []string 20 for i := range n.limbs { 21 limbs = append(limbs, fmt.Sprintf("%016X", n.limbs[len(n.limbs)-1-i])) 22 } 23 return "{" + strings.Join(limbs, " ") + "}" 24 } 25 26 // Generate generates an even nat. It's used by testing/quick to produce random 27 // *nat values for quick.Check invocations. 28 func (*Nat) Generate(r *rand.Rand, size int) reflect.Value { 29 limbs := make([]uint, size) 30 for i := 0; i < size; i++ { 31 limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2) 32 } 33 return reflect.ValueOf(&Nat{limbs}) 34 } 35 36 func testModAddCommutative(a *Nat, b *Nat) bool { 37 m := maxModulus(uint(len(a.limbs))) 38 aPlusB := new(Nat).Set(a) 39 aPlusB.Add(b, m) 40 bPlusA := new(Nat).Set(b) 41 bPlusA.Add(a, m) 42 return aPlusB.Equal(bPlusA) == 1 43 } 44 45 func TestModAddCommutative(t *testing.T) { 46 err := quick.Check(testModAddCommutative, &quick.Config{}) 47 if err != nil { 48 t.Error(err) 49 } 50 } 51 52 func testModSubThenAddIdentity(a *Nat, b *Nat) bool { 53 m := maxModulus(uint(len(a.limbs))) 54 original := new(Nat).Set(a) 55 a.Sub(b, m) 56 a.Add(b, m) 57 return a.Equal(original) == 1 58 } 59 60 func TestModSubThenAddIdentity(t *testing.T) { 61 err := quick.Check(testModSubThenAddIdentity, &quick.Config{}) 62 if err != nil { 63 t.Error(err) 64 } 65 } 66 67 func TestMontgomeryRoundtrip(t *testing.T) { 68 err := quick.Check(func(a *Nat) bool { 69 one := &Nat{make([]uint, len(a.limbs))} 70 one.limbs[0] = 1 71 aPlusOne := new(big.Int).SetBytes(natBytes(a)) 72 aPlusOne.Add(aPlusOne, big.NewInt(1)) 73 m, _ := NewModulusFromBig(aPlusOne) 74 monty := new(Nat).Set(a) 75 monty.montgomeryRepresentation(m) 76 aAgain := new(Nat).Set(monty) 77 aAgain.montgomeryMul(monty, one, m) 78 if a.Equal(aAgain) != 1 { 79 t.Errorf("%v != %v", a, aAgain) 80 return false 81 } 82 return true 83 }, &quick.Config{}) 84 if err != nil { 85 t.Error(err) 86 } 87 } 88 89 func TestShiftIn(t *testing.T) { 90 if bits.UintSize != 64 { 91 t.Skip("examples are only valid in 64 bit") 92 } 93 examples := []struct { 94 m, x, expected []byte 95 y uint64 96 }{{ 97 m: []byte{13}, 98 x: []byte{0}, 99 y: 0xFFFF_FFFF_FFFF_FFFF, 100 expected: []byte{2}, 101 }, { 102 m: []byte{13}, 103 x: []byte{7}, 104 y: 0xFFFF_FFFF_FFFF_FFFF, 105 expected: []byte{10}, 106 }, { 107 m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}, 108 x: make([]byte, 9), 109 y: 0xFFFF_FFFF_FFFF_FFFF, 110 expected: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 111 }, { 112 m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}, 113 x: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 114 y: 0, 115 expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06}, 116 }} 117 118 for i, tt := range examples { 119 m := modulusFromBytes(tt.m) 120 got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m) 121 if exp := natFromBytes(tt.expected).ExpandFor(m); got.Equal(exp) != 1 { 122 t.Errorf("%d: got %v, expected %v", i, got, exp) 123 } 124 } 125 } 126 127 func TestModulusAndNatSizes(t *testing.T) { 128 // These are 126 bit (2 * _W on 64-bit architectures) values, serialized as 129 // 128 bits worth of bytes. If leading zeroes are stripped, they fit in two 130 // limbs, if they are not, they fit in three. This can be a problem because 131 // modulus strips leading zeroes and nat does not. 132 m := modulusFromBytes([]byte{ 133 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 134 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) 135 xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 136 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe} 137 natFromBytes(xb).ExpandFor(m) // must not panic for shrinking 138 NewNat().SetBytes(xb, m) 139 } 140 141 func TestSetBytes(t *testing.T) { 142 tests := []struct { 143 m, b []byte 144 fail bool 145 }{{ 146 m: []byte{0xff, 0xff}, 147 b: []byte{0x00, 0x01}, 148 }, { 149 m: []byte{0xff, 0xff}, 150 b: []byte{0xff, 0xff}, 151 fail: true, 152 }, { 153 m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 154 b: []byte{0x00, 0x01}, 155 }, { 156 m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 157 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, 158 }, { 159 m: []byte{0xff, 0xff}, 160 b: []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 161 fail: true, 162 }, { 163 m: []byte{0xff, 0xff}, 164 b: []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 165 fail: true, 166 }, { 167 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 168 b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, 169 }, { 170 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 171 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, 172 fail: true, 173 }, { 174 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 175 b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 176 fail: true, 177 }, { 178 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 179 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, 180 fail: true, 181 }, { 182 m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfd}, 183 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 184 fail: true, 185 }} 186 187 for i, tt := range tests { 188 m := modulusFromBytes(tt.m) 189 got, err := NewNat().SetBytes(tt.b, m) 190 if err != nil { 191 if !tt.fail { 192 t.Errorf("%d: unexpected error: %v", i, err) 193 } 194 continue 195 } 196 if tt.fail { 197 t.Errorf("%d: unexpected success", i) 198 continue 199 } 200 if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes { 201 t.Errorf("%d: got %v, expected %v", i, got, expected) 202 } 203 } 204 205 f := func(xBytes []byte) bool { 206 m := maxModulus(uint(len(xBytes)*8/_W + 1)) 207 got, err := NewNat().SetBytes(xBytes, m) 208 if err != nil { 209 return false 210 } 211 return got.Equal(natFromBytes(xBytes).ExpandFor(m)) == yes 212 } 213 214 err := quick.Check(f, &quick.Config{}) 215 if err != nil { 216 t.Error(err) 217 } 218 } 219 220 func TestExpand(t *testing.T) { 221 sliced := []uint{1, 2, 3, 4} 222 examples := []struct { 223 in []uint 224 n int 225 out []uint 226 }{{ 227 []uint{1, 2}, 228 4, 229 []uint{1, 2, 0, 0}, 230 }, { 231 sliced[:2], 232 4, 233 []uint{1, 2, 0, 0}, 234 }, { 235 []uint{1, 2}, 236 2, 237 []uint{1, 2}, 238 }} 239 240 for i, tt := range examples { 241 got := (&Nat{tt.in}).expand(tt.n) 242 if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 { 243 t.Errorf("%d: got %v, expected %v", i, got, tt.out) 244 } 245 } 246 } 247 248 func TestMod(t *testing.T) { 249 m := modulusFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}) 250 x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}) 251 out := new(Nat) 252 out.Mod(x, m) 253 expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09}) 254 if out.Equal(expected) != 1 { 255 t.Errorf("%+v != %+v", out, expected) 256 } 257 } 258 259 func TestModSub(t *testing.T) { 260 m := modulusFromBytes([]byte{13}) 261 x := &Nat{[]uint{6}} 262 y := &Nat{[]uint{7}} 263 x.Sub(y, m) 264 expected := &Nat{[]uint{12}} 265 if x.Equal(expected) != 1 { 266 t.Errorf("%+v != %+v", x, expected) 267 } 268 x.Sub(y, m) 269 expected = &Nat{[]uint{5}} 270 if x.Equal(expected) != 1 { 271 t.Errorf("%+v != %+v", x, expected) 272 } 273 } 274 275 func TestModAdd(t *testing.T) { 276 m := modulusFromBytes([]byte{13}) 277 x := &Nat{[]uint{6}} 278 y := &Nat{[]uint{7}} 279 x.Add(y, m) 280 expected := &Nat{[]uint{0}} 281 if x.Equal(expected) != 1 { 282 t.Errorf("%+v != %+v", x, expected) 283 } 284 x.Add(y, m) 285 expected = &Nat{[]uint{7}} 286 if x.Equal(expected) != 1 { 287 t.Errorf("%+v != %+v", x, expected) 288 } 289 } 290 291 func TestExp(t *testing.T) { 292 m := modulusFromBytes([]byte{13}) 293 x := &Nat{[]uint{3}} 294 out := &Nat{[]uint{0}} 295 out.Exp(x, []byte{12}, m) 296 expected := &Nat{[]uint{1}} 297 if out.Equal(expected) != 1 { 298 t.Errorf("%+v != %+v", out, expected) 299 } 300 } 301 302 func TestExpShort(t *testing.T) { 303 m := modulusFromBytes([]byte{13}) 304 x := &Nat{[]uint{3}} 305 out := &Nat{[]uint{0}} 306 out.ExpShortVarTime(x, 12, m) 307 expected := &Nat{[]uint{1}} 308 if out.Equal(expected) != 1 { 309 t.Errorf("%+v != %+v", out, expected) 310 } 311 } 312 313 // TestMulReductions tests that Mul reduces results equal or slightly greater 314 // than the modulus. Some Montgomery algorithms don't and need extra care to 315 // return correct results. See https://go.dev/issue/13907. 316 func TestMulReductions(t *testing.T) { 317 // Two short but multi-limb primes. 318 a, _ := new(big.Int).SetString("773608962677651230850240281261679752031633236267106044359907", 10) 319 b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10) 320 n := new(big.Int).Mul(a, b) 321 322 N, _ := NewModulusFromBig(n) 323 A := NewNat().SetBig(a).ExpandFor(N) 324 B := NewNat().SetBig(b).ExpandFor(N) 325 326 if A.Mul(B, N).IsZero() != 1 { 327 t.Error("a * b mod (a * b) != 0") 328 } 329 330 i := new(big.Int).ModInverse(a, b) 331 N, _ = NewModulusFromBig(b) 332 A = NewNat().SetBig(a).ExpandFor(N) 333 I := NewNat().SetBig(i).ExpandFor(N) 334 one := NewNat().SetBig(big.NewInt(1)).ExpandFor(N) 335 336 if A.Mul(I, N).Equal(one) != 1 { 337 t.Error("a * inv(a) mod b != 1") 338 } 339 } 340 341 func natBytes(n *Nat) []byte { 342 return n.Bytes(maxModulus(uint(len(n.limbs)))) 343 } 344 345 func natFromBytes(b []byte) *Nat { 346 // Must not use Nat.SetBytes as it's used in TestSetBytes. 347 bb := new(big.Int).SetBytes(b) 348 return NewNat().SetBig(bb) 349 } 350 351 func modulusFromBytes(b []byte) *Modulus { 352 bb := new(big.Int).SetBytes(b) 353 m, _ := NewModulusFromBig(bb) 354 return m 355 } 356 357 // maxModulus returns the biggest modulus that can fit in n limbs. 358 func maxModulus(n uint) *Modulus { 359 b := big.NewInt(1) 360 b.Lsh(b, n*_W) 361 b.Sub(b, big.NewInt(1)) 362 m, _ := NewModulusFromBig(b) 363 return m 364 } 365 366 func makeBenchmarkModulus(n uint) *Modulus { 367 return maxModulus(n) 368 } 369 370 func makeBenchmarkValue(n int) *Nat { 371 x := make([]uint, n) 372 for i := 0; i < n; i++ { 373 x[i]-- 374 } 375 return &Nat{limbs: x} 376 } 377 378 func makeBenchmarkExponent() []byte { 379 e := make([]byte, 256) 380 for i := 0; i < 32; i++ { 381 e[i] = 0xFF 382 } 383 return e 384 } 385 386 func BenchmarkRR256(b *testing.B) { 387 b.ResetTimer() 388 for i := 0; i < b.N; i++ { 389 makeBenchmarkModulus(4) 390 } 391 } 392 393 func BenchmarkModAdd(b *testing.B) { 394 x := makeBenchmarkValue(32) 395 y := makeBenchmarkValue(32) 396 m := makeBenchmarkModulus(32) 397 398 b.ResetTimer() 399 for i := 0; i < b.N; i++ { 400 x.Add(y, m) 401 } 402 } 403 404 func BenchmarkModSub(b *testing.B) { 405 x := makeBenchmarkValue(32) 406 y := makeBenchmarkValue(32) 407 m := makeBenchmarkModulus(32) 408 409 b.ResetTimer() 410 for i := 0; i < b.N; i++ { 411 x.Sub(y, m) 412 } 413 } 414 415 func BenchmarkMontgomeryRepr(b *testing.B) { 416 x := makeBenchmarkValue(32) 417 m := makeBenchmarkModulus(32) 418 419 b.ResetTimer() 420 for i := 0; i < b.N; i++ { 421 x.montgomeryRepresentation(m) 422 } 423 } 424 425 func BenchmarkMontgomeryMul(b *testing.B) { 426 x := makeBenchmarkValue(32) 427 y := makeBenchmarkValue(32) 428 out := makeBenchmarkValue(32) 429 m := makeBenchmarkModulus(32) 430 431 b.ResetTimer() 432 for i := 0; i < b.N; i++ { 433 out.montgomeryMul(x, y, m) 434 } 435 } 436 437 func BenchmarkModMul(b *testing.B) { 438 x := makeBenchmarkValue(32) 439 y := makeBenchmarkValue(32) 440 m := makeBenchmarkModulus(32) 441 442 b.ResetTimer() 443 for i := 0; i < b.N; i++ { 444 x.Mul(y, m) 445 } 446 } 447 448 func BenchmarkModMul256(b *testing.B) { 449 x := makeBenchmarkValue(4) 450 y := makeBenchmarkValue(4) 451 m := makeBenchmarkModulus(4) 452 453 b.ResetTimer() 454 for i := 0; i < b.N; i++ { 455 x.Mul(y, m) 456 } 457 } 458 459 func BenchmarkExpBig(b *testing.B) { 460 out := new(big.Int) 461 exponentBytes := makeBenchmarkExponent() 462 x := new(big.Int).SetBytes(exponentBytes) 463 e := new(big.Int).SetBytes(exponentBytes) 464 n := new(big.Int).SetBytes(exponentBytes) 465 one := new(big.Int).SetUint64(1) 466 n.Add(n, one) 467 468 b.ResetTimer() 469 for i := 0; i < b.N; i++ { 470 out.Exp(x, e, n) 471 } 472 } 473 474 func BenchmarkExp(b *testing.B) { 475 x := makeBenchmarkValue(32) 476 e := makeBenchmarkExponent() 477 out := makeBenchmarkValue(32) 478 m := makeBenchmarkModulus(32) 479 480 b.ResetTimer() 481 for i := 0; i < b.N; i++ { 482 out.Exp(x, e, m) 483 } 484 } 485 486 func TestNewModFromBigZero(t *testing.T) { 487 expected := "modulus must be >= 0" 488 _, err := NewModulusFromBig(big.NewInt(0)) 489 if err == nil || err.Error() != expected { 490 t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected) 491 } 492 493 expected = "modulus must be odd" 494 _, err = NewModulusFromBig(big.NewInt(2)) 495 if err == nil || err.Error() != expected { 496 t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected) 497 } 498 }