github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/proxy/wireguard/gvisortun/tun.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2022 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package gvisortun 7 8 import ( 9 "context" 10 "fmt" 11 "net/netip" 12 "os" 13 "syscall" 14 15 "golang.zx2c4.com/wireguard/tun" 16 "gvisor.dev/gvisor/pkg/buffer" 17 "gvisor.dev/gvisor/pkg/tcpip" 18 "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" 19 "gvisor.dev/gvisor/pkg/tcpip/header" 20 "gvisor.dev/gvisor/pkg/tcpip/link/channel" 21 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" 22 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" 23 "gvisor.dev/gvisor/pkg/tcpip/stack" 24 "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" 25 "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" 26 "gvisor.dev/gvisor/pkg/tcpip/transport/udp" 27 ) 28 29 type netTun struct { 30 ep *channel.Endpoint 31 stack *stack.Stack 32 events chan tun.Event 33 incomingPacket chan *buffer.View 34 mtu int 35 hasV4, hasV6 bool 36 } 37 38 type Net netTun 39 40 func CreateNetTUN(localAddresses []netip.Addr, mtu int, promiscuousMode bool) (tun.Device, *Net, *stack.Stack, error) { 41 opts := stack.Options{ 42 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 43 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, 44 HandleLocal: !promiscuousMode, 45 } 46 dev := &netTun{ 47 ep: channel.New(1024, uint32(mtu), ""), 48 stack: stack.New(opts), 49 events: make(chan tun.Event, 1), 50 incomingPacket: make(chan *buffer.View), 51 mtu: mtu, 52 } 53 dev.ep.AddNotify(dev) 54 tcpipErr := dev.stack.CreateNIC(1, dev.ep) 55 if tcpipErr != nil { 56 return nil, nil, dev.stack, fmt.Errorf("CreateNIC: %v", tcpipErr) 57 } 58 for _, ip := range localAddresses { 59 var protoNumber tcpip.NetworkProtocolNumber 60 if ip.Is4() { 61 protoNumber = ipv4.ProtocolNumber 62 } else if ip.Is6() { 63 protoNumber = ipv6.ProtocolNumber 64 } 65 protoAddr := tcpip.ProtocolAddress{ 66 Protocol: protoNumber, 67 AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), 68 } 69 tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) 70 if tcpipErr != nil { 71 return nil, nil, dev.stack, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) 72 } 73 if ip.Is4() { 74 dev.hasV4 = true 75 } else if ip.Is6() { 76 dev.hasV6 = true 77 } 78 } 79 if dev.hasV4 { 80 dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) 81 } 82 if dev.hasV6 { 83 dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) 84 } 85 if promiscuousMode { 86 // enable promiscuous mode to handle all packets processed by netstack 87 dev.stack.SetPromiscuousMode(1, true) 88 dev.stack.SetSpoofing(1, true) 89 } 90 91 opt := tcpip.CongestionControlOption("cubic") 92 if err := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { 93 return nil, nil, dev.stack, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err) 94 } 95 96 dev.events <- tun.EventUp 97 return dev, (*Net)(dev), dev.stack, nil 98 } 99 100 // BatchSize implements tun.Device 101 func (tun *netTun) BatchSize() int { 102 return 1 103 } 104 105 // Name implements tun.Device 106 func (tun *netTun) Name() (string, error) { 107 return "go", nil 108 } 109 110 // File implements tun.Device 111 func (tun *netTun) File() *os.File { 112 return nil 113 } 114 115 // Events implements tun.Device 116 func (tun *netTun) Events() <-chan tun.Event { 117 return tun.events 118 } 119 120 // Read implements tun.Device 121 122 func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { 123 view, ok := <-tun.incomingPacket 124 if !ok { 125 return 0, os.ErrClosed 126 } 127 128 n, err := view.Read(buf[0][offset:]) 129 if err != nil { 130 return 0, err 131 } 132 sizes[0] = n 133 return 1, nil 134 } 135 136 // Write implements tun.Device 137 func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { 138 for _, buf := range buf { 139 packet := buf[offset:] 140 if len(packet) == 0 { 141 continue 142 } 143 144 pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) 145 switch packet[0] >> 4 { 146 case 4: 147 tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) 148 case 6: 149 tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) 150 default: 151 return 0, syscall.EAFNOSUPPORT 152 } 153 } 154 return len(buf), nil 155 } 156 157 // WriteNotify implements channel.Notification 158 func (tun *netTun) WriteNotify() { 159 pkt := tun.ep.Read() 160 if pkt.IsNil() { 161 return 162 } 163 164 view := pkt.ToView() 165 pkt.DecRef() 166 167 tun.incomingPacket <- view 168 } 169 170 // Flush implements tun.Device 171 func (tun *netTun) Flush() error { 172 return nil 173 } 174 175 // Close implements tun.Device 176 func (tun *netTun) Close() error { 177 tun.stack.RemoveNIC(1) 178 179 if tun.events != nil { 180 close(tun.events) 181 } 182 183 tun.ep.Close() 184 185 if tun.incomingPacket != nil { 186 close(tun.incomingPacket) 187 } 188 189 return nil 190 } 191 192 // MTU implements tun.Device 193 func (tun *netTun) MTU() (int, error) { 194 return tun.mtu, nil 195 } 196 197 func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { 198 var protoNumber tcpip.NetworkProtocolNumber 199 if endpoint.Addr().Is4() { 200 protoNumber = ipv4.ProtocolNumber 201 } else { 202 protoNumber = ipv6.ProtocolNumber 203 } 204 return tcpip.FullAddress{ 205 NIC: 1, 206 Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), 207 Port: endpoint.Port(), 208 }, protoNumber 209 } 210 211 func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { 212 fa, pn := convertToFullAddr(addr) 213 return gonet.DialContextTCP(ctx, net.stack, fa, pn) 214 } 215 216 func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { 217 var lfa, rfa *tcpip.FullAddress 218 var pn tcpip.NetworkProtocolNumber 219 if laddr.IsValid() || laddr.Port() > 0 { 220 var addr tcpip.FullAddress 221 addr, pn = convertToFullAddr(laddr) 222 lfa = &addr 223 } 224 if raddr.IsValid() || raddr.Port() > 0 { 225 var addr tcpip.FullAddress 226 addr, pn = convertToFullAddr(raddr) 227 rfa = &addr 228 } 229 return gonet.DialUDP(net.stack, lfa, rfa, pn) 230 }