github.com/twelsh-aw/go/src@v0.0.0-20230516233729-a56fe86a7c81/math/rand/rand_test.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package rand_test 6 7 import ( 8 "bytes" 9 "errors" 10 "fmt" 11 "internal/testenv" 12 "io" 13 "math" 14 . "math/rand" 15 "os" 16 "runtime" 17 "sync" 18 "testing" 19 "testing/iotest" 20 ) 21 22 const ( 23 numTestSamples = 10000 24 ) 25 26 var rn, kn, wn, fn = GetNormalDistributionParameters() 27 var re, ke, we, fe = GetExponentialDistributionParameters() 28 29 type statsResults struct { 30 mean float64 31 stddev float64 32 closeEnough float64 33 maxError float64 34 } 35 36 func max(a, b float64) float64 { 37 if a > b { 38 return a 39 } 40 return b 41 } 42 43 func nearEqual(a, b, closeEnough, maxError float64) bool { 44 absDiff := math.Abs(a - b) 45 if absDiff < closeEnough { // Necessary when one value is zero and one value is close to zero. 46 return true 47 } 48 return absDiff/max(math.Abs(a), math.Abs(b)) < maxError 49 } 50 51 var testSeeds = []int64{1, 1754801282, 1698661970, 1550503961} 52 53 // checkSimilarDistribution returns success if the mean and stddev of the 54 // two statsResults are similar. 55 func (this *statsResults) checkSimilarDistribution(expected *statsResults) error { 56 if !nearEqual(this.mean, expected.mean, expected.closeEnough, expected.maxError) { 57 s := fmt.Sprintf("mean %v != %v (allowed error %v, %v)", this.mean, expected.mean, expected.closeEnough, expected.maxError) 58 fmt.Println(s) 59 return errors.New(s) 60 } 61 if !nearEqual(this.stddev, expected.stddev, expected.closeEnough, expected.maxError) { 62 s := fmt.Sprintf("stddev %v != %v (allowed error %v, %v)", this.stddev, expected.stddev, expected.closeEnough, expected.maxError) 63 fmt.Println(s) 64 return errors.New(s) 65 } 66 return nil 67 } 68 69 func getStatsResults(samples []float64) *statsResults { 70 res := new(statsResults) 71 var sum, squaresum float64 72 for _, s := range samples { 73 sum += s 74 squaresum += s * s 75 } 76 res.mean = sum / float64(len(samples)) 77 res.stddev = math.Sqrt(squaresum/float64(len(samples)) - res.mean*res.mean) 78 return res 79 } 80 81 func checkSampleDistribution(t *testing.T, samples []float64, expected *statsResults) { 82 t.Helper() 83 actual := getStatsResults(samples) 84 err := actual.checkSimilarDistribution(expected) 85 if err != nil { 86 t.Errorf(err.Error()) 87 } 88 } 89 90 func checkSampleSliceDistributions(t *testing.T, samples []float64, nslices int, expected *statsResults) { 91 t.Helper() 92 chunk := len(samples) / nslices 93 for i := 0; i < nslices; i++ { 94 low := i * chunk 95 var high int 96 if i == nslices-1 { 97 high = len(samples) - 1 98 } else { 99 high = (i + 1) * chunk 100 } 101 checkSampleDistribution(t, samples[low:high], expected) 102 } 103 } 104 105 // 106 // Normal distribution tests 107 // 108 109 func generateNormalSamples(nsamples int, mean, stddev float64, seed int64) []float64 { 110 r := New(NewSource(seed)) 111 samples := make([]float64, nsamples) 112 for i := range samples { 113 samples[i] = r.NormFloat64()*stddev + mean 114 } 115 return samples 116 } 117 118 func testNormalDistribution(t *testing.T, nsamples int, mean, stddev float64, seed int64) { 119 //fmt.Printf("testing nsamples=%v mean=%v stddev=%v seed=%v\n", nsamples, mean, stddev, seed); 120 121 samples := generateNormalSamples(nsamples, mean, stddev, seed) 122 errorScale := max(1.0, stddev) // Error scales with stddev 123 expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.08 * errorScale} 124 125 // Make sure that the entire set matches the expected distribution. 126 checkSampleDistribution(t, samples, expected) 127 128 // Make sure that each half of the set matches the expected distribution. 129 checkSampleSliceDistributions(t, samples, 2, expected) 130 131 // Make sure that each 7th of the set matches the expected distribution. 132 checkSampleSliceDistributions(t, samples, 7, expected) 133 } 134 135 // Actual tests 136 137 func TestStandardNormalValues(t *testing.T) { 138 for _, seed := range testSeeds { 139 testNormalDistribution(t, numTestSamples, 0, 1, seed) 140 } 141 } 142 143 func TestNonStandardNormalValues(t *testing.T) { 144 sdmax := 1000.0 145 mmax := 1000.0 146 if testing.Short() { 147 sdmax = 5 148 mmax = 5 149 } 150 for sd := 0.5; sd < sdmax; sd *= 2 { 151 for m := 0.5; m < mmax; m *= 2 { 152 for _, seed := range testSeeds { 153 testNormalDistribution(t, numTestSamples, m, sd, seed) 154 if testing.Short() { 155 break 156 } 157 } 158 } 159 } 160 } 161 162 // 163 // Exponential distribution tests 164 // 165 166 func generateExponentialSamples(nsamples int, rate float64, seed int64) []float64 { 167 r := New(NewSource(seed)) 168 samples := make([]float64, nsamples) 169 for i := range samples { 170 samples[i] = r.ExpFloat64() / rate 171 } 172 return samples 173 } 174 175 func testExponentialDistribution(t *testing.T, nsamples int, rate float64, seed int64) { 176 //fmt.Printf("testing nsamples=%v rate=%v seed=%v\n", nsamples, rate, seed); 177 178 mean := 1 / rate 179 stddev := mean 180 181 samples := generateExponentialSamples(nsamples, rate, seed) 182 errorScale := max(1.0, 1/rate) // Error scales with the inverse of the rate 183 expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.20 * errorScale} 184 185 // Make sure that the entire set matches the expected distribution. 186 checkSampleDistribution(t, samples, expected) 187 188 // Make sure that each half of the set matches the expected distribution. 189 checkSampleSliceDistributions(t, samples, 2, expected) 190 191 // Make sure that each 7th of the set matches the expected distribution. 192 checkSampleSliceDistributions(t, samples, 7, expected) 193 } 194 195 // Actual tests 196 197 func TestStandardExponentialValues(t *testing.T) { 198 for _, seed := range testSeeds { 199 testExponentialDistribution(t, numTestSamples, 1, seed) 200 } 201 } 202 203 func TestNonStandardExponentialValues(t *testing.T) { 204 for rate := 0.05; rate < 10; rate *= 2 { 205 for _, seed := range testSeeds { 206 testExponentialDistribution(t, numTestSamples, rate, seed) 207 if testing.Short() { 208 break 209 } 210 } 211 } 212 } 213 214 // 215 // Table generation tests 216 // 217 218 func initNorm() (testKn []uint32, testWn, testFn []float32) { 219 const m1 = 1 << 31 220 var ( 221 dn float64 = rn 222 tn = dn 223 vn float64 = 9.91256303526217e-3 224 ) 225 226 testKn = make([]uint32, 128) 227 testWn = make([]float32, 128) 228 testFn = make([]float32, 128) 229 230 q := vn / math.Exp(-0.5*dn*dn) 231 testKn[0] = uint32((dn / q) * m1) 232 testKn[1] = 0 233 testWn[0] = float32(q / m1) 234 testWn[127] = float32(dn / m1) 235 testFn[0] = 1.0 236 testFn[127] = float32(math.Exp(-0.5 * dn * dn)) 237 for i := 126; i >= 1; i-- { 238 dn = math.Sqrt(-2.0 * math.Log(vn/dn+math.Exp(-0.5*dn*dn))) 239 testKn[i+1] = uint32((dn / tn) * m1) 240 tn = dn 241 testFn[i] = float32(math.Exp(-0.5 * dn * dn)) 242 testWn[i] = float32(dn / m1) 243 } 244 return 245 } 246 247 func initExp() (testKe []uint32, testWe, testFe []float32) { 248 const m2 = 1 << 32 249 var ( 250 de float64 = re 251 te = de 252 ve float64 = 3.9496598225815571993e-3 253 ) 254 255 testKe = make([]uint32, 256) 256 testWe = make([]float32, 256) 257 testFe = make([]float32, 256) 258 259 q := ve / math.Exp(-de) 260 testKe[0] = uint32((de / q) * m2) 261 testKe[1] = 0 262 testWe[0] = float32(q / m2) 263 testWe[255] = float32(de / m2) 264 testFe[0] = 1.0 265 testFe[255] = float32(math.Exp(-de)) 266 for i := 254; i >= 1; i-- { 267 de = -math.Log(ve/de + math.Exp(-de)) 268 testKe[i+1] = uint32((de / te) * m2) 269 te = de 270 testFe[i] = float32(math.Exp(-de)) 271 testWe[i] = float32(de / m2) 272 } 273 return 274 } 275 276 // compareUint32Slices returns the first index where the two slices 277 // disagree, or <0 if the lengths are the same and all elements 278 // are identical. 279 func compareUint32Slices(s1, s2 []uint32) int { 280 if len(s1) != len(s2) { 281 if len(s1) > len(s2) { 282 return len(s2) + 1 283 } 284 return len(s1) + 1 285 } 286 for i := range s1 { 287 if s1[i] != s2[i] { 288 return i 289 } 290 } 291 return -1 292 } 293 294 // compareFloat32Slices returns the first index where the two slices 295 // disagree, or <0 if the lengths are the same and all elements 296 // are identical. 297 func compareFloat32Slices(s1, s2 []float32) int { 298 if len(s1) != len(s2) { 299 if len(s1) > len(s2) { 300 return len(s2) + 1 301 } 302 return len(s1) + 1 303 } 304 for i := range s1 { 305 if !nearEqual(float64(s1[i]), float64(s2[i]), 0, 1e-7) { 306 return i 307 } 308 } 309 return -1 310 } 311 312 func TestNormTables(t *testing.T) { 313 testKn, testWn, testFn := initNorm() 314 if i := compareUint32Slices(kn[0:], testKn); i >= 0 { 315 t.Errorf("kn disagrees at index %v; %v != %v", i, kn[i], testKn[i]) 316 } 317 if i := compareFloat32Slices(wn[0:], testWn); i >= 0 { 318 t.Errorf("wn disagrees at index %v; %v != %v", i, wn[i], testWn[i]) 319 } 320 if i := compareFloat32Slices(fn[0:], testFn); i >= 0 { 321 t.Errorf("fn disagrees at index %v; %v != %v", i, fn[i], testFn[i]) 322 } 323 } 324 325 func TestExpTables(t *testing.T) { 326 testKe, testWe, testFe := initExp() 327 if i := compareUint32Slices(ke[0:], testKe); i >= 0 { 328 t.Errorf("ke disagrees at index %v; %v != %v", i, ke[i], testKe[i]) 329 } 330 if i := compareFloat32Slices(we[0:], testWe); i >= 0 { 331 t.Errorf("we disagrees at index %v; %v != %v", i, we[i], testWe[i]) 332 } 333 if i := compareFloat32Slices(fe[0:], testFe); i >= 0 { 334 t.Errorf("fe disagrees at index %v; %v != %v", i, fe[i], testFe[i]) 335 } 336 } 337 338 func hasSlowFloatingPoint() bool { 339 switch runtime.GOARCH { 340 case "arm": 341 return os.Getenv("GOARM") == "5" 342 case "mips", "mipsle", "mips64", "mips64le": 343 // Be conservative and assume that all mips boards 344 // have emulated floating point. 345 // TODO: detect what it actually has. 346 return true 347 } 348 return false 349 } 350 351 func TestFloat32(t *testing.T) { 352 // For issue 6721, the problem came after 7533753 calls, so check 10e6. 353 num := int(10e6) 354 // But do the full amount only on builders (not locally). 355 // But ARM5 floating point emulation is slow (Issue 10749), so 356 // do less for that builder: 357 if testing.Short() && (testenv.Builder() == "" || hasSlowFloatingPoint()) { 358 num /= 100 // 1.72 seconds instead of 172 seconds 359 } 360 361 r := New(NewSource(1)) 362 for ct := 0; ct < num; ct++ { 363 f := r.Float32() 364 if f >= 1 { 365 t.Fatal("Float32() should be in range [0,1). ct:", ct, "f:", f) 366 } 367 } 368 } 369 370 func testReadUniformity(t *testing.T, n int, seed int64) { 371 r := New(NewSource(seed)) 372 buf := make([]byte, n) 373 nRead, err := r.Read(buf) 374 if err != nil { 375 t.Errorf("Read err %v", err) 376 } 377 if nRead != n { 378 t.Errorf("Read returned unexpected n; %d != %d", nRead, n) 379 } 380 381 // Expect a uniform distribution of byte values, which lie in [0, 255]. 382 var ( 383 mean = 255.0 / 2 384 stddev = 256.0 / math.Sqrt(12.0) 385 errorScale = stddev / math.Sqrt(float64(n)) 386 ) 387 388 expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.08 * errorScale} 389 390 // Cast bytes as floats to use the common distribution-validity checks. 391 samples := make([]float64, n) 392 for i, val := range buf { 393 samples[i] = float64(val) 394 } 395 // Make sure that the entire set matches the expected distribution. 396 checkSampleDistribution(t, samples, expected) 397 } 398 399 func TestReadUniformity(t *testing.T) { 400 testBufferSizes := []int{ 401 2, 4, 7, 64, 1024, 1 << 16, 1 << 20, 402 } 403 for _, seed := range testSeeds { 404 for _, n := range testBufferSizes { 405 testReadUniformity(t, n, seed) 406 } 407 } 408 } 409 410 func TestReadEmpty(t *testing.T) { 411 r := New(NewSource(1)) 412 buf := make([]byte, 0) 413 n, err := r.Read(buf) 414 if err != nil { 415 t.Errorf("Read err into empty buffer; %v", err) 416 } 417 if n != 0 { 418 t.Errorf("Read into empty buffer returned unexpected n of %d", n) 419 } 420 } 421 422 func TestReadByOneByte(t *testing.T) { 423 r := New(NewSource(1)) 424 b1 := make([]byte, 100) 425 _, err := io.ReadFull(iotest.OneByteReader(r), b1) 426 if err != nil { 427 t.Errorf("read by one byte: %v", err) 428 } 429 r = New(NewSource(1)) 430 b2 := make([]byte, 100) 431 _, err = r.Read(b2) 432 if err != nil { 433 t.Errorf("read: %v", err) 434 } 435 if !bytes.Equal(b1, b2) { 436 t.Errorf("read by one byte vs single read:\n%x\n%x", b1, b2) 437 } 438 } 439 440 func TestReadSeedReset(t *testing.T) { 441 r := New(NewSource(42)) 442 b1 := make([]byte, 128) 443 _, err := r.Read(b1) 444 if err != nil { 445 t.Errorf("read: %v", err) 446 } 447 r.Seed(42) 448 b2 := make([]byte, 128) 449 _, err = r.Read(b2) 450 if err != nil { 451 t.Errorf("read: %v", err) 452 } 453 if !bytes.Equal(b1, b2) { 454 t.Errorf("mismatch after re-seed:\n%x\n%x", b1, b2) 455 } 456 } 457 458 func TestShuffleSmall(t *testing.T) { 459 // Check that Shuffle allows n=0 and n=1, but that swap is never called for them. 460 r := New(NewSource(1)) 461 for n := 0; n <= 1; n++ { 462 r.Shuffle(n, func(i, j int) { t.Fatalf("swap called, n=%d i=%d j=%d", n, i, j) }) 463 } 464 } 465 466 // encodePerm converts from a permuted slice of length n, such as Perm generates, to an int in [0, n!). 467 // See https://en.wikipedia.org/wiki/Lehmer_code. 468 // encodePerm modifies the input slice. 469 func encodePerm(s []int) int { 470 // Convert to Lehmer code. 471 for i, x := range s { 472 r := s[i+1:] 473 for j, y := range r { 474 if y > x { 475 r[j]-- 476 } 477 } 478 } 479 // Convert to int in [0, n!). 480 m := 0 481 fact := 1 482 for i := len(s) - 1; i >= 0; i-- { 483 m += s[i] * fact 484 fact *= len(s) - i 485 } 486 return m 487 } 488 489 // TestUniformFactorial tests several ways of generating a uniform value in [0, n!). 490 func TestUniformFactorial(t *testing.T) { 491 r := New(NewSource(testSeeds[0])) 492 top := 6 493 if testing.Short() { 494 top = 3 495 } 496 for n := 3; n <= top; n++ { 497 t.Run(fmt.Sprintf("n=%d", n), func(t *testing.T) { 498 // Calculate n!. 499 nfact := 1 500 for i := 2; i <= n; i++ { 501 nfact *= i 502 } 503 504 // Test a few different ways to generate a uniform distribution. 505 p := make([]int, n) // re-usable slice for Shuffle generator 506 tests := [...]struct { 507 name string 508 fn func() int 509 }{ 510 {name: "Int31n", fn: func() int { return int(r.Int31n(int32(nfact))) }}, 511 {name: "int31n", fn: func() int { return int(Int31nForTest(r, int32(nfact))) }}, 512 {name: "Perm", fn: func() int { return encodePerm(r.Perm(n)) }}, 513 {name: "Shuffle", fn: func() int { 514 // Generate permutation using Shuffle. 515 for i := range p { 516 p[i] = i 517 } 518 r.Shuffle(n, func(i, j int) { p[i], p[j] = p[j], p[i] }) 519 return encodePerm(p) 520 }}, 521 } 522 523 for _, test := range tests { 524 t.Run(test.name, func(t *testing.T) { 525 // Gather chi-squared values and check that they follow 526 // the expected normal distribution given n!-1 degrees of freedom. 527 // See https://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test and 528 // https://www.johndcook.com/Beautiful_Testing_ch10.pdf. 529 nsamples := 10 * nfact 530 if nsamples < 200 { 531 nsamples = 200 532 } 533 samples := make([]float64, nsamples) 534 for i := range samples { 535 // Generate some uniformly distributed values and count their occurrences. 536 const iters = 1000 537 counts := make([]int, nfact) 538 for i := 0; i < iters; i++ { 539 counts[test.fn()]++ 540 } 541 // Calculate chi-squared and add to samples. 542 want := iters / float64(nfact) 543 var χ2 float64 544 for _, have := range counts { 545 err := float64(have) - want 546 χ2 += err * err 547 } 548 χ2 /= want 549 samples[i] = χ2 550 } 551 552 // Check that our samples approximate the appropriate normal distribution. 553 dof := float64(nfact - 1) 554 expected := &statsResults{mean: dof, stddev: math.Sqrt(2 * dof)} 555 errorScale := max(1.0, expected.stddev) 556 expected.closeEnough = 0.10 * errorScale 557 expected.maxError = 0.08 // TODO: What is the right value here? See issue 21211. 558 checkSampleDistribution(t, samples, expected) 559 }) 560 } 561 }) 562 } 563 } 564 565 // Benchmarks 566 567 func BenchmarkInt63Threadsafe(b *testing.B) { 568 for n := b.N; n > 0; n-- { 569 Int63() 570 } 571 } 572 573 func BenchmarkInt63ThreadsafeParallel(b *testing.B) { 574 b.RunParallel(func(pb *testing.PB) { 575 for pb.Next() { 576 Int63() 577 } 578 }) 579 } 580 581 func BenchmarkInt63Unthreadsafe(b *testing.B) { 582 r := New(NewSource(1)) 583 for n := b.N; n > 0; n-- { 584 r.Int63() 585 } 586 } 587 588 func BenchmarkIntn1000(b *testing.B) { 589 r := New(NewSource(1)) 590 for n := b.N; n > 0; n-- { 591 r.Intn(1000) 592 } 593 } 594 595 func BenchmarkInt63n1000(b *testing.B) { 596 r := New(NewSource(1)) 597 for n := b.N; n > 0; n-- { 598 r.Int63n(1000) 599 } 600 } 601 602 func BenchmarkInt31n1000(b *testing.B) { 603 r := New(NewSource(1)) 604 for n := b.N; n > 0; n-- { 605 r.Int31n(1000) 606 } 607 } 608 609 func BenchmarkFloat32(b *testing.B) { 610 r := New(NewSource(1)) 611 for n := b.N; n > 0; n-- { 612 r.Float32() 613 } 614 } 615 616 func BenchmarkFloat64(b *testing.B) { 617 r := New(NewSource(1)) 618 for n := b.N; n > 0; n-- { 619 r.Float64() 620 } 621 } 622 623 func BenchmarkPerm3(b *testing.B) { 624 r := New(NewSource(1)) 625 for n := b.N; n > 0; n-- { 626 r.Perm(3) 627 } 628 } 629 630 func BenchmarkPerm30(b *testing.B) { 631 r := New(NewSource(1)) 632 for n := b.N; n > 0; n-- { 633 r.Perm(30) 634 } 635 } 636 637 func BenchmarkPerm30ViaShuffle(b *testing.B) { 638 r := New(NewSource(1)) 639 for n := b.N; n > 0; n-- { 640 p := make([]int, 30) 641 for i := range p { 642 p[i] = i 643 } 644 r.Shuffle(30, func(i, j int) { p[i], p[j] = p[j], p[i] }) 645 } 646 } 647 648 // BenchmarkShuffleOverhead uses a minimal swap function 649 // to measure just the shuffling overhead. 650 func BenchmarkShuffleOverhead(b *testing.B) { 651 r := New(NewSource(1)) 652 for n := b.N; n > 0; n-- { 653 r.Shuffle(52, func(i, j int) { 654 if i < 0 || i >= 52 || j < 0 || j >= 52 { 655 b.Fatalf("bad swap(%d, %d)", i, j) 656 } 657 }) 658 } 659 } 660 661 func BenchmarkRead3(b *testing.B) { 662 r := New(NewSource(1)) 663 buf := make([]byte, 3) 664 b.ResetTimer() 665 for n := b.N; n > 0; n-- { 666 r.Read(buf) 667 } 668 } 669 670 func BenchmarkRead64(b *testing.B) { 671 r := New(NewSource(1)) 672 buf := make([]byte, 64) 673 b.ResetTimer() 674 for n := b.N; n > 0; n-- { 675 r.Read(buf) 676 } 677 } 678 679 func BenchmarkRead1000(b *testing.B) { 680 r := New(NewSource(1)) 681 buf := make([]byte, 1000) 682 b.ResetTimer() 683 for n := b.N; n > 0; n-- { 684 r.Read(buf) 685 } 686 } 687 688 func BenchmarkConcurrent(b *testing.B) { 689 const goroutines = 4 690 var wg sync.WaitGroup 691 wg.Add(goroutines) 692 for i := 0; i < goroutines; i++ { 693 go func() { 694 defer wg.Done() 695 for n := b.N; n > 0; n-- { 696 Int63() 697 } 698 }() 699 } 700 wg.Wait() 701 }