github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/protocol/socks/packet.go (about) 1 package socks 2 3 import ( 4 "bytes" 5 "net" 6 7 "github.com/sagernet/sing/common" 8 "github.com/sagernet/sing/common/buf" 9 "github.com/sagernet/sing/common/bufio" 10 E "github.com/sagernet/sing/common/exceptions" 11 M "github.com/sagernet/sing/common/metadata" 12 N "github.com/sagernet/sing/common/network" 13 ) 14 15 // +----+------+------+----------+----------+----------+ 16 // |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | 17 // +----+------+------+----------+----------+----------+ 18 // | 2 | 1 | 1 | Variable | 2 | Variable | 19 // +----+------+------+----------+----------+----------+ 20 21 var ErrInvalidPacket = E.New("socks5: invalid packet") 22 23 type AssociatePacketConn struct { 24 N.AbstractConn 25 conn N.ExtendedConn 26 remoteAddr M.Socksaddr 27 underlying net.Conn 28 } 29 30 func NewAssociatePacketConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn { 31 return &AssociatePacketConn{ 32 AbstractConn: conn, 33 conn: bufio.NewExtendedConn(conn), 34 remoteAddr: remoteAddr, 35 underlying: underlying, 36 } 37 } 38 39 func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 40 n, err = c.conn.Read(p) 41 if err != nil { 42 return 43 } 44 if n < 3 { 45 return 0, nil, ErrInvalidPacket 46 } 47 reader := bytes.NewReader(p[3:n]) 48 destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) 49 if err != nil { 50 return 51 } 52 c.remoteAddr = destination 53 addr = destination.UDPAddr() 54 index := 3 + int(reader.Size()) - reader.Len() 55 n = copy(p, p[index:n]) 56 return 57 } 58 59 func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 60 destination := M.SocksaddrFromNet(addr) 61 buffer := buf.NewSize(3 + M.SocksaddrSerializer.AddrPortLen(destination) + len(p)) 62 defer buffer.Release() 63 common.Must(buffer.WriteZeroN(3)) 64 err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination) 65 if err != nil { 66 return 67 } 68 _, err = buffer.Write(p) 69 if err != nil { 70 return 71 } 72 return c.conn.Write(buffer.Bytes()) 73 } 74 75 func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 76 err = c.conn.ReadBuffer(buffer) 77 if err != nil { 78 return 79 } 80 if buffer.Len() < 3 { 81 return M.Socksaddr{}, ErrInvalidPacket 82 } 83 buffer.Advance(3) 84 destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer) 85 if err != nil { 86 return 87 } 88 c.remoteAddr = destination 89 return destination.Unwrap(), nil 90 } 91 92 func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 93 header := buf.With(buffer.ExtendHeader(3 + M.SocksaddrSerializer.AddrPortLen(destination))) 94 common.Must(header.WriteZeroN(3)) 95 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 96 if err != nil { 97 return err 98 } 99 return c.conn.WriteBuffer(buffer) 100 } 101 102 func (c *AssociatePacketConn) Read(b []byte) (n int, err error) { 103 n, _, err = c.ReadFrom(b) 104 return 105 } 106 107 func (c *AssociatePacketConn) Write(b []byte) (n int, err error) { 108 return c.WriteTo(b, c.remoteAddr) 109 } 110 111 func (c *AssociatePacketConn) RemoteAddr() net.Addr { 112 return c.remoteAddr.UDPAddr() 113 } 114 115 func (c *AssociatePacketConn) Upstream() any { 116 return c.conn 117 } 118 119 func (c *AssociatePacketConn) FrontHeadroom() int { 120 return 3 + M.MaxSocksaddrLength 121 } 122 123 func (c *AssociatePacketConn) Close() error { 124 return common.Close( 125 c.conn, 126 c.underlying, 127 ) 128 }