github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/dsp/fourier/radix24_test.go (about)

     1  // Copyright ©2020 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fourier
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"math/bits"
    11  	"strconv"
    12  	"testing"
    13  	"unsafe"
    14  
    15  	"golang.org/x/exp/rand"
    16  
    17  	"github.com/jingcheng-WU/gonum/cmplxs"
    18  )
    19  
    20  func TestCoefficients(t *testing.T) {
    21  	const tol = 1e-8
    22  
    23  	src := rand.NewSource(1)
    24  	for n := 4; n < 1<<20; n <<= 1 {
    25  		for i := 0; i < 10; i++ {
    26  			t.Run(fmt.Sprintf("Radix2/%d", n), func(t *testing.T) {
    27  				d := randComplexes(n, src)
    28  				fft := NewCmplxFFT(n)
    29  				want := fft.Coefficients(nil, d)
    30  				CoefficientsRadix2(d)
    31  				got := d
    32  				if !cmplxs.EqualApprox(got, want, tol) {
    33  					t.Errorf("unexpected result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
    34  				}
    35  
    36  				want = fft.Sequence(nil, got)
    37  				scale(1/float64(n), want)
    38  
    39  				SequenceRadix2(got)
    40  				scale(1/float64(n), got)
    41  
    42  				if !cmplxs.EqualApprox(got, want, tol) {
    43  					t.Errorf("unexpected ifft result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
    44  				}
    45  			})
    46  			if bits.Len(uint(n))&0x1 == 0 {
    47  				continue
    48  			}
    49  			t.Run(fmt.Sprintf("Radix4/%d", n), func(t *testing.T) {
    50  				d := randComplexes(n, src)
    51  				fft := NewCmplxFFT(n)
    52  				want := fft.Coefficients(nil, d)
    53  				CoefficientsRadix4(d)
    54  				got := d
    55  				if !cmplxs.EqualApprox(got, want, tol) {
    56  					t.Errorf("unexpected fft result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
    57  				}
    58  
    59  				want = fft.Sequence(nil, got)
    60  				scale(1/float64(n), want)
    61  
    62  				SequenceRadix4(got)
    63  				scale(1/float64(n), got)
    64  
    65  				if !cmplxs.EqualApprox(got, want, tol) {
    66  					t.Errorf("unexpected ifft result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
    67  				}
    68  			})
    69  		}
    70  	}
    71  }
    72  
    73  func TestSequence(t *testing.T) {
    74  	const tol = 1e-10
    75  
    76  	src := rand.NewSource(1)
    77  	for n := 4; n < 1<<20; n <<= 1 {
    78  		for i := 0; i < 10; i++ {
    79  			t.Run(fmt.Sprintf("Radix2/%d", n), func(t *testing.T) {
    80  				d := randComplexes(n, src)
    81  				want := make([]complex128, n)
    82  				copy(want, d)
    83  				SequenceRadix2(CoefficientsRadix2(d))
    84  				got := d
    85  
    86  				scale(1/float64(n), got)
    87  
    88  				if !cmplxs.EqualApprox(got, want, tol) {
    89  					t.Errorf("unexpected result for ifft(fft(d)) n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
    90  				}
    91  			})
    92  			if bits.Len(uint(n))&0x1 == 0 {
    93  				continue
    94  			}
    95  			t.Run(fmt.Sprintf("Radix4/%d", n), func(t *testing.T) {
    96  				d := randComplexes(n, src)
    97  				want := make([]complex128, n)
    98  				copy(want, d)
    99  				SequenceRadix4(CoefficientsRadix4(d))
   100  				got := d
   101  
   102  				scale(1/float64(n), got)
   103  
   104  				if !cmplxs.EqualApprox(got, want, tol) {
   105  					t.Errorf("unexpected result for ifft(fft(d)) n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2))
   106  				}
   107  			})
   108  		}
   109  	}
   110  }
   111  
   112  func scale(f float64, c []complex128) {
   113  	for i, v := range c {
   114  		c[i] = complex(f*real(v), f*imag(v))
   115  	}
   116  }
   117  
   118  func TestBitReversePermute(t *testing.T) {
   119  	for n := 2; n <= 1024; n <<= 1 {
   120  		x := make([]complex128, n)
   121  		for i := range x {
   122  			x[i] = complex(float64(i), float64(i))
   123  		}
   124  		bitReversePermute(x)
   125  		for i, got := range x {
   126  			j := bits.Reverse(uint(i)) >> bits.LeadingZeros(uint(n-1))
   127  			want := complex(float64(j), float64(j))
   128  			if got != want {
   129  				t.Errorf("unexpected value at %d: got:%f want:%f", i, got, want)
   130  			}
   131  		}
   132  	}
   133  }
   134  
   135  func TestPadRadix2(t *testing.T) {
   136  	for n := 1; n <= 1025; n++ {
   137  		x := make([]complex128, n)
   138  		y := PadRadix2(x)
   139  		if bits.OnesCount(uint(len(y))) != 1 {
   140  			t.Errorf("unexpected length of padded slice: not a power of 2: %d", len(y))
   141  		}
   142  		if len(x) == len(y) && &y[0] != &x[0] {
   143  			t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n)
   144  		}
   145  		if len(y) < len(x) {
   146  			t.Errorf("unexpected short result: len(y)=%d < len(x)=%d", len(y), len(x))
   147  		}
   148  	}
   149  }
   150  
   151  func TestTrimRadix2(t *testing.T) {
   152  	for n := 1; n <= 1025; n++ {
   153  		x := make([]complex128, n)
   154  		y, r := TrimRadix2(x)
   155  		if bits.OnesCount(uint(len(y))) != 1 {
   156  			t.Errorf("unexpected length of trimmed slice: not a power of 2: %d", len(y))
   157  		}
   158  		if len(y)+len(r) != len(x) {
   159  			t.Errorf("unexpected total result: len(y)=%d + len(r)%d != len(x)=%d", len(y), len(r), len(x))
   160  		}
   161  		if len(x) == len(y) && &y[0] != &x[0] {
   162  			t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n)
   163  		}
   164  		if len(y) > len(x) {
   165  			t.Errorf("unexpected long result: len(y)=%d > len(x)=%d", len(y), len(x))
   166  		}
   167  	}
   168  }
   169  
   170  func TestBitPairReversePermute(t *testing.T) {
   171  	for n := 4; n <= 1024; n <<= 2 {
   172  		x := make([]complex128, n)
   173  		for i := range x {
   174  			x[i] = complex(float64(i), float64(i))
   175  		}
   176  		bitPairReversePermute(x)
   177  		for i, got := range x {
   178  			j := reversePairs(uint(i)) >> bits.LeadingZeros(uint(n-1))
   179  			want := complex(float64(j), float64(j))
   180  			if got != want {
   181  				t.Errorf("unexpected value at %d: got:%f want:%f", i, got, want)
   182  			}
   183  		}
   184  	}
   185  }
   186  
   187  func TestReversePairs(t *testing.T) {
   188  	rnd := rand.New(rand.NewSource(1))
   189  	for i := 0; i < 1000; i++ {
   190  		x := uint(rnd.Uint64())
   191  		got := reversePairs(x)
   192  		want := naiveReversePairs(x)
   193  		if got != want {
   194  			t.Errorf("unexpected bit-pair reversal for 0b%064b:\ngot: 0b%064b\nwant:0b%064b", x, got, want)
   195  		}
   196  	}
   197  }
   198  
   199  // naiveReversePairs does a bit-pair reversal by formatting as a base-4 string,
   200  // reversing the digits of the formatted number and then re-parsing the value.
   201  func naiveReversePairs(x uint) uint {
   202  	bits := int(unsafe.Sizeof(uint(0)) * 8)
   203  
   204  	// Format the number as a quaternary, padded with zeros.
   205  	// We avoid the leftpad issue by doing it ourselves.
   206  	b := strconv.AppendUint(bytes.Repeat([]byte("0"), bits/2), uint64(x), 4)
   207  	b = b[len(b)-bits/2:]
   208  
   209  	// Reverse the quits.
   210  	for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 {
   211  		b[i], b[j] = b[j], b[i]
   212  	}
   213  
   214  	// Re-parse the reversed number.
   215  	y, err := strconv.ParseUint(string(b), 4, 64)
   216  	if err != nil {
   217  		panic(fmt.Sprintf("unexpected parse error: %v", err))
   218  	}
   219  	return uint(y)
   220  }
   221  
   222  func TestPadRadix4(t *testing.T) {
   223  	for n := 1; n <= 1025; n++ {
   224  		x := make([]complex128, n)
   225  		y := PadRadix4(x)
   226  		if bits.OnesCount(uint(len(y))) != 1 || bits.Len(uint(len(y)))&0x1 == 0 {
   227  			t.Errorf("unexpected length of padded slice: not a power of 4: %d", len(y))
   228  		}
   229  		if len(x) == len(y) && &y[0] != &x[0] {
   230  			t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n)
   231  		}
   232  		if len(y) < len(x) {
   233  			t.Errorf("unexpected short result: len(y)=%d < len(x)=%d", len(y), len(x))
   234  		}
   235  	}
   236  }
   237  
   238  func TestTrimRadix4(t *testing.T) {
   239  	for n := 1; n <= 1025; n++ {
   240  		x := make([]complex128, n)
   241  		y, r := TrimRadix4(x)
   242  		if bits.OnesCount(uint(len(y))) != 1 || bits.Len(uint(len(y)))&0x1 == 0 {
   243  			t.Errorf("unexpected length of trimmed slice: not a power of 4: %d", len(y))
   244  		}
   245  		if len(y)+len(r) != len(x) {
   246  			t.Errorf("unexpected total result: len(y)=%d + len(r)%d != len(x)=%d", len(y), len(r), len(x))
   247  		}
   248  		if len(x) == len(y) && &y[0] != &x[0] {
   249  			t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n)
   250  		}
   251  		if len(y) > len(x) {
   252  			t.Errorf("unexpected long result: len(y)=%d > len(x)=%d", len(y), len(x))
   253  		}
   254  	}
   255  }
   256  
   257  func BenchmarkCoefficients(b *testing.B) {
   258  	for n := 16; n < 1<<24; n <<= 3 {
   259  		d := randComplexes(n, rand.NewSource(1))
   260  		b.Run(fmt.Sprintf("Radix2/%d", n), func(b *testing.B) {
   261  			for i := 0; i < b.N; i++ {
   262  				CoefficientsRadix2(d)
   263  			}
   264  		})
   265  		if bits.Len(uint(n))&0x1 == 0 {
   266  			continue
   267  		}
   268  		b.Run(fmt.Sprintf("Radix4/%d", n), func(b *testing.B) {
   269  			for i := 0; i < b.N; i++ {
   270  				CoefficientsRadix4(d)
   271  			}
   272  		})
   273  	}
   274  }
   275  
   276  func BenchmarkSequence(b *testing.B) {
   277  	for n := 16; n < 1<<24; n <<= 3 {
   278  		d := randComplexes(n, rand.NewSource(1))
   279  		b.Run(fmt.Sprintf("Radix2/%d", n), func(b *testing.B) {
   280  			for i := 0; i < b.N; i++ {
   281  				SequenceRadix2(d)
   282  			}
   283  		})
   284  		if bits.Len(uint(n))&0x1 == 0 {
   285  			continue
   286  		}
   287  		b.Run(fmt.Sprintf("Radix4/%d", n), func(b *testing.B) {
   288  			for i := 0; i < b.N; i++ {
   289  				SequenceRadix4(d)
   290  			}
   291  		})
   292  	}
   293  }