github.com/xraypb/Xray-core@v1.8.1/proxy/wireguard/tun.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package wireguard 7 8 import ( 9 "context" 10 "fmt" 11 "net" 12 "net/netip" 13 "os" 14 15 "github.com/sagernet/wireguard-go/tun" 16 "github.com/xraypb/Xray-core/features/dns" 17 "gvisor.dev/gvisor/pkg/bufferv2" 18 "gvisor.dev/gvisor/pkg/tcpip" 19 "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" 20 "gvisor.dev/gvisor/pkg/tcpip/header" 21 "gvisor.dev/gvisor/pkg/tcpip/link/channel" 22 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" 23 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" 24 "gvisor.dev/gvisor/pkg/tcpip/stack" 25 "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" 26 "gvisor.dev/gvisor/pkg/tcpip/transport/udp" 27 ) 28 29 type netTun struct { 30 ep *channel.Endpoint 31 stack *stack.Stack 32 events chan tun.Event 33 incomingPacket chan *bufferv2.View 34 mtu int 35 dnsClient dns.Client 36 hasV4, hasV6 bool 37 } 38 39 type Net netTun 40 41 func CreateNetTUN(localAddresses []netip.Addr, dnsClient dns.Client, mtu int) (tun.Device, *Net, error) { 42 opts := stack.Options{ 43 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 44 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, 45 HandleLocal: true, 46 } 47 dev := &netTun{ 48 ep: channel.New(1024, uint32(mtu), ""), 49 stack: stack.New(opts), 50 events: make(chan tun.Event, 10), 51 incomingPacket: make(chan *bufferv2.View), 52 dnsClient: dnsClient, 53 mtu: mtu, 54 } 55 dev.ep.AddNotify(dev) 56 tcpipErr := dev.stack.CreateNIC(1, dev.ep) 57 if tcpipErr != nil { 58 return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) 59 } 60 for _, ip := range localAddresses { 61 var protoNumber tcpip.NetworkProtocolNumber 62 if ip.Is4() { 63 protoNumber = ipv4.ProtocolNumber 64 } else if ip.Is6() { 65 protoNumber = ipv6.ProtocolNumber 66 } 67 protoAddr := tcpip.ProtocolAddress{ 68 Protocol: protoNumber, 69 AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(), 70 } 71 tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) 72 if tcpipErr != nil { 73 return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) 74 } 75 if ip.Is4() { 76 dev.hasV4 = true 77 } else if ip.Is6() { 78 dev.hasV6 = true 79 } 80 } 81 if dev.hasV4 { 82 dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) 83 } 84 if dev.hasV6 { 85 dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) 86 } 87 88 dev.events <- tun.EventUp 89 return dev, (*Net)(dev), nil 90 } 91 92 func (tun *netTun) Name() (string, error) { 93 return "go", nil 94 } 95 96 func (tun *netTun) File() *os.File { 97 return nil 98 } 99 100 func (tun *netTun) Events() chan tun.Event { 101 return tun.events 102 } 103 104 func (tun *netTun) Read(buf []byte, offset int) (int, error) { 105 view, ok := <-tun.incomingPacket 106 if !ok { 107 return 0, os.ErrClosed 108 } 109 110 return view.Read(buf[offset:]) 111 } 112 113 func (tun *netTun) Write(buf []byte, offset int) (int, error) { 114 packet := buf[offset:] 115 if len(packet) == 0 { 116 return 0, nil 117 } 118 119 pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: bufferv2.MakeWithData(packet)}) 120 switch packet[0] >> 4 { 121 case 4: 122 tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) 123 case 6: 124 tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) 125 } 126 127 return len(buf), nil 128 } 129 130 func (tun *netTun) WriteNotify() { 131 pkt := tun.ep.Read() 132 if pkt == nil { 133 return 134 } 135 136 view := pkt.ToView() 137 pkt.DecRef() 138 139 tun.incomingPacket <- view 140 } 141 142 func (tun *netTun) Flush() error { 143 return nil 144 } 145 146 func (tun *netTun) Close() error { 147 tun.stack.RemoveNIC(1) 148 149 if tun.events != nil { 150 close(tun.events) 151 } 152 153 tun.ep.Close() 154 155 if tun.incomingPacket != nil { 156 close(tun.incomingPacket) 157 } 158 159 return nil 160 } 161 162 func (tun *netTun) MTU() (int, error) { 163 return tun.mtu, nil 164 } 165 166 func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { 167 var protoNumber tcpip.NetworkProtocolNumber 168 if endpoint.Addr().Is4() { 169 protoNumber = ipv4.ProtocolNumber 170 } else { 171 protoNumber = ipv6.ProtocolNumber 172 } 173 return tcpip.FullAddress{ 174 NIC: 1, 175 Addr: tcpip.Address(endpoint.Addr().AsSlice()), 176 Port: endpoint.Port(), 177 }, protoNumber 178 } 179 180 func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { 181 fa, pn := convertToFullAddr(addr) 182 return gonet.DialContextTCP(ctx, net.stack, fa, pn) 183 } 184 185 func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { 186 if addr == nil { 187 return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) 188 } 189 ip, _ := netip.AddrFromSlice(addr.IP) 190 return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port))) 191 } 192 193 func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { 194 fa, pn := convertToFullAddr(addr) 195 return gonet.DialTCP(net.stack, fa, pn) 196 } 197 198 func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { 199 if addr == nil { 200 return net.DialTCPAddrPort(netip.AddrPort{}) 201 } 202 ip, _ := netip.AddrFromSlice(addr.IP) 203 return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) 204 } 205 206 func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { 207 fa, pn := convertToFullAddr(addr) 208 return gonet.ListenTCP(net.stack, fa, pn) 209 } 210 211 func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { 212 if addr == nil { 213 return net.ListenTCPAddrPort(netip.AddrPort{}) 214 } 215 ip, _ := netip.AddrFromSlice(addr.IP) 216 return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) 217 } 218 219 func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { 220 var lfa, rfa *tcpip.FullAddress 221 var pn tcpip.NetworkProtocolNumber 222 if laddr.IsValid() || laddr.Port() > 0 { 223 var addr tcpip.FullAddress 224 addr, pn = convertToFullAddr(laddr) 225 lfa = &addr 226 } 227 if raddr.IsValid() || raddr.Port() > 0 { 228 var addr tcpip.FullAddress 229 addr, pn = convertToFullAddr(raddr) 230 rfa = &addr 231 } 232 return gonet.DialUDP(net.stack, lfa, rfa, pn) 233 } 234 235 func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) { 236 return net.DialUDPAddrPort(laddr, netip.AddrPort{}) 237 } 238 239 func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { 240 var la, ra netip.AddrPort 241 if laddr != nil { 242 ip, _ := netip.AddrFromSlice(laddr.IP) 243 la = netip.AddrPortFrom(ip, uint16(laddr.Port)) 244 } 245 if raddr != nil { 246 ip, _ := netip.AddrFromSlice(raddr.IP) 247 ra = netip.AddrPortFrom(ip, uint16(raddr.Port)) 248 } 249 return net.DialUDPAddrPort(la, ra) 250 } 251 252 func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { 253 return net.DialUDP(laddr, nil) 254 } 255 256 func (n *Net) HasV4() bool { 257 return n.hasV4 258 } 259 260 func (n *Net) HasV6() bool { 261 return n.hasV6 262 } 263 264 func IsDomainName(s string) bool { 265 l := len(s) 266 if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { 267 return false 268 } 269 last := byte('.') 270 nonNumeric := false 271 partlen := 0 272 for i := 0; i < len(s); i++ { 273 c := s[i] 274 switch { 275 default: 276 return false 277 case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': 278 nonNumeric = true 279 partlen++ 280 case '0' <= c && c <= '9': 281 partlen++ 282 case c == '-': 283 if last == '.' { 284 return false 285 } 286 partlen++ 287 nonNumeric = true 288 case c == '.': 289 if last == '.' || last == '-' { 290 return false 291 } 292 if partlen > 63 || partlen == 0 { 293 return false 294 } 295 partlen = 0 296 } 297 last = c 298 } 299 if last == '-' || partlen > 63 { 300 return false 301 } 302 return nonNumeric 303 }