pgregory.net/rand@v1.0.3-0.20230808192358-a0b8ce02f4da/misc/practrand/practrand.go (about)

     1  // Copyright 2022 Gregory Petrosyan <gregory.petrosyan@gmail.com>
     2  //
     3  // This Source Code Form is subject to the terms of the Mozilla Public
     4  // License, v. 2.0. If a copy of the MPL was not distributed with this
     5  // file, You can obtain one at https://mozilla.org/MPL/2.0/.
     6  
     7  package main
     8  
     9  import (
    10  	"encoding/binary"
    11  	"flag"
    12  	"fmt"
    13  	"github.com/valyala/fastrand"
    14  	exprand "golang.org/x/exp/rand"
    15  	"hash/maphash"
    16  	"log"
    17  	"math"
    18  	"math/bits"
    19  	mathrand "math/rand"
    20  	"os"
    21  	"pgregory.net/rand"
    22  )
    23  
    24  const (
    25  	chunkSizeBits  = 1 << 16
    26  	chunkSizeBytes = chunkSizeBits / 8
    27  	numChunks      = 1024
    28  	bufSizeBits    = numChunks * chunkSizeBits
    29  	bufSizeBytes   = bufSizeBits / 8
    30  	bufSizeWords   = bufSizeBytes / 8
    31  	maxInt52       = 1<<52 - 1
    32  )
    33  
    34  type randGen interface {
    35  	Uint64() uint64
    36  	Float64() float64
    37  	NormFloat64() float64
    38  	ExpFloat64() float64
    39  }
    40  
    41  type wyrandSource struct {
    42  	seed uint64
    43  }
    44  
    45  func (s *wyrandSource) Seed(seed uint64) {
    46  	s.seed = seed // bad idea
    47  }
    48  
    49  func (s *wyrandSource) Uint64() uint64 {
    50  	s.seed += 0xa0761d6478bd642f
    51  	hi, lo := bits.Mul64(s.seed, s.seed^0xe7037ed1a0b428db)
    52  	return hi ^ lo
    53  }
    54  
    55  type globalSource struct{}
    56  
    57  func (s globalSource) Seed(_ uint64) {}
    58  
    59  func (s globalSource) Uint64() uint64 {
    60  	a := rand.Intn(math.MaxUint32)
    61  	b := rand.Intn(math.MaxUint32)
    62  	return uint64(a)<<32 | uint64(b)
    63  }
    64  
    65  type fastSource struct {
    66  	rng fastrand.RNG
    67  }
    68  
    69  func (s *fastSource) Seed(seed uint64) {
    70  	s.rng.Seed(uint32(seed))
    71  }
    72  
    73  func (s *fastSource) Uint64() uint64 {
    74  	a := s.rng.Uint32()
    75  	b := s.rng.Uint32()
    76  	return uint64(a)<<32 | uint64(b)
    77  }
    78  
    79  type rand64 struct {
    80  	rng randGen
    81  }
    82  
    83  func (r *rand64) raw() uint64 {
    84  	return r.rng.Uint64()
    85  }
    86  
    87  func (r *rand64) fromF64() uint64 {
    88  	return floatToUniform(r.rng.Float64(), r.rng.Float64())
    89  }
    90  
    91  func (r *rand64) fromNorm() uint64 {
    92  	return floatToUniform(normalCDF(r.rng.NormFloat64()), normalCDF(r.rng.NormFloat64()))
    93  }
    94  
    95  func (r *rand64) fromExp() uint64 {
    96  	return floatToUniform(expCDF(r.rng.ExpFloat64()), expCDF(r.rng.ExpFloat64()))
    97  }
    98  
    99  func floatToUniform(x float64, y float64) uint64 {
   100  	return uint64(x*maxInt52)<<52 | uint64(y*maxInt52)
   101  }
   102  
   103  func normalCDF(x float64) float64 {
   104  	return 0.5 * math.Erfc(-x/math.Sqrt2)
   105  }
   106  
   107  func expCDF(x float64) float64 {
   108  	return -math.Expm1(-x)
   109  }
   110  
   111  func uint16nModulo(g func() uint64, n uint16) uint16 {
   112  	return uint16(g()) % n // biased
   113  }
   114  
   115  func uint16nFixedPoint(g func() uint64, n uint16) uint16 {
   116  	v := uint16(g())
   117  	x := uint32(n) * uint32(v)
   118  	return uint16(x >> 16) // biased
   119  }
   120  
   121  func uint16nLongFixedPoint(g func() uint64, n uint16) uint16 {
   122  	res, _ := bits.Mul32(uint32(n), uint32(g()))
   123  	return uint16(res) // biased with probability 2^-16
   124  }
   125  
   126  func uint16nLemire(g func() uint64, n uint16) uint16 {
   127  	v := uint16(g())
   128  	prod := uint32(v) * uint32(n)
   129  	low := uint16(prod)
   130  	if low < n {
   131  		thresh := -n % n
   132  		for low < thresh {
   133  			v = uint16(g())
   134  			prod = uint32(v) * uint32(n)
   135  			low = uint16(prod)
   136  		}
   137  	}
   138  	return uint16(prod >> 16) // unbiased
   139  }
   140  
   141  func shuffleBits(buf []byte, g func() uint64, b func(func() uint64, uint16) uint16) {
   142  	for i := math.MaxUint16 - 1; i > 0; i-- {
   143  		j := int(b(g, uint16(i+1)))
   144  		bi := getBit(buf, i)
   145  		bj := getBit(buf, j)
   146  		setBit(buf, i, bj)
   147  		setBit(buf, j, bi)
   148  	}
   149  }
   150  
   151  func getBit(buf []byte, i int) bool {
   152  	return buf[i/8]&(1<<(i%8)) > 0
   153  }
   154  
   155  func setBit(buf []byte, i int, b bool) {
   156  	if b {
   157  		buf[i/8] |= 1 << (i % 8)
   158  	} else {
   159  		buf[i/8] &= ^(1 << (i % 8))
   160  	}
   161  }
   162  
   163  func run(gen string, transform string, shuffle string) error {
   164  	var ctor func(uint64) randGen
   165  	switch gen {
   166  	case "rand":
   167  		ctor = func(s uint64) randGen { return rand.New(s) }
   168  	case "std":
   169  		ctor = func(s uint64) randGen { return mathrand.New(mathrand.NewSource(int64(s))) }
   170  	case "x":
   171  		ctor = func(s uint64) randGen { return exprand.New(exprand.NewSource(s)) }
   172  	case "x-wy":
   173  		ctor = func(s uint64) randGen { return exprand.New(&wyrandSource{s}) }
   174  	case "x-rand-g":
   175  		ctor = func(_ uint64) randGen { return exprand.New(globalSource{}) }
   176  	case "x-fast":
   177  		ctor = func(s uint64) randGen {
   178  			var rng fastrand.RNG
   179  			rng.Seed(uint32(s))
   180  			return exprand.New(&fastSource{rng})
   181  		}
   182  	default:
   183  		return fmt.Errorf("unknown RNG: %q", gen)
   184  	}
   185  
   186  	s := new(maphash.Hash).Sum64()
   187  	rng := func(s uint64) *rand64 { return &rand64{ctor(s)} }
   188  	var g func() uint64
   189  	switch transform {
   190  	case "none":
   191  		g = rng(s).raw
   192  	case "f64":
   193  		g = rng(s).fromF64
   194  	case "norm":
   195  		g = rng(s).fromNorm
   196  	case "exp":
   197  		g = rng(s).fromExp
   198  	case "8seed":
   199  		seeds := [8]uint64{1, 2, 4, 8, 16, 32, 64, 128}
   200  		gens := [8]*rand64{}
   201  		for i, s := range seeds {
   202  			gens[i] = rng(s)
   203  		}
   204  		i := 0
   205  		g = func() uint64 {
   206  			u := gens[i].raw()
   207  			i = (i + 1) % 8
   208  			return u
   209  		}
   210  	default:
   211  		return fmt.Errorf("unknown transform: %q", transform)
   212  	}
   213  
   214  	buf := make([]byte, 8*bufSizeWords)
   215  	switch shuffle {
   216  	case "none":
   217  		return output(buf, g, nil)
   218  	case "mod":
   219  		return output(buf, g, uint16nModulo)
   220  	case "fp":
   221  		return output(buf, g, uint16nFixedPoint)
   222  	case "lfp":
   223  		return output(buf, g, uint16nLongFixedPoint)
   224  	case "lemire":
   225  		return output(buf, g, uint16nLemire)
   226  	default:
   227  		return fmt.Errorf("unknown shuffle method: %q", shuffle)
   228  	}
   229  }
   230  
   231  func output(buf []byte, g func() uint64, b func(func() uint64, uint16) uint16) error {
   232  	for {
   233  		if b == nil {
   234  			for i := 0; i < bufSizeWords; i++ {
   235  				binary.LittleEndian.PutUint64(buf[i*8:], g())
   236  			}
   237  		} else {
   238  			for i := 0; i < numChunks; i++ {
   239  				ch := buf[i*chunkSizeBytes : (i+1)*chunkSizeBytes]
   240  				for j := 0; j < len(ch); j++ {
   241  					if j < len(ch)/2 {
   242  						ch[j] = 0xff
   243  					} else {
   244  						ch[j] = 0
   245  					}
   246  				}
   247  				shuffleBits(ch, g, b)
   248  			}
   249  		}
   250  
   251  		_, err := os.Stdout.Write(buf)
   252  		if err != nil {
   253  			return err
   254  		}
   255  	}
   256  }
   257  
   258  func main() {
   259  	var (
   260  		gen       = flag.String("gen", "rand", "RNG to use (rand/std/x/x-wy/x-fast)")
   261  		transform = flag.String("transform", "none", "transform to use (none/f64/norm/rand/8seed)")
   262  		shuffle   = flag.String("shuffle", "none", "shuffle algorithm to use (none/mod/fp/lfp/lemire)")
   263  	)
   264  	flag.Parse()
   265  
   266  	err := run(*gen, *transform, *shuffle)
   267  	if err != nil {
   268  		log.Fatal(err.Error())
   269  	}
   270  }