github.com/code-reading/golang@v0.0.0-20220303082512-ba5bc0e589a3/go/src/math/rand/rand_test.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package rand_test 6 7 import ( 8 "bytes" 9 "errors" 10 "fmt" 11 "internal/testenv" 12 "io" 13 "math" 14 . "math/rand" 15 "os" 16 "runtime" 17 "testing" 18 "testing/iotest" 19 ) 20 21 const ( 22 numTestSamples = 10000 23 ) 24 25 var rn, kn, wn, fn = GetNormalDistributionParameters() 26 var re, ke, we, fe = GetExponentialDistributionParameters() 27 28 type statsResults struct { 29 mean float64 30 stddev float64 31 closeEnough float64 32 maxError float64 33 } 34 35 func max(a, b float64) float64 { 36 if a > b { 37 return a 38 } 39 return b 40 } 41 42 func nearEqual(a, b, closeEnough, maxError float64) bool { 43 absDiff := math.Abs(a - b) 44 if absDiff < closeEnough { // Necessary when one value is zero and one value is close to zero. 45 return true 46 } 47 return absDiff/max(math.Abs(a), math.Abs(b)) < maxError 48 } 49 50 var testSeeds = []int64{1, 1754801282, 1698661970, 1550503961} 51 52 // checkSimilarDistribution returns success if the mean and stddev of the 53 // two statsResults are similar. 54 func (this *statsResults) checkSimilarDistribution(expected *statsResults) error { 55 if !nearEqual(this.mean, expected.mean, expected.closeEnough, expected.maxError) { 56 s := fmt.Sprintf("mean %v != %v (allowed error %v, %v)", this.mean, expected.mean, expected.closeEnough, expected.maxError) 57 fmt.Println(s) 58 return errors.New(s) 59 } 60 if !nearEqual(this.stddev, expected.stddev, expected.closeEnough, expected.maxError) { 61 s := fmt.Sprintf("stddev %v != %v (allowed error %v, %v)", this.stddev, expected.stddev, expected.closeEnough, expected.maxError) 62 fmt.Println(s) 63 return errors.New(s) 64 } 65 return nil 66 } 67 68 func getStatsResults(samples []float64) *statsResults { 69 res := new(statsResults) 70 var sum, squaresum float64 71 for _, s := range samples { 72 sum += s 73 squaresum += s * s 74 } 75 res.mean = sum / float64(len(samples)) 76 res.stddev = math.Sqrt(squaresum/float64(len(samples)) - res.mean*res.mean) 77 return res 78 } 79 80 func checkSampleDistribution(t *testing.T, samples []float64, expected *statsResults) { 81 t.Helper() 82 actual := getStatsResults(samples) 83 err := actual.checkSimilarDistribution(expected) 84 if err != nil { 85 t.Errorf(err.Error()) 86 } 87 } 88 89 func checkSampleSliceDistributions(t *testing.T, samples []float64, nslices int, expected *statsResults) { 90 t.Helper() 91 chunk := len(samples) / nslices 92 for i := 0; i < nslices; i++ { 93 low := i * chunk 94 var high int 95 if i == nslices-1 { 96 high = len(samples) - 1 97 } else { 98 high = (i + 1) * chunk 99 } 100 checkSampleDistribution(t, samples[low:high], expected) 101 } 102 } 103 104 // 105 // Normal distribution tests 106 // 107 108 func generateNormalSamples(nsamples int, mean, stddev float64, seed int64) []float64 { 109 r := New(NewSource(seed)) 110 samples := make([]float64, nsamples) 111 for i := range samples { 112 samples[i] = r.NormFloat64()*stddev + mean 113 } 114 return samples 115 } 116 117 func testNormalDistribution(t *testing.T, nsamples int, mean, stddev float64, seed int64) { 118 //fmt.Printf("testing nsamples=%v mean=%v stddev=%v seed=%v\n", nsamples, mean, stddev, seed); 119 120 samples := generateNormalSamples(nsamples, mean, stddev, seed) 121 errorScale := max(1.0, stddev) // Error scales with stddev 122 expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.08 * errorScale} 123 124 // Make sure that the entire set matches the expected distribution. 125 checkSampleDistribution(t, samples, expected) 126 127 // Make sure that each half of the set matches the expected distribution. 128 checkSampleSliceDistributions(t, samples, 2, expected) 129 130 // Make sure that each 7th of the set matches the expected distribution. 131 checkSampleSliceDistributions(t, samples, 7, expected) 132 } 133 134 // Actual tests 135 136 func TestStandardNormalValues(t *testing.T) { 137 for _, seed := range testSeeds { 138 testNormalDistribution(t, numTestSamples, 0, 1, seed) 139 } 140 } 141 142 func TestNonStandardNormalValues(t *testing.T) { 143 sdmax := 1000.0 144 mmax := 1000.0 145 if testing.Short() { 146 sdmax = 5 147 mmax = 5 148 } 149 for sd := 0.5; sd < sdmax; sd *= 2 { 150 for m := 0.5; m < mmax; m *= 2 { 151 for _, seed := range testSeeds { 152 testNormalDistribution(t, numTestSamples, m, sd, seed) 153 if testing.Short() { 154 break 155 } 156 } 157 } 158 } 159 } 160 161 // 162 // Exponential distribution tests 163 // 164 165 func generateExponentialSamples(nsamples int, rate float64, seed int64) []float64 { 166 r := New(NewSource(seed)) 167 samples := make([]float64, nsamples) 168 for i := range samples { 169 samples[i] = r.ExpFloat64() / rate 170 } 171 return samples 172 } 173 174 func testExponentialDistribution(t *testing.T, nsamples int, rate float64, seed int64) { 175 //fmt.Printf("testing nsamples=%v rate=%v seed=%v\n", nsamples, rate, seed); 176 177 mean := 1 / rate 178 stddev := mean 179 180 samples := generateExponentialSamples(nsamples, rate, seed) 181 errorScale := max(1.0, 1/rate) // Error scales with the inverse of the rate 182 expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.20 * errorScale} 183 184 // Make sure that the entire set matches the expected distribution. 185 checkSampleDistribution(t, samples, expected) 186 187 // Make sure that each half of the set matches the expected distribution. 188 checkSampleSliceDistributions(t, samples, 2, expected) 189 190 // Make sure that each 7th of the set matches the expected distribution. 191 checkSampleSliceDistributions(t, samples, 7, expected) 192 } 193 194 // Actual tests 195 196 func TestStandardExponentialValues(t *testing.T) { 197 for _, seed := range testSeeds { 198 testExponentialDistribution(t, numTestSamples, 1, seed) 199 } 200 } 201 202 func TestNonStandardExponentialValues(t *testing.T) { 203 for rate := 0.05; rate < 10; rate *= 2 { 204 for _, seed := range testSeeds { 205 testExponentialDistribution(t, numTestSamples, rate, seed) 206 if testing.Short() { 207 break 208 } 209 } 210 } 211 } 212 213 // 214 // Table generation tests 215 // 216 217 func initNorm() (testKn []uint32, testWn, testFn []float32) { 218 const m1 = 1 << 31 219 var ( 220 dn float64 = rn 221 tn = dn 222 vn float64 = 9.91256303526217e-3 223 ) 224 225 testKn = make([]uint32, 128) 226 testWn = make([]float32, 128) 227 testFn = make([]float32, 128) 228 229 q := vn / math.Exp(-0.5*dn*dn) 230 testKn[0] = uint32((dn / q) * m1) 231 testKn[1] = 0 232 testWn[0] = float32(q / m1) 233 testWn[127] = float32(dn / m1) 234 testFn[0] = 1.0 235 testFn[127] = float32(math.Exp(-0.5 * dn * dn)) 236 for i := 126; i >= 1; i-- { 237 dn = math.Sqrt(-2.0 * math.Log(vn/dn+math.Exp(-0.5*dn*dn))) 238 testKn[i+1] = uint32((dn / tn) * m1) 239 tn = dn 240 testFn[i] = float32(math.Exp(-0.5 * dn * dn)) 241 testWn[i] = float32(dn / m1) 242 } 243 return 244 } 245 246 func initExp() (testKe []uint32, testWe, testFe []float32) { 247 const m2 = 1 << 32 248 var ( 249 de float64 = re 250 te = de 251 ve float64 = 3.9496598225815571993e-3 252 ) 253 254 testKe = make([]uint32, 256) 255 testWe = make([]float32, 256) 256 testFe = make([]float32, 256) 257 258 q := ve / math.Exp(-de) 259 testKe[0] = uint32((de / q) * m2) 260 testKe[1] = 0 261 testWe[0] = float32(q / m2) 262 testWe[255] = float32(de / m2) 263 testFe[0] = 1.0 264 testFe[255] = float32(math.Exp(-de)) 265 for i := 254; i >= 1; i-- { 266 de = -math.Log(ve/de + math.Exp(-de)) 267 testKe[i+1] = uint32((de / te) * m2) 268 te = de 269 testFe[i] = float32(math.Exp(-de)) 270 testWe[i] = float32(de / m2) 271 } 272 return 273 } 274 275 // compareUint32Slices returns the first index where the two slices 276 // disagree, or <0 if the lengths are the same and all elements 277 // are identical. 278 func compareUint32Slices(s1, s2 []uint32) int { 279 if len(s1) != len(s2) { 280 if len(s1) > len(s2) { 281 return len(s2) + 1 282 } 283 return len(s1) + 1 284 } 285 for i := range s1 { 286 if s1[i] != s2[i] { 287 return i 288 } 289 } 290 return -1 291 } 292 293 // compareFloat32Slices returns the first index where the two slices 294 // disagree, or <0 if the lengths are the same and all elements 295 // are identical. 296 func compareFloat32Slices(s1, s2 []float32) int { 297 if len(s1) != len(s2) { 298 if len(s1) > len(s2) { 299 return len(s2) + 1 300 } 301 return len(s1) + 1 302 } 303 for i := range s1 { 304 if !nearEqual(float64(s1[i]), float64(s2[i]), 0, 1e-7) { 305 return i 306 } 307 } 308 return -1 309 } 310 311 func TestNormTables(t *testing.T) { 312 testKn, testWn, testFn := initNorm() 313 if i := compareUint32Slices(kn[0:], testKn); i >= 0 { 314 t.Errorf("kn disagrees at index %v; %v != %v", i, kn[i], testKn[i]) 315 } 316 if i := compareFloat32Slices(wn[0:], testWn); i >= 0 { 317 t.Errorf("wn disagrees at index %v; %v != %v", i, wn[i], testWn[i]) 318 } 319 if i := compareFloat32Slices(fn[0:], testFn); i >= 0 { 320 t.Errorf("fn disagrees at index %v; %v != %v", i, fn[i], testFn[i]) 321 } 322 } 323 324 func TestExpTables(t *testing.T) { 325 testKe, testWe, testFe := initExp() 326 if i := compareUint32Slices(ke[0:], testKe); i >= 0 { 327 t.Errorf("ke disagrees at index %v; %v != %v", i, ke[i], testKe[i]) 328 } 329 if i := compareFloat32Slices(we[0:], testWe); i >= 0 { 330 t.Errorf("we disagrees at index %v; %v != %v", i, we[i], testWe[i]) 331 } 332 if i := compareFloat32Slices(fe[0:], testFe); i >= 0 { 333 t.Errorf("fe disagrees at index %v; %v != %v", i, fe[i], testFe[i]) 334 } 335 } 336 337 func hasSlowFloatingPoint() bool { 338 switch runtime.GOARCH { 339 case "arm": 340 return os.Getenv("GOARM") == "5" 341 case "mips", "mipsle", "mips64", "mips64le": 342 // Be conservative and assume that all mips boards 343 // have emulated floating point. 344 // TODO: detect what it actually has. 345 return true 346 } 347 return false 348 } 349 350 func TestFloat32(t *testing.T) { 351 // For issue 6721, the problem came after 7533753 calls, so check 10e6. 352 num := int(10e6) 353 // But do the full amount only on builders (not locally). 354 // But ARM5 floating point emulation is slow (Issue 10749), so 355 // do less for that builder: 356 if testing.Short() && (testenv.Builder() == "" || hasSlowFloatingPoint()) { 357 num /= 100 // 1.72 seconds instead of 172 seconds 358 } 359 360 r := New(NewSource(1)) 361 for ct := 0; ct < num; ct++ { 362 f := r.Float32() 363 if f >= 1 { 364 t.Fatal("Float32() should be in range [0,1). ct:", ct, "f:", f) 365 } 366 } 367 } 368 369 func testReadUniformity(t *testing.T, n int, seed int64) { 370 r := New(NewSource(seed)) 371 buf := make([]byte, n) 372 nRead, err := r.Read(buf) 373 if err != nil { 374 t.Errorf("Read err %v", err) 375 } 376 if nRead != n { 377 t.Errorf("Read returned unexpected n; %d != %d", nRead, n) 378 } 379 380 // Expect a uniform distribution of byte values, which lie in [0, 255]. 381 var ( 382 mean = 255.0 / 2 383 stddev = 256.0 / math.Sqrt(12.0) 384 errorScale = stddev / math.Sqrt(float64(n)) 385 ) 386 387 expected := &statsResults{mean, stddev, 0.10 * errorScale, 0.08 * errorScale} 388 389 // Cast bytes as floats to use the common distribution-validity checks. 390 samples := make([]float64, n) 391 for i, val := range buf { 392 samples[i] = float64(val) 393 } 394 // Make sure that the entire set matches the expected distribution. 395 checkSampleDistribution(t, samples, expected) 396 } 397 398 func TestReadUniformity(t *testing.T) { 399 testBufferSizes := []int{ 400 2, 4, 7, 64, 1024, 1 << 16, 1 << 20, 401 } 402 for _, seed := range testSeeds { 403 for _, n := range testBufferSizes { 404 testReadUniformity(t, n, seed) 405 } 406 } 407 } 408 409 func TestReadEmpty(t *testing.T) { 410 r := New(NewSource(1)) 411 buf := make([]byte, 0) 412 n, err := r.Read(buf) 413 if err != nil { 414 t.Errorf("Read err into empty buffer; %v", err) 415 } 416 if n != 0 { 417 t.Errorf("Read into empty buffer returned unexpected n of %d", n) 418 } 419 } 420 421 func TestReadByOneByte(t *testing.T) { 422 r := New(NewSource(1)) 423 b1 := make([]byte, 100) 424 _, err := io.ReadFull(iotest.OneByteReader(r), b1) 425 if err != nil { 426 t.Errorf("read by one byte: %v", err) 427 } 428 r = New(NewSource(1)) 429 b2 := make([]byte, 100) 430 _, err = r.Read(b2) 431 if err != nil { 432 t.Errorf("read: %v", err) 433 } 434 if !bytes.Equal(b1, b2) { 435 t.Errorf("read by one byte vs single read:\n%x\n%x", b1, b2) 436 } 437 } 438 439 func TestReadSeedReset(t *testing.T) { 440 r := New(NewSource(42)) 441 b1 := make([]byte, 128) 442 _, err := r.Read(b1) 443 if err != nil { 444 t.Errorf("read: %v", err) 445 } 446 r.Seed(42) 447 b2 := make([]byte, 128) 448 _, err = r.Read(b2) 449 if err != nil { 450 t.Errorf("read: %v", err) 451 } 452 if !bytes.Equal(b1, b2) { 453 t.Errorf("mismatch after re-seed:\n%x\n%x", b1, b2) 454 } 455 } 456 457 func TestShuffleSmall(t *testing.T) { 458 // Check that Shuffle allows n=0 and n=1, but that swap is never called for them. 459 r := New(NewSource(1)) 460 for n := 0; n <= 1; n++ { 461 r.Shuffle(n, func(i, j int) { t.Fatalf("swap called, n=%d i=%d j=%d", n, i, j) }) 462 } 463 } 464 465 // encodePerm converts from a permuted slice of length n, such as Perm generates, to an int in [0, n!). 466 // See https://en.wikipedia.org/wiki/Lehmer_code. 467 // encodePerm modifies the input slice. 468 func encodePerm(s []int) int { 469 // Convert to Lehmer code. 470 for i, x := range s { 471 r := s[i+1:] 472 for j, y := range r { 473 if y > x { 474 r[j]-- 475 } 476 } 477 } 478 // Convert to int in [0, n!). 479 m := 0 480 fact := 1 481 for i := len(s) - 1; i >= 0; i-- { 482 m += s[i] * fact 483 fact *= len(s) - i 484 } 485 return m 486 } 487 488 // TestUniformFactorial tests several ways of generating a uniform value in [0, n!). 489 func TestUniformFactorial(t *testing.T) { 490 r := New(NewSource(testSeeds[0])) 491 top := 6 492 if testing.Short() { 493 top = 3 494 } 495 for n := 3; n <= top; n++ { 496 t.Run(fmt.Sprintf("n=%d", n), func(t *testing.T) { 497 // Calculate n!. 498 nfact := 1 499 for i := 2; i <= n; i++ { 500 nfact *= i 501 } 502 503 // Test a few different ways to generate a uniform distribution. 504 p := make([]int, n) // re-usable slice for Shuffle generator 505 tests := [...]struct { 506 name string 507 fn func() int 508 }{ 509 {name: "Int31n", fn: func() int { return int(r.Int31n(int32(nfact))) }}, 510 {name: "int31n", fn: func() int { return int(Int31nForTest(r, int32(nfact))) }}, 511 {name: "Perm", fn: func() int { return encodePerm(r.Perm(n)) }}, 512 {name: "Shuffle", fn: func() int { 513 // Generate permutation using Shuffle. 514 for i := range p { 515 p[i] = i 516 } 517 r.Shuffle(n, func(i, j int) { p[i], p[j] = p[j], p[i] }) 518 return encodePerm(p) 519 }}, 520 } 521 522 for _, test := range tests { 523 t.Run(test.name, func(t *testing.T) { 524 // Gather chi-squared values and check that they follow 525 // the expected normal distribution given n!-1 degrees of freedom. 526 // See https://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test and 527 // https://www.johndcook.com/Beautiful_Testing_ch10.pdf. 528 nsamples := 10 * nfact 529 if nsamples < 200 { 530 nsamples = 200 531 } 532 samples := make([]float64, nsamples) 533 for i := range samples { 534 // Generate some uniformly distributed values and count their occurrences. 535 const iters = 1000 536 counts := make([]int, nfact) 537 for i := 0; i < iters; i++ { 538 counts[test.fn()]++ 539 } 540 // Calculate chi-squared and add to samples. 541 want := iters / float64(nfact) 542 var χ2 float64 543 for _, have := range counts { 544 err := float64(have) - want 545 χ2 += err * err 546 } 547 χ2 /= want 548 samples[i] = χ2 549 } 550 551 // Check that our samples approximate the appropriate normal distribution. 552 dof := float64(nfact - 1) 553 expected := &statsResults{mean: dof, stddev: math.Sqrt(2 * dof)} 554 errorScale := max(1.0, expected.stddev) 555 expected.closeEnough = 0.10 * errorScale 556 expected.maxError = 0.08 // TODO: What is the right value here? See issue 21211. 557 checkSampleDistribution(t, samples, expected) 558 }) 559 } 560 }) 561 } 562 } 563 564 // Benchmarks 565 566 func BenchmarkInt63Threadsafe(b *testing.B) { 567 for n := b.N; n > 0; n-- { 568 Int63() 569 } 570 } 571 572 func BenchmarkInt63ThreadsafeParallel(b *testing.B) { 573 b.RunParallel(func(pb *testing.PB) { 574 for pb.Next() { 575 Int63() 576 } 577 }) 578 } 579 580 func BenchmarkInt63Unthreadsafe(b *testing.B) { 581 r := New(NewSource(1)) 582 for n := b.N; n > 0; n-- { 583 r.Int63() 584 } 585 } 586 587 func BenchmarkIntn1000(b *testing.B) { 588 r := New(NewSource(1)) 589 for n := b.N; n > 0; n-- { 590 r.Intn(1000) 591 } 592 } 593 594 func BenchmarkInt63n1000(b *testing.B) { 595 r := New(NewSource(1)) 596 for n := b.N; n > 0; n-- { 597 r.Int63n(1000) 598 } 599 } 600 601 func BenchmarkInt31n1000(b *testing.B) { 602 r := New(NewSource(1)) 603 for n := b.N; n > 0; n-- { 604 r.Int31n(1000) 605 } 606 } 607 608 func BenchmarkFloat32(b *testing.B) { 609 r := New(NewSource(1)) 610 for n := b.N; n > 0; n-- { 611 r.Float32() 612 } 613 } 614 615 func BenchmarkFloat64(b *testing.B) { 616 r := New(NewSource(1)) 617 for n := b.N; n > 0; n-- { 618 r.Float64() 619 } 620 } 621 622 func BenchmarkPerm3(b *testing.B) { 623 r := New(NewSource(1)) 624 for n := b.N; n > 0; n-- { 625 r.Perm(3) 626 } 627 } 628 629 func BenchmarkPerm30(b *testing.B) { 630 r := New(NewSource(1)) 631 for n := b.N; n > 0; n-- { 632 r.Perm(30) 633 } 634 } 635 636 func BenchmarkPerm30ViaShuffle(b *testing.B) { 637 r := New(NewSource(1)) 638 for n := b.N; n > 0; n-- { 639 p := make([]int, 30) 640 for i := range p { 641 p[i] = i 642 } 643 r.Shuffle(30, func(i, j int) { p[i], p[j] = p[j], p[i] }) 644 } 645 } 646 647 // BenchmarkShuffleOverhead uses a minimal swap function 648 // to measure just the shuffling overhead. 649 func BenchmarkShuffleOverhead(b *testing.B) { 650 r := New(NewSource(1)) 651 for n := b.N; n > 0; n-- { 652 r.Shuffle(52, func(i, j int) { 653 if i < 0 || i >= 52 || j < 0 || j >= 52 { 654 b.Fatalf("bad swap(%d, %d)", i, j) 655 } 656 }) 657 } 658 } 659 660 func BenchmarkRead3(b *testing.B) { 661 r := New(NewSource(1)) 662 buf := make([]byte, 3) 663 b.ResetTimer() 664 for n := b.N; n > 0; n-- { 665 r.Read(buf) 666 } 667 } 668 669 func BenchmarkRead64(b *testing.B) { 670 r := New(NewSource(1)) 671 buf := make([]byte, 64) 672 b.ResetTimer() 673 for n := b.N; n > 0; n-- { 674 r.Read(buf) 675 } 676 } 677 678 func BenchmarkRead1000(b *testing.B) { 679 r := New(NewSource(1)) 680 buf := make([]byte, 1000) 681 b.ResetTimer() 682 for n := b.N; n > 0; n-- { 683 r.Read(buf) 684 } 685 }