github.com/MerlinKodo/sing-shadowsocks2@v0.1.6/shadowaead/method.go (about) 1 package shadowaead 2 3 import ( 4 "context" 5 "crypto/aes" 6 "crypto/cipher" 7 "net" 8 9 C "github.com/MerlinKodo/sing-shadowsocks2/cipher" 10 "github.com/MerlinKodo/sing-shadowsocks2/internal/legacykey" 11 "github.com/MerlinKodo/sing-shadowsocks2/internal/shadowio" 12 "github.com/sagernet/sing/common" 13 "github.com/sagernet/sing/common/buf" 14 "github.com/sagernet/sing/common/bufio" 15 E "github.com/sagernet/sing/common/exceptions" 16 M "github.com/sagernet/sing/common/metadata" 17 N "github.com/sagernet/sing/common/network" 18 "github.com/sagernet/sing/common/rw" 19 20 "github.com/RyuaNerin/go-krypto/lea" 21 "github.com/Yawning/aez" 22 "github.com/ericlagergren/aegis" 23 "github.com/ericlagergren/siv" 24 "github.com/oasisprotocol/deoxysii" 25 "github.com/sina-ghaderi/rabaead" 26 "golang.org/x/crypto/chacha20poly1305" 27 ) 28 29 var MethodList = []string{ 30 "aes-128-gcm", 31 "aes-192-gcm", 32 "aes-256-gcm", 33 "chacha20-ietf-poly1305", 34 "xchacha20-ietf-poly1305", 35 // began not standard methods 36 "rabbit128-poly1305", 37 "aes-128-gcm-siv", 38 "aes-256-gcm-siv", 39 "aegis-128l", 40 "aegis-256", 41 "aez-384", 42 "deoxys-ii-256-128", 43 "lea-128-gcm", 44 "lea-192-gcm", 45 "lea-256-gcm", 46 } 47 48 func init() { 49 C.RegisterMethod(MethodList, func(ctx context.Context, methodName string, options C.MethodOptions) (C.Method, error) { 50 return NewMethod(ctx, methodName, options) 51 }) 52 } 53 54 type Method struct { 55 keySaltLength int 56 constructor func(key []byte) (cipher.AEAD, error) 57 key []byte 58 } 59 60 func NewMethod(ctx context.Context, methodName string, options C.MethodOptions) (*Method, error) { 61 m := &Method{} 62 switch methodName { 63 case "aes-128-gcm": 64 m.keySaltLength = 16 65 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 66 case "aes-192-gcm": 67 m.keySaltLength = 24 68 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 69 case "aes-256-gcm": 70 m.keySaltLength = 32 71 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 72 case "chacha20-ietf-poly1305": 73 m.keySaltLength = 32 74 m.constructor = chacha20poly1305.New 75 case "xchacha20-ietf-poly1305": 76 m.keySaltLength = 32 77 m.constructor = chacha20poly1305.NewX 78 case "rabbit128-poly1305": 79 m.keySaltLength = 16 80 m.constructor = rabaead.NewAEAD 81 case "aes-128-gcm-siv": 82 m.keySaltLength = 16 83 m.constructor = siv.NewGCM 84 case "aes-256-gcm-siv": 85 m.keySaltLength = 32 86 m.constructor = siv.NewGCM 87 case "aegis-128l": 88 m.keySaltLength = 16 89 m.constructor = aegis.New 90 case "aegis-256": 91 m.keySaltLength = 32 92 m.constructor = aegis.New 93 case "aez-384": 94 m.keySaltLength = 3 * 16 95 m.constructor = aez.New 96 case "deoxys-ii-256-128": 97 m.keySaltLength = 32 98 m.constructor = deoxysii.New 99 case "lea-128-gcm": 100 m.keySaltLength = 16 101 m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM) 102 case "lea-192-gcm": 103 m.keySaltLength = 24 104 m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM) 105 case "lea-256-gcm": 106 m.keySaltLength = 32 107 m.constructor = aeadCipher(lea.NewCipher, cipher.NewGCM) 108 } 109 if len(options.Key) == m.keySaltLength { 110 m.key = options.Key 111 } else if len(options.Key) > 0 { 112 return nil, E.New("bad key length, required ", m.keySaltLength, ", got ", len(options.Key)) 113 } else if options.Password == "" { 114 return nil, C.ErrMissingPassword 115 } else { 116 m.key = legacykey.Key([]byte(options.Password), m.keySaltLength) 117 } 118 return m, nil 119 } 120 121 func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block cipher.Block) (cipher.AEAD, error)) func(key []byte) (cipher.AEAD, error) { 122 return func(key []byte) (cipher.AEAD, error) { 123 b, err := block(key) 124 if err != nil { 125 return nil, err 126 } 127 return aead(b) 128 } 129 } 130 131 func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { 132 ssConn := &clientConn{ 133 Conn: conn, 134 method: m, 135 destination: destination, 136 } 137 return ssConn, ssConn.writeRequest(nil) 138 } 139 140 func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { 141 return &clientConn{ 142 Conn: conn, 143 method: m, 144 destination: destination, 145 } 146 } 147 148 func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { 149 pc := &clientPacketConn{ 150 AbstractConn: conn, 151 reader: bufio.NewExtendedReader(conn), 152 writer: bufio.NewExtendedWriter(conn), 153 method: m, 154 } 155 if waitRead, isWaitRead := N.CastReader[shadowio.WaitReadReader](conn); isWaitRead { 156 return &clientWaitPacketConn{ 157 clientPacketConn: pc, 158 waitRead: waitRead, 159 } 160 } 161 return pc 162 } 163 164 type clientConn struct { 165 net.Conn 166 method *Method 167 destination M.Socksaddr 168 reader *shadowio.Reader 169 writer *shadowio.Writer 170 shadowio.WriterInterface 171 } 172 173 func (c *clientConn) writeRequest(payload []byte) error { 174 requestBuffer := buf.New() 175 requestBuffer.WriteRandom(c.method.keySaltLength) 176 key := make([]byte, c.method.keySaltLength) 177 legacykey.Kdf(c.method.key, requestBuffer.Bytes(), key) 178 writeCipher, err := c.method.constructor(key) 179 if err != nil { 180 return err 181 } 182 bufferedRequestWriter := bufio.NewBufferedWriter(c.Conn, requestBuffer) 183 requestContentWriter := shadowio.NewWriter(bufferedRequestWriter, writeCipher, nil, MaxPacketSize) 184 bufferedRequestContentWriter := bufio.NewBufferedWriter(requestContentWriter, buf.New()) 185 err = M.SocksaddrSerializer.WriteAddrPort(bufferedRequestContentWriter, c.destination) 186 if err != nil { 187 return err 188 } 189 _, err = bufferedRequestContentWriter.Write(payload) 190 if err != nil { 191 return err 192 } 193 err = bufferedRequestContentWriter.Fallthrough() 194 if err != nil { 195 return err 196 } 197 err = bufferedRequestWriter.Fallthrough() 198 if err != nil { 199 return err 200 } 201 c.writer = shadowio.NewWriter(c.Conn, writeCipher, requestContentWriter.TakeNonce(), MaxPacketSize) 202 return nil 203 } 204 205 func (c *clientConn) readResponse() error { 206 buffer := buf.NewSize(c.method.keySaltLength) 207 defer buffer.Release() 208 _, err := buffer.ReadFullFrom(c.Conn, c.method.keySaltLength) 209 if err != nil { 210 return err 211 } 212 legacykey.Kdf(c.method.key, buffer.Bytes(), buffer.Bytes()) 213 readCipher, err := c.method.constructor(buffer.Bytes()) 214 if err != nil { 215 return err 216 } 217 c.reader = shadowio.NewReader(c.Conn, readCipher) 218 return nil 219 } 220 221 func (c *clientConn) Read(p []byte) (n int, err error) { 222 if c.reader == nil { 223 err = c.readResponse() 224 if err != nil { 225 return 226 } 227 } 228 return c.reader.Read(p) 229 } 230 231 func (c *clientConn) ReadBuffer(buffer *buf.Buffer) error { 232 if c.reader == nil { 233 err := c.readResponse() 234 if err != nil { 235 return err 236 } 237 } 238 return c.reader.ReadBuffer(buffer) 239 } 240 241 func (c *clientConn) ReadBufferThreadSafe() (buffer *buf.Buffer, err error) { 242 if c.reader == nil { 243 err = c.readResponse() 244 if err != nil { 245 return 246 } 247 } 248 return c.reader.ReadBufferThreadSafe() 249 } 250 251 func (c *clientConn) Write(p []byte) (n int, err error) { 252 if c.writer == nil { 253 err = c.writeRequest(p) 254 if err == nil { 255 n = len(p) 256 } 257 return 258 } 259 return c.writer.Write(p) 260 } 261 262 func (c *clientConn) WriteBuffer(buffer *buf.Buffer) error { 263 if c.writer == nil { 264 defer buffer.Release() 265 return c.writeRequest(buffer.Bytes()) 266 } 267 return c.writer.WriteBuffer(buffer) 268 } 269 270 func (c *clientConn) NeedHandshake() bool { 271 return c.writer == nil 272 } 273 274 func (c *clientConn) Upstream() any { 275 return c.Conn 276 } 277 278 func (c *clientConn) WriterMTU() int { 279 return MaxPacketSize 280 } 281 282 type clientPacketConn struct { 283 N.AbstractConn 284 reader N.ExtendedReader 285 writer N.ExtendedWriter 286 method *Method 287 } 288 289 func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 290 err = c.reader.ReadBuffer(buffer) 291 if err != nil { 292 return 293 } 294 if buffer.Len() < c.method.keySaltLength { 295 return M.Socksaddr{}, C.ErrPacketTooShort 296 } 297 key := buf.NewSize(c.method.keySaltLength) 298 legacykey.Kdf(c.method.key, buffer.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength)) 299 readCipher, err := c.method.constructor(key.Bytes()) 300 key.Release() 301 if err != nil { 302 return 303 } 304 packet, err := readCipher.Open(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:readCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil) 305 if err != nil { 306 return 307 } 308 buffer.Advance(c.method.keySaltLength) 309 buffer.Truncate(len(packet)) 310 if err != nil { 311 return 312 } 313 destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer) 314 if err != nil { 315 return 316 } 317 return destination.Unwrap(), nil 318 } 319 320 func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 321 header := buf.With(buffer.ExtendHeader(c.method.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination))) 322 header.WriteRandom(c.method.keySaltLength) 323 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 324 if err != nil { 325 return err 326 } 327 key := buf.NewSize(c.method.keySaltLength) 328 legacykey.Kdf(c.method.key, header.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength)) 329 writeCipher, err := c.method.constructor(key.Bytes()) 330 key.Release() 331 if err != nil { 332 return err 333 } 334 writeCipher.Seal(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil) 335 buffer.Extend(shadowio.Overhead) 336 return c.writer.WriteBuffer(buffer) 337 } 338 339 func (c *clientPacketConn) readFrom(p []byte) (data []byte, addr net.Addr, err error) { 340 if len(p) < c.method.keySaltLength { 341 err = C.ErrPacketTooShort 342 return 343 } 344 key := buf.NewSize(c.method.keySaltLength) 345 legacykey.Kdf(c.method.key, p[:c.method.keySaltLength], key.Extend(c.method.keySaltLength)) 346 readCipher, err := c.method.constructor(key.Bytes()) 347 key.Release() 348 if err != nil { 349 return 350 } 351 packet, err := readCipher.Open(p[c.method.keySaltLength:c.method.keySaltLength], rw.ZeroBytes[:readCipher.NonceSize()], p[c.method.keySaltLength:], nil) 352 if err != nil { 353 return 354 } 355 packetContent := buf.As(packet) 356 destination, err := M.SocksaddrSerializer.ReadAddrPort(packetContent) 357 if err != nil { 358 return 359 } 360 if !destination.IsFqdn() { 361 addr = destination.UDPAddr() 362 } else { 363 addr = destination 364 } 365 data = packetContent.Bytes() 366 return 367 } 368 369 func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 370 n, err = c.reader.Read(p) 371 if err != nil { 372 return 373 } 374 var data []byte 375 data, addr, err = c.readFrom(p[:n]) 376 n = copy(p, data) 377 return 378 } 379 380 func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 381 destination := M.SocksaddrFromNet(addr) 382 buffer := buf.NewSize(c.method.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p) + shadowio.Overhead) 383 defer buffer.Release() 384 buffer.WriteRandom(c.method.keySaltLength) 385 err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination) 386 if err != nil { 387 return 388 } 389 common.Must1(buffer.Write(p)) 390 key := buf.NewSize(c.method.keySaltLength) 391 legacykey.Kdf(c.method.key, buffer.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength)) 392 writeCipher, err := c.method.constructor(key.Bytes()) 393 key.Release() 394 if err != nil { 395 return 396 } 397 writeCipher.Seal(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil) 398 buffer.Extend(shadowio.Overhead) 399 _, err = c.writer.Write(buffer.Bytes()) 400 if err != nil { 401 return 402 } 403 return len(p), nil 404 } 405 406 func (c *clientPacketConn) FrontHeadroom() int { 407 return c.method.keySaltLength + M.MaxSocksaddrLength 408 } 409 410 func (c *clientPacketConn) RearHeadroom() int { 411 return shadowio.Overhead 412 } 413 414 func (c *clientPacketConn) Upstream() any { 415 return c.AbstractConn 416 } 417 418 var _ shadowio.WaitReadFrom = (*clientWaitPacketConn)(nil) 419 420 type clientWaitPacketConn struct { 421 *clientPacketConn 422 waitRead shadowio.WaitRead 423 } 424 425 func (c *clientWaitPacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { 426 data, put, err = c.waitRead.WaitRead() 427 if err != nil { 428 return 429 } 430 if len(data) <= 0 { 431 err = C.ErrPacketTooShort 432 return 433 } 434 data, addr, err = c.readFrom(data) 435 if err != nil { 436 if put != nil { 437 put() 438 } 439 put = nil 440 data = nil 441 return 442 } 443 return 444 }