gonum.org/v1/gonum@v0.14.0/dsp/fourier/radix24_test.go (about) 1 // Copyright ©2020 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package fourier 6 7 import ( 8 "bytes" 9 "fmt" 10 "math/bits" 11 "strconv" 12 "testing" 13 "unsafe" 14 15 "golang.org/x/exp/rand" 16 17 "gonum.org/v1/gonum/cmplxs" 18 ) 19 20 func TestCoefficients(t *testing.T) { 21 const tol = 1e-8 22 23 src := rand.NewSource(1) 24 for n := 4; n < 1<<20; n <<= 1 { 25 for i := 0; i < 10; i++ { 26 t.Run(fmt.Sprintf("Radix2/%d", n), func(t *testing.T) { 27 d := randComplexes(n, src) 28 fft := NewCmplxFFT(n) 29 want := fft.Coefficients(nil, d) 30 CoefficientsRadix2(d) 31 got := d 32 if !cmplxs.EqualApprox(got, want, tol) { 33 t.Errorf("unexpected result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2)) 34 } 35 36 want = fft.Sequence(nil, got) 37 scale(1/float64(n), want) 38 39 SequenceRadix2(got) 40 scale(1/float64(n), got) 41 42 if !cmplxs.EqualApprox(got, want, tol) { 43 t.Errorf("unexpected ifft result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2)) 44 } 45 }) 46 if bits.Len(uint(n))&0x1 == 0 { 47 continue 48 } 49 t.Run(fmt.Sprintf("Radix4/%d", n), func(t *testing.T) { 50 d := randComplexes(n, src) 51 fft := NewCmplxFFT(n) 52 want := fft.Coefficients(nil, d) 53 CoefficientsRadix4(d) 54 got := d 55 if !cmplxs.EqualApprox(got, want, tol) { 56 t.Errorf("unexpected fft result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2)) 57 } 58 59 want = fft.Sequence(nil, got) 60 scale(1/float64(n), want) 61 62 SequenceRadix4(got) 63 scale(1/float64(n), got) 64 65 if !cmplxs.EqualApprox(got, want, tol) { 66 t.Errorf("unexpected ifft result for n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2)) 67 } 68 }) 69 } 70 } 71 } 72 73 func TestSequence(t *testing.T) { 74 const tol = 1e-10 75 76 src := rand.NewSource(1) 77 for n := 4; n < 1<<20; n <<= 1 { 78 for i := 0; i < 10; i++ { 79 t.Run(fmt.Sprintf("Radix2/%d", n), func(t *testing.T) { 80 d := randComplexes(n, src) 81 want := make([]complex128, n) 82 copy(want, d) 83 SequenceRadix2(CoefficientsRadix2(d)) 84 got := d 85 86 scale(1/float64(n), got) 87 88 if !cmplxs.EqualApprox(got, want, tol) { 89 t.Errorf("unexpected result for ifft(fft(d)) n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2)) 90 } 91 }) 92 if bits.Len(uint(n))&0x1 == 0 { 93 continue 94 } 95 t.Run(fmt.Sprintf("Radix4/%d", n), func(t *testing.T) { 96 d := randComplexes(n, src) 97 want := make([]complex128, n) 98 copy(want, d) 99 SequenceRadix4(CoefficientsRadix4(d)) 100 got := d 101 102 scale(1/float64(n), got) 103 104 if !cmplxs.EqualApprox(got, want, tol) { 105 t.Errorf("unexpected result for ifft(fft(d)) n=%d |got-want|^2=%g", n, cmplxs.Distance(got, want, 2)) 106 } 107 }) 108 } 109 } 110 } 111 112 func scale(f float64, c []complex128) { 113 for i, v := range c { 114 c[i] = complex(f*real(v), f*imag(v)) 115 } 116 } 117 118 func TestBitReversePermute(t *testing.T) { 119 for n := 2; n <= 1024; n <<= 1 { 120 x := make([]complex128, n) 121 for i := range x { 122 x[i] = complex(float64(i), float64(i)) 123 } 124 bitReversePermute(x) 125 for i, got := range x { 126 j := bits.Reverse(uint(i)) >> bits.LeadingZeros(uint(n-1)) 127 want := complex(float64(j), float64(j)) 128 if got != want { 129 t.Errorf("unexpected value at %d: got:%f want:%f", i, got, want) 130 } 131 } 132 } 133 } 134 135 func TestPadRadix2(t *testing.T) { 136 for n := 1; n <= 1025; n++ { 137 x := make([]complex128, n) 138 y := PadRadix2(x) 139 if bits.OnesCount(uint(len(y))) != 1 { 140 t.Errorf("unexpected length of padded slice: not a power of 2: %d", len(y)) 141 } 142 if len(x) == len(y) && &y[0] != &x[0] { 143 t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n) 144 } 145 if len(y) < len(x) { 146 t.Errorf("unexpected short result: len(y)=%d < len(x)=%d", len(y), len(x)) 147 } 148 } 149 } 150 151 func TestTrimRadix2(t *testing.T) { 152 for n := 1; n <= 1025; n++ { 153 x := make([]complex128, n) 154 y, r := TrimRadix2(x) 155 if bits.OnesCount(uint(len(y))) != 1 { 156 t.Errorf("unexpected length of trimmed slice: not a power of 2: %d", len(y)) 157 } 158 if len(y)+len(r) != len(x) { 159 t.Errorf("unexpected total result: len(y)=%d + len(r)%d != len(x)=%d", len(y), len(r), len(x)) 160 } 161 if len(x) == len(y) && &y[0] != &x[0] { 162 t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n) 163 } 164 if len(y) > len(x) { 165 t.Errorf("unexpected long result: len(y)=%d > len(x)=%d", len(y), len(x)) 166 } 167 } 168 } 169 170 func TestBitPairReversePermute(t *testing.T) { 171 for n := 4; n <= 1024; n <<= 2 { 172 x := make([]complex128, n) 173 for i := range x { 174 x[i] = complex(float64(i), float64(i)) 175 } 176 bitPairReversePermute(x) 177 for i, got := range x { 178 j := reversePairs(uint(i)) >> bits.LeadingZeros(uint(n-1)) 179 want := complex(float64(j), float64(j)) 180 if got != want { 181 t.Errorf("unexpected value at %d: got:%f want:%f", i, got, want) 182 } 183 } 184 } 185 } 186 187 func TestReversePairs(t *testing.T) { 188 rnd := rand.New(rand.NewSource(1)) 189 for i := 0; i < 1000; i++ { 190 x := uint(rnd.Uint64()) 191 got := reversePairs(x) 192 want := naiveReversePairs(x) 193 if got != want { 194 t.Errorf("unexpected bit-pair reversal for 0b%064b:\ngot: 0b%064b\nwant:0b%064b", x, got, want) 195 } 196 } 197 } 198 199 // naiveReversePairs does a bit-pair reversal by formatting as a base-4 string, 200 // reversing the digits of the formatted number and then re-parsing the value. 201 func naiveReversePairs(x uint) uint { 202 bits := int(unsafe.Sizeof(uint(0)) * 8) 203 204 // Format the number as a quaternary, padded with zeros. 205 // We avoid the leftpad issue by doing it ourselves. 206 b := strconv.AppendUint(bytes.Repeat([]byte("0"), bits/2), uint64(x), 4) 207 b = b[len(b)-bits/2:] 208 209 // Reverse the quits. 210 for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 { 211 b[i], b[j] = b[j], b[i] 212 } 213 214 // Re-parse the reversed number. 215 y, err := strconv.ParseUint(string(b), 4, 64) 216 if err != nil { 217 panic(fmt.Sprintf("unexpected parse error: %v", err)) 218 } 219 return uint(y) 220 } 221 222 func TestPadRadix4(t *testing.T) { 223 for n := 1; n <= 1025; n++ { 224 x := make([]complex128, n) 225 y := PadRadix4(x) 226 if bits.OnesCount(uint(len(y))) != 1 || bits.Len(uint(len(y)))&0x1 == 0 { 227 t.Errorf("unexpected length of padded slice: not a power of 4: %d", len(y)) 228 } 229 if len(x) == len(y) && &y[0] != &x[0] { 230 t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n) 231 } 232 if len(y) < len(x) { 233 t.Errorf("unexpected short result: len(y)=%d < len(x)=%d", len(y), len(x)) 234 } 235 } 236 } 237 238 func TestTrimRadix4(t *testing.T) { 239 for n := 1; n <= 1025; n++ { 240 x := make([]complex128, n) 241 y, r := TrimRadix4(x) 242 if bits.OnesCount(uint(len(y))) != 1 || bits.Len(uint(len(y)))&0x1 == 0 { 243 t.Errorf("unexpected length of trimmed slice: not a power of 4: %d", len(y)) 244 } 245 if len(y)+len(r) != len(x) { 246 t.Errorf("unexpected total result: len(y)=%d + len(r)%d != len(x)=%d", len(y), len(r), len(x)) 247 } 248 if len(x) == len(y) && &y[0] != &x[0] { 249 t.Errorf("unexpected new allocation for power of 2 input length: len(x)=%d", n) 250 } 251 if len(y) > len(x) { 252 t.Errorf("unexpected long result: len(y)=%d > len(x)=%d", len(y), len(x)) 253 } 254 } 255 } 256 257 func BenchmarkCoefficients(b *testing.B) { 258 for n := 16; n < 1<<24; n <<= 3 { 259 d := randComplexes(n, rand.NewSource(1)) 260 b.Run(fmt.Sprintf("Radix2/%d", n), func(b *testing.B) { 261 for i := 0; i < b.N; i++ { 262 CoefficientsRadix2(d) 263 } 264 }) 265 if bits.Len(uint(n))&0x1 == 0 { 266 continue 267 } 268 b.Run(fmt.Sprintf("Radix4/%d", n), func(b *testing.B) { 269 for i := 0; i < b.N; i++ { 270 CoefficientsRadix4(d) 271 } 272 }) 273 } 274 } 275 276 func BenchmarkSequence(b *testing.B) { 277 for n := 16; n < 1<<24; n <<= 3 { 278 d := randComplexes(n, rand.NewSource(1)) 279 b.Run(fmt.Sprintf("Radix2/%d", n), func(b *testing.B) { 280 for i := 0; i < b.N; i++ { 281 SequenceRadix2(d) 282 } 283 }) 284 if bits.Len(uint(n))&0x1 == 0 { 285 continue 286 } 287 b.Run(fmt.Sprintf("Radix4/%d", n), func(b *testing.B) { 288 for i := 0; i < b.N; i++ { 289 SequenceRadix4(d) 290 } 291 }) 292 } 293 }