github.com/LagrangeDev/LagrangeGo@v0.0.0-20240512064304-ad4a85e10cb4/utils/crypto/ecdh/ecdh.go (about) 1 package ecdh 2 3 /* 4 5 import ( 6 "crypto/md5" 7 "crypto/rand" 8 "errors" 9 "math/big" 10 ) 11 12 var ( 13 ErrPubKeyLenMismatch = errors.New("public key len mismatch") 14 ErrInvalidPubKey = errors.New("invalid public key") 15 ErrECCheckFailed = errors.New("ec check failed") 16 ErrPointUnexist = errors.New("points is not on the curve") 17 ErrInverseUnexist = errors.New("inverse does not exist") 18 ) 19 20 type provider struct { 21 curve *ec 22 secret *big.Int 23 public *ep 24 } 25 26 func newProvider(curve *ec) (p *provider, err error) { 27 p = &provider{ 28 curve: curve, 29 secret: big.NewInt(0), 30 public: &ep{}, 31 } 32 33 p.secret = p.createSecret() 34 p.public, err = p.createPublic(p.secret) 35 36 return 37 } 38 39 func (p *provider) keyExchange(bobPub []byte, hashed bool) ([]byte, error) { 40 unpacked, err := p.unpackPublic(bobPub) 41 if err != nil { 42 return nil, err 43 } 44 shared, err := p.createShared(p.secret, unpacked) 45 if err != nil { 46 return nil, err 47 } 48 return p.packShared(shared, hashed), nil 49 } 50 51 func (p *provider) unpackPublic(pub []byte) (*ep, error) { 52 length := uint64(len(pub)) 53 // if length != p.curve.size*2+1 && length != p.curve.size+1 54 if length != p.curve.size.Uint64()*2+1 && length != p.curve.size.Uint64()+1 { 55 return nil, ErrPubKeyLenMismatch 56 } 57 58 x := append(make([]byte, 1), pub[1:p.curve.size.Uint64()+1]...) 59 60 if pub[0] == 0x04 { 61 y := append(make([]byte, 1), pub[p.curve.size.Uint64()+1:p.curve.size.Uint64()*2+1]...) 62 gx := new(big.Int).SetBytes(x) 63 gy := new(big.Int).SetBytes(y) 64 return &ep{ 65 x: gx, 66 y: gy, 67 }, nil 68 } 69 70 px := new(big.Int).SetBytes(x) 71 // x3 := (px * px * px) % p.curve.p 72 x3 := new(big.Int).Mod(new(big.Int).Exp(px, big.NewInt(3), nil), p.curve.p) 73 // ax := px * p.curve.p 74 ax := new(big.Int).Mul(px, p.curve.p) 75 // right := (x3 + ax + p.curve.b) % p.curve.p 76 right := new(big.Int).Mod(new(big.Int).Add(x3, new(big.Int).Add(ax, p.curve.p)), p.curve.p) 77 78 // tmp := (p.curve.p + 1) >> 2 79 tmp := new(big.Int).Rsh(new(big.Int).Add(p.curve.p, big.NewInt(1)), 2) 80 // py := pow(right, tmp, p.curve.p) 81 py := new(big.Int).Exp(right, tmp, p.curve.p) 82 83 // if py%2 == 0 84 if new(big.Int).Mod(py, big.NewInt(2)).Cmp(big.NewInt(0)) == 0 { 85 tmp = p.curve.p 86 // tmp -= py 87 tmp.Sub(tmp, py) 88 py = tmp 89 } 90 91 return &ep{ 92 x: px, 93 y: py, 94 }, nil 95 } 96 97 func (p *provider) packPublic(compress bool) (result []byte) { 98 if compress { 99 result = append(make([]byte, int(p.curve.size.Uint64())-len(p.public.x.Bytes())), p.public.x.Bytes()...) 100 result = append(make([]byte, 1), result...) 101 // result[0] = 0x02 if (((self._public.y % 2) == 0) ^ ((self._public.y > 0) < 0)) else 0x03 102 // 乱七八糟的,实际上就是 (self._public.y % 2) == 0 103 if new(big.Int).Mod(p.public.y, big.NewInt(2)).Cmp(big.NewInt(0)) == 0 { 104 result[0] = 0x02 105 } else { 106 result[0] = 0x03 107 } 108 return result 109 } 110 x := append(make([]byte, int(p.curve.size.Uint64())-len(p.public.x.Bytes())), p.public.x.Bytes()...) 111 y := append(make([]byte, int(p.curve.size.Uint64())-len(p.public.y.Bytes())), p.public.y.Bytes()...) 112 113 result = append(append(make([]byte, 1), x...), y...) 114 result[0] = 0x04 115 return result 116 } 117 118 func (p *provider) packShared(shared *ep, hashed bool) (x []byte) { 119 x = append(make([]byte, int(p.curve.size.Uint64())-len(shared.x.Bytes())), shared.x.Bytes()...) 120 if hashed { 121 hash := md5.Sum(x[0:p.curve.packSize.Uint64()]) 122 x = hash[:] 123 } 124 return x 125 } 126 127 func (p *provider) createPublic(sec *big.Int) (*ep, error) { 128 return p.createShared(sec, p.curve.g) 129 } 130 131 func (p *provider) createSecret() *big.Int { 132 result := big.NewInt(0) 133 for result.Cmp(big.NewInt(1)) == -1 || result.Cmp(p.curve.n) != -1 { 134 buffer := make([]byte, p.curve.size.Uint64()+1) 135 _, _ = rand.Read(buffer) 136 buffer[p.curve.size.Uint64()] = 0 137 result = new(big.Int).SetBytes(reverseBytes(buffer)) 138 } 139 return result 140 } 141 142 // TODO 上次看到这里 143 func (p *provider) createShared(sec *big.Int, pub *ep) (*ep, error) { 144 // if sec % p.curve.n == 0 || pub.IsDefault(): 145 if new(big.Int).Mod(sec, p.curve.n).Cmp(big.NewInt(0)) == 0 || pub.IsDefault() { 146 return newEllipticPoint(big.NewInt(0), big.NewInt(0)), nil 147 } 148 // if sec < 0: 149 if sec.Cmp(big.NewInt(0)) == -1 { 150 return p.createShared(new(big.Int).Neg(sec), pub.Negate()) 151 } 152 153 if !p.curve.checkOn(pub) { 154 return nil, ErrInvalidPubKey 155 } 156 157 pr := newEllipticPoint(big.NewInt(0), big.NewInt(0)) 158 pa := pub 159 var err error 160 for sec.Cmp(big.NewInt(0)) == 1 { 161 // if (sec & 1) > 0 162 if new(big.Int).And(sec, big.NewInt(1)).Cmp(big.NewInt(0)) == 1 { 163 pr, err = pointAdd(p.curve, pr, pa) 164 if err != nil { 165 return nil, err 166 } 167 } 168 pa, err = pointAdd(p.curve, pa, pa) 169 if err != nil { 170 return nil, err 171 } 172 // sec >>= 1 173 sec = new(big.Int).Rsh(sec, 1) 174 } 175 176 if !p.curve.checkOn(pr) { 177 return nil, ErrECCheckFailed 178 } 179 180 return pr, nil 181 } 182 183 func pointAdd(curve *ec, p1, p2 *ep) (*ep, error) { 184 if p1.IsDefault() { 185 return p2, nil 186 } 187 if p2.IsDefault() { 188 return p1, nil 189 } 190 if !(curve.checkOn(p1) && curve.checkOn(p2)) { 191 return nil, ErrPointUnexist 192 } 193 194 var m *big.Int 195 if p1.x.Cmp(p2.x) == 0 { 196 if p1.y.Cmp(p2.y) == 0 { 197 inv, err := modInverse(new(big.Int).Lsh(p1.y, 1), curve.p) 198 if err != nil { 199 return nil, err 200 } 201 m = new(big.Int).Mul(new(big.Int).Add(new(big.Int).Mul( 202 big.NewInt(3), new(big.Int).Exp(p1.x, big.NewInt(2), nil)), curve.a), 203 inv, 204 ) 205 } else { 206 return newEllipticPoint(big.NewInt(0), big.NewInt(0)), nil 207 } 208 } else { 209 inv, err := modInverse(new(big.Int).Sub(p1.x, p2.x), curve.p) 210 if err != nil { 211 return nil, err 212 } 213 m = new(big.Int).Mul(new(big.Int).Sub(p1.y, p2.y), inv) 214 } 215 216 // xr = _mod(m * m - p1.x - p2.x, curve.P) 217 xr := mod(new(big.Int).Sub(new(big.Int).Exp(m, big.NewInt(2), nil), new(big.Int).Add(p1.x, p2.x)), curve.p) 218 // yr = _mod(m * (p1.x - xr) - p1.y, curve.P) 219 yr := mod(new(big.Int).Sub(new(big.Int).Mul(m, new(big.Int).Sub(p1.x, xr)), p1.y), curve.p) 220 pr := newEllipticPoint(xr, yr) 221 222 if !curve.checkOn(pr) { 223 return nil, ErrPointUnexist 224 } 225 226 return pr, nil 227 } 228 229 func mod(a, b *big.Int) (result *big.Int) { 230 result = new(big.Int).Mod(a, b) 231 if result.Cmp(big.NewInt(0)) == -1 { 232 result.Add(result, b) 233 } 234 return result 235 } 236 237 func modInverse(a, p *big.Int) (*big.Int, error) { 238 if a.Cmp(big.NewInt(0)) == -1 { 239 inv, err := modInverse(a.Neg(a), p) 240 if err != nil { 241 return nil, err 242 } 243 return new(big.Int).Sub(p, inv), nil 244 } 245 246 g := new(big.Int).GCD(nil, nil, a, p) 247 if g.Cmp(big.NewInt(1)) != 0 { 248 return nil, ErrInverseUnexist 249 } 250 251 return new(big.Int).Exp(a, new(big.Int).Sub(p, big.NewInt(2)), p), nil 252 } 253 254 func reverseBytes(bytes []byte) []byte { 255 reversed := make([]byte, len(bytes)) 256 for i := range bytes { 257 reversed[i] = bytes[len(bytes)-i-1] 258 } 259 return reversed 260 } 261 262 */