github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/proxy/wireguard/server.go (about) 1 package wireguard 2 3 import ( 4 "context" 5 "errors" 6 "io" 7 8 "github.com/xtls/xray-core/common" 9 "github.com/xtls/xray-core/common/buf" 10 "github.com/xtls/xray-core/common/log" 11 "github.com/xtls/xray-core/common/net" 12 "github.com/xtls/xray-core/common/session" 13 "github.com/xtls/xray-core/common/signal" 14 "github.com/xtls/xray-core/common/task" 15 "github.com/xtls/xray-core/core" 16 "github.com/xtls/xray-core/features/dns" 17 "github.com/xtls/xray-core/features/policy" 18 "github.com/xtls/xray-core/features/routing" 19 "github.com/xtls/xray-core/transport/internet/stat" 20 ) 21 22 var nullDestination = net.TCPDestination(net.AnyIP, 0) 23 24 type Server struct { 25 bindServer *netBindServer 26 27 info routingInfo 28 policyManager policy.Manager 29 } 30 31 type routingInfo struct { 32 ctx context.Context 33 dispatcher routing.Dispatcher 34 inboundTag *session.Inbound 35 outboundTag *session.Outbound 36 contentTag *session.Content 37 } 38 39 func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) { 40 v := core.MustFromContext(ctx) 41 42 endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf) 43 if err != nil { 44 return nil, err 45 } 46 47 server := &Server{ 48 bindServer: &netBindServer{ 49 netBind: netBind{ 50 dns: v.GetFeature(dns.ClientType()).(dns.Client), 51 dnsOption: dns.IPOption{ 52 IPv4Enable: hasIPv4, 53 IPv6Enable: hasIPv6, 54 }, 55 }, 56 }, 57 policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), 58 } 59 60 tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection) 61 if err != nil { 62 return nil, err 63 } 64 65 if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil { 66 _ = tun.Close() 67 return nil, err 68 } 69 70 return server, nil 71 } 72 73 // Network implements proxy.Inbound. 74 func (*Server) Network() []net.Network { 75 return []net.Network{net.Network_UDP} 76 } 77 78 // Process implements proxy.Inbound. 79 func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { 80 inbound := session.InboundFromContext(ctx) 81 inbound.Name = "wireguard" 82 inbound.CanSpliceCopy = 3 83 outbounds := session.OutboundsFromContext(ctx) 84 ob := outbounds[len(outbounds) - 1] 85 86 s.info = routingInfo{ 87 ctx: core.ToBackgroundDetachedContext(ctx), 88 dispatcher: dispatcher, 89 inboundTag: session.InboundFromContext(ctx), 90 outboundTag: ob, 91 contentTag: session.ContentFromContext(ctx), 92 } 93 94 ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String()) 95 if err != nil { 96 return err 97 } 98 99 nep := ep.(*netEndpoint) 100 nep.conn = conn 101 102 reader := buf.NewPacketReader(conn) 103 for { 104 mpayload, err := reader.ReadMultiBuffer() 105 if err != nil { 106 return err 107 } 108 109 for _, payload := range mpayload { 110 v, ok := <-s.bindServer.readQueue 111 if !ok { 112 return nil 113 } 114 i, err := payload.Read(v.buff) 115 116 v.bytes = i 117 v.endpoint = nep 118 v.err = err 119 v.waiter.Done() 120 if err != nil && errors.Is(err, io.EOF) { 121 nep.conn = nil 122 return nil 123 } 124 } 125 } 126 } 127 128 func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { 129 if s.info.dispatcher == nil { 130 newError("unexpected: dispatcher == nil").AtError().WriteToLog() 131 return 132 } 133 defer conn.Close() 134 135 ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx)) 136 plcy := s.policyManager.ForLevel(0) 137 timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle) 138 139 ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ 140 From: nullDestination, 141 To: dest, 142 Status: log.AccessAccepted, 143 Reason: "", 144 }) 145 146 if s.info.inboundTag != nil { 147 ctx = session.ContextWithInbound(ctx, s.info.inboundTag) 148 } 149 if s.info.outboundTag != nil { 150 ctx = session.ContextWithOutbounds(ctx, []*session.Outbound{s.info.outboundTag}) 151 } 152 if s.info.contentTag != nil { 153 ctx = session.ContextWithContent(ctx, s.info.contentTag) 154 } 155 156 link, err := s.info.dispatcher.Dispatch(ctx, dest) 157 if err != nil { 158 newError("dispatch connection").Base(err).AtError().WriteToLog() 159 } 160 defer cancel() 161 162 requestDone := func() error { 163 defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly) 164 if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil { 165 return newError("failed to transport all TCP request").Base(err) 166 } 167 168 return nil 169 } 170 171 responseDone := func() error { 172 defer timer.SetTimeout(plcy.Timeouts.UplinkOnly) 173 if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil { 174 return newError("failed to transport all TCP response").Base(err) 175 } 176 177 return nil 178 } 179 180 requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer)) 181 if err := task.Run(ctx, requestDonePost, responseDone); err != nil { 182 common.Interrupt(link.Reader) 183 common.Interrupt(link.Writer) 184 newError("connection ends").Base(err).AtDebug().WriteToLog() 185 return 186 } 187 }