github.com/ronperry/cryptoedge@v0.0.0-20150815114006-cc363e290743/eccutil/eccutil.go (about)

     1  // Package eccutil contains various utility functions to implement protocols over ecc
     2  package eccutil
     3  
     4  import (
     5  	"crypto/elliptic"
     6  	"crypto/rand"
     7  	"crypto/sha1"
     8  	"errors"
     9  	"io"
    10  	"math/big"
    11  )
    12  
    13  // MaxLoopCount is the maximum number of tries we do for parameter search
    14  const MaxLoopCount = 1000
    15  
    16  // Point describes a point on a curve
    17  type Point struct {
    18  	X *big.Int
    19  	Y *big.Int
    20  }
    21  
    22  // Curve encapsulates an elliptic curve
    23  type Curve struct {
    24  	Curve  elliptic.Curve
    25  	Rand   io.Reader
    26  	Params *elliptic.CurveParams
    27  	Nminus *big.Int
    28  	Hash   func([]byte) []byte
    29  }
    30  
    31  var (
    32  	// ErrMsgShort is returned when the message is too short
    33  	ErrMsgShort = errors.New("singhdas: Message too short")
    34  	// ErrBadCoordinate is returned if a coordinate is an illegal value
    35  	ErrBadCoordinate = errors.New("singhdas: Coordinate illegal value")
    36  	// ErrCoordinateBase is returned if a coordinate is in the base point
    37  	ErrCoordinateBase = errors.New("singhdas: Coordinate illegal value (basepoint,reflection,inverse)")
    38  	// ErrMaxLoop is returned if we cannot find parameters in time, should only happen during testing
    39  	ErrMaxLoop = errors.New("singhdas: Cannot find parameters")
    40  	// ErrNotRelPrime is returned if two numbers must be relative prime but are not
    41  	ErrNotRelPrime = errors.New("singhdas: Not relative prime")
    42  	// ErrParamReuse is returned if blinding parameters are used again
    43  	ErrParamReuse = errors.New("singhdas: Do not reuse blinding parameters")
    44  	// ErrBadBlindParam is returned if the blinding parameter is unusable
    45  	ErrBadBlindParam = errors.New("singhdas: Blinding parameter unusable")
    46  	// ErrHashDif is returned if the hash of the message diffs from the signature
    47  	ErrHashDif = errors.New("singhdas: Hash does not match signature")
    48  	// ErrSigWrong is returned if the signature does not verify for the message and public key of signer
    49  	ErrSigWrong = errors.New("singhdas: Signature does not verify")
    50  )
    51  
    52  var (
    53  	// TestZero testing shortcut variable
    54  	TestZero = big.NewInt(0)
    55  	// TestOne testing shortcut variable
    56  	TestOne = big.NewInt(1)
    57  	// TestTwo testing shortcut variable
    58  	TestTwo = big.NewInt(2)
    59  )
    60  
    61  // Packaging: GetCoords, SetCoords
    62  
    63  // GetCoordinates returns the coordinates of a point
    64  func (p *Point) GetCoordinates() (x, y *big.Int) {
    65  	return p.X, p.Y
    66  }
    67  
    68  // NewPoint returns a *Point with x,y coordinates
    69  func NewPoint(x, y *big.Int) (p *Point) {
    70  	p = new(Point)
    71  	p.X, p.Y = x, y
    72  	return p
    73  }
    74  
    75  // ZeroPoint returns a point with (0,0)
    76  func ZeroPoint() *Point {
    77  	p := new(Point)
    78  	p.X, p.Y = big.NewInt(0), big.NewInt(0)
    79  	return p
    80  }
    81  
    82  // SetCurve returns a Curve encapsulating the curve given
    83  func SetCurve(curve func() elliptic.Curve, rand io.Reader, hash func([]byte) []byte) *Curve {
    84  	c := new(Curve)
    85  	c.Curve = curve()
    86  	c.Rand = rand
    87  	c.Params = c.Curve.Params()
    88  	c.Hash = hash
    89  	c.Nminus = new(big.Int)
    90  	c.Nminus = c.Nminus.Sub(c.Params.N, TestOne)
    91  	return c
    92  }
    93  
    94  // GenerateKey returns a new keypair
    95  func (curve Curve) GenerateKey() (priv []byte, pub *Point, err error) {
    96  	priv, x, y, err := elliptic.GenerateKey(curve.Curve, curve.Rand)
    97  	if err != nil {
    98  		return nil, nil, err
    99  	}
   100  	pub = new(Point)
   101  	pub.X, pub.Y = x, y
   102  	return priv, pub, nil
   103  }
   104  
   105  // TestCoordinate verifies that a number is not the  coordinate of the base point of the curve nor
   106  // that it is zero or one
   107  func (curve Curve) TestCoordinate(i *big.Int) (bool, error) {
   108  	if i.Cmp(TestZero) == 0 {
   109  		return false, ErrBadCoordinate
   110  	}
   111  	if i.Cmp(TestOne) == 0 {
   112  		return false, ErrBadCoordinate
   113  	}
   114  	return true, nil
   115  }
   116  
   117  // ModInverse calculates the modular inverse of a over P
   118  func (curve Curve) ModInverse(a *big.Int) (*big.Int, error) {
   119  	// since P should be a prime, any GCD calculation is really a waste if a != P
   120  	if a.Cmp(curve.Params.N) == 0 {
   121  		return nil, ErrNotRelPrime
   122  	}
   123  	if a.Cmp(big.NewInt(1)) == 0 { // this should never happen
   124  		return nil, ErrNotRelPrime
   125  	}
   126  	if a.Cmp(big.NewInt(0)) == 0 { // this should never happen
   127  		return nil, ErrNotRelPrime
   128  	}
   129  	z := new(big.Int)
   130  	z = z.GCD(nil, nil, a, curve.Params.N)
   131  	if z.Cmp(big.NewInt(1)) == 0 {
   132  		z = z.ModInverse(a, curve.Params.N)
   133  		return z, nil
   134  	}
   135  	return nil, ErrNotRelPrime
   136  }
   137  
   138  // TestInverse verifies that a number is the multiplicative inverse over P of another number
   139  func (curve Curve) TestInverse(a, b *big.Int) bool {
   140  	z := new(big.Int)
   141  	z = z.Mul(a, b)
   142  	z = z.Mod(z, curve.Params.N)
   143  	if z.Cmp(big.NewInt(1)) == 0 {
   144  		return true
   145  	}
   146  	return false
   147  }
   148  
   149  // TestPoint verifies that a point (x1,y2) does not equal another point (x1,y2) or it's reflection on 0
   150  func (curve Curve) TestPoint(x1, y1, x2, y2 *big.Int) (bool, error) {
   151  	_, err := curve.TestCoordinate(x1)
   152  	if err != nil {
   153  		return false, err
   154  	}
   155  	_, err = curve.TestCoordinate(x2)
   156  	if err != nil {
   157  		return false, err
   158  	}
   159  	if x1.Cmp(x1) == 0 && y1.Cmp(y2) == 0 { // Same
   160  		return false, ErrCoordinateBase
   161  	}
   162  	if x1.Cmp(y1) == 0 && y1.Cmp(x2) == 0 { // Reflect
   163  		return false, ErrCoordinateBase
   164  	}
   165  	x2neg := new(big.Int)
   166  	x2neg = x2neg.Neg(x2)
   167  	y2neg := new(big.Int)
   168  	y2neg = y2neg.Neg(y2)
   169  	if x1.Cmp(x2neg) == 0 {
   170  		return false, ErrCoordinateBase
   171  	}
   172  	if x1.Cmp(y2neg) == 0 && y1.Cmp(x2neg) == 0 {
   173  		return false, ErrCoordinateBase
   174  	}
   175  	return true, nil
   176  }
   177  
   178  // TestParams runs tests on parameters to make sure they do not form a dangerous combination
   179  func (curve Curve) TestParams(a ...*big.Int) (bool, error) {
   180  	// Simple tests first
   181  	for _, y := range a {
   182  		if y.Cmp(TestZero) == 0 { // 0 is unacceptable
   183  			return false, ErrBadCoordinate
   184  		}
   185  		if y.Cmp(TestOne) == 0 { // 1 is unacceptable
   186  			return false, ErrBadCoordinate
   187  		}
   188  		if y.Cmp(curve.Params.P) == 0 { // Cannot be the mod
   189  			return false, ErrBadCoordinate
   190  		}
   191  		if y.Cmp(curve.Params.N) == 0 { // Cannot be the order
   192  			return false, ErrBadCoordinate
   193  		}
   194  		if y.Cmp(curve.Params.Gx) == 0 { // cannot be the generator point
   195  			return false, ErrBadCoordinate
   196  		}
   197  	}
   198  	// Test duplicates and inverses
   199  	for i := 0; i < len(a)-1; i++ {
   200  		for j := i + 1; j < len(a); j++ {
   201  			if a[i].Cmp(a[j]) == 0 {
   202  				return false, ErrBadCoordinate
   203  			}
   204  			if curve.TestInverse(a[i], a[j]) {
   205  				return false, ErrBadCoordinate
   206  			}
   207  
   208  		}
   209  	}
   210  	return true, nil
   211  }
   212  
   213  // ExtractR extracts the x coordinate of a point and tests it for validity
   214  func (curve Curve) ExtractR(p *Point) (*big.Int, error) {
   215  	r := new(big.Int)
   216  	r = r.Mod(p.X, curve.Params.N)                   // This should be unnecessary since the scalarmult has taken care of it
   217  	if r.Cmp(TestZero) == 0 || r.Cmp(TestOne) == 0 { // Genkey should have taken care of this
   218  		return nil, ErrBadCoordinate
   219  	}
   220  	return r, nil
   221  }
   222  
   223  // BytesToInt is a helper function to convert bytes into Int the quick way
   224  func BytesToInt(b []byte) *big.Int {
   225  	z := new(big.Int)
   226  	z = z.SetBytes(b)
   227  	return z
   228  }
   229  
   230  // ManyMult multiplies all parameters
   231  func ManyMult(a ...*big.Int) *big.Int {
   232  	if len(a) == 1 {
   233  		return a[0]
   234  	}
   235  	z := big.NewInt(1)
   236  	if len(a) == 2 {
   237  		return z.Mul(a[0], a[1])
   238  	}
   239  	for _, y := range a {
   240  		z = z.Mul(z, y)
   241  	}
   242  	return z
   243  }
   244  
   245  // ManyAdd multiplies all parameters
   246  func ManyAdd(a ...*big.Int) *big.Int {
   247  	if len(a) == 1 {
   248  		return a[0]
   249  	}
   250  	z := big.NewInt(1)
   251  	if len(a) == 2 {
   252  		return z.Add(a[0], a[1])
   253  	}
   254  	for _, y := range a {
   255  		z = z.Add(z, y)
   256  	}
   257  	return z
   258  }
   259  
   260  // RandomElement returns a random element within (1,N-1)
   261  func (curve Curve) RandomElement() (*big.Int, error) {
   262  	for {
   263  		i, err := rand.Int(curve.Rand, curve.Nminus)
   264  		if err != nil {
   265  			return nil, err
   266  		}
   267  		if i.Cmp(TestOne) == 1 {
   268  			return i, nil
   269  		}
   270  	}
   271  }
   272  
   273  // AddPoints adds two points and returns the result
   274  func (curve Curve) AddPoints(a, b *Point) *Point {
   275  	r := new(Point)
   276  	r.X, r.Y = curve.Curve.Add(a.X, a.Y, b.X, b.Y)
   277  	return r
   278  }
   279  
   280  // Mod returns a % curve.N
   281  func (curve Curve) Mod(a *big.Int) *big.Int {
   282  	b := new(big.Int)
   283  	b = b.Mod(a, curve.Params.N)
   284  	return b
   285  }
   286  
   287  // ScalarMult returns the result of a scalar multiplication
   288  func (curve Curve) ScalarMult(p *Point, k []byte) *Point {
   289  	r := new(Point)
   290  	r.X, r.Y = curve.Curve.ScalarMult(p.X, p.Y, k)
   291  	return r
   292  }
   293  
   294  // ScalarBaseMult returns the result of a scalar multiplication
   295  func (curve Curve) ScalarBaseMult(k []byte) *Point {
   296  	r := new(Point)
   297  	r.X, r.Y = curve.Curve.ScalarBaseMult(k)
   298  	return r
   299  }
   300  
   301  // WithinRange tests if a number is in the field defined by curve.N
   302  func (curve Curve) WithinRange(i *big.Int) bool {
   303  	if i.Cmp(TestOne) != 1 && i.Cmp(curve.Nminus) != -1 {
   304  		return false
   305  	}
   306  	return true
   307  }
   308  
   309  // GenHash returns the hash of msg as Int
   310  func (curve Curve) GenHash(msg []byte) *big.Int {
   311  	// Make dependent on BitSize
   312  	x := new(big.Int)
   313  	x = x.SetBytes(curve.Hash(msg))
   314  	return x
   315  }
   316  
   317  // Sha1Hash is an example hash function doing sha1 over []byte and returning []byte
   318  func Sha1Hash(b []byte) []byte {
   319  	x := sha1.Sum(b)
   320  	return x[:]
   321  }
   322  
   323  // GenNV generates the signature blind/nonce. Copied from src/crypto/elliptic/elliptic.go GenerateKey
   324  func (curve Curve) GenNV() (nv []byte, err error) {
   325  	var mask = []byte{0xff, 0x1, 0x3, 0x7, 0xf, 0x1f, 0x3f, 0x7f}
   326  	var x *big.Int
   327  	var loopcount int
   328  	bitSize := curve.Curve.Params().BitSize
   329  	byteLen := (bitSize + 7) >> 3
   330  	nv = make([]byte, byteLen)
   331  	for x == nil {
   332  		if loopcount > MaxLoopCount {
   333  			return nil, ErrMaxLoop
   334  		}
   335  		loopcount++
   336  		_, err = io.ReadFull(curve.Rand, nv)
   337  		if err != nil {
   338  			return
   339  		}
   340  		// We have to mask off any excess bits in the case that the size of the
   341  		// underlying field is not a whole number of bytes.
   342  		nv[0] &= mask[bitSize%8]
   343  		// This is because, in tests, rand will return all zeros and we don't
   344  		// want to get the point at infinity and loop forever.
   345  		nv[1] ^= 0x42
   346  		x, _ = curve.Curve.ScalarBaseMult(nv)
   347  		if x.Cmp(big.NewInt(0)) == 0 { // This cannot really happen ever
   348  			x = nil
   349  		}
   350  	}
   351  	return
   352  }
   353  
   354  // GenNVint returns an int from genNV
   355  func (curve Curve) GenNVint() (nvi *big.Int, err error) {
   356  	nv, err := curve.GenNV()
   357  	if err != nil {
   358  		return nil, err
   359  	}
   360  	nvi = new(big.Int)
   361  	nvi = nvi.SetBytes(nv)
   362  	return nvi, nil
   363  }
   364  
   365  // PointEqual returns true if the points a and b are the same
   366  func PointEqual(a, b *Point) bool {
   367  	if a.X.Cmp(b.X) == 0 && a.Y.Cmp(b.Y) == 0 {
   368  		return true
   369  	}
   370  	return false
   371  }