github.com/moqsien/xraycore@v1.8.5/proxy/wireguard/wireguard.go (about) 1 /* 2 3 Some of codes are copied from https://github.com/octeep/wireproxy, license below. 4 5 Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me> 6 7 Permission to use, copy, modify, and distribute this software for any 8 purpose with or without fee is hereby granted, provided that the above 9 copyright notice and this permission notice appear in all copies. 10 11 THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 12 WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 13 MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 14 ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 15 WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 16 ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 17 OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 18 19 */ 20 21 package wireguard 22 23 import ( 24 "bytes" 25 "context" 26 "fmt" 27 "net/netip" 28 "strings" 29 30 "github.com/sagernet/wireguard-go/device" 31 "github.com/moqsien/xraycore/common" 32 "github.com/moqsien/xraycore/common/buf" 33 "github.com/moqsien/xraycore/common/log" 34 "github.com/moqsien/xraycore/common/net" 35 "github.com/moqsien/xraycore/common/protocol" 36 "github.com/moqsien/xraycore/common/session" 37 "github.com/moqsien/xraycore/common/signal" 38 "github.com/moqsien/xraycore/common/task" 39 "github.com/moqsien/xraycore/core" 40 "github.com/moqsien/xraycore/features/dns" 41 "github.com/moqsien/xraycore/features/policy" 42 "github.com/moqsien/xraycore/transport" 43 "github.com/moqsien/xraycore/transport/internet" 44 ) 45 46 // Handler is an outbound connection that silently swallow the entire payload. 47 type Handler struct { 48 conf *DeviceConfig 49 net *Net 50 bind *netBindClient 51 policyManager policy.Manager 52 dns dns.Client 53 // cached configuration 54 ipc string 55 endpoints []netip.Addr 56 } 57 58 // New creates a new wireguard handler. 59 func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) { 60 v := core.MustFromContext(ctx) 61 62 endpoints, err := parseEndpoints(conf) 63 if err != nil { 64 return nil, err 65 } 66 67 return &Handler{ 68 conf: conf, 69 policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), 70 dns: v.GetFeature(dns.ClientType()).(dns.Client), 71 ipc: createIPCRequest(conf), 72 endpoints: endpoints, 73 }, nil 74 } 75 76 // Process implements OutboundHandler.Dispatch(). 77 func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { 78 if h.bind == nil || h.bind.dialer != dialer || h.net == nil { 79 log.Record(&log.GeneralMessage{ 80 Severity: log.Severity_Info, 81 Content: "switching dialer", 82 }) 83 // bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer 84 bind := &netBindClient{ 85 dialer: dialer, 86 workers: int(h.conf.NumWorkers), 87 dns: h.dns, 88 reserved: h.conf.Reserved, 89 } 90 91 net, err := h.makeVirtualTun(bind) 92 if err != nil { 93 bind.Close() 94 return newError("failed to create virtual tun interface").Base(err) 95 } 96 97 h.net = net 98 if h.bind != nil { 99 h.bind.Close() 100 } 101 h.bind = bind 102 } 103 104 outbound := session.OutboundFromContext(ctx) 105 if outbound == nil || !outbound.Target.IsValid() { 106 return newError("target not specified") 107 } 108 // Destination of the inner request. 109 destination := outbound.Target 110 command := protocol.RequestCommandTCP 111 if destination.Network == net.Network_UDP { 112 command = protocol.RequestCommandUDP 113 } 114 115 // resolve dns 116 addr := destination.Address 117 if addr.Family().IsDomain() { 118 ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{ 119 IPv4Enable: h.net.HasV4(), 120 IPv6Enable: h.net.HasV6(), 121 }) 122 if err != nil { 123 return newError("failed to lookup DNS").Base(err) 124 } else if len(ips) == 0 { 125 return dns.ErrEmptyResponse 126 } 127 addr = net.IPAddress(ips[0]) 128 } 129 130 var newCtx context.Context 131 var newCancel context.CancelFunc 132 if session.TimeoutOnlyFromContext(ctx) { 133 newCtx, newCancel = context.WithCancel(context.Background()) 134 } 135 136 p := h.policyManager.ForLevel(0) 137 138 ctx, cancel := context.WithCancel(ctx) 139 timer := signal.CancelAfterInactivity(ctx, func() { 140 cancel() 141 if newCancel != nil { 142 newCancel() 143 } 144 }, p.Timeouts.ConnectionIdle) 145 addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value()) 146 147 var requestFunc func() error 148 var responseFunc func() error 149 150 if command == protocol.RequestCommandTCP { 151 conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort) 152 if err != nil { 153 return newError("failed to create TCP connection").Base(err) 154 } 155 defer conn.Close() 156 157 requestFunc = func() error { 158 defer timer.SetTimeout(p.Timeouts.DownlinkOnly) 159 return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) 160 } 161 responseFunc = func() error { 162 defer timer.SetTimeout(p.Timeouts.UplinkOnly) 163 return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) 164 } 165 } else if command == protocol.RequestCommandUDP { 166 conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort) 167 if err != nil { 168 return newError("failed to create UDP connection").Base(err) 169 } 170 defer conn.Close() 171 172 requestFunc = func() error { 173 defer timer.SetTimeout(p.Timeouts.DownlinkOnly) 174 return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) 175 } 176 responseFunc = func() error { 177 defer timer.SetTimeout(p.Timeouts.UplinkOnly) 178 return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) 179 } 180 } 181 182 if newCtx != nil { 183 ctx = newCtx 184 } 185 186 responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer)) 187 if err := task.Run(ctx, requestFunc, responseDonePost); err != nil { 188 common.Interrupt(link.Reader) 189 common.Interrupt(link.Writer) 190 return newError("connection ends").Base(err) 191 } 192 193 return nil 194 } 195 196 // serialize the config into an IPC request 197 func createIPCRequest(conf *DeviceConfig) string { 198 var request bytes.Buffer 199 200 request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey)) 201 202 for _, peer := range conf.Peers { 203 request.WriteString(fmt.Sprintf("public_key=%s\nendpoint=%s\npersistent_keepalive_interval=%d\npreshared_key=%s\n", 204 peer.PublicKey, peer.Endpoint, peer.KeepAlive, peer.PreSharedKey)) 205 206 for _, ip := range peer.AllowedIps { 207 request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip)) 208 } 209 } 210 211 return request.String()[:request.Len()] 212 } 213 214 // convert endpoint string to netip.Addr 215 func parseEndpoints(conf *DeviceConfig) ([]netip.Addr, error) { 216 endpoints := make([]netip.Addr, len(conf.Endpoint)) 217 for i, str := range conf.Endpoint { 218 var addr netip.Addr 219 if strings.Contains(str, "/") { 220 prefix, err := netip.ParsePrefix(str) 221 if err != nil { 222 return nil, err 223 } 224 addr = prefix.Addr() 225 if prefix.Bits() != addr.BitLen() { 226 return nil, newError("interface address subnet should be /32 for IPv4 and /128 for IPv6") 227 } 228 } else { 229 var err error 230 addr, err = netip.ParseAddr(str) 231 if err != nil { 232 return nil, err 233 } 234 } 235 endpoints[i] = addr 236 } 237 238 return endpoints, nil 239 } 240 241 // creates a tun interface on netstack given a configuration 242 func (h *Handler) makeVirtualTun(bind *netBindClient) (*Net, error) { 243 tun, tnet, err := CreateNetTUN(h.endpoints, h.dns, int(h.conf.Mtu)) 244 if err != nil { 245 return nil, err 246 } 247 248 bind.dnsOption.IPv4Enable = tnet.HasV4() 249 bind.dnsOption.IPv6Enable = tnet.HasV6() 250 251 // dev := device.NewDevice(tun, conn.NewDefaultBind(), nil /* device.NewLogger(device.LogLevelVerbose, "") */) 252 dev := device.NewDevice(context.Background(), tun, bind, &device.Logger{ 253 Verbosef: func(format string, args ...any) { 254 log.Record(&log.GeneralMessage{ 255 Severity: log.Severity_Debug, 256 Content: fmt.Sprintf(format, args...), 257 }) 258 }, 259 Errorf: func(format string, args ...any) { 260 log.Record(&log.GeneralMessage{ 261 Severity: log.Severity_Error, 262 Content: fmt.Sprintf(format, args...), 263 }) 264 }, 265 }, int(h.conf.NumWorkers)) 266 err = dev.IpcSet(h.ipc) 267 if err != nil { 268 return nil, err 269 } 270 271 err = dev.Up() 272 if err != nil { 273 return nil, err 274 } 275 276 return tnet, nil 277 } 278 279 func init() { 280 common.Must(common.RegisterConfig((*DeviceConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 281 return New(ctx, config.(*DeviceConfig)) 282 })) 283 }