github.com/xmplusdev/xray-core@v1.8.10/proxy/wireguard/tun.go (about) 1 package wireguard 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net" 8 "net/netip" 9 "runtime" 10 "strconv" 11 "strings" 12 "sync" 13 "time" 14 15 "github.com/xmplusdev/xray-core/common/log" 16 xnet "github.com/xmplusdev/xray-core/common/net" 17 "github.com/xmplusdev/xray-core/proxy/wireguard/gvisortun" 18 "gvisor.dev/gvisor/pkg/tcpip" 19 "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" 20 "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" 21 "gvisor.dev/gvisor/pkg/tcpip/transport/udp" 22 "gvisor.dev/gvisor/pkg/waiter" 23 24 "golang.zx2c4.com/wireguard/conn" 25 "golang.zx2c4.com/wireguard/device" 26 "golang.zx2c4.com/wireguard/tun" 27 ) 28 29 type tunCreator func(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) 30 31 type promiscuousModeHandler func(dest xnet.Destination, conn net.Conn) 32 33 type Tunnel interface { 34 BuildDevice(ipc string, bind conn.Bind) error 35 DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (net.Conn, error) 36 DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) 37 Close() error 38 } 39 40 type tunnel struct { 41 tun tun.Device 42 device *device.Device 43 rw sync.Mutex 44 } 45 46 func (t *tunnel) BuildDevice(ipc string, bind conn.Bind) (err error) { 47 t.rw.Lock() 48 defer t.rw.Unlock() 49 50 if t.device != nil { 51 return errors.New("device is already initialized") 52 } 53 54 logger := &device.Logger{ 55 Verbosef: func(format string, args ...any) { 56 log.Record(&log.GeneralMessage{ 57 Severity: log.Severity_Debug, 58 Content: fmt.Sprintf(format, args...), 59 }) 60 }, 61 Errorf: func(format string, args ...any) { 62 log.Record(&log.GeneralMessage{ 63 Severity: log.Severity_Error, 64 Content: fmt.Sprintf(format, args...), 65 }) 66 }, 67 } 68 69 t.device = device.NewDevice(t.tun, bind, logger) 70 if err = t.device.IpcSet(ipc); err != nil { 71 return err 72 } 73 if err = t.device.Up(); err != nil { 74 return err 75 } 76 return nil 77 } 78 79 func (t *tunnel) Close() (err error) { 80 t.rw.Lock() 81 defer t.rw.Unlock() 82 83 if t.device == nil { 84 return nil 85 } 86 87 t.device.Close() 88 t.device = nil 89 err = t.tun.Close() 90 t.tun = nil 91 return nil 92 } 93 94 func CalculateInterfaceName(name string) (tunName string) { 95 if runtime.GOOS == "darwin" { 96 tunName = "utun" 97 } else if name != "" { 98 tunName = name 99 } else { 100 tunName = "tun" 101 } 102 interfaces, err := net.Interfaces() 103 if err != nil { 104 return 105 } 106 var tunIndex int 107 for _, netInterface := range interfaces { 108 if strings.HasPrefix(netInterface.Name, tunName) { 109 index, parseErr := strconv.ParseInt(netInterface.Name[len(tunName):], 10, 16) 110 if parseErr == nil { 111 tunIndex = int(index) + 1 112 } 113 } 114 } 115 tunName = fmt.Sprintf("%s%d", tunName, tunIndex) 116 return 117 } 118 119 var _ Tunnel = (*gvisorNet)(nil) 120 121 type gvisorNet struct { 122 tunnel 123 net *gvisortun.Net 124 } 125 126 func (g *gvisorNet) Close() error { 127 return g.tunnel.Close() 128 } 129 130 func (g *gvisorNet) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) ( 131 net.Conn, error, 132 ) { 133 return g.net.DialContextTCPAddrPort(ctx, addr) 134 } 135 136 func (g *gvisorNet) DialUDPAddrPort(laddr, raddr netip.AddrPort) (net.Conn, error) { 137 return g.net.DialUDPAddrPort(laddr, raddr) 138 } 139 140 func createGVisorTun(localAddresses []netip.Addr, mtu int, handler promiscuousModeHandler) (Tunnel, error) { 141 out := &gvisorNet{} 142 tun, n, stack, err := gvisortun.CreateNetTUN(localAddresses, mtu, handler != nil) 143 if err != nil { 144 return nil, err 145 } 146 147 if handler != nil { 148 // handler is only used for promiscuous mode 149 // capture all packets and send to handler 150 151 tcpForwarder := tcp.NewForwarder(stack, 0, 65535, func(r *tcp.ForwarderRequest) { 152 go func(r *tcp.ForwarderRequest) { 153 var ( 154 wq waiter.Queue 155 id = r.ID() 156 ) 157 158 // Perform a TCP three-way handshake. 159 ep, err := r.CreateEndpoint(&wq) 160 if err != nil { 161 newError(err.String()).AtError().WriteToLog() 162 r.Complete(true) 163 return 164 } 165 r.Complete(false) 166 defer ep.Close() 167 168 // enable tcp keep-alive to prevent hanging connections 169 ep.SocketOptions().SetKeepAlive(true) 170 171 // local address is actually destination 172 handler(xnet.TCPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewTCPConn(&wq, ep)) 173 }(r) 174 }) 175 stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) 176 177 udpForwarder := udp.NewForwarder(stack, func(r *udp.ForwarderRequest) { 178 go func(r *udp.ForwarderRequest) { 179 var ( 180 wq waiter.Queue 181 id = r.ID() 182 ) 183 184 ep, err := r.CreateEndpoint(&wq) 185 if err != nil { 186 newError(err.String()).AtError().WriteToLog() 187 return 188 } 189 defer ep.Close() 190 191 // prevents hanging connections and ensure timely release 192 ep.SocketOptions().SetLinger(tcpip.LingerOption{ 193 Enabled: true, 194 Timeout: 15 * time.Second, 195 }) 196 197 handler(xnet.UDPDestination(xnet.IPAddress(id.LocalAddress.AsSlice()), xnet.Port(id.LocalPort)), gonet.NewUDPConn(stack, &wq, ep)) 198 }(r) 199 }) 200 stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) 201 } 202 203 out.tun, out.net = tun, n 204 return out, nil 205 }