github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/app/tun/handler_udp.go (about)

     1  package tun
     2  
     3  import (
     4  	"context"
     5  
     6  	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
     7  	"gvisor.dev/gvisor/pkg/tcpip/stack"
     8  	gvisor_udp "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
     9  	"gvisor.dev/gvisor/pkg/waiter"
    10  
    11  	tun_net "github.com/v2fly/v2ray-core/v5/app/tun/net"
    12  	"github.com/v2fly/v2ray-core/v5/common/buf"
    13  	"github.com/v2fly/v2ray-core/v5/common/net"
    14  	udp_proto "github.com/v2fly/v2ray-core/v5/common/protocol/udp"
    15  	"github.com/v2fly/v2ray-core/v5/common/session"
    16  	"github.com/v2fly/v2ray-core/v5/features/policy"
    17  	"github.com/v2fly/v2ray-core/v5/features/routing"
    18  	"github.com/v2fly/v2ray-core/v5/transport/internet/udp"
    19  )
    20  
    21  type UDPHandler struct {
    22  	ctx           context.Context
    23  	dispatcher    routing.Dispatcher
    24  	policyManager policy.Manager
    25  	config        *Config
    26  }
    27  
    28  type udpConn struct {
    29  	*gonet.UDPConn
    30  	id stack.TransportEndpointID
    31  }
    32  
    33  func (c *udpConn) ID() *stack.TransportEndpointID {
    34  	return &c.id
    35  }
    36  
    37  func SetUDPHandler(ctx context.Context, dispatcher routing.Dispatcher, policyManager policy.Manager, config *Config) StackOption {
    38  	return func(s *stack.Stack) error {
    39  		udpForwarder := gvisor_udp.NewForwarder(s, func(r *gvisor_udp.ForwarderRequest) {
    40  			wg := new(waiter.Queue)
    41  			linkedEndpoint, err := r.CreateEndpoint(wg)
    42  			if err != nil {
    43  				newError("failed to create endpoint: ", err).WriteToLog(session.ExportIDToError(ctx))
    44  				return
    45  			}
    46  
    47  			conn := &udpConn{
    48  				UDPConn: gonet.NewUDPConn(s, wg, linkedEndpoint),
    49  				id:      r.ID(),
    50  			}
    51  
    52  			handler := &UDPHandler{
    53  				ctx:           ctx,
    54  				dispatcher:    dispatcher,
    55  				policyManager: policyManager,
    56  				config:        config,
    57  			}
    58  			go handler.Handle(conn)
    59  		})
    60  		s.SetTransportProtocolHandler(gvisor_udp.ProtocolNumber, udpForwarder.HandlePacket)
    61  		return nil
    62  	}
    63  }
    64  
    65  func (h *UDPHandler) Handle(conn tun_net.UDPConn) error {
    66  	defer conn.Close()
    67  	id := conn.ID()
    68  	ctx := session.ContextWithInbound(h.ctx, &session.Inbound{Tag: h.config.Tag})
    69  	content := new(session.Content)
    70  	if h.config.SniffingSettings != nil {
    71  		content.SniffingRequest.Enabled = h.config.SniffingSettings.Enabled
    72  		content.SniffingRequest.OverrideDestinationForProtocol = h.config.SniffingSettings.DestinationOverride
    73  		content.SniffingRequest.MetadataOnly = h.config.SniffingSettings.MetadataOnly
    74  	}
    75  	ctx = session.ContextWithContent(ctx, content)
    76  
    77  	udpDispatcherConstructor := udp.NewSplitDispatcher
    78  
    79  	dest := net.UDPDestination(tun_net.AddressFromTCPIPAddr(id.LocalAddress), net.Port(id.LocalPort))
    80  	src := net.UDPDestination(tun_net.AddressFromTCPIPAddr(id.RemoteAddress), net.Port(id.RemotePort))
    81  
    82  	udpServer := udpDispatcherConstructor(h.dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
    83  		if _, err := conn.WriteTo(packet.Payload.Bytes(), &net.UDPAddr{
    84  			IP:   src.Address.IP(),
    85  			Port: int(src.Port),
    86  		}); err != nil {
    87  			newError("failed to write UDP packet").Base(err).WriteToLog()
    88  		}
    89  	})
    90  
    91  	for {
    92  		select {
    93  		case <-ctx.Done():
    94  			return nil
    95  		default:
    96  			var buffer [2048]byte
    97  			n, _, err := conn.ReadFrom(buffer[:])
    98  			if err != nil {
    99  				return newError("failed to read UDP packet").Base(err)
   100  			}
   101  			currentPacketCtx := ctx
   102  
   103  			udpServer.Dispatch(currentPacketCtx, dest, buf.FromBytes(buffer[:n]))
   104  		}
   105  	}
   106  }