gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/dsp/fourier/radix24_test.go (about) 1 // Copyright ©2020 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package fourier 6 7 import ( 8 "bytes" 9 "fmt" 10 "math/bits" 11 "slices" 12 "strconv" 13 "testing" 14 "unsafe" 15 16 "golang.org/x/exp/rand" 17 18 "gonum.org/v1/gonum/cmplxs" 19 ) 20 21 func TestCoefficients(t *testing.T) { 22 const tol = 1e-8 23 24 src := rand.NewSource(1) 25 for n := 4; n < 1<<20; n <<= 1 { 26 for i := 0; i < 10; i++ { 27 t.Run(fmt.Sprintf("Radix2/%d", n), func(t *testing.T) { 28 d := randComplexes(n, src) 29 fft := NewCmplxFFT(n) 30 want := fft.Coefficients(nil, d) 31 CoefficientsRadix2(d) 32 got := d 33 if !cmplxs.EqualApprox(got, want, tol) { 34 t.Errorf("unexpected result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2)) 35 } 36 37 want = fft.Sequence(nil, got) 38 scale(1/float64(n), want) 39 40 SequenceRadix2(got) 41 scale(1/float64(n), got) 42 43 if !cmplxs.EqualApprox(got, want, tol) { 44 t.Errorf("unexpected ifft result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2)) 45 } 46 }) 47 if bits.Len(uint(n))&0x1 == 0 { 48 continue 49 } 50 t.Run(fmt.Sprintf("Radix4/%d", n), func(t *testing.T) { 51 d := randComplexes(n, src) 52 fft := NewCmplxFFT(n) 53 want := fft.Coefficients(nil, d) 54 CoefficientsRadix4(d) 55 got := d 56 if !cmplxs.EqualApprox(got, want, tol) { 57 t.Errorf("unexpected fft result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2)) 58 } 59 60 want = fft.Sequence(nil, got) 61 scale(1/float64(n), want) 62 63 SequenceRadix4(got) 64 scale(1/float64(n), got) 65 66 if !cmplxs.EqualApprox(got, want, tol) { 67 t.Errorf("unexpected ifft result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2)) 68 } 69 }) 70 } 71 } 72 } 73 74 func TestSequence(t *testing.T) { 75 const tol = 1e-10 76 77 src := rand.NewSource(1) 78 for n := 4; n < 1<<20; n <<= 1 { 79 for i := 0; i < 10; i++ { 80 t.Run(fmt.Sprintf("Radix2/%d", n), func(t *testing.T) { 81 d := randComplexes(n, src) 82 want := make([]complex128, n) 83 copy(want, d) 84 SequenceRadix2(CoefficientsRadix2(d)) 85 got := d 86 87 scale(1/float64(n), got) 88 89 if !cmplxs.EqualApprox(got, want, tol) { 90 t.Errorf("unexpected result for ifft(fft(d)) n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2)) 91 } 92 }) 93 if bits.Len(uint(n))&0x1 == 0 { 94 continue 95 } 96 t.Run(fmt.Sprintf("Radix4/%d", n), func(t *testing.T) { 97 d := randComplexes(n, src) 98 want := make([]complex128, n) 99 copy(want, d) 100 SequenceRadix4(CoefficientsRadix4(d)) 101 got := d 102 103 scale(1/float64(n), got) 104 105 if !cmplxs.EqualApprox(got, want, tol) { 106 t.Errorf("unexpected result for ifft(fft(d)) n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2)) 107 } 108 }) 109 } 110 } 111 } 112 113 func scale(f float64, c []complex128) { 114 for i, v := range c { 115 c[i] = complex(f*real(v), f*imag(v)) 116 } 117 } 118 119 func TestBitReversePermute(t *testing.T) { 120 for n := 2; n <= 1024; n <<= 1 { 121 x := make([]complex128, n) 122 for i := range x { 123 x[i] = complex(float64(i), float64(i)) 124 } 125 bitReversePermute(x) 126 for i, got := range x { 127 j := bits.Reverse(uint(i)) >> bits.LeadingZeros(uint(n-1)) 128 want := complex(float64(j), float64(j)) 129 if got != want { 130 t.Errorf("unexpected value at %d: got:%f want:%f", i, got, want) 131 } 132 } 133 } 134 } 135 136 func TestPadRadix2(t *testing.T) { 137 for n := 1; n <= 1025; n++ { 138 x := make([]complex128, n) 139 y := PadRadix2(x) 140 if bits.OnesCount(uint(len(y))) != 1 { 141 t.Errorf("unexpected length of padded slice: not a power of 2: %d", len(y)) 142 } 143 if len(x) == len(y) && &y[0] != &x[0] { 144 t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n) 145 } 146 if len(y) < len(x) { 147 t.Errorf("unexpected short result: len(y)=%d < len(x)=%d", len(y), len(x)) 148 } 149 } 150 } 151 152 func TestTrimRadix2(t *testing.T) { 153 for n := 1; n <= 1025; n++ { 154 x := make([]complex128, n) 155 y, r := TrimRadix2(x) 156 if bits.OnesCount(uint(len(y))) != 1 { 157 t.Errorf("unexpected length of trimmed slice: not a power of 2: %d", len(y)) 158 } 159 if len(y)+len(r) != len(x) { 160 t.Errorf("unexpected total result: len(y)=%d + len(r)%d != len(x)=%d", len(y), len(r), len(x)) 161 } 162 if len(x) == len(y) && &y[0] != &x[0] { 163 t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n) 164 } 165 if len(y) > len(x) { 166 t.Errorf("unexpected long result: len(y)=%d > len(x)=%d", len(y), len(x)) 167 } 168 } 169 } 170 171 func TestBitPairReversePermute(t *testing.T) { 172 for n := 4; n <= 1024; n <<= 2 { 173 x := make([]complex128, n) 174 for i := range x { 175 x[i] = complex(float64(i), float64(i)) 176 } 177 bitPairReversePermute(x) 178 for i, got := range x { 179 j := reversePairs(uint(i)) >> bits.LeadingZeros(uint(n-1)) 180 want := complex(float64(j), float64(j)) 181 if got != want { 182 t.Errorf("unexpected value at %d: got:%f want:%f", i, got, want) 183 } 184 } 185 } 186 } 187 188 func TestReversePairs(t *testing.T) { 189 rnd := rand.New(rand.NewSource(1)) 190 for i := 0; i < 1000; i++ { 191 x := uint(rnd.Uint64()) 192 got := reversePairs(x) 193 want := naiveReversePairs(x) 194 if got != want { 195 t.Errorf("unexpected bit-pair reversal for 0b%064b:\ngot: 0b%064b\nwant:0b%064b", x, got, want) 196 } 197 } 198 } 199 200 // naiveReversePairs does a bit-pair reversal by formatting as a base-4 string, 201 // reversing the digits of the formatted number and then re-parsing the value. 202 func naiveReversePairs(x uint) uint { 203 bits := int(unsafe.Sizeof(uint(0)) * 8) 204 205 // Format the number as a quaternary, padded with zeros. 206 // We avoid the leftpad issue by doing it ourselves. 207 b := strconv.AppendUint(bytes.Repeat([]byte("0"), bits/2), uint64(x), 4) 208 b = b[len(b)-bits/2:] 209 210 // Reverse the quits. 211 slices.Reverse(b) 212 213 // Re-parse the reversed number. 214 y, err := strconv.ParseUint(string(b), 4, 64) 215 if err != nil { 216 panic(fmt.Sprintf("unexpected parse error: %v", err)) 217 } 218 return uint(y) 219 } 220 221 func TestPadRadix4(t *testing.T) { 222 for n := 1; n <= 1025; n++ { 223 x := make([]complex128, n) 224 y := PadRadix4(x) 225 if bits.OnesCount(uint(len(y))) != 1 || bits.Len(uint(len(y)))&0x1 == 0 { 226 t.Errorf("unexpected length of padded slice: not a power of 4: %d", len(y)) 227 } 228 if len(x) == len(y) && &y[0] != &x[0] { 229 t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n) 230 } 231 if len(y) < len(x) { 232 t.Errorf("unexpected short result: len(y)=%d < len(x)=%d", len(y), len(x)) 233 } 234 } 235 } 236 237 func TestTrimRadix4(t *testing.T) { 238 for n := 1; n <= 1025; n++ { 239 x := make([]complex128, n) 240 y, r := TrimRadix4(x) 241 if bits.OnesCount(uint(len(y))) != 1 || bits.Len(uint(len(y)))&0x1 == 0 { 242 t.Errorf("unexpected length of trimmed slice: not a power of 4: %d", len(y)) 243 } 244 if len(y)+len(r) != len(x) { 245 t.Errorf("unexpected total result: len(y)=%d + len(r)%d != len(x)=%d", len(y), len(r), len(x)) 246 } 247 if len(x) == len(y) && &y[0] != &x[0] { 248 t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n) 249 } 250 if len(y) > len(x) { 251 t.Errorf("unexpected long result: len(y)=%d > len(x)=%d", len(y), len(x)) 252 } 253 } 254 } 255 256 func BenchmarkCoefficients(b *testing.B) { 257 for n := 16; n < 1<<24; n <<= 3 { 258 d := randComplexes(n, rand.NewSource(1)) 259 b.Run(fmt.Sprintf("Radix2/%d", n), func(b *testing.B) { 260 for i := 0; i < b.N; i++ { 261 CoefficientsRadix2(d) 262 } 263 }) 264 if bits.Len(uint(n))&0x1 == 0 { 265 continue 266 } 267 b.Run(fmt.Sprintf("Radix4/%d", n), func(b *testing.B) { 268 for i := 0; i < b.N; i++ { 269 CoefficientsRadix4(d) 270 } 271 }) 272 } 273 } 274 275 func BenchmarkSequence(b *testing.B) { 276 for n := 16; n < 1<<24; n <<= 3 { 277 d := randComplexes(n, rand.NewSource(1)) 278 b.Run(fmt.Sprintf("Radix2/%d", n), func(b *testing.B) { 279 for i := 0; i < b.N; i++ { 280 SequenceRadix2(d) 281 } 282 }) 283 if bits.Len(uint(n))&0x1 == 0 { 284 continue 285 } 286 b.Run(fmt.Sprintf("Radix4/%d", n), func(b *testing.B) { 287 for i := 0; i < b.N; i++ { 288 SequenceRadix4(d) 289 } 290 }) 291 } 292 }