github.com/cloudflare/circl@v1.5.0/hpke/kembase.go (about)

     1  package hpke
     2  
     3  import (
     4  	"crypto"
     5  	"crypto/rand"
     6  	"encoding/binary"
     7  	"io"
     8  
     9  	"github.com/cloudflare/circl/kem"
    10  	"golang.org/x/crypto/hkdf"
    11  )
    12  
    13  type dhKEM interface {
    14  	sizeDH() int
    15  	calcDH(dh []byte, sk kem.PrivateKey, pk kem.PublicKey) error
    16  	SeedSize() int
    17  	DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey)
    18  	UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error)
    19  	UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error)
    20  }
    21  
    22  type kemBase struct {
    23  	id   KEM
    24  	name string
    25  	crypto.Hash
    26  }
    27  
    28  type dhKemBase struct {
    29  	kemBase
    30  	dhKEM
    31  }
    32  
    33  func (k kemBase) Name() string       { return k.name }
    34  func (k kemBase) SharedKeySize() int { return k.Hash.Size() }
    35  
    36  func (k kemBase) getSuiteID() (sid [5]byte) {
    37  	sid[0], sid[1], sid[2] = 'K', 'E', 'M'
    38  	binary.BigEndian.PutUint16(sid[3:5], uint16(k.id))
    39  	return
    40  }
    41  
    42  func (k kemBase) extractExpand(dh, kemCtx []byte) []byte {
    43  	eaePkr := k.labeledExtract([]byte(""), []byte("eae_prk"), dh)
    44  	return k.labeledExpand(
    45  		eaePkr,
    46  		[]byte("shared_secret"),
    47  		kemCtx,
    48  		uint16(k.Size()),
    49  	)
    50  }
    51  
    52  func (k kemBase) labeledExtract(salt, label, info []byte) []byte {
    53  	suiteID := k.getSuiteID()
    54  	labeledIKM := append(append(append(append(
    55  		make([]byte, 0, len(versionLabel)+len(suiteID)+len(label)+len(info)),
    56  		versionLabel...),
    57  		suiteID[:]...),
    58  		label...),
    59  		info...)
    60  	return hkdf.Extract(k.New, labeledIKM, salt)
    61  }
    62  
    63  func (k kemBase) labeledExpand(prk, label, info []byte, l uint16) []byte {
    64  	suiteID := k.getSuiteID()
    65  	labeledInfo := make(
    66  		[]byte,
    67  		2,
    68  		2+len(versionLabel)+len(suiteID)+len(label)+len(info),
    69  	)
    70  	binary.BigEndian.PutUint16(labeledInfo[0:2], l)
    71  	labeledInfo = append(append(append(append(labeledInfo,
    72  		versionLabel...),
    73  		suiteID[:]...),
    74  		label...),
    75  		info...)
    76  	b := make([]byte, l)
    77  	rd := hkdf.Expand(k.New, prk, labeledInfo)
    78  	if _, err := io.ReadFull(rd, b); err != nil {
    79  		panic(err)
    80  	}
    81  	return b
    82  }
    83  
    84  func (k dhKemBase) AuthEncapsulate(pkr kem.PublicKey, sks kem.PrivateKey) (
    85  	ct []byte, ss []byte, err error,
    86  ) {
    87  	seed := make([]byte, k.SeedSize())
    88  	_, err = io.ReadFull(rand.Reader, seed)
    89  	if err != nil {
    90  		return nil, nil, err
    91  	}
    92  
    93  	return k.authEncap(pkr, sks, seed)
    94  }
    95  
    96  func (k dhKemBase) Encapsulate(pkr kem.PublicKey) (
    97  	ct []byte, ss []byte, err error,
    98  ) {
    99  	seed := make([]byte, k.SeedSize())
   100  	_, err = io.ReadFull(rand.Reader, seed)
   101  	if err != nil {
   102  		return nil, nil, err
   103  	}
   104  
   105  	return k.encap(pkr, seed)
   106  }
   107  
   108  func (k dhKemBase) AuthEncapsulateDeterministically(
   109  	pkr kem.PublicKey, sks kem.PrivateKey, seed []byte,
   110  ) (ct, ss []byte, err error) {
   111  	return k.authEncap(pkr, sks, seed)
   112  }
   113  
   114  func (k dhKemBase) EncapsulateDeterministically(
   115  	pkr kem.PublicKey, seed []byte,
   116  ) (ct, ss []byte, err error) {
   117  	return k.encap(pkr, seed)
   118  }
   119  
   120  func (k dhKemBase) encap(
   121  	pkR kem.PublicKey,
   122  	seed []byte,
   123  ) (ct []byte, ss []byte, err error) {
   124  	dh := make([]byte, k.sizeDH())
   125  	enc, kemCtx, err := k.coreEncap(dh, pkR, seed)
   126  	if err != nil {
   127  		return nil, nil, err
   128  	}
   129  	ss = k.extractExpand(dh, kemCtx)
   130  	return enc, ss, nil
   131  }
   132  
   133  func (k dhKemBase) authEncap(
   134  	pkR kem.PublicKey,
   135  	skS kem.PrivateKey,
   136  	seed []byte,
   137  ) (ct []byte, ss []byte, err error) {
   138  	dhLen := k.sizeDH()
   139  	dh := make([]byte, 2*dhLen)
   140  	enc, kemCtx, err := k.coreEncap(dh[:dhLen], pkR, seed)
   141  	if err != nil {
   142  		return nil, nil, err
   143  	}
   144  
   145  	err = k.calcDH(dh[dhLen:], skS, pkR)
   146  	if err != nil {
   147  		return nil, nil, err
   148  	}
   149  
   150  	pkS := skS.Public()
   151  	pkSm, err := pkS.MarshalBinary()
   152  	if err != nil {
   153  		return nil, nil, err
   154  	}
   155  	kemCtx = append(kemCtx, pkSm...)
   156  
   157  	ss = k.extractExpand(dh, kemCtx)
   158  	return enc, ss, nil
   159  }
   160  
   161  func (k dhKemBase) coreEncap(
   162  	dh []byte,
   163  	pkR kem.PublicKey,
   164  	seed []byte,
   165  ) (enc []byte, kemCtx []byte, err error) {
   166  	pkE, skE := k.DeriveKeyPair(seed)
   167  	err = k.calcDH(dh, skE, pkR)
   168  	if err != nil {
   169  		return nil, nil, err
   170  	}
   171  
   172  	enc, err = pkE.MarshalBinary()
   173  	if err != nil {
   174  		return nil, nil, err
   175  	}
   176  	pkRm, err := pkR.MarshalBinary()
   177  	if err != nil {
   178  		return nil, nil, err
   179  	}
   180  	kemCtx = append(append([]byte{}, enc...), pkRm...)
   181  
   182  	return enc, kemCtx, nil
   183  }
   184  
   185  func (k dhKemBase) Decapsulate(skr kem.PrivateKey, ct []byte) ([]byte, error) {
   186  	dh := make([]byte, k.sizeDH())
   187  	kemCtx, err := k.coreDecap(dh, skr, ct)
   188  	if err != nil {
   189  		return nil, err
   190  	}
   191  	return k.extractExpand(dh, kemCtx), nil
   192  }
   193  
   194  func (k dhKemBase) AuthDecapsulate(
   195  	skR kem.PrivateKey,
   196  	ct []byte,
   197  	pkS kem.PublicKey,
   198  ) ([]byte, error) {
   199  	dhLen := k.sizeDH()
   200  	dh := make([]byte, 2*dhLen)
   201  	kemCtx, err := k.coreDecap(dh[:dhLen], skR, ct)
   202  	if err != nil {
   203  		return nil, err
   204  	}
   205  
   206  	err = k.calcDH(dh[dhLen:], skR, pkS)
   207  	if err != nil {
   208  		return nil, err
   209  	}
   210  
   211  	pkSm, err := pkS.MarshalBinary()
   212  	if err != nil {
   213  		return nil, err
   214  	}
   215  	kemCtx = append(kemCtx, pkSm...)
   216  	return k.extractExpand(dh, kemCtx), nil
   217  }
   218  
   219  func (k dhKemBase) coreDecap(
   220  	dh []byte,
   221  	skR kem.PrivateKey,
   222  	ct []byte,
   223  ) ([]byte, error) {
   224  	pkE, err := k.UnmarshalBinaryPublicKey(ct)
   225  	if err != nil {
   226  		return nil, err
   227  	}
   228  
   229  	err = k.calcDH(dh, skR, pkE)
   230  	if err != nil {
   231  		return nil, err
   232  	}
   233  
   234  	pkR := skR.Public()
   235  	pkRm, err := pkR.MarshalBinary()
   236  	if err != nil {
   237  		return nil, err
   238  	}
   239  
   240  	return append(append([]byte{}, ct...), pkRm...), nil
   241  }