github.com/emmansun/gmsm@v0.29.1/sm9/bn256/g1.go (about) 1 package bn256 2 3 import ( 4 "crypto/rand" 5 "errors" 6 "io" 7 "math/big" 8 "sync" 9 ) 10 11 func randomK(r io.Reader) (k *big.Int, err error) { 12 for { 13 k, err = rand.Int(r, Order) 14 if err != nil || k.Sign() > 0 { 15 return 16 } 17 } 18 } 19 20 // G1 is an abstract cyclic group. The zero value is suitable for use as the 21 // output of an operation, but cannot be used as an input. 22 type G1 struct { 23 p *curvePoint 24 } 25 26 // Gen1 is the generator of G1. 27 var Gen1 = &G1{curveGen} 28 29 var g1GeneratorTable *[32 * 2]curvePointTable 30 var g1GeneratorTableOnce sync.Once 31 32 func (g *G1) generatorTable() *[32 * 2]curvePointTable { 33 g1GeneratorTableOnce.Do(func() { 34 g1GeneratorTable = new([32 * 2]curvePointTable) 35 base := NewCurveGenerator() 36 for i := 0; i < 32*2; i++ { 37 g1GeneratorTable[i][0] = &curvePoint{} 38 g1GeneratorTable[i][0].Set(base) 39 40 g1GeneratorTable[i][1] = &curvePoint{} 41 g1GeneratorTable[i][1].Double(g1GeneratorTable[i][0]) 42 g1GeneratorTable[i][2] = &curvePoint{} 43 g1GeneratorTable[i][2].Add(g1GeneratorTable[i][1], base) 44 45 g1GeneratorTable[i][3] = &curvePoint{} 46 g1GeneratorTable[i][3].Double(g1GeneratorTable[i][1]) 47 g1GeneratorTable[i][4] = &curvePoint{} 48 g1GeneratorTable[i][4].Add(g1GeneratorTable[i][3], base) 49 50 g1GeneratorTable[i][5] = &curvePoint{} 51 g1GeneratorTable[i][5].Double(g1GeneratorTable[i][2]) 52 g1GeneratorTable[i][6] = &curvePoint{} 53 g1GeneratorTable[i][6].Add(g1GeneratorTable[i][5], base) 54 55 g1GeneratorTable[i][7] = &curvePoint{} 56 g1GeneratorTable[i][7].Double(g1GeneratorTable[i][3]) 57 g1GeneratorTable[i][8] = &curvePoint{} 58 g1GeneratorTable[i][8].Add(g1GeneratorTable[i][7], base) 59 60 g1GeneratorTable[i][9] = &curvePoint{} 61 g1GeneratorTable[i][9].Double(g1GeneratorTable[i][4]) 62 g1GeneratorTable[i][10] = &curvePoint{} 63 g1GeneratorTable[i][10].Add(g1GeneratorTable[i][9], base) 64 65 g1GeneratorTable[i][11] = &curvePoint{} 66 g1GeneratorTable[i][11].Double(g1GeneratorTable[i][5]) 67 g1GeneratorTable[i][12] = &curvePoint{} 68 g1GeneratorTable[i][12].Add(g1GeneratorTable[i][11], base) 69 70 g1GeneratorTable[i][13] = &curvePoint{} 71 g1GeneratorTable[i][13].Double(g1GeneratorTable[i][6]) 72 g1GeneratorTable[i][14] = &curvePoint{} 73 g1GeneratorTable[i][14].Add(g1GeneratorTable[i][13], base) 74 75 base.Double(base) 76 base.Double(base) 77 base.Double(base) 78 base.Double(base) 79 } 80 }) 81 return g1GeneratorTable 82 } 83 84 // RandomG1 returns x and g₁ˣ where x is a random, non-zero number read from r. 85 func RandomG1(r io.Reader) (*big.Int, *G1, error) { 86 k, err := randomK(r) 87 if err != nil { 88 return nil, nil, err 89 } 90 91 g1, err := new(G1).ScalarBaseMult(NormalizeScalar(k.Bytes())) 92 return k, g1, err 93 } 94 95 func (g *G1) String() string { 96 return "sm9.G1" + g.p.String() 97 } 98 99 func NormalizeScalar(scalar []byte) []byte { 100 if len(scalar) == 32 { 101 return scalar 102 } 103 s := new(big.Int).SetBytes(scalar) 104 if len(scalar) > 32 { 105 s.Mod(s, Order) 106 } 107 out := make([]byte, 32) 108 return s.FillBytes(out) 109 } 110 111 // ScalarBaseMult sets e to scaler*g where g is the generator of the group and then 112 // returns e. 113 func (e *G1) ScalarBaseMult(scalar []byte) (*G1, error) { 114 if len(scalar) != 32 { 115 return nil, errors.New("invalid scalar length") 116 } 117 if e.p == nil { 118 e.p = &curvePoint{} 119 } 120 121 //e.p.Mul(curveGen, k) 122 123 tables := e.generatorTable() 124 // This is also a scalar multiplication with a four-bit window like in 125 // ScalarMult, but in this case the doublings are precomputed. The value 126 // [windowValue]G added at iteration k would normally get doubled 127 // (totIterations-k)×4 times, but with a larger precomputation we can 128 // instead add [2^((totIterations-k)×4)][windowValue]G and avoid the 129 // doublings between iterations. 130 t := NewCurvePoint() 131 e.p.SetInfinity() 132 tableIndex := len(tables) - 1 133 for _, byte := range scalar { 134 windowValue := byte >> 4 135 tables[tableIndex].Select(t, windowValue) 136 e.p.Add(e.p, t) 137 tableIndex-- 138 windowValue = byte & 0b1111 139 tables[tableIndex].Select(t, windowValue) 140 e.p.Add(e.p, t) 141 tableIndex-- 142 } 143 return e, nil 144 } 145 146 // ScalarMult sets e to a*k and then returns e. 147 func (e *G1) ScalarMult(a *G1, scalar []byte) (*G1, error) { 148 if e.p == nil { 149 e.p = &curvePoint{} 150 } 151 //e.p.Mul(a.p, k) 152 // Compute a curvePointTable for the base point a. 153 var table = curvePointTable{NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), 154 NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), 155 NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), 156 NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), NewCurvePoint()} 157 table[0].Set(a.p) 158 for i := 1; i < 15; i += 2 { 159 table[i].Double(table[i/2]) 160 table[i+1].Add(table[i], a.p) 161 } 162 // Instead of doing the classic double-and-add chain, we do it with a 163 // four-bit window: we double four times, and then add [0-15]P. 164 t := &G1{NewCurvePoint()} 165 e.p.SetInfinity() 166 for i, byte := range scalar { 167 // No need to double on the first iteration, as p is the identity at 168 // this point, and [N]∞ = ∞. 169 if i != 0 { 170 e.Double(e) 171 e.Double(e) 172 e.Double(e) 173 e.Double(e) 174 } 175 windowValue := byte >> 4 176 table.Select(t.p, windowValue) 177 e.Add(e, t) 178 e.Double(e) 179 e.Double(e) 180 e.Double(e) 181 e.Double(e) 182 windowValue = byte & 0b1111 183 table.Select(t.p, windowValue) 184 e.Add(e, t) 185 } 186 return e, nil 187 } 188 189 // Add sets e to a+b and then returns e. 190 func (e *G1) Add(a, b *G1) *G1 { 191 if e.p == nil { 192 e.p = &curvePoint{} 193 } 194 e.p.Add(a.p, b.p) 195 return e 196 } 197 198 // Double sets e to [2]a and then returns e. 199 func (e *G1) Double(a *G1) *G1 { 200 if e.p == nil { 201 e.p = &curvePoint{} 202 } 203 e.p.Double(a.p) 204 return e 205 } 206 207 // Neg sets e to -a and then returns e. 208 func (e *G1) Neg(a *G1) *G1 { 209 if e.p == nil { 210 e.p = &curvePoint{} 211 } 212 e.p.Neg(a.p) 213 return e 214 } 215 216 // Set sets e to a and then returns e. 217 func (e *G1) Set(a *G1) *G1 { 218 if e.p == nil { 219 e.p = &curvePoint{} 220 } 221 e.p.Set(a.p) 222 return e 223 } 224 225 // Marshal converts e to a byte slice. 226 func (e *G1) Marshal() []byte { 227 // Each value is a 256-bit number. 228 const numBytes = 256 / 8 229 230 ret := make([]byte, numBytes*2) 231 232 e.fillBytes(ret) 233 return ret 234 } 235 236 // MarshalUncompressed converts e to a byte slice with prefix 237 func (e *G1) MarshalUncompressed() []byte { 238 // Each value is a 256-bit number. 239 const numBytes = 256 / 8 240 241 ret := make([]byte, numBytes*2+1) 242 ret[0] = 4 243 244 e.fillBytes(ret[1:]) 245 return ret 246 } 247 248 // MarshalCompressed converts e to a byte slice with compress prefix. 249 // If the point is not on the curve (or is the conventional point at infinity), the behavior is undefined. 250 func (e *G1) MarshalCompressed() []byte { 251 // Each value is a 256-bit number. 252 const numBytes = 256 / 8 253 ret := make([]byte, numBytes+1) 254 if e.p == nil { 255 e.p = &curvePoint{} 256 } 257 258 e.p.MakeAffine() 259 temp := &gfP{} 260 montDecode(temp, &e.p.y) 261 262 temp.Marshal(ret[1:]) 263 ret[0] = (ret[numBytes] & 1) | 2 264 montDecode(temp, &e.p.x) 265 temp.Marshal(ret[1:]) 266 267 return ret 268 } 269 270 // UnmarshalCompressed sets e to the result of converting the output of Marshal back into 271 // a group element and then returns e. 272 func (e *G1) UnmarshalCompressed(data []byte) ([]byte, error) { 273 // Each value is a 256-bit number. 274 const numBytes = 256 / 8 275 if len(data) < 1+numBytes { 276 return nil, errors.New("sm9.G1: not enough data") 277 } 278 if data[0] != 2 && data[0] != 3 { // compressed form 279 return nil, errors.New("sm9.G1: invalid point compress byte") 280 } 281 if e.p == nil { 282 e.p = &curvePoint{} 283 } else { 284 e.p.x.Set(zero) 285 e.p.y.Set(zero) 286 } 287 e.p.x.Unmarshal(data[1:]) 288 montEncode(&e.p.x, &e.p.x) 289 x3 := e.p.polynomial(&e.p.x) 290 e.p.y.Sqrt(x3) 291 montDecode(x3, &e.p.y) 292 if byte(x3[0]&1) != data[0]&1 { 293 gfpNeg(&e.p.y, &e.p.y) 294 } 295 if e.p.x.Equal(zero) == 1 && e.p.y.Equal(zero) == 1 { 296 // This is the point at infinity. 297 e.p.SetInfinity() 298 } else { 299 e.p.z.Set(one) 300 e.p.t.Set(one) 301 302 if !e.p.IsOnCurve() { 303 return nil, errors.New("sm9.G1: malformed point") 304 } 305 } 306 307 return data[numBytes+1:], nil 308 } 309 310 func (e *G1) fillBytes(buffer []byte) { 311 const numBytes = 256 / 8 312 313 if e.p == nil { 314 e.p = &curvePoint{} 315 } 316 317 e.p.MakeAffine() 318 if e.p.IsInfinity() { 319 return 320 } 321 temp := &gfP{} 322 323 montDecode(temp, &e.p.x) 324 temp.Marshal(buffer) 325 montDecode(temp, &e.p.y) 326 temp.Marshal(buffer[numBytes:]) 327 } 328 329 // Unmarshal sets e to the result of converting the output of Marshal back into 330 // a group element and then returns e. 331 func (e *G1) Unmarshal(m []byte) ([]byte, error) { 332 // Each value is a 256-bit number. 333 const numBytes = 256 / 8 334 335 if len(m) < 2*numBytes { 336 return nil, errors.New("sm9.G1: not enough data") 337 } 338 339 if e.p == nil { 340 e.p = &curvePoint{} 341 } else { 342 e.p.x.Set(zero) 343 e.p.y.Set(zero) 344 } 345 346 e.p.x.Unmarshal(m) 347 e.p.y.Unmarshal(m[numBytes:]) 348 montEncode(&e.p.x, &e.p.x) 349 montEncode(&e.p.y, &e.p.y) 350 351 if e.p.x.Equal(zero) == 1 && e.p.y.Equal(zero) == 1 { 352 // This is the point at infinity. 353 e.p.SetInfinity() 354 } else { 355 e.p.z.Set(one) 356 e.p.t.Set(one) 357 358 if !e.p.IsOnCurve() { 359 return nil, errors.New("sm9.G1: malformed point") 360 } 361 } 362 363 return m[2*numBytes:], nil 364 } 365 366 // Equal compare e and other 367 func (e *G1) Equal(other *G1) bool { 368 if e.p == nil && other.p == nil { 369 return true 370 } 371 return e.p.Equal(other.p) 372 } 373 374 // IsOnCurve returns true if e is on the curve. 375 func (e *G1) IsOnCurve() bool { 376 return e.p.IsOnCurve() 377 } 378 379 type G1Curve struct { 380 params *CurveParams 381 g G1 382 } 383 384 var g1Curve = &G1Curve{ 385 params: &CurveParams{ 386 Name: "sm9", 387 BitSize: 256, 388 P: bigFromHex("B640000002A3A6F1D603AB4FF58EC74521F2934B1A7AEEDBE56F9B27E351457D"), 389 N: bigFromHex("B640000002A3A6F1D603AB4FF58EC74449F2934B18EA8BEEE56EE19CD69ECF25"), 390 B: bigFromHex("0000000000000000000000000000000000000000000000000000000000000005"), 391 Gx: bigFromHex("93DE051D62BF718FF5ED0704487D01D6E1E4086909DC3280E8C4E4817C66DDDD"), 392 Gy: bigFromHex("21FE8DDA4F21E607631065125C395BBC1C1C00CBFA6024350C464CD70A3EA616"), 393 }, 394 g: G1{}, 395 } 396 397 func (g1 *G1Curve) pointFromAffine(x, y *big.Int) (a *G1, err error) { 398 a = &G1{&curvePoint{}} 399 if x.Sign() == 0 { 400 a.p.SetInfinity() 401 return a, nil 402 } 403 // Reject values that would not get correctly encoded. 404 if x.Sign() < 0 || y.Sign() < 0 { 405 return a, errors.New("negative coordinate") 406 } 407 if x.BitLen() > g1.params.BitSize || y.BitLen() > g1.params.BitSize { 408 return a, errors.New("overflowing coordinate") 409 } 410 a.p.x = *fromBigInt(x) 411 a.p.y = *fromBigInt(y) 412 a.p.z = *newGFp(1) 413 a.p.t = *newGFp(1) 414 415 if !a.p.IsOnCurve() { 416 return a, errors.New("point not on G1 curve") 417 } 418 419 return a, nil 420 } 421 422 func (g1 *G1Curve) Params() *CurveParams { 423 return g1.params 424 } 425 426 // normalizeScalar brings the scalar within the byte size of the order of the 427 // curve, as expected by the nistec scalar multiplication functions. 428 func (curve *G1Curve) normalizeScalar(scalar []byte) []byte { 429 byteSize := (curve.params.N.BitLen() + 7) / 8 430 s := new(big.Int).SetBytes(scalar) 431 if len(scalar) > byteSize { 432 s.Mod(s, curve.params.N) 433 } 434 out := make([]byte, byteSize) 435 return s.FillBytes(out) 436 } 437 438 func (g1 *G1Curve) ScalarBaseMult(scalar []byte) (*big.Int, *big.Int) { 439 scalar = g1.normalizeScalar(scalar) 440 p, err := g1.g.ScalarBaseMult(scalar) 441 if err != nil { 442 panic("sm9: g1 rejected normalized scalar") 443 } 444 res := p.Marshal() 445 return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:]) 446 } 447 448 func (g1 *G1Curve) ScalarMult(Bx, By *big.Int, scalar []byte) (*big.Int, *big.Int) { 449 a, err := g1.pointFromAffine(Bx, By) 450 if err != nil { 451 panic("sm9: ScalarMult was called on an invalid point") 452 } 453 scalar = g1.normalizeScalar(scalar) 454 p, err := g1.g.ScalarMult(a, scalar) 455 if err != nil { 456 panic("sm9: g1 rejected normalized scalar") 457 } 458 res := p.Marshal() 459 return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:]) 460 } 461 462 func (g1 *G1Curve) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) { 463 a, err := g1.pointFromAffine(x1, y1) 464 if err != nil { 465 panic("sm9: Add was called on an invalid point") 466 } 467 b, err := g1.pointFromAffine(x2, y2) 468 if err != nil { 469 panic("sm9: Add was called on an invalid point") 470 } 471 res := g1.g.Add(a, b).Marshal() 472 return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:]) 473 } 474 475 func (g1 *G1Curve) Double(x, y *big.Int) (*big.Int, *big.Int) { 476 a, err := g1.pointFromAffine(x, y) 477 if err != nil { 478 panic("sm9: Double was called on an invalid point") 479 } 480 res := g1.g.Double(a).Marshal() 481 return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:]) 482 } 483 484 func (g1 *G1Curve) IsOnCurve(x, y *big.Int) bool { 485 _, err := g1.pointFromAffine(x, y) 486 return err == nil 487 } 488 489 func (curve *G1Curve) UnmarshalCompressed(data []byte) (x, y *big.Int) { 490 if len(data) != 33 || (data[0] != 2 && data[0] != 3) { 491 return nil, nil 492 } 493 r := &gfP{} 494 r.Unmarshal(data[1:33]) 495 if lessThanP(r) == 0 { 496 return nil, nil 497 } 498 x = new(big.Int).SetBytes(data[1:33]) 499 p := &curvePoint{} 500 montEncode(r, r) 501 p.x = *r 502 p.z = *newGFp(1) 503 p.t = *newGFp(1) 504 y2 := &gfP{} 505 gfpMul(y2, r, r) 506 gfpMul(y2, y2, r) 507 gfpAdd(y2, y2, curveB) 508 y2.Sqrt(y2) 509 p.y = *y2 510 if !p.IsOnCurve() { 511 return nil, nil 512 } 513 montDecode(y2, y2) 514 ret := make([]byte, 32) 515 y2.Marshal(ret) 516 y = new(big.Int).SetBytes(ret) 517 if byte(y.Bit(0)) != data[0]&1 { 518 gfpNeg(y2, y2) 519 y2.Marshal(ret) 520 y.SetBytes(ret) 521 } 522 return x, y 523 } 524 525 func (curve *G1Curve) Unmarshal(data []byte) (x, y *big.Int) { 526 if len(data) != 65 || (data[0] != 4) { 527 return nil, nil 528 } 529 x1 := &gfP{} 530 x1.Unmarshal(data[1:33]) 531 y1 := &gfP{} 532 y1.Unmarshal(data[33:]) 533 if lessThanP(x1) == 0 || lessThanP(y1) == 0 { 534 return nil, nil 535 } 536 montEncode(x1, x1) 537 montEncode(y1, y1) 538 p := &curvePoint{ 539 x: *x1, 540 y: *y1, 541 z: *newGFp(1), 542 t: *newGFp(1), 543 } 544 if !p.IsOnCurve() { 545 return nil, nil 546 } 547 x = new(big.Int).SetBytes(data[1:33]) 548 y = new(big.Int).SetBytes(data[33:]) 549 return x, y 550 }