github.com/metacubex/sing-tun@v0.2.7-0.20240512075008-89e7c6208eec/stack_mixed.go (about) 1 //go:build with_gvisor 2 3 package tun 4 5 import ( 6 "time" 7 8 "github.com/sagernet/sing/common/bufio" 9 "github.com/sagernet/sing/common/canceler" 10 E "github.com/sagernet/sing/common/exceptions" 11 M "github.com/sagernet/sing/common/metadata" 12 13 "github.com/metacubex/gvisor/pkg/buffer" 14 "github.com/metacubex/gvisor/pkg/tcpip/adapters/gonet" 15 "github.com/metacubex/gvisor/pkg/tcpip/header" 16 "github.com/metacubex/gvisor/pkg/tcpip/link/channel" 17 "github.com/metacubex/gvisor/pkg/tcpip/stack" 18 "github.com/metacubex/gvisor/pkg/tcpip/transport/udp" 19 "github.com/metacubex/gvisor/pkg/waiter" 20 "github.com/metacubex/sing-tun/internal/clashtcpip" 21 ) 22 23 type Mixed struct { 24 *System 25 endpointIndependentNat bool 26 stack *stack.Stack 27 endpoint *channel.Endpoint 28 } 29 30 func NewMixed( 31 options StackOptions, 32 ) (Stack, error) { 33 system, err := NewSystem(options) 34 if err != nil { 35 return nil, err 36 } 37 return &Mixed{ 38 System: system.(*System), 39 endpointIndependentNat: options.EndpointIndependentNat, 40 }, nil 41 } 42 43 func (m *Mixed) Start() error { 44 err := m.System.start() 45 if err != nil { 46 return err 47 } 48 endpoint := channel.New(1024, uint32(m.mtu), "") 49 ipStack, err := newGVisorStack(endpoint) 50 if err != nil { 51 return err 52 } 53 if !m.endpointIndependentNat { 54 udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) { 55 var wq waiter.Queue 56 endpoint, err := request.CreateEndpoint(&wq) 57 if err != nil { 58 return 59 } 60 udpConn := gonet.NewUDPConn(&wq, endpoint) 61 lAddr := udpConn.RemoteAddr() 62 rAddr := udpConn.LocalAddr() 63 if lAddr == nil || rAddr == nil { 64 endpoint.Abort() 65 return 66 } 67 gConn := &gUDPConn{UDPConn: udpConn} 68 go func() { 69 var metadata M.Metadata 70 metadata.Source = M.SocksaddrFromNet(lAddr) 71 metadata.Destination = M.SocksaddrFromNet(rAddr) 72 ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewUnbindPacketConnWithAddr(gConn, metadata.Destination), time.Duration(m.udpTimeout)*time.Second) 73 hErr := m.handler.NewPacketConnection(ctx, conn, metadata) 74 if hErr != nil { 75 endpoint.Abort() 76 } 77 }() 78 }) 79 ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) 80 } else { 81 ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket) 82 } 83 m.stack = ipStack 84 m.endpoint = endpoint 85 go m.tunLoop() 86 go m.packetLoop() 87 return nil 88 } 89 90 func (m *Mixed) tunLoop() { 91 if winTun, isWinTun := m.tun.(WinTun); isWinTun { 92 m.wintunLoop(winTun) 93 return 94 } 95 if linuxTUN, isLinuxTUN := m.tun.(LinuxTUN); isLinuxTUN { 96 m.frontHeadroom = linuxTUN.FrontHeadroom() 97 m.txChecksumOffload = linuxTUN.TXChecksumOffload() 98 batchSize := linuxTUN.BatchSize() 99 if batchSize > 1 { 100 m.batchLoop(linuxTUN, batchSize) 101 return 102 } 103 } 104 packetBuffer := make([]byte, m.mtu+PacketOffset) 105 for { 106 n, err := m.tun.Read(packetBuffer) 107 if err != nil { 108 if E.IsClosed(err) { 109 return 110 } 111 m.logger.Error(E.Cause(err, "read packet")) 112 } 113 if n < clashtcpip.IPv4PacketMinLength { 114 continue 115 } 116 rawPacket := packetBuffer[:n] 117 packet := packetBuffer[PacketOffset:n] 118 if m.processPacket(packet) { 119 _, err = m.tun.Write(rawPacket) 120 if err != nil { 121 m.logger.Trace(E.Cause(err, "write packet")) 122 } 123 } 124 } 125 } 126 127 func (m *Mixed) wintunLoop(winTun WinTun) { 128 for { 129 packet, release, err := winTun.ReadPacket() 130 if err != nil { 131 return 132 } 133 if len(packet) < clashtcpip.IPv4PacketMinLength { 134 release() 135 continue 136 } 137 if m.processPacket(packet) { 138 _, err = winTun.Write(packet) 139 if err != nil { 140 m.logger.Trace(E.Cause(err, "write packet")) 141 } 142 } 143 release() 144 } 145 } 146 147 func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) { 148 packetBuffers := make([][]byte, batchSize) 149 writeBuffers := make([][]byte, batchSize) 150 packetSizes := make([]int, batchSize) 151 for i := range packetBuffers { 152 packetBuffers[i] = make([]byte, m.mtu+m.frontHeadroom) 153 } 154 for { 155 n, err := linuxTUN.BatchRead(packetBuffers, m.frontHeadroom, packetSizes) 156 if err != nil { 157 if E.IsClosed(err) { 158 return 159 } 160 m.logger.Error(E.Cause(err, "batch read packet")) 161 } 162 if n == 0 { 163 continue 164 } 165 for i := 0; i < n; i++ { 166 packetSize := packetSizes[i] 167 if packetSize < clashtcpip.IPv4PacketMinLength { 168 continue 169 } 170 packetBuffer := packetBuffers[i] 171 packet := packetBuffer[m.frontHeadroom : m.frontHeadroom+packetSize] 172 if m.processPacket(packet) { 173 writeBuffers = append(writeBuffers, packetBuffer[:m.frontHeadroom+packetSize]) 174 } 175 } 176 if len(writeBuffers) > 0 { 177 err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom) 178 if err != nil { 179 m.logger.Trace(E.Cause(err, "batch write packet")) 180 } 181 writeBuffers = writeBuffers[:0] 182 } 183 } 184 } 185 186 func (m *Mixed) processPacket(packet []byte) bool { 187 var ( 188 writeBack bool 189 err error 190 ) 191 switch ipVersion := packet[0] >> 4; ipVersion { 192 case 4: 193 writeBack, err = m.processIPv4(packet) 194 case 6: 195 writeBack, err = m.processIPv6(packet) 196 default: 197 err = E.New("ip: unknown version: ", ipVersion) 198 } 199 if err != nil { 200 m.logger.Trace(err) 201 return false 202 } 203 return writeBack 204 } 205 206 func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) { 207 writeBack = true 208 destination := packet.DestinationIP() 209 if destination == m.broadcastAddr || !destination.IsGlobalUnicast() { 210 return 211 } 212 switch packet.Protocol() { 213 case clashtcpip.TCP: 214 err = m.processIPv4TCP(packet, packet.Payload()) 215 case clashtcpip.UDP: 216 writeBack = false 217 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 218 Payload: buffer.MakeWithData(packet), 219 IsForwardedPacket: true, 220 }) 221 m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt) 222 pkt.DecRef() 223 return 224 case clashtcpip.ICMP: 225 err = m.processIPv4ICMP(packet, packet.Payload()) 226 } 227 return 228 } 229 230 func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) { 231 writeBack = true 232 if !packet.DestinationIP().IsGlobalUnicast() { 233 return 234 } 235 switch packet.Protocol() { 236 case clashtcpip.TCP: 237 err = m.processIPv6TCP(packet, packet.Payload()) 238 case clashtcpip.UDP: 239 writeBack = false 240 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 241 Payload: buffer.MakeWithData(packet), 242 IsForwardedPacket: true, 243 }) 244 m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt) 245 pkt.DecRef() 246 case clashtcpip.ICMPv6: 247 err = m.processIPv6ICMP(packet, packet.Payload()) 248 } 249 return 250 } 251 252 func (m *Mixed) packetLoop() { 253 for { 254 packet := m.endpoint.ReadContext(m.ctx) 255 if packet == nil { 256 break 257 } 258 bufio.WriteVectorised(m.tun, packet.AsSlices()) 259 packet.DecRef() 260 } 261 } 262 263 func (m *Mixed) Close() error { 264 m.endpoint.Attach(nil) 265 m.stack.Close() 266 for _, endpoint := range m.stack.CleanupEndpoints() { 267 endpoint.Abort() 268 } 269 return m.System.Close() 270 }