github.com/coyove/common@v0.0.0-20240403014525-f70e643f9de8/rand/rand.go (about)

     1  package rand
     2  
     3  import (
     4  	"crypto/aes"
     5  	"crypto/cipher"
     6  	"crypto/rand"
     7  	"encoding/binary"
     8  	"fmt"
     9  	"sync"
    10  )
    11  
    12  const bufferLen = 200 // don't exceed 1<<32-1 and keep an integral multiple of 16 bytes
    13  const budget = (1 << 20) / bufferLen
    14  
    15  // Rand is a concurrent random number generator struct
    16  type Rand struct {
    17  	off     uint64
    18  	counter uint32
    19  	block   cipher.Block
    20  	buffer  [bufferLen]byte
    21  	mu      sync.Mutex
    22  }
    23  
    24  // New returns a new Rand struct
    25  func New() *Rand {
    26  	r := &Rand{}
    27  	if _, err := rand.Read(r.buffer[:]); err != nil {
    28  		panic(err)
    29  	}
    30  
    31  	r.block, _ = aes.NewCipher(r.buffer[:16])
    32  	return r
    33  }
    34  
    35  // Int63 returns an int63 number
    36  func (src *Rand) Int63() int64 {
    37  	return int64(src.Uint64() & 0x7fffffffffffffff)
    38  }
    39  
    40  // Int31 returns an int31 number
    41  func (src *Rand) Int31() int32 {
    42  	return int32(src.Int63() >> 32)
    43  }
    44  
    45  // Uint64 returns an uint64 number
    46  // It has many duplicated code from Read(), but to achieve best performance, this redundancy is necessary
    47  func (src *Rand) Uint64() uint64 {
    48  	src.mu.Lock()
    49  	offStart, offEnd := src.off, src.off+8
    50  
    51  	if offEnd > bufferLen {
    52  		if src.counter++; src.counter >= budget {
    53  			if _, err := rand.Read(src.buffer[:]); err != nil {
    54  				panic(err)
    55  			}
    56  			src.counter = 0
    57  		} else {
    58  			for b := 0; b < bufferLen/16; b++ {
    59  				src.block.Encrypt(src.buffer[b*16:], src.buffer[b*16:])
    60  			}
    61  		}
    62  
    63  		src.off = 0
    64  		offStart, offEnd = 0, 8
    65  	}
    66  
    67  	src.off = offEnd
    68  
    69  	i := binary.BigEndian.Uint64(src.buffer[offStart:offEnd])
    70  	src.mu.Unlock()
    71  	return i
    72  }
    73  
    74  // Intn returns an integer within [0, n)
    75  func (src *Rand) Intn(n int) int {
    76  	if n <= 0 {
    77  		panic("invalid argument to Intn")
    78  	}
    79  	if n <= 1<<31-1 {
    80  		return int(src.Int31n(int32(n)))
    81  	}
    82  	return int(src.Int63n(int64(n)))
    83  }
    84  
    85  // Int63n returns an integer within [0, n)
    86  func (src *Rand) Int63n(n int64) int64 {
    87  	if n <= 0 {
    88  		panic("invalid argument to Int63n")
    89  	}
    90  	if n&(n-1) == 0 { // n is power of two, can mask
    91  		return src.Int63() & (n - 1)
    92  	}
    93  	max := int64((1 << 63) - 1 - (1<<63)%uint64(n))
    94  	v := src.Int63()
    95  	for v > max {
    96  		v = src.Int63()
    97  	}
    98  	return v % n
    99  }
   100  
   101  // Int31n returns an integer within [0, n)
   102  func (src *Rand) Int31n(n int32) int32 {
   103  	if n <= 0 {
   104  		panic("invalid argument to Int31n")
   105  	}
   106  	if n&(n-1) == 0 { // n is power of two, can mask
   107  		return src.Int31() & (n - 1)
   108  	}
   109  	max := int32((1 << 31) - 1 - (1<<31)%uint32(n))
   110  	v := src.Int31()
   111  	for v > max {
   112  		v = src.Int31()
   113  	}
   114  	return v % n
   115  }
   116  
   117  // Perm returns an array of shuffled integers from 0 to n-1
   118  func (src *Rand) Perm(n int) []int {
   119  	m := make([]int, n)
   120  	// Note we start from 1, different from the Go official
   121  	for i := 1; i < n; i++ {
   122  		j := src.Intn(i + 1)
   123  		m[i] = m[j]
   124  		m[j] = i
   125  	}
   126  	return m
   127  }
   128  
   129  // Read reads bytes into buf
   130  func (src *Rand) Read(buf []byte) error {
   131  	n := uint64(len(buf))
   132  
   133  	if n > bufferLen {
   134  		return fmt.Errorf("rand: don't read more than %d bytes in a single Read()", bufferLen)
   135  	}
   136  
   137  	src.mu.Lock()
   138  	offStart, offEnd := src.off, src.off+n
   139  
   140  	if offEnd > bufferLen {
   141  		if src.counter++; src.counter >= budget {
   142  			if _, err := rand.Read(src.buffer[:]); err != nil {
   143  				panic(err)
   144  			}
   145  			src.counter = 0
   146  		} else {
   147  			for b := 0; b < bufferLen/16; b++ {
   148  				src.block.Encrypt(src.buffer[b*16:], src.buffer[b*16:])
   149  			}
   150  
   151  		}
   152  
   153  		offStart, offEnd = 0, n
   154  	}
   155  
   156  	src.off = offEnd
   157  
   158  	copy(buf, src.buffer[offStart:offEnd])
   159  	// fmt.Println(buf, offStart)
   160  	src.mu.Unlock()
   161  	return nil
   162  }
   163  
   164  // Fetch fetches n bytes
   165  func (src *Rand) Fetch(n int) []byte {
   166  	buf := make([]byte, n)
   167  	if err := src.Read(buf); err != nil {
   168  		panic(err)
   169  	}
   170  	return buf
   171  }