github.com/sagernet/sing@v0.2.6/common/bufio/nat.go (about)

     1  package bufio
     2  
     3  import (
     4  	"net"
     5  	"net/netip"
     6  
     7  	"github.com/sagernet/sing/common/buf"
     8  	M "github.com/sagernet/sing/common/metadata"
     9  	N "github.com/sagernet/sing/common/network"
    10  )
    11  
    12  type NATPacketConn struct {
    13  	N.NetPacketConn
    14  	origin      M.Socksaddr
    15  	destination M.Socksaddr
    16  }
    17  
    18  func NewNATPacketConn(conn N.NetPacketConn, origin M.Socksaddr, destination M.Socksaddr) *NATPacketConn {
    19  	return &NATPacketConn{
    20  		NetPacketConn: conn,
    21  		origin:        origin,
    22  		destination:   destination,
    23  	}
    24  }
    25  
    26  func (c *NATPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
    27  	n, addr, err = c.NetPacketConn.ReadFrom(p)
    28  	if err == nil && M.SocksaddrFromNet(addr) == c.origin {
    29  		addr = c.destination.UDPAddr()
    30  	}
    31  	return
    32  }
    33  
    34  func (c *NATPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
    35  	if M.SocksaddrFromNet(addr) == c.destination {
    36  		addr = c.origin.UDPAddr()
    37  	}
    38  	return c.NetPacketConn.WriteTo(p, addr)
    39  }
    40  
    41  func (c *NATPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
    42  	destination, err = c.NetPacketConn.ReadPacket(buffer)
    43  	if destination == c.origin {
    44  		destination = c.destination
    45  	}
    46  	return
    47  }
    48  
    49  func (c *NATPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
    50  	if destination == c.destination {
    51  		destination = c.origin
    52  	}
    53  	return c.NetPacketConn.WritePacket(buffer, destination)
    54  }
    55  
    56  func (c *NATPacketConn) UpdateDestination(destinationAddress netip.Addr) {
    57  	c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
    58  }
    59  
    60  func (c *NATPacketConn) Upstream() any {
    61  	return c.NetPacketConn
    62  }