github.com/MerlinKodo/sing-tun@v0.1.15/stack_gvisor_udp.go (about)

     1  //go:build with_gvisor
     2  
     3  package tun
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  	"math"
     9  	"net/netip"
    10  	"os"
    11  	"sync"
    12  	"syscall"
    13  
    14  	"github.com/sagernet/sing/common/buf"
    15  	E "github.com/sagernet/sing/common/exceptions"
    16  	M "github.com/sagernet/sing/common/metadata"
    17  	N "github.com/sagernet/sing/common/network"
    18  	"github.com/sagernet/sing/common/udpnat"
    19  
    20  	"github.com/MerlinKodo/gvisor/pkg/buffer"
    21  	"github.com/MerlinKodo/gvisor/pkg/tcpip"
    22  	"github.com/MerlinKodo/gvisor/pkg/tcpip/adapters/gonet"
    23  	"github.com/MerlinKodo/gvisor/pkg/tcpip/checksum"
    24  	"github.com/MerlinKodo/gvisor/pkg/tcpip/header"
    25  	"github.com/MerlinKodo/gvisor/pkg/tcpip/stack"
    26  )
    27  
    28  type UDPForwarder struct {
    29  	ctx    context.Context
    30  	stack  *stack.Stack
    31  	udpNat *udpnat.Service[netip.AddrPort]
    32  
    33  	// cache
    34  	cacheProto tcpip.NetworkProtocolNumber
    35  	cacheID    stack.TransportEndpointID
    36  }
    37  
    38  func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder {
    39  	return &UDPForwarder{
    40  		ctx:    ctx,
    41  		stack:  stack,
    42  		udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler),
    43  	}
    44  }
    45  
    46  func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
    47  	var upstreamMetadata M.Metadata
    48  	upstreamMetadata.Source = M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort)
    49  	upstreamMetadata.Destination = M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort)
    50  	if upstreamMetadata.Source.IsIPv4() {
    51  		f.cacheProto = header.IPv4ProtocolNumber
    52  	} else {
    53  		f.cacheProto = header.IPv6ProtocolNumber
    54  	}
    55  	gBuffer := pkt.Data().ToBuffer()
    56  	sBuffer := buf.NewSize(int(gBuffer.Size()))
    57  	gBuffer.Apply(func(view *buffer.View) {
    58  		sBuffer.Write(view.AsSlice())
    59  	})
    60  	f.cacheID = id
    61  	f.udpNat.NewPacket(
    62  		f.ctx,
    63  		upstreamMetadata.Source.AddrPort(),
    64  		sBuffer,
    65  		upstreamMetadata,
    66  		f.newUDPConn,
    67  	)
    68  	return true
    69  }
    70  
    71  func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter {
    72  	return &UDPBackWriter{
    73  		stack:         f.stack,
    74  		source:        f.cacheID.RemoteAddress,
    75  		sourcePort:    f.cacheID.RemotePort,
    76  		sourceNetwork: f.cacheProto,
    77  	}
    78  }
    79  
    80  type UDPBackWriter struct {
    81  	access        sync.Mutex
    82  	stack         *stack.Stack
    83  	source        tcpip.Address
    84  	sourcePort    uint16
    85  	sourceNetwork tcpip.NetworkProtocolNumber
    86  	packet        stack.PacketBufferPtr
    87  }
    88  
    89  func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error {
    90  	if !destination.IsIP() {
    91  		return E.Cause(os.ErrInvalid, "invalid destination")
    92  	} else if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber {
    93  		destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port)
    94  	} else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4ProtocolNumber) {
    95  		return E.New("send IPv6 packet to IPv4 connection")
    96  	}
    97  
    98  	defer packetBuffer.Release()
    99  
   100  	route, err := w.stack.FindRoute(
   101  		defaultNIC,
   102  		AddressFromAddr(destination.Addr),
   103  		w.source,
   104  		w.sourceNetwork,
   105  		false,
   106  	)
   107  	if err != nil {
   108  		return wrapStackError(err)
   109  	}
   110  	defer route.Release()
   111  
   112  	packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
   113  		ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()),
   114  		Payload:            buffer.MakeWithData(packetBuffer.Bytes()),
   115  	})
   116  	defer packet.DecRef()
   117  
   118  	packet.TransportProtocolNumber = header.UDPProtocolNumber
   119  	udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize))
   120  	pLen := uint16(packet.Size())
   121  	udpHdr.Encode(&header.UDPFields{
   122  		SrcPort: destination.Port,
   123  		DstPort: w.sourcePort,
   124  		Length:  pLen,
   125  	})
   126  
   127  	if route.RequiresTXTransportChecksum() && w.sourceNetwork == header.IPv6ProtocolNumber {
   128  		xsum := udpHdr.CalculateChecksum(checksum.Combine(
   129  			route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen),
   130  			packet.Data().Checksum(),
   131  		))
   132  		if xsum != math.MaxUint16 {
   133  			xsum = ^xsum
   134  		}
   135  		udpHdr.SetChecksum(xsum)
   136  	}
   137  
   138  	err = route.WritePacket(stack.NetworkHeaderParams{
   139  		Protocol: header.UDPProtocolNumber,
   140  		TTL:      route.DefaultTTL(),
   141  		TOS:      0,
   142  	}, packet)
   143  
   144  	if err != nil {
   145  		route.Stats().UDP.PacketSendErrors.Increment()
   146  		return wrapStackError(err)
   147  	}
   148  
   149  	route.Stats().UDP.PacketsSent.Increment()
   150  	return nil
   151  }
   152  
   153  type gRequest struct {
   154  	stack *stack.Stack
   155  	id    stack.TransportEndpointID
   156  	pkt   stack.PacketBufferPtr
   157  }
   158  
   159  type gUDPConn struct {
   160  	*gonet.UDPConn
   161  }
   162  
   163  func (c *gUDPConn) Read(b []byte) (n int, err error) {
   164  	n, err = c.UDPConn.Read(b)
   165  	if err == nil {
   166  		return
   167  	}
   168  	err = wrapError(err)
   169  	return
   170  }
   171  
   172  func (c *gUDPConn) Write(b []byte) (n int, err error) {
   173  	n, err = c.UDPConn.Write(b)
   174  	if err == nil {
   175  		return
   176  	}
   177  	err = wrapError(err)
   178  	return
   179  }
   180  
   181  func (c *gUDPConn) Close() error {
   182  	return c.UDPConn.Close()
   183  }
   184  
   185  func gWriteUnreachable(gStack *stack.Stack, packet stack.PacketBufferPtr, err error) (retErr error) {
   186  	if errors.Is(err, syscall.ENETUNREACH) {
   187  		if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
   188  			return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetUnreachable)
   189  		} else {
   190  			return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
   191  		}
   192  	} else if errors.Is(err, syscall.EHOSTUNREACH) {
   193  		if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
   194  			return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostUnreachable)
   195  		} else {
   196  			return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
   197  		}
   198  	} else if errors.Is(err, syscall.ECONNREFUSED) {
   199  		if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
   200  			return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable)
   201  		} else {
   202  			return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable)
   203  		}
   204  	}
   205  	return nil
   206  }
   207  
   208  func gWriteUnreachable4(gStack *stack.Stack, packet stack.PacketBufferPtr, icmpCode stack.RejectIPv4WithICMPType) error {
   209  	err := gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true)
   210  	if err != nil {
   211  		return wrapStackError(err)
   212  	}
   213  	return nil
   214  }
   215  
   216  func gWriteUnreachable6(gStack *stack.Stack, packet stack.PacketBufferPtr, icmpCode stack.RejectIPv6WithICMPType) error {
   217  	err := gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true)
   218  	if err != nil {
   219  		return wrapStackError(err)
   220  	}
   221  	return nil
   222  }