github.com/cloudflare/circl@v1.5.0/group/short.go (about)

     1  package group
     2  
     3  import (
     4  	"crypto"
     5  	"crypto/elliptic"
     6  	_ "crypto/sha256"
     7  	_ "crypto/sha512"
     8  	"crypto/subtle"
     9  	"fmt"
    10  	"io"
    11  	"math/big"
    12  
    13  	"github.com/cloudflare/circl/ecc/p384"
    14  	"github.com/cloudflare/circl/expander"
    15  )
    16  
    17  var (
    18  	// P256 is the group generated by P-256 elliptic curve.
    19  	P256 Group = wG{elliptic.P256()}
    20  	// P384 is the group generated by P-384 elliptic curve.
    21  	P384 Group = wG{p384.P384()}
    22  	// P521 is the group generated by P-521 elliptic curve.
    23  	P521 Group = wG{elliptic.P521()}
    24  )
    25  
    26  type wG struct {
    27  	c elliptic.Curve
    28  }
    29  
    30  func (g wG) String() string      { return g.c.Params().Name }
    31  func (g wG) NewElement() Element { return g.zeroElement() }
    32  func (g wG) NewScalar() Scalar   { return g.zeroScalar() }
    33  func (g wG) Identity() Element   { return g.zeroElement() }
    34  func (g wG) zeroScalar() *wScl   { return &wScl{g, make([]byte, (g.c.Params().BitSize+7)/8)} }
    35  func (g wG) zeroElement() *wElt  { return &wElt{g, new(big.Int), new(big.Int)} }
    36  func (g wG) Generator() Element  { return &wElt{g, g.c.Params().Gx, g.c.Params().Gy} }
    37  func (g wG) RandomElement(rd io.Reader) Element {
    38  	b := make([]byte, (g.c.Params().BitSize+7)/8)
    39  	if n, err := io.ReadFull(rd, b); err != nil || n != len(b) {
    40  		panic(err)
    41  	}
    42  	return g.HashToElement(b, nil)
    43  }
    44  
    45  func (g wG) RandomScalar(rd io.Reader) Scalar {
    46  	b := make([]byte, (g.c.Params().BitSize+7)/8)
    47  	if n, err := io.ReadFull(rd, b); err != nil || n != len(b) {
    48  		panic(err)
    49  	}
    50  	return g.HashToScalar(b, nil)
    51  }
    52  
    53  func (g wG) RandomNonZeroScalar(rd io.Reader) Scalar {
    54  	zero := g.zeroScalar()
    55  	for {
    56  		s := g.RandomScalar(rd)
    57  		if !s.IsEqual(zero) {
    58  			return s
    59  		}
    60  	}
    61  }
    62  
    63  func (g wG) cvtElt(e Element) *wElt {
    64  	if e == nil {
    65  		return g.zeroElement()
    66  	}
    67  	ee, ok := e.(*wElt)
    68  	if !ok || g.c.Params().BitSize != ee.c.Params().BitSize {
    69  		panic(ErrType)
    70  	}
    71  	return ee
    72  }
    73  
    74  func (g wG) cvtScl(s Scalar) *wScl {
    75  	if s == nil {
    76  		return g.zeroScalar()
    77  	}
    78  	ss, ok := s.(*wScl)
    79  	if !ok || g.c.Params().BitSize != ss.c.Params().BitSize {
    80  		panic(ErrType)
    81  	}
    82  	return ss
    83  }
    84  
    85  func (g wG) Params() *Params {
    86  	fieldLen := uint((g.c.Params().BitSize + 7) / 8)
    87  	return &Params{
    88  		ElementLength:           1 + 2*fieldLen,
    89  		CompressedElementLength: 1 + fieldLen,
    90  		ScalarLength:            fieldLen,
    91  	}
    92  }
    93  
    94  func (g wG) HashToElementNonUniform(b, dst []byte) Element {
    95  	var u [1]big.Int
    96  	mapping, h, L := g.mapToCurveParams()
    97  	xmd := expander.NewExpanderMD(h, dst)
    98  	HashToField(u[:], b, xmd, g.c.Params().P, L)
    99  	return mapping(&u[0])
   100  }
   101  
   102  func (g wG) HashToElement(b, dst []byte) Element {
   103  	var u [2]big.Int
   104  	mapping, h, L := g.mapToCurveParams()
   105  	xmd := expander.NewExpanderMD(h, dst)
   106  	HashToField(u[:], b, xmd, g.c.Params().P, L)
   107  	Q0 := mapping(&u[0])
   108  	Q1 := mapping(&u[1])
   109  	return Q0.Add(Q0, Q1)
   110  }
   111  
   112  func (g wG) HashToScalar(b, dst []byte) Scalar {
   113  	var u [1]big.Int
   114  	_, h, L := g.mapToCurveParams()
   115  	xmd := expander.NewExpanderMD(h, dst)
   116  	HashToField(u[:], b, xmd, g.c.Params().N, L)
   117  	s := g.NewScalar().(*wScl)
   118  	s.fromBig(&u[0])
   119  	return s
   120  }
   121  
   122  type wElt struct {
   123  	wG
   124  	x, y *big.Int
   125  }
   126  
   127  func (e *wElt) Group() Group     { return e.wG }
   128  func (e *wElt) String() string   { return fmt.Sprintf("x: 0x%v\ny: 0x%v", e.x.Text(16), e.y.Text(16)) }
   129  func (e *wElt) IsIdentity() bool { return e.x.Sign() == 0 && e.y.Sign() == 0 }
   130  func (e *wElt) IsEqual(o Element) bool {
   131  	oo := e.cvtElt(o)
   132  	return e.x.Cmp(oo.x) == 0 && e.y.Cmp(oo.y) == 0
   133  }
   134  
   135  func (e *wElt) Set(a Element) Element {
   136  	aa := e.cvtElt(a)
   137  	e.x.Set(aa.x)
   138  	e.y.Set(aa.y)
   139  	return e
   140  }
   141  
   142  func (e *wElt) Copy() Element { return e.wG.zeroElement().Set(e) }
   143  
   144  func (e *wElt) CMov(v int, a Element) Element {
   145  	if !(v == 0 || v == 1) {
   146  		panic(ErrSelector)
   147  	}
   148  	aa := e.cvtElt(a)
   149  	l := (e.wG.c.Params().BitSize + 7) / 8
   150  	bufE := make([]byte, l)
   151  	bufA := make([]byte, l)
   152  	e.x.FillBytes(bufE)
   153  	aa.x.FillBytes(bufA)
   154  	subtle.ConstantTimeCopy(v, bufE, bufA)
   155  	e.x.SetBytes(bufE)
   156  
   157  	e.y.FillBytes(bufE)
   158  	aa.y.FillBytes(bufA)
   159  	subtle.ConstantTimeCopy(v, bufE, bufA)
   160  	e.y.SetBytes(bufE)
   161  
   162  	return e
   163  }
   164  
   165  func (e *wElt) CSelect(v int, a Element, b Element) Element {
   166  	if !(v == 0 || v == 1) {
   167  		panic(ErrSelector)
   168  	}
   169  	aa, bb := e.cvtElt(a), e.cvtElt(b)
   170  	l := (e.wG.c.Params().BitSize + 7) / 8
   171  	bufE := make([]byte, l)
   172  	bufA := make([]byte, l)
   173  	bufB := make([]byte, l)
   174  
   175  	e.x.FillBytes(bufE)
   176  	aa.x.FillBytes(bufA)
   177  	bb.x.FillBytes(bufB)
   178  	for i := range bufE {
   179  		bufE[i] = byte(subtle.ConstantTimeSelect(v, int(bufA[i]), int(bufB[i])))
   180  	}
   181  	e.x.SetBytes(bufE)
   182  
   183  	e.y.FillBytes(bufE)
   184  	aa.y.FillBytes(bufA)
   185  	bb.y.FillBytes(bufB)
   186  	for i := range bufE {
   187  		bufE[i] = byte(subtle.ConstantTimeSelect(v, int(bufA[i]), int(bufB[i])))
   188  	}
   189  	e.y.SetBytes(bufE)
   190  
   191  	return e
   192  }
   193  
   194  func (e *wElt) Add(a, b Element) Element {
   195  	aa, bb := e.cvtElt(a), e.cvtElt(b)
   196  	e.x, e.y = e.c.Add(aa.x, aa.y, bb.x, bb.y)
   197  	return e
   198  }
   199  
   200  func (e *wElt) Dbl(a Element) Element {
   201  	aa := e.cvtElt(a)
   202  	e.x, e.y = e.c.Double(aa.x, aa.y)
   203  	return e
   204  }
   205  
   206  func (e *wElt) Neg(a Element) Element {
   207  	aa := e.cvtElt(a)
   208  	e.x.Set(aa.x)
   209  	e.y.Neg(aa.y).Mod(e.y, e.c.Params().P)
   210  	return e
   211  }
   212  
   213  func (e *wElt) Mul(a Element, s Scalar) Element {
   214  	aa, ss := e.cvtElt(a), e.cvtScl(s)
   215  	e.x, e.y = e.c.ScalarMult(aa.x, aa.y, ss.k)
   216  	return e
   217  }
   218  
   219  func (e *wElt) MulGen(s Scalar) Element {
   220  	ss := e.cvtScl(s)
   221  	e.x, e.y = e.c.ScalarBaseMult(ss.k)
   222  	return e
   223  }
   224  
   225  func (e *wElt) MarshalBinary() ([]byte, error) {
   226  	if e.IsIdentity() {
   227  		return []byte{0x0}, nil
   228  	}
   229  	e.x.Mod(e.x, e.c.Params().P)
   230  	e.y.Mod(e.y, e.c.Params().P)
   231  	return elliptic.Marshal(e.wG.c, e.x, e.y), nil
   232  }
   233  
   234  func (e *wElt) MarshalBinaryCompress() ([]byte, error) {
   235  	if e.IsIdentity() {
   236  		return []byte{0x0}, nil
   237  	}
   238  	e.x.Mod(e.x, e.c.Params().P)
   239  	e.y.Mod(e.y, e.c.Params().P)
   240  	return elliptic.MarshalCompressed(e.wG.c, e.x, e.y), nil
   241  }
   242  
   243  func (e *wElt) UnmarshalBinary(b []byte) error {
   244  	byteLen := (e.c.Params().BitSize + 7) / 8
   245  	l := len(b)
   246  	switch {
   247  	case l == 1 && b[0] == 0x00: // point at infinity
   248  		e.x.SetInt64(0)
   249  		e.y.SetInt64(0)
   250  	case l == 1+byteLen && (b[0] == 0x02 || b[0] == 0x03): // compressed
   251  		x, y := elliptic.UnmarshalCompressed(e.wG.c, b)
   252  		if x == nil {
   253  			return ErrUnmarshal
   254  		}
   255  		e.x, e.y = x, y
   256  	case l == 1+2*byteLen && b[0] == 0x04: // uncompressed
   257  		x, y := elliptic.Unmarshal(e.wG.c, b)
   258  		if x == nil {
   259  			return ErrUnmarshal
   260  		}
   261  		e.x, e.y = x, y
   262  	default:
   263  		return ErrUnmarshal
   264  	}
   265  	return nil
   266  }
   267  
   268  type wScl struct {
   269  	wG
   270  	k []byte
   271  }
   272  
   273  func (s *wScl) Group() Group                { return s.wG }
   274  func (s *wScl) String() string              { return fmt.Sprintf("0x%x", s.k) }
   275  func (s *wScl) SetUint64(n uint64) Scalar   { s.fromBig(new(big.Int).SetUint64(n)); return s }
   276  func (s *wScl) SetBigInt(x *big.Int) Scalar { s.fromBig(x); return s }
   277  func (s *wScl) IsZero() bool {
   278  	return subtle.ConstantTimeCompare(s.k, make([]byte, (s.wG.c.Params().BitSize+7)/8)) == 1
   279  }
   280  
   281  func (s *wScl) IsEqual(a Scalar) bool {
   282  	aa := s.cvtScl(a)
   283  	return subtle.ConstantTimeCompare(s.k, aa.k) == 1
   284  }
   285  
   286  func (s *wScl) fromBig(b *big.Int) {
   287  	k := new(big.Int).Mod(b, s.c.Params().N)
   288  	if err := s.UnmarshalBinary(k.Bytes()); err != nil {
   289  		panic(err)
   290  	}
   291  }
   292  
   293  func (s *wScl) Set(a Scalar) Scalar {
   294  	aa := s.cvtScl(a)
   295  	if err := s.UnmarshalBinary(aa.k); err != nil {
   296  		panic(err)
   297  	}
   298  	return s
   299  }
   300  
   301  func (s *wScl) Copy() Scalar { return s.wG.zeroScalar().Set(s) }
   302  
   303  func (s *wScl) CMov(v int, a Scalar) Scalar {
   304  	if !(v == 0 || v == 1) {
   305  		panic(ErrSelector)
   306  	}
   307  	aa := s.cvtScl(a)
   308  	subtle.ConstantTimeCopy(v, s.k, aa.k)
   309  	return s
   310  }
   311  
   312  func (s *wScl) CSelect(v int, a Scalar, b Scalar) Scalar {
   313  	if !(v == 0 || v == 1) {
   314  		panic(ErrSelector)
   315  	}
   316  	aa, bb := s.cvtScl(a), s.cvtScl(b)
   317  	for i := range s.k {
   318  		s.k[i] = byte(subtle.ConstantTimeSelect(v, int(aa.k[i]), int(bb.k[i])))
   319  	}
   320  	return s
   321  }
   322  
   323  func (s *wScl) Add(a, b Scalar) Scalar {
   324  	aa, bb := s.cvtScl(a), s.cvtScl(b)
   325  	r := new(big.Int)
   326  	r.SetBytes(aa.k).Add(r, new(big.Int).SetBytes(bb.k))
   327  	s.fromBig(r)
   328  	return s
   329  }
   330  
   331  func (s *wScl) Sub(a, b Scalar) Scalar {
   332  	aa, bb := s.cvtScl(a), s.cvtScl(b)
   333  	r := new(big.Int)
   334  	r.SetBytes(aa.k).Sub(r, new(big.Int).SetBytes(bb.k))
   335  	s.fromBig(r)
   336  	return s
   337  }
   338  
   339  func (s *wScl) Mul(a, b Scalar) Scalar {
   340  	aa, bb := s.cvtScl(a), s.cvtScl(b)
   341  	r := new(big.Int)
   342  	r.SetBytes(aa.k).Mul(r, new(big.Int).SetBytes(bb.k))
   343  	s.fromBig(r)
   344  	return s
   345  }
   346  
   347  func (s *wScl) Neg(a Scalar) Scalar {
   348  	aa := s.cvtScl(a)
   349  	r := new(big.Int)
   350  	r.SetBytes(aa.k).Neg(r)
   351  	s.fromBig(r)
   352  	return s
   353  }
   354  
   355  func (s *wScl) Inv(a Scalar) Scalar {
   356  	aa := s.cvtScl(a)
   357  	r := new(big.Int)
   358  	r.SetBytes(aa.k).ModInverse(r, s.c.Params().N)
   359  	s.fromBig(r)
   360  	return s
   361  }
   362  
   363  func (s *wScl) MarshalBinary() (data []byte, err error) {
   364  	data = make([]byte, (s.c.Params().BitSize+7)/8)
   365  	copy(data, s.k)
   366  	return data, nil
   367  }
   368  
   369  func (s *wScl) UnmarshalBinary(b []byte) error {
   370  	l := (s.c.Params().BitSize + 7) / 8
   371  	s.k = make([]byte, l)
   372  	copy(s.k[l-len(b):l], b)
   373  	return nil
   374  }
   375  
   376  func (g wG) mapToCurveParams() (mapping func(u *big.Int) *wElt, h crypto.Hash, L uint) {
   377  	var Z, C2 big.Int
   378  	switch g.c.Params().BitSize {
   379  	case 256:
   380  		Z.SetInt64(-10)
   381  		C2.SetString("0x78bc71a02d89ec07214623f6d0f955072c7cc05604a5a6e23ffbf67115fa5301", 0)
   382  		h = crypto.SHA256
   383  		L = 48
   384  	case 384:
   385  		Z.SetInt64(-12)
   386  		C2.SetString("0x19877cc1041b7555743c0ae2e3a3e61fb2aaa2e0e87ea557a563d8b598a0940d0a697a9e0b9e92cfaa314f583c9d066", 0)
   387  		h = crypto.SHA384
   388  		L = 72
   389  	case 521:
   390  		Z.SetInt64(-4)
   391  		C2.SetInt64(8)
   392  		h = crypto.SHA512
   393  		L = 98
   394  	default:
   395  		panic("curve not supported")
   396  	}
   397  	return func(u *big.Int) *wElt { return g.sswu3mod4Map(u, &Z, &C2) }, h, L
   398  }
   399  
   400  func (g wG) sswu3mod4Map(u *big.Int, Z, C2 *big.Int) *wElt {
   401  	tv1 := new(big.Int)
   402  	tv2 := new(big.Int)
   403  	tv3 := new(big.Int)
   404  	tv4 := new(big.Int)
   405  	xn := new(big.Int)
   406  	xd := new(big.Int)
   407  	x1n := new(big.Int)
   408  	x2n := new(big.Int)
   409  	gx1 := new(big.Int)
   410  	gxd := new(big.Int)
   411  	y1 := new(big.Int)
   412  	y2 := new(big.Int)
   413  	x := new(big.Int)
   414  	y := new(big.Int)
   415  
   416  	A := big.NewInt(-3)
   417  	B := g.c.Params().B
   418  	p := g.c.Params().P
   419  	c1 := new(big.Int)
   420  	c1.Sub(p, big.NewInt(3)).Rsh(c1, 2) // 1.  c1 = (q - 3) / 4
   421  
   422  	add := func(c, a, b *big.Int) { c.Add(a, b).Mod(c, p) }
   423  	mul := func(c, a, b *big.Int) { c.Mul(a, b).Mod(c, p) }
   424  	sqr := func(c, a *big.Int) { c.Mul(a, a).Mod(c, p) }
   425  	exp := func(c, a, b *big.Int) { c.Exp(a, b, p) }
   426  	sgn := func(a *big.Int) uint { a.Mod(a, p); return a.Bit(0) }
   427  	cmv := func(c, a, b *big.Int, k bool) {
   428  		if k {
   429  			c.Set(b)
   430  		} else {
   431  			c.Set(a)
   432  		}
   433  	}
   434  
   435  	sqr(tv1, u)                 // 1.  tv1 = u^2
   436  	mul(tv3, Z, tv1)            // 2.  tv3 = Z * tv1
   437  	sqr(tv2, tv3)               // 3.  tv2 = tv3^2
   438  	add(xd, tv2, tv3)           // 4.   xd = tv2 + tv3
   439  	add(x1n, xd, big.NewInt(1)) // 5.  x1n = xd + 1
   440  	mul(x1n, x1n, B)            // 6.  x1n = x1n * B
   441  	tv4.Neg(A)                  //
   442  	mul(xd, tv4, xd)            // 7.   xd = -A * xd
   443  	e1 := xd.Sign() == 0        // 8.   e1 = xd == 0
   444  	mul(tv4, Z, A)              //
   445  	cmv(xd, xd, tv4, e1)        // 9.   xd = CMOV(xd, Z * A, e1)
   446  	sqr(tv2, xd)                // 10. tv2 = xd^2
   447  	mul(gxd, tv2, xd)           // 11. gxd = tv2 * xd
   448  	mul(tv2, A, tv2)            // 12. tv2 = A * tv2
   449  	sqr(gx1, x1n)               // 13. gx1 = x1n^2
   450  	add(gx1, gx1, tv2)          // 14. gx1 = gx1 + tv2
   451  	mul(gx1, gx1, x1n)          // 15. gx1 = gx1 * x1n
   452  	mul(tv2, B, gxd)            // 16. tv2 = B * gxd
   453  	add(gx1, gx1, tv2)          // 17. gx1 = gx1 + tv2
   454  	sqr(tv4, gxd)               // 18. tv4 = gxd^2
   455  	mul(tv2, gx1, gxd)          // 19. tv2 = gx1 * gxd
   456  	mul(tv4, tv4, tv2)          // 20. tv4 = tv4 * tv2
   457  	exp(y1, tv4, c1)            // 21.  y1 = tv4^c1
   458  	mul(y1, y1, tv2)            // 22.  y1 = y1 * tv2
   459  	mul(x2n, tv3, x1n)          // 23. x2n = tv3 * x1n
   460  	mul(y2, y1, C2)             // 24.  y2 = y1 * c2
   461  	mul(y2, y2, tv1)            // 25.  y2 = y2 * tv1
   462  	mul(y2, y2, u)              // 26.  y2 = y2 * u
   463  	sqr(tv2, y1)                // 27. tv2 = y1^2
   464  	mul(tv2, tv2, gxd)          // 28. tv2 = tv2 * gxd
   465  	e2 := tv2.Cmp(gx1) == 0     // 29.  e2 = tv2 == gx1
   466  	cmv(xn, x2n, x1n, e2)       // 30.  xn = CMOV(x2n, x1n, e2)
   467  	cmv(y, y2, y1, e2)          // 31.   y = CMOV(y2, y1, e2)
   468  	e3 := sgn(u) == sgn(y)      // 32.  e3 = sgn0(u) == sgn0(y)
   469  	tv1.Neg(y)                  //
   470  	cmv(y, tv1, y, e3)          // 33.   y = CMOV(-y, y, e3)
   471  	tv1.ModInverse(xd, p)       //
   472  	mul(x, xn, tv1)             // 34. return (xn, xd, y, 1)
   473  	y.Mod(y, p)
   474  	return &wElt{g, x, y}
   475  }