github.com/remyoudompheng/bigfft@v0.0.0-20230129092748-24d4a6f8daec/calibrate_test.go (about)

     1  // Usage: go test -run=TestCalibrate -calibrate
     2  
     3  package bigfft
     4  
     5  import (
     6  	"flag"
     7  	"fmt"
     8  	"testing"
     9  	"time"
    10  )
    11  
    12  var calibrate = flag.Bool("calibrate", false, "run calibration test")
    13  
    14  // measureMul benchmarks math/big versus FFT for a given input size
    15  // (in bits).
    16  func measureMul(th int) (tBig, tFFT time.Duration) {
    17  	bigLoad := func(b *testing.B) { benchmarkMulBig(b, th, th) }
    18  	fftLoad := func(b *testing.B) { benchmarkMulFFT(b, th, th) }
    19  
    20  	res1 := testing.Benchmark(bigLoad)
    21  	res2 := testing.Benchmark(fftLoad)
    22  	tBig = time.Duration(res1.NsPerOp())
    23  	tFFT = time.Duration(res2.NsPerOp())
    24  	return
    25  }
    26  
    27  func roundDur(d time.Duration) time.Duration {
    28  	if d > 100*time.Millisecond {
    29  		return d / time.Millisecond * time.Millisecond
    30  	} else {
    31  		return d / time.Microsecond * time.Microsecond
    32  	}
    33  }
    34  
    35  func TestCalibrateThreshold(t *testing.T) {
    36  	if !*calibrate {
    37  		t.Log("not calibrating, use -calibrate to do so.")
    38  		return
    39  	}
    40  
    41  	lower := int(1e3)   // math/big is faster at this size.
    42  	upper := int(300e3) // FFT is faster at this size.
    43  
    44  	var sizes [9]int
    45  	var speedups [9]float64
    46  	for i := 0; i < 3; i++ {
    47  		for idx := 1; idx <= 9; idx++ {
    48  			sz := ((10-idx)*lower + idx*upper) / 10
    49  			big, fft := measureMul(sz)
    50  			spd := float64(big) / float64(fft)
    51  			sizes[idx-1] = sz
    52  			speedups[idx-1] = spd
    53  			fmt.Printf("speedup of FFT over math/big at size %d bits: %.2f (%s vs %s)\n",
    54  				sz, spd, roundDur(big), roundDur(fft))
    55  		}
    56  		narrow := false
    57  		for idx, s := range speedups {
    58  			if s < .98 {
    59  				lower = sizes[idx]
    60  				narrow = true
    61  			} else {
    62  				break
    63  			}
    64  		}
    65  		for idx := range speedups {
    66  			if speedups[8-idx] > 1.02 {
    67  				upper = sizes[8-idx]
    68  				narrow = true
    69  			} else {
    70  				break
    71  			}
    72  		}
    73  		if lower >= upper {
    74  			panic("impossible")
    75  		}
    76  		if !narrow || (upper-lower) <= 10 {
    77  			break
    78  		}
    79  	}
    80  	fmt.Printf("sizes: %d\n", sizes)
    81  	fmt.Printf("speedups: %.2f\n", speedups)
    82  }
    83  
    84  func measureFFTSize(w int, k uint) time.Duration {
    85  	load := func(b *testing.B) {
    86  		x := rndNat(w)
    87  		y := rndNat(w)
    88  		for i := 0; i < b.N; i++ {
    89  			m := (w+w)>>k + 1
    90  			xp := polyFromNat(x, k, m)
    91  			yp := polyFromNat(y, k, m)
    92  			rp := xp.Mul(&yp)
    93  			_ = rp.Int()
    94  		}
    95  	}
    96  	res := testing.Benchmark(load)
    97  	return time.Duration(res.NsPerOp())
    98  }
    99  
   100  func TestCalibrateFFT(t *testing.T) {
   101  	if !*calibrate {
   102  		t.Log("not calibrating, use -calibrate to do so.")
   103  		return
   104  	}
   105  
   106  	lows := [...]int{10, 10, 10, 10,
   107  		20, 50, 100, 200, 500, // 8
   108  		1000, 2000, 5000, 10000, // 12
   109  		20000, 50000, 100e3, 200e3, // 16
   110  	}
   111  	his := [...]int{100, 100, 100, 200,
   112  		500, 1000, 2000, 5000, 10000, // 8
   113  		50e3, 100e3, 200e3, 800e3, // 12
   114  		2e6, 5e6, 10e6, 20e6, // 16
   115  	}
   116  	for k := uint(3); k <= 16; k++ {
   117  		// Measure the speedup between k and k+1
   118  		low := lows[k] // FFT of size 1<<k known to be faster
   119  		hi := his[k]   // FFT of size 2<<k known to be faster
   120  		var sizes [9]int
   121  		var speedups [9]float64
   122  		for i := 0; i < 3; i++ {
   123  			for idx := 1; idx <= 9; idx++ {
   124  				sz := ((10-idx)*low + idx*hi) / 10
   125  				t1, t2 := measureFFTSize(sz, k), measureFFTSize(sz, k+1)
   126  				spd := float64(t1) / float64(t2)
   127  				sizes[idx-1] = sz
   128  				speedups[idx-1] = spd
   129  				fmt.Printf("speedup of %d vs %d at size %d words: %.2f (%s vs %s)\n",
   130  					k+1, k, sz, spd, roundDur(t1), roundDur(t2))
   131  			}
   132  			narrow := false
   133  			for idx, s := range speedups {
   134  				if s < .98 {
   135  					low = sizes[idx]
   136  					narrow = true
   137  				} else {
   138  					break
   139  				}
   140  			}
   141  			for idx := range speedups {
   142  				if speedups[8-idx] > 1.02 {
   143  					hi = sizes[8-idx]
   144  					narrow = true
   145  				} else {
   146  					break
   147  				}
   148  			}
   149  			if low >= hi {
   150  				panic("impossible")
   151  			}
   152  			if !narrow || (hi-low) <= 10 {
   153  				break
   154  			}
   155  		}
   156  		fmt.Printf("sizes: %d\n", sizes)
   157  		fmt.Printf("speedups: %.2f\n", speedups)
   158  	}
   159  }