github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/tun/gvisor/udp.go (about)

     1  package tun
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math"
     7  	"net"
     8  
     9  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    10  	"github.com/Asutorufa/yuhaiin/pkg/protos/statistic"
    11  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    12  	"gvisor.dev/gvisor/pkg/buffer"
    13  	"gvisor.dev/gvisor/pkg/tcpip"
    14  	"gvisor.dev/gvisor/pkg/tcpip/checksum"
    15  	"gvisor.dev/gvisor/pkg/tcpip/header"
    16  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    17  )
    18  
    19  // func (t *tunServer) udpForwarder() *udp.Forwarder {
    20  // 	return udp.NewForwarder(t.stack, func(fr *udp.ForwarderRequest) {
    21  // 		var wq waiter.Queue
    22  // 		ep, err := fr.CreateEndpoint(&wq)
    23  // 		if err != nil {
    24  // 			log.Error("create endpoint failed:", "err", err)
    25  // 			return
    26  // 		}
    27  
    28  // 		local := gonet.NewUDPConn(&wq, ep)
    29  
    30  // 		go func(local *gonet.UDPConn, id stack.TransportEndpointID) {
    31  // 			defer local.Close()
    32  
    33  // 			addr, ok := netip.AddrFromSlice(id.LocalAddress.AsSlice())
    34  // 			if !ok {
    35  // 				return
    36  // 			}
    37  
    38  // 			dst := netapi.ParseAddrPort(statistic.Type_udp, netip.AddrPortFrom(addr, id.LocalPort))
    39  
    40  // 			for {
    41  // 				buf := pool.GetBytesBuffer(t.mtu)
    42  
    43  // 				_ = local.SetReadDeadline(time.Now().Add(nat.IdleTimeout))
    44  // 				_, src, err := buf.ReadFromPacket(local)
    45  // 				if err != nil {
    46  // 					if ne, ok := err.(net.Error); (ok && ne.Timeout()) || err == io.EOF {
    47  // 						return /* ignore I/O timeout & EOF */
    48  // 					}
    49  
    50  // 					log.Error("read udp failed:", "err", err)
    51  // 					return
    52  // 				}
    53  
    54  // 				err = t.SendPacket(&netapi.Packet{
    55  // 					Src:     src,
    56  // 					Dst:     dst,
    57  // 					Payload: buf,
    58  // 					WriteBack: func(b []byte, addr net.Addr) (int, error) {
    59  // 						from, err := netapi.ParseSysAddr(addr)
    60  // 						if err != nil {
    61  // 							return 0, err
    62  // 						}
    63  
    64  // 						// Symmetric NAT
    65  // 						// gVisor udp.NewForwarder only support Symmetric NAT,
    66  // 						// can't set source in udp header
    67  // 						// TODO: rewrite HandlePacket() to support full cone NAT
    68  // 						if from.String() != dst.String() {
    69  // 							return 0, nil
    70  // 						}
    71  
    72  // 						n, err := local.WriteTo(b, src)
    73  // 						if err != nil {
    74  // 							return n, err
    75  // 						}
    76  
    77  // 						_ = local.SetReadDeadline(time.Now().Add(nat.IdleTimeout))
    78  // 						return n, nil
    79  // 					},
    80  // 				})
    81  // 				if err != nil {
    82  // 					return
    83  // 				}
    84  // 			}
    85  
    86  // 		}(local, fr.ID())
    87  // 	})
    88  // }
    89  
    90  func (f *tunServer) HandleUDPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
    91  	srcPort, dstPort := id.RemotePort, id.LocalPort
    92  
    93  	length := pkt.Data().Size()
    94  	buf := pool.GetBytesWriter(length)
    95  
    96  	_, err := pkt.Data().ReadTo(buf, true)
    97  	if err != nil {
    98  		return true
    99  	}
   100  
   101  	_ = f.SendPacket(&netapi.Packet{
   102  		Src:     netapi.ParseIPAddrPort(statistic.Type_udp, id.RemoteAddress.AsSlice(), int(srcPort)),
   103  		Dst:     netapi.ParseIPAddrPort(statistic.Type_udp, id.LocalAddress.AsSlice(), int(dstPort)),
   104  		Payload: buf.Unwrap(),
   105  		WriteBack: func(b []byte, addr net.Addr) (int, error) {
   106  			return f.WriteUDPBack(b, id.RemoteAddress, srcPort, addr)
   107  		},
   108  	})
   109  	return true
   110  }
   111  
   112  func (w *tunServer) WriteUDPBack(data []byte, sourceAddr tcpip.Address, sourcePort uint16, destination net.Addr) (int, error) {
   113  	daddr, err := netapi.ParseSysAddr(destination)
   114  	if err != nil {
   115  		return 0, err
   116  	}
   117  
   118  	if daddr.IsFqdn() {
   119  		return 0, fmt.Errorf("send FQDN packet")
   120  	}
   121  
   122  	dip := daddr.AddrPort(context.TODO()).V
   123  
   124  	if sourceAddr.Len() == 4 && dip.Addr().Is6() {
   125  		return 0, fmt.Errorf("send IPv6 packet to IPv4 connection")
   126  	}
   127  
   128  	var addr tcpip.Address
   129  	var sourceNetwork tcpip.NetworkProtocolNumber
   130  	if sourceAddr.Len() == 16 {
   131  		addr = tcpip.AddrFrom16(dip.Addr().As16())
   132  		sourceNetwork = header.IPv6ProtocolNumber
   133  	} else {
   134  		addr = tcpip.AddrFrom4(dip.Addr().As4())
   135  		sourceNetwork = header.IPv4ProtocolNumber
   136  	}
   137  
   138  	route, gerr := w.stack.FindRoute(w.nicID, addr, sourceAddr, sourceNetwork, false)
   139  	if gerr != nil {
   140  		return 0, fmt.Errorf("failed to find route: %v", gerr)
   141  	}
   142  	defer route.Release()
   143  
   144  	packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
   145  		ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()),
   146  		Payload:            buffer.MakeWithData(data),
   147  	})
   148  	defer packet.DecRef()
   149  
   150  	packet.TransportProtocolNumber = header.UDPProtocolNumber
   151  	udp := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize))
   152  	pLen := uint16(packet.Size())
   153  	udp.Encode(&header.UDPFields{
   154  		SrcPort: dip.Port(),
   155  		DstPort: sourcePort,
   156  		Length:  pLen,
   157  	})
   158  
   159  	// Set the checksum field unless TX checksum offload is enabled.
   160  	// On IPv4, UDP checksum is optional, and a zero value indicates the
   161  	// transmitter skipped the checksum generation (RFC768).
   162  	// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
   163  	if route.RequiresTXTransportChecksum() && sourceNetwork == header.IPv6ProtocolNumber {
   164  		xsum := udp.CalculateChecksum(checksum.Combine(
   165  			route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen),
   166  			packet.Data().Checksum(),
   167  		))
   168  		if xsum != math.MaxUint16 {
   169  			xsum = ^xsum
   170  		}
   171  		udp.SetChecksum(xsum)
   172  	}
   173  
   174  	gerr = route.WritePacket(stack.NetworkHeaderParams{
   175  		Protocol: header.UDPProtocolNumber,
   176  		TTL:      route.DefaultTTL(),
   177  		TOS:      0,
   178  	}, packet)
   179  	if gerr != nil {
   180  		route.Stats().UDP.PacketSendErrors.Increment()
   181  		return 0, fmt.Errorf("failed to write packet: %v", gerr)
   182  	}
   183  
   184  	route.Stats().UDP.PacketsSent.Increment()
   185  	return len(data), nil
   186  }