github.com/remyoudompheng/bigfft@v0.0.0-20230129092748-24d4a6f8daec/fft_test.go (about)

     1  package bigfft
     2  
     3  import (
     4  	"fmt"
     5  	"math/big"
     6  	"math/rand"
     7  	"testing"
     8  )
     9  
    10  func cmpnat(t *testing.T, x, y nat) int {
    11  	var a, b Int
    12  	a.SetBits(x)
    13  	b.SetBits(y)
    14  	c := a.Cmp(&b)
    15  	if c != 0 {
    16  		t.Logf("a.len=%d, b.len=%d", a.BitLen(), b.BitLen())
    17  		for i := 0; i < len(x) || i < len(y); i++ {
    18  			var u, v Word
    19  			if i < len(x) {
    20  				u = x[i]
    21  			}
    22  			if i < len(y) {
    23  				v = y[i]
    24  			}
    25  			if diff := u ^ v; diff != 0 {
    26  				t.Logf("diff at word %d: %x", i, diff)
    27  			}
    28  		}
    29  	}
    30  	return c
    31  }
    32  
    33  func TestRoundTripIntPoly(t *testing.T) {
    34  	N := 4
    35  	step := 500
    36  	if testing.Short() {
    37  		N = 2
    38  	}
    39  	// Sizes 12800 and 34300 may cause problems.
    40  	for size := 300; size < 50000; size += step {
    41  		n := make(nat, size)
    42  		for i := 0; i < N; i++ {
    43  			for p := range n {
    44  				n[p] = Word(rand.Int63())
    45  			}
    46  			k, m := fftSize(n, nil)
    47  			pol := polyFromNat(n, k, m)
    48  			n2 := pol.Int()
    49  			if cmpnat(t, n, n2) != 0 {
    50  				t.Errorf("different n and n2, size=%d, iter=%d", size, i)
    51  			}
    52  		}
    53  	}
    54  }
    55  
    56  func TestFourierSizes(t *testing.T) {
    57  	sizes := []int{
    58  		2e3, 3e3, 5e3, 7e3, 10e3, 14e3,
    59  		2e4, 3e4, 5e4, 7e4, 10e4, 14e4,
    60  		2e5, 3e5, 5e5, 7e5, 10e5, 14e5,
    61  		2e6, 3e6, 5e6, 7e6, 10e6, 14e6,
    62  		2e7, 3e7, 5e7, 7e7, 10e7, 14e7,
    63  		2e8, 3e8, 5e8, 7e8, 10e8, 14e8,
    64  	}
    65  	for _, s := range sizes {
    66  		k, m := fftSize(make(nat, s/_W), make(nat, s/_W))
    67  		v := valueSize(k, m, 2)
    68  		t.Logf("bits=%d => FFT size %d, chunk size = %d, value size = %d",
    69  			s, 1<<k, m, v)
    70  		needed := 2*m*_W + int(k)
    71  		got := v * _W
    72  		t.Logf("inefficiency: value/chunk_product=%.2f, fftsize/inputsize=%.2f",
    73  			float64(got)/float64(needed), float64(v<<k)/float64(2*s/_W))
    74  		if v > 3*m {
    75  			t.Errorf("FFT word size %d >> input word size %d", v, m)
    76  		}
    77  	}
    78  }
    79  
    80  func testFourier(t *testing.T, N int, k uint) {
    81  	// Random coefficients
    82  	src := make([]fermat, 1<<k)
    83  	for i := range src {
    84  		src[i] = make(fermat, N+1)
    85  		for p := 0; p < N; p++ {
    86  			src[i][p] = Word(rnd.Int63())
    87  		}
    88  	}
    89  	cmpFourier(t, N, k, src, false)
    90  	cmpFourier(t, N, k, src, true)
    91  
    92  	// Saturated coefficients (b^N-1)
    93  	for i := range src {
    94  		for p := 0; p < N; p++ {
    95  			src[i][p] = ^Word(0)
    96  		}
    97  	}
    98  	cmpFourier(t, N, k, src, false)
    99  	cmpFourier(t, N, k, src, true)
   100  }
   101  
   102  // cmpFourier computes the Fourier transform of src
   103  // and compares it to the FFT result.
   104  func cmpFourier(t *testing.T, N int, k uint, src []fermat, inverse bool) {
   105  	t.Logf("testFourier(t, %d, %d, inverse=%v)", N, k, inverse)
   106  	ωshift := (4 * N * _W) >> k
   107  	if inverse {
   108  		ωshift = -ωshift
   109  	}
   110  	dst1 := make([]fermat, 1<<k)
   111  	dst2 := make([]fermat, 1<<k)
   112  	for i := range src {
   113  		dst1[i] = make(fermat, N+1)
   114  		dst2[i] = make(fermat, N+1)
   115  	}
   116  
   117  	// naive transform
   118  	tmp := make(fermat, N+1)
   119  	tmp2 := make(fermat, N+1)
   120  	for i := range src {
   121  		for j := range dst1 {
   122  			tmp.ShiftHalf(src[i], i*j*ωshift, tmp2)
   123  			dst1[j].Add(dst1[j], tmp)
   124  		}
   125  	}
   126  
   127  	// fast transform
   128  	fourier(dst2, src, inverse, N, k)
   129  
   130  	for i := range src {
   131  		if cmpnat(t, nat(dst1[i]), nat(dst2[i])) != 0 {
   132  			var x, y Int
   133  			x.SetBits(dst1[i])
   134  			y.SetBits(dst2[i])
   135  			t.Errorf("difference in dst[%d]: %x %x", i, &x, &y)
   136  		}
   137  	}
   138  }
   139  
   140  func TestFourier(t *testing.T) {
   141  	// 1-word transforms.
   142  	testFourier(t, 1, 2)
   143  	testFourier(t, 1, 3)
   144  	testFourier(t, 1, 4)
   145  
   146  	// 2-word transforms
   147  	testFourier(t, 2, 2)
   148  	testFourier(t, 2, 3)
   149  	testFourier(t, 2, 4)
   150  	testFourier(t, 2, 8)
   151  
   152  	testFourier(t, 4, 4)
   153  	testFourier(t, 4, 5)
   154  	testFourier(t, 4, 6)
   155  	testFourier(t, 4, 8)
   156  
   157  	// Test a few limit cases. This is when
   158  	// N*WordSize is a multiple of 1<<(k-2) but not 1<<(k-1)
   159  	if _W == 64 {
   160  		testFourier(t, 1, 8)
   161  		testFourier(t, 3, 8)
   162  		testFourier(t, 5, 8)
   163  		testFourier(t, 7, 8)
   164  		testFourier(t, 9, 8)
   165  		testFourier(t, 11, 8)
   166  	}
   167  }
   168  
   169  // Tests Fourier transform and its reciprocal.
   170  func TestRoundTripPolyValues(t *testing.T) {
   171  	Size := 100000
   172  	if testing.Short() {
   173  		Size = 50
   174  	}
   175  	// Build a polynomial from an integer.
   176  	n := make(nat, Size)
   177  	for p := range n {
   178  		n[p] = Word(rand.Int63())
   179  	}
   180  	k, m := fftSize(n, nil)
   181  	pol := polyFromNat(n, k, m)
   182  
   183  	// Transform it.
   184  	f := valueSize(k, m, 1)
   185  	values := pol.Transform(f)
   186  
   187  	// Inverse transform.
   188  	pol2 := values.InvTransform()
   189  	pol2.m = m
   190  
   191  	t.Logf("k=%d, m=%d", k, m)
   192  
   193  	// Evaluate and compare.
   194  	n2 := pol2.Int()
   195  	if cmpnat(t, n, n2) != 0 {
   196  		t.Errorf("different n and n2")
   197  	}
   198  }
   199  
   200  var rnd = rand.New(rand.NewSource(0x43de683f473542af))
   201  
   202  func rndNat(n int) nat {
   203  	x := make(nat, n)
   204  	for i := 0; i < n; i++ {
   205  		x[i] = Word(rnd.Int63()<<1 + rnd.Int63n(2))
   206  	}
   207  	return x
   208  }
   209  
   210  func TestMul(t *testing.T) {
   211  	sizes := []int{1e3, 5e3, 15e3, 25e3, 70e3, 200e3, 500e3}
   212  	iters := 10
   213  	if testing.Short() {
   214  		iters = 1
   215  	}
   216  
   217  	var x, y Int
   218  	for i := 0; i < iters; i++ {
   219  		for _, size1 := range sizes {
   220  			for _, size2 := range sizes {
   221  				x.SetBits(rndNat(size1 / _W))
   222  				y.SetBits(rndNat(size2 / _W))
   223  				z := new(Int).Mul(&x, &y)
   224  				z2 := Mul(&x, &y)
   225  				if z.Cmp(z2) != 0 {
   226  					t.Errorf("z (%d bits) != z2 (%d bits)", z.BitLen(), z2.BitLen())
   227  					logbig(t, new(Int).Xor(z, z2))
   228  				}
   229  			}
   230  		}
   231  	}
   232  }
   233  
   234  func logbig(t *testing.T, n *Int) {
   235  	s := fmt.Sprintf("%x", n)
   236  	for len(s) > 64 {
   237  		t.Log(s[:64])
   238  		s = s[64:]
   239  	}
   240  	t.Log(s)
   241  }
   242  
   243  func benchmarkMulBig(b *testing.B, sizex, sizey int) {
   244  	mulx := rndNat(sizex / _W)
   245  	muly := rndNat(sizey / _W)
   246  	b.ResetTimer()
   247  	var x, y, z Int
   248  	x.SetBits(mulx)
   249  	y.SetBits(muly)
   250  	for i := 0; i < b.N; i++ {
   251  		z.Mul(&x, &y)
   252  	}
   253  }
   254  
   255  func benchmarkMulFFT(b *testing.B, sizex, sizey int) {
   256  	mulx := rndNat(sizex / _W)
   257  	muly := rndNat(sizey / _W)
   258  	b.ResetTimer()
   259  	var x, y Int
   260  	x.SetBits(mulx)
   261  	y.SetBits(muly)
   262  	for i := 0; i < b.N; i++ {
   263  		_ = mulFFT(&x, &y)
   264  	}
   265  }
   266  
   267  func BenchmarkMulBig_1kb(b *testing.B)   { benchmarkMulBig(b, 1e3, 1e3) }
   268  func BenchmarkMulBig_10kb(b *testing.B)  { benchmarkMulBig(b, 1e4, 1e4) }
   269  func BenchmarkMulBig_50kb(b *testing.B)  { benchmarkMulBig(b, 5e4, 5e4) }
   270  func BenchmarkMulBig_100kb(b *testing.B) { benchmarkMulBig(b, 1e5, 1e5) }
   271  func BenchmarkMulBig_200kb(b *testing.B) { benchmarkMulBig(b, 2e5, 2e5) }
   272  func BenchmarkMulBig_500kb(b *testing.B) { benchmarkMulBig(b, 5e5, 5e5) }
   273  func BenchmarkMulBig_1Mb(b *testing.B)   { benchmarkMulBig(b, 1e6, 1e6) }
   274  func BenchmarkMulBig_2Mb(b *testing.B)   { benchmarkMulBig(b, 2e6, 2e6) }
   275  func BenchmarkMulBig_5Mb(b *testing.B)   { benchmarkMulBig(b, 5e6, 5e6) }
   276  func BenchmarkMulBig_10Mb(b *testing.B)  { benchmarkMulBig(b, 10e6, 10e6) }
   277  func BenchmarkMulBig_20Mb(b *testing.B)  { benchmarkMulBig(b, 20e6, 20e6) }
   278  func BenchmarkMulBig_50Mb(b *testing.B)  { benchmarkMulBig(b, 50e6, 50e6) }
   279  func BenchmarkMulBig_100Mb(b *testing.B) { benchmarkMulBig(b, 100e6, 100e6) }
   280  
   281  func BenchmarkMulFFT_1kb(b *testing.B)   { benchmarkMulFFT(b, 1e3, 1e3) }
   282  func BenchmarkMulFFT_10kb(b *testing.B)  { benchmarkMulFFT(b, 1e4, 1e4) }
   283  func BenchmarkMulFFT_50kb(b *testing.B)  { benchmarkMulFFT(b, 5e4, 5e4) }
   284  func BenchmarkMulFFT_100kb(b *testing.B) { benchmarkMulFFT(b, 1e5, 1e5) }
   285  func BenchmarkMulFFT_200kb(b *testing.B) { benchmarkMulFFT(b, 2e5, 2e5) }
   286  func BenchmarkMulFFT_500kb(b *testing.B) { benchmarkMulFFT(b, 5e5, 5e5) }
   287  func BenchmarkMulFFT_1Mb(b *testing.B)   { benchmarkMulFFT(b, 1e6, 1e6) }
   288  func BenchmarkMulFFT_2Mb(b *testing.B)   { benchmarkMulFFT(b, 2e6, 2e6) }
   289  func BenchmarkMulFFT_5Mb(b *testing.B)   { benchmarkMulFFT(b, 5e6, 5e6) }
   290  func BenchmarkMulFFT_10Mb(b *testing.B)  { benchmarkMulFFT(b, 10e6, 10e6) }
   291  func BenchmarkMulFFT_20Mb(b *testing.B)  { benchmarkMulFFT(b, 20e6, 20e6) }
   292  func BenchmarkMulFFT_50Mb(b *testing.B)  { benchmarkMulFFT(b, 50e6, 50e6) }
   293  func BenchmarkMulFFT_100Mb(b *testing.B) { benchmarkMulFFT(b, 100e6, 100e6) }
   294  func BenchmarkMulFFT_200Mb(b *testing.B) { benchmarkMulFFT(b, 200e6, 200e6) }
   295  func BenchmarkMulFFT_500Mb(b *testing.B) { benchmarkMulFFT(b, 500e6, 500e6) }
   296  func BenchmarkMulFFT_1Gb(b *testing.B)   { benchmarkMulFFT(b, 1e9, 1e9) }
   297  
   298  func benchmarkMul(b *testing.B, sizex, sizey int) {
   299  	mulx := rndNat(sizex / _W)
   300  	muly := rndNat(sizey / _W)
   301  	b.ResetTimer()
   302  	for i := 0; i < b.N; i++ {
   303  		var x, y Int
   304  		x.SetBits(mulx)
   305  		y.SetBits(muly)
   306  		_ = Mul(&x, &y)
   307  	}
   308  }
   309  
   310  func BenchmarkMul_50kb(b *testing.B)  { benchmarkMul(b, 5e4, 5e4) }
   311  func BenchmarkMul_100kb(b *testing.B) { benchmarkMul(b, 1e5, 1e5) }
   312  func BenchmarkMul_200kb(b *testing.B) { benchmarkMul(b, 2e5, 2e5) }
   313  func BenchmarkMul_500kb(b *testing.B) { benchmarkMul(b, 5e5, 5e5) }
   314  func BenchmarkMul_1Mb(b *testing.B)   { benchmarkMul(b, 1e6, 1e6) }
   315  func BenchmarkMul_2Mb(b *testing.B)   { benchmarkMul(b, 2e6, 2e6) }
   316  func BenchmarkMul_5Mb(b *testing.B)   { benchmarkMul(b, 5e6, 5e6) }
   317  func BenchmarkMul_10Mb(b *testing.B)  { benchmarkMul(b, 10e6, 10e6) }
   318  func BenchmarkMul_20Mb(b *testing.B)  { benchmarkMul(b, 20e6, 20e6) }
   319  func BenchmarkMul_50Mb(b *testing.B)  { benchmarkMul(b, 50e6, 50e6) }
   320  func BenchmarkMul_100Mb(b *testing.B) { benchmarkMul(b, 100e6, 100e6) }
   321  
   322  // Unbalanced multiplication benchmarks
   323  func BenchmarkMul_1x5Mb(b *testing.B)  { benchmarkMul(b, 1e6, 5e6) }
   324  func BenchmarkMul_1x10Mb(b *testing.B) { benchmarkMul(b, 1e6, 10e6) }
   325  func BenchmarkMul_1x20Mb(b *testing.B) { benchmarkMul(b, 1e6, 20e6) }
   326  func BenchmarkMul_1x50Mb(b *testing.B) { benchmarkMul(b, 1e6, 50e6) }
   327  func BenchmarkMul_5x20Mb(b *testing.B) { benchmarkMul(b, 5e6, 20e6) }
   328  func BenchmarkMul_5x50Mb(b *testing.B) { benchmarkMul(b, 5e6, 50e6) }
   329  
   330  func BenchmarkMulBig_1x5Mb(b *testing.B)  { benchmarkMulBig(b, 1e6, 5e6) }
   331  func BenchmarkMulBig_1x10Mb(b *testing.B) { benchmarkMulBig(b, 1e6, 10e6) }
   332  func BenchmarkMulBig_1x20Mb(b *testing.B) { benchmarkMulBig(b, 1e6, 20e6) }
   333  func BenchmarkMulBig_1x50Mb(b *testing.B) { benchmarkMulBig(b, 1e6, 50e6) }
   334  func BenchmarkMulBig_5x20Mb(b *testing.B) { benchmarkMulBig(b, 5e6, 20e6) }
   335  func BenchmarkMulBig_5x50Mb(b *testing.B) { benchmarkMulBig(b, 5e6, 50e6) }
   336  
   337  func BenchmarkMulFFT_1x5Mb(b *testing.B)  { benchmarkMulFFT(b, 1e6, 5e6) }
   338  func BenchmarkMulFFT_1x10Mb(b *testing.B) { benchmarkMulFFT(b, 1e6, 10e6) }
   339  func BenchmarkMulFFT_1x20Mb(b *testing.B) { benchmarkMulFFT(b, 1e6, 20e6) }
   340  func BenchmarkMulFFT_1x50Mb(b *testing.B) { benchmarkMulFFT(b, 1e6, 50e6) }
   341  func BenchmarkMulFFT_5x20Mb(b *testing.B) { benchmarkMulFFT(b, 5e6, 20e6) }
   342  func BenchmarkMulFFT_5x50Mb(b *testing.B) { benchmarkMulFFT(b, 5e6, 50e6) }
   343  
   344  func TestIssue1(t *testing.T) {
   345  	e := big.NewInt(1)
   346  	e.SetBit(e, 132048, 1)
   347  	e.Sub(e, big.NewInt(4)) // e == 1<<132048 - 4
   348  	g := big.NewInt(0).Set(e)
   349  	e.Mul(e, e)
   350  	g = Mul(g, g)
   351  	if g.Cmp(e) != 0 {
   352  		t.Fatal("incorrect Mul result")
   353  	}
   354  }