github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/bufio/nat.go (about)

     1  package bufio
     2  
     3  import (
     4  	"net"
     5  	"net/netip"
     6  
     7  	"github.com/sagernet/sing/common/buf"
     8  	M "github.com/sagernet/sing/common/metadata"
     9  	N "github.com/sagernet/sing/common/network"
    10  )
    11  
    12  type NATPacketConn interface {
    13  	N.NetPacketConn
    14  	UpdateDestination(destinationAddress netip.Addr)
    15  }
    16  
    17  func NewUnidirectionalNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
    18  	return &unidirectionalNATPacketConn{
    19  		NetPacketConn: conn,
    20  		origin:        socksaddrWithoutPort(origin),
    21  		destination:   socksaddrWithoutPort(destination),
    22  	}
    23  }
    24  
    25  func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) NATPacketConn {
    26  	return &bidirectionalNATPacketConn{
    27  		NetPacketConn: conn,
    28  		origin:        socksaddrWithoutPort(origin),
    29  		destination:   socksaddrWithoutPort(destination),
    30  	}
    31  }
    32  
    33  type unidirectionalNATPacketConn struct {
    34  	N.NetPacketConn
    35  	origin      M.Socksaddr
    36  	destination M.Socksaddr
    37  }
    38  
    39  func (c *unidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
    40  	destination := M.SocksaddrFromNet(addr)
    41  	if socksaddrWithoutPort(destination) == c.destination {
    42  		destination = M.Socksaddr{
    43  			Addr: c.origin.Addr,
    44  			Fqdn: c.origin.Fqdn,
    45  			Port: destination.Port,
    46  		}
    47  	}
    48  	return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
    49  }
    50  
    51  func (c *unidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
    52  	if socksaddrWithoutPort(destination) == c.destination {
    53  		destination = M.Socksaddr{
    54  			Addr: c.origin.Addr,
    55  			Fqdn: c.origin.Fqdn,
    56  			Port: destination.Port,
    57  		}
    58  	}
    59  	return c.NetPacketConn.WritePacket(buffer, destination)
    60  }
    61  
    62  func (c *unidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
    63  	c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
    64  }
    65  
    66  func (c *unidirectionalNATPacketConn) RemoteAddr() net.Addr {
    67  	return c.destination.UDPAddr()
    68  }
    69  
    70  func (c *unidirectionalNATPacketConn) Upstream() any {
    71  	return c.NetPacketConn
    72  }
    73  
    74  type bidirectionalNATPacketConn struct {
    75  	N.NetPacketConn
    76  	origin      M.Socksaddr
    77  	destination M.Socksaddr
    78  }
    79  
    80  func (c *bidirectionalNATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
    81  	n, addr, err = c.NetPacketConn.ReadFrom(p)
    82  	if err != nil {
    83  		return
    84  	}
    85  	destination := M.SocksaddrFromNet(addr)
    86  	if socksaddrWithoutPort(destination) == c.origin {
    87  		destination = M.Socksaddr{
    88  			Addr: c.destination.Addr,
    89  			Fqdn: c.destination.Fqdn,
    90  			Port: destination.Port,
    91  		}
    92  	}
    93  	addr = destination.UDPAddr()
    94  	return
    95  }
    96  
    97  func (c *bidirectionalNATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
    98  	destination := M.SocksaddrFromNet(addr)
    99  	if socksaddrWithoutPort(destination) == c.destination {
   100  		destination = M.Socksaddr{
   101  			Addr: c.origin.Addr,
   102  			Fqdn: c.origin.Fqdn,
   103  			Port: destination.Port,
   104  		}
   105  	}
   106  	return c.NetPacketConn.WriteTo(p, destination.UDPAddr())
   107  }
   108  
   109  func (c *bidirectionalNATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   110  	destination, err = c.NetPacketConn.ReadPacket(buffer)
   111  	if err != nil {
   112  		return
   113  	}
   114  	if socksaddrWithoutPort(destination) == c.origin {
   115  		destination = M.Socksaddr{
   116  			Addr: c.destination.Addr,
   117  			Fqdn: c.destination.Fqdn,
   118  			Port: destination.Port,
   119  		}
   120  	}
   121  	return
   122  }
   123  
   124  func (c *bidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   125  	if socksaddrWithoutPort(destination) == c.destination {
   126  		destination = M.Socksaddr{
   127  			Addr: c.origin.Addr,
   128  			Fqdn: c.origin.Fqdn,
   129  			Port: destination.Port,
   130  		}
   131  	}
   132  	return c.NetPacketConn.WritePacket(buffer, destination)
   133  }
   134  
   135  func (c *bidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
   136  	c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
   137  }
   138  
   139  func (c *bidirectionalNATPacketConn) Upstream() any {
   140  	return c.NetPacketConn
   141  }
   142  
   143  func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr {
   144  	return c.destination.UDPAddr()
   145  }
   146  
   147  func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
   148  	destination.Port = 0
   149  	return destination
   150  }