github.com/sagernet/sing@v0.2.6/protocol/socks/packet.go (about)

     1  package socks
     2  
     3  import (
     4  	"bytes"
     5  	"net"
     6  
     7  	"github.com/sagernet/sing/common"
     8  	"github.com/sagernet/sing/common/buf"
     9  	"github.com/sagernet/sing/common/bufio"
    10  	M "github.com/sagernet/sing/common/metadata"
    11  	N "github.com/sagernet/sing/common/network"
    12  )
    13  
    14  // +----+------+------+----------+----------+----------+
    15  // |RSV | FRAG | ATYP | DST.ADDR | DST.PORT |   DATA   |
    16  // +----+------+------+----------+----------+----------+
    17  // | 2  |  1   |  1   | Variable |    2     | Variable |
    18  // +----+------+------+----------+----------+----------+
    19  
    20  type AssociatePacketConn struct {
    21  	N.NetPacketConn
    22  	remoteAddr M.Socksaddr
    23  	underlying net.Conn
    24  }
    25  
    26  func NewAssociatePacketConn(conn net.PacketConn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn {
    27  	return &AssociatePacketConn{
    28  		NetPacketConn: bufio.NewPacketConn(conn),
    29  		remoteAddr:    remoteAddr,
    30  		underlying:    underlying,
    31  	}
    32  }
    33  
    34  func NewAssociateConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn {
    35  	return &AssociatePacketConn{
    36  		NetPacketConn: bufio.NewUnbindPacketConn(conn),
    37  		remoteAddr:    remoteAddr,
    38  		underlying:    underlying,
    39  	}
    40  }
    41  
    42  func (c *AssociatePacketConn) RemoteAddr() net.Addr {
    43  	return c.remoteAddr.UDPAddr()
    44  }
    45  
    46  //warn:unsafe
    47  func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
    48  	n, addr, err = c.NetPacketConn.ReadFrom(p)
    49  	if err != nil {
    50  		return
    51  	}
    52  	c.remoteAddr = M.SocksaddrFromNet(addr)
    53  	reader := bytes.NewReader(p[3:n])
    54  	destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
    55  	if err != nil {
    56  		return
    57  	}
    58  	addr = destination.UDPAddr()
    59  	index := 3 + int(reader.Size()) - reader.Len()
    60  	n = copy(p, p[index:n])
    61  	return
    62  }
    63  
    64  //warn:unsafe
    65  func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
    66  	destination := M.SocksaddrFromNet(addr)
    67  	_buffer := buf.StackNewSize(3 + M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
    68  	defer common.KeepAlive(_buffer)
    69  	buffer := common.Dup(_buffer)
    70  	defer buffer.Release()
    71  	common.Must(buffer.WriteZeroN(3))
    72  	err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
    73  	if err != nil {
    74  		return
    75  	}
    76  	_, err = buffer.Write(p)
    77  	if err != nil {
    78  		return
    79  	}
    80  	return bufio.WritePacketBuffer(c.NetPacketConn, buffer, c.remoteAddr)
    81  }
    82  
    83  func (c *AssociatePacketConn) Read(b []byte) (n int, err error) {
    84  	n, _, err = c.ReadFrom(b)
    85  	return
    86  }
    87  
    88  func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
    89  	return c.WriteTo(b, c.remoteAddr)
    90  }
    91  
    92  func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
    93  	destination, err = c.NetPacketConn.ReadPacket(buffer)
    94  	if err != nil {
    95  		return M.Socksaddr{}, err
    96  	}
    97  	c.remoteAddr = destination
    98  	buffer.Advance(3)
    99  	destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
   100  	if err != nil {
   101  		return
   102  	}
   103  	return destination.Unwrap(), nil
   104  }
   105  
   106  func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   107  	header := buf.With(buffer.ExtendHeader(3 + M.SocksaddrSerializer.AddrPortLen(destination)))
   108  	common.Must(header.WriteZeroN(3))
   109  	common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
   110  	return common.Error(bufio.WritePacketBuffer(c.NetPacketConn, buffer, c.remoteAddr))
   111  }
   112  
   113  func (c *AssociatePacketConn) Upstream() any {
   114  	return c.NetPacketConn
   115  }
   116  
   117  func (c *AssociatePacketConn) FrontHeadroom() int {
   118  	return 3 + M.MaxSocksaddrLength
   119  }
   120  
   121  func (c *AssociatePacketConn) Close() error {
   122  	return common.Close(
   123  		c.NetPacketConn,
   124  		c.underlying,
   125  	)
   126  }