gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/dsp/fourier/radix24_test.go (about)

     1  // Copyright ©2020 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fourier
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"math/bits"
    11  	"slices"
    12  	"strconv"
    13  	"testing"
    14  	"unsafe"
    15  
    16  	"golang.org/x/exp/rand"
    17  
    18  	"gonum.org/v1/gonum/cmplxs"
    19  )
    20  
    21  func TestCoefficients(t *testing.T) {
    22  	const tol = 1e-8
    23  
    24  	src := rand.NewSource(1)
    25  	for n := 4; n < 1<<20; n <<= 1 {
    26  		for i := 0; i < 10; i++ {
    27  			t.Run(fmt.Sprintf("Radix2/%d", n), func(t *testing.T) {
    28  				d := randComplexes(n, src)
    29  				fft := NewCmplxFFT(n)
    30  				want := fft.Coefficients(nil, d)
    31  				CoefficientsRadix2(d)
    32  				got := d
    33  				if !cmplxs.EqualApprox(got, want, tol) {
    34  					t.Errorf("unexpected result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
    35  				}
    36  
    37  				want = fft.Sequence(nil, got)
    38  				scale(1/float64(n), want)
    39  
    40  				SequenceRadix2(got)
    41  				scale(1/float64(n), got)
    42  
    43  				if !cmplxs.EqualApprox(got, want, tol) {
    44  					t.Errorf("unexpected ifft result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
    45  				}
    46  			})
    47  			if bits.Len(uint(n))&0x1 == 0 {
    48  				continue
    49  			}
    50  			t.Run(fmt.Sprintf("Radix4/%d", n), func(t *testing.T) {
    51  				d := randComplexes(n, src)
    52  				fft := NewCmplxFFT(n)
    53  				want := fft.Coefficients(nil, d)
    54  				CoefficientsRadix4(d)
    55  				got := d
    56  				if !cmplxs.EqualApprox(got, want, tol) {
    57  					t.Errorf("unexpected fft result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
    58  				}
    59  
    60  				want = fft.Sequence(nil, got)
    61  				scale(1/float64(n), want)
    62  
    63  				SequenceRadix4(got)
    64  				scale(1/float64(n), got)
    65  
    66  				if !cmplxs.EqualApprox(got, want, tol) {
    67  					t.Errorf("unexpected ifft result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
    68  				}
    69  			})
    70  		}
    71  	}
    72  }
    73  
    74  func TestSequence(t *testing.T) {
    75  	const tol = 1e-10
    76  
    77  	src := rand.NewSource(1)
    78  	for n := 4; n < 1<<20; n <<= 1 {
    79  		for i := 0; i < 10; i++ {
    80  			t.Run(fmt.Sprintf("Radix2/%d", n), func(t *testing.T) {
    81  				d := randComplexes(n, src)
    82  				want := make([]complex128, n)
    83  				copy(want, d)
    84  				SequenceRadix2(CoefficientsRadix2(d))
    85  				got := d
    86  
    87  				scale(1/float64(n), got)
    88  
    89  				if !cmplxs.EqualApprox(got, want, tol) {
    90  					t.Errorf("unexpected result for ifft(fft(d)) n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
    91  				}
    92  			})
    93  			if bits.Len(uint(n))&0x1 == 0 {
    94  				continue
    95  			}
    96  			t.Run(fmt.Sprintf("Radix4/%d", n), func(t *testing.T) {
    97  				d := randComplexes(n, src)
    98  				want := make([]complex128, n)
    99  				copy(want, d)
   100  				SequenceRadix4(CoefficientsRadix4(d))
   101  				got := d
   102  
   103  				scale(1/float64(n), got)
   104  
   105  				if !cmplxs.EqualApprox(got, want, tol) {
   106  					t.Errorf("unexpected result for ifft(fft(d)) n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
   107  				}
   108  			})
   109  		}
   110  	}
   111  }
   112  
   113  func scale(f float64, c []complex128) {
   114  	for i, v := range c {
   115  		c[i] = complex(f*real(v), f*imag(v))
   116  	}
   117  }
   118  
   119  func TestBitReversePermute(t *testing.T) {
   120  	for n := 2; n <= 1024; n <<= 1 {
   121  		x := make([]complex128, n)
   122  		for i := range x {
   123  			x[i] = complex(float64(i), float64(i))
   124  		}
   125  		bitReversePermute(x)
   126  		for i, got := range x {
   127  			j := bits.Reverse(uint(i)) >> bits.LeadingZeros(uint(n-1))
   128  			want := complex(float64(j), float64(j))
   129  			if got != want {
   130  				t.Errorf("unexpected value at %d: got:%f want:%f", i, got, want)
   131  			}
   132  		}
   133  	}
   134  }
   135  
   136  func TestPadRadix2(t *testing.T) {
   137  	for n := 1; n <= 1025; n++ {
   138  		x := make([]complex128, n)
   139  		y := PadRadix2(x)
   140  		if bits.OnesCount(uint(len(y))) != 1 {
   141  			t.Errorf("unexpected length of padded slice: not a power of 2: %d", len(y))
   142  		}
   143  		if len(x) == len(y) && &y[0] != &x[0] {
   144  			t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n)
   145  		}
   146  		if len(y) < len(x) {
   147  			t.Errorf("unexpected short result: len(y)=%d < len(x)=%d", len(y), len(x))
   148  		}
   149  	}
   150  }
   151  
   152  func TestTrimRadix2(t *testing.T) {
   153  	for n := 1; n <= 1025; n++ {
   154  		x := make([]complex128, n)
   155  		y, r := TrimRadix2(x)
   156  		if bits.OnesCount(uint(len(y))) != 1 {
   157  			t.Errorf("unexpected length of trimmed slice: not a power of 2: %d", len(y))
   158  		}
   159  		if len(y)+len(r) != len(x) {
   160  			t.Errorf("unexpected total result: len(y)=%d + len(r)%d != len(x)=%d", len(y), len(r), len(x))
   161  		}
   162  		if len(x) == len(y) && &y[0] != &x[0] {
   163  			t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n)
   164  		}
   165  		if len(y) > len(x) {
   166  			t.Errorf("unexpected long result: len(y)=%d > len(x)=%d", len(y), len(x))
   167  		}
   168  	}
   169  }
   170  
   171  func TestBitPairReversePermute(t *testing.T) {
   172  	for n := 4; n <= 1024; n <<= 2 {
   173  		x := make([]complex128, n)
   174  		for i := range x {
   175  			x[i] = complex(float64(i), float64(i))
   176  		}
   177  		bitPairReversePermute(x)
   178  		for i, got := range x {
   179  			j := reversePairs(uint(i)) >> bits.LeadingZeros(uint(n-1))
   180  			want := complex(float64(j), float64(j))
   181  			if got != want {
   182  				t.Errorf("unexpected value at %d: got:%f want:%f", i, got, want)
   183  			}
   184  		}
   185  	}
   186  }
   187  
   188  func TestReversePairs(t *testing.T) {
   189  	rnd := rand.New(rand.NewSource(1))
   190  	for i := 0; i < 1000; i++ {
   191  		x := uint(rnd.Uint64())
   192  		got := reversePairs(x)
   193  		want := naiveReversePairs(x)
   194  		if got != want {
   195  			t.Errorf("unexpected bit-pair reversal for 0b%064b:\ngot: 0b%064b\nwant:0b%064b", x, got, want)
   196  		}
   197  	}
   198  }
   199  
   200  // naiveReversePairs does a bit-pair reversal by formatting as a base-4 string,
   201  // reversing the digits of the formatted number and then re-parsing the value.
   202  func naiveReversePairs(x uint) uint {
   203  	bits := int(unsafe.Sizeof(uint(0)) * 8)
   204  
   205  	// Format the number as a quaternary, padded with zeros.
   206  	// We avoid the leftpad issue by doing it ourselves.
   207  	b := strconv.AppendUint(bytes.Repeat([]byte("0"), bits/2), uint64(x), 4)
   208  	b = b[len(b)-bits/2:]
   209  
   210  	// Reverse the quits.
   211  	slices.Reverse(b)
   212  
   213  	// Re-parse the reversed number.
   214  	y, err := strconv.ParseUint(string(b), 4, 64)
   215  	if err != nil {
   216  		panic(fmt.Sprintf("unexpected parse error: %v", err))
   217  	}
   218  	return uint(y)
   219  }
   220  
   221  func TestPadRadix4(t *testing.T) {
   222  	for n := 1; n <= 1025; n++ {
   223  		x := make([]complex128, n)
   224  		y := PadRadix4(x)
   225  		if bits.OnesCount(uint(len(y))) != 1 || bits.Len(uint(len(y)))&0x1 == 0 {
   226  			t.Errorf("unexpected length of padded slice: not a power of 4: %d", len(y))
   227  		}
   228  		if len(x) == len(y) && &y[0] != &x[0] {
   229  			t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n)
   230  		}
   231  		if len(y) < len(x) {
   232  			t.Errorf("unexpected short result: len(y)=%d < len(x)=%d", len(y), len(x))
   233  		}
   234  	}
   235  }
   236  
   237  func TestTrimRadix4(t *testing.T) {
   238  	for n := 1; n <= 1025; n++ {
   239  		x := make([]complex128, n)
   240  		y, r := TrimRadix4(x)
   241  		if bits.OnesCount(uint(len(y))) != 1 || bits.Len(uint(len(y)))&0x1 == 0 {
   242  			t.Errorf("unexpected length of trimmed slice: not a power of 4: %d", len(y))
   243  		}
   244  		if len(y)+len(r) != len(x) {
   245  			t.Errorf("unexpected total result: len(y)=%d + len(r)%d != len(x)=%d", len(y), len(r), len(x))
   246  		}
   247  		if len(x) == len(y) && &y[0] != &x[0] {
   248  			t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n)
   249  		}
   250  		if len(y) > len(x) {
   251  			t.Errorf("unexpected long result: len(y)=%d > len(x)=%d", len(y), len(x))
   252  		}
   253  	}
   254  }
   255  
   256  func BenchmarkCoefficients(b *testing.B) {
   257  	for n := 16; n < 1<<24; n <<= 3 {
   258  		d := randComplexes(n, rand.NewSource(1))
   259  		b.Run(fmt.Sprintf("Radix2/%d", n), func(b *testing.B) {
   260  			for i := 0; i < b.N; i++ {
   261  				CoefficientsRadix2(d)
   262  			}
   263  		})
   264  		if bits.Len(uint(n))&0x1 == 0 {
   265  			continue
   266  		}
   267  		b.Run(fmt.Sprintf("Radix4/%d", n), func(b *testing.B) {
   268  			for i := 0; i < b.N; i++ {
   269  				CoefficientsRadix4(d)
   270  			}
   271  		})
   272  	}
   273  }
   274  
   275  func BenchmarkSequence(b *testing.B) {
   276  	for n := 16; n < 1<<24; n <<= 3 {
   277  		d := randComplexes(n, rand.NewSource(1))
   278  		b.Run(fmt.Sprintf("Radix2/%d", n), func(b *testing.B) {
   279  			for i := 0; i < b.N; i++ {
   280  				SequenceRadix2(d)
   281  			}
   282  		})
   283  		if bits.Len(uint(n))&0x1 == 0 {
   284  			continue
   285  		}
   286  		b.Run(fmt.Sprintf("Radix4/%d", n), func(b *testing.B) {
   287  			for i := 0; i < b.N; i++ {
   288  				SequenceRadix4(d)
   289  			}
   290  		})
   291  	}
   292  }