github.com/emmansun/gmsm@v0.29.1/sm9/bn256/g2.go (about)

     1  package bn256
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"math/big"
     7  	"sync"
     8  )
     9  
    10  // G2 is an abstract cyclic group. The zero value is suitable for use as the
    11  // output of an operation, but cannot be used as an input.
    12  type G2 struct {
    13  	p *twistPoint
    14  }
    15  
    16  // Gen2 is the generator of G2.
    17  var Gen2 = &G2{twistGen}
    18  
    19  var g2GeneratorTable *[32 * 2]twistPointTable
    20  var g2GeneratorTableOnce sync.Once
    21  
    22  func (g *G2) generatorTable() *[32 * 2]twistPointTable {
    23  	g2GeneratorTableOnce.Do(func() {
    24  		g2GeneratorTable = new([32 * 2]twistPointTable)
    25  		base := NewTwistGenerator()
    26  		for i := 0; i < 32*2; i++ {
    27  			g2GeneratorTable[i][0] = &twistPoint{}
    28  			g2GeneratorTable[i][0].Set(base)
    29  
    30  			g2GeneratorTable[i][1] = &twistPoint{}
    31  			g2GeneratorTable[i][1].Double(g2GeneratorTable[i][0])
    32  			g2GeneratorTable[i][2] = &twistPoint{}
    33  			g2GeneratorTable[i][2].Add(g2GeneratorTable[i][1], base)
    34  
    35  			g2GeneratorTable[i][3] = &twistPoint{}
    36  			g2GeneratorTable[i][3].Double(g2GeneratorTable[i][1])
    37  			g2GeneratorTable[i][4] = &twistPoint{}
    38  			g2GeneratorTable[i][4].Add(g2GeneratorTable[i][3], base)
    39  
    40  			g2GeneratorTable[i][5] = &twistPoint{}
    41  			g2GeneratorTable[i][5].Double(g2GeneratorTable[i][2])
    42  			g2GeneratorTable[i][6] = &twistPoint{}
    43  			g2GeneratorTable[i][6].Add(g2GeneratorTable[i][5], base)
    44  
    45  			g2GeneratorTable[i][7] = &twistPoint{}
    46  			g2GeneratorTable[i][7].Double(g2GeneratorTable[i][3])
    47  			g2GeneratorTable[i][8] = &twistPoint{}
    48  			g2GeneratorTable[i][8].Add(g2GeneratorTable[i][7], base)
    49  
    50  			g2GeneratorTable[i][9] = &twistPoint{}
    51  			g2GeneratorTable[i][9].Double(g2GeneratorTable[i][4])
    52  			g2GeneratorTable[i][10] = &twistPoint{}
    53  			g2GeneratorTable[i][10].Add(g2GeneratorTable[i][9], base)
    54  
    55  			g2GeneratorTable[i][11] = &twistPoint{}
    56  			g2GeneratorTable[i][11].Double(g2GeneratorTable[i][5])
    57  			g2GeneratorTable[i][12] = &twistPoint{}
    58  			g2GeneratorTable[i][12].Add(g2GeneratorTable[i][11], base)
    59  
    60  			g2GeneratorTable[i][13] = &twistPoint{}
    61  			g2GeneratorTable[i][13].Double(g2GeneratorTable[i][6])
    62  			g2GeneratorTable[i][14] = &twistPoint{}
    63  			g2GeneratorTable[i][14].Add(g2GeneratorTable[i][13], base)
    64  
    65  			base.Double(base)
    66  			base.Double(base)
    67  			base.Double(base)
    68  			base.Double(base)
    69  		}
    70  	})
    71  	return g2GeneratorTable
    72  }
    73  
    74  // RandomG2 returns x and g₂ˣ where x is a random, non-zero number read from r.
    75  func RandomG2(r io.Reader) (*big.Int, *G2, error) {
    76  	k, err := randomK(r)
    77  	if err != nil {
    78  		return nil, nil, err
    79  	}
    80  	g2, err := new(G2).ScalarBaseMult(NormalizeScalar(k.Bytes()))
    81  	return k, g2, err
    82  }
    83  
    84  func (e *G2) String() string {
    85  	return "sm9.G2" + e.p.String()
    86  }
    87  
    88  // ScalarBaseMult sets e to g*k where g is the generator of the group and then
    89  // returns out.
    90  func (e *G2) ScalarBaseMult(scalar []byte) (*G2, error) {
    91  	if len(scalar) != 32 {
    92  		return nil, errors.New("invalid scalar length")
    93  	}
    94  	if e.p == nil {
    95  		e.p = &twistPoint{}
    96  	}
    97  	//e.p.Mul(twistGen, k)
    98  
    99  	tables := e.generatorTable()
   100  	// This is also a scalar multiplication with a four-bit window like in
   101  	// ScalarMult, but in this case the doublings are precomputed. The value
   102  	// [windowValue]G added at iteration k would normally get doubled
   103  	// (totIterations-k)×4 times, but with a larger precomputation we can
   104  	// instead add [2^((totIterations-k)×4)][windowValue]G and avoid the
   105  	// doublings between iterations.
   106  	t := NewTwistPoint()
   107  	e.p.SetInfinity()
   108  	tableIndex := len(tables) - 1
   109  	for _, byte := range scalar {
   110  		windowValue := byte >> 4
   111  		tables[tableIndex].Select(t, windowValue)
   112  		e.p.Add(e.p, t)
   113  		tableIndex--
   114  		windowValue = byte & 0b1111
   115  		tables[tableIndex].Select(t, windowValue)
   116  		e.p.Add(e.p, t)
   117  		tableIndex--
   118  	}
   119  
   120  	return e, nil
   121  }
   122  
   123  // ScalarMult sets e to a*k and then returns e.
   124  func (e *G2) ScalarMult(a *G2, scalar []byte) (*G2, error) {
   125  	if e.p == nil {
   126  		e.p = &twistPoint{}
   127  	}
   128  	//e.p.Mul(a.p, k)
   129  	// Compute a twistPointTable for the base point a.
   130  	var table = twistPointTable{NewTwistPoint(), NewTwistPoint(), NewTwistPoint(),
   131  		NewTwistPoint(), NewTwistPoint(), NewTwistPoint(), NewTwistPoint(),
   132  		NewTwistPoint(), NewTwistPoint(), NewTwistPoint(), NewTwistPoint(),
   133  		NewTwistPoint(), NewTwistPoint(), NewTwistPoint(), NewTwistPoint()}
   134  	table[0].Set(a.p)
   135  	for i := 1; i < 15; i += 2 {
   136  		table[i].Double(table[i/2])
   137  		table[i+1].Add(table[i], a.p)
   138  	}
   139  	// Instead of doing the classic double-and-add chain, we do it with a
   140  	// four-bit window: we double four times, and then add [0-15]P.
   141  	t := &G2{NewTwistPoint()}
   142  	e.p.SetInfinity()
   143  	for i, byte := range scalar {
   144  		// No need to double on the first iteration, as p is the identity at
   145  		// this point, and [N]∞ = ∞.
   146  		if i != 0 {
   147  			e.p.Double(e.p)
   148  			e.p.Double(e.p)
   149  			e.p.Double(e.p)
   150  			e.p.Double(e.p)
   151  		}
   152  		windowValue := byte >> 4
   153  		table.Select(t.p, windowValue)
   154  		e.Add(e, t)
   155  		e.p.Double(e.p)
   156  		e.p.Double(e.p)
   157  		e.p.Double(e.p)
   158  		e.p.Double(e.p)
   159  		windowValue = byte & 0b1111
   160  		table.Select(t.p, windowValue)
   161  		e.Add(e, t)
   162  	}
   163  	return e, nil
   164  }
   165  
   166  // Add sets e to a+b and then returns e.
   167  func (e *G2) Add(a, b *G2) *G2 {
   168  	if e.p == nil {
   169  		e.p = &twistPoint{}
   170  	}
   171  	e.p.Add(a.p, b.p)
   172  	return e
   173  }
   174  
   175  // Neg sets e to -a and then returns e.
   176  func (e *G2) Neg(a *G2) *G2 {
   177  	if e.p == nil {
   178  		e.p = &twistPoint{}
   179  	}
   180  	e.p.Neg(a.p)
   181  	return e
   182  }
   183  
   184  // Set sets e to a and then returns e.
   185  func (e *G2) Set(a *G2) *G2 {
   186  	if e.p == nil {
   187  		e.p = &twistPoint{}
   188  	}
   189  	e.p.Set(a.p)
   190  	return e
   191  }
   192  
   193  // Marshal converts e into a byte slice.
   194  func (e *G2) Marshal() []byte {
   195  	// Each value is a 256-bit number.
   196  	const numBytes = 256 / 8
   197  	ret := make([]byte, numBytes*4)
   198  	e.fillBytes(ret)
   199  	return ret
   200  }
   201  
   202  // MarshalUncompressed converts e into a byte slice with uncompressed point prefix
   203  func (e *G2) MarshalUncompressed() []byte {
   204  	// Each value is a 256-bit number.
   205  	const numBytes = 256 / 8
   206  	ret := make([]byte, numBytes*4+1)
   207  	ret[0] = 4
   208  	e.fillBytes(ret[1:])
   209  	return ret
   210  }
   211  
   212  // MarshalCompressed converts e into a byte slice with uncompressed point prefix
   213  func (e *G2) MarshalCompressed() []byte {
   214  	// Each value is a 256-bit number.
   215  	const numBytes = 256 / 8
   216  	ret := make([]byte, numBytes*2+1)
   217  	if e.p == nil {
   218  		e.p = &twistPoint{}
   219  	}
   220  	e.p.MakeAffine()
   221  	temp := &gfP{}
   222  	montDecode(temp, &e.p.y.y)
   223  	temp.Marshal(ret[1:])
   224  	ret[0] = (ret[numBytes] & 1) | 2
   225  
   226  	montDecode(temp, &e.p.x.x)
   227  	temp.Marshal(ret[1:])
   228  	montDecode(temp, &e.p.x.y)
   229  	temp.Marshal(ret[numBytes+1:])
   230  
   231  	return ret
   232  }
   233  
   234  // UnmarshalCompressed sets e to the result of converting the output of Marshal back into
   235  // a group element and then returns e.
   236  func (e *G2) UnmarshalCompressed(data []byte) ([]byte, error) {
   237  	// Each value is a 256-bit number.
   238  	const numBytes = 256 / 8
   239  	if len(data) < 1+2*numBytes {
   240  		return nil, errors.New("sm9.G2: not enough data")
   241  	}
   242  	if data[0] != 2 && data[0] != 3 { // compressed form
   243  		return nil, errors.New("sm9.G2: invalid point compress byte")
   244  	}
   245  	var err error
   246  	// Unmarshal the points and check their caps
   247  	if e.p == nil {
   248  		e.p = &twistPoint{}
   249  	}
   250  	if err = e.p.x.x.Unmarshal(data[1:]); err != nil {
   251  		return nil, err
   252  	}
   253  	if err = e.p.x.y.Unmarshal(data[1+numBytes:]); err != nil {
   254  		return nil, err
   255  	}
   256  	montEncode(&e.p.x.x, &e.p.x.x)
   257  	montEncode(&e.p.x.y, &e.p.x.y)
   258  	x3 := e.p.polynomial(&e.p.x)
   259  	e.p.y.Sqrt(x3)
   260  	x3y := &gfP{}
   261  	montDecode(x3y, &e.p.y.y)
   262  	if byte(x3y[0]&1) != data[0]&1 {
   263  		e.p.y.Neg(&e.p.y)
   264  	}
   265  	if e.p.x.IsZero() && e.p.y.IsZero() {
   266  		// This is the point at infinity.
   267  		e.p.y.SetOne()
   268  		e.p.z.SetZero()
   269  		e.p.t.SetZero()
   270  	} else {
   271  		e.p.z.SetOne()
   272  		e.p.t.SetOne()
   273  
   274  		if !e.p.IsOnCurve() {
   275  			return nil, errors.New("sm9.G2: malformed point")
   276  		}
   277  	}
   278  	return data[1+2*numBytes:], nil
   279  }
   280  
   281  func (e *G2) fillBytes(buffer []byte) {
   282  	// Each value is a 256-bit number.
   283  	const numBytes = 256 / 8
   284  
   285  	if e.p == nil {
   286  		e.p = &twistPoint{}
   287  	}
   288  
   289  	e.p.MakeAffine()
   290  	if e.p.IsInfinity() {
   291  		return
   292  	}
   293  	temp := &gfP{}
   294  
   295  	montDecode(temp, &e.p.x.x)
   296  	temp.Marshal(buffer)
   297  	montDecode(temp, &e.p.x.y)
   298  	temp.Marshal(buffer[numBytes:])
   299  	montDecode(temp, &e.p.y.x)
   300  	temp.Marshal(buffer[2*numBytes:])
   301  	montDecode(temp, &e.p.y.y)
   302  	temp.Marshal(buffer[3*numBytes:])
   303  }
   304  
   305  // Unmarshal sets e to the result of converting the output of Marshal back into
   306  // a group element and then returns e.
   307  func (e *G2) Unmarshal(m []byte) ([]byte, error) {
   308  	// Each value is a 256-bit number.
   309  	const numBytes = 256 / 8
   310  	if len(m) < 4*numBytes {
   311  		return nil, errors.New("sm9.G2: not enough data")
   312  	}
   313  	// Unmarshal the points and check their caps
   314  	if e.p == nil {
   315  		e.p = &twistPoint{}
   316  	}
   317  	var err error
   318  	if err = e.p.x.x.Unmarshal(m); err != nil {
   319  		return nil, err
   320  	}
   321  	if err = e.p.x.y.Unmarshal(m[numBytes:]); err != nil {
   322  		return nil, err
   323  	}
   324  	if err = e.p.y.x.Unmarshal(m[2*numBytes:]); err != nil {
   325  		return nil, err
   326  	}
   327  	if err = e.p.y.y.Unmarshal(m[3*numBytes:]); err != nil {
   328  		return nil, err
   329  	}
   330  	// Encode into Montgomery form and ensure it's on the curve
   331  	montEncode(&e.p.x.x, &e.p.x.x)
   332  	montEncode(&e.p.x.y, &e.p.x.y)
   333  	montEncode(&e.p.y.x, &e.p.y.x)
   334  	montEncode(&e.p.y.y, &e.p.y.y)
   335  
   336  	if e.p.x.IsZero() && e.p.y.IsZero() {
   337  		// This is the point at infinity.
   338  		e.p.y.SetOne()
   339  		e.p.z.SetZero()
   340  		e.p.t.SetZero()
   341  	} else {
   342  		e.p.z.SetOne()
   343  		e.p.t.SetOne()
   344  
   345  		if !e.p.IsOnCurve() {
   346  			return nil, errors.New("sm9.G2: malformed point")
   347  		}
   348  	}
   349  	return m[4*numBytes:], nil
   350  }
   351  
   352  // Equal compare e and other
   353  func (e *G2) Equal(other *G2) bool {
   354  	if e.p == nil && other.p == nil {
   355  		return true
   356  	}
   357  	return e.p.x.Equal(&other.p.x) == 1 &&
   358  		e.p.y.Equal(&other.p.y) == 1 &&
   359  		e.p.z.Equal(&other.p.z) == 1 &&
   360  		e.p.t.Equal(&other.p.t) == 1
   361  }
   362  
   363  // IsOnCurve returns true if e is on the twist curve.
   364  func (e *G2) IsOnCurve() bool {
   365  	return e.p.IsOnCurve()
   366  }