github.com/metacubex/sing-tun@v0.2.7-0.20240512075008-89e7c6208eec/stack_gvisor_udp.go (about)

     1  //go:build with_gvisor
     2  
     3  package tun
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  	"math"
     9  	"net/netip"
    10  	"os"
    11  	"sync"
    12  	"syscall"
    13  
    14  	"github.com/sagernet/sing/common/buf"
    15  	E "github.com/sagernet/sing/common/exceptions"
    16  	M "github.com/sagernet/sing/common/metadata"
    17  	N "github.com/sagernet/sing/common/network"
    18  	"github.com/sagernet/sing/common/udpnat"
    19  
    20  	"github.com/metacubex/gvisor/pkg/buffer"
    21  	"github.com/metacubex/gvisor/pkg/tcpip"
    22  	"github.com/metacubex/gvisor/pkg/tcpip/adapters/gonet"
    23  	"github.com/metacubex/gvisor/pkg/tcpip/checksum"
    24  	"github.com/metacubex/gvisor/pkg/tcpip/header"
    25  	"github.com/metacubex/gvisor/pkg/tcpip/stack"
    26  )
    27  
    28  type UDPForwarder struct {
    29  	ctx    context.Context
    30  	stack  *stack.Stack
    31  	udpNat *udpnat.Service[netip.AddrPort]
    32  
    33  	// cache
    34  	cacheProto tcpip.NetworkProtocolNumber
    35  	cacheID    stack.TransportEndpointID
    36  }
    37  
    38  func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder {
    39  	return &UDPForwarder{
    40  		ctx:    ctx,
    41  		stack:  stack,
    42  		udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler),
    43  	}
    44  }
    45  
    46  func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
    47  	var upstreamMetadata M.Metadata
    48  	upstreamMetadata.Source = M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort)
    49  	upstreamMetadata.Destination = M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort)
    50  	if upstreamMetadata.Source.IsIPv4() {
    51  		f.cacheProto = header.IPv4ProtocolNumber
    52  	} else {
    53  		f.cacheProto = header.IPv6ProtocolNumber
    54  	}
    55  	gBuffer := pkt.Data().ToBuffer()
    56  	sBuffer := buf.NewSize(int(gBuffer.Size()))
    57  	gBuffer.Apply(func(view *buffer.View) {
    58  		sBuffer.Write(view.AsSlice())
    59  	})
    60  	f.cacheID = id
    61  	f.udpNat.NewPacket(
    62  		f.ctx,
    63  		upstreamMetadata.Source.AddrPort(),
    64  		sBuffer,
    65  		upstreamMetadata,
    66  		f.newUDPConn,
    67  	)
    68  	return true
    69  }
    70  
    71  func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter {
    72  	return &UDPBackWriter{
    73  		stack:         f.stack,
    74  		source:        f.cacheID.RemoteAddress,
    75  		sourcePort:    f.cacheID.RemotePort,
    76  		sourceNetwork: f.cacheProto,
    77  	}
    78  }
    79  
    80  type UDPBackWriter struct {
    81  	access        sync.Mutex
    82  	stack         *stack.Stack
    83  	source        tcpip.Address
    84  	sourcePort    uint16
    85  	sourceNetwork tcpip.NetworkProtocolNumber
    86  }
    87  
    88  func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error {
    89  	if !destination.IsIP() {
    90  		return E.Cause(os.ErrInvalid, "invalid destination")
    91  	} else if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber {
    92  		destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port)
    93  	} else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4ProtocolNumber) {
    94  		return E.New("send IPv6 packet to IPv4 connection")
    95  	}
    96  
    97  	defer packetBuffer.Release()
    98  
    99  	route, err := w.stack.FindRoute(
   100  		defaultNIC,
   101  		AddressFromAddr(destination.Addr),
   102  		w.source,
   103  		w.sourceNetwork,
   104  		false,
   105  	)
   106  	if err != nil {
   107  		return wrapStackError(err)
   108  	}
   109  	defer route.Release()
   110  
   111  	packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
   112  		ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()),
   113  		Payload:            buffer.MakeWithData(packetBuffer.Bytes()),
   114  	})
   115  	defer packet.DecRef()
   116  
   117  	packet.TransportProtocolNumber = header.UDPProtocolNumber
   118  	udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize))
   119  	pLen := uint16(packet.Size())
   120  	udpHdr.Encode(&header.UDPFields{
   121  		SrcPort: destination.Port,
   122  		DstPort: w.sourcePort,
   123  		Length:  pLen,
   124  	})
   125  
   126  	if route.RequiresTXTransportChecksum() && w.sourceNetwork == header.IPv6ProtocolNumber {
   127  		xsum := udpHdr.CalculateChecksum(checksum.Combine(
   128  			route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen),
   129  			packet.Data().Checksum(),
   130  		))
   131  		if xsum != math.MaxUint16 {
   132  			xsum = ^xsum
   133  		}
   134  		udpHdr.SetChecksum(xsum)
   135  	}
   136  
   137  	err = route.WritePacket(stack.NetworkHeaderParams{
   138  		Protocol: header.UDPProtocolNumber,
   139  		TTL:      route.DefaultTTL(),
   140  		TOS:      0,
   141  	}, packet)
   142  
   143  	if err != nil {
   144  		route.Stats().UDP.PacketSendErrors.Increment()
   145  		return wrapStackError(err)
   146  	}
   147  
   148  	route.Stats().UDP.PacketsSent.Increment()
   149  	return nil
   150  }
   151  
   152  type gUDPConn struct {
   153  	*gonet.UDPConn
   154  }
   155  
   156  func (c *gUDPConn) Read(b []byte) (n int, err error) {
   157  	n, err = c.UDPConn.Read(b)
   158  	if err == nil {
   159  		return
   160  	}
   161  	err = wrapError(err)
   162  	return
   163  }
   164  
   165  func (c *gUDPConn) Write(b []byte) (n int, err error) {
   166  	n, err = c.UDPConn.Write(b)
   167  	if err == nil {
   168  		return
   169  	}
   170  	err = wrapError(err)
   171  	return
   172  }
   173  
   174  func (c *gUDPConn) Close() error {
   175  	return c.UDPConn.Close()
   176  }
   177  
   178  func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err error) (retErr error) {
   179  	if errors.Is(err, syscall.ENETUNREACH) {
   180  		if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
   181  			return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetUnreachable)
   182  		} else {
   183  			return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
   184  		}
   185  	} else if errors.Is(err, syscall.EHOSTUNREACH) {
   186  		if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
   187  			return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostUnreachable)
   188  		} else {
   189  			return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
   190  		}
   191  	} else if errors.Is(err, syscall.ECONNREFUSED) {
   192  		if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
   193  			return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable)
   194  		} else {
   195  			return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable)
   196  		}
   197  	}
   198  	return nil
   199  }
   200  
   201  func gWriteUnreachable4(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv4WithICMPType) error {
   202  	err := gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true)
   203  	if err != nil {
   204  		return wrapStackError(err)
   205  	}
   206  	return nil
   207  }
   208  
   209  func gWriteUnreachable6(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv6WithICMPType) error {
   210  	err := gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true)
   211  	if err != nil {
   212  		return wrapStackError(err)
   213  	}
   214  	return nil
   215  }