github.com/yaling888/clash@v1.53.0/transport/wireguard/tun.go (about) 1 //go:build !nogvisor 2 3 package wireguard 4 5 import ( 6 "context" 7 "fmt" 8 "net" 9 "net/netip" 10 "os" 11 "syscall" 12 13 "golang.zx2c4.com/wireguard/tun" 14 "gvisor.dev/gvisor/pkg/buffer" 15 "gvisor.dev/gvisor/pkg/tcpip" 16 "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" 17 "gvisor.dev/gvisor/pkg/tcpip/header" 18 "gvisor.dev/gvisor/pkg/tcpip/link/channel" 19 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" 20 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" 21 "gvisor.dev/gvisor/pkg/tcpip/stack" 22 "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" 23 "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" 24 "gvisor.dev/gvisor/pkg/tcpip/transport/udp" 25 ) 26 27 type netTun struct { 28 ep *channel.Endpoint 29 stack *stack.Stack 30 events chan tun.Event 31 incomingPacket chan *buffer.View 32 mtu int 33 dnsServers []netip.Addr 34 hasV4, hasV6 bool 35 } 36 37 type Net netTun 38 39 func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { 40 opts := stack.Options{ 41 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 42 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, 43 HandleLocal: true, 44 } 45 dev := &netTun{ 46 ep: channel.New(1024, uint32(mtu), ""), 47 stack: stack.New(opts), 48 events: make(chan tun.Event, 10), 49 incomingPacket: make(chan *buffer.View), 50 dnsServers: dnsServers, 51 mtu: mtu, 52 } 53 sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default 54 tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) 55 if tcpipErr != nil { 56 return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) 57 } 58 dev.ep.AddNotify(dev) 59 tcpipErr = dev.stack.CreateNIC(1, dev.ep) 60 if tcpipErr != nil { 61 return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) 62 } 63 for _, ip := range localAddresses { 64 var protoNumber tcpip.NetworkProtocolNumber 65 if ip.Is4() { 66 protoNumber = ipv4.ProtocolNumber 67 } else if ip.Is6() { 68 protoNumber = ipv6.ProtocolNumber 69 } 70 protoAddr := tcpip.ProtocolAddress{ 71 Protocol: protoNumber, 72 AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), 73 } 74 tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) 75 if tcpipErr != nil { 76 return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) 77 } 78 if ip.Is4() { 79 dev.hasV4 = true 80 } else if ip.Is6() { 81 dev.hasV6 = true 82 } 83 } 84 if dev.hasV4 { 85 dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) 86 } 87 if dev.hasV6 { 88 dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) 89 } 90 91 dev.events <- tun.EventUp 92 return dev, (*Net)(dev), nil 93 } 94 95 func (tun *netTun) Name() (string, error) { 96 return "Clash-WireGuard", nil 97 } 98 99 func (tun *netTun) File() *os.File { 100 return nil 101 } 102 103 func (tun *netTun) Events() <-chan tun.Event { 104 return tun.events 105 } 106 107 func (tun *netTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { 108 view, ok := <-tun.incomingPacket 109 if !ok { 110 return 0, os.ErrClosed 111 } 112 113 defer view.Release() 114 115 n, err := view.Read(bufs[0][offset:]) 116 if err != nil { 117 return 0, err 118 } 119 sizes[0] = n 120 return 1, nil 121 } 122 123 func (tun *netTun) Write(bufs [][]byte, offset int) (int, error) { 124 for i, buf := range bufs { 125 packet := buf[offset:] 126 if len(packet) == 0 { 127 continue 128 } 129 130 pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) 131 switch packet[0] >> 4 { 132 case 4: 133 tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) 134 case 6: 135 tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) 136 default: 137 pkb.DecRef() 138 return i, syscall.EAFNOSUPPORT 139 } 140 pkb.DecRef() 141 } 142 return len(bufs), nil 143 } 144 145 func (tun *netTun) WriteNotify() { 146 pkt := tun.ep.Read() 147 if pkt == nil { 148 return 149 } 150 151 view := pkt.ToView() 152 pkt.DecRef() 153 154 tun.incomingPacket <- view 155 } 156 157 func (tun *netTun) Close() error { 158 tun.stack.Destroy() 159 160 if tun.events != nil { 161 close(tun.events) 162 } 163 164 tun.ep.Close() 165 166 if tun.incomingPacket != nil { 167 close(tun.incomingPacket) 168 } 169 170 return nil 171 } 172 173 func (tun *netTun) MTU() (int, error) { 174 return tun.mtu, nil 175 } 176 177 func (tun *netTun) BatchSize() int { 178 return 1 179 } 180 181 func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { 182 var protoNumber tcpip.NetworkProtocolNumber 183 if endpoint.Addr().Is4() { 184 protoNumber = ipv4.ProtocolNumber 185 } else { 186 protoNumber = ipv6.ProtocolNumber 187 } 188 return tcpip.FullAddress{ 189 NIC: 1, 190 Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), 191 Port: endpoint.Port(), 192 }, protoNumber 193 } 194 195 func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { 196 fa, pn := convertToFullAddr(addr) 197 return gonet.DialContextTCP(ctx, net.stack, fa, pn) 198 } 199 200 func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { 201 if addr == nil { 202 return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) 203 } 204 ip, _ := netip.AddrFromSlice(addr.IP) 205 return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port))) 206 } 207 208 func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { 209 fa, pn := convertToFullAddr(addr) 210 return gonet.DialTCP(net.stack, fa, pn) 211 } 212 213 func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { 214 if addr == nil { 215 return net.DialTCPAddrPort(netip.AddrPort{}) 216 } 217 ip, _ := netip.AddrFromSlice(addr.IP) 218 return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) 219 } 220 221 func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { 222 fa, pn := convertToFullAddr(addr) 223 return gonet.ListenTCP(net.stack, fa, pn) 224 } 225 226 func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { 227 if addr == nil { 228 return net.ListenTCPAddrPort(netip.AddrPort{}) 229 } 230 ip, _ := netip.AddrFromSlice(addr.IP) 231 return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) 232 } 233 234 func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { 235 var lfa, rfa *tcpip.FullAddress 236 var pn tcpip.NetworkProtocolNumber 237 if laddr.IsValid() || laddr.Port() > 0 { 238 var addr tcpip.FullAddress 239 addr, pn = convertToFullAddr(laddr) 240 lfa = &addr 241 } 242 if raddr.IsValid() || raddr.Port() > 0 { 243 var addr tcpip.FullAddress 244 addr, pn = convertToFullAddr(raddr) 245 rfa = &addr 246 } 247 return gonet.DialUDP(net.stack, lfa, rfa, pn) 248 } 249 250 func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) { 251 return net.DialUDPAddrPort(laddr, netip.AddrPort{}) 252 } 253 254 func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { 255 var la, ra netip.AddrPort 256 if laddr != nil { 257 ip, _ := netip.AddrFromSlice(laddr.IP) 258 la = netip.AddrPortFrom(ip, uint16(laddr.Port)) 259 } 260 if raddr != nil { 261 ip, _ := netip.AddrFromSlice(raddr.IP) 262 ra = netip.AddrPortFrom(ip, uint16(raddr.Port)) 263 } 264 return net.DialUDPAddrPort(la, ra) 265 } 266 267 func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { 268 return net.DialUDP(laddr, nil) 269 }