github.com/twelsh-aw/go/src@v0.0.0-20230516233729-a56fe86a7c81/crypto/internal/bigmod/nat_test.go (about)

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package bigmod
     6  
     7  import (
     8  	"math/big"
     9  	"math/bits"
    10  	"math/rand"
    11  	"reflect"
    12  	"testing"
    13  	"testing/quick"
    14  )
    15  
    16  // Generate generates an even nat. It's used by testing/quick to produce random
    17  // *nat values for quick.Check invocations.
    18  func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
    19  	limbs := make([]uint, size)
    20  	for i := 0; i < size; i++ {
    21  		limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2)
    22  	}
    23  	return reflect.ValueOf(&Nat{limbs})
    24  }
    25  
    26  func testModAddCommutative(a *Nat, b *Nat) bool {
    27  	m := maxModulus(uint(len(a.limbs)))
    28  	aPlusB := new(Nat).set(a)
    29  	aPlusB.Add(b, m)
    30  	bPlusA := new(Nat).set(b)
    31  	bPlusA.Add(a, m)
    32  	return aPlusB.Equal(bPlusA) == 1
    33  }
    34  
    35  func TestModAddCommutative(t *testing.T) {
    36  	err := quick.Check(testModAddCommutative, &quick.Config{})
    37  	if err != nil {
    38  		t.Error(err)
    39  	}
    40  }
    41  
    42  func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
    43  	m := maxModulus(uint(len(a.limbs)))
    44  	original := new(Nat).set(a)
    45  	a.Sub(b, m)
    46  	a.Add(b, m)
    47  	return a.Equal(original) == 1
    48  }
    49  
    50  func TestModSubThenAddIdentity(t *testing.T) {
    51  	err := quick.Check(testModSubThenAddIdentity, &quick.Config{})
    52  	if err != nil {
    53  		t.Error(err)
    54  	}
    55  }
    56  
    57  func testMontgomeryRoundtrip(a *Nat) bool {
    58  	one := &Nat{make([]uint, len(a.limbs))}
    59  	one.limbs[0] = 1
    60  	aPlusOne := new(big.Int).SetBytes(natBytes(a))
    61  	aPlusOne.Add(aPlusOne, big.NewInt(1))
    62  	m := NewModulusFromBig(aPlusOne)
    63  	monty := new(Nat).set(a)
    64  	monty.montgomeryRepresentation(m)
    65  	aAgain := new(Nat).set(monty)
    66  	aAgain.montgomeryMul(monty, one, m)
    67  	return a.Equal(aAgain) == 1
    68  }
    69  
    70  func TestMontgomeryRoundtrip(t *testing.T) {
    71  	err := quick.Check(testMontgomeryRoundtrip, &quick.Config{})
    72  	if err != nil {
    73  		t.Error(err)
    74  	}
    75  }
    76  
    77  func TestShiftIn(t *testing.T) {
    78  	if bits.UintSize != 64 {
    79  		t.Skip("examples are only valid in 64 bit")
    80  	}
    81  	examples := []struct {
    82  		m, x, expected []byte
    83  		y              uint64
    84  	}{{
    85  		m:        []byte{13},
    86  		x:        []byte{0},
    87  		y:        0x7FFF_FFFF_FFFF_FFFF,
    88  		expected: []byte{7},
    89  	}, {
    90  		m:        []byte{13},
    91  		x:        []byte{7},
    92  		y:        0x7FFF_FFFF_FFFF_FFFF,
    93  		expected: []byte{11},
    94  	}, {
    95  		m:        []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
    96  		x:        make([]byte, 9),
    97  		y:        0x7FFF_FFFF_FFFF_FFFF,
    98  		expected: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
    99  	}, {
   100  		m:        []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
   101  		x:        []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   102  		y:        0,
   103  		expected: []byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08},
   104  	}}
   105  
   106  	for i, tt := range examples {
   107  		m := modulusFromBytes(tt.m)
   108  		got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m)
   109  		if got.Equal(natFromBytes(tt.expected).ExpandFor(m)) != 1 {
   110  			t.Errorf("%d: got %x, expected %x", i, got, tt.expected)
   111  		}
   112  	}
   113  }
   114  
   115  func TestModulusAndNatSizes(t *testing.T) {
   116  	// These are 126 bit (2 * _W on 64-bit architectures) values, serialized as
   117  	// 128 bits worth of bytes. If leading zeroes are stripped, they fit in two
   118  	// limbs, if they are not, they fit in three. This can be a problem because
   119  	// modulus strips leading zeroes and nat does not.
   120  	m := modulusFromBytes([]byte{
   121  		0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
   122  		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
   123  	xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
   124  		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}
   125  	natFromBytes(xb).ExpandFor(m) // must not panic for shrinking
   126  	NewNat().SetBytes(xb, m)
   127  }
   128  
   129  func TestSetBytes(t *testing.T) {
   130  	tests := []struct {
   131  		m, b []byte
   132  		fail bool
   133  	}{{
   134  		m: []byte{0xff, 0xff},
   135  		b: []byte{0x00, 0x01},
   136  	}, {
   137  		m:    []byte{0xff, 0xff},
   138  		b:    []byte{0xff, 0xff},
   139  		fail: true,
   140  	}, {
   141  		m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   142  		b: []byte{0x00, 0x01},
   143  	}, {
   144  		m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   145  		b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
   146  	}, {
   147  		m:    []byte{0xff, 0xff},
   148  		b:    []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
   149  		fail: true,
   150  	}, {
   151  		m:    []byte{0xff, 0xff},
   152  		b:    []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
   153  		fail: true,
   154  	}, {
   155  		m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   156  		b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
   157  	}, {
   158  		m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   159  		b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
   160  		fail: true,
   161  	}, {
   162  		m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   163  		b:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   164  		fail: true,
   165  	}, {
   166  		m:    []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   167  		b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
   168  		fail: true,
   169  	}, {
   170  		m:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfd},
   171  		b:    []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
   172  		fail: true,
   173  	}}
   174  
   175  	for i, tt := range tests {
   176  		m := modulusFromBytes(tt.m)
   177  		got, err := NewNat().SetBytes(tt.b, m)
   178  		if err != nil {
   179  			if !tt.fail {
   180  				t.Errorf("%d: unexpected error: %v", i, err)
   181  			}
   182  			continue
   183  		}
   184  		if tt.fail {
   185  			t.Errorf("%d: unexpected success", i)
   186  			continue
   187  		}
   188  		if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes {
   189  			t.Errorf("%d: got %x, expected %x", i, got, expected)
   190  		}
   191  	}
   192  
   193  	f := func(xBytes []byte) bool {
   194  		m := maxModulus(uint(len(xBytes)*8/_W + 1))
   195  		got, err := NewNat().SetBytes(xBytes, m)
   196  		if err != nil {
   197  			return false
   198  		}
   199  		return got.Equal(natFromBytes(xBytes).ExpandFor(m)) == yes
   200  	}
   201  
   202  	err := quick.Check(f, &quick.Config{})
   203  	if err != nil {
   204  		t.Error(err)
   205  	}
   206  }
   207  
   208  func TestExpand(t *testing.T) {
   209  	sliced := []uint{1, 2, 3, 4}
   210  	examples := []struct {
   211  		in  []uint
   212  		n   int
   213  		out []uint
   214  	}{{
   215  		[]uint{1, 2},
   216  		4,
   217  		[]uint{1, 2, 0, 0},
   218  	}, {
   219  		sliced[:2],
   220  		4,
   221  		[]uint{1, 2, 0, 0},
   222  	}, {
   223  		[]uint{1, 2},
   224  		2,
   225  		[]uint{1, 2},
   226  	}}
   227  
   228  	for i, tt := range examples {
   229  		got := (&Nat{tt.in}).expand(tt.n)
   230  		if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 {
   231  			t.Errorf("%d: got %x, expected %x", i, got, tt.out)
   232  		}
   233  	}
   234  }
   235  
   236  func TestMod(t *testing.T) {
   237  	m := modulusFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d})
   238  	x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
   239  	out := new(Nat)
   240  	out.Mod(x, m)
   241  	expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
   242  	if out.Equal(expected) != 1 {
   243  		t.Errorf("%+v != %+v", out, expected)
   244  	}
   245  }
   246  
   247  func TestModSub(t *testing.T) {
   248  	m := modulusFromBytes([]byte{13})
   249  	x := &Nat{[]uint{6}}
   250  	y := &Nat{[]uint{7}}
   251  	x.Sub(y, m)
   252  	expected := &Nat{[]uint{12}}
   253  	if x.Equal(expected) != 1 {
   254  		t.Errorf("%+v != %+v", x, expected)
   255  	}
   256  	x.Sub(y, m)
   257  	expected = &Nat{[]uint{5}}
   258  	if x.Equal(expected) != 1 {
   259  		t.Errorf("%+v != %+v", x, expected)
   260  	}
   261  }
   262  
   263  func TestModAdd(t *testing.T) {
   264  	m := modulusFromBytes([]byte{13})
   265  	x := &Nat{[]uint{6}}
   266  	y := &Nat{[]uint{7}}
   267  	x.Add(y, m)
   268  	expected := &Nat{[]uint{0}}
   269  	if x.Equal(expected) != 1 {
   270  		t.Errorf("%+v != %+v", x, expected)
   271  	}
   272  	x.Add(y, m)
   273  	expected = &Nat{[]uint{7}}
   274  	if x.Equal(expected) != 1 {
   275  		t.Errorf("%+v != %+v", x, expected)
   276  	}
   277  }
   278  
   279  func TestExp(t *testing.T) {
   280  	m := modulusFromBytes([]byte{13})
   281  	x := &Nat{[]uint{3}}
   282  	out := &Nat{[]uint{0}}
   283  	out.Exp(x, []byte{12}, m)
   284  	expected := &Nat{[]uint{1}}
   285  	if out.Equal(expected) != 1 {
   286  		t.Errorf("%+v != %+v", out, expected)
   287  	}
   288  }
   289  
   290  func natBytes(n *Nat) []byte {
   291  	return n.Bytes(maxModulus(uint(len(n.limbs))))
   292  }
   293  
   294  func natFromBytes(b []byte) *Nat {
   295  	bb := new(big.Int).SetBytes(b)
   296  	return NewNat().setBig(bb)
   297  }
   298  
   299  func modulusFromBytes(b []byte) *Modulus {
   300  	bb := new(big.Int).SetBytes(b)
   301  	return NewModulusFromBig(bb)
   302  }
   303  
   304  // maxModulus returns the biggest modulus that can fit in n limbs.
   305  func maxModulus(n uint) *Modulus {
   306  	m := big.NewInt(1)
   307  	m.Lsh(m, n*_W)
   308  	m.Sub(m, big.NewInt(1))
   309  	return NewModulusFromBig(m)
   310  }
   311  
   312  func makeBenchmarkModulus() *Modulus {
   313  	return maxModulus(32)
   314  }
   315  
   316  func makeBenchmarkValue() *Nat {
   317  	x := make([]uint, 32)
   318  	for i := 0; i < 32; i++ {
   319  		x[i] = _MASK - 1
   320  	}
   321  	return &Nat{limbs: x}
   322  }
   323  
   324  func makeBenchmarkExponent() []byte {
   325  	e := make([]byte, 256)
   326  	for i := 0; i < 32; i++ {
   327  		e[i] = 0xFF
   328  	}
   329  	return e
   330  }
   331  
   332  func BenchmarkModAdd(b *testing.B) {
   333  	x := makeBenchmarkValue()
   334  	y := makeBenchmarkValue()
   335  	m := makeBenchmarkModulus()
   336  
   337  	b.ResetTimer()
   338  	for i := 0; i < b.N; i++ {
   339  		x.Add(y, m)
   340  	}
   341  }
   342  
   343  func BenchmarkModSub(b *testing.B) {
   344  	x := makeBenchmarkValue()
   345  	y := makeBenchmarkValue()
   346  	m := makeBenchmarkModulus()
   347  
   348  	b.ResetTimer()
   349  	for i := 0; i < b.N; i++ {
   350  		x.Sub(y, m)
   351  	}
   352  }
   353  
   354  func BenchmarkMontgomeryRepr(b *testing.B) {
   355  	x := makeBenchmarkValue()
   356  	m := makeBenchmarkModulus()
   357  
   358  	b.ResetTimer()
   359  	for i := 0; i < b.N; i++ {
   360  		x.montgomeryRepresentation(m)
   361  	}
   362  }
   363  
   364  func BenchmarkMontgomeryMul(b *testing.B) {
   365  	x := makeBenchmarkValue()
   366  	y := makeBenchmarkValue()
   367  	out := makeBenchmarkValue()
   368  	m := makeBenchmarkModulus()
   369  
   370  	b.ResetTimer()
   371  	for i := 0; i < b.N; i++ {
   372  		out.montgomeryMul(x, y, m)
   373  	}
   374  }
   375  
   376  func BenchmarkModMul(b *testing.B) {
   377  	x := makeBenchmarkValue()
   378  	y := makeBenchmarkValue()
   379  	m := makeBenchmarkModulus()
   380  
   381  	b.ResetTimer()
   382  	for i := 0; i < b.N; i++ {
   383  		x.Mul(y, m)
   384  	}
   385  }
   386  
   387  func BenchmarkExpBig(b *testing.B) {
   388  	out := new(big.Int)
   389  	exponentBytes := makeBenchmarkExponent()
   390  	x := new(big.Int).SetBytes(exponentBytes)
   391  	e := new(big.Int).SetBytes(exponentBytes)
   392  	n := new(big.Int).SetBytes(exponentBytes)
   393  	one := new(big.Int).SetUint64(1)
   394  	n.Add(n, one)
   395  
   396  	b.ResetTimer()
   397  	for i := 0; i < b.N; i++ {
   398  		out.Exp(x, e, n)
   399  	}
   400  }
   401  
   402  func BenchmarkExp(b *testing.B) {
   403  	x := makeBenchmarkValue()
   404  	e := makeBenchmarkExponent()
   405  	out := makeBenchmarkValue()
   406  	m := makeBenchmarkModulus()
   407  
   408  	b.ResetTimer()
   409  	for i := 0; i < b.N; i++ {
   410  		out.Exp(x, e, m)
   411  	}
   412  }