github.com/twelsh-aw/go/src@v0.0.0-20230516233729-a56fe86a7c81/crypto/internal/bigmod/nat_test.go (about) 1 // Copyright 2021 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package bigmod 6 7 import ( 8 "math/big" 9 "math/bits" 10 "math/rand" 11 "reflect" 12 "testing" 13 "testing/quick" 14 ) 15 16 // Generate generates an even nat. It's used by testing/quick to produce random 17 // *nat values for quick.Check invocations. 18 func (*Nat) Generate(r *rand.Rand, size int) reflect.Value { 19 limbs := make([]uint, size) 20 for i := 0; i < size; i++ { 21 limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2) 22 } 23 return reflect.ValueOf(&Nat{limbs}) 24 } 25 26 func testModAddCommutative(a *Nat, b *Nat) bool { 27 m := maxModulus(uint(len(a.limbs))) 28 aPlusB := new(Nat).set(a) 29 aPlusB.Add(b, m) 30 bPlusA := new(Nat).set(b) 31 bPlusA.Add(a, m) 32 return aPlusB.Equal(bPlusA) == 1 33 } 34 35 func TestModAddCommutative(t *testing.T) { 36 err := quick.Check(testModAddCommutative, &quick.Config{}) 37 if err != nil { 38 t.Error(err) 39 } 40 } 41 42 func testModSubThenAddIdentity(a *Nat, b *Nat) bool { 43 m := maxModulus(uint(len(a.limbs))) 44 original := new(Nat).set(a) 45 a.Sub(b, m) 46 a.Add(b, m) 47 return a.Equal(original) == 1 48 } 49 50 func TestModSubThenAddIdentity(t *testing.T) { 51 err := quick.Check(testModSubThenAddIdentity, &quick.Config{}) 52 if err != nil { 53 t.Error(err) 54 } 55 } 56 57 func testMontgomeryRoundtrip(a *Nat) bool { 58 one := &Nat{make([]uint, len(a.limbs))} 59 one.limbs[0] = 1 60 aPlusOne := new(big.Int).SetBytes(natBytes(a)) 61 aPlusOne.Add(aPlusOne, big.NewInt(1)) 62 m := NewModulusFromBig(aPlusOne) 63 monty := new(Nat).set(a) 64 monty.montgomeryRepresentation(m) 65 aAgain := new(Nat).set(monty) 66 aAgain.montgomeryMul(monty, one, m) 67 return a.Equal(aAgain) == 1 68 } 69 70 func TestMontgomeryRoundtrip(t *testing.T) { 71 err := quick.Check(testMontgomeryRoundtrip, &quick.Config{}) 72 if err != nil { 73 t.Error(err) 74 } 75 } 76 77 func TestShiftIn(t *testing.T) { 78 if bits.UintSize != 64 { 79 t.Skip("examples are only valid in 64 bit") 80 } 81 examples := []struct { 82 m, x, expected []byte 83 y uint64 84 }{{ 85 m: []byte{13}, 86 x: []byte{0}, 87 y: 0x7FFF_FFFF_FFFF_FFFF, 88 expected: []byte{7}, 89 }, { 90 m: []byte{13}, 91 x: []byte{7}, 92 y: 0x7FFF_FFFF_FFFF_FFFF, 93 expected: []byte{11}, 94 }, { 95 m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}, 96 x: make([]byte, 9), 97 y: 0x7FFF_FFFF_FFFF_FFFF, 98 expected: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 99 }, { 100 m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}, 101 x: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 102 y: 0, 103 expected: []byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08}, 104 }} 105 106 for i, tt := range examples { 107 m := modulusFromBytes(tt.m) 108 got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m) 109 if got.Equal(natFromBytes(tt.expected).ExpandFor(m)) != 1 { 110 t.Errorf("%d: got %x, expected %x", i, got, tt.expected) 111 } 112 } 113 } 114 115 func TestModulusAndNatSizes(t *testing.T) { 116 // These are 126 bit (2 * _W on 64-bit architectures) values, serialized as 117 // 128 bits worth of bytes. If leading zeroes are stripped, they fit in two 118 // limbs, if they are not, they fit in three. This can be a problem because 119 // modulus strips leading zeroes and nat does not. 120 m := modulusFromBytes([]byte{ 121 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 122 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) 123 xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 124 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe} 125 natFromBytes(xb).ExpandFor(m) // must not panic for shrinking 126 NewNat().SetBytes(xb, m) 127 } 128 129 func TestSetBytes(t *testing.T) { 130 tests := []struct { 131 m, b []byte 132 fail bool 133 }{{ 134 m: []byte{0xff, 0xff}, 135 b: []byte{0x00, 0x01}, 136 }, { 137 m: []byte{0xff, 0xff}, 138 b: []byte{0xff, 0xff}, 139 fail: true, 140 }, { 141 m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 142 b: []byte{0x00, 0x01}, 143 }, { 144 m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 145 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, 146 }, { 147 m: []byte{0xff, 0xff}, 148 b: []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 149 fail: true, 150 }, { 151 m: []byte{0xff, 0xff}, 152 b: []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 153 fail: true, 154 }, { 155 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 156 b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, 157 }, { 158 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 159 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, 160 fail: true, 161 }, { 162 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 163 b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 164 fail: true, 165 }, { 166 m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 167 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, 168 fail: true, 169 }, { 170 m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfd}, 171 b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 172 fail: true, 173 }} 174 175 for i, tt := range tests { 176 m := modulusFromBytes(tt.m) 177 got, err := NewNat().SetBytes(tt.b, m) 178 if err != nil { 179 if !tt.fail { 180 t.Errorf("%d: unexpected error: %v", i, err) 181 } 182 continue 183 } 184 if tt.fail { 185 t.Errorf("%d: unexpected success", i) 186 continue 187 } 188 if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes { 189 t.Errorf("%d: got %x, expected %x", i, got, expected) 190 } 191 } 192 193 f := func(xBytes []byte) bool { 194 m := maxModulus(uint(len(xBytes)*8/_W + 1)) 195 got, err := NewNat().SetBytes(xBytes, m) 196 if err != nil { 197 return false 198 } 199 return got.Equal(natFromBytes(xBytes).ExpandFor(m)) == yes 200 } 201 202 err := quick.Check(f, &quick.Config{}) 203 if err != nil { 204 t.Error(err) 205 } 206 } 207 208 func TestExpand(t *testing.T) { 209 sliced := []uint{1, 2, 3, 4} 210 examples := []struct { 211 in []uint 212 n int 213 out []uint 214 }{{ 215 []uint{1, 2}, 216 4, 217 []uint{1, 2, 0, 0}, 218 }, { 219 sliced[:2], 220 4, 221 []uint{1, 2, 0, 0}, 222 }, { 223 []uint{1, 2}, 224 2, 225 []uint{1, 2}, 226 }} 227 228 for i, tt := range examples { 229 got := (&Nat{tt.in}).expand(tt.n) 230 if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 { 231 t.Errorf("%d: got %x, expected %x", i, got, tt.out) 232 } 233 } 234 } 235 236 func TestMod(t *testing.T) { 237 m := modulusFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}) 238 x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}) 239 out := new(Nat) 240 out.Mod(x, m) 241 expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09}) 242 if out.Equal(expected) != 1 { 243 t.Errorf("%+v != %+v", out, expected) 244 } 245 } 246 247 func TestModSub(t *testing.T) { 248 m := modulusFromBytes([]byte{13}) 249 x := &Nat{[]uint{6}} 250 y := &Nat{[]uint{7}} 251 x.Sub(y, m) 252 expected := &Nat{[]uint{12}} 253 if x.Equal(expected) != 1 { 254 t.Errorf("%+v != %+v", x, expected) 255 } 256 x.Sub(y, m) 257 expected = &Nat{[]uint{5}} 258 if x.Equal(expected) != 1 { 259 t.Errorf("%+v != %+v", x, expected) 260 } 261 } 262 263 func TestModAdd(t *testing.T) { 264 m := modulusFromBytes([]byte{13}) 265 x := &Nat{[]uint{6}} 266 y := &Nat{[]uint{7}} 267 x.Add(y, m) 268 expected := &Nat{[]uint{0}} 269 if x.Equal(expected) != 1 { 270 t.Errorf("%+v != %+v", x, expected) 271 } 272 x.Add(y, m) 273 expected = &Nat{[]uint{7}} 274 if x.Equal(expected) != 1 { 275 t.Errorf("%+v != %+v", x, expected) 276 } 277 } 278 279 func TestExp(t *testing.T) { 280 m := modulusFromBytes([]byte{13}) 281 x := &Nat{[]uint{3}} 282 out := &Nat{[]uint{0}} 283 out.Exp(x, []byte{12}, m) 284 expected := &Nat{[]uint{1}} 285 if out.Equal(expected) != 1 { 286 t.Errorf("%+v != %+v", out, expected) 287 } 288 } 289 290 func natBytes(n *Nat) []byte { 291 return n.Bytes(maxModulus(uint(len(n.limbs)))) 292 } 293 294 func natFromBytes(b []byte) *Nat { 295 bb := new(big.Int).SetBytes(b) 296 return NewNat().setBig(bb) 297 } 298 299 func modulusFromBytes(b []byte) *Modulus { 300 bb := new(big.Int).SetBytes(b) 301 return NewModulusFromBig(bb) 302 } 303 304 // maxModulus returns the biggest modulus that can fit in n limbs. 305 func maxModulus(n uint) *Modulus { 306 m := big.NewInt(1) 307 m.Lsh(m, n*_W) 308 m.Sub(m, big.NewInt(1)) 309 return NewModulusFromBig(m) 310 } 311 312 func makeBenchmarkModulus() *Modulus { 313 return maxModulus(32) 314 } 315 316 func makeBenchmarkValue() *Nat { 317 x := make([]uint, 32) 318 for i := 0; i < 32; i++ { 319 x[i] = _MASK - 1 320 } 321 return &Nat{limbs: x} 322 } 323 324 func makeBenchmarkExponent() []byte { 325 e := make([]byte, 256) 326 for i := 0; i < 32; i++ { 327 e[i] = 0xFF 328 } 329 return e 330 } 331 332 func BenchmarkModAdd(b *testing.B) { 333 x := makeBenchmarkValue() 334 y := makeBenchmarkValue() 335 m := makeBenchmarkModulus() 336 337 b.ResetTimer() 338 for i := 0; i < b.N; i++ { 339 x.Add(y, m) 340 } 341 } 342 343 func BenchmarkModSub(b *testing.B) { 344 x := makeBenchmarkValue() 345 y := makeBenchmarkValue() 346 m := makeBenchmarkModulus() 347 348 b.ResetTimer() 349 for i := 0; i < b.N; i++ { 350 x.Sub(y, m) 351 } 352 } 353 354 func BenchmarkMontgomeryRepr(b *testing.B) { 355 x := makeBenchmarkValue() 356 m := makeBenchmarkModulus() 357 358 b.ResetTimer() 359 for i := 0; i < b.N; i++ { 360 x.montgomeryRepresentation(m) 361 } 362 } 363 364 func BenchmarkMontgomeryMul(b *testing.B) { 365 x := makeBenchmarkValue() 366 y := makeBenchmarkValue() 367 out := makeBenchmarkValue() 368 m := makeBenchmarkModulus() 369 370 b.ResetTimer() 371 for i := 0; i < b.N; i++ { 372 out.montgomeryMul(x, y, m) 373 } 374 } 375 376 func BenchmarkModMul(b *testing.B) { 377 x := makeBenchmarkValue() 378 y := makeBenchmarkValue() 379 m := makeBenchmarkModulus() 380 381 b.ResetTimer() 382 for i := 0; i < b.N; i++ { 383 x.Mul(y, m) 384 } 385 } 386 387 func BenchmarkExpBig(b *testing.B) { 388 out := new(big.Int) 389 exponentBytes := makeBenchmarkExponent() 390 x := new(big.Int).SetBytes(exponentBytes) 391 e := new(big.Int).SetBytes(exponentBytes) 392 n := new(big.Int).SetBytes(exponentBytes) 393 one := new(big.Int).SetUint64(1) 394 n.Add(n, one) 395 396 b.ResetTimer() 397 for i := 0; i < b.N; i++ { 398 out.Exp(x, e, n) 399 } 400 } 401 402 func BenchmarkExp(b *testing.B) { 403 x := makeBenchmarkValue() 404 e := makeBenchmarkExponent() 405 out := makeBenchmarkValue() 406 m := makeBenchmarkModulus() 407 408 b.ResetTimer() 409 for i := 0; i < b.N; i++ { 410 out.Exp(x, e, m) 411 } 412 }