github.com/sagernet/sing-tun@v0.3.0-beta.5/stack_mixed.go (about) 1 //go:build with_gvisor 2 3 package tun 4 5 import ( 6 "time" 7 8 "github.com/sagernet/gvisor/pkg/buffer" 9 "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" 10 "github.com/sagernet/gvisor/pkg/tcpip/header" 11 "github.com/sagernet/gvisor/pkg/tcpip/link/channel" 12 "github.com/sagernet/gvisor/pkg/tcpip/stack" 13 "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" 14 "github.com/sagernet/gvisor/pkg/waiter" 15 "github.com/sagernet/sing-tun/internal/clashtcpip" 16 "github.com/sagernet/sing/common/bufio" 17 "github.com/sagernet/sing/common/canceler" 18 E "github.com/sagernet/sing/common/exceptions" 19 M "github.com/sagernet/sing/common/metadata" 20 ) 21 22 type Mixed struct { 23 *System 24 endpointIndependentNat bool 25 stack *stack.Stack 26 endpoint *channel.Endpoint 27 } 28 29 func NewMixed( 30 options StackOptions, 31 ) (Stack, error) { 32 system, err := NewSystem(options) 33 if err != nil { 34 return nil, err 35 } 36 return &Mixed{ 37 System: system.(*System), 38 endpointIndependentNat: options.EndpointIndependentNat, 39 }, nil 40 } 41 42 func (m *Mixed) Start() error { 43 err := m.System.start() 44 if err != nil { 45 return err 46 } 47 endpoint := channel.New(1024, uint32(m.mtu), "") 48 ipStack, err := newGVisorStack(endpoint) 49 if err != nil { 50 return err 51 } 52 if !m.endpointIndependentNat { 53 udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) { 54 var wq waiter.Queue 55 endpoint, err := request.CreateEndpoint(&wq) 56 if err != nil { 57 return 58 } 59 udpConn := gonet.NewUDPConn(&wq, endpoint) 60 lAddr := udpConn.RemoteAddr() 61 rAddr := udpConn.LocalAddr() 62 if lAddr == nil || rAddr == nil { 63 endpoint.Abort() 64 return 65 } 66 gConn := &gUDPConn{UDPConn: udpConn} 67 go func() { 68 var metadata M.Metadata 69 metadata.Source = M.SocksaddrFromNet(lAddr) 70 metadata.Destination = M.SocksaddrFromNet(rAddr) 71 ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewUnbindPacketConnWithAddr(gConn, metadata.Destination), time.Duration(m.udpTimeout)*time.Second) 72 hErr := m.handler.NewPacketConnection(ctx, conn, metadata) 73 if hErr != nil { 74 endpoint.Abort() 75 } 76 }() 77 }) 78 ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) 79 } else { 80 ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket) 81 } 82 m.stack = ipStack 83 m.endpoint = endpoint 84 go m.tunLoop() 85 go m.packetLoop() 86 return nil 87 } 88 89 func (m *Mixed) tunLoop() { 90 if winTun, isWinTun := m.tun.(WinTun); isWinTun { 91 m.wintunLoop(winTun) 92 return 93 } 94 if linuxTUN, isLinuxTUN := m.tun.(LinuxTUN); isLinuxTUN { 95 m.frontHeadroom = linuxTUN.FrontHeadroom() 96 m.txChecksumOffload = linuxTUN.TXChecksumOffload() 97 batchSize := linuxTUN.BatchSize() 98 if batchSize > 1 { 99 m.batchLoop(linuxTUN, batchSize) 100 return 101 } 102 } 103 packetBuffer := make([]byte, m.mtu+PacketOffset) 104 for { 105 n, err := m.tun.Read(packetBuffer) 106 if err != nil { 107 if E.IsClosed(err) { 108 return 109 } 110 m.logger.Error(E.Cause(err, "read packet")) 111 } 112 if n < clashtcpip.IPv4PacketMinLength { 113 continue 114 } 115 rawPacket := packetBuffer[:n] 116 packet := packetBuffer[PacketOffset:n] 117 if m.processPacket(packet) { 118 _, err = m.tun.Write(rawPacket) 119 if err != nil { 120 m.logger.Trace(E.Cause(err, "write packet")) 121 } 122 } 123 } 124 } 125 126 func (m *Mixed) wintunLoop(winTun WinTun) { 127 for { 128 packet, release, err := winTun.ReadPacket() 129 if err != nil { 130 return 131 } 132 if len(packet) < clashtcpip.IPv4PacketMinLength { 133 release() 134 continue 135 } 136 if m.processPacket(packet) { 137 _, err = winTun.Write(packet) 138 if err != nil { 139 m.logger.Trace(E.Cause(err, "write packet")) 140 } 141 } 142 release() 143 } 144 } 145 146 func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) { 147 packetBuffers := make([][]byte, batchSize) 148 writeBuffers := make([][]byte, batchSize) 149 packetSizes := make([]int, batchSize) 150 for i := range packetBuffers { 151 packetBuffers[i] = make([]byte, m.mtu+m.frontHeadroom) 152 } 153 for { 154 n, err := linuxTUN.BatchRead(packetBuffers, m.frontHeadroom, packetSizes) 155 if err != nil { 156 if E.IsClosed(err) { 157 return 158 } 159 m.logger.Error(E.Cause(err, "batch read packet")) 160 } 161 if n == 0 { 162 continue 163 } 164 for i := 0; i < n; i++ { 165 packetSize := packetSizes[i] 166 if packetSize < clashtcpip.IPv4PacketMinLength { 167 continue 168 } 169 packetBuffer := packetBuffers[i] 170 packet := packetBuffer[m.frontHeadroom : m.frontHeadroom+packetSize] 171 if m.processPacket(packet) { 172 writeBuffers = append(writeBuffers, packetBuffer[:m.frontHeadroom+packetSize]) 173 } 174 } 175 if len(writeBuffers) > 0 { 176 err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom) 177 if err != nil { 178 m.logger.Trace(E.Cause(err, "batch write packet")) 179 } 180 writeBuffers = writeBuffers[:0] 181 } 182 } 183 } 184 185 func (m *Mixed) processPacket(packet []byte) bool { 186 var ( 187 writeBack bool 188 err error 189 ) 190 switch ipVersion := packet[0] >> 4; ipVersion { 191 case 4: 192 writeBack, err = m.processIPv4(packet) 193 case 6: 194 writeBack, err = m.processIPv6(packet) 195 default: 196 err = E.New("ip: unknown version: ", ipVersion) 197 } 198 if err != nil { 199 m.logger.Trace(err) 200 return false 201 } 202 return writeBack 203 } 204 205 func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) { 206 writeBack = true 207 destination := packet.DestinationIP() 208 if destination == m.broadcastAddr || !destination.IsGlobalUnicast() { 209 return 210 } 211 switch packet.Protocol() { 212 case clashtcpip.TCP: 213 err = m.processIPv4TCP(packet, packet.Payload()) 214 case clashtcpip.UDP: 215 writeBack = false 216 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 217 Payload: buffer.MakeWithData(packet), 218 IsForwardedPacket: true, 219 }) 220 m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt) 221 pkt.DecRef() 222 return 223 case clashtcpip.ICMP: 224 err = m.processIPv4ICMP(packet, packet.Payload()) 225 } 226 return 227 } 228 229 func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) { 230 writeBack = true 231 if !packet.DestinationIP().IsGlobalUnicast() { 232 return 233 } 234 switch packet.Protocol() { 235 case clashtcpip.TCP: 236 err = m.processIPv6TCP(packet, packet.Payload()) 237 case clashtcpip.UDP: 238 writeBack = false 239 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 240 Payload: buffer.MakeWithData(packet), 241 IsForwardedPacket: true, 242 }) 243 m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt) 244 pkt.DecRef() 245 case clashtcpip.ICMPv6: 246 err = m.processIPv6ICMP(packet, packet.Payload()) 247 } 248 return 249 } 250 251 func (m *Mixed) packetLoop() { 252 for { 253 packet := m.endpoint.ReadContext(m.ctx) 254 if packet == nil { 255 break 256 } 257 bufio.WriteVectorised(m.tun, packet.AsSlices()) 258 packet.DecRef() 259 } 260 } 261 262 func (m *Mixed) Close() error { 263 m.endpoint.Attach(nil) 264 m.stack.Close() 265 for _, endpoint := range m.stack.CleanupEndpoints() { 266 endpoint.Abort() 267 } 268 return m.System.Close() 269 }