github.com/prysmaticlabs/prysm@v1.4.4/beacon-chain/core/helpers/shuffle.go (about)

     1  package helpers
     2  
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     6  
     7  	types "github.com/prysmaticlabs/eth2-types"
     8  	"github.com/prysmaticlabs/prysm/shared/bytesutil"
     9  	"github.com/prysmaticlabs/prysm/shared/hashutil"
    10  	"github.com/prysmaticlabs/prysm/shared/params"
    11  	"github.com/prysmaticlabs/prysm/shared/sliceutil"
    12  )
    13  
    14  const seedSize = int8(32)
    15  const roundSize = int8(1)
    16  const positionWindowSize = int8(4)
    17  const pivotViewSize = seedSize + roundSize
    18  const totalSize = seedSize + roundSize + positionWindowSize
    19  
    20  var maxShuffleListSize uint64 = 1 << 40
    21  
    22  // SplitIndices splits a list into n pieces.
    23  func SplitIndices(l []uint64, n uint64) [][]uint64 {
    24  	var divided [][]uint64
    25  	var lSize = uint64(len(l))
    26  	for i := uint64(0); i < n; i++ {
    27  		start := sliceutil.SplitOffset(lSize, n, i)
    28  		end := sliceutil.SplitOffset(lSize, n, i+1)
    29  		divided = append(divided, l[start:end])
    30  	}
    31  	return divided
    32  }
    33  
    34  // ShuffledIndex returns `p(index)` in a pseudorandom permutation `p` of `0...list_size - 1` with ``seed`` as entropy.
    35  // We utilize 'swap or not' shuffling in this implementation; we are allocating the memory with the seed that stays
    36  // constant between iterations instead of reallocating it each iteration as in the spec. This implementation is based
    37  // on the original implementation from protolambda, https://github.com/protolambda/eth2-shuffle
    38  func ShuffledIndex(index types.ValidatorIndex, indexCount uint64, seed [32]byte) (types.ValidatorIndex, error) {
    39  	return ComputeShuffledIndex(index, indexCount, seed, true /* shuffle */)
    40  }
    41  
    42  // UnShuffledIndex returns the inverse of ShuffledIndex. This implementation is based
    43  // on the original implementation from protolambda, https://github.com/protolambda/eth2-shuffle
    44  func UnShuffledIndex(index types.ValidatorIndex, indexCount uint64, seed [32]byte) (types.ValidatorIndex, error) {
    45  	return ComputeShuffledIndex(index, indexCount, seed, false /* un-shuffle */)
    46  }
    47  
    48  // ComputeShuffledIndex returns the shuffled validator index corresponding to seed and index count.
    49  // Spec pseudocode definition:
    50  //   def compute_shuffled_index(index: uint64, index_count: uint64, seed: Bytes32) -> uint64:
    51  //    """
    52  //    Return the shuffled index corresponding to ``seed`` (and ``index_count``).
    53  //    """
    54  //    assert index < index_count
    55  //
    56  //    # Swap or not (https://link.springer.com/content/pdf/10.1007%2F978-3-642-32009-5_1.pdf)
    57  //    # See the 'generalized domain' algorithm on page 3
    58  //    for current_round in range(SHUFFLE_ROUND_COUNT):
    59  //        pivot = bytes_to_uint64(hash(seed + uint_to_bytes(uint8(current_round)))[0:8]) % index_count
    60  //        flip = (pivot + index_count - index) % index_count
    61  //        position = max(index, flip)
    62  //        source = hash(
    63  //            seed
    64  //            + uint_to_bytes(uint8(current_round))
    65  //            + uint_to_bytes(uint32(position // 256))
    66  //        )
    67  //        byte = uint8(source[(position % 256) // 8])
    68  //        bit = (byte >> (position % 8)) % 2
    69  //        index = flip if bit else index
    70  //
    71  //    return index
    72  func ComputeShuffledIndex(index types.ValidatorIndex, indexCount uint64, seed [32]byte, shuffle bool) (types.ValidatorIndex, error) {
    73  	if params.BeaconConfig().ShuffleRoundCount == 0 {
    74  		return index, nil
    75  	}
    76  	if uint64(index) >= indexCount {
    77  		return 0, fmt.Errorf("input index %d out of bounds: %d",
    78  			index, indexCount)
    79  	}
    80  	if indexCount > maxShuffleListSize {
    81  		return 0, fmt.Errorf("list size %d out of bounds",
    82  			indexCount)
    83  	}
    84  	rounds := uint8(params.BeaconConfig().ShuffleRoundCount)
    85  	round := uint8(0)
    86  	if !shuffle {
    87  		// Starting last round and iterating through the rounds in reverse, un-swaps everything,
    88  		// effectively un-shuffling the list.
    89  		round = rounds - 1
    90  	}
    91  	buf := make([]byte, totalSize)
    92  	posBuffer := make([]byte, 8)
    93  	hashfunc := hashutil.CustomSHA256Hasher()
    94  
    95  	// Seed is always the first 32 bytes of the hash input, we never have to change this part of the buffer.
    96  	copy(buf[:32], seed[:])
    97  	for {
    98  		buf[seedSize] = round
    99  		hash := hashfunc(buf[:pivotViewSize])
   100  		hash8 := hash[:8]
   101  		hash8Int := bytesutil.FromBytes8(hash8)
   102  		pivot := hash8Int % indexCount
   103  		flip := (pivot + indexCount - uint64(index)) % indexCount
   104  		// Consider every pair only once by picking the highest pair index to retrieve randomness.
   105  		position := uint64(index)
   106  		if flip > position {
   107  			position = flip
   108  		}
   109  		// Add position except its last byte to []buf for randomness,
   110  		// it will be used later to select a bit from the resulting hash.
   111  		binary.LittleEndian.PutUint64(posBuffer[:8], position>>8)
   112  		copy(buf[pivotViewSize:], posBuffer[:4])
   113  		source := hashfunc(buf)
   114  		// Effectively keep the first 5 bits of the byte value of the position,
   115  		// and use it to retrieve one of the 32 (= 2^5) bytes of the hash.
   116  		byteV := source[(position&0xff)>>3]
   117  		// Using the last 3 bits of the position-byte, determine which bit to get from the hash-byte (note: 8 bits = 2^3)
   118  		bitV := (byteV >> (position & 0x7)) & 0x1
   119  		// index = flip if bit else index
   120  		if bitV == 1 {
   121  			index = types.ValidatorIndex(flip)
   122  		}
   123  		if shuffle {
   124  			round++
   125  			if round == rounds {
   126  				break
   127  			}
   128  		} else {
   129  			if round == 0 {
   130  				break
   131  			}
   132  			round--
   133  		}
   134  	}
   135  	return index, nil
   136  }
   137  
   138  // ShuffleList returns list of shuffled indexes in a pseudorandom permutation `p` of `0...list_size - 1` with ``seed`` as entropy.
   139  // We utilize 'swap or not' shuffling in this implementation; we are allocating the memory with the seed that stays
   140  // constant between iterations instead of reallocating it each iteration as in the spec. This implementation is based
   141  // on the original implementation from protolambda, https://github.com/protolambda/eth2-shuffle
   142  //  improvements:
   143  //   - seed is always the first 32 bytes of the hash input, we just copy it into the buffer one time.
   144  //   - add round byte to seed and hash that part of the buffer.
   145  //   - split up the for-loop in two:
   146  //    1. Handle the part from 0 (incl) to pivot (incl). This is mirrored around (pivot / 2).
   147  //    2. Handle the part from pivot (excl) to N (excl). This is mirrored around ((pivot / 2) + (size/2)).
   148  //   - hash source every 256 iterations.
   149  //   - change byteV every 8 iterations.
   150  //   - we start at the edges, and work back to the mirror point.
   151  //     this makes us process each pear exactly once (instead of unnecessarily twice, like in the spec).
   152  func ShuffleList(input []types.ValidatorIndex, seed [32]byte) ([]types.ValidatorIndex, error) {
   153  	return innerShuffleList(input, seed, true /* shuffle */)
   154  }
   155  
   156  // UnshuffleList un-shuffles the list by running backwards through the round count.
   157  func UnshuffleList(input []types.ValidatorIndex, seed [32]byte) ([]types.ValidatorIndex, error) {
   158  	return innerShuffleList(input, seed, false /* un-shuffle */)
   159  }
   160  
   161  // shuffles or unshuffles, shuffle=false to un-shuffle.
   162  func innerShuffleList(input []types.ValidatorIndex, seed [32]byte, shuffle bool) ([]types.ValidatorIndex, error) {
   163  	if len(input) <= 1 {
   164  		return input, nil
   165  	}
   166  	if uint64(len(input)) > maxShuffleListSize {
   167  		return nil, fmt.Errorf("list size %d out of bounds",
   168  			len(input))
   169  	}
   170  	rounds := uint8(params.BeaconConfig().ShuffleRoundCount)
   171  	hashfunc := hashutil.CustomSHA256Hasher()
   172  	if rounds == 0 {
   173  		return input, nil
   174  	}
   175  	listSize := uint64(len(input))
   176  	buf := make([]byte, totalSize)
   177  	r := uint8(0)
   178  	if !shuffle {
   179  		r = rounds - 1
   180  	}
   181  	copy(buf[:seedSize], seed[:])
   182  	for {
   183  		buf[seedSize] = r
   184  		ph := hashfunc(buf[:pivotViewSize])
   185  		pivot := bytesutil.FromBytes8(ph[:8]) % listSize
   186  		mirror := (pivot + 1) >> 1
   187  		binary.LittleEndian.PutUint32(buf[pivotViewSize:], uint32(pivot>>8))
   188  		source := hashfunc(buf)
   189  		byteV := source[(pivot&0xff)>>3]
   190  		for i, j := uint64(0), pivot; i < mirror; i, j = i+1, j-1 {
   191  			byteV, source = swapOrNot(buf, byteV, types.ValidatorIndex(i), input, types.ValidatorIndex(j), source, hashfunc)
   192  		}
   193  		// Now repeat, but for the part after the pivot.
   194  		mirror = (pivot + listSize + 1) >> 1
   195  		end := listSize - 1
   196  		binary.LittleEndian.PutUint32(buf[pivotViewSize:], uint32(end>>8))
   197  		source = hashfunc(buf)
   198  		byteV = source[(end&0xff)>>3]
   199  		for i, j := pivot+1, end; i < mirror; i, j = i+1, j-1 {
   200  			byteV, source = swapOrNot(buf, byteV, types.ValidatorIndex(i), input, types.ValidatorIndex(j), source, hashfunc)
   201  		}
   202  		if shuffle {
   203  			r++
   204  			if r == rounds {
   205  				break
   206  			}
   207  		} else {
   208  			if r == 0 {
   209  				break
   210  			}
   211  			r--
   212  		}
   213  	}
   214  	return input, nil
   215  }
   216  
   217  // swapOrNot describes the main algorithm behind the shuffle where we swap bytes in the inputted value
   218  // depending on if the conditions are met.
   219  func swapOrNot(buf []byte, byteV byte, i types.ValidatorIndex, input []types.ValidatorIndex,
   220  	j types.ValidatorIndex, source [32]byte, hashFunc func([]byte) [32]byte) (byte, [32]byte) {
   221  	if j&0xff == 0xff {
   222  		// just overwrite the last part of the buffer, reuse the start (seed, round)
   223  		binary.LittleEndian.PutUint32(buf[pivotViewSize:], uint32(j>>8))
   224  		source = hashFunc(buf)
   225  	}
   226  	if j&0x7 == 0x7 {
   227  		byteV = source[(j&0xff)>>3]
   228  	}
   229  	bitV := (byteV >> (j & 0x7)) & 0x1
   230  
   231  	if bitV == 1 {
   232  		input[i], input[j] = input[j], input[i]
   233  	}
   234  	return byteV, source
   235  }