github.com/hellobchain/newcryptosm@v0.0.0-20221019060107-edb949a317e9/sm9/elliptic.go (about)

     1  package sm9
     2  
     3  import (
     4  	"crypto/elliptic"
     5  	"errors"
     6  	"fmt"
     7  	"math/big"
     8  	"sync"
     9  )
    10  
    11  type Curve struct {
    12  	*elliptic.CurveParams
    13  }
    14  
    15  var curve Curve
    16  var initOnce sync.Once
    17  
    18  func initSM9() {
    19  	curve.CurveParams = &elliptic.CurveParams{}
    20  	curve.CurveParams.P, _ = new(big.Int).SetString("B640000002A3A6F1D603AB4FF58EC74521F2934B1A7AEEDBE56F9B27E351457D", 16)
    21  	curve.CurveParams.N, _ = new(big.Int).SetString("B640000002A3A6F1D603AB4FF58EC74449F2934B18EA8BEEE56EE19CD69ECF25", 16)
    22  	curve.CurveParams.B = new(big.Int).SetInt64(5)
    23  	curve.CurveParams.Gx, _ = new(big.Int).SetString("93DE051D62BF718FF5ED0704487D01D6E1E4086909DC3280E8C4E4817C66DDDD", 16)
    24  	curve.CurveParams.Gy, _ = new(big.Int).SetString("21FE8DDA4F21E607631065125C395BBC1C1C00CBFA6024350C464CD70A3EA616", 16)
    25  	curve.CurveParams.BitSize = 256
    26  	curve.CurveParams.Name = "SM9"
    27  }
    28  
    29  // SM9 return the elliptic.Curve interface of SM9 curve
    30  func SM9() *Curve {
    31  	initOnce.Do(initSM9)
    32  	return &curve
    33  }
    34  
    35  // Params returns the parameters for the curve.
    36  func (curve *Curve) Params() *elliptic.CurveParams {
    37  	return curve.CurveParams
    38  }
    39  
    40  // IsOnCurve reports whether the given (x,y) lies on the curve.
    41  func (curve *Curve) IsOnCurve(x, y *big.Int) bool {
    42  
    43  	// y² = x³ + b
    44  	y2 := new(big.Int).Mul(y, y)
    45  	y2.Mod(y2, curve.P)
    46  
    47  	x3 := new(big.Int).Mul(x, x)
    48  	x3.Mul(x3, x)
    49  
    50  	x3.Add(x3, curve.B)
    51  	x3.Mod(x3, curve.P)
    52  
    53  	return x3.Cmp(y2) == 0
    54  }
    55  
    56  // Add returns the sum of (x1,y1) and (x2,y2)
    57  func (curve *Curve) Add(x1, y1, x2, y2 *big.Int) (x, y *big.Int) {
    58  	p1 := BigToG1(x1, y1)
    59  	p2 := BigToG1(x2, y2)
    60  	rp := new(G1).Add(p1, p2)
    61  	return G1ToBig(rp)
    62  }
    63  
    64  // Double returns 2*(x,y)
    65  func (curve *Curve) Double(x1, y1 *big.Int) (x, y *big.Int) {
    66  	return curve.Add(x1, y1, x1, y1)
    67  }
    68  
    69  // ScalarMult returns k*(Bx,By) where k is a number in big-endian form.
    70  func (curve *Curve) ScalarMult(x1, y1 *big.Int, k []byte) (x, y *big.Int) {
    71  	p1 := BigToG1(x1, y1)
    72  	bigK := new(big.Int).SetBytes(k)
    73  	rp := new(G1).ScalarMult(p1, bigK)
    74  	return G1ToBig(rp)
    75  }
    76  
    77  // ScalarBaseMult returns k*G, where G is the base point of the group
    78  // and k is an integer in big-endian form.
    79  func (curve *Curve) ScalarBaseMult(k []byte) (x, y *big.Int) {
    80  	bigK := new(big.Int).SetBytes(k)
    81  	rp := new(G1).ScalarBaseMult(bigK)
    82  	return G1ToBig(rp)
    83  }
    84  
    85  //Neg is (x, -y)
    86  func (curve *Curve) Neg(x1, y1 *big.Int) (x, y *big.Int) {
    87  	return new(big.Int).Set(x1), new(big.Int).Sub(curve.Params().P, y1)
    88  }
    89  
    90  //CombinedMult do baseScalar*G + scalar*(X,Y)
    91  func (curve *Curve) CombinedMult(bigX, bigY *big.Int, baseScalar, scalar []byte) (x, y *big.Int) {
    92  	x1, y1 := curve.ScalarBaseMult(baseScalar)
    93  	x2, y2 := curve.ScalarMult(bigX, bigY, scalar)
    94  	return curve.Add(x1, y1, x2, y2)
    95  }
    96  
    97  func BigToG1(x, y *big.Int) *G1 {
    98  	m := make([]byte, 64)
    99  	xBytes := x.Bytes()
   100  	yBytes := y.Bytes()
   101  	copy(m[32-len(xBytes):32], xBytes)
   102  	copy(m[64-len(yBytes):64], yBytes)
   103  	r := new(G1)
   104  	_, success := r.Unmarshal(m)
   105  	if success != nil {
   106  		fmt.Printf("error in big int to G1")
   107  		return nil
   108  	}
   109  	return r
   110  }
   111  
   112  func G1ToBig(g1 *G1) (x, y *big.Int) {
   113  	m := g1.Marshal()
   114  	return new(big.Int).SetBytes(m[0:32]), new(big.Int).SetBytes(m[32:64])
   115  }
   116  
   117  func (curve *Curve) Compress(x, y *big.Int) []byte {
   118  	return CompressP(x, y)
   119  }
   120  
   121  func (curve *Curve) Decompress(in []byte) (x, y *big.Int, err error) {
   122  	x, y = DecompressP(in)
   123  	if x == nil || y == nil {
   124  		return nil, nil, errors.New("decompress fail")
   125  	}
   126  	return x, y, nil
   127  }
   128  
   129  func BytesToG2(in []byte) *G2 {
   130  	point := new(G2)
   131  	if _, success := point.Unmarshal(in); success != nil {
   132  		fmt.Printf("point is not on curve G2")
   133  		return nil
   134  	}
   135  	return point
   136  }
   137  
   138  func G2ToBytes(point *G2) []byte {
   139  	return point.Marshal()
   140  }
   141  
   142  func BytesToGt(in []byte) *GT {
   143  	point := new(GT)
   144  	if _, success := point.Unmarshal(in); success != nil {
   145  		fmt.Printf("point is not on curve GT")
   146  		return nil
   147  	}
   148  	return point
   149  }
   150  
   151  func GtToBytes(point *GT) []byte {
   152  	return point.Marshal()
   153  }