github.com/emmansun/gmsm@v0.29.1/sm9/bn256/g1.go (about)

     1  package bn256
     2  
     3  import (
     4  	"crypto/rand"
     5  	"errors"
     6  	"io"
     7  	"math/big"
     8  	"sync"
     9  )
    10  
    11  func randomK(r io.Reader) (k *big.Int, err error) {
    12  	for {
    13  		k, err = rand.Int(r, Order)
    14  		if err != nil || k.Sign() > 0 {
    15  			return
    16  		}
    17  	}
    18  }
    19  
    20  // G1 is an abstract cyclic group. The zero value is suitable for use as the
    21  // output of an operation, but cannot be used as an input.
    22  type G1 struct {
    23  	p *curvePoint
    24  }
    25  
    26  // Gen1 is the generator of G1.
    27  var Gen1 = &G1{curveGen}
    28  
    29  var g1GeneratorTable *[32 * 2]curvePointTable
    30  var g1GeneratorTableOnce sync.Once
    31  
    32  func (g *G1) generatorTable() *[32 * 2]curvePointTable {
    33  	g1GeneratorTableOnce.Do(func() {
    34  		g1GeneratorTable = new([32 * 2]curvePointTable)
    35  		base := NewCurveGenerator()
    36  		for i := 0; i < 32*2; i++ {
    37  			g1GeneratorTable[i][0] = &curvePoint{}
    38  			g1GeneratorTable[i][0].Set(base)
    39  
    40  			g1GeneratorTable[i][1] = &curvePoint{}
    41  			g1GeneratorTable[i][1].Double(g1GeneratorTable[i][0])
    42  			g1GeneratorTable[i][2] = &curvePoint{}
    43  			g1GeneratorTable[i][2].Add(g1GeneratorTable[i][1], base)
    44  
    45  			g1GeneratorTable[i][3] = &curvePoint{}
    46  			g1GeneratorTable[i][3].Double(g1GeneratorTable[i][1])
    47  			g1GeneratorTable[i][4] = &curvePoint{}
    48  			g1GeneratorTable[i][4].Add(g1GeneratorTable[i][3], base)
    49  
    50  			g1GeneratorTable[i][5] = &curvePoint{}
    51  			g1GeneratorTable[i][5].Double(g1GeneratorTable[i][2])
    52  			g1GeneratorTable[i][6] = &curvePoint{}
    53  			g1GeneratorTable[i][6].Add(g1GeneratorTable[i][5], base)
    54  
    55  			g1GeneratorTable[i][7] = &curvePoint{}
    56  			g1GeneratorTable[i][7].Double(g1GeneratorTable[i][3])
    57  			g1GeneratorTable[i][8] = &curvePoint{}
    58  			g1GeneratorTable[i][8].Add(g1GeneratorTable[i][7], base)
    59  
    60  			g1GeneratorTable[i][9] = &curvePoint{}
    61  			g1GeneratorTable[i][9].Double(g1GeneratorTable[i][4])
    62  			g1GeneratorTable[i][10] = &curvePoint{}
    63  			g1GeneratorTable[i][10].Add(g1GeneratorTable[i][9], base)
    64  
    65  			g1GeneratorTable[i][11] = &curvePoint{}
    66  			g1GeneratorTable[i][11].Double(g1GeneratorTable[i][5])
    67  			g1GeneratorTable[i][12] = &curvePoint{}
    68  			g1GeneratorTable[i][12].Add(g1GeneratorTable[i][11], base)
    69  
    70  			g1GeneratorTable[i][13] = &curvePoint{}
    71  			g1GeneratorTable[i][13].Double(g1GeneratorTable[i][6])
    72  			g1GeneratorTable[i][14] = &curvePoint{}
    73  			g1GeneratorTable[i][14].Add(g1GeneratorTable[i][13], base)
    74  
    75  			base.Double(base)
    76  			base.Double(base)
    77  			base.Double(base)
    78  			base.Double(base)
    79  		}
    80  	})
    81  	return g1GeneratorTable
    82  }
    83  
    84  // RandomG1 returns x and g₁ˣ where x is a random, non-zero number read from r.
    85  func RandomG1(r io.Reader) (*big.Int, *G1, error) {
    86  	k, err := randomK(r)
    87  	if err != nil {
    88  		return nil, nil, err
    89  	}
    90  
    91  	g1, err := new(G1).ScalarBaseMult(NormalizeScalar(k.Bytes()))
    92  	return k, g1, err
    93  }
    94  
    95  func (g *G1) String() string {
    96  	return "sm9.G1" + g.p.String()
    97  }
    98  
    99  func NormalizeScalar(scalar []byte) []byte {
   100  	if len(scalar) == 32 {
   101  		return scalar
   102  	}
   103  	s := new(big.Int).SetBytes(scalar)
   104  	if len(scalar) > 32 {
   105  		s.Mod(s, Order)
   106  	}
   107  	out := make([]byte, 32)
   108  	return s.FillBytes(out)
   109  }
   110  
   111  // ScalarBaseMult sets e to scaler*g where g is the generator of the group and then
   112  // returns e.
   113  func (e *G1) ScalarBaseMult(scalar []byte) (*G1, error) {
   114  	if len(scalar) != 32 {
   115  		return nil, errors.New("invalid scalar length")
   116  	}
   117  	if e.p == nil {
   118  		e.p = &curvePoint{}
   119  	}
   120  
   121  	//e.p.Mul(curveGen, k)
   122  
   123  	tables := e.generatorTable()
   124  	// This is also a scalar multiplication with a four-bit window like in
   125  	// ScalarMult, but in this case the doublings are precomputed. The value
   126  	// [windowValue]G added at iteration k would normally get doubled
   127  	// (totIterations-k)×4 times, but with a larger precomputation we can
   128  	// instead add [2^((totIterations-k)×4)][windowValue]G and avoid the
   129  	// doublings between iterations.
   130  	t := NewCurvePoint()
   131  	e.p.SetInfinity()
   132  	tableIndex := len(tables) - 1
   133  	for _, byte := range scalar {
   134  		windowValue := byte >> 4
   135  		tables[tableIndex].Select(t, windowValue)
   136  		e.p.Add(e.p, t)
   137  		tableIndex--
   138  		windowValue = byte & 0b1111
   139  		tables[tableIndex].Select(t, windowValue)
   140  		e.p.Add(e.p, t)
   141  		tableIndex--
   142  	}
   143  	return e, nil
   144  }
   145  
   146  // ScalarMult sets e to a*k and then returns e.
   147  func (e *G1) ScalarMult(a *G1, scalar []byte) (*G1, error) {
   148  	if e.p == nil {
   149  		e.p = &curvePoint{}
   150  	}
   151  	//e.p.Mul(a.p, k)
   152  	// Compute a curvePointTable for the base point a.
   153  	var table = curvePointTable{NewCurvePoint(), NewCurvePoint(), NewCurvePoint(),
   154  		NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), NewCurvePoint(),
   155  		NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), NewCurvePoint(),
   156  		NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), NewCurvePoint()}
   157  	table[0].Set(a.p)
   158  	for i := 1; i < 15; i += 2 {
   159  		table[i].Double(table[i/2])
   160  		table[i+1].Add(table[i], a.p)
   161  	}
   162  	// Instead of doing the classic double-and-add chain, we do it with a
   163  	// four-bit window: we double four times, and then add [0-15]P.
   164  	t := &G1{NewCurvePoint()}
   165  	e.p.SetInfinity()
   166  	for i, byte := range scalar {
   167  		// No need to double on the first iteration, as p is the identity at
   168  		// this point, and [N]∞ = ∞.
   169  		if i != 0 {
   170  			e.Double(e)
   171  			e.Double(e)
   172  			e.Double(e)
   173  			e.Double(e)
   174  		}
   175  		windowValue := byte >> 4
   176  		table.Select(t.p, windowValue)
   177  		e.Add(e, t)
   178  		e.Double(e)
   179  		e.Double(e)
   180  		e.Double(e)
   181  		e.Double(e)
   182  		windowValue = byte & 0b1111
   183  		table.Select(t.p, windowValue)
   184  		e.Add(e, t)
   185  	}
   186  	return e, nil
   187  }
   188  
   189  // Add sets e to a+b and then returns e.
   190  func (e *G1) Add(a, b *G1) *G1 {
   191  	if e.p == nil {
   192  		e.p = &curvePoint{}
   193  	}
   194  	e.p.Add(a.p, b.p)
   195  	return e
   196  }
   197  
   198  // Double sets e to [2]a and then returns e.
   199  func (e *G1) Double(a *G1) *G1 {
   200  	if e.p == nil {
   201  		e.p = &curvePoint{}
   202  	}
   203  	e.p.Double(a.p)
   204  	return e
   205  }
   206  
   207  // Neg sets e to -a and then returns e.
   208  func (e *G1) Neg(a *G1) *G1 {
   209  	if e.p == nil {
   210  		e.p = &curvePoint{}
   211  	}
   212  	e.p.Neg(a.p)
   213  	return e
   214  }
   215  
   216  // Set sets e to a and then returns e.
   217  func (e *G1) Set(a *G1) *G1 {
   218  	if e.p == nil {
   219  		e.p = &curvePoint{}
   220  	}
   221  	e.p.Set(a.p)
   222  	return e
   223  }
   224  
   225  // Marshal converts e to a byte slice.
   226  func (e *G1) Marshal() []byte {
   227  	// Each value is a 256-bit number.
   228  	const numBytes = 256 / 8
   229  
   230  	ret := make([]byte, numBytes*2)
   231  
   232  	e.fillBytes(ret)
   233  	return ret
   234  }
   235  
   236  // MarshalUncompressed converts e to a byte slice with prefix
   237  func (e *G1) MarshalUncompressed() []byte {
   238  	// Each value is a 256-bit number.
   239  	const numBytes = 256 / 8
   240  
   241  	ret := make([]byte, numBytes*2+1)
   242  	ret[0] = 4
   243  
   244  	e.fillBytes(ret[1:])
   245  	return ret
   246  }
   247  
   248  // MarshalCompressed converts e to a byte slice with compress prefix.
   249  // If the point is not on the curve (or is the conventional point at infinity), the behavior is undefined.
   250  func (e *G1) MarshalCompressed() []byte {
   251  	// Each value is a 256-bit number.
   252  	const numBytes = 256 / 8
   253  	ret := make([]byte, numBytes+1)
   254  	if e.p == nil {
   255  		e.p = &curvePoint{}
   256  	}
   257  
   258  	e.p.MakeAffine()
   259  	temp := &gfP{}
   260  	montDecode(temp, &e.p.y)
   261  
   262  	temp.Marshal(ret[1:])
   263  	ret[0] = (ret[numBytes] & 1) | 2
   264  	montDecode(temp, &e.p.x)
   265  	temp.Marshal(ret[1:])
   266  
   267  	return ret
   268  }
   269  
   270  // UnmarshalCompressed sets e to the result of converting the output of Marshal back into
   271  // a group element and then returns e.
   272  func (e *G1) UnmarshalCompressed(data []byte) ([]byte, error) {
   273  	// Each value is a 256-bit number.
   274  	const numBytes = 256 / 8
   275  	if len(data) < 1+numBytes {
   276  		return nil, errors.New("sm9.G1: not enough data")
   277  	}
   278  	if data[0] != 2 && data[0] != 3 { // compressed form
   279  		return nil, errors.New("sm9.G1: invalid point compress byte")
   280  	}
   281  	if e.p == nil {
   282  		e.p = &curvePoint{}
   283  	} else {
   284  		e.p.x.Set(zero)
   285  		e.p.y.Set(zero)
   286  	}
   287  	e.p.x.Unmarshal(data[1:])
   288  	montEncode(&e.p.x, &e.p.x)
   289  	x3 := e.p.polynomial(&e.p.x)
   290  	e.p.y.Sqrt(x3)
   291  	montDecode(x3, &e.p.y)
   292  	if byte(x3[0]&1) != data[0]&1 {
   293  		gfpNeg(&e.p.y, &e.p.y)
   294  	}
   295  	if e.p.x.Equal(zero) == 1 && e.p.y.Equal(zero) == 1 {
   296  		// This is the point at infinity.
   297  		e.p.SetInfinity()
   298  	} else {
   299  		e.p.z.Set(one)
   300  		e.p.t.Set(one)
   301  
   302  		if !e.p.IsOnCurve() {
   303  			return nil, errors.New("sm9.G1: malformed point")
   304  		}
   305  	}
   306  
   307  	return data[numBytes+1:], nil
   308  }
   309  
   310  func (e *G1) fillBytes(buffer []byte) {
   311  	const numBytes = 256 / 8
   312  
   313  	if e.p == nil {
   314  		e.p = &curvePoint{}
   315  	}
   316  
   317  	e.p.MakeAffine()
   318  	if e.p.IsInfinity() {
   319  		return
   320  	}
   321  	temp := &gfP{}
   322  
   323  	montDecode(temp, &e.p.x)
   324  	temp.Marshal(buffer)
   325  	montDecode(temp, &e.p.y)
   326  	temp.Marshal(buffer[numBytes:])
   327  }
   328  
   329  // Unmarshal sets e to the result of converting the output of Marshal back into
   330  // a group element and then returns e.
   331  func (e *G1) Unmarshal(m []byte) ([]byte, error) {
   332  	// Each value is a 256-bit number.
   333  	const numBytes = 256 / 8
   334  
   335  	if len(m) < 2*numBytes {
   336  		return nil, errors.New("sm9.G1: not enough data")
   337  	}
   338  
   339  	if e.p == nil {
   340  		e.p = &curvePoint{}
   341  	} else {
   342  		e.p.x.Set(zero)
   343  		e.p.y.Set(zero)
   344  	}
   345  
   346  	e.p.x.Unmarshal(m)
   347  	e.p.y.Unmarshal(m[numBytes:])
   348  	montEncode(&e.p.x, &e.p.x)
   349  	montEncode(&e.p.y, &e.p.y)
   350  
   351  	if e.p.x.Equal(zero) == 1 && e.p.y.Equal(zero) == 1 {
   352  		// This is the point at infinity.
   353  		e.p.SetInfinity()
   354  	} else {
   355  		e.p.z.Set(one)
   356  		e.p.t.Set(one)
   357  
   358  		if !e.p.IsOnCurve() {
   359  			return nil, errors.New("sm9.G1: malformed point")
   360  		}
   361  	}
   362  
   363  	return m[2*numBytes:], nil
   364  }
   365  
   366  // Equal compare e and other
   367  func (e *G1) Equal(other *G1) bool {
   368  	if e.p == nil && other.p == nil {
   369  		return true
   370  	}
   371  	return e.p.Equal(other.p)
   372  }
   373  
   374  // IsOnCurve returns true if e is on the curve.
   375  func (e *G1) IsOnCurve() bool {
   376  	return e.p.IsOnCurve()
   377  }
   378  
   379  type G1Curve struct {
   380  	params *CurveParams
   381  	g      G1
   382  }
   383  
   384  var g1Curve = &G1Curve{
   385  	params: &CurveParams{
   386  		Name:    "sm9",
   387  		BitSize: 256,
   388  		P:       bigFromHex("B640000002A3A6F1D603AB4FF58EC74521F2934B1A7AEEDBE56F9B27E351457D"),
   389  		N:       bigFromHex("B640000002A3A6F1D603AB4FF58EC74449F2934B18EA8BEEE56EE19CD69ECF25"),
   390  		B:       bigFromHex("0000000000000000000000000000000000000000000000000000000000000005"),
   391  		Gx:      bigFromHex("93DE051D62BF718FF5ED0704487D01D6E1E4086909DC3280E8C4E4817C66DDDD"),
   392  		Gy:      bigFromHex("21FE8DDA4F21E607631065125C395BBC1C1C00CBFA6024350C464CD70A3EA616"),
   393  	},
   394  	g: G1{},
   395  }
   396  
   397  func (g1 *G1Curve) pointFromAffine(x, y *big.Int) (a *G1, err error) {
   398  	a = &G1{&curvePoint{}}
   399  	if x.Sign() == 0 {
   400  		a.p.SetInfinity()
   401  		return a, nil
   402  	}
   403  	// Reject values that would not get correctly encoded.
   404  	if x.Sign() < 0 || y.Sign() < 0 {
   405  		return a, errors.New("negative coordinate")
   406  	}
   407  	if x.BitLen() > g1.params.BitSize || y.BitLen() > g1.params.BitSize {
   408  		return a, errors.New("overflowing coordinate")
   409  	}
   410  	a.p.x = *fromBigInt(x)
   411  	a.p.y = *fromBigInt(y)
   412  	a.p.z = *newGFp(1)
   413  	a.p.t = *newGFp(1)
   414  
   415  	if !a.p.IsOnCurve() {
   416  		return a, errors.New("point not on G1 curve")
   417  	}
   418  
   419  	return a, nil
   420  }
   421  
   422  func (g1 *G1Curve) Params() *CurveParams {
   423  	return g1.params
   424  }
   425  
   426  // normalizeScalar brings the scalar within the byte size of the order of the
   427  // curve, as expected by the nistec scalar multiplication functions.
   428  func (curve *G1Curve) normalizeScalar(scalar []byte) []byte {
   429  	byteSize := (curve.params.N.BitLen() + 7) / 8
   430  	s := new(big.Int).SetBytes(scalar)
   431  	if len(scalar) > byteSize {
   432  		s.Mod(s, curve.params.N)
   433  	}
   434  	out := make([]byte, byteSize)
   435  	return s.FillBytes(out)
   436  }
   437  
   438  func (g1 *G1Curve) ScalarBaseMult(scalar []byte) (*big.Int, *big.Int) {
   439  	scalar = g1.normalizeScalar(scalar)
   440  	p, err := g1.g.ScalarBaseMult(scalar)
   441  	if err != nil {
   442  		panic("sm9: g1 rejected normalized scalar")
   443  	}
   444  	res := p.Marshal()
   445  	return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:])
   446  }
   447  
   448  func (g1 *G1Curve) ScalarMult(Bx, By *big.Int, scalar []byte) (*big.Int, *big.Int) {
   449  	a, err := g1.pointFromAffine(Bx, By)
   450  	if err != nil {
   451  		panic("sm9: ScalarMult was called on an invalid point")
   452  	}
   453  	scalar = g1.normalizeScalar(scalar)
   454  	p, err := g1.g.ScalarMult(a, scalar)
   455  	if err != nil {
   456  		panic("sm9: g1 rejected normalized scalar")
   457  	}
   458  	res := p.Marshal()
   459  	return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:])
   460  }
   461  
   462  func (g1 *G1Curve) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) {
   463  	a, err := g1.pointFromAffine(x1, y1)
   464  	if err != nil {
   465  		panic("sm9: Add was called on an invalid point")
   466  	}
   467  	b, err := g1.pointFromAffine(x2, y2)
   468  	if err != nil {
   469  		panic("sm9: Add was called on an invalid point")
   470  	}
   471  	res := g1.g.Add(a, b).Marshal()
   472  	return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:])
   473  }
   474  
   475  func (g1 *G1Curve) Double(x, y *big.Int) (*big.Int, *big.Int) {
   476  	a, err := g1.pointFromAffine(x, y)
   477  	if err != nil {
   478  		panic("sm9: Double was called on an invalid point")
   479  	}
   480  	res := g1.g.Double(a).Marshal()
   481  	return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:])
   482  }
   483  
   484  func (g1 *G1Curve) IsOnCurve(x, y *big.Int) bool {
   485  	_, err := g1.pointFromAffine(x, y)
   486  	return err == nil
   487  }
   488  
   489  func (curve *G1Curve) UnmarshalCompressed(data []byte) (x, y *big.Int) {
   490  	if len(data) != 33 || (data[0] != 2 && data[0] != 3) {
   491  		return nil, nil
   492  	}
   493  	r := &gfP{}
   494  	r.Unmarshal(data[1:33])
   495  	if lessThanP(r) == 0 {
   496  		return nil, nil
   497  	}
   498  	x = new(big.Int).SetBytes(data[1:33])
   499  	p := &curvePoint{}
   500  	montEncode(r, r)
   501  	p.x = *r
   502  	p.z = *newGFp(1)
   503  	p.t = *newGFp(1)
   504  	y2 := &gfP{}
   505  	gfpMul(y2, r, r)
   506  	gfpMul(y2, y2, r)
   507  	gfpAdd(y2, y2, curveB)
   508  	y2.Sqrt(y2)
   509  	p.y = *y2
   510  	if !p.IsOnCurve() {
   511  		return nil, nil
   512  	}
   513  	montDecode(y2, y2)
   514  	ret := make([]byte, 32)
   515  	y2.Marshal(ret)
   516  	y = new(big.Int).SetBytes(ret)
   517  	if byte(y.Bit(0)) != data[0]&1 {
   518  		gfpNeg(y2, y2)
   519  		y2.Marshal(ret)
   520  		y.SetBytes(ret)
   521  	}
   522  	return x, y
   523  }
   524  
   525  func (curve *G1Curve) Unmarshal(data []byte) (x, y *big.Int) {
   526  	if len(data) != 65 || (data[0] != 4) {
   527  		return nil, nil
   528  	}
   529  	x1 := &gfP{}
   530  	x1.Unmarshal(data[1:33])
   531  	y1 := &gfP{}
   532  	y1.Unmarshal(data[33:])
   533  	if lessThanP(x1) == 0 || lessThanP(y1) == 0 {
   534  		return nil, nil
   535  	}
   536  	montEncode(x1, x1)
   537  	montEncode(y1, y1)
   538  	p := &curvePoint{
   539  		x: *x1,
   540  		y: *y1,
   541  		z: *newGFp(1),
   542  		t: *newGFp(1),
   543  	}
   544  	if !p.IsOnCurve() {
   545  		return nil, nil
   546  	}
   547  	x = new(big.Int).SetBytes(data[1:33])
   548  	y = new(big.Int).SetBytes(data[33:])
   549  	return x, y
   550  }