github.com/metacubex/sing-tun@v0.2.7-0.20240512075008-89e7c6208eec/stack_gvisor_udp.go (about) 1 //go:build with_gvisor 2 3 package tun 4 5 import ( 6 "context" 7 "errors" 8 "math" 9 "net/netip" 10 "os" 11 "sync" 12 "syscall" 13 14 "github.com/sagernet/sing/common/buf" 15 E "github.com/sagernet/sing/common/exceptions" 16 M "github.com/sagernet/sing/common/metadata" 17 N "github.com/sagernet/sing/common/network" 18 "github.com/sagernet/sing/common/udpnat" 19 20 "github.com/metacubex/gvisor/pkg/buffer" 21 "github.com/metacubex/gvisor/pkg/tcpip" 22 "github.com/metacubex/gvisor/pkg/tcpip/adapters/gonet" 23 "github.com/metacubex/gvisor/pkg/tcpip/checksum" 24 "github.com/metacubex/gvisor/pkg/tcpip/header" 25 "github.com/metacubex/gvisor/pkg/tcpip/stack" 26 ) 27 28 type UDPForwarder struct { 29 ctx context.Context 30 stack *stack.Stack 31 udpNat *udpnat.Service[netip.AddrPort] 32 33 // cache 34 cacheProto tcpip.NetworkProtocolNumber 35 cacheID stack.TransportEndpointID 36 } 37 38 func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder { 39 return &UDPForwarder{ 40 ctx: ctx, 41 stack: stack, 42 udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler), 43 } 44 } 45 46 func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { 47 var upstreamMetadata M.Metadata 48 upstreamMetadata.Source = M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort) 49 upstreamMetadata.Destination = M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort) 50 if upstreamMetadata.Source.IsIPv4() { 51 f.cacheProto = header.IPv4ProtocolNumber 52 } else { 53 f.cacheProto = header.IPv6ProtocolNumber 54 } 55 gBuffer := pkt.Data().ToBuffer() 56 sBuffer := buf.NewSize(int(gBuffer.Size())) 57 gBuffer.Apply(func(view *buffer.View) { 58 sBuffer.Write(view.AsSlice()) 59 }) 60 f.cacheID = id 61 f.udpNat.NewPacket( 62 f.ctx, 63 upstreamMetadata.Source.AddrPort(), 64 sBuffer, 65 upstreamMetadata, 66 f.newUDPConn, 67 ) 68 return true 69 } 70 71 func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter { 72 return &UDPBackWriter{ 73 stack: f.stack, 74 source: f.cacheID.RemoteAddress, 75 sourcePort: f.cacheID.RemotePort, 76 sourceNetwork: f.cacheProto, 77 } 78 } 79 80 type UDPBackWriter struct { 81 access sync.Mutex 82 stack *stack.Stack 83 source tcpip.Address 84 sourcePort uint16 85 sourceNetwork tcpip.NetworkProtocolNumber 86 } 87 88 func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error { 89 if !destination.IsIP() { 90 return E.Cause(os.ErrInvalid, "invalid destination") 91 } else if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber { 92 destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port) 93 } else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4ProtocolNumber) { 94 return E.New("send IPv6 packet to IPv4 connection") 95 } 96 97 defer packetBuffer.Release() 98 99 route, err := w.stack.FindRoute( 100 defaultNIC, 101 AddressFromAddr(destination.Addr), 102 w.source, 103 w.sourceNetwork, 104 false, 105 ) 106 if err != nil { 107 return wrapStackError(err) 108 } 109 defer route.Release() 110 111 packet := stack.NewPacketBuffer(stack.PacketBufferOptions{ 112 ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()), 113 Payload: buffer.MakeWithData(packetBuffer.Bytes()), 114 }) 115 defer packet.DecRef() 116 117 packet.TransportProtocolNumber = header.UDPProtocolNumber 118 udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize)) 119 pLen := uint16(packet.Size()) 120 udpHdr.Encode(&header.UDPFields{ 121 SrcPort: destination.Port, 122 DstPort: w.sourcePort, 123 Length: pLen, 124 }) 125 126 if route.RequiresTXTransportChecksum() && w.sourceNetwork == header.IPv6ProtocolNumber { 127 xsum := udpHdr.CalculateChecksum(checksum.Combine( 128 route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen), 129 packet.Data().Checksum(), 130 )) 131 if xsum != math.MaxUint16 { 132 xsum = ^xsum 133 } 134 udpHdr.SetChecksum(xsum) 135 } 136 137 err = route.WritePacket(stack.NetworkHeaderParams{ 138 Protocol: header.UDPProtocolNumber, 139 TTL: route.DefaultTTL(), 140 TOS: 0, 141 }, packet) 142 143 if err != nil { 144 route.Stats().UDP.PacketSendErrors.Increment() 145 return wrapStackError(err) 146 } 147 148 route.Stats().UDP.PacketsSent.Increment() 149 return nil 150 } 151 152 type gUDPConn struct { 153 *gonet.UDPConn 154 } 155 156 func (c *gUDPConn) Read(b []byte) (n int, err error) { 157 n, err = c.UDPConn.Read(b) 158 if err == nil { 159 return 160 } 161 err = wrapError(err) 162 return 163 } 164 165 func (c *gUDPConn) Write(b []byte) (n int, err error) { 166 n, err = c.UDPConn.Write(b) 167 if err == nil { 168 return 169 } 170 err = wrapError(err) 171 return 172 } 173 174 func (c *gUDPConn) Close() error { 175 return c.UDPConn.Close() 176 } 177 178 func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err error) (retErr error) { 179 if errors.Is(err, syscall.ENETUNREACH) { 180 if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { 181 return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetUnreachable) 182 } else { 183 return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute) 184 } 185 } else if errors.Is(err, syscall.EHOSTUNREACH) { 186 if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { 187 return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostUnreachable) 188 } else { 189 return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute) 190 } 191 } else if errors.Is(err, syscall.ECONNREFUSED) { 192 if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { 193 return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable) 194 } else { 195 return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable) 196 } 197 } 198 return nil 199 } 200 201 func gWriteUnreachable4(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv4WithICMPType) error { 202 err := gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true) 203 if err != nil { 204 return wrapStackError(err) 205 } 206 return nil 207 } 208 209 func gWriteUnreachable6(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv6WithICMPType) error { 210 err := gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true) 211 if err != nil { 212 return wrapStackError(err) 213 } 214 return nil 215 }