github.com/remyoudompheng/bigfft@v0.0.0-20230129092748-24d4a6f8daec/fermat_test.go (about)

     1  package bigfft
     2  
     3  import (
     4  	"fmt"
     5  	"math/big"
     6  	"testing"
     7  )
     8  
     9  type (
    10  	Int = big.Int
    11  )
    12  
    13  // parseHex reads an hex-formatted number modulo 2^bits+1.
    14  func parseHex(s string, bits int) fermat {
    15  	z := new(Int)
    16  	z, ok := z.SetString(s, 0)
    17  	if !ok {
    18  		panic(s)
    19  	}
    20  	f := fermat(z.Bits())
    21  	for len(f)*_W <= bits {
    22  		f = append(f, 0)
    23  	}
    24  	return f
    25  }
    26  
    27  func compare(t *testing.T, prefix string, a, b fermat) {
    28  	var x, y Int
    29  	x.SetBits(a)
    30  	y.SetBits(b)
    31  	if x.Cmp(&y) != 0 {
    32  		t.Errorf("%s: %x != %x", prefix, &x, &y)
    33  	}
    34  }
    35  
    36  func TestFermatShift(t *testing.T) {
    37  	const n = 4
    38  	f := make(fermat, n+1)
    39  	for i := 0; i < n; i++ {
    40  		f[i] = Word(rnd.Int63())
    41  	}
    42  	b := big.NewInt(1)
    43  	b = b.Lsh(b, uint(n*_W))
    44  	b = b.Add(b, big.NewInt(1))
    45  	z := make(fermat, len(f)) // Test with uninitialized z.
    46  	for shift := -2048; shift < 2048; shift++ {
    47  		z.Shift(f, shift)
    48  
    49  		z2 := new(Int)
    50  		z2.SetBits(f)
    51  		if shift < 0 {
    52  			s2 := (-shift) % (2 * n * _W)
    53  			z2 = z2.Lsh(z2, uint(2*n*_W-s2))
    54  		} else {
    55  			z2 = z2.Lsh(z2, uint(shift))
    56  		}
    57  		z2 = z2.Mod(z2, b)
    58  		compare(t, fmt.Sprintf("shift %d", shift), z, z2.Bits())
    59  	}
    60  }
    61  
    62  func TestFermatShiftHalf(t *testing.T) {
    63  	const n = 3
    64  	f := make(fermat, n+1)
    65  	for i := 0; i < n; i++ {
    66  		f[i] = ^Word(0)
    67  	}
    68  	b := big.NewInt(1)
    69  	b = b.Lsh(b, uint(n*_W))
    70  	b = b.Add(b, big.NewInt(1))
    71  	z := make(fermat, len(f)) // Test with uninitialized z.
    72  	tmp := make(fermat, len(f))
    73  	tmp2 := make(fermat, len(f))
    74  	for shift := 0; shift < 16384; shift++ {
    75  		// Shift twice by shift/2
    76  		z.ShiftHalf(f, shift, tmp)
    77  		copy(tmp, z)
    78  		z.ShiftHalf(tmp, shift, tmp2)
    79  
    80  		z2 := new(Int)
    81  		z2 = z2.Lsh(new(Int).SetBits(f), uint(shift))
    82  		z2 = z2.Mod(z2, b)
    83  		compare(t, fmt.Sprintf("shift %d", shift), z, z2.Bits())
    84  	}
    85  }
    86  
    87  type test struct{ a, b, c fermat }
    88  
    89  // addTests is a series of mod 2^256+1 tests.
    90  var addTests = []test{
    91  	{
    92  		parseHex("0x5555555555555555555555555555555555555555555555555555555555555555", 256),
    93  		parseHex("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab", 256),
    94  		parseHex("0x10000000000000000000000000000000000000000000000000000000000000000", 256),
    95  	},
    96  	{
    97  		parseHex("0x5555555555555555555555555555555555555555555555555555555555555555", 256),
    98  		parseHex("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", 256),
    99  		parseHex("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 256),
   100  	},
   101  	{
   102  		parseHex("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", 256),
   103  		parseHex("0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", 256),
   104  		parseHex("0x5555555555555555555555555555555555555555555555555555555555555553", 256),
   105  	},
   106  }
   107  
   108  func TestFermatAdd(t *testing.T) {
   109  	for i, item := range addTests {
   110  		z := make(fermat, len(item.a))
   111  		z = z.Add(item.a, item.b)
   112  		compare(t, fmt.Sprintf("addTests[%d]", i), z, item.c)
   113  	}
   114  }
   115  
   116  var mulTests = []test{
   117  	{ // 3^400 = 3^200 * 3^200
   118  		parseHex("0xc21a937a76f3432ffd73d97e447606b683ecf6f6e4a7ae223c2578e26c486a03", 256),
   119  		parseHex("0xc21a937a76f3432ffd73d97e447606b683ecf6f6e4a7ae223c2578e26c486a03", 256),
   120  		parseHex("0x0e65f4d3508036eaca8faa2b8194ace009c863e44bdc040c459a7127bf8bcc62", 256),
   121  	},
   122  	{ // 2^256 * 2^256 mod (2^256+1) = 1.
   123  		parseHex("0x10000000000000000000000000000000000000000000000000000000000000000", 256),
   124  		parseHex("0x10000000000000000000000000000000000000000000000000000000000000000", 256),
   125  		parseHex("0x1", 256),
   126  	},
   127  	{ // (2^256-1) * (2^256-1) mod (2^256+1) = 4.
   128  		parseHex("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 256),
   129  		parseHex("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", 256),
   130  		parseHex("0x4", 256),
   131  	},
   132  	{ // 1<<(64W) * 1<<(64W) mod (1<<64W+1) = 1
   133  		fermat{64: 1},
   134  		fermat{64: 1},
   135  		fermat{0: 1},
   136  	},
   137  	{
   138  		// Test case from issue 1. One of the squares of the Fourier
   139  		// transforms was miscomputed.
   140  		// The input number is made of 18 words, but we are working modulo 2^1280+1
   141  		parseHex("0xfffffffffffffffffffffffeffffffffffffffffffffffffffffffffffff00000000000000000000000100000000000000000000000000000000000000000000000000000000fffeffffffffffffffffffffffffffffffffffffffffffffffffffffffff000100000000000000000000000100000000000000000000000000000000fffefffffffffffffffffffd", 1280),
   142  		parseHex("0xfffffffffffffffffffffffeffffffffffffffffffffffffffffffffffff00000000000000000000000100000000000000000000000000000000000000000000000000000000fffeffffffffffffffffffffffffffffffffffffffffffffffffffffffff000100000000000000000000000100000000000000000000000000000000fffefffffffffffffffffffd", 1280),
   143  		parseHex("0xfffe00000003fffc0000000000000000fff40003000000000000000000060001fffffffd0001fffffffffffffffe000dfffbfffffffffffffffffffafffe0000000200000000000000000002fff60002fffffffffffffffa00060001ffffffff0000000000000000fffc0007fffe0000000000000007fff8fffdfffffffffffffffffffa00000004fffa0000fffffffffffffff600080000000000000000000a", 1280),
   144  	},
   145  }
   146  
   147  func TestFermatMul(t *testing.T) {
   148  	for i, item := range mulTests {
   149  		z := make(fermat, 3*len(item.a))
   150  		z = z.Mul(item.a, item.b)
   151  		compare(t, fmt.Sprintf("mulTests[%d]", i), z, item.c)
   152  	}
   153  }