github.com/metacubex/sing-shadowsocks@v0.2.6/shadowaead/protocol.go (about) 1 package shadowaead 2 3 import ( 4 "crypto/aes" 5 "crypto/cipher" 6 "crypto/sha1" 7 "io" 8 "net" 9 10 "github.com/metacubex/sing-shadowsocks" 11 "github.com/sagernet/sing/common" 12 "github.com/sagernet/sing/common/buf" 13 M "github.com/sagernet/sing/common/metadata" 14 N "github.com/sagernet/sing/common/network" 15 "github.com/sagernet/sing/common/rw" 16 17 "github.com/RyuaNerin/go-krypto/lea" 18 "github.com/Yawning/aez" 19 "github.com/ericlagergren/aegis" 20 "github.com/ericlagergren/siv" 21 "github.com/oasisprotocol/deoxysii" 22 "github.com/sina-ghaderi/rabaead" 23 "golang.org/x/crypto/chacha20poly1305" 24 "golang.org/x/crypto/hkdf" 25 ) 26 27 var List = []string{ 28 "aes-128-gcm", 29 "aes-192-gcm", 30 "aes-256-gcm", 31 "chacha20-ietf-poly1305", 32 "xchacha20-ietf-poly1305", 33 // began not standard methods 34 "rabbit128-poly1305", 35 "aes-128-gcm-siv", 36 "aes-256-gcm-siv", 37 "aegis-128l", 38 "aegis-256", 39 "aez-384", 40 "deoxys-ii-256-128", 41 "lea-128-gcm", 42 "lea-192-gcm", 43 "lea-256-gcm", 44 } 45 46 var _ shadowsocks.Method = (*Method)(nil) 47 48 func New(method string, key []byte, password string) (*Method, error) { 49 m := &Method{ 50 name: method, 51 } 52 switch method { 53 case "aes-128-gcm": 54 m.keySaltLength = 16 55 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 56 case "aes-192-gcm": 57 m.keySaltLength = 24 58 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 59 case "aes-256-gcm": 60 m.keySaltLength = 32 61 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 62 case "chacha20-ietf-poly1305": 63 m.keySaltLength = 32 64 m.constructor = chacha20poly1305.New 65 case "xchacha20-ietf-poly1305": 66 m.keySaltLength = 32 67 m.constructor = chacha20poly1305.NewX 68 case "rabbit128-poly1305": 69 m.keySaltLength = 16 70 m.constructor = rabaead.NewAEAD 71 case "aes-128-gcm-siv": 72 m.keySaltLength = 16 73 m.constructor = siv.NewGCM 74 case "aes-256-gcm-siv": 75 m.keySaltLength = 32 76 m.constructor = siv.NewGCM 77 case "aegis-128l": 78 m.keySaltLength = 16 79 m.constructor = aegis.New 80 case "aegis-256": 81 m.keySaltLength = 32 82 m.constructor = aegis.New 83 case "aez-384": 84 m.keySaltLength = 3 * 16 85 m.constructor = aez.New 86 case "deoxys-ii-256-128": 87 m.keySaltLength = 32 88 m.constructor = deoxysii.New 89 case "lea-128-gcm": 90 m.keySaltLength = 16 91 m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM) 92 case "lea-192-gcm": 93 m.keySaltLength = 24 94 m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM) 95 case "lea-256-gcm": 96 m.keySaltLength = 32 97 m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM) 98 } 99 if len(key) == m.keySaltLength { 100 m.key = key 101 } else if len(key) > 0 { 102 return nil, shadowsocks.ErrBadKey 103 } else if password == "" { 104 return nil, shadowsocks.ErrMissingPassword 105 } else { 106 m.key = shadowsocks.Key([]byte(password), m.keySaltLength) 107 } 108 return m, nil 109 } 110 111 func Kdf(key, iv []byte, buffer *buf.Buffer) { 112 kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey")) 113 common.Must1(buffer.ReadFullFrom(kdf, buffer.FreeLen())) 114 } 115 116 func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block cipher.Block) (cipher.AEAD, error)) func(key []byte) (cipher.AEAD, error) { 117 return func(key []byte) (cipher.AEAD, error) { 118 b, err := block(key) 119 if err != nil { 120 return nil, err 121 } 122 return aead(b) 123 } 124 } 125 126 type Method struct { 127 name string 128 keySaltLength int 129 constructor func(key []byte) (cipher.AEAD, error) 130 key []byte 131 } 132 133 func (m *Method) Name() string { 134 return m.name 135 } 136 137 func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { 138 shadowsocksConn := &clientConn{ 139 Conn: conn, 140 Method: m, 141 destination: destination, 142 } 143 return shadowsocksConn, shadowsocksConn.writeRequest(nil) 144 } 145 146 func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { 147 return &clientConn{ 148 Conn: conn, 149 Method: m, 150 destination: destination, 151 } 152 } 153 154 func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { 155 return &clientPacketConn{m, conn} 156 } 157 158 type clientConn struct { 159 net.Conn 160 *Method 161 destination M.Socksaddr 162 reader *Reader 163 writer *Writer 164 } 165 166 func (c *clientConn) writeRequest(payload []byte) error { 167 salt := buf.NewSize(c.keySaltLength) 168 defer salt.Release() 169 salt.WriteRandom(c.keySaltLength) 170 171 key := buf.NewSize(c.keySaltLength) 172 173 Kdf(c.key, salt.Bytes(), key) 174 writeCipher, err := c.constructor(key.Bytes()) 175 key.Release() 176 if err != nil { 177 return err 178 } 179 writer := NewWriter(c.Conn, writeCipher, MaxPacketSize) 180 header := writer.Buffer() 181 common.Must1(header.Write(salt.Bytes())) 182 bufferedWriter := writer.BufferedWriter(header.Len()) 183 184 if len(payload) > 0 { 185 err = M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination) 186 if err != nil { 187 return err 188 } 189 190 _, err = bufferedWriter.Write(payload) 191 if err != nil { 192 return err 193 } 194 } else { 195 err = M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination) 196 if err != nil { 197 return err 198 } 199 } 200 201 err = bufferedWriter.Flush() 202 if err != nil { 203 return err 204 } 205 206 c.writer = writer 207 return nil 208 } 209 210 func (c *clientConn) readResponse() error { 211 salt := buf.NewSize(c.keySaltLength) 212 defer salt.Release() 213 _, err := salt.ReadFullFrom(c.Conn, c.keySaltLength) 214 if err != nil { 215 return err 216 } 217 key := buf.NewSize(c.keySaltLength) 218 defer key.Release() 219 Kdf(c.key, salt.Bytes(), key) 220 readCipher, err := c.constructor(key.Bytes()) 221 if err != nil { 222 return err 223 } 224 c.reader = NewReader( 225 c.Conn, 226 readCipher, 227 MaxPacketSize, 228 ) 229 return nil 230 } 231 232 func (c *clientConn) Read(p []byte) (n int, err error) { 233 if c.reader == nil { 234 if err = c.readResponse(); err != nil { 235 return 236 } 237 } 238 return c.reader.Read(p) 239 } 240 241 func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) { 242 if c.reader == nil { 243 if err = c.readResponse(); err != nil { 244 return 245 } 246 } 247 return c.reader.WriteTo(w) 248 } 249 250 func (c *clientConn) Write(p []byte) (n int, err error) { 251 if c.writer == nil { 252 err = c.writeRequest(p) 253 if err != nil { 254 return 255 } 256 return len(p), nil 257 } 258 return c.writer.Write(p) 259 } 260 261 func (c *clientConn) NeedHandshake() bool { 262 return c.writer == nil 263 } 264 265 func (c *clientConn) NeedAdditionalReadDeadline() bool { 266 return true 267 } 268 269 func (c *clientConn) Upstream() any { 270 return c.Conn 271 } 272 273 type clientPacketConn struct { 274 *Method 275 net.Conn 276 } 277 278 func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 279 defer buffer.Release() 280 header := buf.With(buffer.ExtendHeader(c.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination))) 281 header.WriteRandom(c.keySaltLength) 282 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 283 if err != nil { 284 return err 285 } 286 key := buf.NewSize(c.keySaltLength) 287 Kdf(c.key, buffer.To(c.keySaltLength), key) 288 writeCipher, err := c.constructor(key.Bytes()) 289 key.Release() 290 if err != nil { 291 return err 292 } 293 writeCipher.Seal(buffer.Index(c.keySaltLength), rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(c.keySaltLength), nil) 294 buffer.Extend(Overhead) 295 return common.Error(c.Write(buffer.Bytes())) 296 } 297 298 func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { 299 n, err := c.Read(buffer.FreeBytes()) 300 if err != nil { 301 return M.Socksaddr{}, err 302 } 303 buffer.Truncate(n) 304 if buffer.Len() < c.keySaltLength { 305 return M.Socksaddr{}, io.ErrShortBuffer 306 } 307 key := buf.NewSize(c.keySaltLength) 308 Kdf(c.key, buffer.To(c.keySaltLength), key) 309 readCipher, err := c.constructor(key.Bytes()) 310 key.Release() 311 if err != nil { 312 return M.Socksaddr{}, err 313 } 314 packet, err := readCipher.Open(buffer.Index(c.keySaltLength), rw.ZeroBytes[:readCipher.NonceSize()], buffer.From(c.keySaltLength), nil) 315 if err != nil { 316 return M.Socksaddr{}, err 317 } 318 buffer.Advance(c.keySaltLength) 319 buffer.Truncate(len(packet)) 320 if err != nil { 321 return M.Socksaddr{}, err 322 } 323 return M.SocksaddrSerializer.ReadAddrPort(buffer) 324 } 325 326 func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 327 buffer := buf.With(p) 328 destination, err := c.ReadPacket(buffer) 329 if err != nil { 330 return 331 } 332 if destination.IsFqdn() { 333 addr = destination 334 } else { 335 addr = destination.UDPAddr() 336 } 337 n = copy(p, buffer.Bytes()) 338 return 339 } 340 341 func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 342 destination := M.SocksaddrFromNet(addr) 343 buffer := buf.NewSize(c.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p) + Overhead) 344 defer buffer.Release() 345 buffer.Resize(c.keySaltLength+M.SocksaddrSerializer.AddrPortLen(destination), 0) 346 common.Must1(buffer.Write(p)) 347 err = c.WritePacket(buffer, destination) 348 if err != nil { 349 return 350 } 351 return len(p), nil 352 } 353 354 func (c *clientPacketConn) FrontHeadroom() int { 355 return c.keySaltLength + M.MaxSocksaddrLength 356 } 357 358 func (c *clientPacketConn) RearHeadroom() int { 359 return Overhead 360 } 361 362 func (c *clientPacketConn) ReaderMTU() int { 363 return MaxPacketSize 364 } 365 366 func (c *clientPacketConn) WriterMTU() int { 367 return MaxPacketSize 368 } 369 370 func (c *clientPacketConn) Upstream() any { 371 return c.Conn 372 }