github.com/MerlinKodo/sing-tun@v0.1.15/stack_mixed.go (about) 1 //go:build with_gvisor 2 3 package tun 4 5 import ( 6 "time" 7 8 "github.com/sagernet/sing/common" 9 "github.com/sagernet/sing/common/bufio" 10 "github.com/sagernet/sing/common/canceler" 11 E "github.com/sagernet/sing/common/exceptions" 12 M "github.com/sagernet/sing/common/metadata" 13 N "github.com/sagernet/sing/common/network" 14 15 "github.com/MerlinKodo/gvisor/pkg/buffer" 16 "github.com/MerlinKodo/gvisor/pkg/tcpip/adapters/gonet" 17 "github.com/MerlinKodo/gvisor/pkg/tcpip/header" 18 "github.com/MerlinKodo/gvisor/pkg/tcpip/link/channel" 19 "github.com/MerlinKodo/gvisor/pkg/tcpip/stack" 20 "github.com/MerlinKodo/gvisor/pkg/tcpip/transport/udp" 21 "github.com/MerlinKodo/gvisor/pkg/waiter" 22 "github.com/MerlinKodo/sing-tun/internal/clashtcpip" 23 ) 24 25 type Mixed struct { 26 *System 27 writer N.VectorisedWriter 28 endpointIndependentNat bool 29 stack *stack.Stack 30 endpoint *channel.Endpoint 31 } 32 33 func NewMixed( 34 options StackOptions, 35 ) (Stack, error) { 36 system, err := NewSystem(options) 37 if err != nil { 38 return nil, err 39 } 40 return &Mixed{ 41 System: system.(*System), 42 writer: options.Tun.CreateVectorisedWriter(), 43 endpointIndependentNat: options.EndpointIndependentNat, 44 }, nil 45 } 46 47 func (m *Mixed) Start() error { 48 err := m.System.start() 49 if err != nil { 50 return err 51 } 52 endpoint := channel.New(1024, m.mtu, "") 53 ipStack, err := newGVisorStack(endpoint) 54 if err != nil { 55 return err 56 } 57 if !m.endpointIndependentNat { 58 udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) { 59 var wq waiter.Queue 60 endpoint, err := request.CreateEndpoint(&wq) 61 if err != nil { 62 return 63 } 64 udpConn := gonet.NewUDPConn(ipStack, &wq, endpoint) 65 lAddr := udpConn.RemoteAddr() 66 rAddr := udpConn.LocalAddr() 67 if lAddr == nil || rAddr == nil { 68 endpoint.Abort() 69 return 70 } 71 gConn := &gUDPConn{UDPConn: udpConn} 72 go func() { 73 var metadata M.Metadata 74 metadata.Source = M.SocksaddrFromNet(lAddr) 75 metadata.Destination = M.SocksaddrFromNet(rAddr) 76 ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(gConn), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(m.udpTimeout)*time.Second) 77 hErr := m.handler.NewPacketConnection(ctx, conn, metadata) 78 if hErr != nil { 79 endpoint.Abort() 80 } 81 }() 82 }) 83 ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) 84 } else { 85 ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket) 86 } 87 m.stack = ipStack 88 m.endpoint = endpoint 89 go m.tunLoop() 90 go m.packetLoop() 91 return nil 92 } 93 94 func (m *Mixed) tunLoop() { 95 if winTun, isWinTun := m.tun.(WinTun); isWinTun { 96 m.wintunLoop(winTun) 97 return 98 } 99 packetBuffer := make([]byte, m.mtu+PacketOffset) 100 for { 101 n, err := m.tun.Read(packetBuffer) 102 if err != nil { 103 return 104 } 105 if n < clashtcpip.IPv4PacketMinLength { 106 continue 107 } 108 packet := packetBuffer[PacketOffset:n] 109 switch ipVersion := packet[0] >> 4; ipVersion { 110 case 4: 111 err = m.processIPv4(packet) 112 case 6: 113 err = m.processIPv6(packet) 114 default: 115 err = E.New("ip: unknown version: ", ipVersion) 116 } 117 if err != nil { 118 m.logger.Trace(err) 119 } 120 } 121 } 122 123 func (m *Mixed) wintunLoop(winTun WinTun) { 124 for { 125 packet, release, err := winTun.ReadPacket() 126 if err != nil { 127 return 128 } 129 if len(packet) < clashtcpip.IPv4PacketMinLength { 130 release() 131 continue 132 } 133 switch ipVersion := packet[0] >> 4; ipVersion { 134 case 4: 135 err = m.processIPv4(packet) 136 case 6: 137 err = m.processIPv6(packet) 138 default: 139 err = E.New("ip: unknown version: ", ipVersion) 140 } 141 if err != nil { 142 m.logger.Trace(err) 143 } 144 release() 145 } 146 } 147 148 func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) error { 149 switch packet.Protocol() { 150 case clashtcpip.TCP: 151 return m.processIPv4TCP(packet, packet.Payload()) 152 case clashtcpip.UDP: 153 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 154 Payload: buffer.MakeWithData(packet), 155 }) 156 m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt) 157 pkt.DecRef() 158 return nil 159 case clashtcpip.ICMP: 160 return m.processIPv4ICMP(packet, packet.Payload()) 161 default: 162 return common.Error(m.tun.Write(packet)) 163 } 164 } 165 166 func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) error { 167 switch packet.Protocol() { 168 case clashtcpip.TCP: 169 return m.processIPv6TCP(packet, packet.Payload()) 170 case clashtcpip.UDP: 171 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 172 Payload: buffer.MakeWithData(packet), 173 }) 174 m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt) 175 pkt.DecRef() 176 return nil 177 case clashtcpip.ICMPv6: 178 return m.processIPv6ICMP(packet, packet.Payload()) 179 default: 180 return common.Error(m.tun.Write(packet)) 181 } 182 } 183 184 func (m *Mixed) packetLoop() { 185 for { 186 packet := m.endpoint.ReadContext(m.ctx) 187 if packet == nil { 188 break 189 } 190 bufio.WriteVectorised(m.writer, packet.AsSlices()) 191 packet.DecRef() 192 } 193 } 194 195 func (m *Mixed) Close() error { 196 m.endpoint.Attach(nil) 197 m.stack.Close() 198 for _, endpoint := range m.stack.CleanupEndpoints() { 199 endpoint.Abort() 200 } 201 return m.System.Close() 202 }