github.com/MerlinKodo/sing-tun@v0.1.15/stack_gvisor_udp.go (about) 1 //go:build with_gvisor 2 3 package tun 4 5 import ( 6 "context" 7 "errors" 8 "math" 9 "net/netip" 10 "os" 11 "sync" 12 "syscall" 13 14 "github.com/sagernet/sing/common/buf" 15 E "github.com/sagernet/sing/common/exceptions" 16 M "github.com/sagernet/sing/common/metadata" 17 N "github.com/sagernet/sing/common/network" 18 "github.com/sagernet/sing/common/udpnat" 19 20 "github.com/MerlinKodo/gvisor/pkg/buffer" 21 "github.com/MerlinKodo/gvisor/pkg/tcpip" 22 "github.com/MerlinKodo/gvisor/pkg/tcpip/adapters/gonet" 23 "github.com/MerlinKodo/gvisor/pkg/tcpip/checksum" 24 "github.com/MerlinKodo/gvisor/pkg/tcpip/header" 25 "github.com/MerlinKodo/gvisor/pkg/tcpip/stack" 26 ) 27 28 type UDPForwarder struct { 29 ctx context.Context 30 stack *stack.Stack 31 udpNat *udpnat.Service[netip.AddrPort] 32 33 // cache 34 cacheProto tcpip.NetworkProtocolNumber 35 cacheID stack.TransportEndpointID 36 } 37 38 func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder { 39 return &UDPForwarder{ 40 ctx: ctx, 41 stack: stack, 42 udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler), 43 } 44 } 45 46 func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool { 47 var upstreamMetadata M.Metadata 48 upstreamMetadata.Source = M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort) 49 upstreamMetadata.Destination = M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort) 50 if upstreamMetadata.Source.IsIPv4() { 51 f.cacheProto = header.IPv4ProtocolNumber 52 } else { 53 f.cacheProto = header.IPv6ProtocolNumber 54 } 55 gBuffer := pkt.Data().ToBuffer() 56 sBuffer := buf.NewSize(int(gBuffer.Size())) 57 gBuffer.Apply(func(view *buffer.View) { 58 sBuffer.Write(view.AsSlice()) 59 }) 60 f.cacheID = id 61 f.udpNat.NewPacket( 62 f.ctx, 63 upstreamMetadata.Source.AddrPort(), 64 sBuffer, 65 upstreamMetadata, 66 f.newUDPConn, 67 ) 68 return true 69 } 70 71 func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter { 72 return &UDPBackWriter{ 73 stack: f.stack, 74 source: f.cacheID.RemoteAddress, 75 sourcePort: f.cacheID.RemotePort, 76 sourceNetwork: f.cacheProto, 77 } 78 } 79 80 type UDPBackWriter struct { 81 access sync.Mutex 82 stack *stack.Stack 83 source tcpip.Address 84 sourcePort uint16 85 sourceNetwork tcpip.NetworkProtocolNumber 86 packet stack.PacketBufferPtr 87 } 88 89 func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error { 90 if !destination.IsIP() { 91 return E.Cause(os.ErrInvalid, "invalid destination") 92 } else if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber { 93 destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port) 94 } else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4ProtocolNumber) { 95 return E.New("send IPv6 packet to IPv4 connection") 96 } 97 98 defer packetBuffer.Release() 99 100 route, err := w.stack.FindRoute( 101 defaultNIC, 102 AddressFromAddr(destination.Addr), 103 w.source, 104 w.sourceNetwork, 105 false, 106 ) 107 if err != nil { 108 return wrapStackError(err) 109 } 110 defer route.Release() 111 112 packet := stack.NewPacketBuffer(stack.PacketBufferOptions{ 113 ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()), 114 Payload: buffer.MakeWithData(packetBuffer.Bytes()), 115 }) 116 defer packet.DecRef() 117 118 packet.TransportProtocolNumber = header.UDPProtocolNumber 119 udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize)) 120 pLen := uint16(packet.Size()) 121 udpHdr.Encode(&header.UDPFields{ 122 SrcPort: destination.Port, 123 DstPort: w.sourcePort, 124 Length: pLen, 125 }) 126 127 if route.RequiresTXTransportChecksum() && w.sourceNetwork == header.IPv6ProtocolNumber { 128 xsum := udpHdr.CalculateChecksum(checksum.Combine( 129 route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen), 130 packet.Data().Checksum(), 131 )) 132 if xsum != math.MaxUint16 { 133 xsum = ^xsum 134 } 135 udpHdr.SetChecksum(xsum) 136 } 137 138 err = route.WritePacket(stack.NetworkHeaderParams{ 139 Protocol: header.UDPProtocolNumber, 140 TTL: route.DefaultTTL(), 141 TOS: 0, 142 }, packet) 143 144 if err != nil { 145 route.Stats().UDP.PacketSendErrors.Increment() 146 return wrapStackError(err) 147 } 148 149 route.Stats().UDP.PacketsSent.Increment() 150 return nil 151 } 152 153 type gRequest struct { 154 stack *stack.Stack 155 id stack.TransportEndpointID 156 pkt stack.PacketBufferPtr 157 } 158 159 type gUDPConn struct { 160 *gonet.UDPConn 161 } 162 163 func (c *gUDPConn) Read(b []byte) (n int, err error) { 164 n, err = c.UDPConn.Read(b) 165 if err == nil { 166 return 167 } 168 err = wrapError(err) 169 return 170 } 171 172 func (c *gUDPConn) Write(b []byte) (n int, err error) { 173 n, err = c.UDPConn.Write(b) 174 if err == nil { 175 return 176 } 177 err = wrapError(err) 178 return 179 } 180 181 func (c *gUDPConn) Close() error { 182 return c.UDPConn.Close() 183 } 184 185 func gWriteUnreachable(gStack *stack.Stack, packet stack.PacketBufferPtr, err error) (retErr error) { 186 if errors.Is(err, syscall.ENETUNREACH) { 187 if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { 188 return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetUnreachable) 189 } else { 190 return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute) 191 } 192 } else if errors.Is(err, syscall.EHOSTUNREACH) { 193 if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { 194 return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostUnreachable) 195 } else { 196 return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute) 197 } 198 } else if errors.Is(err, syscall.ECONNREFUSED) { 199 if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { 200 return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable) 201 } else { 202 return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable) 203 } 204 } 205 return nil 206 } 207 208 func gWriteUnreachable4(gStack *stack.Stack, packet stack.PacketBufferPtr, icmpCode stack.RejectIPv4WithICMPType) error { 209 err := gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true) 210 if err != nil { 211 return wrapStackError(err) 212 } 213 return nil 214 } 215 216 func gWriteUnreachable6(gStack *stack.Stack, packet stack.PacketBufferPtr, icmpCode stack.RejectIPv6WithICMPType) error { 217 err := gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true) 218 if err != nil { 219 return wrapStackError(err) 220 } 221 return nil 222 }