github.com/metacubex/sing-shadowsocks2@v0.2.0/shadowaead/method.go (about) 1 package shadowaead 2 3 import ( 4 "context" 5 "crypto/aes" 6 "crypto/cipher" 7 "net" 8 9 C "github.com/metacubex/sing-shadowsocks2/cipher" 10 "github.com/metacubex/sing-shadowsocks2/internal/legacykey" 11 "github.com/metacubex/sing-shadowsocks2/internal/shadowio" 12 "github.com/sagernet/sing/common" 13 "github.com/sagernet/sing/common/buf" 14 "github.com/sagernet/sing/common/bufio" 15 E "github.com/sagernet/sing/common/exceptions" 16 M "github.com/sagernet/sing/common/metadata" 17 N "github.com/sagernet/sing/common/network" 18 "github.com/sagernet/sing/common/rw" 19 20 "github.com/RyuaNerin/go-krypto/lea" 21 "github.com/Yawning/aez" 22 "github.com/ericlagergren/aegis" 23 "github.com/ericlagergren/siv" 24 "github.com/oasisprotocol/deoxysii" 25 "github.com/sina-ghaderi/rabaead" 26 "golang.org/x/crypto/chacha20poly1305" 27 ) 28 29 var MethodList = []string{ 30 "aes-128-gcm", 31 "aes-192-gcm", 32 "aes-256-gcm", 33 "chacha20-ietf-poly1305", 34 "xchacha20-ietf-poly1305", 35 // began not standard methods 36 "rabbit128-poly1305", 37 "aes-128-gcm-siv", 38 "aes-256-gcm-siv", 39 "aegis-128l", 40 "aegis-256", 41 "aez-384", 42 "deoxys-ii-256-128", 43 "lea-128-gcm", 44 "lea-192-gcm", 45 "lea-256-gcm", 46 } 47 48 func init() { 49 C.RegisterMethod(MethodList, func(ctx context.Context, methodName string, options C.MethodOptions) (C.Method, error) { 50 return NewMethod(ctx, methodName, options) 51 }) 52 } 53 54 type Method struct { 55 keySaltLength int 56 constructor func(key []byte) (cipher.AEAD, error) 57 key []byte 58 } 59 60 func NewMethod(ctx context.Context, methodName string, options C.MethodOptions) (*Method, error) { 61 m := &Method{} 62 switch methodName { 63 case "aes-128-gcm": 64 m.keySaltLength = 16 65 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 66 case "aes-192-gcm": 67 m.keySaltLength = 24 68 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 69 case "aes-256-gcm": 70 m.keySaltLength = 32 71 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 72 case "chacha20-ietf-poly1305": 73 m.keySaltLength = 32 74 m.constructor = chacha20poly1305.New 75 case "xchacha20-ietf-poly1305": 76 m.keySaltLength = 32 77 m.constructor = chacha20poly1305.NewX 78 case "rabbit128-poly1305": 79 m.keySaltLength = 16 80 m.constructor = rabaead.NewAEAD 81 case "aes-128-gcm-siv": 82 m.keySaltLength = 16 83 m.constructor = siv.NewGCM 84 case "aes-256-gcm-siv": 85 m.keySaltLength = 32 86 m.constructor = siv.NewGCM 87 case "aegis-128l": 88 m.keySaltLength = 16 89 m.constructor = aegis.New 90 case "aegis-256": 91 m.keySaltLength = 32 92 m.constructor = aegis.New 93 case "aez-384": 94 m.keySaltLength = 3 * 16 95 m.constructor = aez.New 96 case "deoxys-ii-256-128": 97 m.keySaltLength = 32 98 m.constructor = deoxysii.New 99 case "lea-128-gcm": 100 m.keySaltLength = 16 101 m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM) 102 case "lea-192-gcm": 103 m.keySaltLength = 24 104 m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM) 105 case "lea-256-gcm": 106 m.keySaltLength = 32 107 m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM) 108 } 109 if len(options.Key) == m.keySaltLength { 110 m.key = options.Key 111 } else if len(options.Key) > 0 { 112 return nil, E.New("bad key length, required ", m.keySaltLength, ", got ", len(options.Key)) 113 } else if options.Password == "" { 114 return nil, C.ErrMissingPassword 115 } else { 116 m.key = legacykey.Key([]byte(options.Password), m.keySaltLength) 117 } 118 return m, nil 119 } 120 121 func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block cipher.Block) (cipher.AEAD, error)) func(key []byte) (cipher.AEAD, error) { 122 return func(key []byte) (cipher.AEAD, error) { 123 b, err := block(key) 124 if err != nil { 125 return nil, err 126 } 127 return aead(b) 128 } 129 } 130 131 func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { 132 ssConn := &clientConn{ 133 Conn: conn, 134 method: m, 135 destination: destination, 136 } 137 return ssConn, ssConn.writeRequest(nil) 138 } 139 140 func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { 141 return &clientConn{ 142 Conn: conn, 143 method: m, 144 destination: destination, 145 } 146 } 147 148 func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { 149 return &clientPacketConn{ 150 AbstractConn: conn, 151 reader: bufio.NewExtendedReader(conn), 152 writer: bufio.NewExtendedWriter(conn), 153 method: m, 154 } 155 } 156 157 var _ N.ExtendedConn = (*clientConn)(nil) 158 159 type clientConn struct { 160 net.Conn 161 method *Method 162 destination M.Socksaddr 163 reader *shadowio.Reader 164 readWaitOptions N.ReadWaitOptions 165 writer *shadowio.Writer 166 shadowio.WriterInterface 167 } 168 169 func (c *clientConn) writeRequest(payload []byte) error { 170 requestBuffer := buf.New() 171 requestBuffer.WriteRandom(c.method.keySaltLength) 172 key := make([]byte, c.method.keySaltLength) 173 legacykey.Kdf(c.method.key, requestBuffer.Bytes(), key) 174 writeCipher, err := c.method.constructor(key) 175 if err != nil { 176 return err 177 } 178 bufferedRequestWriter := bufio.NewBufferedWriter(c.Conn, requestBuffer) 179 requestContentWriter := shadowio.NewWriter(bufferedRequestWriter, writeCipher, nil, MaxPacketSize) 180 bufferedRequestContentWriter := bufio.NewBufferedWriter(requestContentWriter, buf.New()) 181 err = M.SocksaddrSerializer.WriteAddrPort(bufferedRequestContentWriter, c.destination) 182 if err != nil { 183 return err 184 } 185 _, err = bufferedRequestContentWriter.Write(payload) 186 if err != nil { 187 return err 188 } 189 err = bufferedRequestContentWriter.Fallthrough() 190 if err != nil { 191 return err 192 } 193 err = bufferedRequestWriter.Fallthrough() 194 if err != nil { 195 return err 196 } 197 c.writer = shadowio.NewWriter(c.Conn, writeCipher, requestContentWriter.TakeNonce(), MaxPacketSize) 198 return nil 199 } 200 201 func (c *clientConn) readResponse() error { 202 buffer := buf.NewSize(c.method.keySaltLength) 203 defer buffer.Release() 204 _, err := buffer.ReadFullFrom(c.Conn, c.method.keySaltLength) 205 if err != nil { 206 return err 207 } 208 legacykey.Kdf(c.method.key, buffer.Bytes(), buffer.Bytes()) 209 readCipher, err := c.method.constructor(buffer.Bytes()) 210 if err != nil { 211 return err 212 } 213 reader := shadowio.NewReader(c.Conn, readCipher) 214 reader.InitializeReadWaiter(c.readWaitOptions) 215 c.reader = reader 216 return nil 217 } 218 219 func (c *clientConn) Read(p []byte) (n int, err error) { 220 if c.reader == nil { 221 err = c.readResponse() 222 if err != nil { 223 return 224 } 225 } 226 return c.reader.Read(p) 227 } 228 229 func (c *clientConn) ReadBuffer(buffer *buf.Buffer) error { 230 if c.reader == nil { 231 err := c.readResponse() 232 if err != nil { 233 return err 234 } 235 } 236 return c.reader.ReadBuffer(buffer) 237 } 238 239 func (c *clientConn) Write(p []byte) (n int, err error) { 240 if c.writer == nil { 241 err = c.writeRequest(p) 242 if err == nil { 243 n = len(p) 244 } 245 return 246 } 247 return c.writer.Write(p) 248 } 249 250 func (c *clientConn) WriteBuffer(buffer *buf.Buffer) error { 251 if c.writer == nil { 252 defer buffer.Release() 253 return c.writeRequest(buffer.Bytes()) 254 } 255 return c.writer.WriteBuffer(buffer) 256 } 257 258 func (c *clientConn) NeedHandshake() bool { 259 return c.writer == nil 260 } 261 262 func (c *clientConn) Upstream() any { 263 return c.Conn 264 } 265 266 func (c *clientConn) WriterMTU() int { 267 return MaxPacketSize 268 } 269 270 type clientPacketConn struct { 271 N.AbstractConn 272 reader N.ExtendedReader 273 writer N.ExtendedWriter 274 method *Method 275 } 276 277 func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 278 err = c.reader.ReadBuffer(buffer) 279 if err != nil { 280 return 281 } 282 return c.readPacket(buffer) 283 } 284 285 func (c *clientPacketConn) readPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 286 if buffer.Len() < c.method.keySaltLength { 287 return M.Socksaddr{}, C.ErrPacketTooShort 288 } 289 key := buf.NewSize(c.method.keySaltLength) 290 legacykey.Kdf(c.method.key, buffer.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength)) 291 readCipher, err := c.method.constructor(key.Bytes()) 292 key.Release() 293 if err != nil { 294 return 295 } 296 packet, err := readCipher.Open(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:readCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil) 297 if err != nil { 298 return 299 } 300 buffer.Advance(c.method.keySaltLength) 301 buffer.Truncate(len(packet)) 302 if err != nil { 303 return 304 } 305 destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer) 306 if err != nil { 307 return 308 } 309 return destination.Unwrap(), nil 310 } 311 312 func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 313 header := buf.With(buffer.ExtendHeader(c.method.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination))) 314 header.WriteRandom(c.method.keySaltLength) 315 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 316 if err != nil { 317 return err 318 } 319 key := buf.NewSize(c.method.keySaltLength) 320 legacykey.Kdf(c.method.key, header.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength)) 321 writeCipher, err := c.method.constructor(key.Bytes()) 322 key.Release() 323 if err != nil { 324 return err 325 } 326 writeCipher.Seal(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil) 327 buffer.Extend(shadowio.Overhead) 328 return c.writer.WriteBuffer(buffer) 329 } 330 331 func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 332 n, err = c.reader.Read(p) 333 if err != nil { 334 return 335 } 336 if n < c.method.keySaltLength { 337 err = C.ErrPacketTooShort 338 return 339 } 340 key := buf.NewSize(c.method.keySaltLength) 341 legacykey.Kdf(c.method.key, p[:c.method.keySaltLength], key.Extend(c.method.keySaltLength)) 342 readCipher, err := c.method.constructor(key.Bytes()) 343 key.Release() 344 if err != nil { 345 return 346 } 347 packet, err := readCipher.Open(p[c.method.keySaltLength:c.method.keySaltLength], rw.ZeroBytes[:readCipher.NonceSize()], p[c.method.keySaltLength:n], nil) 348 if err != nil { 349 return 350 } 351 packetContent := buf.As(packet) 352 destination, err := M.SocksaddrSerializer.ReadAddrPort(packetContent) 353 if err != nil { 354 return 355 } 356 if !destination.IsFqdn() { 357 addr = destination.UDPAddr() 358 } else { 359 addr = destination 360 } 361 n = copy(p, packetContent.Bytes()) 362 return 363 } 364 365 func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 366 destination := M.SocksaddrFromNet(addr) 367 buffer := buf.NewSize(c.method.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p) + shadowio.Overhead) 368 defer buffer.Release() 369 buffer.WriteRandom(c.method.keySaltLength) 370 err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination) 371 if err != nil { 372 return 373 } 374 common.Must1(buffer.Write(p)) 375 key := buf.NewSize(c.method.keySaltLength) 376 legacykey.Kdf(c.method.key, buffer.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength)) 377 writeCipher, err := c.method.constructor(key.Bytes()) 378 key.Release() 379 if err != nil { 380 return 381 } 382 writeCipher.Seal(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil) 383 buffer.Extend(shadowio.Overhead) 384 _, err = c.writer.Write(buffer.Bytes()) 385 if err != nil { 386 return 387 } 388 return len(p), nil 389 } 390 391 func (c *clientPacketConn) FrontHeadroom() int { 392 return c.method.keySaltLength + M.MaxSocksaddrLength 393 } 394 395 func (c *clientPacketConn) RearHeadroom() int { 396 return shadowio.Overhead 397 } 398 399 func (c *clientPacketConn) ReaderMTU() int { 400 return MaxPacketSize 401 } 402 403 func (c *clientPacketConn) WriterMTU() int { 404 return MaxPacketSize 405 } 406 407 func (c *clientPacketConn) Upstream() any { 408 return c.AbstractConn 409 }