github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/chat/flip/prng.go (about)

     1  package flip
     2  
     3  import (
     4  	"crypto/aes"
     5  	"crypto/cipher"
     6  	"encoding/binary"
     7  	"math/big"
     8  )
     9  
    10  // PRNG is based on AES-CTR. The input key is a 32-byte random secret (as generated
    11  // by our commitment scheme). The output is a AES(k,1), AES(k,2), AES(k,3), etc...
    12  // We are relying on the fact that AES is a PRP, which is pretty widely assumed.
    13  type PRNG struct {
    14  	key    Secret
    15  	buf    []byte
    16  	i      uint64
    17  	cipher cipher.Block
    18  }
    19  
    20  func NewPRNG(s Secret) *PRNG {
    21  	return &PRNG{
    22  		key: s,
    23  		i:   uint64(1),
    24  	}
    25  }
    26  
    27  func min(x, y int) int {
    28  	if x < y {
    29  		return x
    30  	}
    31  	return y
    32  }
    33  
    34  func (p *PRNG) read(ret []byte) int {
    35  	n := min(len(p.buf), len(ret))
    36  	copy(ret[0:n], p.buf[0:n])
    37  	p.buf = p.buf[n:]
    38  	return n
    39  }
    40  
    41  type block [16]byte
    42  
    43  func (b *block) counter(i uint64) {
    44  	binary.BigEndian.PutUint64(b[8:], i)
    45  }
    46  
    47  func (p *PRNG) getCipher() cipher.Block {
    48  	if p.cipher == nil {
    49  		var err error
    50  		p.cipher, err = aes.NewCipher(p.key[:])
    51  		if err != nil {
    52  			panic(err.Error())
    53  		}
    54  		var tmp block
    55  		if p.cipher.BlockSize() != len(tmp) {
    56  			panic("Expected a 16-byte block size")
    57  		}
    58  	}
    59  	return p.cipher
    60  }
    61  
    62  func (p *PRNG) replenish() {
    63  	if len(p.buf) == 0 {
    64  		var input block
    65  		var output block
    66  		input.counter(p.i)
    67  		p.i++
    68  		p.getCipher().Encrypt(output[:], input[:])
    69  		p.buf = output[:]
    70  	}
    71  }
    72  
    73  func (p *PRNG) Read(out []byte) int {
    74  	var nRead int
    75  	i := 0
    76  	for nRead < len(out) {
    77  		p.replenish()
    78  		tmp := p.read(out[nRead:])
    79  		nRead += tmp
    80  		i++
    81  	}
    82  	return nRead
    83  }
    84  
    85  func (p *PRNG) Big(modulus *big.Int) *big.Int {
    86  
    87  	sign := modulus.Sign()
    88  	// For moduli of 0, the sign will be 0. Just return it, since there's
    89  	// nothing we can really do.
    90  	if sign == 0 {
    91  		return modulus
    92  	}
    93  
    94  	// Find out how many bits are in numbers that are between 0 and |modulus|, exclusive.
    95  	// To do this, we find the absolute value of modulus, store it into n, and ask
    96  	// how many bits are in (n-1).
    97  	var n big.Int
    98  	n.Abs(modulus)
    99  	var nMinus1 big.Int
   100  	nMinus1.Sub(&n, big.NewInt(1))
   101  	bits := nMinus1.BitLen()
   102  
   103  	// For a modulus n, we want to clear out the bits that are
   104  	// greater than the greatest bit of n. So compute 2^(ceil(log2(n)))-1,
   105  	// and AND our candidate with that mask. That'll get rid of the
   106  	// bits we don't want.
   107  	var mask big.Int
   108  	mask.Lsh(big.NewInt(1), uint(bits))
   109  	mask.Sub(&mask, big.NewInt(1))
   110  
   111  	// Compute the number of bytes it takes to get that many bits.
   112  	// but rounding up.
   113  	bytes := bits / 8
   114  	if bits%8 != 0 {
   115  		bytes++
   116  	}
   117  
   118  	buf := make([]byte, bytes)
   119  	for {
   120  		p.Read(buf)
   121  		var x big.Int
   122  		x.SetBytes(buf)
   123  		x.And(&x, &mask)
   124  		if x.Cmp(&n) < 0 {
   125  			return x.Mul(&x, big.NewInt(int64(sign)))
   126  		}
   127  	}
   128  }
   129  
   130  func (p *PRNG) Int(modulus int64) int64 {
   131  	return p.Big(big.NewInt(modulus)).Int64()
   132  }
   133  
   134  func (p *PRNG) Bool() bool {
   135  	var b [1]byte
   136  	p.Read(b[:])
   137  	var ret bool
   138  	if b[0]&0x1 == byte(1) {
   139  		ret = true
   140  	}
   141  	return ret
   142  }
   143  
   144  // Permutation runs the Fisher-Yates shuffle on the sequence [0,n).
   145  // See: https://en.wikipedia.org/wiki/Fisher–Yates_shuffle
   146  // Be careful for off-by-one errors in this implementation, as we have
   147  // already witnessed one. We bounty bugs like these, so let us know!
   148  func (p *PRNG) Permutation(n int) []int {
   149  	ret := make([]int, n)
   150  	for i := 0; i < n; i++ {
   151  		ret[i] = i
   152  	}
   153  	for i := n - 1; i >= 1; i-- {
   154  		modulus := i + 1
   155  		j := p.Int(int64(modulus))
   156  		ret[j], ret[i] = ret[i], ret[j]
   157  	}
   158  	return ret
   159  }