github.com/emmansun/gmsm@v0.29.1/sm2/sm2ec/sm2ec.go (about)

     1  // Package sm2ec defines/implements SM2 curve structure.
     2  package sm2ec
     3  
     4  import (
     5  	"crypto/elliptic"
     6  	"errors"
     7  	"math/big"
     8  
     9  	_sm2ec "github.com/emmansun/gmsm/internal/sm2ec"
    10  )
    11  
    12  type sm2Curve struct {
    13  	newPoint func() *_sm2ec.SM2P256Point
    14  	params   *elliptic.CurveParams
    15  }
    16  
    17  var sm2p256 = &sm2Curve{newPoint: _sm2ec.NewSM2P256Point}
    18  
    19  func initSM2P256() {
    20  	sm2p256.params = sm2Params
    21  }
    22  
    23  func (curve *sm2Curve) Params() *elliptic.CurveParams {
    24  	return curve.params
    25  }
    26  
    27  func (curve *sm2Curve) IsOnCurve(x, y *big.Int) bool {
    28  	// IsOnCurve is documented to reject (0, 0), the conventional point at
    29  	// infinity, which however is accepted by pointFromAffine.
    30  	if x.Sign() == 0 && y.Sign() == 0 {
    31  		return false
    32  	}
    33  	_, err := curve.pointFromAffine(x, y)
    34  	return err == nil
    35  }
    36  
    37  func (curve *sm2Curve) pointFromAffine(x, y *big.Int) (p *_sm2ec.SM2P256Point, err error) {
    38  	// (0, 0) is by convention the point at infinity, which can't be represented
    39  	// in affine coordinates. See Issue 37294.
    40  	if x.Sign() == 0 && y.Sign() == 0 {
    41  		return curve.newPoint(), nil
    42  	}
    43  	// Reject values that would not get correctly encoded.
    44  	if x.Sign() < 0 || y.Sign() < 0 {
    45  		return p, errors.New("negative coordinate")
    46  	}
    47  	if x.BitLen() > curve.params.BitSize || y.BitLen() > curve.params.BitSize {
    48  		return p, errors.New("overflowing coordinate")
    49  	}
    50  	// Encode the coordinates and let SetBytes reject invalid points.
    51  	byteLen := (curve.params.BitSize + 7) / 8
    52  	buf := make([]byte, 1+2*byteLen)
    53  	buf[0] = 4 // uncompressed point
    54  	x.FillBytes(buf[1 : 1+byteLen])
    55  	y.FillBytes(buf[1+byteLen : 1+2*byteLen])
    56  	return curve.newPoint().SetBytes(buf)
    57  }
    58  
    59  func (curve *sm2Curve) pointToAffine(p *_sm2ec.SM2P256Point) (x, y *big.Int) {
    60  	out := p.Bytes()
    61  	if len(out) == 1 && out[0] == 0 {
    62  		// This is the encoding of the point at infinity, which the affine
    63  		// coordinates API represents as (0, 0) by convention.
    64  		return new(big.Int), new(big.Int)
    65  	}
    66  	byteLen := (curve.params.BitSize + 7) / 8
    67  	x = new(big.Int).SetBytes(out[1 : 1+byteLen])
    68  	y = new(big.Int).SetBytes(out[1+byteLen:])
    69  	return x, y
    70  }
    71  
    72  func (curve *sm2Curve) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) {
    73  	p1, err := curve.pointFromAffine(x1, y1)
    74  	if err != nil {
    75  		panic("sm2/elliptic: Add was called on an invalid point")
    76  	}
    77  	p2, err := curve.pointFromAffine(x2, y2)
    78  	if err != nil {
    79  		panic("sm2/elliptic: Add was called on an invalid point")
    80  	}
    81  	return curve.pointToAffine(p1.Add(p1, p2))
    82  }
    83  
    84  func (curve *sm2Curve) Double(x1, y1 *big.Int) (*big.Int, *big.Int) {
    85  	p, err := curve.pointFromAffine(x1, y1)
    86  	if err != nil {
    87  		panic("sm2/elliptic: Double was called on an invalid point")
    88  	}
    89  	return curve.pointToAffine(p.Double(p))
    90  }
    91  
    92  // normalizeScalar brings the scalar within the byte size of the order of the
    93  // curve, as expected by the nistec scalar multiplication functions.
    94  func (curve *sm2Curve) normalizeScalar(scalar []byte) []byte {
    95  	byteSize := (curve.params.N.BitLen() + 7) / 8
    96  	if len(scalar) == byteSize {
    97  		return scalar
    98  	}
    99  	s := new(big.Int).SetBytes(scalar)
   100  	if len(scalar) > byteSize {
   101  		s.Mod(s, curve.params.N)
   102  	}
   103  	out := make([]byte, byteSize)
   104  	return s.FillBytes(out)
   105  }
   106  
   107  func (curve *sm2Curve) ScalarMult(Bx, By *big.Int, scalar []byte) (*big.Int, *big.Int) {
   108  	p, err := curve.pointFromAffine(Bx, By)
   109  	if err != nil {
   110  		panic("sm2/elliptic: ScalarMult was called on an invalid point")
   111  	}
   112  	scalar = curve.normalizeScalar(scalar)
   113  	p, err = p.ScalarMult(p, scalar)
   114  	if err != nil {
   115  		panic("sm2/elliptic: sm2 rejected normalized scalar")
   116  	}
   117  	return curve.pointToAffine(p)
   118  }
   119  
   120  func (curve *sm2Curve) ScalarBaseMult(scalar []byte) (*big.Int, *big.Int) {
   121  	scalar = curve.normalizeScalar(scalar)
   122  	p, err := curve.newPoint().ScalarBaseMult(scalar)
   123  	if err != nil {
   124  		panic("sm2/elliptic: sm2 rejected normalized scalar")
   125  	}
   126  	return curve.pointToAffine(p)
   127  }
   128  
   129  // CombinedMult returns [s1]G + [s2]P where G is the generator. It's used
   130  // through an interface upgrade in crypto/ecdsa.
   131  func (curve *sm2Curve) CombinedMult(Px, Py *big.Int, s1, s2 []byte) (x, y *big.Int) {
   132  	s1 = curve.normalizeScalar(s1)
   133  	q, err := curve.newPoint().ScalarBaseMult(s1)
   134  	if err != nil {
   135  		panic("sm2/elliptic: sm2 rejected normalized scalar")
   136  	}
   137  	p, err := curve.pointFromAffine(Px, Py)
   138  	if err != nil {
   139  		panic("sm2/elliptic: CombinedMult was called on an invalid point")
   140  	}
   141  	s2 = curve.normalizeScalar(s2)
   142  	p, err = p.ScalarMult(p, s2)
   143  	if err != nil {
   144  		panic("sm2/elliptic: sm2 rejected normalized scalar")
   145  	}
   146  	return curve.pointToAffine(p.Add(p, q))
   147  }
   148  
   149  func (curve *sm2Curve) Unmarshal(data []byte) (x, y *big.Int) {
   150  	if len(data) == 0 || data[0] != 4 {
   151  		return nil, nil
   152  	}
   153  	// Use SetBytes to check that data encodes a valid point.
   154  	_, err := curve.newPoint().SetBytes(data)
   155  	if err != nil {
   156  		return nil, nil
   157  	}
   158  	// We don't use pointToAffine because it involves an expensive field
   159  	// inversion to convert from Jacobian to affine coordinates, which we
   160  	// already have.
   161  	byteLen := (curve.params.BitSize + 7) / 8
   162  	x = new(big.Int).SetBytes(data[1 : 1+byteLen])
   163  	y = new(big.Int).SetBytes(data[1+byteLen:])
   164  	return x, y
   165  }
   166  
   167  func (curve *sm2Curve) UnmarshalCompressed(data []byte) (x, y *big.Int) {
   168  	if len(data) == 0 || (data[0] != 2 && data[0] != 3) {
   169  		return nil, nil
   170  	}
   171  	p, err := curve.newPoint().SetBytes(data)
   172  	if err != nil {
   173  		return nil, nil
   174  	}
   175  	return curve.pointToAffine(p)
   176  }
   177  
   178  // Inverse, implements invertible interface, used by Sign()
   179  func (curve *sm2Curve) Inverse(k *big.Int) *big.Int {
   180  	if k.Sign() < 0 {
   181  		// This should never happen.
   182  		k = new(big.Int).Neg(k)
   183  	}
   184  	if k.Cmp(curve.params.N) >= 0 {
   185  		// This should never happen.
   186  		k = new(big.Int).Mod(k, curve.params.N)
   187  	}
   188  	scalar := k.FillBytes(make([]byte, 32))
   189  	inverse, err := _sm2ec.P256OrdInverse(scalar)
   190  	if err != nil {
   191  		panic("sm2/elliptic: sm2 rejected normalized scalar")
   192  	}
   193  	return new(big.Int).SetBytes(inverse)
   194  }