github.com/sagernet/sing-tun@v0.3.0-beta.5/stack_gvisor_udp.go (about) 1 //go:build with_gvisor 2 3 package tun 4 5 import ( 6 "context" 7 "errors" 8 "math" 9 "net/netip" 10 "os" 11 "sync" 12 "syscall" 13 14 "github.com/sagernet/gvisor/pkg/buffer" 15 "github.com/sagernet/gvisor/pkg/tcpip" 16 "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" 17 "github.com/sagernet/gvisor/pkg/tcpip/checksum" 18 "github.com/sagernet/gvisor/pkg/tcpip/header" 19 "github.com/sagernet/gvisor/pkg/tcpip/stack" 20 "github.com/sagernet/sing/common/buf" 21 E "github.com/sagernet/sing/common/exceptions" 22 M "github.com/sagernet/sing/common/metadata" 23 N "github.com/sagernet/sing/common/network" 24 "github.com/sagernet/sing/common/udpnat" 25 ) 26 27 type UDPForwarder struct { 28 ctx context.Context 29 stack *stack.Stack 30 udpNat *udpnat.Service[netip.AddrPort] 31 32 // cache 33 cacheProto tcpip.NetworkProtocolNumber 34 cacheID stack.TransportEndpointID 35 } 36 37 func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder { 38 return &UDPForwarder{ 39 ctx: ctx, 40 stack: stack, 41 udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler), 42 } 43 } 44 45 func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { 46 var upstreamMetadata M.Metadata 47 upstreamMetadata.Source = M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort) 48 upstreamMetadata.Destination = M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort) 49 if upstreamMetadata.Source.IsIPv4() { 50 f.cacheProto = header.IPv4ProtocolNumber 51 } else { 52 f.cacheProto = header.IPv6ProtocolNumber 53 } 54 gBuffer := pkt.Data().ToBuffer() 55 sBuffer := buf.NewSize(int(gBuffer.Size())) 56 gBuffer.Apply(func(view *buffer.View) { 57 sBuffer.Write(view.AsSlice()) 58 }) 59 f.cacheID = id 60 f.udpNat.NewPacket( 61 f.ctx, 62 upstreamMetadata.Source.AddrPort(), 63 sBuffer, 64 upstreamMetadata, 65 f.newUDPConn, 66 ) 67 return true 68 } 69 70 func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter { 71 return &UDPBackWriter{ 72 stack: f.stack, 73 source: f.cacheID.RemoteAddress, 74 sourcePort: f.cacheID.RemotePort, 75 sourceNetwork: f.cacheProto, 76 } 77 } 78 79 type UDPBackWriter struct { 80 access sync.Mutex 81 stack *stack.Stack 82 source tcpip.Address 83 sourcePort uint16 84 sourceNetwork tcpip.NetworkProtocolNumber 85 } 86 87 func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error { 88 if !destination.IsIP() { 89 return E.Cause(os.ErrInvalid, "invalid destination") 90 } else if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber { 91 destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port) 92 } else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4ProtocolNumber) { 93 return E.New("send IPv6 packet to IPv4 connection") 94 } 95 96 defer packetBuffer.Release() 97 98 route, err := w.stack.FindRoute( 99 defaultNIC, 100 AddressFromAddr(destination.Addr), 101 w.source, 102 w.sourceNetwork, 103 false, 104 ) 105 if err != nil { 106 return wrapStackError(err) 107 } 108 defer route.Release() 109 110 packet := stack.NewPacketBuffer(stack.PacketBufferOptions{ 111 ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()), 112 Payload: buffer.MakeWithData(packetBuffer.Bytes()), 113 }) 114 defer packet.DecRef() 115 116 packet.TransportProtocolNumber = header.UDPProtocolNumber 117 udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize)) 118 pLen := uint16(packet.Size()) 119 udpHdr.Encode(&header.UDPFields{ 120 SrcPort: destination.Port, 121 DstPort: w.sourcePort, 122 Length: pLen, 123 }) 124 125 if route.RequiresTXTransportChecksum() && w.sourceNetwork == header.IPv6ProtocolNumber { 126 xsum := udpHdr.CalculateChecksum(checksum.Combine( 127 route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen), 128 packet.Data().Checksum(), 129 )) 130 if xsum != math.MaxUint16 { 131 xsum = ^xsum 132 } 133 udpHdr.SetChecksum(xsum) 134 } 135 136 err = route.WritePacket(stack.NetworkHeaderParams{ 137 Protocol: header.UDPProtocolNumber, 138 TTL: route.DefaultTTL(), 139 TOS: 0, 140 }, packet) 141 142 if err != nil { 143 route.Stats().UDP.PacketSendErrors.Increment() 144 return wrapStackError(err) 145 } 146 147 route.Stats().UDP.PacketsSent.Increment() 148 return nil 149 } 150 151 type gUDPConn struct { 152 *gonet.UDPConn 153 } 154 155 func (c *gUDPConn) Read(b []byte) (n int, err error) { 156 n, err = c.UDPConn.Read(b) 157 if err == nil { 158 return 159 } 160 err = wrapError(err) 161 return 162 } 163 164 func (c *gUDPConn) Write(b []byte) (n int, err error) { 165 n, err = c.UDPConn.Write(b) 166 if err == nil { 167 return 168 } 169 err = wrapError(err) 170 return 171 } 172 173 func (c *gUDPConn) Close() error { 174 return c.UDPConn.Close() 175 } 176 177 func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err error) (retErr error) { 178 if errors.Is(err, syscall.ENETUNREACH) { 179 if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { 180 return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetUnreachable) 181 } else { 182 return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute) 183 } 184 } else if errors.Is(err, syscall.EHOSTUNREACH) { 185 if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { 186 return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostUnreachable) 187 } else { 188 return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute) 189 } 190 } else if errors.Is(err, syscall.ECONNREFUSED) { 191 if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber { 192 return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable) 193 } else { 194 return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable) 195 } 196 } 197 return nil 198 } 199 200 func gWriteUnreachable4(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv4WithICMPType) error { 201 err := gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true) 202 if err != nil { 203 return wrapStackError(err) 204 } 205 return nil 206 } 207 208 func gWriteUnreachable6(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv6WithICMPType) error { 209 err := gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true) 210 if err != nil { 211 return wrapStackError(err) 212 } 213 return nil 214 }