github.com/MerlinKodo/sing-shadowsocks2@v0.1.6/shadowstream/method.go (about) 1 package shadowstream 2 3 import ( 4 "context" 5 "crypto/aes" 6 "crypto/cipher" 7 "crypto/md5" 8 "crypto/rc4" 9 "net" 10 "os" 11 12 C "github.com/MerlinKodo/sing-shadowsocks2/cipher" 13 "github.com/MerlinKodo/sing-shadowsocks2/internal/legacykey" 14 "github.com/MerlinKodo/sing-shadowsocks2/internal/shadowio" 15 "github.com/sagernet/sing/common" 16 "github.com/sagernet/sing/common/buf" 17 "github.com/sagernet/sing/common/bufio" 18 E "github.com/sagernet/sing/common/exceptions" 19 M "github.com/sagernet/sing/common/metadata" 20 N "github.com/sagernet/sing/common/network" 21 22 "github.com/aead/chacha20/chacha" 23 "golang.org/x/crypto/chacha20" 24 ) 25 26 var MethodList = []string{ 27 "aes-128-ctr", 28 "aes-192-ctr", 29 "aes-256-ctr", 30 "aes-128-cfb", 31 "aes-192-cfb", 32 "aes-256-cfb", 33 "rc4-md5", 34 "chacha20-ietf", 35 "xchacha20", 36 "chacha20", 37 } 38 39 func init() { 40 C.RegisterMethod(MethodList, NewMethod) 41 } 42 43 type Method struct { 44 keyLength int 45 saltLength int 46 encryptConstructor func(key []byte, salt []byte) (cipher.Stream, error) 47 decryptConstructor func(key []byte, salt []byte) (cipher.Stream, error) 48 key []byte 49 } 50 51 func NewMethod(ctx context.Context, methodName string, options C.MethodOptions) (C.Method, error) { 52 m := &Method{} 53 switch methodName { 54 case "aes-128-ctr": 55 m.keyLength = 16 56 m.saltLength = aes.BlockSize 57 m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) 58 m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) 59 case "aes-192-ctr": 60 m.keyLength = 24 61 m.saltLength = aes.BlockSize 62 m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) 63 m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) 64 case "aes-256-ctr": 65 m.keyLength = 32 66 m.saltLength = aes.BlockSize 67 m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) 68 m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) 69 case "aes-128-cfb": 70 m.keyLength = 16 71 m.saltLength = aes.BlockSize 72 m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter) 73 m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter) 74 case "aes-192-cfb": 75 m.keyLength = 24 76 m.saltLength = aes.BlockSize 77 m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter) 78 m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter) 79 case "aes-256-cfb": 80 m.keyLength = 32 81 m.saltLength = aes.BlockSize 82 m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter) 83 m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter) 84 case "rc4-md5": 85 m.keyLength = 16 86 m.saltLength = 16 87 m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { 88 h := md5.New() 89 h.Write(key) 90 h.Write(salt) 91 return rc4.NewCipher(h.Sum(nil)) 92 } 93 m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { 94 h := md5.New() 95 h.Write(key) 96 h.Write(salt) 97 return rc4.NewCipher(h.Sum(nil)) 98 } 99 case "chacha20-ietf": 100 m.keyLength = chacha20.KeySize 101 m.saltLength = chacha20.NonceSize 102 m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { 103 return chacha20.NewUnauthenticatedCipher(key, salt) 104 } 105 m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { 106 return chacha20.NewUnauthenticatedCipher(key, salt) 107 } 108 case "xchacha20": 109 m.keyLength = chacha20.KeySize 110 m.saltLength = chacha20.NonceSizeX 111 m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { 112 return chacha20.NewUnauthenticatedCipher(key, salt) 113 } 114 m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { 115 return chacha20.NewUnauthenticatedCipher(key, salt) 116 } 117 case "chacha20": 118 m.keyLength = chacha.KeySize 119 m.saltLength = chacha.NonceSize 120 m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { 121 return chacha.NewCipher(salt, key, 20) 122 } 123 m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { 124 return chacha.NewCipher(salt, key, 20) 125 } 126 default: 127 return nil, os.ErrInvalid 128 } 129 if len(options.Key) == m.keyLength { 130 m.key = options.Key 131 } else if len(options.Key) > 0 { 132 return nil, E.New("bad key length, required ", m.keyLength, ", got ", len(options.Key)) 133 } else if options.Password != "" { 134 m.key = legacykey.Key([]byte(options.Password), m.keyLength) 135 } else { 136 return nil, C.ErrMissingPassword 137 } 138 return m, nil 139 } 140 141 func blockStream(blockCreator func(key []byte) (cipher.Block, error), streamCreator func(block cipher.Block, iv []byte) cipher.Stream) func([]byte, []byte) (cipher.Stream, error) { 142 return func(key []byte, iv []byte) (cipher.Stream, error) { 143 block, err := blockCreator(key) 144 if err != nil { 145 return nil, err 146 } 147 return streamCreator(block, iv), err 148 } 149 } 150 151 func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { 152 ssConn := &clientConn{ 153 ExtendedConn: bufio.NewExtendedConn(conn), 154 method: m, 155 destination: destination, 156 } 157 return ssConn, common.Error(ssConn.Write(nil)) 158 } 159 160 func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { 161 return &clientConn{ 162 ExtendedConn: bufio.NewExtendedConn(conn), 163 method: m, 164 destination: destination, 165 } 166 } 167 168 func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { 169 pc := &clientPacketConn{ 170 ExtendedConn: bufio.NewExtendedConn(conn), 171 method: m, 172 } 173 if waitRead, isWaitRead := N.CastReader[shadowio.WaitReadReader](conn); isWaitRead { 174 return &clientWaitPacketConn{ 175 clientPacketConn: pc, 176 waitRead: waitRead, 177 } 178 } 179 return pc 180 } 181 182 type clientConn struct { 183 N.ExtendedConn 184 method *Method 185 destination M.Socksaddr 186 readStream cipher.Stream 187 writeStream cipher.Stream 188 } 189 190 func (c *clientConn) readResponse() error { 191 saltBuffer := buf.NewSize(c.method.saltLength) 192 defer saltBuffer.Release() 193 _, err := saltBuffer.ReadFullFrom(c.ExtendedConn, c.method.saltLength) 194 if err != nil { 195 return err 196 } 197 c.readStream, err = c.method.decryptConstructor(c.method.key, saltBuffer.Bytes()) 198 return err 199 } 200 201 func (c *clientConn) Read(p []byte) (n int, err error) { 202 if c.readStream == nil { 203 err = c.readResponse() 204 if err != nil { 205 return 206 } 207 } 208 n, err = c.ExtendedConn.Read(p) 209 if err != nil { 210 return 211 } 212 c.readStream.XORKeyStream(p[:n], p[:n]) 213 return 214 } 215 216 func (c *clientConn) Write(p []byte) (n int, err error) { 217 if c.writeStream == nil { 218 buffer := buf.NewSize(c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(c.destination) + len(p)) 219 defer buffer.Release() 220 buffer.WriteRandom(c.method.saltLength) 221 err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination) 222 if err != nil { 223 return 224 } 225 common.Must1(buffer.Write(p)) 226 c.writeStream, err = c.method.encryptConstructor(c.method.key, buffer.To(c.method.saltLength)) 227 if err != nil { 228 return 229 } 230 c.writeStream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength)) 231 _, err = c.ExtendedConn.Write(buffer.Bytes()) 232 if err == nil { 233 n = len(p) 234 } 235 return 236 } 237 c.writeStream.XORKeyStream(p, p) 238 return c.ExtendedConn.Write(p) 239 } 240 241 func (c *clientConn) ReadBuffer(buffer *buf.Buffer) error { 242 if c.readStream == nil { 243 err := c.readResponse() 244 if err != nil { 245 return err 246 } 247 } 248 err := c.ExtendedConn.ReadBuffer(buffer) 249 if err != nil { 250 return err 251 } 252 c.readStream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) 253 return nil 254 } 255 256 func (c *clientConn) WriteBuffer(buffer *buf.Buffer) error { 257 if c.writeStream == nil { 258 header := buf.With(buffer.ExtendHeader(c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(c.destination))) 259 header.WriteRandom(c.method.saltLength) 260 err := M.SocksaddrSerializer.WriteAddrPort(header, c.destination) 261 if err != nil { 262 return err 263 } 264 c.writeStream, err = c.method.encryptConstructor(c.method.key, header.To(c.method.saltLength)) 265 if err != nil { 266 return err 267 } 268 c.writeStream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength)) 269 } else { 270 c.writeStream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) 271 } 272 return c.ExtendedConn.WriteBuffer(buffer) 273 } 274 275 func (c *clientConn) FrontHeadroom() int { 276 if c.writeStream == nil { 277 return c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(c.destination) 278 } 279 return 0 280 } 281 282 func (c *clientConn) NeedHandshake() bool { 283 return c.writeStream == nil 284 } 285 286 func (c *clientConn) Upstream() any { 287 return c.ExtendedConn 288 } 289 290 type clientPacketConn struct { 291 N.ExtendedConn 292 method *Method 293 } 294 295 func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 296 err = c.ReadBuffer(buffer) 297 if err != nil { 298 return 299 } 300 stream, err := c.method.decryptConstructor(c.method.key, buffer.To(c.method.saltLength)) 301 if err != nil { 302 return 303 } 304 stream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength)) 305 buffer.Advance(c.method.saltLength) 306 destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer) 307 if err != nil { 308 return 309 } 310 return destination.Unwrap(), nil 311 } 312 313 func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 314 header := buf.With(buffer.ExtendHeader(c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(destination))) 315 header.WriteRandom(c.method.saltLength) 316 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 317 if err != nil { 318 return err 319 } 320 stream, err := c.method.encryptConstructor(c.method.key, buffer.To(c.method.saltLength)) 321 if err != nil { 322 return err 323 } 324 stream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength)) 325 return c.ExtendedConn.WriteBuffer(buffer) 326 } 327 328 func (c *clientPacketConn) readFrom(p []byte) (data []byte, addr net.Addr, err error) { 329 if len(p) < c.method.saltLength { 330 err = C.ErrPacketTooShort 331 return 332 } 333 stream, err := c.method.decryptConstructor(c.method.key, p[:c.method.saltLength]) 334 if err != nil { 335 return 336 } 337 buffer := buf.As(p[c.method.saltLength:]) 338 stream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) 339 destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) 340 if err != nil { 341 return 342 } 343 if destination.IsFqdn() { 344 addr = destination 345 } else { 346 addr = destination.UDPAddr() 347 } 348 data = buffer.Bytes() 349 return 350 } 351 352 func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 353 n, err = c.ExtendedConn.Read(p) 354 if err != nil { 355 return 356 } 357 var data []byte 358 data, addr, err = c.readFrom(p[:n]) 359 n = copy(p, data) 360 return 361 } 362 363 func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 364 destination := M.SocksaddrFromNet(addr) 365 buffer := buf.NewSize(c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p)) 366 defer buffer.Release() 367 buffer.WriteRandom(c.method.saltLength) 368 err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination) 369 if err != nil { 370 return 371 } 372 stream, err := c.method.encryptConstructor(c.method.key, buffer.To(c.method.saltLength)) 373 if err != nil { 374 return 375 } 376 stream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength)) 377 stream.XORKeyStream(buffer.Extend(len(p)), p) 378 _, err = c.ExtendedConn.Write(buffer.Bytes()) 379 if err == nil { 380 n = len(p) 381 } 382 return 383 } 384 385 func (c *clientPacketConn) FrontHeadroom() int { 386 return c.method.saltLength + M.MaxSocksaddrLength 387 } 388 389 func (c *clientPacketConn) Upstream() any { 390 return c.ExtendedConn 391 } 392 393 var _ shadowio.WaitReadFrom = (*clientWaitPacketConn)(nil) 394 395 type clientWaitPacketConn struct { 396 *clientPacketConn 397 waitRead shadowio.WaitRead 398 } 399 400 func (c *clientWaitPacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { 401 data, put, err = c.waitRead.WaitRead() 402 if err != nil { 403 return 404 } 405 if len(data) <= 0 { 406 err = C.ErrPacketTooShort 407 return 408 } 409 data, addr, err = c.readFrom(data) 410 if err != nil { 411 if put != nil { 412 put() 413 } 414 put = nil 415 data = nil 416 return 417 } 418 return 419 }