github.com/sagernet/sing-tun@v0.3.0-beta.5/stack_gvisor_udp.go (about)

     1  //go:build with_gvisor
     2  
     3  package tun
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  	"math"
     9  	"net/netip"
    10  	"os"
    11  	"sync"
    12  	"syscall"
    13  
    14  	"github.com/sagernet/gvisor/pkg/buffer"
    15  	"github.com/sagernet/gvisor/pkg/tcpip"
    16  	"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
    17  	"github.com/sagernet/gvisor/pkg/tcpip/checksum"
    18  	"github.com/sagernet/gvisor/pkg/tcpip/header"
    19  	"github.com/sagernet/gvisor/pkg/tcpip/stack"
    20  	"github.com/sagernet/sing/common/buf"
    21  	E "github.com/sagernet/sing/common/exceptions"
    22  	M "github.com/sagernet/sing/common/metadata"
    23  	N "github.com/sagernet/sing/common/network"
    24  	"github.com/sagernet/sing/common/udpnat"
    25  )
    26  
    27  type UDPForwarder struct {
    28  	ctx    context.Context
    29  	stack  *stack.Stack
    30  	udpNat *udpnat.Service[netip.AddrPort]
    31  
    32  	// cache
    33  	cacheProto tcpip.NetworkProtocolNumber
    34  	cacheID    stack.TransportEndpointID
    35  }
    36  
    37  func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder {
    38  	return &UDPForwarder{
    39  		ctx:    ctx,
    40  		stack:  stack,
    41  		udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler),
    42  	}
    43  }
    44  
    45  func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
    46  	var upstreamMetadata M.Metadata
    47  	upstreamMetadata.Source = M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort)
    48  	upstreamMetadata.Destination = M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort)
    49  	if upstreamMetadata.Source.IsIPv4() {
    50  		f.cacheProto = header.IPv4ProtocolNumber
    51  	} else {
    52  		f.cacheProto = header.IPv6ProtocolNumber
    53  	}
    54  	gBuffer := pkt.Data().ToBuffer()
    55  	sBuffer := buf.NewSize(int(gBuffer.Size()))
    56  	gBuffer.Apply(func(view *buffer.View) {
    57  		sBuffer.Write(view.AsSlice())
    58  	})
    59  	f.cacheID = id
    60  	f.udpNat.NewPacket(
    61  		f.ctx,
    62  		upstreamMetadata.Source.AddrPort(),
    63  		sBuffer,
    64  		upstreamMetadata,
    65  		f.newUDPConn,
    66  	)
    67  	return true
    68  }
    69  
    70  func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter {
    71  	return &UDPBackWriter{
    72  		stack:         f.stack,
    73  		source:        f.cacheID.RemoteAddress,
    74  		sourcePort:    f.cacheID.RemotePort,
    75  		sourceNetwork: f.cacheProto,
    76  	}
    77  }
    78  
    79  type UDPBackWriter struct {
    80  	access        sync.Mutex
    81  	stack         *stack.Stack
    82  	source        tcpip.Address
    83  	sourcePort    uint16
    84  	sourceNetwork tcpip.NetworkProtocolNumber
    85  }
    86  
    87  func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error {
    88  	if !destination.IsIP() {
    89  		return E.Cause(os.ErrInvalid, "invalid destination")
    90  	} else if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber {
    91  		destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port)
    92  	} else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4ProtocolNumber) {
    93  		return E.New("send IPv6 packet to IPv4 connection")
    94  	}
    95  
    96  	defer packetBuffer.Release()
    97  
    98  	route, err := w.stack.FindRoute(
    99  		defaultNIC,
   100  		AddressFromAddr(destination.Addr),
   101  		w.source,
   102  		w.sourceNetwork,
   103  		false,
   104  	)
   105  	if err != nil {
   106  		return wrapStackError(err)
   107  	}
   108  	defer route.Release()
   109  
   110  	packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
   111  		ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()),
   112  		Payload:            buffer.MakeWithData(packetBuffer.Bytes()),
   113  	})
   114  	defer packet.DecRef()
   115  
   116  	packet.TransportProtocolNumber = header.UDPProtocolNumber
   117  	udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize))
   118  	pLen := uint16(packet.Size())
   119  	udpHdr.Encode(&header.UDPFields{
   120  		SrcPort: destination.Port,
   121  		DstPort: w.sourcePort,
   122  		Length:  pLen,
   123  	})
   124  
   125  	if route.RequiresTXTransportChecksum() && w.sourceNetwork == header.IPv6ProtocolNumber {
   126  		xsum := udpHdr.CalculateChecksum(checksum.Combine(
   127  			route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen),
   128  			packet.Data().Checksum(),
   129  		))
   130  		if xsum != math.MaxUint16 {
   131  			xsum = ^xsum
   132  		}
   133  		udpHdr.SetChecksum(xsum)
   134  	}
   135  
   136  	err = route.WritePacket(stack.NetworkHeaderParams{
   137  		Protocol: header.UDPProtocolNumber,
   138  		TTL:      route.DefaultTTL(),
   139  		TOS:      0,
   140  	}, packet)
   141  
   142  	if err != nil {
   143  		route.Stats().UDP.PacketSendErrors.Increment()
   144  		return wrapStackError(err)
   145  	}
   146  
   147  	route.Stats().UDP.PacketsSent.Increment()
   148  	return nil
   149  }
   150  
   151  type gUDPConn struct {
   152  	*gonet.UDPConn
   153  }
   154  
   155  func (c *gUDPConn) Read(b []byte) (n int, err error) {
   156  	n, err = c.UDPConn.Read(b)
   157  	if err == nil {
   158  		return
   159  	}
   160  	err = wrapError(err)
   161  	return
   162  }
   163  
   164  func (c *gUDPConn) Write(b []byte) (n int, err error) {
   165  	n, err = c.UDPConn.Write(b)
   166  	if err == nil {
   167  		return
   168  	}
   169  	err = wrapError(err)
   170  	return
   171  }
   172  
   173  func (c *gUDPConn) Close() error {
   174  	return c.UDPConn.Close()
   175  }
   176  
   177  func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err error) (retErr error) {
   178  	if errors.Is(err, syscall.ENETUNREACH) {
   179  		if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
   180  			return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetUnreachable)
   181  		} else {
   182  			return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
   183  		}
   184  	} else if errors.Is(err, syscall.EHOSTUNREACH) {
   185  		if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
   186  			return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostUnreachable)
   187  		} else {
   188  			return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
   189  		}
   190  	} else if errors.Is(err, syscall.ECONNREFUSED) {
   191  		if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
   192  			return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable)
   193  		} else {
   194  			return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable)
   195  		}
   196  	}
   197  	return nil
   198  }
   199  
   200  func gWriteUnreachable4(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv4WithICMPType) error {
   201  	err := gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true)
   202  	if err != nil {
   203  		return wrapStackError(err)
   204  	}
   205  	return nil
   206  }
   207  
   208  func gWriteUnreachable6(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv6WithICMPType) error {
   209  	err := gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true)
   210  	if err != nil {
   211  		return wrapStackError(err)
   212  	}
   213  	return nil
   214  }