github.com/emmansun/gmsm@v0.29.1/sm2/sm2_keyexchange_sample_test.go (about)

     1  package sm2
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/ecdsa"
     6  	"crypto/elliptic"
     7  	"encoding/hex"
     8  	"errors"
     9  	"math/big"
    10  	"testing"
    11  
    12  	"github.com/emmansun/gmsm/sm3"
    13  )
    14  
    15  type CurveParams struct {
    16  	elliptic.CurveParams
    17  	A *big.Int // the constant of the curve equation
    18  }
    19  
    20  // polynomial returns x³ +ax + b.
    21  func (curve *CurveParams) polynomial(x *big.Int) *big.Int {
    22  	x3 := new(big.Int).Mul(x, x)
    23  	x3.Mul(x3, x)
    24  
    25  	aX := new(big.Int).Mul(curve.A, x)
    26  
    27  	x3.Add(x3, aX)
    28  	x3.Add(x3, curve.B)
    29  	x3.Mod(x3, curve.P)
    30  
    31  	return x3
    32  }
    33  
    34  func (curve *CurveParams) IsOnCurve(x, y *big.Int) bool {
    35  	if x.Sign() < 0 || x.Cmp(curve.P) >= 0 ||
    36  		y.Sign() < 0 || y.Cmp(curve.P) >= 0 {
    37  		return false
    38  	}
    39  
    40  	// y² = x³ + ax + b
    41  	y2 := new(big.Int).Mul(y, y)
    42  	y2.Mod(y2, curve.P)
    43  
    44  	return curve.polynomial(x).Cmp(y2) == 0
    45  }
    46  
    47  // zForAffine returns a Jacobian Z value for the affine point (x, y). If x and
    48  // y are zero, it assumes that they represent the point at infinity because (0,
    49  // 0) is not on the any of the curves handled here.
    50  func zForAffine(x, y *big.Int) *big.Int {
    51  	z := new(big.Int)
    52  	if x.Sign() != 0 || y.Sign() != 0 {
    53  		z.SetInt64(1)
    54  	}
    55  	return z
    56  }
    57  
    58  // affineFromJacobian reverses the Jacobian transform. See the comment at the
    59  // top of the file. If the point is ∞ it returns 0, 0.
    60  func (curve *CurveParams) affineFromJacobian(x, y, z *big.Int) (xOut, yOut *big.Int) {
    61  	if z.Sign() == 0 {
    62  		return new(big.Int), new(big.Int)
    63  	}
    64  
    65  	zinv := new(big.Int).ModInverse(z, curve.P)
    66  	zinvsq := new(big.Int).Mul(zinv, zinv)
    67  
    68  	xOut = new(big.Int).Mul(x, zinvsq)
    69  	xOut.Mod(xOut, curve.P)
    70  	zinvsq.Mul(zinvsq, zinv)
    71  	yOut = new(big.Int).Mul(y, zinvsq)
    72  	yOut.Mod(yOut, curve.P)
    73  	return
    74  }
    75  
    76  func (curve *CurveParams) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) {
    77  	z1 := zForAffine(x1, y1)
    78  	z2 := zForAffine(x2, y2)
    79  	return curve.affineFromJacobian(curve.addJacobian(x1, y1, z1, x2, y2, z2))
    80  }
    81  
    82  // addJacobian takes two points in Jacobian coordinates, (x1, y1, z1) and
    83  // (x2, y2, z2) and returns their sum, also in Jacobian form.
    84  func (curve *CurveParams) addJacobian(x1, y1, z1, x2, y2, z2 *big.Int) (*big.Int, *big.Int, *big.Int) {
    85  	// See https://hyperelliptic.org/EFD/g1p/data/shortw/jacobian/addition/add-2007-bl
    86  	x3, y3, z3 := new(big.Int), new(big.Int), new(big.Int)
    87  	if z1.Sign() == 0 {
    88  		x3.Set(x2)
    89  		y3.Set(y2)
    90  		z3.Set(z2)
    91  		return x3, y3, z3
    92  	}
    93  	if z2.Sign() == 0 {
    94  		x3.Set(x1)
    95  		y3.Set(y1)
    96  		z3.Set(z1)
    97  		return x3, y3, z3
    98  	}
    99  
   100  	z1z1 := new(big.Int).Mul(z1, z1)
   101  	z1z1.Mod(z1z1, curve.P)
   102  	z2z2 := new(big.Int).Mul(z2, z2)
   103  	z2z2.Mod(z2z2, curve.P)
   104  
   105  	u1 := new(big.Int).Mul(x1, z2z2)
   106  	u1.Mod(u1, curve.P)
   107  	u2 := new(big.Int).Mul(x2, z1z1)
   108  	u2.Mod(u2, curve.P)
   109  	h := new(big.Int).Sub(u2, u1)
   110  	xEqual := h.Sign() == 0
   111  	if h.Sign() == -1 {
   112  		h.Add(h, curve.P)
   113  	}
   114  	i := new(big.Int).Lsh(h, 1)
   115  	i.Mul(i, i)
   116  	j := new(big.Int).Mul(h, i)
   117  
   118  	s1 := new(big.Int).Mul(y1, z2)
   119  	s1.Mul(s1, z2z2)
   120  	s1.Mod(s1, curve.P)
   121  	s2 := new(big.Int).Mul(y2, z1)
   122  	s2.Mul(s2, z1z1)
   123  	s2.Mod(s2, curve.P)
   124  	r := new(big.Int).Sub(s2, s1)
   125  	if r.Sign() == -1 {
   126  		r.Add(r, curve.P)
   127  	}
   128  	yEqual := r.Sign() == 0
   129  	if xEqual && yEqual {
   130  		return curve.doubleJacobian(x1, y1, z1)
   131  	}
   132  	r.Lsh(r, 1)
   133  	v := new(big.Int).Mul(u1, i)
   134  
   135  	x3.Set(r)
   136  	x3.Mul(x3, x3)
   137  	x3.Sub(x3, j)
   138  	x3.Sub(x3, v)
   139  	x3.Sub(x3, v)
   140  	x3.Mod(x3, curve.P)
   141  
   142  	y3.Set(r)
   143  	v.Sub(v, x3)
   144  	y3.Mul(y3, v)
   145  	s1.Mul(s1, j)
   146  	s1.Lsh(s1, 1)
   147  	y3.Sub(y3, s1)
   148  	y3.Mod(y3, curve.P)
   149  
   150  	z3.Add(z1, z2)
   151  	z3.Mul(z3, z3)
   152  	z3.Sub(z3, z1z1)
   153  	z3.Sub(z3, z2z2)
   154  	z3.Mul(z3, h)
   155  	z3.Mod(z3, curve.P)
   156  
   157  	return x3, y3, z3
   158  }
   159  
   160  func (curve *CurveParams) Double(x1, y1 *big.Int) (*big.Int, *big.Int) {
   161  	z1 := zForAffine(x1, y1)
   162  	return curve.affineFromJacobian(curve.doubleJacobian(x1, y1, z1))
   163  }
   164  
   165  // doubleJacobian takes a point in Jacobian coordinates, (x, y, z), and
   166  // returns its double, also in Jacobian form.
   167  func (curve *CurveParams) doubleJacobian(x, y, z *big.Int) (*big.Int, *big.Int, *big.Int) {
   168  	// See https://hyperelliptic.org/EFD/g1p/data/shortw/jacobian/doubling/dbl-2007-bl
   169  	xx := new(big.Int).Mul(x, x)
   170  	xx.Mod(xx, curve.P)
   171  	yy := new(big.Int).Mul(y, y)
   172  	yy.Mod(yy, curve.P)
   173  	yyyy := new(big.Int).Mul(yy, yy)
   174  	yyyy.Mod(yyyy, curve.P)
   175  	zz := new(big.Int).Mul(z, z)
   176  	zz.Mod(zz, curve.P)
   177  
   178  	s := new(big.Int).Add(x, yy)
   179  	s.Mul(s, s)
   180  	s.Sub(s, xx)
   181  	if s.Sign() == -1 {
   182  		s.Add(s, curve.P)
   183  	}
   184  	s.Sub(s, yyyy)
   185  	if s.Sign() == -1 {
   186  		s.Add(s, curve.P)
   187  	}
   188  	s.Lsh(s, 1)
   189  	s.Mod(s, curve.P)
   190  
   191  	m := new(big.Int).Mul(xx, big.NewInt(3))
   192  	m.Mod(m, curve.P)
   193  	tmp := new(big.Int).Mul(zz, zz)
   194  	tmp.Mul(tmp, curve.A)
   195  	tmp.Mod(tmp, curve.P)
   196  	m.Add(m, tmp)
   197  	m.Mod(m, curve.P)
   198  
   199  	t := new(big.Int).Mul(m, m)
   200  	t.Sub(t, s)
   201  	if t.Sign() == -1 {
   202  		t.Add(t, curve.P)
   203  	}
   204  	t.Sub(t, s)
   205  	if t.Sign() == -1 {
   206  		t.Add(t, curve.P)
   207  	}
   208  	t.Mod(t, curve.P)
   209  	x3 := t
   210  
   211  	y3 := new(big.Int).Sub(s, t)
   212  	y3.Mul(y3, m)
   213  	yyyy.Lsh(yyyy, 3)
   214  	y3.Sub(y3, yyyy)
   215  	if y3.Sign() == -1 {
   216  		y3.Add(y3, curve.P)
   217  	}
   218  	y3.Mod(y3, curve.P)
   219  
   220  	z3 := new(big.Int).Add(y, z)
   221  	z3.Mul(z3, z3)
   222  	z3.Sub(z3, yy)
   223  	if z3.Sign() == -1 {
   224  		z3.Add(z3, curve.P)
   225  	}
   226  	z3.Sub(z3, zz)
   227  	if z3.Sign() == -1 {
   228  		z3.Add(z3, curve.P)
   229  	}
   230  	z3.Mod(z3, curve.P)
   231  
   232  	return x3, y3, z3
   233  }
   234  
   235  func (curve *CurveParams) ScalarMult(Bx, By *big.Int, k []byte) (*big.Int, *big.Int) {
   236  	Bz := new(big.Int).SetInt64(1)
   237  	x, y, z := new(big.Int), new(big.Int), new(big.Int)
   238  
   239  	for _, byte := range k {
   240  		for bitNum := 0; bitNum < 8; bitNum++ {
   241  			x, y, z = curve.doubleJacobian(x, y, z)
   242  			if byte&0x80 == 0x80 {
   243  				x, y, z = curve.addJacobian(Bx, By, Bz, x, y, z)
   244  			}
   245  			byte <<= 1
   246  		}
   247  	}
   248  
   249  	return curve.affineFromJacobian(x, y, z)
   250  }
   251  
   252  func (curve *CurveParams) ScalarBaseMult(k []byte) (*big.Int, *big.Int) {
   253  	return curve.ScalarMult(curve.Gx, curve.Gy, k)
   254  }
   255  
   256  func bigFromHex(s string) *big.Int {
   257  	b, ok := new(big.Int).SetString(s, 16)
   258  	if !ok {
   259  		panic("sm2/elliptic: internal error: invalid encoding")
   260  	}
   261  	return b
   262  }
   263  
   264  var sampleParams = &CurveParams{
   265  	elliptic.CurveParams{
   266  		Name:    "sampleCurve",
   267  		BitSize: 256,
   268  		P:       bigFromHex("8542D69E4C044F18E8B92435BF6FF7DE457283915C45517D722EDB8B08F1DFC3"),
   269  		N:       bigFromHex("8542D69E4C044F18E8B92435BF6FF7DD297720630485628D5AE74EE7C32E79B7"),
   270  		B:       bigFromHex("63E4C6D3B23B0C849CF84241484BFE48F61D59A5B16BA06E6E12D1DA27C5249A"),
   271  		Gx:      bigFromHex("421DEBD61B62EAB6746434EBC3CC315E32220B3BADD50BDC4C4E6C147FEDD43D"),
   272  		Gy:      bigFromHex("0680512BCBB42C07D47349D2153B70C4E5D7FDFCBFA36EA1A85841B9E46E09A2"),
   273  	},
   274  	bigFromHex("787968B4FA32C3FD2417842E73BBFEFF2F3C848B6831D7E0EC65228B3937E498"),
   275  }
   276  
   277  func TestPublicKey(t *testing.T) {
   278  	d := bigFromHex("6FCBA2EF9AE0AB902BC3BDE3FF915D44BA4CC78F88E2F8E7F8996D3B8CCEEDEE")
   279  	x, y := sampleParams.ScalarBaseMult(d.Bytes())
   280  	if hex.EncodeToString(x.Bytes()) != "3099093bf3c137d8fcbbcdf4a2ae50f3b0f216c3122d79425fe03a45dbfe1655" ||
   281  		hex.EncodeToString(y.Bytes()) != "3df79e8dac1cf0ecbaa2f2b49d51a4b387f2efaf482339086a27a8e05baed98b" {
   282  		t.FailNow()
   283  	}
   284  	d = bigFromHex("5E35D7D3F3C54DBAC72E61819E730B019A84208CA3A35E4C2E353DFCCB2A3B53")
   285  	x, y = sampleParams.ScalarBaseMult(d.Bytes())
   286  	if hex.EncodeToString(x.Bytes()) != "245493d446c38d8cc0f118374690e7df633a8a4bfb3329b5ece604b2b4f37f43" ||
   287  		hex.EncodeToString(y.Bytes()) != "53c0869f4b9e17773de68fec45e14904e0dea45bf6cecf9918c85ea047c60a4c" {
   288  		t.FailNow()
   289  	}
   290  }
   291  
   292  // calculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
   293  func calculateSampleZA(pub *ecdsa.PublicKey, a *big.Int, uid []byte) ([]byte, error) {
   294  	uidLen := len(uid)
   295  	if uidLen >= 0x2000 {
   296  		return nil, errors.New("sm2: the uid is too long")
   297  	}
   298  	entla := uint16(uidLen) << 3
   299  	md := sm3.New()
   300  	md.Write([]byte{byte(entla >> 8), byte(entla)})
   301  	if uidLen > 0 {
   302  		md.Write(uid)
   303  	}
   304  	md.Write(toBytes(pub.Curve, a))
   305  	md.Write(toBytes(pub.Curve, pub.Params().B))
   306  	md.Write(toBytes(pub.Curve, pub.Params().Gx))
   307  	md.Write(toBytes(pub.Curve, pub.Params().Gy))
   308  	md.Write(toBytes(pub.Curve, pub.X))
   309  	md.Write(toBytes(pub.Curve, pub.Y))
   310  	return md.Sum(nil), nil
   311  }
   312  
   313  // Sample from Appendix A.2
   314  func TestKeyExchangeRealSample(t *testing.T) {
   315  	initiatorUID := []byte("ALICE123@YAHOO.COM")
   316  	responderUID := []byte("BILL456@YAHOO.COM")
   317  	kenLen := 16
   318  
   319  	// initiator's private key
   320  	privA := new(PrivateKey)
   321  	privA.D = bigFromHex("6FCBA2EF9AE0AB902BC3BDE3FF915D44BA4CC78F88E2F8E7F8996D3B8CCEEDEE")
   322  	privA.Curve = sampleParams
   323  	privA.X, privA.Y = privA.Curve.ScalarBaseMult(privA.D.Bytes())
   324  	if hex.EncodeToString(privA.X.Bytes()) != "3099093bf3c137d8fcbbcdf4a2ae50f3b0f216c3122d79425fe03a45dbfe1655" ||
   325  		hex.EncodeToString(privA.Y.Bytes()) != "3df79e8dac1cf0ecbaa2f2b49d51a4b387f2efaf482339086a27a8e05baed98b" {
   326  		t.Fatalf("unexpected public key PA")
   327  	}
   328  
   329  	// initiator's Z value
   330  	za, _ := calculateSampleZA(&privA.PublicKey, sampleParams.A, initiatorUID)
   331  	if hex.EncodeToString(za) != "e4d1d0c3ca4c7f11bc8ff8cb3f4c02a78f108fa098e51a668487240f75e20f31" {
   332  		t.Fatalf("unexpected ZA")
   333  	}
   334  
   335  	// responder's private key
   336  	privB := new(PrivateKey)
   337  	privB.D = bigFromHex("5E35D7D3F3C54DBAC72E61819E730B019A84208CA3A35E4C2E353DFCCB2A3B53")
   338  	privB.Curve = sampleParams
   339  	privB.X, privB.Y = privB.Curve.ScalarBaseMult(privB.D.Bytes())
   340  	if hex.EncodeToString(privB.X.Bytes()) != "245493d446c38d8cc0f118374690e7df633a8a4bfb3329b5ece604b2b4f37f43" ||
   341  		hex.EncodeToString(privB.Y.Bytes()) != "53c0869f4b9e17773de68fec45e14904e0dea45bf6cecf9918c85ea047c60a4c" {
   342  		t.Fatalf("unexpected public key PB")
   343  	}
   344  	// responder's Z value
   345  	zb, _ := calculateSampleZA(&privB.PublicKey, sampleParams.A, responderUID)
   346  	if hex.EncodeToString(zb) != "6b4b6d0e276691bd4a11bf72f4fb501ae309fdacb72fa6cc336e6656119abd67" {
   347  		t.Fatalf("unexpected ZB")
   348  	}
   349  
   350  	// create initiator
   351  	initiator, err := NewKeyExchange(privA, &privB.PublicKey, initiatorUID, responderUID, kenLen, true)
   352  	if err != nil {
   353  		t.Fatal(err)
   354  	}
   355  	// overwrite Z values, due to different A
   356  	initiator.z = za
   357  	initiator.peerZ = zb
   358  
   359  	// create responder
   360  	responder, err := NewKeyExchange(privB, &privA.PublicKey, responderUID, initiatorUID, kenLen, true)
   361  	if err != nil {
   362  		t.Fatal(err)
   363  	}
   364  	// overwrite Z values, due to different A
   365  	responder.z = zb
   366  	responder.peerZ = za
   367  
   368  	defer func() {
   369  		initiator.Destroy()
   370  		responder.Destroy()
   371  	}()
   372  
   373  	// for initiator's step A1-A3
   374  	rA := bigFromHex("83A2C9C8B96E5AF70BD480B472409A9A327257F1EBB73F5B073354B248668563")
   375  	initKeyExchange(initiator, rA)
   376  	if hex.EncodeToString(initiator.secret.X.Bytes()) != "6cb5633816f4dd560b1dec458310cbcc6856c09505324a6d23150c408f162bf0" ||
   377  		hex.EncodeToString(initiator.secret.Y.Bytes()) != "0d6fcf62f1036c0a1b6daccf57399223a65f7d7bf2d9637e5bbbeb857961bf1a" {
   378  		t.Fatalf("unexpected RA")
   379  	}
   380  
   381  	// for responder's step B1-B8
   382  	rB := bigFromHex("33FE21940342161C55619C4A0C060293D543C80AF19748CE176D83477DE71C80")
   383  	RB, sB, _ := respondKeyExchange(responder, initiator.secret, rB)
   384  	if hex.EncodeToString(RB.X.Bytes()) != "1799b2a2c778295300d9a2325c686129b8f2b5337b3dcf4514e8bbc19d900ee5" ||
   385  		hex.EncodeToString(RB.Y.Bytes()) != "54c9288c82733efdf7808ae7f27d0e732f7c73a7d9ac98b7d8740a91d0db3cf4" {
   386  		t.Fatalf("unexpected RB")
   387  	}
   388  	if hex.EncodeToString(sB) != "284c8f198f141b502e81250f1581c7e9eeb4ca6990f9e02df388b45471f5bc5c" {
   389  		t.Fatalf("unexpected sB")
   390  	}
   391  
   392  	// for initiator's step A4-A10
   393  	keyA, sA, err := initiator.ConfirmResponder(RB, sB)
   394  	if err != nil {
   395  		t.Fatal(err)
   396  	}
   397  	if hex.EncodeToString(sA) != "23444daf8ed7534366cb901c84b3bdbb63504f4065c1116c91a4c00697e6cf7a" {
   398  		t.Fatalf("unexpected sA")
   399  	}
   400  
   401  	// for responder's step B10
   402  	keyB, err := responder.ConfirmInitiator(sA)
   403  	if err != nil {
   404  		t.Fatal(err)
   405  	}
   406  	if !bytes.Equal(keyA, keyB) {
   407  		t.Errorf("got different key")
   408  	}
   409  	if !bytes.Equal(keyA, hexDecode(t, "55B0AC62A6B927BA23703832C853DED4")) {
   410  		t.Errorf("got unexpected keying data %v\n", hex.EncodeToString(keyA))
   411  	}
   412  }