github.com/jonasnick/go-ethereum@v0.7.12-0.20150216215225-22176f05d387/crypto/ecies/ecies.go (about)

     1  package ecies
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"crypto/ecdsa"
     6  	"crypto/elliptic"
     7  	"crypto/hmac"
     8  	"crypto/subtle"
     9  	"fmt"
    10  	"hash"
    11  	"io"
    12  	"math/big"
    13  )
    14  
    15  var (
    16  	ErrImport                     = fmt.Errorf("ecies: failed to import key")
    17  	ErrInvalidCurve               = fmt.Errorf("ecies: invalid elliptic curve")
    18  	ErrInvalidParams              = fmt.Errorf("ecies: invalid ECIES parameters")
    19  	ErrInvalidPublicKey           = fmt.Errorf("ecies: invalid public key")
    20  	ErrSharedKeyIsPointAtInfinity = fmt.Errorf("ecies: shared key is point at infinity")
    21  	ErrSharedKeyTooBig            = fmt.Errorf("ecies: shared key params are too big")
    22  )
    23  
    24  // PublicKey is a representation of an elliptic curve public key.
    25  type PublicKey struct {
    26  	X *big.Int
    27  	Y *big.Int
    28  	elliptic.Curve
    29  	Params *ECIESParams
    30  }
    31  
    32  // Export an ECIES public key as an ECDSA public key.
    33  func (pub *PublicKey) ExportECDSA() *ecdsa.PublicKey {
    34  	return &ecdsa.PublicKey{pub.Curve, pub.X, pub.Y}
    35  }
    36  
    37  // Import an ECDSA public key as an ECIES public key.
    38  func ImportECDSAPublic(pub *ecdsa.PublicKey) *PublicKey {
    39  	return &PublicKey{
    40  		X:      pub.X,
    41  		Y:      pub.Y,
    42  		Curve:  pub.Curve,
    43  		Params: ParamsFromCurve(pub.Curve),
    44  	}
    45  }
    46  
    47  // PrivateKey is a representation of an elliptic curve private key.
    48  type PrivateKey struct {
    49  	PublicKey
    50  	D *big.Int
    51  }
    52  
    53  // Export an ECIES private key as an ECDSA private key.
    54  func (prv *PrivateKey) ExportECDSA() *ecdsa.PrivateKey {
    55  	pub := &prv.PublicKey
    56  	pubECDSA := pub.ExportECDSA()
    57  	return &ecdsa.PrivateKey{*pubECDSA, prv.D}
    58  }
    59  
    60  // Import an ECDSA private key as an ECIES private key.
    61  func ImportECDSA(prv *ecdsa.PrivateKey) *PrivateKey {
    62  	pub := ImportECDSAPublic(&prv.PublicKey)
    63  	return &PrivateKey{*pub, prv.D}
    64  }
    65  
    66  // Generate an elliptic curve public / private keypair. If params is nil,
    67  // the recommended default paramters for the key will be chosen.
    68  func GenerateKey(rand io.Reader, curve elliptic.Curve, params *ECIESParams) (prv *PrivateKey, err error) {
    69  	pb, x, y, err := elliptic.GenerateKey(curve, rand)
    70  	if err != nil {
    71  		return
    72  	}
    73  	prv = new(PrivateKey)
    74  	prv.PublicKey.X = x
    75  	prv.PublicKey.Y = y
    76  	prv.PublicKey.Curve = curve
    77  	prv.D = new(big.Int).SetBytes(pb)
    78  	if params == nil {
    79  		params = ParamsFromCurve(curve)
    80  	}
    81  	prv.PublicKey.Params = params
    82  	return
    83  }
    84  
    85  // MaxSharedKeyLength returns the maximum length of the shared key the
    86  // public key can produce.
    87  func MaxSharedKeyLength(pub *PublicKey) int {
    88  	return (pub.Curve.Params().BitSize + 7) / 8
    89  }
    90  
    91  // ECDH key agreement method used to establish secret keys for encryption.
    92  func (prv *PrivateKey) GenerateShared(pub *PublicKey, skLen, macLen int) (sk []byte, err error) {
    93  	if prv.PublicKey.Curve != pub.Curve {
    94  		return nil, ErrInvalidCurve
    95  	}
    96  	if skLen+macLen > MaxSharedKeyLength(pub) {
    97  		return nil, ErrSharedKeyTooBig
    98  	}
    99  	x, _ := pub.Curve.ScalarMult(pub.X, pub.Y, prv.D.Bytes())
   100  	if x == nil {
   101  		return nil, ErrSharedKeyIsPointAtInfinity
   102  	}
   103  
   104  	sk = make([]byte, skLen+macLen)
   105  	skBytes := x.Bytes()
   106  	copy(sk[len(sk)-len(skBytes):], skBytes)
   107  	return sk, nil
   108  }
   109  
   110  var (
   111  	ErrKeyDataTooLong = fmt.Errorf("ecies: can't supply requested key data")
   112  	ErrSharedTooLong  = fmt.Errorf("ecies: shared secret is too long")
   113  	ErrInvalidMessage = fmt.Errorf("ecies: invalid message")
   114  )
   115  
   116  var (
   117  	big2To32   = new(big.Int).Exp(big.NewInt(2), big.NewInt(32), nil)
   118  	big2To32M1 = new(big.Int).Sub(big2To32, big.NewInt(1))
   119  )
   120  
   121  func incCounter(ctr []byte) {
   122  	if ctr[3]++; ctr[3] != 0 {
   123  		return
   124  	} else if ctr[2]++; ctr[2] != 0 {
   125  		return
   126  	} else if ctr[1]++; ctr[1] != 0 {
   127  		return
   128  	} else if ctr[0]++; ctr[0] != 0 {
   129  		return
   130  	}
   131  	return
   132  }
   133  
   134  // NIST SP 800-56 Concatenation Key Derivation Function (see section 5.8.1).
   135  func concatKDF(hash hash.Hash, z, s1 []byte, kdLen int) (k []byte, err error) {
   136  	if s1 == nil {
   137  		s1 = make([]byte, 0)
   138  	}
   139  
   140  	reps := ((kdLen + 7) * 8) / (hash.BlockSize() * 8)
   141  	if big.NewInt(int64(reps)).Cmp(big2To32M1) > 0 {
   142  		fmt.Println(big2To32M1)
   143  		return nil, ErrKeyDataTooLong
   144  	}
   145  
   146  	counter := []byte{0, 0, 0, 1}
   147  	k = make([]byte, 0)
   148  
   149  	for i := 0; i <= reps; i++ {
   150  		hash.Write(counter)
   151  		hash.Write(z)
   152  		hash.Write(s1)
   153  		k = append(k, hash.Sum(nil)...)
   154  		hash.Reset()
   155  		incCounter(counter)
   156  	}
   157  
   158  	k = k[:kdLen]
   159  	return
   160  }
   161  
   162  // messageTag computes the MAC of a message (called the tag) as per
   163  // SEC 1, 3.5.
   164  func messageTag(hash func() hash.Hash, km, msg, shared []byte) []byte {
   165  	if shared == nil {
   166  		shared = make([]byte, 0)
   167  	}
   168  	mac := hmac.New(hash, km)
   169  	mac.Write(msg)
   170  	tag := mac.Sum(nil)
   171  	return tag
   172  }
   173  
   174  // Generate an initialisation vector for CTR mode.
   175  func generateIV(params *ECIESParams, rand io.Reader) (iv []byte, err error) {
   176  	iv = make([]byte, params.BlockSize)
   177  	_, err = io.ReadFull(rand, iv)
   178  	return
   179  }
   180  
   181  // symEncrypt carries out CTR encryption using the block cipher specified in the
   182  // parameters.
   183  func symEncrypt(rand io.Reader, params *ECIESParams, key, m []byte) (ct []byte, err error) {
   184  	c, err := params.Cipher(key)
   185  	if err != nil {
   186  		return
   187  	}
   188  
   189  	iv, err := generateIV(params, rand)
   190  	if err != nil {
   191  		return
   192  	}
   193  	ctr := cipher.NewCTR(c, iv)
   194  
   195  	ct = make([]byte, len(m)+params.BlockSize)
   196  	copy(ct, iv)
   197  	ctr.XORKeyStream(ct[params.BlockSize:], m)
   198  	return
   199  }
   200  
   201  // symDecrypt carries out CTR decryption using the block cipher specified in
   202  // the parameters
   203  func symDecrypt(rand io.Reader, params *ECIESParams, key, ct []byte) (m []byte, err error) {
   204  	c, err := params.Cipher(key)
   205  	if err != nil {
   206  		return
   207  	}
   208  
   209  	ctr := cipher.NewCTR(c, ct[:params.BlockSize])
   210  
   211  	m = make([]byte, len(ct)-params.BlockSize)
   212  	ctr.XORKeyStream(m, ct[params.BlockSize:])
   213  	return
   214  }
   215  
   216  // Encrypt encrypts a message using ECIES as specified in SEC 1, 5.1. If
   217  // the shared information parameters aren't being used, they should be
   218  // nil.
   219  func Encrypt(rand io.Reader, pub *PublicKey, m, s1, s2 []byte) (ct []byte, err error) {
   220  	params := pub.Params
   221  	if params == nil {
   222  		if params = ParamsFromCurve(pub.Curve); params == nil {
   223  			err = ErrUnsupportedECIESParameters
   224  			return
   225  		}
   226  	}
   227  	R, err := GenerateKey(rand, pub.Curve, params)
   228  	if err != nil {
   229  		return
   230  	}
   231  
   232  	hash := params.Hash()
   233  	z, err := R.GenerateShared(pub, params.KeyLen, params.KeyLen)
   234  	if err != nil {
   235  		return
   236  	}
   237  	K, err := concatKDF(hash, z, s1, params.KeyLen+params.KeyLen)
   238  	if err != nil {
   239  		return
   240  	}
   241  	Ke := K[:params.KeyLen]
   242  	Km := K[params.KeyLen:]
   243  	hash.Write(Km)
   244  	Km = hash.Sum(nil)
   245  	hash.Reset()
   246  
   247  	em, err := symEncrypt(rand, params, Ke, m)
   248  	if err != nil || len(em) <= params.BlockSize {
   249  		return
   250  	}
   251  
   252  	d := messageTag(params.Hash, Km, em, s2)
   253  
   254  	Rb := elliptic.Marshal(pub.Curve, R.PublicKey.X, R.PublicKey.Y)
   255  	ct = make([]byte, len(Rb)+len(em)+len(d))
   256  	copy(ct, Rb)
   257  	copy(ct[len(Rb):], em)
   258  	copy(ct[len(Rb)+len(em):], d)
   259  	return
   260  }
   261  
   262  // Decrypt decrypts an ECIES ciphertext.
   263  func (prv *PrivateKey) Decrypt(rand io.Reader, c, s1, s2 []byte) (m []byte, err error) {
   264  	if c == nil || len(c) == 0 {
   265  		err = ErrInvalidMessage
   266  		return
   267  	}
   268  	params := prv.PublicKey.Params
   269  	if params == nil {
   270  		if params = ParamsFromCurve(prv.PublicKey.Curve); params == nil {
   271  			err = ErrUnsupportedECIESParameters
   272  			return
   273  		}
   274  	}
   275  	hash := params.Hash()
   276  
   277  	var (
   278  		rLen   int
   279  		hLen   int = hash.Size()
   280  		mStart int
   281  		mEnd   int
   282  	)
   283  
   284  	switch c[0] {
   285  	case 2, 3, 4:
   286  		rLen = ((prv.PublicKey.Curve.Params().BitSize + 7) / 4)
   287  		if len(c) < (rLen + hLen + 1) {
   288  			err = ErrInvalidMessage
   289  			return
   290  		}
   291  	default:
   292  		err = ErrInvalidPublicKey
   293  		return
   294  	}
   295  
   296  	mStart = rLen
   297  	mEnd = len(c) - hLen
   298  
   299  	R := new(PublicKey)
   300  	R.Curve = prv.PublicKey.Curve
   301  	R.X, R.Y = elliptic.Unmarshal(R.Curve, c[:rLen])
   302  	if R.X == nil {
   303  		err = ErrInvalidPublicKey
   304  		return
   305  	}
   306  
   307  	z, err := prv.GenerateShared(R, params.KeyLen, params.KeyLen)
   308  	if err != nil {
   309  		return
   310  	}
   311  
   312  	K, err := concatKDF(hash, z, s1, params.KeyLen+params.KeyLen)
   313  	if err != nil {
   314  		return
   315  	}
   316  
   317  	Ke := K[:params.KeyLen]
   318  	Km := K[params.KeyLen:]
   319  	hash.Write(Km)
   320  	Km = hash.Sum(nil)
   321  	hash.Reset()
   322  
   323  	d := messageTag(params.Hash, Km, c[mStart:mEnd], s2)
   324  	if subtle.ConstantTimeCompare(c[mEnd:], d) != 1 {
   325  		err = ErrInvalidMessage
   326  		return
   327  	}
   328  
   329  	m, err = symDecrypt(rand, params, Ke, c[mStart:mEnd])
   330  	return
   331  }