github.com/moqsien/xraycore@v1.8.5/proxy/wireguard/tun.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package wireguard 7 8 import ( 9 "context" 10 "fmt" 11 "net" 12 "net/netip" 13 "os" 14 15 "github.com/sagernet/wireguard-go/tun" 16 "github.com/moqsien/xraycore/features/dns" 17 "gvisor.dev/gvisor/pkg/buffer" 18 "gvisor.dev/gvisor/pkg/tcpip" 19 "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" 20 "gvisor.dev/gvisor/pkg/tcpip/header" 21 "gvisor.dev/gvisor/pkg/tcpip/link/channel" 22 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" 23 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" 24 "gvisor.dev/gvisor/pkg/tcpip/stack" 25 "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" 26 "gvisor.dev/gvisor/pkg/tcpip/transport/udp" 27 ) 28 29 type netTun struct { 30 ep *channel.Endpoint 31 stack *stack.Stack 32 events chan tun.Event 33 incomingPacket chan *buffer.View 34 mtu int 35 dnsClient dns.Client 36 hasV4, hasV6 bool 37 } 38 39 type Net netTun 40 41 func CreateNetTUN(localAddresses []netip.Addr, dnsClient dns.Client, mtu int) (tun.Device, *Net, error) { 42 opts := stack.Options{ 43 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 44 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, 45 HandleLocal: true, 46 } 47 dev := &netTun{ 48 ep: channel.New(1024, uint32(mtu), ""), 49 stack: stack.New(opts), 50 events: make(chan tun.Event, 10), 51 incomingPacket: make(chan *buffer.View), 52 dnsClient: dnsClient, 53 mtu: mtu, 54 } 55 dev.ep.AddNotify(dev) 56 tcpipErr := dev.stack.CreateNIC(1, dev.ep) 57 if tcpipErr != nil { 58 return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) 59 } 60 for _, ip := range localAddresses { 61 var protoNumber tcpip.NetworkProtocolNumber 62 if ip.Is4() { 63 protoNumber = ipv4.ProtocolNumber 64 } else if ip.Is6() { 65 protoNumber = ipv6.ProtocolNumber 66 } 67 protoAddr := tcpip.ProtocolAddress{ 68 Protocol: protoNumber, 69 AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), 70 } 71 tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) 72 if tcpipErr != nil { 73 return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) 74 } 75 if ip.Is4() { 76 dev.hasV4 = true 77 } else if ip.Is6() { 78 dev.hasV6 = true 79 } 80 } 81 if dev.hasV4 { 82 dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) 83 } 84 if dev.hasV6 { 85 dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) 86 } 87 88 dev.events <- tun.EventUp 89 return dev, (*Net)(dev), nil 90 } 91 92 func (tun *netTun) Name() (string, error) { 93 return "go", nil 94 } 95 96 func (tun *netTun) File() *os.File { 97 return nil 98 } 99 100 func (tun *netTun) Events() <-chan tun.Event { 101 return tun.events 102 } 103 104 func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { 105 view, ok := <-tun.incomingPacket 106 if !ok { 107 return 0, os.ErrClosed 108 } 109 110 return view.Read(buf[0][offset:]) 111 } 112 113 func (tun *netTun) Write(buf [][]byte, offset int) (count int, err error) { 114 for _, b := range buf { 115 packet := b[offset:] 116 if len(packet) == 0 { 117 continue 118 } 119 120 pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) 121 switch packet[0] >> 4 { 122 case 4: 123 tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) 124 case 6: 125 tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) 126 } 127 count++ 128 } 129 return 130 } 131 132 func (tun *netTun) WriteNotify() { 133 pkt := tun.ep.Read() 134 if pkt == nil { 135 return 136 } 137 138 view := pkt.ToView() 139 pkt.DecRef() 140 141 tun.incomingPacket <- view 142 } 143 144 func (tun *netTun) Flush() error { 145 return nil 146 } 147 148 func (tun *netTun) Close() error { 149 tun.stack.RemoveNIC(1) 150 151 if tun.events != nil { 152 close(tun.events) 153 } 154 155 tun.ep.Close() 156 157 if tun.incomingPacket != nil { 158 close(tun.incomingPacket) 159 } 160 161 return nil 162 } 163 164 func (tun *netTun) MTU() (int, error) { 165 return tun.mtu, nil 166 } 167 168 func (tun *netTun) BatchSize() int { 169 return 1 170 } 171 172 func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { 173 var protoNumber tcpip.NetworkProtocolNumber 174 if endpoint.Addr().Is4() { 175 protoNumber = ipv4.ProtocolNumber 176 } else { 177 protoNumber = ipv6.ProtocolNumber 178 } 179 return tcpip.FullAddress{ 180 NIC: 1, 181 Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), 182 Port: endpoint.Port(), 183 }, protoNumber 184 } 185 186 func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { 187 fa, pn := convertToFullAddr(addr) 188 return gonet.DialContextTCP(ctx, net.stack, fa, pn) 189 } 190 191 func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { 192 if addr == nil { 193 return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) 194 } 195 ip, _ := netip.AddrFromSlice(addr.IP) 196 return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port))) 197 } 198 199 func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { 200 fa, pn := convertToFullAddr(addr) 201 return gonet.DialTCP(net.stack, fa, pn) 202 } 203 204 func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { 205 if addr == nil { 206 return net.DialTCPAddrPort(netip.AddrPort{}) 207 } 208 ip, _ := netip.AddrFromSlice(addr.IP) 209 return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) 210 } 211 212 func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { 213 fa, pn := convertToFullAddr(addr) 214 return gonet.ListenTCP(net.stack, fa, pn) 215 } 216 217 func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { 218 if addr == nil { 219 return net.ListenTCPAddrPort(netip.AddrPort{}) 220 } 221 ip, _ := netip.AddrFromSlice(addr.IP) 222 return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) 223 } 224 225 func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { 226 var lfa, rfa *tcpip.FullAddress 227 var pn tcpip.NetworkProtocolNumber 228 if laddr.IsValid() || laddr.Port() > 0 { 229 var addr tcpip.FullAddress 230 addr, pn = convertToFullAddr(laddr) 231 lfa = &addr 232 } 233 if raddr.IsValid() || raddr.Port() > 0 { 234 var addr tcpip.FullAddress 235 addr, pn = convertToFullAddr(raddr) 236 rfa = &addr 237 } 238 return gonet.DialUDP(net.stack, lfa, rfa, pn) 239 } 240 241 func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) { 242 return net.DialUDPAddrPort(laddr, netip.AddrPort{}) 243 } 244 245 func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { 246 var la, ra netip.AddrPort 247 if laddr != nil { 248 ip, _ := netip.AddrFromSlice(laddr.IP) 249 la = netip.AddrPortFrom(ip, uint16(laddr.Port)) 250 } 251 if raddr != nil { 252 ip, _ := netip.AddrFromSlice(raddr.IP) 253 ra = netip.AddrPortFrom(ip, uint16(raddr.Port)) 254 } 255 return net.DialUDPAddrPort(la, ra) 256 } 257 258 func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { 259 return net.DialUDP(laddr, nil) 260 } 261 262 func (n *Net) HasV4() bool { 263 return n.hasV4 264 } 265 266 func (n *Net) HasV6() bool { 267 return n.hasV6 268 } 269 270 func IsDomainName(s string) bool { 271 l := len(s) 272 if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { 273 return false 274 } 275 last := byte('.') 276 nonNumeric := false 277 partlen := 0 278 for i := 0; i < len(s); i++ { 279 c := s[i] 280 switch { 281 default: 282 return false 283 case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': 284 nonNumeric = true 285 partlen++ 286 case '0' <= c && c <= '9': 287 partlen++ 288 case c == '-': 289 if last == '.' { 290 return false 291 } 292 partlen++ 293 nonNumeric = true 294 case c == '.': 295 if last == '.' || last == '-' { 296 return false 297 } 298 if partlen > 63 || partlen == 0 { 299 return false 300 } 301 partlen = 0 302 } 303 last = c 304 } 305 if last == '-' || partlen > 63 { 306 return false 307 } 308 return nonNumeric 309 }