github.com/emmansun/gmsm@v0.29.1/sm2/sm2_keyexchange_sample_test.go (about) 1 package sm2 2 3 import ( 4 "bytes" 5 "crypto/ecdsa" 6 "crypto/elliptic" 7 "encoding/hex" 8 "errors" 9 "math/big" 10 "testing" 11 12 "github.com/emmansun/gmsm/sm3" 13 ) 14 15 type CurveParams struct { 16 elliptic.CurveParams 17 A *big.Int // the constant of the curve equation 18 } 19 20 // polynomial returns x³ +ax + b. 21 func (curve *CurveParams) polynomial(x *big.Int) *big.Int { 22 x3 := new(big.Int).Mul(x, x) 23 x3.Mul(x3, x) 24 25 aX := new(big.Int).Mul(curve.A, x) 26 27 x3.Add(x3, aX) 28 x3.Add(x3, curve.B) 29 x3.Mod(x3, curve.P) 30 31 return x3 32 } 33 34 func (curve *CurveParams) IsOnCurve(x, y *big.Int) bool { 35 if x.Sign() < 0 || x.Cmp(curve.P) >= 0 || 36 y.Sign() < 0 || y.Cmp(curve.P) >= 0 { 37 return false 38 } 39 40 // y² = x³ + ax + b 41 y2 := new(big.Int).Mul(y, y) 42 y2.Mod(y2, curve.P) 43 44 return curve.polynomial(x).Cmp(y2) == 0 45 } 46 47 // zForAffine returns a Jacobian Z value for the affine point (x, y). If x and 48 // y are zero, it assumes that they represent the point at infinity because (0, 49 // 0) is not on the any of the curves handled here. 50 func zForAffine(x, y *big.Int) *big.Int { 51 z := new(big.Int) 52 if x.Sign() != 0 || y.Sign() != 0 { 53 z.SetInt64(1) 54 } 55 return z 56 } 57 58 // affineFromJacobian reverses the Jacobian transform. See the comment at the 59 // top of the file. If the point is ∞ it returns 0, 0. 60 func (curve *CurveParams) affineFromJacobian(x, y, z *big.Int) (xOut, yOut *big.Int) { 61 if z.Sign() == 0 { 62 return new(big.Int), new(big.Int) 63 } 64 65 zinv := new(big.Int).ModInverse(z, curve.P) 66 zinvsq := new(big.Int).Mul(zinv, zinv) 67 68 xOut = new(big.Int).Mul(x, zinvsq) 69 xOut.Mod(xOut, curve.P) 70 zinvsq.Mul(zinvsq, zinv) 71 yOut = new(big.Int).Mul(y, zinvsq) 72 yOut.Mod(yOut, curve.P) 73 return 74 } 75 76 func (curve *CurveParams) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) { 77 z1 := zForAffine(x1, y1) 78 z2 := zForAffine(x2, y2) 79 return curve.affineFromJacobian(curve.addJacobian(x1, y1, z1, x2, y2, z2)) 80 } 81 82 // addJacobian takes two points in Jacobian coordinates, (x1, y1, z1) and 83 // (x2, y2, z2) and returns their sum, also in Jacobian form. 84 func (curve *CurveParams) addJacobian(x1, y1, z1, x2, y2, z2 *big.Int) (*big.Int, *big.Int, *big.Int) { 85 // See https://hyperelliptic.org/EFD/g1p/data/shortw/jacobian/addition/add-2007-bl 86 x3, y3, z3 := new(big.Int), new(big.Int), new(big.Int) 87 if z1.Sign() == 0 { 88 x3.Set(x2) 89 y3.Set(y2) 90 z3.Set(z2) 91 return x3, y3, z3 92 } 93 if z2.Sign() == 0 { 94 x3.Set(x1) 95 y3.Set(y1) 96 z3.Set(z1) 97 return x3, y3, z3 98 } 99 100 z1z1 := new(big.Int).Mul(z1, z1) 101 z1z1.Mod(z1z1, curve.P) 102 z2z2 := new(big.Int).Mul(z2, z2) 103 z2z2.Mod(z2z2, curve.P) 104 105 u1 := new(big.Int).Mul(x1, z2z2) 106 u1.Mod(u1, curve.P) 107 u2 := new(big.Int).Mul(x2, z1z1) 108 u2.Mod(u2, curve.P) 109 h := new(big.Int).Sub(u2, u1) 110 xEqual := h.Sign() == 0 111 if h.Sign() == -1 { 112 h.Add(h, curve.P) 113 } 114 i := new(big.Int).Lsh(h, 1) 115 i.Mul(i, i) 116 j := new(big.Int).Mul(h, i) 117 118 s1 := new(big.Int).Mul(y1, z2) 119 s1.Mul(s1, z2z2) 120 s1.Mod(s1, curve.P) 121 s2 := new(big.Int).Mul(y2, z1) 122 s2.Mul(s2, z1z1) 123 s2.Mod(s2, curve.P) 124 r := new(big.Int).Sub(s2, s1) 125 if r.Sign() == -1 { 126 r.Add(r, curve.P) 127 } 128 yEqual := r.Sign() == 0 129 if xEqual && yEqual { 130 return curve.doubleJacobian(x1, y1, z1) 131 } 132 r.Lsh(r, 1) 133 v := new(big.Int).Mul(u1, i) 134 135 x3.Set(r) 136 x3.Mul(x3, x3) 137 x3.Sub(x3, j) 138 x3.Sub(x3, v) 139 x3.Sub(x3, v) 140 x3.Mod(x3, curve.P) 141 142 y3.Set(r) 143 v.Sub(v, x3) 144 y3.Mul(y3, v) 145 s1.Mul(s1, j) 146 s1.Lsh(s1, 1) 147 y3.Sub(y3, s1) 148 y3.Mod(y3, curve.P) 149 150 z3.Add(z1, z2) 151 z3.Mul(z3, z3) 152 z3.Sub(z3, z1z1) 153 z3.Sub(z3, z2z2) 154 z3.Mul(z3, h) 155 z3.Mod(z3, curve.P) 156 157 return x3, y3, z3 158 } 159 160 func (curve *CurveParams) Double(x1, y1 *big.Int) (*big.Int, *big.Int) { 161 z1 := zForAffine(x1, y1) 162 return curve.affineFromJacobian(curve.doubleJacobian(x1, y1, z1)) 163 } 164 165 // doubleJacobian takes a point in Jacobian coordinates, (x, y, z), and 166 // returns its double, also in Jacobian form. 167 func (curve *CurveParams) doubleJacobian(x, y, z *big.Int) (*big.Int, *big.Int, *big.Int) { 168 // See https://hyperelliptic.org/EFD/g1p/data/shortw/jacobian/doubling/dbl-2007-bl 169 xx := new(big.Int).Mul(x, x) 170 xx.Mod(xx, curve.P) 171 yy := new(big.Int).Mul(y, y) 172 yy.Mod(yy, curve.P) 173 yyyy := new(big.Int).Mul(yy, yy) 174 yyyy.Mod(yyyy, curve.P) 175 zz := new(big.Int).Mul(z, z) 176 zz.Mod(zz, curve.P) 177 178 s := new(big.Int).Add(x, yy) 179 s.Mul(s, s) 180 s.Sub(s, xx) 181 if s.Sign() == -1 { 182 s.Add(s, curve.P) 183 } 184 s.Sub(s, yyyy) 185 if s.Sign() == -1 { 186 s.Add(s, curve.P) 187 } 188 s.Lsh(s, 1) 189 s.Mod(s, curve.P) 190 191 m := new(big.Int).Mul(xx, big.NewInt(3)) 192 m.Mod(m, curve.P) 193 tmp := new(big.Int).Mul(zz, zz) 194 tmp.Mul(tmp, curve.A) 195 tmp.Mod(tmp, curve.P) 196 m.Add(m, tmp) 197 m.Mod(m, curve.P) 198 199 t := new(big.Int).Mul(m, m) 200 t.Sub(t, s) 201 if t.Sign() == -1 { 202 t.Add(t, curve.P) 203 } 204 t.Sub(t, s) 205 if t.Sign() == -1 { 206 t.Add(t, curve.P) 207 } 208 t.Mod(t, curve.P) 209 x3 := t 210 211 y3 := new(big.Int).Sub(s, t) 212 y3.Mul(y3, m) 213 yyyy.Lsh(yyyy, 3) 214 y3.Sub(y3, yyyy) 215 if y3.Sign() == -1 { 216 y3.Add(y3, curve.P) 217 } 218 y3.Mod(y3, curve.P) 219 220 z3 := new(big.Int).Add(y, z) 221 z3.Mul(z3, z3) 222 z3.Sub(z3, yy) 223 if z3.Sign() == -1 { 224 z3.Add(z3, curve.P) 225 } 226 z3.Sub(z3, zz) 227 if z3.Sign() == -1 { 228 z3.Add(z3, curve.P) 229 } 230 z3.Mod(z3, curve.P) 231 232 return x3, y3, z3 233 } 234 235 func (curve *CurveParams) ScalarMult(Bx, By *big.Int, k []byte) (*big.Int, *big.Int) { 236 Bz := new(big.Int).SetInt64(1) 237 x, y, z := new(big.Int), new(big.Int), new(big.Int) 238 239 for _, byte := range k { 240 for bitNum := 0; bitNum < 8; bitNum++ { 241 x, y, z = curve.doubleJacobian(x, y, z) 242 if byte&0x80 == 0x80 { 243 x, y, z = curve.addJacobian(Bx, By, Bz, x, y, z) 244 } 245 byte <<= 1 246 } 247 } 248 249 return curve.affineFromJacobian(x, y, z) 250 } 251 252 func (curve *CurveParams) ScalarBaseMult(k []byte) (*big.Int, *big.Int) { 253 return curve.ScalarMult(curve.Gx, curve.Gy, k) 254 } 255 256 func bigFromHex(s string) *big.Int { 257 b, ok := new(big.Int).SetString(s, 16) 258 if !ok { 259 panic("sm2/elliptic: internal error: invalid encoding") 260 } 261 return b 262 } 263 264 var sampleParams = &CurveParams{ 265 elliptic.CurveParams{ 266 Name: "sampleCurve", 267 BitSize: 256, 268 P: bigFromHex("8542D69E4C044F18E8B92435BF6FF7DE457283915C45517D722EDB8B08F1DFC3"), 269 N: bigFromHex("8542D69E4C044F18E8B92435BF6FF7DD297720630485628D5AE74EE7C32E79B7"), 270 B: bigFromHex("63E4C6D3B23B0C849CF84241484BFE48F61D59A5B16BA06E6E12D1DA27C5249A"), 271 Gx: bigFromHex("421DEBD61B62EAB6746434EBC3CC315E32220B3BADD50BDC4C4E6C147FEDD43D"), 272 Gy: bigFromHex("0680512BCBB42C07D47349D2153B70C4E5D7FDFCBFA36EA1A85841B9E46E09A2"), 273 }, 274 bigFromHex("787968B4FA32C3FD2417842E73BBFEFF2F3C848B6831D7E0EC65228B3937E498"), 275 } 276 277 func TestPublicKey(t *testing.T) { 278 d := bigFromHex("6FCBA2EF9AE0AB902BC3BDE3FF915D44BA4CC78F88E2F8E7F8996D3B8CCEEDEE") 279 x, y := sampleParams.ScalarBaseMult(d.Bytes()) 280 if hex.EncodeToString(x.Bytes()) != "3099093bf3c137d8fcbbcdf4a2ae50f3b0f216c3122d79425fe03a45dbfe1655" || 281 hex.EncodeToString(y.Bytes()) != "3df79e8dac1cf0ecbaa2f2b49d51a4b387f2efaf482339086a27a8e05baed98b" { 282 t.FailNow() 283 } 284 d = bigFromHex("5E35D7D3F3C54DBAC72E61819E730B019A84208CA3A35E4C2E353DFCCB2A3B53") 285 x, y = sampleParams.ScalarBaseMult(d.Bytes()) 286 if hex.EncodeToString(x.Bytes()) != "245493d446c38d8cc0f118374690e7df633a8a4bfb3329b5ece604b2b4f37f43" || 287 hex.EncodeToString(y.Bytes()) != "53c0869f4b9e17773de68fec45e14904e0dea45bf6cecf9918c85ea047c60a4c" { 288 t.FailNow() 289 } 290 } 291 292 // calculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA) 293 func calculateSampleZA(pub *ecdsa.PublicKey, a *big.Int, uid []byte) ([]byte, error) { 294 uidLen := len(uid) 295 if uidLen >= 0x2000 { 296 return nil, errors.New("sm2: the uid is too long") 297 } 298 entla := uint16(uidLen) << 3 299 md := sm3.New() 300 md.Write([]byte{byte(entla >> 8), byte(entla)}) 301 if uidLen > 0 { 302 md.Write(uid) 303 } 304 md.Write(toBytes(pub.Curve, a)) 305 md.Write(toBytes(pub.Curve, pub.Params().B)) 306 md.Write(toBytes(pub.Curve, pub.Params().Gx)) 307 md.Write(toBytes(pub.Curve, pub.Params().Gy)) 308 md.Write(toBytes(pub.Curve, pub.X)) 309 md.Write(toBytes(pub.Curve, pub.Y)) 310 return md.Sum(nil), nil 311 } 312 313 // Sample from Appendix A.2 314 func TestKeyExchangeRealSample(t *testing.T) { 315 initiatorUID := []byte("ALICE123@YAHOO.COM") 316 responderUID := []byte("BILL456@YAHOO.COM") 317 kenLen := 16 318 319 // initiator's private key 320 privA := new(PrivateKey) 321 privA.D = bigFromHex("6FCBA2EF9AE0AB902BC3BDE3FF915D44BA4CC78F88E2F8E7F8996D3B8CCEEDEE") 322 privA.Curve = sampleParams 323 privA.X, privA.Y = privA.Curve.ScalarBaseMult(privA.D.Bytes()) 324 if hex.EncodeToString(privA.X.Bytes()) != "3099093bf3c137d8fcbbcdf4a2ae50f3b0f216c3122d79425fe03a45dbfe1655" || 325 hex.EncodeToString(privA.Y.Bytes()) != "3df79e8dac1cf0ecbaa2f2b49d51a4b387f2efaf482339086a27a8e05baed98b" { 326 t.Fatalf("unexpected public key PA") 327 } 328 329 // initiator's Z value 330 za, _ := calculateSampleZA(&privA.PublicKey, sampleParams.A, initiatorUID) 331 if hex.EncodeToString(za) != "e4d1d0c3ca4c7f11bc8ff8cb3f4c02a78f108fa098e51a668487240f75e20f31" { 332 t.Fatalf("unexpected ZA") 333 } 334 335 // responder's private key 336 privB := new(PrivateKey) 337 privB.D = bigFromHex("5E35D7D3F3C54DBAC72E61819E730B019A84208CA3A35E4C2E353DFCCB2A3B53") 338 privB.Curve = sampleParams 339 privB.X, privB.Y = privB.Curve.ScalarBaseMult(privB.D.Bytes()) 340 if hex.EncodeToString(privB.X.Bytes()) != "245493d446c38d8cc0f118374690e7df633a8a4bfb3329b5ece604b2b4f37f43" || 341 hex.EncodeToString(privB.Y.Bytes()) != "53c0869f4b9e17773de68fec45e14904e0dea45bf6cecf9918c85ea047c60a4c" { 342 t.Fatalf("unexpected public key PB") 343 } 344 // responder's Z value 345 zb, _ := calculateSampleZA(&privB.PublicKey, sampleParams.A, responderUID) 346 if hex.EncodeToString(zb) != "6b4b6d0e276691bd4a11bf72f4fb501ae309fdacb72fa6cc336e6656119abd67" { 347 t.Fatalf("unexpected ZB") 348 } 349 350 // create initiator 351 initiator, err := NewKeyExchange(privA, &privB.PublicKey, initiatorUID, responderUID, kenLen, true) 352 if err != nil { 353 t.Fatal(err) 354 } 355 // overwrite Z values, due to different A 356 initiator.z = za 357 initiator.peerZ = zb 358 359 // create responder 360 responder, err := NewKeyExchange(privB, &privA.PublicKey, responderUID, initiatorUID, kenLen, true) 361 if err != nil { 362 t.Fatal(err) 363 } 364 // overwrite Z values, due to different A 365 responder.z = zb 366 responder.peerZ = za 367 368 defer func() { 369 initiator.Destroy() 370 responder.Destroy() 371 }() 372 373 // for initiator's step A1-A3 374 rA := bigFromHex("83A2C9C8B96E5AF70BD480B472409A9A327257F1EBB73F5B073354B248668563") 375 initKeyExchange(initiator, rA) 376 if hex.EncodeToString(initiator.secret.X.Bytes()) != "6cb5633816f4dd560b1dec458310cbcc6856c09505324a6d23150c408f162bf0" || 377 hex.EncodeToString(initiator.secret.Y.Bytes()) != "0d6fcf62f1036c0a1b6daccf57399223a65f7d7bf2d9637e5bbbeb857961bf1a" { 378 t.Fatalf("unexpected RA") 379 } 380 381 // for responder's step B1-B8 382 rB := bigFromHex("33FE21940342161C55619C4A0C060293D543C80AF19748CE176D83477DE71C80") 383 RB, sB, _ := respondKeyExchange(responder, initiator.secret, rB) 384 if hex.EncodeToString(RB.X.Bytes()) != "1799b2a2c778295300d9a2325c686129b8f2b5337b3dcf4514e8bbc19d900ee5" || 385 hex.EncodeToString(RB.Y.Bytes()) != "54c9288c82733efdf7808ae7f27d0e732f7c73a7d9ac98b7d8740a91d0db3cf4" { 386 t.Fatalf("unexpected RB") 387 } 388 if hex.EncodeToString(sB) != "284c8f198f141b502e81250f1581c7e9eeb4ca6990f9e02df388b45471f5bc5c" { 389 t.Fatalf("unexpected sB") 390 } 391 392 // for initiator's step A4-A10 393 keyA, sA, err := initiator.ConfirmResponder(RB, sB) 394 if err != nil { 395 t.Fatal(err) 396 } 397 if hex.EncodeToString(sA) != "23444daf8ed7534366cb901c84b3bdbb63504f4065c1116c91a4c00697e6cf7a" { 398 t.Fatalf("unexpected sA") 399 } 400 401 // for responder's step B10 402 keyB, err := responder.ConfirmInitiator(sA) 403 if err != nil { 404 t.Fatal(err) 405 } 406 if !bytes.Equal(keyA, keyB) { 407 t.Errorf("got different key") 408 } 409 if !bytes.Equal(keyA, hexDecode(t, "55B0AC62A6B927BA23703832C853DED4")) { 410 t.Errorf("got unexpected keying data %v\n", hex.EncodeToString(keyA)) 411 } 412 }