github.com/xraypb/Xray-core@v1.8.1/proxy/wireguard/wireguard.go (about) 1 /* 2 3 Some of codes are copied from https://github.com/octeep/wireproxy, license below. 4 5 Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me> 6 7 Permission to use, copy, modify, and distribute this software for any 8 purpose with or without fee is hereby granted, provided that the above 9 copyright notice and this permission notice appear in all copies. 10 11 THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 12 WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 13 MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 14 ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 15 WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 16 ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 17 OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 18 19 */ 20 21 package wireguard 22 23 import ( 24 "bytes" 25 "context" 26 "fmt" 27 "net/netip" 28 "strings" 29 30 "github.com/sagernet/wireguard-go/device" 31 "github.com/xraypb/Xray-core/common" 32 "github.com/xraypb/Xray-core/common/buf" 33 "github.com/xraypb/Xray-core/common/log" 34 "github.com/xraypb/Xray-core/common/net" 35 "github.com/xraypb/Xray-core/common/protocol" 36 "github.com/xraypb/Xray-core/common/session" 37 "github.com/xraypb/Xray-core/common/signal" 38 "github.com/xraypb/Xray-core/common/task" 39 "github.com/xraypb/Xray-core/core" 40 "github.com/xraypb/Xray-core/features/dns" 41 "github.com/xraypb/Xray-core/features/policy" 42 "github.com/xraypb/Xray-core/transport" 43 "github.com/xraypb/Xray-core/transport/internet" 44 ) 45 46 // Handler is an outbound connection that silently swallow the entire payload. 47 type Handler struct { 48 conf *DeviceConfig 49 net *Net 50 bind *netBindClient 51 policyManager policy.Manager 52 dns dns.Client 53 // cached configuration 54 ipc string 55 endpoints []netip.Addr 56 } 57 58 // New creates a new wireguard handler. 59 func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) { 60 v := core.MustFromContext(ctx) 61 62 endpoints, err := parseEndpoints(conf) 63 if err != nil { 64 return nil, err 65 } 66 67 return &Handler{ 68 conf: conf, 69 policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), 70 dns: v.GetFeature(dns.ClientType()).(dns.Client), 71 ipc: createIPCRequest(conf), 72 endpoints: endpoints, 73 }, nil 74 } 75 76 // Process implements OutboundHandler.Dispatch(). 77 func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { 78 if h.bind == nil || h.bind.dialer != dialer || h.net == nil { 79 log.Record(&log.GeneralMessage{ 80 Severity: log.Severity_Info, 81 Content: "switching dialer", 82 }) 83 // bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer 84 bind := &netBindClient{ 85 dialer: dialer, 86 workers: int(h.conf.NumWorkers), 87 dns: h.dns, 88 reserved: h.conf.Reserved, 89 } 90 91 net, err := h.makeVirtualTun(bind) 92 if err != nil { 93 bind.Close() 94 return newError("failed to create virtual tun interface").Base(err) 95 } 96 97 h.net = net 98 if h.bind != nil { 99 h.bind.Close() 100 } 101 h.bind = bind 102 } 103 104 outbound := session.OutboundFromContext(ctx) 105 if outbound == nil || !outbound.Target.IsValid() { 106 return newError("target not specified") 107 } 108 // Destination of the inner request. 109 destination := outbound.Target 110 command := protocol.RequestCommandTCP 111 if destination.Network == net.Network_UDP { 112 command = protocol.RequestCommandUDP 113 } 114 115 // resolve dns 116 addr := destination.Address 117 if addr.Family().IsDomain() { 118 ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{ 119 IPv4Enable: h.net.HasV4(), 120 IPv6Enable: h.net.HasV6(), 121 }) 122 if err != nil { 123 return newError("failed to lookup DNS").Base(err) 124 } else if len(ips) == 0 { 125 return dns.ErrEmptyResponse 126 } 127 addr = net.IPAddress(ips[0]) 128 } 129 130 var newCtx context.Context 131 var newCancel context.CancelFunc 132 if session.TimeoutOnlyFromContext(ctx) { 133 newCtx, newCancel = context.WithCancel(context.Background()) 134 } 135 136 p := h.policyManager.ForLevel(0) 137 138 ctx, cancel := context.WithCancel(ctx) 139 timer := signal.CancelAfterInactivity(ctx, func() { 140 cancel() 141 if newCancel != nil { 142 newCancel() 143 } 144 }, p.Timeouts.ConnectionIdle) 145 addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value()) 146 147 var requestFunc func() error 148 var responseFunc func() error 149 150 if command == protocol.RequestCommandTCP { 151 conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort) 152 if err != nil { 153 return newError("failed to create TCP connection").Base(err) 154 } 155 156 requestFunc = func() error { 157 defer timer.SetTimeout(p.Timeouts.DownlinkOnly) 158 return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) 159 } 160 responseFunc = func() error { 161 defer timer.SetTimeout(p.Timeouts.UplinkOnly) 162 return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) 163 } 164 } else if command == protocol.RequestCommandUDP { 165 conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort) 166 if err != nil { 167 return newError("failed to create UDP connection").Base(err) 168 } 169 170 requestFunc = func() error { 171 defer timer.SetTimeout(p.Timeouts.DownlinkOnly) 172 return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) 173 } 174 responseFunc = func() error { 175 defer timer.SetTimeout(p.Timeouts.UplinkOnly) 176 return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) 177 } 178 } 179 180 if newCtx != nil { 181 ctx = newCtx 182 } 183 184 responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer)) 185 if err := task.Run(ctx, requestFunc, responseDonePost); err != nil { 186 return newError("connection ends").Base(err) 187 } 188 189 return nil 190 } 191 192 // serialize the config into an IPC request 193 func createIPCRequest(conf *DeviceConfig) string { 194 var request bytes.Buffer 195 196 request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey)) 197 198 for _, peer := range conf.Peers { 199 request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n", 200 peer.PublicKey, peer.Endpoint, peer.KeepAlive, peer.PreSharedKey)) 201 202 for _, ip := range peer.AllowedIps { 203 request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip)) 204 } 205 } 206 207 return request.String()[:request.Len()] 208 } 209 210 // convert endpoint string to netip.Addr 211 func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) { 212 endpoints := make([]netip.Addr, len(conf.Endpoint)) 213 for i, str := range conf.Endpoint { 214 var addr netip.Addr 215 if strings.Contains(str, "/") { 216 prefix, err := netip.ParsePrefix(str) 217 if err != nil { 218 return nil, err 219 } 220 addr = prefix.Addr() 221 if prefix.Bits() != addr.BitLen() { 222 return nil, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6") 223 } 224 } else { 225 var err error 226 addr, err = netip.ParseAddr(str) 227 if err != nil { 228 return nil, err 229 } 230 } 231 endpoints[i] = addr 232 } 233 234 return endpoints, nil 235 } 236 237 // creates a tun interface on netstack given a configuration 238 func (h *Handler) makeVirtualTun(bind *netBindClient) (*Net, error) { 239 tun, tnet, err := CreateNetTUN(h.endpoints, h.dns, int(h.conf.Mtu)) 240 if err != nil { 241 return nil, err 242 } 243 244 bind.dnsOption.IPv4Enable = tnet.HasV4() 245 bind.dnsOption.IPv6Enable = tnet.HasV6() 246 247 // dev := device.NewDevice(tun, conn.NewDefaultBind(), nil /* device.NewLogger(device.LogLevelVerbose, "") */) 248 dev := device.NewDevice(tun, bind, &device.Logger{ 249 Verbosef: func(format string, args ...any) { 250 log.Record(&log.GeneralMessage{ 251 Severity: log.Severity_Debug, 252 Content: fmt.Sprintf(format, args...), 253 }) 254 }, 255 Errorf: func(format string, args ...any) { 256 log.Record(&log.GeneralMessage{ 257 Severity: log.Severity_Error, 258 Content: fmt.Sprintf(format, args...), 259 }) 260 }, 261 }, int(h.conf.NumWorkers)) 262 err = dev.IpcSet(h.ipc) 263 if err != nil { 264 return nil, err 265 } 266 267 err = dev.Up() 268 if err != nil { 269 return nil, err 270 } 271 272 return tnet, nil 273 } 274 275 func init() { 276 common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 277 return New(ctx, config.(*DeviceConfig)) 278 })) 279 }