github.com/consensys/gnark-crypto@v0.14.0/field/hash/hashutils.go (about)

     1  package hash
     2  
     3  import (
     4  	"crypto/sha256"
     5  	"errors"
     6  )
     7  
     8  // ExpandMsgXmd expands msg to a slice of lenInBytes bytes.
     9  // https://datatracker.ietf.org/doc/html/rfc9380#name-expand_message_xmd
    10  // https://datatracker.ietf.org/doc/html/rfc9380#name-utility-functions (I2OSP/O2ISP)
    11  func ExpandMsgXmd(msg, dst []byte, lenInBytes int) ([]byte, error) {
    12  
    13  	h := sha256.New()
    14  	ell := (lenInBytes + h.Size() - 1) / h.Size() // ceil(len_in_bytes / b_in_bytes)
    15  	if ell > 255 {
    16  		return nil, errors.New("invalid lenInBytes")
    17  	}
    18  	if len(dst) > 255 {
    19  		return nil, errors.New("invalid domain size (>255 bytes)")
    20  	}
    21  	sizeDomain := uint8(len(dst))
    22  
    23  	// Z_pad = I2OSP(0, r_in_bytes)
    24  	// l_i_b_str = I2OSP(len_in_bytes, 2)
    25  	// DST_prime = DST ∥ I2OSP(len(DST), 1)
    26  	// b₀ = H(Z_pad ∥ msg ∥ l_i_b_str ∥ I2OSP(0, 1) ∥ DST_prime)
    27  	h.Reset()
    28  	if _, err := h.Write(make([]byte, h.BlockSize())); err != nil {
    29  		return nil, err
    30  	}
    31  	if _, err := h.Write(msg); err != nil {
    32  		return nil, err
    33  	}
    34  	if _, err := h.Write([]byte{uint8(lenInBytes >> 8), uint8(lenInBytes), uint8(0)}); err != nil {
    35  		return nil, err
    36  	}
    37  	if _, err := h.Write(dst); err != nil {
    38  		return nil, err
    39  	}
    40  	if _, err := h.Write([]byte{sizeDomain}); err != nil {
    41  		return nil, err
    42  	}
    43  	b0 := h.Sum(nil)
    44  
    45  	// b₁ = H(b₀ ∥ I2OSP(1, 1) ∥ DST_prime)
    46  	h.Reset()
    47  	if _, err := h.Write(b0); err != nil {
    48  		return nil, err
    49  	}
    50  	if _, err := h.Write([]byte{uint8(1)}); err != nil {
    51  		return nil, err
    52  	}
    53  	if _, err := h.Write(dst); err != nil {
    54  		return nil, err
    55  	}
    56  	if _, err := h.Write([]byte{sizeDomain}); err != nil {
    57  		return nil, err
    58  	}
    59  	b1 := h.Sum(nil)
    60  
    61  	res := make([]byte, lenInBytes)
    62  	copy(res[:h.Size()], b1)
    63  
    64  	for i := 2; i <= ell; i++ {
    65  		// b_i = H(strxor(b₀, b_(i - 1)) ∥ I2OSP(i, 1) ∥ DST_prime)
    66  		h.Reset()
    67  		strxor := make([]byte, h.Size())
    68  		for j := 0; j < h.Size(); j++ {
    69  			strxor[j] = b0[j] ^ b1[j]
    70  		}
    71  		if _, err := h.Write(strxor); err != nil {
    72  			return nil, err
    73  		}
    74  		if _, err := h.Write([]byte{uint8(i)}); err != nil {
    75  			return nil, err
    76  		}
    77  		if _, err := h.Write(dst); err != nil {
    78  			return nil, err
    79  		}
    80  		if _, err := h.Write([]byte{sizeDomain}); err != nil {
    81  			return nil, err
    82  		}
    83  		b1 = h.Sum(nil)
    84  		copy(res[h.Size()*(i-1):min(h.Size()*i, len(res))], b1)
    85  	}
    86  	return res, nil
    87  }
    88  
    89  func min(a, b int) int {
    90  	if a < b {
    91  		return a
    92  	}
    93  	return b
    94  }