github.com/xmplusdev/xray-core@v1.8.10/proxy/wireguard/server.go (about) 1 package wireguard 2 3 import ( 4 "context" 5 "errors" 6 "io" 7 8 "github.com/xmplusdev/xray-core/common" 9 "github.com/xmplusdev/xray-core/common/buf" 10 "github.com/xmplusdev/xray-core/common/log" 11 "github.com/xmplusdev/xray-core/common/net" 12 "github.com/xmplusdev/xray-core/common/session" 13 "github.com/xmplusdev/xray-core/common/signal" 14 "github.com/xmplusdev/xray-core/common/task" 15 "github.com/xmplusdev/xray-core/core" 16 "github.com/xmplusdev/xray-core/features/dns" 17 "github.com/xmplusdev/xray-core/features/policy" 18 "github.com/xmplusdev/xray-core/features/routing" 19 "github.com/xmplusdev/xray-core/transport/internet/stat" 20 ) 21 22 var nullDestination = net.TCPDestination(net.AnyIP, 0) 23 24 type Server struct { 25 bindServer *netBindServer 26 27 info routingInfo 28 policyManager policy.Manager 29 } 30 31 type routingInfo struct { 32 ctx context.Context 33 dispatcher routing.Dispatcher 34 inboundTag *session.Inbound 35 outboundTag *session.Outbound 36 contentTag *session.Content 37 } 38 39 func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) { 40 v := core.MustFromContext(ctx) 41 42 endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf) 43 if err != nil { 44 return nil, err 45 } 46 47 server := &Server{ 48 bindServer: &netBindServer{ 49 netBind: netBind{ 50 dns: v.GetFeature(dns.ClientType()).(dns.Client), 51 dnsOption: dns.IPOption{ 52 IPv4Enable: hasIPv4, 53 IPv6Enable: hasIPv6, 54 }, 55 }, 56 }, 57 policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), 58 } 59 60 tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection) 61 if err != nil { 62 return nil, err 63 } 64 65 if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil { 66 _ = tun.Close() 67 return nil, err 68 } 69 70 return server, nil 71 } 72 73 // Network implements proxy.Inbound. 74 func (*Server) Network() []net.Network { 75 return []net.Network{net.Network_UDP} 76 } 77 78 // Process implements proxy.Inbound. 79 func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { 80 inbound := session.InboundFromContext(ctx) 81 inbound.Name = "wireguard" 82 inbound.SetCanSpliceCopy(3) 83 84 s.info = routingInfo{ 85 ctx: core.ToBackgroundDetachedContext(ctx), 86 dispatcher: dispatcher, 87 inboundTag: session.InboundFromContext(ctx), 88 outboundTag: session.OutboundFromContext(ctx), 89 contentTag: session.ContentFromContext(ctx), 90 } 91 92 ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String()) 93 if err != nil { 94 return err 95 } 96 97 nep := ep.(*netEndpoint) 98 nep.conn = conn 99 100 reader := buf.NewPacketReader(conn) 101 for { 102 mpayload, err := reader.ReadMultiBuffer() 103 if err != nil { 104 return err 105 } 106 107 for _, payload := range mpayload { 108 v, ok := <-s.bindServer.readQueue 109 if !ok { 110 return nil 111 } 112 i, err := payload.Read(v.buff) 113 114 v.bytes = i 115 v.endpoint = nep 116 v.err = err 117 v.waiter.Done() 118 if err != nil && errors.Is(err, io.EOF) { 119 nep.conn = nil 120 return nil 121 } 122 } 123 } 124 } 125 126 func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) { 127 if s.info.dispatcher == nil { 128 newError("unexpected: dispatcher == nil").AtError().WriteToLog() 129 return 130 } 131 defer conn.Close() 132 133 ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx)) 134 plcy := s.policyManager.ForLevel(0) 135 timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle) 136 137 ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ 138 From: nullDestination, 139 To: dest, 140 Status: log.AccessAccepted, 141 Reason: "", 142 }) 143 144 if s.info.inboundTag != nil { 145 ctx = session.ContextWithInbound(ctx, s.info.inboundTag) 146 } 147 if s.info.outboundTag != nil { 148 ctx = session.ContextWithOutbound(ctx, s.info.outboundTag) 149 } 150 if s.info.contentTag != nil { 151 ctx = session.ContextWithContent(ctx, s.info.contentTag) 152 } 153 154 link, err := s.info.dispatcher.Dispatch(ctx, dest) 155 if err != nil { 156 newError("dispatch connection").Base(err).AtError().WriteToLog() 157 } 158 defer cancel() 159 160 requestDone := func() error { 161 defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly) 162 if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil { 163 return newError("failed to transport all TCP request").Base(err) 164 } 165 166 return nil 167 } 168 169 responseDone := func() error { 170 defer timer.SetTimeout(plcy.Timeouts.UplinkOnly) 171 if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil { 172 return newError("failed to transport all TCP response").Base(err) 173 } 174 175 return nil 176 } 177 178 requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer)) 179 if err := task.Run(ctx, requestDonePost, responseDone); err != nil { 180 common.Interrupt(link.Reader) 181 common.Interrupt(link.Writer) 182 newError("connection ends").Base(err).AtDebug().WriteToLog() 183 return 184 } 185 }