pgregory.net/rand@v1.0.3-0.20230808192358-a0b8ce02f4da/std_rand_test.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE-go file. 4 5 package rand_test 6 7 import ( 8 "bytes" 9 "errors" 10 "flag" 11 "fmt" 12 "io" 13 "math" 14 "os" 15 . "pgregory.net/rand" 16 "runtime" 17 "testing" 18 "testing/iotest" 19 ) 20 21 const ( 22 numTestSamples = 10000 23 ) 24 25 var ( 26 printtables = flag.Bool("printtables", false, "print ziggurat tables") 27 ) 28 29 var rn, kn, wn, fn = GetNormalDistributionParameters() 30 var re, ke, we, fe = GetExponentialDistributionParameters() 31 32 type statsResults struct { 33 mean float64 34 stddev float64 35 closeEnough float64 36 maxError float64 37 } 38 39 func max(a, b float64) float64 { 40 if a > b { 41 return a 42 } 43 return b 44 } 45 46 func nearEqual(a, b, closeEnough, maxError float64) bool { 47 absDiff := math.Abs(a - b) 48 if absDiff < closeEnough { // Necessary when one value is zero and one value is close to zero. 49 return true 50 } 51 return absDiff/(max(math.Abs(a), math.Abs(b))+math.SmallestNonzeroFloat64) < maxError 52 } 53 54 var testSeeds = []int64{1, 1754801282, 1698661970, 1550503961} 55 56 // checkSimilarDistribution returns success if the mean and stddev of the 57 // two statsResults are similar. 58 func (sr *statsResults) checkSimilarDistribution(expected *statsResults) error { 59 if !nearEqual(sr.mean, expected.mean, expected.closeEnough, expected.maxError) { 60 s := fmt.Sprintf("mean %v != %v (allowed error %v, %v)", sr.mean, expected.mean, expected.closeEnough, expected.maxError) 61 fmt.Println(s) 62 return errors.New(s) 63 } 64 if !nearEqual(sr.stddev, expected.stddev, expected.closeEnough, expected.maxError) { 65 s := fmt.Sprintf("stddev %v != %v (allowed error %v, %v)", sr.stddev, expected.stddev, expected.closeEnough, expected.maxError) 66 fmt.Println(s) 67 return errors.New(s) 68 } 69 return nil 70 } 71 72 func getStatsResults(samples []float64) *statsResults { 73 res := new(statsResults) 74 var sum, squaresum float64 75 for _, s := range samples { 76 sum += s 77 squaresum += s * s 78 } 79 res.mean = sum / float64(len(samples)) 80 res.stddev = math.Sqrt(squaresum/float64(len(samples)) - res.mean*res.mean) 81 return res 82 } 83 84 func checkSampleDistribution(t *testing.T, samples []float64, expected *statsResults) { 85 t.Helper() 86 actual := getStatsResults(samples) 87 err := actual.checkSimilarDistribution(expected) 88 if err != nil { 89 t.Errorf(err.Error()) 90 } 91 } 92 93 func checkSampleSliceDistributions(t *testing.T, samples []float64, nslices int, expected *statsResults) { 94 t.Helper() 95 chunk := len(samples) / nslices 96 for i := 0; i < nslices; i++ { 97 low := i * chunk 98 var high int 99 if i == nslices-1 { 100 high = len(samples) - 1 101 } else { 102 high = (i + 1) * chunk 103 } 104 checkSampleDistribution(t, samples[low:high], expected) 105 } 106 } 107 108 // 109 // Normal distribution tests 110 // 111 112 func generateNormalSamples(nsamples int, mean, stddev float64, seed int64) []float64 { 113 r := New(uint64(seed)) 114 samples := make([]float64, nsamples) 115 for i := range samples { 116 samples[i] = r.NormFloat64()*stddev + mean 117 } 118 return samples 119 } 120 121 func testNormalDistribution(t *testing.T, nsamples int, mean, stddev float64, seed int64) { 122 //fmt.Printf("testing nsamples=%v mean=%v stddev=%v seed=%v\n", nsamples, mean, stddev, seed); 123 124 samples := generateNormalSamples(nsamples, mean, stddev, seed) 125 errorScale := max(1.0, stddev) // Error scales with stddev 126 expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.08 * errorScale} 127 128 // Make sure that the entire set matches the expected distribution. 129 checkSampleDistribution(t, samples, expected) 130 131 // Make sure that each half of the set matches the expected distribution. 132 checkSampleSliceDistributions(t, samples, 2, expected) 133 134 // Make sure that each 7th of the set matches the expected distribution. 135 checkSampleSliceDistributions(t, samples, 7, expected) 136 } 137 138 // Actual tests 139 140 func TestStandardNormalValues(t *testing.T) { 141 for _, seed := range testSeeds { 142 testNormalDistribution(t, numTestSamples, 0, 1, seed) 143 } 144 } 145 146 func TestNonStandardNormalValues(t *testing.T) { 147 sdmax := 1000.0 148 mmax := 1000.0 149 if testing.Short() { 150 sdmax = 5 151 mmax = 5 152 } 153 for sd := 0.5; sd < sdmax; sd *= 2 { 154 for m := 0.5; m < mmax; m *= 2 { 155 for _, seed := range testSeeds { 156 testNormalDistribution(t, numTestSamples, m, sd, seed) 157 if testing.Short() { 158 break 159 } 160 } 161 } 162 } 163 } 164 165 // 166 // Exponential distribution tests 167 // 168 169 func generateExponentialSamples(nsamples int, rate float64, seed int64) []float64 { 170 r := New(uint64(seed)) 171 samples := make([]float64, nsamples) 172 for i := range samples { 173 samples[i] = r.ExpFloat64() / rate 174 } 175 return samples 176 } 177 178 func testExponentialDistribution(t *testing.T, nsamples int, rate float64, seed int64) { 179 //fmt.Printf("testing nsamples=%v rate=%v seed=%v\n", nsamples, rate, seed); 180 181 mean := 1 / rate 182 stddev := mean 183 184 samples := generateExponentialSamples(nsamples, rate, seed) 185 errorScale := max(1.0, 1/rate) // Error scales with the inverse of the rate 186 expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.20 * errorScale} 187 188 // Make sure that the entire set matches the expected distribution. 189 checkSampleDistribution(t, samples, expected) 190 191 // Make sure that each half of the set matches the expected distribution. 192 checkSampleSliceDistributions(t, samples, 2, expected) 193 194 // Make sure that each 7th of the set matches the expected distribution. 195 checkSampleSliceDistributions(t, samples, 7, expected) 196 } 197 198 // Actual tests 199 200 func TestStandardExponentialValues(t *testing.T) { 201 for _, seed := range testSeeds { 202 testExponentialDistribution(t, numTestSamples, 1, seed) 203 } 204 } 205 206 func TestNonStandardExponentialValues(t *testing.T) { 207 for rate := 0.05; rate < 10; rate *= 2 { 208 for _, seed := range testSeeds { 209 testExponentialDistribution(t, numTestSamples, rate, seed) 210 if testing.Short() { 211 break 212 } 213 } 214 } 215 } 216 217 // 218 // Table generation tests 219 // 220 221 func initNorm() (testKn []uint64, testWn, testFn []float64) { 222 const ( 223 m1 = 1 << 52 224 vn = 0.00492867323399 225 ) 226 var ( 227 dn = rn 228 tn = dn 229 ) 230 231 testKn = make([]uint64, 256) 232 testWn = make([]float64, 256) 233 testFn = make([]float64, 256) 234 235 q := vn / math.Exp(-0.5*dn*dn) 236 testKn[0] = uint64((dn / q) * m1) 237 testKn[1] = 0 238 testWn[0] = q / m1 239 testWn[255] = dn / m1 240 testFn[0] = 1.0 241 testFn[255] = math.Exp(-0.5 * dn * dn) 242 for i := 254; i >= 1; i-- { 243 dn = math.Sqrt(-2.0 * math.Log(vn/dn+math.Exp(-0.5*dn*dn))) 244 testKn[i+1] = uint64((dn / tn) * m1) 245 tn = dn 246 testFn[i] = math.Exp(-0.5 * dn * dn) 247 testWn[i] = dn / m1 248 } 249 return 250 } 251 252 func initExp() (testKe []uint64, testWe, testFe []float64) { 253 const ( 254 m2 = 1 << 53 255 ve = 0.0039496598225815571993 256 ) 257 var ( 258 de = re 259 te = de 260 ) 261 262 testKe = make([]uint64, 256) 263 testWe = make([]float64, 256) 264 testFe = make([]float64, 256) 265 266 q := ve / math.Exp(-de) 267 testKe[0] = uint64((de / q) * m2) 268 testKe[1] = 0 269 testWe[0] = q / m2 270 testWe[255] = de / m2 271 testFe[0] = 1.0 272 testFe[255] = math.Exp(-de) 273 for i := 254; i >= 1; i-- { 274 de = -math.Log(ve/de + math.Exp(-de)) 275 testKe[i+1] = uint64((de / te) * m2) 276 te = de 277 testFe[i] = math.Exp(-de) 278 testWe[i] = de / m2 279 } 280 return 281 } 282 283 // compareUint64Slices returns the first index where the two slices 284 // disagree, or <0 if the lengths are the same and all elements 285 // are close to each other. 286 func compareUint64Slices(s1, s2 []uint64) int { 287 if len(s1) != len(s2) { 288 if len(s1) > len(s2) { 289 return len(s2) + 1 290 } 291 return len(s1) + 1 292 } 293 for i := range s1 { 294 if !nearEqual(float64(s1[i]), float64(s2[i]), 0, 1e-12) { 295 return i 296 } 297 } 298 return -1 299 } 300 301 // compareFloat64Slices returns the first index where the two slices 302 // disagree, or <0 if the lengths are the same and all elements 303 // are close to each other. 304 func compareFloat64Slices(s1, s2 []float64) int { 305 if len(s1) != len(s2) { 306 if len(s1) > len(s2) { 307 return len(s2) + 1 308 } 309 return len(s1) + 1 310 } 311 for i := range s1 { 312 if !nearEqual(s1[i], s2[i], 0, 1e-12) { 313 return i 314 } 315 } 316 return -1 317 } 318 319 func TestNormTables(t *testing.T) { 320 testKn, testWn, testFn := initNorm() 321 if *printtables { 322 fmt.Printf("var kn = %#v\n", testKn) 323 fmt.Printf("var wn = %#v\n", testWn) 324 fmt.Printf("var fn = %#v\n", testFn) 325 return 326 } 327 328 if i := compareUint64Slices(kn[0:], testKn); i >= 0 { 329 t.Errorf("kn disagrees at index %v; %v != %v", i, kn[i], testKn[i]) 330 } 331 if i := compareFloat64Slices(wn[0:], testWn); i >= 0 { 332 t.Errorf("wn disagrees at index %v; %v != %v", i, wn[i], testWn[i]) 333 } 334 if i := compareFloat64Slices(fn[0:], testFn); i >= 0 { 335 t.Errorf("fn disagrees at index %v; %v != %v", i, fn[i], testFn[i]) 336 } 337 } 338 339 func TestExpTables(t *testing.T) { 340 testKe, testWe, testFe := initExp() 341 if *printtables { 342 fmt.Printf("var ke = %#v\n", testKe) 343 fmt.Printf("var we = %#v\n", testWe) 344 fmt.Printf("var fe = %#v\n", testFe) 345 return 346 } 347 348 if i := compareUint64Slices(ke[0:], testKe); i >= 0 { 349 t.Errorf("ke disagrees at index %v; %v != %v", i, ke[i], testKe[i]) 350 } 351 if i := compareFloat64Slices(we[0:], testWe); i >= 0 { 352 t.Errorf("we disagrees at index %v; %v != %v", i, we[i], testWe[i]) 353 } 354 if i := compareFloat64Slices(fe[0:], testFe); i >= 0 { 355 t.Errorf("fe disagrees at index %v; %v != %v", i, fe[i], testFe[i]) 356 } 357 } 358 359 func hasSlowFloatingPoint() bool { 360 switch runtime.GOARCH { 361 case "arm": 362 return os.Getenv("GOARM") == "5" 363 case "mips", "mipsle", "mips64", "mips64le": 364 // Be conservative and assume that all mips boards 365 // have emulated floating point. 366 // TODO: detect what it actually has. 367 return true 368 } 369 return false 370 } 371 372 func TestFloat32_6721(t *testing.T) { 373 // For issue 6721, the problem came after 7533753 calls, so check 10e6. 374 num := int(10e6) 375 // But do the full amount only on builders (not locally). 376 // But ARM5 floating point emulation is slow (Issue 10749), so 377 // do less for that builder: 378 if testing.Short() && hasSlowFloatingPoint() { 379 num /= 100 // 1.72 seconds instead of 172 seconds 380 } 381 382 r := New(1) 383 for ct := 0; ct < num; ct++ { 384 f := r.Float32() 385 if f >= 1 { 386 t.Fatal("Float32() should be in range [0,1). ct:", ct, "f:", f) 387 } 388 } 389 } 390 391 func testReadUniformity(t *testing.T, n int, seed int64) { 392 r := New(uint64(seed)) 393 buf := make([]byte, n) 394 nRead, err := r.Read(buf) 395 if err != nil { 396 t.Errorf("Read err %v", err) 397 } 398 if nRead != n { 399 t.Errorf("Read returned unexpected n; %d != %d", nRead, n) 400 } 401 402 // Expect a uniform distribution of byte values, which lie in [0, 255]. 403 var ( 404 mean = 255.0 / 2 405 stddev = 256.0 / math.Sqrt(12.0) 406 errorScale = stddev / math.Sqrt(float64(n)) 407 ) 408 409 expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.08 * errorScale} 410 411 // Cast bytes as floats to use the common distribution-validity checks. 412 samples := make([]float64, n) 413 for i, val := range buf { 414 samples[i] = float64(val) 415 } 416 // Make sure that the entire set matches the expected distribution. 417 checkSampleDistribution(t, samples, expected) 418 } 419 420 func TestReadUniformity(t *testing.T) { 421 testBufferSizes := []int{ 422 2, 4, 7, 64, 1024, 1 << 16, 1 << 20, 423 } 424 for _, seed := range testSeeds { 425 for _, n := range testBufferSizes { 426 testReadUniformity(t, n, seed) 427 } 428 } 429 } 430 431 func TestReadEmpty(t *testing.T) { 432 r := New(1) 433 buf := make([]byte, 0) 434 n, err := r.Read(buf) 435 if err != nil { 436 t.Errorf("Read err into empty buffer; %v", err) 437 } 438 if n != 0 { 439 t.Errorf("Read into empty buffer returned unexpected n of %d", n) 440 } 441 } 442 443 func TestReadByOneByte(t *testing.T) { 444 r := New(1) 445 b1 := make([]byte, 100) 446 _, err := io.ReadFull(iotest.OneByteReader(r), b1) 447 if err != nil { 448 t.Errorf("read by one byte: %v", err) 449 } 450 r = New(1) 451 b2 := make([]byte, 100) 452 _, err = r.Read(b2) 453 if err != nil { 454 t.Errorf("read: %v", err) 455 } 456 if !bytes.Equal(b1, b2) { 457 t.Errorf("read by one byte vs single read:\n%x\n%x", b1, b2) 458 } 459 } 460 461 func TestReadSeedReset(t *testing.T) { 462 r := New(42) 463 b1 := make([]byte, 128) 464 _, err := r.Read(b1) 465 if err != nil { 466 t.Errorf("read: %v", err) 467 } 468 r.Seed(42) 469 b2 := make([]byte, 128) 470 _, err = r.Read(b2) 471 if err != nil { 472 t.Errorf("read: %v", err) 473 } 474 if !bytes.Equal(b1, b2) { 475 t.Errorf("mismatch after re-seed:\n%x\n%x", b1, b2) 476 } 477 } 478 479 func TestShuffleSmall(t *testing.T) { 480 // Check that Shuffle allows n=0 and n=1, but that swap is never called for them. 481 r := New(1) 482 for n := 0; n <= 1; n++ { 483 r.Shuffle(n, func(i, j int) { t.Fatalf("swap called, n=%d i=%d j=%d", n, i, j) }) 484 } 485 } 486 487 // encodePerm converts from a permuted slice of length n, such as Perm generates, to an int in [0, n!). 488 // See https://en.wikipedia.org/wiki/Lehmer_code. 489 // encodePerm modifies the input slice. 490 func encodePerm(s []int) int { 491 // Convert to Lehmer code. 492 for i, x := range s { 493 r := s[i+1:] 494 for j, y := range r { 495 if y > x { 496 r[j]-- 497 } 498 } 499 } 500 // Convert to int in [0, n!). 501 m := 0 502 fact := 1 503 for i := len(s) - 1; i >= 0; i-- { 504 m += s[i] * fact 505 fact *= len(s) - i 506 } 507 return m 508 } 509 510 // TestUniformFactorial tests several ways of generating a uniform value in [0, n!). 511 func TestUniformFactorial(t *testing.T) { 512 r := New(uint64(testSeeds[0])) 513 top := 6 514 if testing.Short() { 515 top = 3 516 } 517 for n := 3; n <= top; n++ { 518 t.Run(fmt.Sprintf("n=%d", n), func(t *testing.T) { 519 // Calculate n!. 520 nfact := 1 521 for i := 2; i <= n; i++ { 522 nfact *= i 523 } 524 525 // Test a few different ways to generate a uniform distribution. 526 p := make([]int, n) // re-usable slice for Shuffle generator 527 tests := [...]struct { 528 name string 529 fn func() int 530 }{ 531 {name: "Int31n", fn: func() int { return int(r.Int31n(int32(nfact))) }}, 532 {name: "int31n", fn: func() int { return int(Int31nForTest(r, int32(nfact))) }}, 533 {name: "Perm", fn: func() int { return encodePerm(r.Perm(n)) }}, 534 {name: "Shuffle", fn: func() int { 535 // Generate permutation using Shuffle. 536 for i := range p { 537 p[i] = i 538 } 539 r.Shuffle(n, func(i, j int) { p[i], p[j] = p[j], p[i] }) 540 return encodePerm(p) 541 }}, 542 {name: "ShuffleSliceGeneric", fn: func() int { 543 if ShuffleSliceGeneric == nil { 544 return int(r.Uint32n(uint32(nfact))) 545 } 546 // Generate permutation using generic Shuffle. 547 for i := range p { 548 p[i] = i 549 } 550 ShuffleSliceGeneric(r, p) 551 return encodePerm(p) 552 }}, 553 } 554 555 for _, test := range tests { 556 t.Run(test.name, func(t *testing.T) { 557 // Gather chi-squared values and check that they follow 558 // the expected normal distribution given n!-1 degrees of freedom. 559 // See https://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test and 560 // https://www.johndcook.com/Beautiful_Testing_ch10.pdf. 561 nsamples := 10 * nfact 562 if nsamples < 200 { 563 nsamples = 200 564 } 565 samples := make([]float64, nsamples) 566 for i := range samples { 567 // Generate some uniformly distributed values and count their occurrences. 568 const iters = 1000 569 counts := make([]int, nfact) 570 for i := 0; i < iters; i++ { 571 counts[test.fn()]++ 572 } 573 // Calculate chi-squared and add to samples. 574 want := iters / float64(nfact) 575 var χ2 float64 576 for _, have := range counts { 577 err := float64(have) - want 578 χ2 += err * err 579 } 580 χ2 /= want 581 samples[i] = χ2 582 } 583 584 // Check that our samples approximate the appropriate normal distribution. 585 dof := float64(nfact - 1) 586 expected := &statsResults{mean: dof, stddev: math.Sqrt(2 * dof)} 587 errorScale := max(1.0, expected.stddev) 588 expected.closeEnough = 0.10 * errorScale 589 expected.maxError = 0.08 // TODO: What is the right value here? See issue 21211. 590 checkSampleDistribution(t, samples, expected) 591 }) 592 } 593 }) 594 } 595 }