github.com/emmansun/gmsm@v0.29.1/ecdh/sm2ec.go (about)

     1  package ecdh
     2  
     3  import (
     4  	"encoding/binary"
     5  	"errors"
     6  	"hash"
     7  	"io"
     8  	"math/bits"
     9  
    10  	"github.com/emmansun/gmsm/internal/randutil"
    11  	sm2ec "github.com/emmansun/gmsm/internal/sm2ec"
    12  	"github.com/emmansun/gmsm/internal/subtle"
    13  )
    14  
    15  type sm2Curve struct {
    16  	name              string
    17  	newPoint          func() *sm2ec.SM2P256Point
    18  	scalarOrderMinus1 []byte
    19  	constantA         []byte
    20  	constantB         []byte
    21  	generator         []byte
    22  }
    23  
    24  func (c *sm2Curve) String() string {
    25  	return c.name
    26  }
    27  
    28  func (c *sm2Curve) GenerateKey(rand io.Reader) (*PrivateKey, error) {
    29  	key := make([]byte, len(c.scalarOrderMinus1))
    30  	randutil.MaybeReadByte(rand)
    31  
    32  	for {
    33  		if _, err := io.ReadFull(rand, key); err != nil {
    34  			return nil, err
    35  		}
    36  
    37  		// In tests, rand will return all zeros and NewPrivateKey will reject
    38  		// the zero key as it generates the identity as a public key. This also
    39  		// makes this function consistent with crypto/elliptic.GenerateKey.
    40  		key[1] ^= 0x42
    41  
    42  		k, err := c.NewPrivateKey(key)
    43  		if err == errInvalidPrivateKey {
    44  			continue
    45  		}
    46  		return k, err
    47  	}
    48  }
    49  
    50  func (c *sm2Curve) NewPrivateKey(key []byte) (*PrivateKey, error) {
    51  	if len(key) != len(c.scalarOrderMinus1) {
    52  		return nil, errors.New("ecdh: invalid private key size")
    53  	}
    54  	if subtle.ConstantTimeAllZero(key) == 1 || !isLess(key, c.scalarOrderMinus1) {
    55  		return nil, errInvalidPrivateKey
    56  	}
    57  	return &PrivateKey{
    58  		curve:      c,
    59  		privateKey: append([]byte{}, key...),
    60  	}, nil
    61  }
    62  
    63  func (c *sm2Curve) privateKeyToPublicKey(key *PrivateKey) *PublicKey {
    64  	if key.curve != c {
    65  		panic("ecdh: internal error: converting the wrong key type")
    66  	}
    67  	p, err := c.newPoint().ScalarBaseMult(key.privateKey)
    68  	if err != nil {
    69  		// This is unreachable because the only error condition of
    70  		// ScalarBaseMult is if the input is not the right size.
    71  		panic("ecdh: internal error: sm2ec ScalarBaseMult failed for a fixed-size input")
    72  	}
    73  	publicKey := p.Bytes()
    74  	if len(publicKey) == 1 {
    75  		// The encoding of the identity is a single 0x00 byte. This is
    76  		// unreachable because the only scalar that generates the identity is
    77  		// zero, which is rejected by NewPrivateKey.
    78  		panic("ecdh: internal error: sm2ec ScalarBaseMult returned the identity")
    79  	}
    80  	return &PublicKey{
    81  		curve:     key.curve,
    82  		publicKey: publicKey,
    83  	}
    84  }
    85  
    86  func (c *sm2Curve) NewPublicKey(key []byte) (*PublicKey, error) {
    87  	// Reject the point at infinity and compressed encodings.
    88  	if len(key) == 0 || key[0] != 4 {
    89  		return nil, errors.New("ecdh: invalid public key")
    90  	}
    91  	// SetBytes also checks that the point is on the curve.
    92  	if _, err := c.newPoint().SetBytes(key); err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	return &PublicKey{
    97  		curve:     c,
    98  		publicKey: append([]byte{}, key...),
    99  	}, nil
   100  }
   101  
   102  func (c *sm2Curve) ecdh(local *PrivateKey, remote *PublicKey) ([]byte, error) {
   103  	p, err := c.newPoint().SetBytes(remote.publicKey)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  	if _, err := p.ScalarMult(p, local.privateKey); err != nil {
   108  		return nil, err
   109  	}
   110  	// BytesX will return an error if p is the point at infinity.
   111  	return p.BytesX()
   112  }
   113  
   114  func (c *sm2Curve) sm2avf(secret *PublicKey) []byte {
   115  	bytes := secret.publicKey[1:33]
   116  	var result [32]byte
   117  	copy(result[16:], bytes[16:])
   118  	result[16] = (result[16] & 0x7f) | 0x80
   119  
   120  	return result[:]
   121  }
   122  
   123  func (c *sm2Curve) sm2mqv(sLocal, eLocal *PrivateKey, sRemote, eRemote *PublicKey) (*PublicKey, error) {
   124  	// implicitSig: (sLocal + avf(eLocal.Pub) * ePriv) mod N
   125  	x2 := c.sm2avf(eLocal.PublicKey())
   126  	t, err := sm2ec.ImplicitSig(sLocal.privateKey, eLocal.privateKey, x2)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  
   131  	// new base point: peerPub + [x1](peerSecret)
   132  	x1 := c.sm2avf(eRemote)
   133  	p2, err := c.newPoint().SetBytes(eRemote.publicKey)
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  	if _, err := p2.ScalarMult(p2, x1); err != nil {
   138  		return nil, err
   139  	}
   140  	p1, err := c.newPoint().SetBytes(sRemote.publicKey)
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  	p2.Add(p1, p2)
   145  
   146  	if _, err := p2.ScalarMult(p2, t); err != nil {
   147  		return nil, err
   148  	}
   149  	return c.NewPublicKey(p2.Bytes())
   150  }
   151  
   152  var defaultUID = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38}
   153  
   154  // CalculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA).
   155  // Compliance with GB/T 32918.2-2016 5.5
   156  func (c *sm2Curve) sm2za(md hash.Hash, pub *PublicKey, uid []byte) ([]byte, error) {
   157  	if len(uid) == 0 {
   158  		uid = defaultUID
   159  	}
   160  	uidLen := len(uid)
   161  	if uidLen >= 0x2000 {
   162  		return nil, errors.New("ecdh: the uid is too long")
   163  	}
   164  	entla := uint16(uidLen) << 3
   165  	md.Write([]byte{byte(entla >> 8), byte(entla)})
   166  	if uidLen > 0 {
   167  		md.Write(uid)
   168  	}
   169  	md.Write(c.constantA)
   170  	md.Write(c.constantB)
   171  	md.Write(c.generator)
   172  	md.Write(pub.publicKey[1:])
   173  
   174  	return md.Sum(nil), nil
   175  }
   176  
   177  // P256 returns a Curve which implements SM2, also known as sm2p256v1
   178  //
   179  // Multiple invocations of this function will return the same value, so it can
   180  // be used for equality checks and switch statements.
   181  func P256() Curve { return sm2P256 }
   182  
   183  var sm2P256 = &sm2Curve{
   184  	name:              "sm2p256v1",
   185  	newPoint:          sm2ec.NewSM2P256Point,
   186  	scalarOrderMinus1: sm2P256OrderMinus1,
   187  	generator:         sm2Generator,
   188  	constantA:         sm2ConstantA,
   189  	constantB:         sm2ConstantB,
   190  }
   191  
   192  var sm2P256OrderMinus1 = []byte{
   193  	0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xff, 0xff,
   194  	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
   195  	0x72, 0x03, 0xdf, 0x6b, 0x21, 0xc6, 0x05, 0x2b,
   196  	0x53, 0xbb, 0xf4, 0x09, 0x39, 0xd5, 0x41, 0x22}
   197  var sm2Generator = []byte{
   198  	0x32, 0xc4, 0xae, 0x2c, 0x1f, 0x19, 0x81, 0x19,
   199  	0x5f, 0x99, 0x4, 0x46, 0x6a, 0x39, 0xc9, 0x94,
   200  	0x8f, 0xe3, 0xb, 0xbf, 0xf2, 0x66, 0xb, 0xe1,
   201  	0x71, 0x5a, 0x45, 0x89, 0x33, 0x4c, 0x74, 0xc7,
   202  	0xbc, 0x37, 0x36, 0xa2, 0xf4, 0xf6, 0x77, 0x9c,
   203  	0x59, 0xbd, 0xce, 0xe3, 0x6b, 0x69, 0x21, 0x53,
   204  	0xd0, 0xa9, 0x87, 0x7c, 0xc6, 0x2a, 0x47, 0x40,
   205  	0x2, 0xdf, 0x32, 0xe5, 0x21, 0x39, 0xf0, 0xa0}
   206  var sm2ConstantA = []byte{
   207  	0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xff, 0xff,
   208  	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
   209  	0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
   210  	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfc}
   211  var sm2ConstantB = []byte{
   212  	0x28, 0xe9, 0xfa, 0x9e, 0x9d, 0x9f, 0x5e, 0x34,
   213  	0x4d, 0x5a, 0x9e, 0x4b, 0xcf, 0x65, 0x09, 0xa7,
   214  	0xf3, 0x97, 0x89, 0xf5, 0x15, 0xab, 0x8f, 0x92,
   215  	0xdd, 0xbc, 0xbd, 0x41, 0x4d, 0x94, 0x0e, 0x93}
   216  
   217  // isLess returns whether a < b, where a and b are big-endian buffers of the
   218  // same length and shorter than 72 bytes.
   219  func isLess(a, b []byte) bool {
   220  	if len(a) != len(b) {
   221  		panic("ecdh: internal error: mismatched isLess inputs")
   222  	}
   223  
   224  	// Copy the values into a fixed-size preallocated little-endian buffer.
   225  	// 72 bytes is enough for every scalar in this package, and having a fixed
   226  	// size lets us avoid heap allocations.
   227  	if len(a) > 72 {
   228  		panic("ecdh: internal error: isLess input too large")
   229  	}
   230  	bufA, bufB := make([]byte, 72), make([]byte, 72)
   231  	for i := range a {
   232  		bufA[i], bufB[i] = a[len(a)-i-1], b[len(b)-i-1]
   233  	}
   234  
   235  	// Perform a subtraction with borrow.
   236  	var borrow uint64
   237  	for i := 0; i < len(bufA); i += 8 {
   238  		limbA, limbB := binary.LittleEndian.Uint64(bufA[i:]), binary.LittleEndian.Uint64(bufB[i:])
   239  		_, borrow = bits.Sub64(limbA, limbB, borrow)
   240  	}
   241  
   242  	// If there is a borrow at the end of the operation, then a < b.
   243  	return borrow == 1
   244  }
   245  
   246  var errInvalidPrivateKey = errors.New("ecdh: invalid private key")