github.com/sagernet/sing-shadowsocks2@v0.2.0/shadowstream/method.go (about) 1 package shadowstream 2 3 import ( 4 "context" 5 "crypto/aes" 6 "crypto/cipher" 7 "crypto/md5" 8 "crypto/rc4" 9 "net" 10 "os" 11 12 C "github.com/sagernet/sing-shadowsocks2/cipher" 13 "github.com/sagernet/sing-shadowsocks2/internal/legacykey" 14 "github.com/sagernet/sing/common" 15 "github.com/sagernet/sing/common/buf" 16 "github.com/sagernet/sing/common/bufio" 17 E "github.com/sagernet/sing/common/exceptions" 18 M "github.com/sagernet/sing/common/metadata" 19 N "github.com/sagernet/sing/common/network" 20 21 "golang.org/x/crypto/chacha20" 22 ) 23 24 var MethodList = []string{ 25 "aes-128-ctr", 26 "aes-192-ctr", 27 "aes-256-ctr", 28 "aes-128-cfb", 29 "aes-192-cfb", 30 "aes-256-cfb", 31 "rc4-md5", 32 "chacha20-ietf", 33 "xchacha20", 34 } 35 36 func init() { 37 C.RegisterMethod(MethodList, NewMethod) 38 } 39 40 type Method struct { 41 keyLength int 42 saltLength int 43 encryptConstructor func(key []byte, salt []byte) (cipher.Stream, error) 44 decryptConstructor func(key []byte, salt []byte) (cipher.Stream, error) 45 key []byte 46 } 47 48 func NewMethod(ctx context.Context, methodName string, options C.MethodOptions) (C.Method, error) { 49 m := &Method{} 50 switch methodName { 51 case "aes-128-ctr": 52 m.keyLength = 16 53 m.saltLength = aes.BlockSize 54 m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) 55 m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) 56 case "aes-192-ctr": 57 m.keyLength = 24 58 m.saltLength = aes.BlockSize 59 m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) 60 m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) 61 case "aes-256-ctr": 62 m.keyLength = 32 63 m.saltLength = aes.BlockSize 64 m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) 65 m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCTR) 66 case "aes-128-cfb": 67 m.keyLength = 16 68 m.saltLength = aes.BlockSize 69 m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter) 70 m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter) 71 case "aes-192-cfb": 72 m.keyLength = 24 73 m.saltLength = aes.BlockSize 74 m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter) 75 m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter) 76 case "aes-256-cfb": 77 m.keyLength = 32 78 m.saltLength = aes.BlockSize 79 m.encryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBEncrypter) 80 m.decryptConstructor = blockStream(aes.NewCipher, cipher.NewCFBDecrypter) 81 case "rc4-md5": 82 m.keyLength = 16 83 m.saltLength = 16 84 m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { 85 h := md5.New() 86 h.Write(key) 87 h.Write(salt) 88 return rc4.NewCipher(h.Sum(nil)) 89 } 90 m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { 91 h := md5.New() 92 h.Write(key) 93 h.Write(salt) 94 return rc4.NewCipher(h.Sum(nil)) 95 } 96 case "chacha20-ietf": 97 m.keyLength = chacha20.KeySize 98 m.saltLength = chacha20.NonceSize 99 m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { 100 return chacha20.NewUnauthenticatedCipher(key, salt) 101 } 102 m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { 103 return chacha20.NewUnauthenticatedCipher(key, salt) 104 } 105 case "xchacha20": 106 m.keyLength = chacha20.KeySize 107 m.saltLength = chacha20.NonceSizeX 108 m.encryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { 109 return chacha20.NewUnauthenticatedCipher(key, salt) 110 } 111 m.decryptConstructor = func(key []byte, salt []byte) (cipher.Stream, error) { 112 return chacha20.NewUnauthenticatedCipher(key, salt) 113 } 114 default: 115 return nil, os.ErrInvalid 116 } 117 if len(options.Key) == m.keyLength { 118 m.key = options.Key 119 } else if len(options.Key) > 0 { 120 return nil, E.New("bad key length, required ", m.keyLength, ", got ", len(options.Key)) 121 } else if options.Password != "" { 122 m.key = legacykey.Key([]byte(options.Password), m.keyLength) 123 } else { 124 return nil, C.ErrMissingPassword 125 } 126 return m, nil 127 } 128 129 func blockStream(blockCreator func(key []byte) (cipher.Block, error), streamCreator func(block cipher.Block, iv []byte) cipher.Stream) func([]byte, []byte) (cipher.Stream, error) { 130 return func(key []byte, iv []byte) (cipher.Stream, error) { 131 block, err := blockCreator(key) 132 if err != nil { 133 return nil, err 134 } 135 return streamCreator(block, iv), err 136 } 137 } 138 139 func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { 140 ssConn := &clientConn{ 141 ExtendedConn: bufio.NewExtendedConn(conn), 142 method: m, 143 destination: destination, 144 } 145 return ssConn, common.Error(ssConn.Write(nil)) 146 } 147 148 func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { 149 return &clientConn{ 150 ExtendedConn: bufio.NewExtendedConn(conn), 151 method: m, 152 destination: destination, 153 } 154 } 155 156 func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { 157 return &clientPacketConn{ 158 ExtendedConn: bufio.NewExtendedConn(conn), 159 method: m, 160 } 161 } 162 163 type clientConn struct { 164 N.ExtendedConn 165 method *Method 166 destination M.Socksaddr 167 readStream cipher.Stream 168 writeStream cipher.Stream 169 } 170 171 func (c *clientConn) readResponse() error { 172 saltBuffer := buf.NewSize(c.method.saltLength) 173 defer saltBuffer.Release() 174 _, err := saltBuffer.ReadFullFrom(c.ExtendedConn, c.method.saltLength) 175 if err != nil { 176 return err 177 } 178 c.readStream, err = c.method.decryptConstructor(c.method.key, saltBuffer.Bytes()) 179 return err 180 } 181 182 func (c *clientConn) Read(p []byte) (n int, err error) { 183 if c.readStream == nil { 184 err = c.readResponse() 185 if err != nil { 186 return 187 } 188 } 189 n, err = c.ExtendedConn.Read(p) 190 if err != nil { 191 return 192 } 193 c.readStream.XORKeyStream(p[:n], p[:n]) 194 return 195 } 196 197 func (c *clientConn) Write(p []byte) (n int, err error) { 198 if c.writeStream == nil { 199 buffer := buf.NewSize(c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(c.destination) + len(p)) 200 defer buffer.Release() 201 buffer.WriteRandom(c.method.saltLength) 202 err = M.SocksaddrSerializer.WriteAddrPort(buffer, c.destination) 203 if err != nil { 204 return 205 } 206 common.Must1(buffer.Write(p)) 207 c.writeStream, err = c.method.encryptConstructor(c.method.key, buffer.To(c.method.saltLength)) 208 if err != nil { 209 return 210 } 211 c.writeStream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength)) 212 _, err = c.ExtendedConn.Write(buffer.Bytes()) 213 if err == nil { 214 n = len(p) 215 } 216 return 217 } 218 c.writeStream.XORKeyStream(p, p) 219 return c.ExtendedConn.Write(p) 220 } 221 222 func (c *clientConn) ReadBuffer(buffer *buf.Buffer) error { 223 if c.readStream == nil { 224 err := c.readResponse() 225 if err != nil { 226 return err 227 } 228 } 229 err := c.ExtendedConn.ReadBuffer(buffer) 230 if err != nil { 231 return err 232 } 233 c.readStream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) 234 return nil 235 } 236 237 func (c *clientConn) WriteBuffer(buffer *buf.Buffer) error { 238 if c.writeStream == nil { 239 header := buf.With(buffer.ExtendHeader(c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(c.destination))) 240 header.WriteRandom(c.method.saltLength) 241 err := M.SocksaddrSerializer.WriteAddrPort(header, c.destination) 242 if err != nil { 243 return err 244 } 245 c.writeStream, err = c.method.encryptConstructor(c.method.key, header.To(c.method.saltLength)) 246 if err != nil { 247 return err 248 } 249 c.writeStream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength)) 250 } else { 251 c.writeStream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) 252 } 253 return c.ExtendedConn.WriteBuffer(buffer) 254 } 255 256 func (c *clientConn) FrontHeadroom() int { 257 if c.writeStream == nil { 258 return c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(c.destination) 259 } 260 return 0 261 } 262 263 func (c *clientConn) NeedHandshake() bool { 264 return c.writeStream == nil 265 } 266 267 func (c *clientConn) Upstream() any { 268 return c.ExtendedConn 269 } 270 271 type clientPacketConn struct { 272 N.ExtendedConn 273 method *Method 274 } 275 276 func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 277 err = c.ReadBuffer(buffer) 278 if err != nil { 279 return 280 } 281 stream, err := c.method.decryptConstructor(c.method.key, buffer.To(c.method.saltLength)) 282 if err != nil { 283 return 284 } 285 stream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength)) 286 buffer.Advance(c.method.saltLength) 287 destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer) 288 if err != nil { 289 return 290 } 291 return destination.Unwrap(), nil 292 } 293 294 func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 295 header := buf.With(buffer.ExtendHeader(c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(destination))) 296 header.WriteRandom(c.method.saltLength) 297 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 298 if err != nil { 299 return err 300 } 301 stream, err := c.method.encryptConstructor(c.method.key, buffer.To(c.method.saltLength)) 302 if err != nil { 303 return err 304 } 305 stream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength)) 306 return c.ExtendedConn.WriteBuffer(buffer) 307 } 308 309 func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 310 n, err = c.ExtendedConn.Read(p) 311 if err != nil { 312 return 313 } 314 stream, err := c.method.decryptConstructor(c.method.key, p[:c.method.saltLength]) 315 if err != nil { 316 return 317 } 318 buffer := buf.As(p[c.method.saltLength:n]) 319 stream.XORKeyStream(buffer.Bytes(), buffer.Bytes()) 320 destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) 321 if err != nil { 322 return 323 } 324 if destination.IsFqdn() { 325 addr = destination 326 } else { 327 addr = destination.UDPAddr() 328 } 329 n = copy(p, buffer.Bytes()) 330 return 331 } 332 333 func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 334 destination := M.SocksaddrFromNet(addr) 335 buffer := buf.NewSize(c.method.saltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p)) 336 defer buffer.Release() 337 buffer.WriteRandom(c.method.saltLength) 338 err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination) 339 if err != nil { 340 return 341 } 342 stream, err := c.method.encryptConstructor(c.method.key, buffer.To(c.method.saltLength)) 343 if err != nil { 344 return 345 } 346 stream.XORKeyStream(buffer.From(c.method.saltLength), buffer.From(c.method.saltLength)) 347 stream.XORKeyStream(buffer.Extend(len(p)), p) 348 _, err = c.ExtendedConn.Write(buffer.Bytes()) 349 if err == nil { 350 n = len(p) 351 } 352 return 353 } 354 355 func (c *clientPacketConn) FrontHeadroom() int { 356 return c.method.saltLength + M.MaxSocksaddrLength 357 } 358 359 func (c *clientPacketConn) Upstream() any { 360 return c.ExtendedConn 361 }