github.com/LagrangeDev/LagrangeGo@v0.0.0-20240512064304-ad4a85e10cb4/utils/crypto/ecdh/ecdh.go (about)

     1  package ecdh
     2  
     3  /*
     4  
     5  import (
     6  	"crypto/md5"
     7  	"crypto/rand"
     8  	"errors"
     9  	"math/big"
    10  )
    11  
    12  var (
    13  	ErrPubKeyLenMismatch = errors.New("public key len mismatch")
    14  	ErrInvalidPubKey     = errors.New("invalid public key")
    15  	ErrECCheckFailed     = errors.New("ec check failed")
    16  	ErrPointUnexist      = errors.New("points is not on the curve")
    17  	ErrInverseUnexist    = errors.New("inverse does not exist")
    18  )
    19  
    20  type provider struct {
    21  	curve  *ec
    22  	secret *big.Int
    23  	public *ep
    24  }
    25  
    26  func newProvider(curve *ec) (p *provider, err error) {
    27  	p = &provider{
    28  		curve:  curve,
    29  		secret: big.NewInt(0),
    30  		public: &ep{},
    31  	}
    32  
    33  	p.secret = p.createSecret()
    34  	p.public, err = p.createPublic(p.secret)
    35  
    36  	return
    37  }
    38  
    39  func (p *provider) keyExchange(bobPub []byte, hashed bool) ([]byte, error) {
    40  	unpacked, err := p.unpackPublic(bobPub)
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  	shared, err := p.createShared(p.secret, unpacked)
    45  	if err != nil {
    46  		return nil, err
    47  	}
    48  	return p.packShared(shared, hashed), nil
    49  }
    50  
    51  func (p *provider) unpackPublic(pub []byte) (*ep, error) {
    52  	length := uint64(len(pub))
    53  	// if length != p.curve.size*2+1 && length != p.curve.size+1
    54  	if length != p.curve.size.Uint64()*2+1 && length != p.curve.size.Uint64()+1 {
    55  		return nil, ErrPubKeyLenMismatch
    56  	}
    57  
    58  	x := append(make([]byte, 1), pub[1:p.curve.size.Uint64()+1]...)
    59  
    60  	if pub[0] == 0x04 {
    61  		y := append(make([]byte, 1), pub[p.curve.size.Uint64()+1:p.curve.size.Uint64()*2+1]...)
    62  		gx := new(big.Int).SetBytes(x)
    63  		gy := new(big.Int).SetBytes(y)
    64  		return &ep{
    65  			x: gx,
    66  			y: gy,
    67  		}, nil
    68  	}
    69  
    70  	px := new(big.Int).SetBytes(x)
    71  	// x3 := (px * px * px) % p.curve.p
    72  	x3 := new(big.Int).Mod(new(big.Int).Exp(px, big.NewInt(3), nil), p.curve.p)
    73  	// ax := px * p.curve.p
    74  	ax := new(big.Int).Mul(px, p.curve.p)
    75  	// right := (x3 + ax + p.curve.b) % p.curve.p
    76  	right := new(big.Int).Mod(new(big.Int).Add(x3, new(big.Int).Add(ax, p.curve.p)), p.curve.p)
    77  
    78  	// tmp := (p.curve.p + 1) >> 2
    79  	tmp := new(big.Int).Rsh(new(big.Int).Add(p.curve.p, big.NewInt(1)), 2)
    80  	// py := pow(right, tmp, p.curve.p)
    81  	py := new(big.Int).Exp(right, tmp, p.curve.p)
    82  
    83  	// if py%2 == 0
    84  	if new(big.Int).Mod(py, big.NewInt(2)).Cmp(big.NewInt(0)) == 0 {
    85  		tmp = p.curve.p
    86  		// tmp -= py
    87  		tmp.Sub(tmp, py)
    88  		py = tmp
    89  	}
    90  
    91  	return &ep{
    92  		x: px,
    93  		y: py,
    94  	}, nil
    95  }
    96  
    97  func (p *provider) packPublic(compress bool) (result []byte) {
    98  	if compress {
    99  		result = append(make([]byte, int(p.curve.size.Uint64())-len(p.public.x.Bytes())), p.public.x.Bytes()...)
   100  		result = append(make([]byte, 1), result...)
   101  		// result[0] = 0x02 if (((self._public.y % 2) == 0) ^ ((self._public.y > 0) < 0)) else 0x03
   102  		// 乱七八糟的,实际上就是 (self._public.y % 2) == 0
   103  		if new(big.Int).Mod(p.public.y, big.NewInt(2)).Cmp(big.NewInt(0)) == 0 {
   104  			result[0] = 0x02
   105  		} else {
   106  			result[0] = 0x03
   107  		}
   108  		return result
   109  	}
   110  	x := append(make([]byte, int(p.curve.size.Uint64())-len(p.public.x.Bytes())), p.public.x.Bytes()...)
   111  	y := append(make([]byte, int(p.curve.size.Uint64())-len(p.public.y.Bytes())), p.public.y.Bytes()...)
   112  
   113  	result = append(append(make([]byte, 1), x...), y...)
   114  	result[0] = 0x04
   115  	return result
   116  }
   117  
   118  func (p *provider) packShared(shared *ep, hashed bool) (x []byte) {
   119  	x = append(make([]byte, int(p.curve.size.Uint64())-len(shared.x.Bytes())), shared.x.Bytes()...)
   120  	if hashed {
   121  		hash := md5.Sum(x[0:p.curve.packSize.Uint64()])
   122  		x = hash[:]
   123  	}
   124  	return x
   125  }
   126  
   127  func (p *provider) createPublic(sec *big.Int) (*ep, error) {
   128  	return p.createShared(sec, p.curve.g)
   129  }
   130  
   131  func (p *provider) createSecret() *big.Int {
   132  	result := big.NewInt(0)
   133  	for result.Cmp(big.NewInt(1)) == -1 || result.Cmp(p.curve.n) != -1 {
   134  		buffer := make([]byte, p.curve.size.Uint64()+1)
   135  		_, _ = rand.Read(buffer)
   136  		buffer[p.curve.size.Uint64()] = 0
   137  		result = new(big.Int).SetBytes(reverseBytes(buffer))
   138  	}
   139  	return result
   140  }
   141  
   142  // TODO 上次看到这里
   143  func (p *provider) createShared(sec *big.Int, pub *ep) (*ep, error) {
   144  	// if sec % p.curve.n == 0 || pub.IsDefault():
   145  	if new(big.Int).Mod(sec, p.curve.n).Cmp(big.NewInt(0)) == 0 || pub.IsDefault() {
   146  		return newEllipticPoint(big.NewInt(0), big.NewInt(0)), nil
   147  	}
   148  	// if sec < 0:
   149  	if sec.Cmp(big.NewInt(0)) == -1 {
   150  		return p.createShared(new(big.Int).Neg(sec), pub.Negate())
   151  	}
   152  
   153  	if !p.curve.checkOn(pub) {
   154  		return nil, ErrInvalidPubKey
   155  	}
   156  
   157  	pr := newEllipticPoint(big.NewInt(0), big.NewInt(0))
   158  	pa := pub
   159  	var err error
   160  	for sec.Cmp(big.NewInt(0)) == 1 {
   161  		// if (sec & 1) > 0
   162  		if new(big.Int).And(sec, big.NewInt(1)).Cmp(big.NewInt(0)) == 1 {
   163  			pr, err = pointAdd(p.curve, pr, pa)
   164  			if err != nil {
   165  				return nil, err
   166  			}
   167  		}
   168  		pa, err = pointAdd(p.curve, pa, pa)
   169  		if err != nil {
   170  			return nil, err
   171  		}
   172  		// sec >>= 1
   173  		sec = new(big.Int).Rsh(sec, 1)
   174  	}
   175  
   176  	if !p.curve.checkOn(pr) {
   177  		return nil, ErrECCheckFailed
   178  	}
   179  
   180  	return pr, nil
   181  }
   182  
   183  func pointAdd(curve *ec, p1, p2 *ep) (*ep, error) {
   184  	if p1.IsDefault() {
   185  		return p2, nil
   186  	}
   187  	if p2.IsDefault() {
   188  		return p1, nil
   189  	}
   190  	if !(curve.checkOn(p1) && curve.checkOn(p2)) {
   191  		return nil, ErrPointUnexist
   192  	}
   193  
   194  	var m *big.Int
   195  	if p1.x.Cmp(p2.x) == 0 {
   196  		if p1.y.Cmp(p2.y) == 0 {
   197  			inv, err := modInverse(new(big.Int).Lsh(p1.y, 1), curve.p)
   198  			if err != nil {
   199  				return nil, err
   200  			}
   201  			m = new(big.Int).Mul(new(big.Int).Add(new(big.Int).Mul(
   202  				big.NewInt(3), new(big.Int).Exp(p1.x, big.NewInt(2), nil)), curve.a),
   203  				inv,
   204  			)
   205  		} else {
   206  			return newEllipticPoint(big.NewInt(0), big.NewInt(0)), nil
   207  		}
   208  	} else {
   209  		inv, err := modInverse(new(big.Int).Sub(p1.x, p2.x), curve.p)
   210  		if err != nil {
   211  			return nil, err
   212  		}
   213  		m = new(big.Int).Mul(new(big.Int).Sub(p1.y, p2.y), inv)
   214  	}
   215  
   216  	// xr = _mod(m * m - p1.x - p2.x, curve.P)
   217  	xr := mod(new(big.Int).Sub(new(big.Int).Exp(m, big.NewInt(2), nil), new(big.Int).Add(p1.x, p2.x)), curve.p)
   218  	// yr = _mod(m * (p1.x - xr) - p1.y, curve.P)
   219  	yr := mod(new(big.Int).Sub(new(big.Int).Mul(m, new(big.Int).Sub(p1.x, xr)), p1.y), curve.p)
   220  	pr := newEllipticPoint(xr, yr)
   221  
   222  	if !curve.checkOn(pr) {
   223  		return nil, ErrPointUnexist
   224  	}
   225  
   226  	return pr, nil
   227  }
   228  
   229  func mod(a, b *big.Int) (result *big.Int) {
   230  	result = new(big.Int).Mod(a, b)
   231  	if result.Cmp(big.NewInt(0)) == -1 {
   232  		result.Add(result, b)
   233  	}
   234  	return result
   235  }
   236  
   237  func modInverse(a, p *big.Int) (*big.Int, error) {
   238  	if a.Cmp(big.NewInt(0)) == -1 {
   239  		inv, err := modInverse(a.Neg(a), p)
   240  		if err != nil {
   241  			return nil, err
   242  		}
   243  		return new(big.Int).Sub(p, inv), nil
   244  	}
   245  
   246  	g := new(big.Int).GCD(nil, nil, a, p)
   247  	if g.Cmp(big.NewInt(1)) != 0 {
   248  		return nil, ErrInverseUnexist
   249  	}
   250  
   251  	return new(big.Int).Exp(a, new(big.Int).Sub(p, big.NewInt(2)), p), nil
   252  }
   253  
   254  func reverseBytes(bytes []byte) []byte {
   255  	reversed := make([]byte, len(bytes))
   256  	for i := range bytes {
   257  		reversed[i] = bytes[len(bytes)-i-1]
   258  	}
   259  	return reversed
   260  }
   261  
   262  */