github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/tun/gvisor/udp.go (about) 1 package tun 2 3 import ( 4 "context" 5 "fmt" 6 "math" 7 "net" 8 9 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 10 "github.com/Asutorufa/yuhaiin/pkg/protos/statistic" 11 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 12 "gvisor.dev/gvisor/pkg/buffer" 13 "gvisor.dev/gvisor/pkg/tcpip" 14 "gvisor.dev/gvisor/pkg/tcpip/checksum" 15 "gvisor.dev/gvisor/pkg/tcpip/header" 16 "gvisor.dev/gvisor/pkg/tcpip/stack" 17 ) 18 19 // func (t *tunServer) udpForwarder() *udp.Forwarder { 20 // return udp.NewForwarder(t.stack, func(fr *udp.ForwarderRequest) { 21 // var wq waiter.Queue 22 // ep, err := fr.CreateEndpoint(&wq) 23 // if err != nil { 24 // log.Error("create endpoint failed:", "err", err) 25 // return 26 // } 27 28 // local := gonet.NewUDPConn(&wq, ep) 29 30 // go func(local *gonet.UDPConn, id stack.TransportEndpointID) { 31 // defer local.Close() 32 33 // addr, ok := netip.AddrFromSlice(id.LocalAddress.AsSlice()) 34 // if !ok { 35 // return 36 // } 37 38 // dst := netapi.ParseAddrPort(statistic.Type_udp, netip.AddrPortFrom(addr, id.LocalPort)) 39 40 // for { 41 // buf := pool.GetBytesBuffer(t.mtu) 42 43 // _ = local.SetReadDeadline(time.Now().Add(nat.IdleTimeout)) 44 // _, src, err := buf.ReadFromPacket(local) 45 // if err != nil { 46 // if ne, ok := err.(net.Error); (ok && ne.Timeout()) || err == io.EOF { 47 // return /* ignore I/O timeout & EOF */ 48 // } 49 50 // log.Error("read udp failed:", "err", err) 51 // return 52 // } 53 54 // err = t.SendPacket(&netapi.Packet{ 55 // Src: src, 56 // Dst: dst, 57 // Payload: buf, 58 // WriteBack: func(b []byte, addr net.Addr) (int, error) { 59 // from, err := netapi.ParseSysAddr(addr) 60 // if err != nil { 61 // return 0, err 62 // } 63 64 // // Symmetric NAT 65 // // gVisor udp.NewForwarder only support Symmetric NAT, 66 // // can't set source in udp header 67 // // TODO: rewrite HandlePacket() to support full cone NAT 68 // if from.String() != dst.String() { 69 // return 0, nil 70 // } 71 72 // n, err := local.WriteTo(b, src) 73 // if err != nil { 74 // return n, err 75 // } 76 77 // _ = local.SetReadDeadline(time.Now().Add(nat.IdleTimeout)) 78 // return n, nil 79 // }, 80 // }) 81 // if err != nil { 82 // return 83 // } 84 // } 85 86 // }(local, fr.ID()) 87 // }) 88 // } 89 90 func (f *tunServer) HandleUDPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { 91 srcPort, dstPort := id.RemotePort, id.LocalPort 92 93 length := pkt.Data().Size() 94 buf := pool.GetBytesWriter(length) 95 96 _, err := pkt.Data().ReadTo(buf, true) 97 if err != nil { 98 return true 99 } 100 101 _ = f.SendPacket(&netapi.Packet{ 102 Src: netapi.ParseIPAddrPort(statistic.Type_udp, id.RemoteAddress.AsSlice(), int(srcPort)), 103 Dst: netapi.ParseIPAddrPort(statistic.Type_udp, id.LocalAddress.AsSlice(), int(dstPort)), 104 Payload: buf.Unwrap(), 105 WriteBack: func(b []byte, addr net.Addr) (int, error) { 106 return f.WriteUDPBack(b, id.RemoteAddress, srcPort, addr) 107 }, 108 }) 109 return true 110 } 111 112 func (w *tunServer) WriteUDPBack(data []byte, sourceAddr tcpip.Address, sourcePort uint16, destination net.Addr) (int, error) { 113 daddr, err := netapi.ParseSysAddr(destination) 114 if err != nil { 115 return 0, err 116 } 117 118 if daddr.IsFqdn() { 119 return 0, fmt.Errorf("send FQDN packet") 120 } 121 122 dip := daddr.AddrPort(context.TODO()).V 123 124 if sourceAddr.Len() == 4 && dip.Addr().Is6() { 125 return 0, fmt.Errorf("send IPv6 packet to IPv4 connection") 126 } 127 128 var addr tcpip.Address 129 var sourceNetwork tcpip.NetworkProtocolNumber 130 if sourceAddr.Len() == 16 { 131 addr = tcpip.AddrFrom16(dip.Addr().As16()) 132 sourceNetwork = header.IPv6ProtocolNumber 133 } else { 134 addr = tcpip.AddrFrom4(dip.Addr().As4()) 135 sourceNetwork = header.IPv4ProtocolNumber 136 } 137 138 route, gerr := w.stack.FindRoute(w.nicID, addr, sourceAddr, sourceNetwork, false) 139 if gerr != nil { 140 return 0, fmt.Errorf("failed to find route: %v", gerr) 141 } 142 defer route.Release() 143 144 packet := stack.NewPacketBuffer(stack.PacketBufferOptions{ 145 ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()), 146 Payload: buffer.MakeWithData(data), 147 }) 148 defer packet.DecRef() 149 150 packet.TransportProtocolNumber = header.UDPProtocolNumber 151 udp := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize)) 152 pLen := uint16(packet.Size()) 153 udp.Encode(&header.UDPFields{ 154 SrcPort: dip.Port(), 155 DstPort: sourcePort, 156 Length: pLen, 157 }) 158 159 // Set the checksum field unless TX checksum offload is enabled. 160 // On IPv4, UDP checksum is optional, and a zero value indicates the 161 // transmitter skipped the checksum generation (RFC768). 162 // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). 163 if route.RequiresTXTransportChecksum() && sourceNetwork == header.IPv6ProtocolNumber { 164 xsum := udp.CalculateChecksum(checksum.Combine( 165 route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen), 166 packet.Data().Checksum(), 167 )) 168 if xsum != math.MaxUint16 { 169 xsum = ^xsum 170 } 171 udp.SetChecksum(xsum) 172 } 173 174 gerr = route.WritePacket(stack.NetworkHeaderParams{ 175 Protocol: header.UDPProtocolNumber, 176 TTL: route.DefaultTTL(), 177 TOS: 0, 178 }, packet) 179 if gerr != nil { 180 route.Stats().UDP.PacketSendErrors.Increment() 181 return 0, fmt.Errorf("failed to write packet: %v", gerr) 182 } 183 184 route.Stats().UDP.PacketsSent.Increment() 185 return len(data), nil 186 }