github.com/lestrrat-go/jwx/v2@v2.0.21/jwe/internal/concatkdf/concatkdf.go (about)

     1  package concatkdf
     2  
     3  import (
     4  	"crypto"
     5  	"encoding/binary"
     6  	"fmt"
     7  )
     8  
     9  type KDF struct {
    10  	buf       []byte
    11  	otherinfo []byte
    12  	z         []byte
    13  	hash      crypto.Hash
    14  }
    15  
    16  func ndata(src []byte) []byte {
    17  	buf := make([]byte, 4+len(src))
    18  	binary.BigEndian.PutUint32(buf, uint32(len(src)))
    19  	copy(buf[4:], src)
    20  	return buf
    21  }
    22  
    23  func New(hash crypto.Hash, alg, Z, apu, apv, pubinfo, privinfo []byte) *KDF {
    24  	algbuf := ndata(alg)
    25  	apubuf := ndata(apu)
    26  	apvbuf := ndata(apv)
    27  
    28  	concat := make([]byte, len(algbuf)+len(apubuf)+len(apvbuf)+len(pubinfo)+len(privinfo))
    29  	n := copy(concat, algbuf)
    30  	n += copy(concat[n:], apubuf)
    31  	n += copy(concat[n:], apvbuf)
    32  	n += copy(concat[n:], pubinfo)
    33  	copy(concat[n:], privinfo)
    34  
    35  	return &KDF{
    36  		hash:      hash,
    37  		otherinfo: concat,
    38  		z:         Z,
    39  	}
    40  }
    41  
    42  func (k *KDF) Read(out []byte) (int, error) {
    43  	var round uint32 = 1
    44  	h := k.hash.New()
    45  
    46  	for len(out) > len(k.buf) {
    47  		h.Reset()
    48  
    49  		if err := binary.Write(h, binary.BigEndian, round); err != nil {
    50  			return 0, fmt.Errorf(`failed to write round using kdf: %w`, err)
    51  		}
    52  		if _, err := h.Write(k.z); err != nil {
    53  			return 0, fmt.Errorf(`failed to write z using kdf: %w`, err)
    54  		}
    55  		if _, err := h.Write(k.otherinfo); err != nil {
    56  			return 0, fmt.Errorf(`failed to write other info using kdf: %w`, err)
    57  		}
    58  
    59  		k.buf = append(k.buf, h.Sum(nil)...)
    60  		round++
    61  	}
    62  
    63  	n := copy(out, k.buf[:len(out)])
    64  	k.buf = k.buf[len(out):]
    65  	return n, nil
    66  }