github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/protocol/socks/packet.go (about)

     1  package socks
     2  
     3  import (
     4  	"bytes"
     5  	"net"
     6  
     7  	"github.com/sagernet/sing/common"
     8  	"github.com/sagernet/sing/common/buf"
     9  	"github.com/sagernet/sing/common/bufio"
    10  	E "github.com/sagernet/sing/common/exceptions"
    11  	M "github.com/sagernet/sing/common/metadata"
    12  	N "github.com/sagernet/sing/common/network"
    13  )
    14  
    15  // +----+------+------+----------+----------+----------+
    16  // |RSV | FRAG | ATYP | DST.ADDR | DST.PORT |   DATA   |
    17  // +----+------+------+----------+----------+----------+
    18  // | 2  |  1   |  1   | Variable |    2     | Variable |
    19  // +----+------+------+----------+----------+----------+
    20  
    21  var ErrInvalidPacket = E.New("socks5: invalid packet")
    22  
    23  type AssociatePacketConn struct {
    24  	N.AbstractConn
    25  	conn       N.ExtendedConn
    26  	remoteAddr M.Socksaddr
    27  	underlying net.Conn
    28  }
    29  
    30  func NewAssociatePacketConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn {
    31  	return &AssociatePacketConn{
    32  		AbstractConn: conn,
    33  		conn:         bufio.NewExtendedConn(conn),
    34  		remoteAddr:   remoteAddr,
    35  		underlying:   underlying,
    36  	}
    37  }
    38  
    39  func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
    40  	n, err = c.conn.Read(p)
    41  	if err != nil {
    42  		return
    43  	}
    44  	if n < 3 {
    45  		return 0, nil, ErrInvalidPacket
    46  	}
    47  	reader := bytes.NewReader(p[3:n])
    48  	destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
    49  	if err != nil {
    50  		return
    51  	}
    52  	c.remoteAddr = destination
    53  	addr = destination.UDPAddr()
    54  	index := 3 + int(reader.Size()) - reader.Len()
    55  	n = copy(p, p[index:n])
    56  	return
    57  }
    58  
    59  func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
    60  	destination := M.SocksaddrFromNet(addr)
    61  	buffer := buf.NewSize(3 + M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
    62  	defer buffer.Release()
    63  	common.Must(buffer.WriteZeroN(3))
    64  	err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
    65  	if err != nil {
    66  		return
    67  	}
    68  	_, err = buffer.Write(p)
    69  	if err != nil {
    70  		return
    71  	}
    72  	return c.conn.Write(buffer.Bytes())
    73  }
    74  
    75  func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
    76  	err = c.conn.ReadBuffer(buffer)
    77  	if err != nil {
    78  		return
    79  	}
    80  	if buffer.Len() < 3 {
    81  		return M.Socksaddr{}, ErrInvalidPacket
    82  	}
    83  	buffer.Advance(3)
    84  	destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
    85  	if err != nil {
    86  		return
    87  	}
    88  	c.remoteAddr = destination
    89  	return destination.Unwrap(), nil
    90  }
    91  
    92  func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
    93  	header := buf.With(buffer.ExtendHeader(3 + M.SocksaddrSerializer.AddrPortLen(destination)))
    94  	common.Must(header.WriteZeroN(3))
    95  	err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
    96  	if err != nil {
    97  		return err
    98  	}
    99  	return c.conn.WriteBuffer(buffer)
   100  }
   101  
   102  func (c *AssociatePacketConn) Read(b []byte) (n int, err error) {
   103  	n, _, err = c.ReadFrom(b)
   104  	return
   105  }
   106  
   107  func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
   108  	return c.WriteTo(b, c.remoteAddr)
   109  }
   110  
   111  func (c *AssociatePacketConn) RemoteAddr() net.Addr {
   112  	return c.remoteAddr.UDPAddr()
   113  }
   114  
   115  func (c *AssociatePacketConn) Upstream() any {
   116  	return c.conn
   117  }
   118  
   119  func (c *AssociatePacketConn) FrontHeadroom() int {
   120  	return 3 + M.MaxSocksaddrLength
   121  }
   122  
   123  func (c *AssociatePacketConn) Close() error {
   124  	return common.Close(
   125  		c.conn,
   126  		c.underlying,
   127  	)
   128  }