github.com/cloudflare/circl@v1.5.0/expander/expander.go (about)

     1  // Package expander generates arbitrary bytes from an XOF or Hash function.
     2  package expander
     3  
     4  import (
     5  	"crypto"
     6  	"encoding/binary"
     7  	"errors"
     8  	"io"
     9  
    10  	"github.com/cloudflare/circl/xof"
    11  )
    12  
    13  type Expander interface {
    14  	// Expand generates a pseudo-random byte string of a determined length by
    15  	// expanding an input string.
    16  	Expand(in []byte, length uint) (pseudo []byte)
    17  }
    18  
    19  type expanderMD struct {
    20  	h   crypto.Hash
    21  	dst []byte
    22  }
    23  
    24  // NewExpanderMD returns a hash function based on a Merkle-Damgård hash function.
    25  func NewExpanderMD(h crypto.Hash, dst []byte) *expanderMD {
    26  	return &expanderMD{h, dst}
    27  }
    28  
    29  func (e *expanderMD) calcDSTPrime() []byte {
    30  	var dstPrime []byte
    31  	if l := len(e.dst); l > maxDSTLength {
    32  		H := e.h.New()
    33  		mustWrite(H, longDSTPrefix[:])
    34  		mustWrite(H, e.dst)
    35  		dstPrime = H.Sum(nil)
    36  	} else {
    37  		dstPrime = make([]byte, l, l+1)
    38  		copy(dstPrime, e.dst)
    39  	}
    40  	return append(dstPrime, byte(len(dstPrime)))
    41  }
    42  
    43  func (e *expanderMD) Expand(in []byte, n uint) []byte {
    44  	H := e.h.New()
    45  	bLen := uint(H.Size())
    46  	ell := (n + (bLen - 1)) / bLen
    47  	if ell > 255 {
    48  		panic(errorLongOutput)
    49  	}
    50  
    51  	zPad := make([]byte, H.BlockSize())
    52  	libStr := []byte{0, 0}
    53  	libStr[0] = byte((n >> 8) & 0xFF)
    54  	libStr[1] = byte(n & 0xFF)
    55  	dstPrime := e.calcDSTPrime()
    56  
    57  	mustWrite(H, zPad)
    58  	mustWrite(H, in)
    59  	mustWrite(H, libStr)
    60  	mustWrite(H, []byte{0})
    61  	mustWrite(H, dstPrime)
    62  	b0 := H.Sum(nil)
    63  
    64  	H.Reset()
    65  	mustWrite(H, b0)
    66  	mustWrite(H, []byte{1})
    67  	mustWrite(H, dstPrime)
    68  	bi := H.Sum(nil)
    69  	pseudo := append([]byte{}, bi...)
    70  	for i := uint(2); i <= ell; i++ {
    71  		H.Reset()
    72  		for i := range b0 {
    73  			bi[i] ^= b0[i]
    74  		}
    75  		mustWrite(H, bi)
    76  		mustWrite(H, []byte{byte(i)})
    77  		mustWrite(H, dstPrime)
    78  		bi = H.Sum(nil)
    79  		pseudo = append(pseudo, bi...)
    80  	}
    81  	return pseudo[0:n]
    82  }
    83  
    84  // expanderXOF is based on an extendable output function.
    85  type expanderXOF struct {
    86  	id        xof.ID
    87  	kSecLevel uint
    88  	dst       []byte
    89  }
    90  
    91  // NewExpanderXOF returns an Expander based on an extendable output function.
    92  // The kSecLevel parameter is the target security level in bits, and dst is
    93  // a domain separation string.
    94  func NewExpanderXOF(id xof.ID, kSecLevel uint, dst []byte) *expanderXOF {
    95  	return &expanderXOF{id, kSecLevel, dst}
    96  }
    97  
    98  // Expand panics if output's length is longer than 2^16 bytes.
    99  func (e *expanderXOF) Expand(in []byte, n uint) []byte {
   100  	bLen := []byte{0, 0}
   101  	binary.BigEndian.PutUint16(bLen, uint16(n))
   102  	pseudo := make([]byte, n)
   103  	dstPrime := e.calcDSTPrime()
   104  
   105  	H := e.id.New()
   106  	mustWrite(H, in)
   107  	mustWrite(H, bLen)
   108  	mustWrite(H, dstPrime)
   109  	mustReadFull(H, pseudo)
   110  	return pseudo
   111  }
   112  
   113  func (e *expanderXOF) calcDSTPrime() []byte {
   114  	var dstPrime []byte
   115  	if l := len(e.dst); l > maxDSTLength {
   116  		H := e.id.New()
   117  		mustWrite(H, longDSTPrefix[:])
   118  		mustWrite(H, e.dst)
   119  		max := ((2 * e.kSecLevel) + 7) / 8
   120  		dstPrime = make([]byte, max, max+1)
   121  		mustReadFull(H, dstPrime)
   122  	} else {
   123  		dstPrime = make([]byte, l, l+1)
   124  		copy(dstPrime, e.dst)
   125  	}
   126  	return append(dstPrime, byte(len(dstPrime)))
   127  }
   128  
   129  func mustWrite(w io.Writer, b []byte) {
   130  	if n, err := w.Write(b); err != nil || n != len(b) {
   131  		panic(err)
   132  	}
   133  }
   134  
   135  func mustReadFull(r io.Reader, b []byte) {
   136  	if n, err := io.ReadFull(r, b); err != nil || n != len(b) {
   137  		panic(err)
   138  	}
   139  }
   140  
   141  const maxDSTLength = 255
   142  
   143  var (
   144  	longDSTPrefix = [17]byte{'H', '2', 'C', '-', 'O', 'V', 'E', 'R', 'S', 'I', 'Z', 'E', '-', 'D', 'S', 'T', '-'}
   145  
   146  	errorLongOutput = errors.New("requested too many bytes")
   147  )